Repository: stanfordnlp/stanza Branch: main Commit: 516b07140fdf Files: 579 Total size: 3.8 MB Directory structure: gitextract_z9hqe0ws/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── feature_request.md │ │ └── question.md │ ├── pull_request_template.md │ ├── stale.yml │ └── workflows/ │ └── stanza-tests.yaml ├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── demo/ │ ├── CONLL_Dependency_Visualizer_Example.ipynb │ ├── Dependency_Visualization_Testing.ipynb │ ├── NER_Visualization.ipynb │ ├── Stanza_Beginners_Guide.ipynb │ ├── Stanza_CoreNLP_Interface.ipynb │ ├── arabic_test.conllu.txt │ ├── corenlp.py │ ├── en_test.conllu.txt │ ├── japanese_test.conllu.txt │ ├── pipeline_demo.py │ ├── scenegraph.py │ ├── semgrex visualization.ipynb │ ├── semgrex.py │ └── ssurgeon_script.txt ├── doc/ │ └── CoreNLP.proto ├── scripts/ │ ├── config.sh │ └── download_vectors.sh ├── setup.py └── stanza/ ├── __init__.py ├── _version.py ├── models/ │ ├── __init__.py │ ├── _training_logging.py │ ├── charlm.py │ ├── classifier.py │ ├── classifiers/ │ │ ├── __init__.py │ │ ├── base_classifier.py │ │ ├── cnn_classifier.py │ │ ├── config.py │ │ ├── constituency_classifier.py │ │ ├── data.py │ │ ├── iterate_test.py │ │ ├── trainer.py │ │ └── utils.py │ ├── common/ │ │ ├── __init__.py │ │ ├── beam.py │ │ ├── bert_embedding.py │ │ ├── biaffine.py │ │ ├── build_short_name_to_treebank.py │ │ ├── char_model.py │ │ ├── chuliu_edmonds.py │ │ ├── constant.py │ │ ├── convert_pretrain.py │ │ ├── count_ner_coverage.py │ │ ├── count_pretrain_coverage.py │ │ ├── crf.py │ │ ├── data.py │ │ ├── doc.py │ │ ├── dropout.py │ │ ├── exceptions.py │ │ ├── foundation_cache.py │ │ ├── hlstm.py │ │ ├── large_margin_loss.py │ │ ├── loss.py │ │ ├── maxout_linear.py │ │ ├── packed_lstm.py │ │ ├── peft_config.py │ │ ├── pretrain.py │ │ ├── relative_attn.py │ │ ├── seq2seq_constant.py │ │ ├── seq2seq_model.py │ │ ├── seq2seq_modules.py │ │ ├── seq2seq_utils.py │ │ ├── short_name_to_treebank.py │ │ ├── stanza_object.py │ │ ├── trainer.py │ │ ├── utils.py │ │ └── vocab.py │ ├── constituency/ │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── base_trainer.py │ │ ├── dynamic_oracle.py │ │ ├── ensemble.py │ │ ├── error_analysis_in_order.py │ │ ├── evaluate_treebanks.py │ │ ├── in_order_compound_oracle.py │ │ ├── in_order_oracle.py │ │ ├── label_attention.py │ │ ├── lstm_model.py │ │ ├── lstm_tree_stack.py │ │ ├── parse_transitions.py │ │ ├── parse_tree.py │ │ ├── parser_training.py │ │ ├── partitioned_transformer.py │ │ ├── positional_encoding.py │ │ ├── retagging.py │ │ ├── score_converted_dependencies.py │ │ ├── state.py │ │ ├── text_processing.py │ │ ├── top_down_oracle.py │ │ ├── trainer.py │ │ ├── transformer_tree_stack.py │ │ ├── transition_sequence.py │ │ ├── tree_embedding.py │ │ ├── tree_reader.py │ │ ├── tree_stack.py │ │ └── utils.py │ ├── constituency_parser.py │ ├── coref/ │ │ ├── __init__.py │ │ ├── anaphoricity_scorer.py │ │ ├── bert.py │ │ ├── cluster_checker.py │ │ ├── config.py │ │ ├── conll.py │ │ ├── const.py │ │ ├── coref_chain.py │ │ ├── coref_config.toml │ │ ├── dataset.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── pairwise_encoder.py │ │ ├── predict.py │ │ ├── rough_scorer.py │ │ ├── span_predictor.py │ │ ├── tokenizer_customization.py │ │ ├── utils.py │ │ └── word_encoder.py │ ├── depparse/ │ │ ├── __init__.py │ │ ├── data.py │ │ ├── model.py │ │ ├── scorer.py │ │ └── trainer.py │ ├── identity_lemmatizer.py │ ├── lang_identifier.py │ ├── langid/ │ │ ├── __init__.py │ │ ├── create_ud_data.py │ │ ├── data.py │ │ ├── model.py │ │ └── trainer.py │ ├── lemma/ │ │ ├── __init__.py │ │ ├── attach_lemma_classifier.py │ │ ├── data.py │ │ ├── edit.py │ │ ├── scorer.py │ │ ├── trainer.py │ │ └── vocab.py │ ├── lemma_classifier/ │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── base_trainer.py │ │ ├── baseline_model.py │ │ ├── constants.py │ │ ├── evaluate_many.py │ │ ├── evaluate_models.py │ │ ├── lstm_model.py │ │ ├── prepare_dataset.py │ │ ├── train_lstm_model.py │ │ ├── train_many.py │ │ ├── train_transformer_model.py │ │ ├── transformer_model.py │ │ └── utils.py │ ├── lemmatizer.py │ ├── mwt/ │ │ ├── __init__.py │ │ ├── character_classifier.py │ │ ├── data.py │ │ ├── scorer.py │ │ ├── trainer.py │ │ ├── utils.py │ │ └── vocab.py │ ├── mwt_expander.py │ ├── ner/ │ │ ├── __init__.py │ │ ├── data.py │ │ ├── model.py │ │ ├── scorer.py │ │ ├── trainer.py │ │ ├── utils.py │ │ └── vocab.py │ ├── ner_tagger.py │ ├── parser.py │ ├── pos/ │ │ ├── __init__.py │ │ ├── build_xpos_vocab_factory.py │ │ ├── data.py │ │ ├── model.py │ │ ├── scorer.py │ │ ├── trainer.py │ │ ├── vocab.py │ │ ├── xpos_vocab_factory.py │ │ └── xpos_vocab_utils.py │ ├── tagger.py │ ├── tokenization/ │ │ ├── __init__.py │ │ ├── data.py │ │ ├── model.py │ │ ├── tokenize_files.py │ │ ├── trainer.py │ │ ├── utils.py │ │ └── vocab.py │ ├── tokenizer.py │ └── wl_coref.py ├── pipeline/ │ ├── __init__.py │ ├── _constants.py │ ├── constituency_processor.py │ ├── core.py │ ├── coref_processor.py │ ├── demo/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── demo_server.py │ │ ├── stanza-brat.css │ │ ├── stanza-brat.html │ │ ├── stanza-brat.js │ │ └── stanza-parseviewer.js │ ├── depparse_processor.py │ ├── external/ │ │ ├── __init__.py │ │ ├── corenlp_converter_depparse.py │ │ ├── jieba.py │ │ ├── pythainlp.py │ │ ├── spacy.py │ │ └── sudachipy.py │ ├── langid_processor.py │ ├── lemma_processor.py │ ├── morphseg_processor.py │ ├── multilingual.py │ ├── mwt_processor.py │ ├── ner_processor.py │ ├── pos_processor.py │ ├── processor.py │ ├── registry.py │ ├── sentiment_processor.py │ └── tokenize_processor.py ├── protobuf/ │ ├── CoreNLP_pb2.py │ └── __init__.py ├── resources/ │ ├── __init__.py │ ├── common.py │ ├── default_packages.py │ ├── installation.py │ ├── prepare_resources.py │ └── print_charlm_depparse.py ├── server/ │ ├── __init__.py │ ├── annotator.py │ ├── client.py │ ├── dependency_converter.py │ ├── java_protobuf_requests.py │ ├── main.py │ ├── morphology.py │ ├── parser_eval.py │ ├── semgrex.py │ ├── ssurgeon.py │ ├── tokensregex.py │ ├── tsurgeon.py │ └── ud_enhancer.py ├── tests/ │ ├── __init__.py │ ├── classifiers/ │ │ ├── __init__.py │ │ ├── test_classifier.py │ │ ├── test_constituency_classifier.py │ │ ├── test_data.py │ │ └── test_process_utils.py │ ├── common/ │ │ ├── __init__.py │ │ ├── test_bert_embedding.py │ │ ├── test_char_model.py │ │ ├── test_chuliu_edmonds.py │ │ ├── test_common_data.py │ │ ├── test_confusion.py │ │ ├── test_constant.py │ │ ├── test_data_conversion.py │ │ ├── test_data_objects.py │ │ ├── test_doc.py │ │ ├── test_dropout.py │ │ ├── test_foundation_cache.py │ │ ├── test_pretrain.py │ │ ├── test_relative_attn.py │ │ ├── test_short_name_to_treebank.py │ │ └── test_utils.py │ ├── constituency/ │ │ ├── __init__.py │ │ ├── test_convert_arboretum.py │ │ ├── test_convert_it_vit.py │ │ ├── test_convert_starlang.py │ │ ├── test_ensemble.py │ │ ├── test_in_order_compound_oracle.py │ │ ├── test_in_order_oracle.py │ │ ├── test_lstm_model.py │ │ ├── test_parse_transitions.py │ │ ├── test_parse_tree.py │ │ ├── test_positional_encoding.py │ │ ├── test_selftrain_vi_quad.py │ │ ├── test_text_processing.py │ │ ├── test_top_down_oracle.py │ │ ├── test_trainer.py │ │ ├── test_transformer_tree_stack.py │ │ ├── test_transition_sequence.py │ │ ├── test_tree_reader.py │ │ ├── test_tree_stack.py │ │ ├── test_utils.py │ │ └── test_vietnamese.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── coref/ │ │ │ ├── __init__.py │ │ │ └── test_hebrew_iahlt.py │ │ ├── ner/ │ │ │ ├── __init__.py │ │ │ ├── test_prepare_ner_file.py │ │ │ └── test_utils.py │ │ ├── test_common.py │ │ └── test_vietnamese_renormalization.py │ ├── depparse/ │ │ ├── __init__.py │ │ ├── test_depparse_data.py │ │ └── test_parser.py │ ├── langid/ │ │ ├── __init__.py │ │ ├── test_langid.py │ │ └── test_multilingual.py │ ├── lemma/ │ │ ├── __init__.py │ │ ├── test_data.py │ │ ├── test_lemma_trainer.py │ │ └── test_lowercase.py │ ├── lemma_classifier/ │ │ ├── __init__.py │ │ ├── test_data_preparation.py │ │ └── test_training.py │ ├── morphseg/ │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_integration.py │ │ ├── test_morpheme_segmenter.py │ │ └── test_stanza_integration.py │ ├── mwt/ │ │ ├── __init__.py │ │ ├── test_character_classifier.py │ │ ├── test_english_corner_cases.py │ │ ├── test_prepare_mwt.py │ │ └── test_utils.py │ ├── ner/ │ │ ├── __init__.py │ │ ├── test_bsf_2_beios.py │ │ ├── test_bsf_2_iob.py │ │ ├── test_combine_ner_datasets.py │ │ ├── test_convert_amt.py │ │ ├── test_convert_nkjp.py │ │ ├── test_convert_starlang_ner.py │ │ ├── test_data.py │ │ ├── test_from_conllu.py │ │ ├── test_models_ner_scorer.py │ │ ├── test_ner_tagger.py │ │ ├── test_ner_trainer.py │ │ ├── test_ner_training.py │ │ ├── test_ner_utils.py │ │ ├── test_pay_amt_annotators.py │ │ ├── test_split_wikiner.py │ │ └── test_suc3.py │ ├── pipeline/ │ │ ├── __init__.py │ │ ├── pipeline_device_tests.py │ │ ├── test_arabic_pipeline.py │ │ ├── test_core.py │ │ ├── test_decorators.py │ │ ├── test_depparse.py │ │ ├── test_english_pipeline.py │ │ ├── test_french_pipeline.py │ │ ├── test_lemmatizer.py │ │ ├── test_pipeline_constituency_processor.py │ │ ├── test_pipeline_depparse_processor.py │ │ ├── test_pipeline_mwt_expander.py │ │ ├── test_pipeline_ner_processor.py │ │ ├── test_pipeline_pos_processor.py │ │ ├── test_pipeline_sentiment_processor.py │ │ ├── test_requirements.py │ │ └── test_tokenizer.py │ ├── pos/ │ │ ├── __init__.py │ │ ├── test_data.py │ │ ├── test_tagger.py │ │ └── test_xpos_vocab_factory.py │ ├── pytest.ini │ ├── resources/ │ │ ├── __init__.py │ │ ├── test_charlm_depparse.py │ │ ├── test_common.py │ │ ├── test_default_packages.py │ │ ├── test_installation.py │ │ └── test_prepare_resources.py │ ├── server/ │ │ ├── __init__.py │ │ ├── test_client.py │ │ ├── test_java_protobuf_requests.py │ │ ├── test_morphology.py │ │ ├── test_parser_eval.py │ │ ├── test_protobuf.py │ │ ├── test_semgrex.py │ │ ├── test_server_misc.py │ │ ├── test_server_pretokenized.py │ │ ├── test_server_request.py │ │ ├── test_server_start.py │ │ ├── test_ssurgeon.py │ │ ├── test_tokensregex.py │ │ ├── test_tsurgeon.py │ │ └── test_ud_enhancer.py │ ├── setup.py │ └── tokenization/ │ ├── __init__.py │ ├── test_prepare_tokenizer_treebank.py │ ├── test_replace_long_tokens.py │ ├── test_spaces.py │ ├── test_tokenization_lst20.py │ ├── test_tokenization_orchid.py │ ├── test_tokenize_data.py │ ├── test_tokenize_files.py │ ├── test_tokenize_utils.py │ └── test_vocab.py └── utils/ ├── __init__.py ├── avg_sent_len.py ├── charlm/ │ ├── __init__.py │ ├── conll17_to_text.py │ ├── dump_oscar.py │ ├── make_lm_data.py │ └── oscar_to_text.py ├── confusion.py ├── conll.py ├── constituency/ │ ├── __init__.py │ ├── check_transitions.py │ ├── grep_dev_logs.py │ ├── grep_test_logs.py │ └── list_tensors.py ├── datasets/ │ ├── __init__.py │ ├── common.py │ ├── conllu_to_text.py │ ├── constituency/ │ │ ├── __init__.py │ │ ├── build_silver_dataset.py │ │ ├── common_trees.py │ │ ├── convert_alt.py │ │ ├── convert_arboretum.py │ │ ├── convert_cintil.py │ │ ├── convert_ctb.py │ │ ├── convert_icepahc.py │ │ ├── convert_it_turin.py │ │ ├── convert_it_vit.py │ │ ├── convert_spmrl.py │ │ ├── convert_starlang.py │ │ ├── count_common_words.py │ │ ├── extract_all_silver_dataset.py │ │ ├── extract_silver_dataset.py │ │ ├── prepare_con_dataset.py │ │ ├── reduce_dataset.py │ │ ├── relabel_tags.py │ │ ├── selftrain.py │ │ ├── selftrain_it.py │ │ ├── selftrain_single_file.py │ │ ├── selftrain_vi_quad.py │ │ ├── selftrain_wiki.py │ │ ├── silver_variance.py │ │ ├── split_holdout.py │ │ ├── split_weighted_ensemble.py │ │ ├── tokenize_wiki.py │ │ ├── treebank_to_labeled_brackets.py │ │ ├── utils.py │ │ ├── vtb_convert.py │ │ └── vtb_split.py │ ├── contract_mwt.py │ ├── coref/ │ │ ├── __init__.py │ │ ├── balance_languages.py │ │ ├── convert_hebrew_iahlt.py │ │ ├── convert_hebrew_mixed.py │ │ ├── convert_hindi.py │ │ ├── convert_ontonotes.py │ │ ├── convert_tamil.py │ │ ├── convert_udcoref.py │ │ ├── convert_udcoref_1.2.py │ │ └── utils.py │ ├── corenlp_segmenter_dataset.py │ ├── depparse/ │ │ └── check_results.py │ ├── ner/ │ │ ├── __init__.py │ │ ├── build_en_combined.py │ │ ├── check_for_duplicates.py │ │ ├── combine_ner_datasets.py │ │ ├── compare_entities.py │ │ ├── conll_to_iob.py │ │ ├── convert_amt.py │ │ ├── convert_ar_aqmar.py │ │ ├── convert_bn_daffodil.py │ │ ├── convert_bsf_to_beios.py │ │ ├── convert_bsnlp.py │ │ ├── convert_en_conll03.py │ │ ├── convert_fire_2013.py │ │ ├── convert_he_iahlt.py │ │ ├── convert_hy_armtdp.py │ │ ├── convert_ijc.py │ │ ├── convert_kk_kazNERD.py │ │ ├── convert_lst20.py │ │ ├── convert_mr_l3cube.py │ │ ├── convert_my_ucsy.py │ │ ├── convert_nkjp.py │ │ ├── convert_nner22.py │ │ ├── convert_nytk.py │ │ ├── convert_ontonotes.py │ │ ├── convert_rgai.py │ │ ├── convert_sindhi_siner.py │ │ ├── convert_starlang_ner.py │ │ ├── count_entities.py │ │ ├── json_to_bio.py │ │ ├── misc_to_date.py │ │ ├── ontonotes_multitag.py │ │ ├── prepare_ner_dataset.py │ │ ├── prepare_ner_file.py │ │ ├── preprocess_wikiner.py │ │ ├── simplify_en_worldwide.py │ │ ├── simplify_ontonotes_to_worldwide.py │ │ ├── split_wikiner.py │ │ ├── suc_conll_to_iob.py │ │ ├── suc_to_iob.py │ │ └── utils.py │ ├── pos/ │ │ ├── __init__.py │ │ ├── convert_trees_to_pos.py │ │ └── remove_columns.py │ ├── prepare_depparse_treebank.py │ ├── prepare_lemma_classifier.py │ ├── prepare_lemma_treebank.py │ ├── prepare_mwt_treebank.py │ ├── prepare_pos_treebank.py │ ├── prepare_tokenizer_data.py │ ├── prepare_tokenizer_treebank.py │ ├── pretrain/ │ │ ├── __init__.py │ │ └── word_in_pretrain.py │ ├── random_split_conllu.py │ ├── sentiment/ │ │ ├── __init__.py │ │ ├── add_constituency.py │ │ ├── convert_italian_poetry_classification.py │ │ ├── convert_italian_sentence_classification.py │ │ ├── prepare_sentiment_dataset.py │ │ ├── process_MELD.py │ │ ├── process_airline.py │ │ ├── process_arguana_xml.py │ │ ├── process_corona.py │ │ ├── process_es_tass2020.py │ │ ├── process_it_sentipolc16.py │ │ ├── process_ren_chinese.py │ │ ├── process_sb10k.py │ │ ├── process_scare.py │ │ ├── process_slsd.py │ │ ├── process_sst.py │ │ ├── process_usage_german.py │ │ ├── process_utils.py │ │ └── process_vsfc_vietnamese.py │ ├── thai_syllable_dict_generator.py │ ├── tokenization/ │ │ ├── __init__.py │ │ ├── convert_ml_cochin.py │ │ ├── convert_my_alt.py │ │ ├── convert_text_files.py │ │ ├── convert_th_best.py │ │ ├── convert_th_lst20.py │ │ ├── convert_th_orchid.py │ │ ├── convert_vi_vlsp.py │ │ └── process_thai_tokenization.py │ └── vietnamese/ │ ├── __init__.py │ └── renormalize.py ├── default_paths.py ├── get_tqdm.py ├── helper_func.py ├── languages/ │ ├── __init__.py │ └── kazakh_transliteration.py ├── lemma/ │ ├── __init__.py │ └── count_ambiguous_lemmas.py ├── max_mwt_length.py ├── ner/ │ ├── __init__.py │ ├── flair_ner_tag_dataset.py │ ├── paying_annotators.py │ └── spacy_ner_tag_dataset.py ├── pretrain/ │ ├── __init__.py │ └── compare_pretrains.py ├── select_backoff.py ├── training/ │ ├── __init__.py │ ├── common.py │ ├── compose_ete_results.py │ ├── remove_constituency_optimizer.py │ ├── run_charlm.py │ ├── run_constituency.py │ ├── run_depparse.py │ ├── run_ete.py │ ├── run_lemma.py │ ├── run_lemma_classifier.py │ ├── run_mwt.py │ ├── run_ner.py │ ├── run_pos.py │ ├── run_sentiment.py │ ├── run_tokenizer.py │ └── separate_ner_pretrain.py └── visualization/ ├── README ├── __init__.py ├── conll_deprel_visualization.py ├── constants.py ├── dependency_visualization.py ├── ner_visualization.py ├── semgrex_app.py ├── semgrex_visualizer.py ├── ssurgeon_visualizer.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: bug assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error **Expected behavior** A clear and concise description of what you expected to happen. **Environment (please complete the following information):** - OS: [e.g. Windows, Ubuntu, CentOS, MacOS] - Python version: [e.g. Python 3.6.8 from Anaconda] - Stanza version: [e.g., 1.0.0] **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: enhancement assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/ISSUE_TEMPLATE/question.md ================================================ --- name: Question about: 'Question about general usage. ' title: "[QUESTION]" labels: question assignees: '' --- Before you start, make sure to check out: * Our documentation: https://stanfordnlp.github.io/stanza/ * Our FAQ: https://stanfordnlp.github.io/stanza/faq.html * Github issues (especially closed ones) Your question might have an answer in these places! If you still couldn't find the answer to your question, feel free to delete this text and write down your question. The more information you provide with your question, the faster we will be able to help you! If you have a question about an issue you're facing when using Stanza, please try to provide a detailed step-by-step guide to reproduce the issue you're facing. Try to at least provide a minimal code sample to reproduce the problem you are facing, instead of just describing it. That would greatly help us in locating the issue faster and help you resolve it! ================================================ FILE: .github/pull_request_template.md ================================================ **BEFORE YOU START**: please make sure your pull request is against the `dev` branch. We cannot accept pull requests against the `main` branch. See our [contributing guide](https://github.com/stanfordnlp/stanza/blob/main/CONTRIBUTING.md) for details. ## Description A brief and concise description of what your pull request is trying to accomplish. ## Fixes Issues A list of issues/bugs with # references. (e.g., #123) ## Unit test coverage Are there unit tests in place to make sure your code is functioning correctly? (see [here](https://github.com/stanfordnlp/stanza/blob/master/tests/test_tagger.py) for a simple example) ## Known breaking changes/behaviors Does this break anything in Stanza's existing user interface? If so, what is it and how is it addressed? ================================================ FILE: .github/stale.yml ================================================ # Number of days of inactivity before an issue becomes stale daysUntilStale: 60 # Number of days of inactivity before a stale issue is closed daysUntilClose: 7 # Issues with these labels will never be considered stale exemptLabels: - pinned - security - fixed on dev - bug - enhancement # Label to use when marking an issue as stale staleLabel: stale # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable closeComment: > This issue has been automatically closed due to inactivity. ================================================ FILE: .github/workflows/stanza-tests.yaml ================================================ name: Run Stanza Tests on: [push] jobs: Run-Stanza-Tests: runs-on: self-hosted steps: - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event." - run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!" - run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}." - name: Check out repository code uses: actions/checkout@v2 - run: echo "💡 The ${{ github.repository }} repository has been cloned to the runner." - run: echo "🖥️ The workflow is now ready to test your code on the runner." - name: Run Stanza Tests run: | # set up environment echo "Setting up environment..." bash #. $CONDA_PREFIX/etc/profile.d/conda.sh . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh conda activate stanza export STANZA_TEST_HOME=/scr/stanza_test export CORENLP_HOME=$STANZA_TEST_HOME/corenlp_dir export CLASSPATH=$CORENLP_HOME/*: echo CORENLP_HOME=$CORENLP_HOME echo CLASSPATH=$CLASSPATH # install from stanza repo being evaluated echo PWD: $pwd echo PATH: $PATH pip3 install -e . pip3 install -e .[test] pip3 install -e .[transformers] pip3 install -e .[tokenizers] pip3 install -e .[morphseg] # set up for tests echo "Running stanza test set up..." rm -rf $STANZA_TEST_HOME python3 stanza/tests/setup.py # run tests echo "Running tests..." export CUDA_VISIBLE_DEVICES=2 pytest stanza/tests - run: echo "🍏 This job's status is ${{ job.status }}." ================================================ FILE: .gitignore ================================================ # kept from original .DS_Store *.tmp *.pkl *.conllu *.lem *.toklabels # also data w/o any slash to account for symlinks data data/ stanza_resources/ stanza_test/ saved_models/ logs/ log/ *_test_treebanks wandb/ params/*/*.json !params/*/default.json # emacs backup files *~ # VI backup files? *py.swp # standard github python project gitignore # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # IDE-related .vscode/ .idea/vcs.xml .idea/inspectionProfiles/profiles_settings.xml .idea/workspace.xml # Jekyll stuff, triggered by running the docs locally .jekyll-cache/ .jekyll-metadata _site/ # symlink / directory for data files extern_data ================================================ FILE: .travis.yml ================================================ language: python python: - 3.6.5 notifications: email: false install: - pip install --quiet . - export CORENLP_HOME=~/corenlp-latest CORENLP_VERSION=stanford-corenlp-latest - export CORENLP_URL="http://nlp.stanford.edu/software/${CORENLP_VERSION}.zip" - wget $CORENLP_URL -O corenlp-latest.zip - unzip corenlp-latest.zip > unzip.log - export CORENLP_UNZIP=`grep creating unzip.log | head -n 1 | cut -d ":" -f 2` - mv $CORENLP_UNZIP $CORENLP_HOME - mkdir ~/stanza_test - mkdir ~/stanza_test/in - mkdir ~/stanza_test/out - mkdir ~/stanza_test/scripts - cp tests/data/external_server.properties ~/stanza_test/scripts - cp tests/data/example_french.json ~/stanza_test/out - cp tests/data/tiny_emb.* ~/stanza_test/in - export STANZA_TEST_HOME=~/stanza_test script: - python -m pytest -m travis tests/ ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Stanza We would love to see contributions to Stanza from the community! Contributions that we welcome include bugfixes and enhancements. If you want to report a bug or suggest a feature but don't intend to fix or implement it by yourself, please create a corresponding issue on [our issues page](https://github.com/stanfordnlp/stanza/issues). If you plan to contribute a bugfix or enhancement, please read the following. ## 🛠️ Bugfixes For bugfixes, please follow these steps: - Make sure a fix does not already exist, by searching through existing [issues](https://github.com/stanfordnlp/stanza/issues) (including closed ones) and [pull requests](https://github.com/stanfordnlp/stanza/pulls). - Confirm the bug with us by creating a bug-report issue. In your issue, you should at least include the platform and environment that you are running with, and a minimal code snippet that will reproduce the bug. - Once the bug is confirmed, you can go ahead with implementing the bugfix, and create a pull request **against the `dev` branch**. ## 💡 Enhancements For enhancements, please follow these steps: - Make sure a similar enhancement suggestion does not already exist, by searching through existing [issues](https://github.com/stanfordnlp/stanza/issues). - Create a feature-request issue and discuss about this enhancement with us. We'll need to make sure this enhancement won't break existing user interface and functionalities. - Once the enhancement is confirmed with us, you can go ahead with implementing it, and create a pull request **against the `dev` branch**. ================================================ FILE: LICENSE ================================================ Copyright 2019 The Board of Trustees of The Leland Stanford Junior University Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

Stanza: A Python NLP Library for Many Human Languages

Run Tests PyPI Version Conda Versions Python Versions
The Stanford NLP Group's official Python NLP library. It contains support for running various accurate natural language processing tools on 60+ languages and for accessing the Java Stanford CoreNLP software from Python. For detailed information please visit our [official website](https://stanfordnlp.github.io/stanza/). 🔥  A new collection of **biomedical** and **clinical** English model packages are now available, offering seamless experience for syntactic analysis and named entity recognition (NER) from biomedical literature text and clinical notes. For more information, check out our [Biomedical models documentation page](https://stanfordnlp.github.io/stanza/biomed.html). ### References If you use this library in your research, please kindly cite our [ACL2020 Stanza system demo paper](https://arxiv.org/abs/2003.07082): ```bibtex @inproceedings{qi2020stanza, title={Stanza: A {Python} Natural Language Processing Toolkit for Many Human Languages}, author={Qi, Peng and Zhang, Yuhao and Zhang, Yuhui and Bolton, Jason and Manning, Christopher D.}, booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations", year={2020} } ``` If you use our biomedical and clinical models, please also cite our [Stanza Biomedical Models description paper](https://arxiv.org/abs/2007.14640): ```bibtex @article{zhang2021biomedical, author = {Zhang, Yuhao and Zhang, Yuhui and Qi, Peng and Manning, Christopher D and Langlotz, Curtis P}, title = {Biomedical and clinical {E}nglish model packages for the {S}tanza {P}ython {NLP} library}, journal = {Journal of the American Medical Informatics Association}, year = {2021}, month = {06}, issn = {1527-974X} } ``` The PyTorch implementation of the neural pipeline in this repository is due to [Peng Qi](http://qipeng.me) (@qipeng), [Yuhao Zhang](http://yuhao.im) (@yuhaozhang), and [Yuhui Zhang](https://cs.stanford.edu/~yuhuiz/) (@yuhui-zh15), with help from [Jason Bolton](mailto:jebolton@stanford.edu) (@j38), [Tim Dozat](https://web.stanford.edu/~tdozat/) (@tdozat) and [John Bauer](https://www.linkedin.com/in/john-bauer-b3883b60/) (@AngledLuffa). Maintenance of this repo is currently led by [John Bauer](https://www.linkedin.com/in/john-bauer-b3883b60/). If you use the CoreNLP software through Stanza, please cite the CoreNLP software package and the respective modules as described [here](https://stanfordnlp.github.io/CoreNLP/#citing-stanford-corenlp-in-papers) ("Citing Stanford CoreNLP in papers"). The CoreNLP client is mostly written by [Arun Chaganty](http://arun.chagantys.org/), and [Jason Bolton](mailto:jebolton@stanford.edu) spearheaded merging the two projects together. If you use the Semgrex or Ssurgeon part of CoreNLP, please cite [our GURT paper on Semgrex and Ssurgeon](https://aclanthology.org/2023.tlt-1.7/): ```bibtex @inproceedings{bauer-etal-2023-semgrex, title = "Semgrex and Ssurgeon, Searching and Manipulating Dependency Graphs", author = "Bauer, John and Kiddon, Chlo{\'e} and Yeh, Eric and Shan, Alex and D. Manning, Christopher", booktitle = "Proceedings of the 21st International Workshop on Treebanks and Linguistic Theories (TLT, GURT/SyntaxFest 2023)", month = mar, year = "2023", address = "Washington, D.C.", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/2023.tlt-1.7", pages = "67--73", abstract = "Searching dependency graphs and manipulating them can be a time consuming and challenging task to get right. We document Semgrex, a system for searching dependency graphs, and introduce Ssurgeon, a system for manipulating the output of Semgrex. The compact language used by these systems allows for easy command line or API processing of dependencies. Additionally, integration with publicly released toolkits in Java and Python allows for searching text relations and attributes over natural text.", } ``` ## Issues and Usage Q&A To ask questions, report issues or request features 🤔, please use the [GitHub Issue Tracker](https://github.com/stanfordnlp/stanza/issues). Before creating a new issue, please make sure to search for existing issues that may solve your problem, or visit the [Frequently Asked Questions (FAQ) page](https://stanfordnlp.github.io/stanza/faq.html) on our website. ## Contributing to Stanza We welcome community contributions to Stanza in the form of bugfixes 🛠️ and enhancements 💡! If you want to contribute, please first read [our contribution guideline](CONTRIBUTING.md). ## Installation ### pip Stanza supports Python 3.6 or later. We recommend that you install Stanza via [pip](https://pip.pypa.io/en/stable/installing/), the Python package manager. To install, simply run: ```bash pip install stanza ``` This should also help resolve all of the dependencies of Stanza, for instance [PyTorch](https://pytorch.org/) 1.3.0 or above. If you currently have a previous version of `stanza` installed, use: ```bash pip install stanza -U ``` ### Anaconda To install Stanza via Anaconda, use the following conda command: ```bash conda install -c stanfordnlp stanza ``` Note that for now installing Stanza via Anaconda does not work for Python 3.10. For Python 3.10 please use pip installation. ### From Source Alternatively, you can also install from source of this git repository, which will give you more flexibility in developing on top of Stanza. For this option, run ```bash git clone https://github.com/stanfordnlp/stanza.git cd stanza pip install -e . ``` ## Running Stanza ### Getting Started with the neural pipeline To run your first Stanza pipeline, simply follow these steps in your Python interactive interpreter: ```python >>> import stanza >>> stanza.download('en') # Optional: pre-download English models (Pipeline can auto-download if needed) >>> nlp = stanza.Pipeline('en') # This sets up a default neural pipeline in English >>> doc = nlp("Barack Obama was born in Hawaii. He was elected president in 2008.") >>> doc.sentences[0].print_dependencies() ``` If you encounter `requests.exceptions.ConnectionError`, please try to use a proxy: ```python >>> import stanza >>> proxies = {'http': 'http://ip:port', 'https': 'http://ip:port'} >>> stanza.download('en', proxies=proxies) # Optional: pre-download English models (Pipeline can auto-download if needed) >>> nlp = stanza.Pipeline('en') # This sets up a default neural pipeline in English >>> doc = nlp("Barack Obama was born in Hawaii. He was elected president in 2008.") >>> doc.sentences[0].print_dependencies() ``` The last command will print out the words in the first sentence in the input string (or [`Document`](https://stanfordnlp.github.io/stanza/data_objects.html#document), as it is represented in Stanza), as well as the indices for the word that governs it in the Universal Dependencies parse of that sentence (its "head"), along with the dependency relation between the words. The output should look like: ``` ('Barack', '4', 'nsubj:pass') ('Obama', '1', 'flat') ('was', '4', 'aux:pass') ('born', '0', 'root') ('in', '6', 'case') ('Hawaii', '4', 'obl') ('.', '4', 'punct') ``` See [our getting started guide](https://stanfordnlp.github.io/stanza/installation_usage.html#getting-started) for more details. ### Accessing Java Stanford CoreNLP software Aside from the neural pipeline, this package also includes an official wrapper for accessing the Java Stanford CoreNLP software with Python code. There are a few initial setup steps. * Download [Stanford CoreNLP](https://stanfordnlp.github.io/CoreNLP/) and models for the language you wish to use * Put the model jars in the distribution folder * Tell the Python code where Stanford CoreNLP is located by setting the `CORENLP_HOME` environment variable (e.g., in *nix): `export CORENLP_HOME=/path/to/stanford-corenlp-4.5.3` We provide [comprehensive examples](https://stanfordnlp.github.io/stanza/corenlp_client.html) in our documentation that show how one can use CoreNLP through Stanza and extract various annotations from it. ### Online Colab Notebooks To get your started, we also provide interactive Jupyter notebooks in the `demo` folder. You can also open these notebooks and run them interactively on [Google Colab](https://colab.research.google.com). To view all available notebooks, follow these steps: * Go to the [Google Colab website](https://colab.research.google.com) * Navigate to `File` -> `Open notebook`, and choose `GitHub` in the pop-up menu * Note that you do **not** need to give Colab access permission to your GitHub account * Type `stanfordnlp/stanza` in the search bar, and click enter ### Trained Models for the Neural Pipeline We currently provide models for all of the [Universal Dependencies](https://universaldependencies.org/) treebanks v2.8, as well as NER models for a few widely-spoken languages. You can find instructions for downloading and using these models [here](https://stanfordnlp.github.io/stanza/models.html). ### Batching To Maximize Pipeline Speed To maximize speed performance, it is essential to run the pipeline on batches of documents. Running a for loop on one sentence at a time will be very slow. The best approach at this time is to concatenate documents together, with each document separated by a blank line (i.e., two line breaks `\n\n`). The tokenizer will recognize blank lines as sentence breaks. We are actively working on improving multi-document processing. ## Training your own neural pipelines All neural modules in this library can be trained with your own data. The tokenizer, the multi-word token (MWT) expander, the POS/morphological features tagger, the lemmatizer and the dependency parser require [CoNLL-U](https://universaldependencies.org/format.html) formatted data, while the NER model requires the BIOES format. Currently, we do not support model training via the `Pipeline` interface. Therefore, to train your own models, you need to clone this git repository and run training from the source. For detailed step-by-step guidance on how to train and evaluate your own models, please visit our [training documentation](https://stanfordnlp.github.io/stanza/training.html). ## LICENSE Stanza is released under the Apache License, Version 2.0. See the [LICENSE](https://github.com/stanfordnlp/stanza/blob/master/LICENSE) file for more details. ================================================ FILE: demo/CONLL_Dependency_Visualizer_Example.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "c0fd86c8", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\n", "\n", "# load necessary conllu files - expected to be in the demo directory along with the notebook\n", "en_file = \"en_test.conllu.txt\"\n", "\n", "# testing left to right languages\n", "conll_to_visual(en_file, \"en\", sent_count=2)\n", "conll_to_visual(en_file, \"en\", sent_count=10)\n", "#conll_to_visual(en_file, \"en\", display_all=True)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fc4b3f9b", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\n", "\n", "jp_file = \"japanese_test.conllu.txt\"\n", "conll_to_visual(jp_file, \"ja\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6852b8e8", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\n", "\n", "# testing right to left languages\n", "ar_file = \"arabic_test.conllu.txt\"\n", "conll_to_visual(ar_file, \"ar\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.22" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: demo/Dependency_Visualization_Testing.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "64b2a9e0", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.dependency_visualization import visualize_strings\n", "\n", "ar_strings = ['برلين ترفض حصول شركة اميركية على رخصة تصنيع دبابة \"ليوبارد\" الالمانية', \"هل بإمكاني مساعدتك؟\", \n", " \"أراك في مابعد\", \"لحظة من فضلك\"]\n", "# Testing with right to left language\n", "visualize_strings(ar_strings, \"ar\")" ] }, { "cell_type": "code", "execution_count": null, "id": "35ef521b", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.dependency_visualization import visualize_strings\n", "\n", "en_strings = [\"This is a sentence.\", \n", " \"He is wearing a red shirt\",\n", " \"Barack Obama was born in Hawaii. He was elected President of the United States in 2008.\"]\n", "# Testing with left to right languages\n", "visualize_strings(en_strings, \"en\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f3cf10ba", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.dependency_visualization import visualize_strings\n", "\n", "zh_strings = [\"中国是一个很有意思的国家。\"]\n", "# Testing with right to left language\n", "visualize_strings(zh_strings, \"zh\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d2b9b574", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.22" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: demo/NER_Visualization.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "abf300bb", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.ner_visualization import visualize_strings\n", "\n", "en_strings = ['''Samuel Jackson, a Christian man from Utah, went to the JFK Airport for a flight to New York.\n", " He was thinking of attending the US Open, his favorite tennis tournament besides Wimbledon.\n", " That would be a dream trip, certainly not possible since it is $5000 attendance and 5000 miles away.\n", " On the way there, he watched the Super Bowl for 2 hours and read War and Piece by Tolstoy for 1 hour.\n", " In New York, he crossed the Brooklyn Bridge and listened to the 5th symphony of Beethoven as well as\n", " \"All I want for Christmas is You\" by Mariah Carey.''', \n", " \"Barack Obama was born in Hawaii. He was elected President of the United States in 2008\"]\n", " \n", "visualize_strings(en_strings, \"en\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "5670921a", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.ner_visualization import visualize_strings\n", "\n", "zh_strings = ['''来自犹他州的基督徒塞缪尔杰克逊前往肯尼迪机场搭乘航班飞往纽约。\n", " 他正在考虑参加美国公开赛,这是除了温布尔登之外他最喜欢的网球赛事。\n", " 那将是一次梦想之旅,当然不可能,因为它的出勤费为 5000 美元,距离 5000 英里。\n", " 在去的路上,他看了 2 个小时的超级碗比赛,看了 1 个小时的托尔斯泰的《战争与碎片》。\n", " 在纽约,他穿过布鲁克林大桥,聆听了贝多芬的第五交响曲以及 玛丽亚凯莉的“圣诞节我想要的就是你”。''',\n", " \"我觉得罗家费德勒住在加州, 在美国里面。\"]\n", "visualize_strings(zh_strings, \"zh\", colors={\"PERSON\": \"yellow\", \"DATE\": \"red\", \"GPE\": \"blue\"})\n", "visualize_strings(zh_strings, \"zh\", select=['PERSON', 'DATE'])" ] }, { "cell_type": "code", "execution_count": null, "id": "b8d96072", "metadata": {}, "outputs": [], "source": [ "from stanza.utils.visualization.ner_visualization import visualize_strings\n", "\n", "ar_strings = [\".أعيش في سان فرانسيسكو ، كاليفورنيا. اسمي أليكس وأنا ألتحق بجامعة ستانفورد. أنا أدرس علوم الكمبيوتر وأستاذي هو كريس مانينغ\"\n", " , \"اسمي أليكس ، أنا من الولايات المتحدة.\", \n", " '''صامويل جاكسون ، رجل مسيحي من ولاية يوتا ، ذهب إلى مطار جون كنيدي في رحلة إلى نيويورك. كان يفكر في حضور بطولة الولايات المتحدة المفتوحة للتنس ، بطولة التنس المفضلة لديه إلى جانب بطولة ويمبلدون. ستكون هذه رحلة الأحلام ، وبالتأكيد ليست ممكنة لأنها تبلغ 5000 دولار للحضور و 5000 ميل. في الطريق إلى هناك ، شاهد Super Bowl لمدة ساعتين وقرأ War and Piece by Tolstoy لمدة ساعة واحدة. في نيويورك ، عبر جسر بروكلين واستمع إلى السيمفونية الخامسة لبيتهوفن وكذلك \"كل ما أريده في عيد الميلاد هو أنت\" لماريا كاري.''']\n", "\n", "visualize_strings(ar_strings, \"ar\", colors={\"PER\": \"pink\", \"LOC\": \"linear-gradient(90deg, #aa9cfc, #fc9ce7)\", \"ORG\": \"yellow\"})" ] }, { "cell_type": "code", "execution_count": null, "id": "22489b27", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.22" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: demo/Stanza_Beginners_Guide.ipynb ================================================ { "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Stanza-Beginners-Guide.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "56LiYCkPM7V_", "colab_type": "text" }, "source": [ "# Welcome to Stanza!\n", "\n", "![Latest Version](https://img.shields.io/pypi/v/stanza.svg?colorB=bc4545)\n", "![Python Versions](https://img.shields.io/pypi/pyversions/stanza.svg?colorB=bc4545)\n", "\n", "Stanza is a Python NLP toolkit that supports 60+ human languages. It is built with highly accurate neural network components that enable efficient training and evaluation with your own annotated data, and offers pretrained models on 100 treebanks. Additionally, Stanza provides a stable, officially maintained Python interface to Java Stanford CoreNLP Toolkit.\n", "\n", "In this tutorial, we will demonstrate how to set up Stanza and annotate text with its native neural network NLP models. For the use of the Python CoreNLP interface, please see other tutorials." ] }, { "cell_type": "markdown", "metadata": { "id": "yQff4Di5Nnq0", "colab_type": "text" }, "source": [ "## 1. Installing Stanza\n", "\n", "Note that Stanza only supports Python 3.6 and above. Installing and importing Stanza are as simple as running the following commands:" ] }, { "cell_type": "code", "metadata": { "id": "owSj1UtdEvSU", "colab_type": "code", "colab": {} }, "source": [ "# Install; note that the prefix \"!\" is not needed if you are running in a terminal\n", "!pip install stanza\n", "\n", "# Import the package\n", "import stanza" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "4ixllwEKeCJg", "colab_type": "text" }, "source": [ "### More Information\n", "\n", "For common troubleshooting, please visit our [troubleshooting page](https://stanfordnlp.github.io/stanfordnlp/installation_usage.html#troubleshooting)." ] }, { "cell_type": "markdown", "metadata": { "id": "aeyPs5ARO79d", "colab_type": "text" }, "source": [ "## 2. Downloading Models\n", "\n", "You can download models with the `stanza.download` command. The language can be specified with either a full language name (e.g., \"english\"), or a short code (e.g., \"en\"). \n", "\n", "By default, models will be saved to your `~/stanza_resources` directory. If you want to specify your own path to save the model files, you can pass a `dir=your_path` argument.\n" ] }, { "cell_type": "code", "metadata": { "id": "HDwRm-KXGcYo", "colab_type": "code", "colab": {} }, "source": [ "# Download an English model into the default directory\n", "print(\"Downloading English model...\")\n", "stanza.download('en')\n", "\n", "# Similarly, download a (simplified) Chinese model\n", "# Note that you can use verbose=False to turn off all printed messages\n", "print(\"Downloading Chinese model...\")\n", "stanza.download('zh', verbose=False)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "7HCfQ0SfdmsU", "colab_type": "text" }, "source": [ "### More Information\n", "\n", "Pretrained models are provided for 60+ different languages. For all languages, available models and the corresponding short language codes, please check out the [models page](https://stanfordnlp.github.io/stanza/models.html).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "b3-WZJrzWD2o", "colab_type": "text" }, "source": [ "## 3. Processing Text\n" ] }, { "cell_type": "markdown", "metadata": { "id": "XrnKl2m3fq2f", "colab_type": "text" }, "source": [ "### Constructing Pipeline\n", "\n", "To process a piece of text, you'll need to first construct a `Pipeline` with different `Processor` units. The pipeline is language-specific, so again you'll need to first specify the language (see examples).\n", "\n", "- By default, the pipeline will include all processors, including tokenization, multi-word token expansion, part-of-speech tagging, lemmatization, dependency parsing and named entity recognition (for supported languages). However, you can always specify what processors you want to include with the `processors` argument.\n", "\n", "- Stanza's pipeline is CUDA-aware, meaning that a CUDA-device will be used whenever it is available, otherwise CPUs will be used when a GPU is not found. You can force the pipeline to use CPU regardless by setting `use_gpu=False`.\n", "\n", "- Again, you can suppress all printed messages by setting `verbose=False`." ] }, { "cell_type": "code", "metadata": { "id": "HbiTSBDPG53o", "colab_type": "code", "colab": {} }, "source": [ "# Build an English pipeline, with all processors by default\n", "print(\"Building an English pipeline...\")\n", "en_nlp = stanza.Pipeline('en')\n", "\n", "# Build a Chinese pipeline, with customized processor list and no logging, and force it to use CPU\n", "print(\"Building a Chinese pipeline...\")\n", "zh_nlp = stanza.Pipeline('zh', processors='tokenize,lemma,pos,depparse', verbose=False, use_gpu=False)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Go123Bx8e1wt", "colab_type": "text" }, "source": [ "### Annotating Text\n", "\n", "After a pipeline is successfully constructed, you can get annotations of a piece of text simply by passing the string into the pipeline object. The pipeline will return a `Document` object, which can be used to access detailed annotations from. For example:\n" ] }, { "cell_type": "code", "metadata": { "id": "k_p0h1UTHDMm", "colab_type": "code", "colab": {} }, "source": [ "# Processing English text\n", "en_doc = en_nlp(\"Barack Obama was born in Hawaii. He was elected president in 2008.\")\n", "print(type(en_doc))\n", "\n", "# Processing Chinese text\n", "zh_doc = zh_nlp(\"达沃斯世界经济论坛是每年全球政商界领袖聚在一起的年度盛事。\")\n", "print(type(zh_doc))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "DavwCP9egzNZ", "colab_type": "text" }, "source": [ "### More Information\n", "\n", "For more information on how to construct a pipeline and information on different processors, please visit our [pipeline page](https://stanfordnlp.github.io/stanfordnlp/pipeline.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "O_PYLEGziQWR", "colab_type": "text" }, "source": [ "## 4. Accessing Annotations\n", "\n", "Annotations can be accessed from the returned `Document` object. \n", "\n", "A `Document` contains a list of `Sentence`s, and a `Sentence` contains a list of `Token`s and `Word`s. For the most part `Token`s and `Word`s overlap, but some tokens can be divided into mutiple words, for instance the French token `aux` is divided into the words `à` and `les`, while in English a word and a token are equivalent. Note that dependency parses are derived over `Word`s.\n", "\n", "Additionally, a `Span` object is used to represent annotations that are part of a document, such as named entity mentions.\n", "\n", "\n", "The following example iterate over all English sentences and words, and print the word information one by one:" ] }, { "cell_type": "code", "metadata": { "id": "B5691SpFHFZ6", "colab_type": "code", "colab": {} }, "source": [ "for i, sent in enumerate(en_doc.sentences):\n", " print(\"[Sentence {}]\".format(i+1))\n", " for word in sent.words:\n", " print(\"{:12s}\\t{:12s}\\t{:6s}\\t{:d}\\t{:12s}\".format(\\\n", " word.text, word.lemma, word.pos, word.head, word.deprel))\n", " print(\"\")" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "-AUkCkNIrusq", "colab_type": "text" }, "source": [ "The following example iterate over all extracted named entity mentions and print out their character spans and types." ] }, { "cell_type": "code", "metadata": { "id": "5Uu0-WmvsnlK", "colab_type": "code", "colab": {} }, "source": [ "print(\"Mention text\\tType\\tStart-End\")\n", "for ent in en_doc.ents:\n", " print(\"{}\\t{}\\t{}-{}\".format(ent.text, ent.type, ent.start_char, ent.end_char))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Ql1SZlZOnMLo", "colab_type": "text" }, "source": [ "And similarly for the Chinese text:" ] }, { "cell_type": "code", "metadata": { "id": "XsVcEO9tHKPG", "colab_type": "code", "colab": {} }, "source": [ "for i, sent in enumerate(zh_doc.sentences):\n", " print(\"[Sentence {}]\".format(i+1))\n", " for word in sent.words:\n", " print(\"{:12s}\\t{:12s}\\t{:6s}\\t{:d}\\t{:12s}\".format(\\\n", " word.text, word.lemma, word.pos, word.head, word.deprel))\n", " print(\"\")" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "dUhWAs8pnnHT", "colab_type": "text" }, "source": [ "Alternatively, you can directly print a `Word` object to view all its annotations as a Python dict:" ] }, { "cell_type": "code", "metadata": { "id": "6_UafNb7HHIg", "colab_type": "code", "colab": {} }, "source": [ "word = en_doc.sentences[0].words[0]\n", "print(word)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "TAQlOsuRoq2V", "colab_type": "text" }, "source": [ "### More Information\n", "\n", "For all information on different data objects, please visit our [data objects page](https://stanfordnlp.github.io/stanza/data_objects.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "hiiWHxYPpmhd", "colab_type": "text" }, "source": [ "## 5. Resources\n", "\n", "Apart from this interactive tutorial, we also provide tutorials on our website that cover a variety of use cases such as how to use different model \"packages\" for a language, how to use spaCy as a tokenizer, how to process pretokenized text without running the tokenizer, etc. For these tutorials please visit [our Tutorials page](https://stanfordnlp.github.io/stanza/tutorials.html).\n", "\n", "Other resources that you may find helpful include:\n", "\n", "- [Stanza Homepage](https://stanfordnlp.github.io/stanza/index.html)\n", "- [FAQs](https://stanfordnlp.github.io/stanza/faq.html)\n", "- [GitHub Repo](https://github.com/stanfordnlp/stanza)\n", "- [Reporting Issues](https://github.com/stanfordnlp/stanza/issues)\n", "- [Stanza System Description Paper](http://arxiv.org/abs/2003.07082)\n" ] } ] } ================================================ FILE: demo/Stanza_CoreNLP_Interface.ipynb ================================================ { "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Stanza-CoreNLP-Interface.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "2-4lzQTC9yxG", "colab_type": "text" }, "source": [ "# Stanza: A Tutorial on the Python CoreNLP Interface\n", "\n", "![Latest Version](https://img.shields.io/pypi/v/stanza.svg?colorB=bc4545)\n", "![Python Versions](https://img.shields.io/pypi/pyversions/stanza.svg?colorB=bc4545)\n", "\n", "While the Stanza library implements accurate neural network modules for basic functionalities such as part-of-speech tagging and dependency parsing, the [Stanford CoreNLP Java library](https://stanfordnlp.github.io/CoreNLP/) has been developed for years and offers more complementary features such as coreference resolution and relation extraction. To unlock these features, the Stanza library also offers an officially maintained Python interface to the CoreNLP Java library. This interface allows you to get NLP anntotations from CoreNLP by writing native Python code.\n", "\n", "\n", "This tutorial walks you through the installation, setup and basic usage of this Python CoreNLP interface. If you want to learn how to use the neural network components in Stanza, please refer to other tutorials." ] }, { "cell_type": "markdown", "metadata": { "id": "YpKwWeVkASGt", "colab_type": "text" }, "source": [ "## 1. Installation\n", "\n", "Before the installation starts, please make sure that you have Python 3 and Java installed on your computer. Since Colab already has them installed, we'll skip this procedure in this notebook." ] }, { "cell_type": "markdown", "metadata": { "id": "k1Az2ECuAfG8", "colab_type": "text" }, "source": [ "### Installing Stanza\n", "\n", "Installing and importing Stanza are as simple as running the following commands:" ] }, { "cell_type": "code", "metadata": { "id": "xiFwYAgW4Mss", "colab_type": "code", "colab": {} }, "source": [ "# Install stanza; note that the prefix \"!\" is not needed if you are running in a terminal\n", "!pip install stanza\n", "\n", "# Import stanza\n", "import stanza" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "2zFvaA8_A32_", "colab_type": "text" }, "source": [ "### Setting up Stanford CoreNLP\n", "\n", "In order for the interface to work, the Stanford CoreNLP library has to be installed and a `CORENLP_HOME` environment variable has to be pointed to the installation location.\n", "\n", "Here we are going to show you how to download and install the CoreNLP library on your machine, with Stanza's installation command:" ] }, { "cell_type": "code", "metadata": { "id": "MgK6-LPV-OdA", "colab_type": "code", "colab": {} }, "source": [ "# Download the Stanford CoreNLP package with Stanza's installation command\n", "# This'll take several minutes, depending on the network speed\n", "corenlp_dir = './corenlp'\n", "stanza.install_corenlp(dir=corenlp_dir)\n", "\n", "# Set the CORENLP_HOME environment variable to point to the installation location\n", "import os\n", "os.environ[\"CORENLP_HOME\"] = corenlp_dir" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Jdq8MT-NAhKj", "colab_type": "text" }, "source": [ "That's all for the installation! 🎉 We can now double check if the installation is successful by listing files in the CoreNLP directory. You should be able to see a number of `.jar` files by running the following command:" ] }, { "cell_type": "code", "metadata": { "id": "K5eIOaJp_tuo", "colab_type": "code", "colab": {} }, "source": [ "# Examine the CoreNLP installation folder to make sure the installation is successful\n", "!ls $CORENLP_HOME" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "S0xb9BHt__gx", "colab_type": "text" }, "source": [ "**Note 1**:\n", "If you are want to use the interface in a terminal (instead of a Colab notebook), you can properly set the `CORENLP_HOME` environment variable with:\n", "\n", "```bash\n", "export CORENLP_HOME=path_to_corenlp_dir\n", "```\n", "\n", "Here we instead set this variable with the Python `os` library, simply because `export` command is not well-supported in Colab notebook.\n", "\n", "\n", "**Note 2**:\n", "The `stanza.install_corenlp()` function is only available since Stanza v1.1.1. If you are using an earlier version of Stanza, please check out our [manual installation page](https://stanfordnlp.github.io/stanza/client_setup.html#manual-installation) for how to install CoreNLP on your computer.\n", "\n", "**Note 3**:\n", "Besides the installation function, we also provide a `stanza.download_corenlp_models()` function to help you download additional CoreNLP models for different languages that are not shipped with the default installation. Check out our [automatic installation website page](https://stanfordnlp.github.io/stanza/client_setup.html#automated-installation) for more information on how to use it." ] }, { "cell_type": "markdown", "metadata": { "id": "xJsuO6D8D05q", "colab_type": "text" }, "source": [ "## 2. Annotating Text with CoreNLP Interface" ] }, { "cell_type": "markdown", "metadata": { "id": "dZNHxXHkH1K2", "colab_type": "text" }, "source": [ "### Constructing CoreNLPClient\n", "\n", "At a high level, the CoreNLP Python interface works by first starting a background Java CoreNLP server process, and then initializing a client instance in Python which can pass the text to the background server process, and accept the returned annotation results.\n", "\n", "We wrap these functionalities in a `CoreNLPClient` class. Therefore, we need to start by importing this class from Stanza." ] }, { "cell_type": "code", "metadata": { "id": "LS4OKnqJ8wui", "colab_type": "code", "colab": {} }, "source": [ "# Import client module\n", "from stanza.server import CoreNLPClient" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "WP4Dz6PIJHeL", "colab_type": "text" }, "source": [ "After the import is done, we can construct a `CoreNLPClient` instance. The constructor method takes a Python list of annotator names as argument. Here let's explore some basic annotators including tokenization, sentence split, part-of-speech tagging, lemmatization and named entity recognition (NER). \n", "\n", "Additionally, the client constructor accepts a `memory` argument, which specifies how much memory will be allocated to the background Java process. An `endpoint` option can be used to specify a port number used by the communication between the server and the client. The default port is 9000. However, since this port is pre-occupied by a system process in Colab, we'll manually set it to 9001 in the following example.\n", "\n", "Also, here we manually set `be_quiet=True` to avoid an IO issue in colab notebook. You should be able to use `be_quiet=False` on your own computer, which will print detailed logging information from CoreNLP during usage.\n", "\n", "For more options in constructing the clients, please refer to the [CoreNLP Client Options List](https://stanfordnlp.github.io/stanza/corenlp_client.html#corenlp-client-options)." ] }, { "cell_type": "code", "metadata": { "id": "mbOBugvd9JaM", "colab_type": "code", "colab": {} }, "source": [ "# Construct a CoreNLPClient with some basic annotators, a memory allocation of 4GB, and port number 9001\n", "client = CoreNLPClient(\n", " annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \n", " memory='4G', \n", " endpoint='http://localhost:9001',\n", " be_quiet=True)\n", "print(client)\n", "\n", "# Start the background server and wait for some time\n", "# Note that in practice this is totally optional, as by default the server will be started when the first annotation is performed\n", "client.start()\n", "import time; time.sleep(10)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "kgTiVjNydmIW", "colab_type": "text" }, "source": [ "After the above code block finishes executing, if you print the background processes, you should be able to find the Java CoreNLP server running." ] }, { "cell_type": "code", "metadata": { "id": "spZrJ-oFdkdF", "colab_type": "code", "colab": {} }, "source": [ "# Print background processes and look for java\n", "# You should be able to see a StanfordCoreNLPServer java process running in the background\n", "!ps -o pid,cmd | grep java" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "KxJeJ0D2LoOs", "colab_type": "text" }, "source": [ "### Annotating Text\n", "\n", "Annotating a piece of text is as simple as passing the text into an `annotate` function of the client object. After the annotation is complete, a `Document` object will be returned with all annotations.\n", "\n", "Note that although in general annotations are very fast, the first annotation might take a while to complete in the notebook. Please stay patient." ] }, { "cell_type": "code", "metadata": { "id": "s194RnNg5z95", "colab_type": "code", "colab": {} }, "source": [ "# Annotate some text\n", "text = \"Albert Einstein was a German-born theoretical physicist. He developed the theory of relativity.\"\n", "document = client.annotate(text)\n", "print(type(document))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "semmA3e0TcM1", "colab_type": "text" }, "source": [ "## 3. Accessing Annotations\n", "\n", "Annotations can be accessed from the returned `Document` object.\n", "\n", "A `Document` contains a list of `Sentence`s, which contain a list of `Token`s. Here let's first explore the annotations stored in all tokens." ] }, { "cell_type": "code", "metadata": { "id": "lIO4B5d6Rk4I", "colab_type": "code", "colab": {} }, "source": [ "# Iterate over all tokens in all sentences, and print out the word, lemma, pos and ner tags\n", "print(\"{:12s}\\t{:12s}\\t{:6s}\\t{}\".format(\"Word\", \"Lemma\", \"POS\", \"NER\"))\n", "\n", "for i, sent in enumerate(document.sentence):\n", " print(\"[Sentence {}]\".format(i+1))\n", " for t in sent.token:\n", " print(\"{:12s}\\t{:12s}\\t{:6s}\\t{}\".format(t.word, t.lemma, t.pos, t.ner))\n", " print(\"\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "msrJfvu8VV9m", "colab_type": "text" }, "source": [ "Alternatively, you can also browse the NER results by iterating over entity mentions over the sentences. For example:" ] }, { "cell_type": "code", "metadata": { "id": "ezEjc9LeV2Xs", "colab_type": "code", "colab": {} }, "source": [ "# Iterate over all detected entity mentions\n", "print(\"{:30s}\\t{}\".format(\"Mention\", \"Type\"))\n", "\n", "for sent in document.sentence:\n", " for m in sent.mentions:\n", " print(\"{:30s}\\t{}\".format(m.entityMentionText, m.entityType))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ueGzBZ3hWzkN", "colab_type": "text" }, "source": [ "To print all annotations a sentence, token or mention has, you can simply print the corresponding obejct." ] }, { "cell_type": "code", "metadata": { "id": "4_S8o2BHXIed", "colab_type": "code", "colab": {} }, "source": [ "# Print annotations of a token\n", "print(document.sentence[0].token[0])\n", "\n", "# Print annotations of a mention\n", "print(document.sentence[0].mentions[0])" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Qp66wjZ10xia", "colab_type": "text" }, "source": [ "**Note**: Since the Stanza CoreNLP client interface simply ports the CoreNLP annotation results to native Python objects, for a comprehensive lists of available annotators and how their annotation results can be accessed, you will need to visit the [Stanford CoreNLP website](https://stanfordnlp.github.io/CoreNLP/)." ] }, { "cell_type": "markdown", "metadata": { "id": "IPqzMK90X0w3", "colab_type": "text" }, "source": [ "## 4. Shutting Down the CoreNLP Server\n", "\n", "To shut down the background CoreNLP server process, simply call the `stop` function of the client. Note that once a server is shutdown, you'll have to restart the server with the `start()` function before any annotation is requested." ] }, { "cell_type": "code", "metadata": { "id": "xrJq8lZ3Nw7b", "colab_type": "code", "colab": {} }, "source": [ "# Shut down the background CoreNLP server\n", "client.stop()\n", "\n", "time.sleep(10)\n", "!ps -o pid,cmd | grep java" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "23Vwa_ifYfF7", "colab_type": "text" }, "source": [ "### More Information\n", "\n", "For more information on how to use the `CoreNLPClient`, please go to the [CoreNLPClient documentation page](https://stanfordnlp.github.io/stanza/corenlp_client.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "YUrVT6kA_Bzx", "colab_type": "text" }, "source": [ "## 5. Simplifying Client Usage with the Python `with` statement\n", "\n", "In the above demo, we explicitly called the `client.start()` and `client.stop()` functions to start and stop a client-server connection. However, doing this in practice is usually suboptimal, since you may forget to call the `stop()` function at the end, resulting in an unused server process occupying your machine memory.\n", "\n", "To solve is, a simple solution is to use the client interface with the [Python `with` statement](https://docs.python.org/3/reference/compound_stmts.html#the-with-statement). The `with` statement provides an elegant way to automatically start and stop the server process in your Python program, without you needing to worry about this. The following code snippet demonstrates how to establish a client, annotate an example text and then stop the server with a simple `with` statement. Note that we **always recommend** you to use the `with` statement when working with the Stanza CoreNLP client interface." ] }, { "cell_type": "code", "metadata": { "id": "H0ct2-R4AvJh", "colab_type": "code", "colab": {} }, "source": [ "print(\"Starting a server with the Python \\\"with\\\" statement...\")\n", "with CoreNLPClient(annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \n", " memory='4G', endpoint='http://localhost:9001', be_quiet=True) as client:\n", " text = \"Albert Einstein was a German-born theoretical physicist.\"\n", " document = client.annotate(text)\n", "\n", " print(\"{:30s}\\t{}\".format(\"Mention\", \"Type\"))\n", " for sent in document.sentence:\n", " for m in sent.mentions:\n", " print(\"{:30s}\\t{}\".format(m.entityMentionText, m.entityType))\n", "\n", "print(\"\\nThe server should be stopped upon exit from the \\\"with\\\" statement.\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "W435Lwc4YqKb", "colab_type": "text" }, "source": [ "## 6. Other Resources\n", "\n", "- [Stanza Homepage](https://stanfordnlp.github.io/stanza/)\n", "- [FAQs](https://stanfordnlp.github.io/stanza/faq.html)\n", "- [GitHub Repo](https://github.com/stanfordnlp/stanza)\n", "- [Reporting Issues](https://github.com/stanfordnlp/stanza/issues)\n" ] } ] } ================================================ FILE: demo/arabic_test.conllu.txt ================================================ # newdoc id = assabah.20041005.0017 # newpar id = assabah.20041005.0017:p1 # sent_id = assabah.20041005.0017:p1u1 # text = سوريا: تعديل وزاري واسع يشمل 8 حقائب # orig_file_sentence ASB_ARB_20041005.0017#1 1 سوريا سُورِيَا X X--------- Foreign=Yes 0 root 0:root SpaceAfter=No|Vform=سُورِيَا|Gloss=Syria|Root=sUr|Translit=sūriyā|LTranslit=sūriyā 2 : : PUNCT G--------- _ 1 punct 1:punct Vform=:|Translit=: 3 تعديل تَعدِيل NOUN N------S1I Case=Nom|Definite=Ind|Number=Sing 6 nsubj 6:nsubj Vform=تَعدِيلٌ|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=taʿdīlun|LTranslit=taʿdīl 4 وزاري وِزَارِيّ ADJ A-----MS1I Case=Nom|Definite=Ind|Gender=Masc|Number=Sing 3 amod 3:amod Vform=وِزَارِيٌّ|Gloss=ministry,ministerial|Root=w_z_r|Translit=wizārīyun|LTranslit=wizārīy 5 واسع وَاسِع ADJ A-----MS1I Case=Nom|Definite=Ind|Gender=Masc|Number=Sing 3 amod 3:amod Vform=وَاسِعٌ|Gloss=wide,extensive,broad|Root=w_s_`|Translit=wāsiʿun|LTranslit=wāsiʿ 6 يشمل شَمِل VERB VIIA-3MS-- Aspect=Imp|Gender=Masc|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act 1 parataxis 1:parataxis Vform=يَشمَلُ|Gloss=comprise,include,contain|Root=^s_m_l|Translit=yašmalu|LTranslit=šamil 7 8 8 NUM Q--------- NumForm=Digit 6 obj 6:obj Vform=٨|Translit=8 8 حقائب حَقِيبَة NOUN N------P2I Case=Gen|Definite=Ind|Number=Plur 7 nmod 7:nmod:gen Vform=حَقَائِبَ|Gloss=briefcase,suitcase,portfolio,luggage|Root=.h_q_b|Translit=ḥaqāʾiba|LTranslit=ḥaqībat # newpar id = assabah.20041005.0017:p2 # sent_id = assabah.20041005.0017:p2u1 # text = دمشق (وكالات الانباء) - اجرى الرئيس السوري بشار الاسد تعديلا حكومياً واسعا تم بموجبه إقالة وزيري الداخلية والاعلام عن منصبيها في حين ظل محمد ناجي العطري رئيساً للحكومة. # orig_file_sentence ASB_ARB_20041005.0017#2 1 دمشق دمشق X U--------- _ 0 root 0:root Vform=دمشق|Root=OOV|Translit=dmšq 2 ( ( PUNCT G--------- _ 3 punct 3:punct SpaceAfter=No|Vform=(|Translit=( 3 وكالات وِكَالَة NOUN N------P1R Case=Nom|Definite=Cons|Number=Plur 1 dep 1:dep Vform=وِكَالَاتُ|Gloss=agency|Root=w_k_l|Translit=wikālātu|LTranslit=wikālat 4 الانباء نَبَأ NOUN N------P2D Case=Gen|Definite=Def|Number=Plur 3 nmod 3:nmod:gen SpaceAfter=No|Vform=اَلأَنبَاءِ|Gloss=news_item,report|Root=n_b_'|Translit=al-ʾanbāʾi|LTranslit=nabaʾ 5 ) ) PUNCT G--------- _ 3 punct 3:punct Vform=)|Translit=) 6 - - PUNCT G--------- _ 1 punct 1:punct Vform=-|Translit=- 7 اجرى أَجرَى VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 1 advcl 1:advcl:فِي_حِينَ Vform=أَجرَى|Gloss=conduct,carry_out,perform|Root=^g_r_y|Translit=ʾaǧrā|LTranslit=ʾaǧrā 8 الرئيس رَئِيس NOUN N------S1D Case=Nom|Definite=Def|Number=Sing 7 nsubj 7:nsubj Vform=اَلرَّئِيسُ|Gloss=president,head,chairman|Root=r_'_s|Translit=ar-raʾīsu|LTranslit=raʾīs 9 السوري سُورِيّ ADJ A-----MS1D Case=Nom|Definite=Def|Gender=Masc|Number=Sing 8 amod 8:amod Vform=اَلسُّورِيُّ|Gloss=Syrian|Root=sUr|Translit=as-sūrīyu|LTranslit=sūrīy 10 بشار بشار X U--------- _ 11 nmod 11:nmod Vform=بشار|Root=OOV|Translit=bšār 11 الاسد الاسد X U--------- _ 8 nmod 8:nmod Vform=الاسد|Root=OOV|Translit=ālāsd 12 تعديلا تَعدِيل NOUN N------S4I Case=Acc|Definite=Ind|Number=Sing 7 obj 7:obj Vform=تَعدِيلًا|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=taʿdīlan|LTranslit=taʿdīl 13 حكومياً حُكُومِيّ ADJ A-----MS4I Case=Acc|Definite=Ind|Gender=Masc|Number=Sing 12 amod 12:amod Vform=حُكُومِيًّا|Gloss=governmental,state,official|Root=.h_k_m|Translit=ḥukūmīyan|LTranslit=ḥukūmīy 14 واسعا وَاسِع ADJ A-----MS4I Case=Acc|Definite=Ind|Gender=Masc|Number=Sing 12 amod 12:amod Vform=وَاسِعًا|Gloss=wide,extensive,broad|Root=w_s_`|Translit=wāsiʿan|LTranslit=wāsiʿ 15 تم تَمّ VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 12 acl 12:acl Vform=تَمَّ|Gloss=conclude,take_place|Root=t_m_m|Translit=tamma|LTranslit=tamm 16-18 بموجبه _ _ _ _ _ _ _ _ 16 ب بِ ADP P--------- AdpType=Prep 18 case 18:case Vform=بِ|Gloss=by,with|Root=bi|Translit=bi|LTranslit=bi 17 موجب مُوجِب NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 16 fixed 16:fixed Vform=مُوجِبِ|Gloss=reason,motive|Root=w_^g_b|Translit=mūǧibi|LTranslit=mūǧib 18 ه هُوَ PRON SP---3MS2- Case=Gen|Gender=Masc|Number=Sing|Person=3|PronType=Prs 15 nmod 15:nmod:بِ_مُوجِب:gen Vform=هِ|Gloss=he,she,it|Translit=hi|LTranslit=huwa 19 إقالة إِقَالَة NOUN N------S1R Case=Nom|Definite=Cons|Number=Sing 15 nsubj 15:nsubj Vform=إِقَالَةُ|Gloss=dismissal,discharge|Root=q_y_l|Translit=ʾiqālatu|LTranslit=ʾiqālat 20 وزيري وَزِير NOUN N------D2R Case=Gen|Definite=Cons|Number=Dual 19 nmod 19:nmod:gen Vform=وَزِيرَي|Gloss=minister|Root=w_z_r|Translit=wazīray|LTranslit=wazīr 21 الداخلية دَاخِلِيّ ADJ A-----FS2D Case=Gen|Definite=Def|Gender=Fem|Number=Sing 20 amod 20:amod Vform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy 22-23 والاعلام _ _ _ _ _ _ _ _ 22 و وَ CCONJ C--------- _ 23 cc 23:cc Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa 23 الإعلام إِعلَام NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 21 conj 20:amod|21:conj Vform=اَلإِعلَامِ|Gloss=information,media|Root=`_l_m|Translit=al-ʾiʿlāmi|LTranslit=ʾiʿlām 24 عن عَن ADP P--------- AdpType=Prep 25 case 25:case Vform=عَن|Gloss=about,from|Root=`an|Translit=ʿan|LTranslit=ʿan 25-26 منصبيها _ _ _ _ _ _ _ _ 25 منصبي مَنصِب NOUN N------D2R Case=Gen|Definite=Cons|Number=Dual 19 nmod 19:nmod:عَن:gen Vform=مَنصِبَي|Gloss=post,position,office|Root=n_.s_b|Translit=manṣibay|LTranslit=manṣib 26 ها هُوَ PRON SP---3FS2- Case=Gen|Gender=Fem|Number=Sing|Person=3|PronType=Prs 25 nmod 25:nmod:gen Vform=هَا|Gloss=he,she,it|Translit=hā|LTranslit=huwa 27 في فِي ADP P--------- AdpType=Prep 7 mark 7:mark Vform=فِي|Gloss=in|Root=fI|Translit=fī|LTranslit=fī 28 حين حِينَ ADP PI------2- AdpType=Prep|Case=Gen 7 mark 7:mark Vform=حِينِ|Gloss=when|Root=.h_y_n|Translit=ḥīni|LTranslit=ḥīna 29 ظل ظَلّ VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 7 parataxis 7:parataxis Vform=ظَلَّ|Gloss=remain,continue|Root=.z_l_l|Translit=ẓalla|LTranslit=ẓall 30 محمد محمد X U--------- _ 32 nmod 32:nmod Vform=محمد|Root=OOV|Translit=mḥmd 31 ناجي ناجي X U--------- _ 32 nmod 32:nmod Vform=ناجي|Root=OOV|Translit=nāǧy 32 العطري العطري X U--------- _ 29 nsubj 29:nsubj Vform=العطري|Root=OOV|Translit=ālʿṭry 33 رئيساً رَئِيس NOUN N------S4I Case=Acc|Definite=Ind|Number=Sing 29 xcomp 29:xcomp Vform=رَئِيسًا|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsan|LTranslit=raʾīs 34-35 للحكومة _ _ _ _ _ _ _ SpaceAfter=No 34 ل لِ ADP P--------- AdpType=Prep 35 case 35:case Vform=لِ|Gloss=for,to|Root=l|Translit=li|LTranslit=li 35 الحكومة حُكُومَة NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 33 nmod 33:nmod:لِ:gen Vform=اَلحُكُومَةِ|Gloss=government,administration|Root=.h_k_m|Translit=al-ḥukūmati|LTranslit=ḥukūmat 36 . . PUNCT G--------- _ 1 punct 1:punct Vform=.|Translit=. # newpar id = assabah.20041005.0017:p3 # sent_id = assabah.20041005.0017:p3u1 # text = واضافت المصادر ان مهدي دخل الله رئيس تحرير صحيفة الحزب الحاكم والليبرالي التوجهات تسلم منصب وزير الاعلام خلفا لاحمد الحسن فيما تسلم اللواء غازي كنعان رئيس شعبة الامن السياسي منصب وزير الداخلية. # orig_file_sentence ASB_ARB_20041005.0017#3 1-2 واضافت _ _ _ _ _ _ _ _ 1 و وَ CCONJ C--------- _ 0 root 0:root Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa 2 أضافت أَضَاف VERB VP-A-3FS-- Aspect=Perf|Gender=Fem|Number=Sing|Person=3|Voice=Act 1 parataxis 1:parataxis Vform=أَضَافَت|Gloss=add,attach,receive_as_guest|Root=.d_y_f|Translit=ʾaḍāfat|LTranslit=ʾaḍāf 3 المصادر مَصدَر NOUN N------P1D Case=Nom|Definite=Def|Number=Plur 2 nsubj 2:nsubj Vform=اَلمَصَادِرُ|Gloss=source|Root=.s_d_r|Translit=al-maṣādiru|LTranslit=maṣdar 4 ان أَنَّ SCONJ C--------- _ 16 mark 16:mark Vform=أَنَّ|Gloss=that|Root='_n|Translit=ʾanna|LTranslit=ʾanna 5 مهدي مهدي X U--------- _ 6 nmod 6:nmod Vform=مهدي|Root=OOV|Translit=mhdy 6 دخل دخل X U--------- _ 16 nsubj 16:nsubj Vform=دخل|Root=OOV|Translit=dḫl 7 الله الله X U--------- _ 6 nmod 6:nmod Vform=الله|Root=OOV|Translit=āllh 8 رئيس رَئِيس NOUN N------S4R Case=Acc|Definite=Cons|Number=Sing 6 nmod 6:nmod:acc Vform=رَئِيسَ|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsa|LTranslit=raʾīs 9 تحرير تَحرِير NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 8 nmod 8:nmod:gen Vform=تَحرِيرِ|Gloss=liberation,liberating,editorship,editing|Root=.h_r_r|Translit=taḥrīri|LTranslit=taḥrīr 10 صحيفة صَحِيفَة NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 9 nmod 9:nmod:gen Vform=صَحِيفَةِ|Gloss=newspaper,sheet,leaf|Root=.s_.h_f|Translit=ṣaḥīfati|LTranslit=ṣaḥīfat 11 الحزب حِزب NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 10 nmod 10:nmod:gen Vform=اَلحِزبِ|Gloss=party,band|Root=.h_z_b|Translit=al-ḥizbi|LTranslit=ḥizb 12 الحاكم حَاكِم NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 11 nmod 11:nmod:gen Vform=اَلحَاكِمِ|Gloss=ruler,governor|Root=.h_k_m|Translit=al-ḥākimi|LTranslit=ḥākim 13-14 والليبرالي _ _ _ _ _ _ _ _ 13 و وَ CCONJ C--------- _ 6 cc 6:cc Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa 14 الليبرالي لِيبِرَالِيّ ADJ A-----MS4D Case=Acc|Definite=Def|Gender=Masc|Number=Sing 6 amod 6:amod Vform=اَللِّيبِرَالِيَّ|Gloss=liberal|Root=lIbirAl|Translit=al-lībirālīya|LTranslit=lībirālīy 15 التوجهات تَوَجُّه NOUN N------P2D Case=Gen|Definite=Def|Number=Plur 14 nmod 14:nmod:gen Vform=اَلتَّوَجُّهَاتِ|Gloss=attitude,approach|Root=w_^g_h|Translit=at-tawaǧǧuhāti|LTranslit=tawaǧǧuh 16 تسلم تَسَلَّم VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 2 ccomp 2:ccomp Vform=تَسَلَّمَ|Gloss=receive,assume|Root=s_l_m|Translit=tasallama|LTranslit=tasallam 17 منصب مَنصِب NOUN N------S4R Case=Acc|Definite=Cons|Number=Sing 16 obj 16:obj Vform=مَنصِبَ|Gloss=post,position,office|Root=n_.s_b|Translit=manṣiba|LTranslit=manṣib 18 وزير وَزِير NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 17 nmod 17:nmod:gen Vform=وَزِيرِ|Gloss=minister|Root=w_z_r|Translit=wazīri|LTranslit=wazīr 19 الاعلام عَلَم NOUN N------P2D Case=Gen|Definite=Def|Number=Plur 18 nmod 18:nmod:gen Vform=اَلأَعلَامِ|Gloss=flag,banner,badge|Root=`_l_m|Translit=al-ʾaʿlāmi|LTranslit=ʿalam 20 خلفا خَلَف NOUN N------S4I Case=Acc|Definite=Ind|Number=Sing 16 obl 16:obl:acc Vform=خَلَفًا|Gloss=substitute,scion|Root=_h_l_f|Translit=ḫalafan|LTranslit=ḫalaf 21-22 لاحمد _ _ _ _ _ _ _ _ 21 ل لِ ADP P--------- AdpType=Prep 23 case 23:case Vform=لِ|Gloss=for,to|Root=l|Translit=li|LTranslit=li 22 أحمد أَحمَد NOUN N------S2I Case=Gen|Definite=Ind|Number=Sing 23 nmod 23:nmod:gen Vform=أَحمَدَ|Gloss=Ahmad|Root=.h_m_d|Translit=ʾaḥmada|LTranslit=ʾaḥmad 23 الحسن الحسن X U--------- _ 20 nmod 20:nmod:لِ Vform=الحسن|Root=OOV|Translit=ālḥsn 24 فيما فِيمَا CCONJ C--------- _ 25 cc 25:cc Vform=فِيمَا|Gloss=while,during_which|Root=fI|Translit=fīmā|LTranslit=fīmā 25 تسلم تَسَلَّم VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 16 conj 2:ccomp|16:conj Vform=تَسَلَّمَ|Gloss=receive,assume|Root=s_l_m|Translit=tasallama|LTranslit=tasallam 26 اللواء لِوَاء NOUN N------S1D Case=Nom|Definite=Def|Number=Sing 25 nsubj 25:nsubj Vform=اَللِّوَاءُ|Gloss=banner,flag|Root=l_w_y|Translit=al-liwāʾu|LTranslit=liwāʾ 27 غازي غازي X U--------- _ 28 nmod 28:nmod Vform=غازي|Root=OOV|Translit=ġāzy 28 كنعان كنعان X U--------- _ 26 nmod 26:nmod Vform=كنعان|Root=OOV|Translit=knʿān 29 رئيس رَئِيس NOUN N------S1R Case=Nom|Definite=Cons|Number=Sing 26 nmod 26:nmod:nom Vform=رَئِيسُ|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsu|LTranslit=raʾīs 30 شعبة شُعبَة NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 29 nmod 29:nmod:gen Vform=شُعبَةِ|Gloss=branch,subdivision|Root=^s_`_b|Translit=šuʿbati|LTranslit=šuʿbat 31 الامن أَمن NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 30 nmod 30:nmod:gen Vform=اَلأَمنِ|Gloss=security,safety|Root='_m_n|Translit=al-ʾamni|LTranslit=ʾamn 32 السياسي سِيَاسِيّ ADJ A-----MS2D Case=Gen|Definite=Def|Gender=Masc|Number=Sing 31 amod 31:amod Vform=اَلسِّيَاسِيِّ|Gloss=political|Root=s_w_s|Translit=as-siyāsīyi|LTranslit=siyāsīy 33 منصب مَنصِب NOUN N------S4R Case=Acc|Definite=Cons|Number=Sing 25 obj 25:obj Vform=مَنصِبَ|Gloss=post,position,office|Root=n_.s_b|Translit=manṣiba|LTranslit=manṣib 34 وزير وَزِير NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 33 nmod 33:nmod:gen Vform=وَزِيرِ|Gloss=minister|Root=w_z_r|Translit=wazīri|LTranslit=wazīr 35 الداخلية دَاخِلِيّ ADJ A-----FS2D Case=Gen|Definite=Def|Gender=Fem|Number=Sing 34 amod 34:amod SpaceAfter=No|Vform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy 36 . . PUNCT G--------- _ 1 punct 1:punct Vform=.|Translit=. # newpar id = assabah.20041005.0017:p4 # sent_id = assabah.20041005.0017:p4u1 # text = وذكرت وكالة الانباء السورية ان التعديل شمل ثماني حقائب بينها وزارتا الداخلية والاقتصاد. # orig_file_sentence ASB_ARB_20041005.0017#4 1-2 وذكرت _ _ _ _ _ _ _ _ 1 و وَ CCONJ C--------- _ 0 root 0:root Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa 2 ذكرت ذَكَر VERB VP-A-3FS-- Aspect=Perf|Gender=Fem|Number=Sing|Person=3|Voice=Act 1 parataxis 1:parataxis Vform=ذَكَرَت|Gloss=mention,cite,remember|Root=_d_k_r|Translit=ḏakarat|LTranslit=ḏakar 3 وكالة وِكَالَة NOUN N------S1R Case=Nom|Definite=Cons|Number=Sing 2 nsubj 2:nsubj Vform=وِكَالَةُ|Gloss=agency|Root=w_k_l|Translit=wikālatu|LTranslit=wikālat 4 الانباء نَبَأ NOUN N------P2D Case=Gen|Definite=Def|Number=Plur 3 nmod 3:nmod:gen Vform=اَلأَنبَاءِ|Gloss=news_item,report|Root=n_b_'|Translit=al-ʾanbāʾi|LTranslit=nabaʾ 5 السورية سُورِيّ ADJ A-----FS1D Case=Nom|Definite=Def|Gender=Fem|Number=Sing 3 amod 3:amod Vform=اَلسُّورِيَّةُ|Gloss=Syrian|Root=sUr|Translit=as-sūrīyatu|LTranslit=sūrīy 6 ان أَنَّ SCONJ C--------- _ 8 mark 8:mark Vform=أَنَّ|Gloss=that|Root='_n|Translit=ʾanna|LTranslit=ʾanna 7 التعديل تَعدِيل NOUN N------S4D Case=Acc|Definite=Def|Number=Sing 8 obl 8:obl:acc Vform=اَلتَّعدِيلَ|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=at-taʿdīla|LTranslit=taʿdīl 8 شمل شَمِل VERB VP-A-3MS-- Aspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act 2 ccomp 2:ccomp Vform=شَمِلَ|Gloss=comprise,include,contain|Root=^s_m_l|Translit=šamila|LTranslit=šamil 9 ثماني ثَمَانُون NUM QL------4R Case=Acc|Definite=Cons|NumForm=Word 8 obj 8:obj Vform=ثَمَانِي|Gloss=eighty|Root=_t_m_n|Translit=ṯamānī|LTranslit=ṯamānūn 10 حقائب حَقِيبَة NOUN N------P2I Case=Gen|Definite=Ind|Number=Plur 9 nmod 9:nmod:gen Vform=حَقَائِبَ|Gloss=briefcase,suitcase,portfolio,luggage|Root=.h_q_b|Translit=ḥaqāʾiba|LTranslit=ḥaqībat 11-12 بينها _ _ _ _ _ _ _ _ 11 بين بَينَ ADP PI------4- AdpType=Prep|Case=Acc 12 case 12:case Vform=بَينَ|Gloss=between,among|Root=b_y_n|Translit=bayna|LTranslit=bayna 12 ها هُوَ PRON SP---3FS2- Case=Gen|Gender=Fem|Number=Sing|Person=3|PronType=Prs 10 obl 10:obl:بَينَ:gen Vform=هَا|Gloss=he,she,it|Translit=hā|LTranslit=huwa 13 وزارتا وِزَارَة NOUN N------D1R Case=Nom|Definite=Cons|Number=Dual 12 nsubj 12:nsubj Vform=وِزَارَتَا|Gloss=ministry|Root=w_z_r|Translit=wizāratā|LTranslit=wizārat 14 الداخلية دَاخِلِيّ ADJ A-----FS2D Case=Gen|Definite=Def|Gender=Fem|Number=Sing 13 amod 13:amod Vform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy 15-16 والاقتصاد _ _ _ _ _ _ _ SpaceAfter=No 15 و وَ CCONJ C--------- _ 16 cc 16:cc Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa 16 الاقتصاد اِقتِصَاد NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 14 conj 13:amod|14:conj Vform=اَلِاقتِصَادِ|Gloss=economy,saving|Root=q_.s_d|Translit=al-i-ʼqtiṣādi|LTranslit=iqtiṣād 17 . . PUNCT G--------- _ 1 punct 1:punct Vform=.|Translit=. ================================================ FILE: demo/corenlp.py ================================================ from stanza.server import CoreNLPClient # example text print('---') print('input text') print('') text = "Chris Manning is a nice person. Chris wrote a simple sentence. He also gives oranges to people." print(text) # set up the client print('---') print('starting up Java Stanford CoreNLP Server...') # set up the client with CoreNLPClient(annotators=['tokenize','ssplit','pos','lemma','ner','parse','depparse','coref'], timeout=60000, memory='16G') as client: # submit the request to the server ann = client.annotate(text) # get the first sentence sentence = ann.sentence[0] # get the dependency parse of the first sentence print('---') print('dependency parse of first sentence') dependency_parse = sentence.basicDependencies print(dependency_parse) # get the constituency parse of the first sentence print('---') print('constituency parse of first sentence') constituency_parse = sentence.parseTree print(constituency_parse) # get the first subtree of the constituency parse print('---') print('first subtree of constituency parse') print(constituency_parse.child[0]) # get the value of the first subtree print('---') print('value of first subtree of constituency parse') print(constituency_parse.child[0].value) # get the first token of the first sentence print('---') print('first token of first sentence') token = sentence.token[0] print(token) # get the part-of-speech tag print('---') print('part of speech tag of token') token.pos print(token.pos) # get the named entity tag print('---') print('named entity tag of token') print(token.ner) # get an entity mention from the first sentence print('---') print('first entity mention in sentence') print(sentence.mentions[0]) # access the coref chain print('---') print('coref chains for the example') print(ann.corefChain) # Use tokensregex patterns to find who wrote a sentence. pattern = '([ner: PERSON]+) /wrote/ /an?/ []{0,3} /sentence|article/' matches = client.tokensregex(text, pattern) # sentences contains a list with matches for each sentence. assert len(matches["sentences"]) == 3 # length tells you whether or not there are any matches in this assert matches["sentences"][1]["length"] == 1 # You can access matches like most regex groups. matches["sentences"][1]["0"]["text"] == "Chris wrote a simple sentence" matches["sentences"][1]["0"]["1"]["text"] == "Chris" # Use semgrex patterns to directly find who wrote what. pattern = '{word:wrote} >nsubj {}=subject >obj {}=object' matches = client.semgrex(text, pattern) # sentences contains a list with matches for each sentence. assert len(matches["sentences"]) == 3 # length tells you whether or not there are any matches in this assert matches["sentences"][1]["length"] == 1 # You can access matches like most regex groups. matches["sentences"][1]["0"]["text"] == "wrote" matches["sentences"][1]["0"]["$subject"]["text"] == "Chris" matches["sentences"][1]["0"]["$object"]["text"] == "sentence" ================================================ FILE: demo/en_test.conllu.txt ================================================ # newdoc id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200 # sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0001 # newpar id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-p0001 # text = What if Google Morphed Into GoogleOS? 1 What what PRON WP PronType=Int 0 root 0:root _ 2 if if SCONJ IN _ 4 mark 4:mark _ 3 Google Google PROPN NNP Number=Sing 4 nsubj 4:nsubj _ 4 Morphed morph VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 advcl 1:advcl:if _ 5 Into into ADP IN _ 6 case 6:case _ 6 GoogleOS GoogleOS PROPN NNP Number=Sing 4 obl 4:obl:into SpaceAfter=No 7 ? ? PUNCT . _ 4 punct 4:punct _ # sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0002 # text = What if Google expanded on its search-engine (and now e-mail) wares into a full-fledged operating system? 1 What what PRON WP PronType=Int 0 root 0:root _ 2 if if SCONJ IN _ 4 mark 4:mark _ 3 Google Google PROPN NNP Number=Sing 4 nsubj 4:nsubj _ 4 expanded expand VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 advcl 1:advcl:if _ 5 on on ADP IN _ 15 case 15:case _ 6 its its PRON PRP$ Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 15 nmod:poss 15:nmod:poss _ 7 search search NOUN NN Number=Sing 9 compound 9:compound SpaceAfter=No 8 - - PUNCT HYPH _ 9 punct 9:punct SpaceAfter=No 9 engine engine NOUN NN Number=Sing 15 compound 15:compound _ 10 ( ( PUNCT -LRB- _ 9 punct 9:punct SpaceAfter=No 11 and and CCONJ CC _ 13 cc 13:cc _ 12 now now ADV RB _ 13 advmod 13:advmod _ 13 e-mail e-mail NOUN NN Number=Sing 9 conj 9:conj:and|15:compound SpaceAfter=No 14 ) ) PUNCT -RRB- _ 15 punct 15:punct _ 15 wares wares NOUN NNS Number=Plur 4 obl 4:obl:on _ 16 into into ADP IN _ 22 case 22:case _ 17 a a DET DT Definite=Ind|PronType=Art 22 det 22:det _ 18 full full ADV RB _ 20 advmod 20:advmod SpaceAfter=No 19 - - PUNCT HYPH _ 20 punct 20:punct SpaceAfter=No 20 fledged fledged ADJ JJ Degree=Pos 22 amod 22:amod _ 21 operating operating NOUN NN Number=Sing 22 compound 22:compound _ 22 system system NOUN NN Number=Sing 4 obl 4:obl:into SpaceAfter=No 23 ? ? PUNCT . _ 4 punct 4:punct _ # sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0003 # text = [via Microsoft Watch from Mary Jo Foley ] 1 [ [ PUNCT -LRB- _ 4 punct 4:punct SpaceAfter=No 2 via via ADP IN _ 4 case 4:case _ 3 Microsoft Microsoft PROPN NNP Number=Sing 4 compound 4:compound _ 4 Watch Watch PROPN NNP Number=Sing 0 root 0:root _ 5 from from ADP IN _ 6 case 6:case _ 6 Mary Mary PROPN NNP Number=Sing 4 nmod 4:nmod:from _ 7 Jo Jo PROPN NNP Number=Sing 6 flat 6:flat _ 8 Foley Foley PROPN NNP Number=Sing 6 flat 6:flat _ 9 ] ] PUNCT -RRB- _ 4 punct 4:punct _ # newdoc id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700 # sent_id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-0001 # newpar id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-p0001 # text = (And, by the way, is anybody else just a little nostalgic for the days when that was a good thing?) 1 ( ( PUNCT -LRB- _ 14 punct 14:punct SpaceAfter=No 2 And and CCONJ CC _ 14 cc 14:cc SpaceAfter=No 3 , , PUNCT , _ 14 punct 14:punct _ 4 by by ADP IN _ 6 case 6:case _ 5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _ 6 way way NOUN NN Number=Sing 14 obl 14:obl:by SpaceAfter=No 7 , , PUNCT , _ 14 punct 14:punct _ 8 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 14 cop 14:cop _ 9 anybody anybody PRON NN Number=Sing 14 nsubj 14:nsubj _ 10 else else ADJ JJ Degree=Pos 9 amod 9:amod _ 11 just just ADV RB _ 13 advmod 13:advmod _ 12 a a DET DT Definite=Ind|PronType=Art 13 det 13:det _ 13 little little ADJ JJ Degree=Pos 14 obl:npmod 14:obl:npmod _ 14 nostalgic nostalgic NOUN NN Number=Sing 0 root 0:root _ 15 for for ADP IN _ 17 case 17:case _ 16 the the DET DT Definite=Def|PronType=Art 17 det 17:det _ 17 days day NOUN NNS Number=Plur 14 nmod 14:nmod:for|23:obl:npmod _ 18 when when ADV WRB PronType=Rel 23 advmod 17:ref _ 19 that that PRON DT Number=Sing|PronType=Dem 23 nsubj 23:nsubj _ 20 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 23 cop 23:cop _ 21 a a DET DT Definite=Ind|PronType=Art 23 det 23:det _ 22 good good ADJ JJ Degree=Pos 23 amod 23:amod _ 23 thing thing NOUN NN Number=Sing 17 acl:relcl 17:acl:relcl SpaceAfter=No 24 ? ? PUNCT . _ 14 punct 14:punct SpaceAfter=No 25 ) ) PUNCT -RRB- _ 14 punct 14:punct _ ================================================ FILE: demo/japanese_test.conllu.txt ================================================ # newdoc id = test-s1 # sent_id = test-s1 # text = これに不快感を示す住民はいましたが,現在,表立って反対や抗議の声を挙げている住民はいないようです。 1 これ 此れ PRON 代名詞 _ 6 obl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=代名詞|SpaceAfter=No|UnidicInfo=,此れ,これ,これ,コレ,,,コレ,コレ,此れ 2 に に ADP 助詞-格助詞 _ 1 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,に,に,に,ニ,,,ニ,ニ,に 3 不快 不快 NOUN 名詞-普通名詞-形状詞可能 _ 4 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,不快,不快,不快,フカイ,,,フカイ,フカイカン,不快感 4 感 感 NOUN 名詞-普通名詞-一般 _ 6 obj _ BunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,感,感,感,カン,,,カン,フカイカン,不快感 5 を を ADP 助詞-格助詞 _ 4 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,を,を,を,オ,,,ヲ,ヲ,を 6 示す 示す VERB 動詞-一般-五段-サ行 _ 7 acl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-五段-サ行|SpaceAfter=No|UnidicInfo=,示す,示す,示す,シメス,,,シメス,シメス,示す 7 住民 住民 NOUN 名詞-普通名詞-一般 _ 9 nsubj _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,住民,住民,住民,ジューミン,,,ジュウミン,ジュウミン,住民 8 は は ADP 助詞-係助詞 _ 7 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は 9 い 居る VERB 動詞-非自立可能-上一段-ア行 _ 29 advcl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,イル,居る 10 まし ます AUX 助動詞-助動詞-マス _ 9 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-マス|SpaceAfter=No|UnidicInfo=,ます,まし,ます,マシ,,,マス,マス,ます 11 た た AUX 助動詞-助動詞-タ _ 9 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-タ|SpaceAfter=No|UnidicInfo=,た,た,た,タ,,,タ,タ,た 12 が が SCONJ 助詞-接続助詞 _ 9 mark _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-接続助詞|SpaceAfter=No|UnidicInfo=,が,が,が,ガ,,,ガ,ガ,が 13 , , PUNCT 補助記号-読点 _ 9 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,,,,,,,,,,, 14 現在 現在 ADV 名詞-普通名詞-副詞可能 _ 16 advmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=副詞|SpaceAfter=No|UnidicInfo=,現在,現在,現在,ゲンザイ,,,ゲンザイ,ゲンザイ,現在 15 , , PUNCT 補助記号-読点 _ 14 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,,,,,,,,,,, 16 表立っ 表立つ VERB 動詞-一般-五段-タ行 _ 24 advcl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-五段-タ行|SpaceAfter=No|UnidicInfo=,表立つ,表立っ,表立つ,オモテダッ,,,オモテダツ,オモテダツ,表立つ 17 て て SCONJ 助詞-接続助詞 _ 16 mark _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-接続助詞|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テ,て 18 反対 反対 NOUN 名詞-普通名詞-サ変形状詞可能 _ 20 nmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,反対,反対,反対,ハンタイ,,,ハンタイ,ハンタイ,反対 19 や や ADP 助詞-副助詞 _ 18 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-副助詞|SpaceAfter=No|UnidicInfo=,や,や,や,ヤ,,,ヤ,ヤ,や 20 抗議 抗議 NOUN 名詞-普通名詞-サ変可能 _ 22 nmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,抗議,抗議,抗議,コーギ,,,コウギ,コウギ,抗議 21 の の ADP 助詞-格助詞 _ 20 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,の,の,の,ノ,,,ノ,ノ,の 22 声 声 NOUN 名詞-普通名詞-一般 _ 24 obj _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,声,声,声,コエ,,,コエ,コエ,声 23 を を ADP 助詞-格助詞 _ 22 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,を,を,を,オ,,,ヲ,ヲ,を 24 挙げ 上げる VERB 動詞-非自立可能-下一段-ガ行 _ 27 acl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-下一段-ガ行|SpaceAfter=No|UnidicInfo=,上げる,挙げ,挙げる,アゲ,,,アゲル,アゲル,上げる 25 て て SCONJ 助詞-接続助詞 _ 24 mark _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-上一段-ア行|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テイル,ている 26 いる 居る VERB 動詞-非自立可能-上一段-ア行 _ 25 fixed _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=助動詞-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,いる,いる,イル,,,イル,テイル,ている 27 住民 住民 NOUN 名詞-普通名詞-一般 _ 29 nsubj _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,住民,住民,住民,ジューミン,,,ジュウミン,ジュウミン,住民 28 は は ADP 助詞-係助詞 _ 27 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は 29 い 居る VERB 動詞-非自立可能-上一段-ア行 _ 0 root _ BunsetuBILabel=B|BunsetuPositionType=ROOT|LUWBILabel=B|LUWPOS=動詞-一般-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,イル,居る 30 ない ない AUX 助動詞-助動詞-ナイ Polarity=Neg 29 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-ナイ|SpaceAfter=No|UnidicInfo=,ない,ない,ない,ナイ,,,ナイ,ナイ,ない 31 よう 様 AUX 形状詞-助動詞語幹 _ 29 aux _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=形状詞-助動詞語幹|PrevUDLemma=よう|SpaceAfter=No|UnidicInfo=,様,よう,よう,ヨー,,,ヨウ,ヨウ,様 32 です です AUX 助動詞-助動詞-デス _ 29 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-デス|PrevUDLemma=だ|SpaceAfter=No|UnidicInfo=,です,です,です,デス,,,デス,デス,です 33 。 。 PUNCT 補助記号-句点 _ 29 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。 # newdoc id = test-s2 # sent_id = test-s2 # text = 幸福の科学側からは,特にどうしてほしいという要望はいただいていません。 1 幸福 幸福 NOUN 名詞-普通名詞-形状詞可能 _ 4 nmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,幸福,幸福,幸福,コーフク,,,コウフク,コウフクノカガクガワ,幸福の科学側 2 の の ADP 助詞-格助詞 _ 1 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,の,の,の,ノ,,,ノ,コウフクノカガクガワ,幸福の科学側 3 科学 科学 NOUN 名詞-普通名詞-サ変可能 _ 4 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,科学,科学,科学,カガク,,,カガク,コウフクノカガクガワ,幸福の科学側 4 側 側 NOUN 名詞-普通名詞-一般 _ 17 obl _ BunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,側,側,側,ガワ,,,ガワ,コウフクノカガクガワ,幸福の科学側 5 から から ADP 助詞-格助詞 _ 4 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,から,から,から,カラ,,,カラ,カラ,から 6 は は ADP 助詞-係助詞 _ 4 case _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は 7 , , PUNCT 補助記号-読点 _ 4 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,,,,,,,,,,, 8 特に 特に ADV 副詞 _ 17 advmod _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=副詞|SpaceAfter=No|UnidicInfo=,特に,特に,特に,トクニ,,,トクニ,トクニ,特に 9 どう どう ADV 副詞 _ 15 advcl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,どう,どう,どう,ドー,,,ドウ,ドウスル,どうする 10 し 為る AUX 動詞-非自立可能-サ行変格 _ 9 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,し,する,シ,,,スル,ドウスル,どうする 11 て て SCONJ 助詞-接続助詞 _ 9 mark _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-形容詞|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テホシイ,てほしい 12 ほしい 欲しい AUX 形容詞-非自立可能-形容詞 _ 11 fixed _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=助動詞-形容詞|PrevUDLemma=ほしい|SpaceAfter=No|UnidicInfo=,欲しい,ほしい,ほしい,ホシー,,,ホシイ,テホシイ,てほしい 13 と と ADP 助詞-格助詞 _ 9 case _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,と,と,と,ト,,,ト,トイウ,という 14 いう 言う VERB 動詞-一般-五段-ワア行 _ 13 fixed _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,言う,いう,いう,イウ,,,イウ,トイウ,という 15 要望 要望 NOUN 名詞-普通名詞-サ変可能 _ 17 nsubj _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,要望,要望,要望,ヨーボー,,,ヨウボウ,ヨウボウ,要望 16 は は ADP 助詞-係助詞 _ 15 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は 17 いただい 頂く VERB 動詞-非自立可能-五段-カ行 _ 0 root _ BunsetuBILabel=B|BunsetuPositionType=ROOT|LUWBILabel=B|LUWPOS=動詞-一般-五段-カ行|PrevUDLemma=いただく|SpaceAfter=No|UnidicInfo=,頂く,いただい,いただく,イタダイ,,,イタダク,イタダク,頂く 18 て て SCONJ 助詞-接続助詞 _ 17 mark _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-上一段-ア行|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テイル,ている 19 い 居る VERB 動詞-非自立可能-上一段-ア行 _ 18 fixed _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=助動詞-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,テイル,ている 20 ませ ます AUX 助動詞-助動詞-マス _ 17 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-マス|SpaceAfter=No|UnidicInfo=,ます,ませ,ます,マセ,,,マス,マス,ます 21 ん ず AUX 助動詞-助動詞-ヌ Polarity=Neg 17 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-ヌ|PrevUDLemma=ぬ|SpaceAfter=No|UnidicInfo=,ず,ん,ぬ,ン,,,ヌ,ズ,ず 22 。 。 PUNCT 補助記号-句点 _ 17 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。 # newdoc id = test-s3 # sent_id = test-s3 # text = 星取り参加は当然とされ,不参加は白眼視される。 1 星取り 星取り NOUN 名詞-普通名詞-一般 _ 2 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,星取り,星取り,星取り,ホシトリ,,,ホシトリ,ホシトリサンカ,星取り参加 2 参加 参加 NOUN 名詞-普通名詞-サ変可能 _ 4 nsubj _ BunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,参加,参加,参加,サンカ,,,サンカ,ホシトリサンカ,星取り参加 3 は は ADP 助詞-係助詞 _ 2 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は 4 当然 当然 ADJ 形状詞-一般 _ 6 advcl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=形状詞-一般|SpaceAfter=No|UnidicInfo=,当然,当然,当然,トーゼン,,,トウゼン,トウゼン,当然 5 と と ADP 助詞-格助詞 _ 4 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,と,と,と,ト,,,ト,ト,と 6 さ 為る VERB 動詞-非自立可能-サ行変格 _ 13 acl _ BunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,さ,する,サ,,,スル,スル,する 7 れ れる AUX 助動詞-助動詞-レル _ 6 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-レル|SpaceAfter=No|UnidicInfo=,れる,れ,れる,レ,,,レル,レル,れる 8 , , PUNCT 補助記号-読点 _ 6 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,,,,,,,,,,, 9 不 不 NOUN 接頭辞 Polarity=Neg 10 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,不,不,不,フ,,,フ,フサンカ,不参加 10 参加 参加 NOUN 名詞-普通名詞-サ変可能 _ 13 nsubj _ BunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,参加,参加,参加,サンカ,,,サンカ,フサンカ,不参加 11 は は ADP 助詞-係助詞 _ 10 case _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は 12 白眼 白眼 NOUN 名詞-普通名詞-一般 _ 13 compound _ BunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,白眼,白眼,白眼,ハクガン,,,ハクガン,ハクガンシスル,白眼視する 13 視 視 NOUN 接尾辞-名詞的-サ変可能 _ 0 root _ BunsetuBILabel=I|BunsetuPositionType=ROOT|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,視,視,視,シ,,,シ,ハクガンシスル,白眼視する 14 さ 為る AUX 動詞-非自立可能-サ行変格 _ 13 aux _ BunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,さ,する,サ,,,スル,ハクガンシスル,白眼視する 15 れる れる AUX 助動詞-助動詞-レル _ 13 aux _ BunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-レル|SpaceAfter=No|UnidicInfo=,れる,れる,れる,レル,,,レル,レル,れる 16 。 。 PUNCT 補助記号-句点 _ 13 punct _ BunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。 ================================================ FILE: demo/pipeline_demo.py ================================================ """ A basic demo of the Stanza neural pipeline. """ import sys import argparse import os import stanza from stanza.resources.common import DEFAULT_MODEL_DIR if __name__ == '__main__': # get arguments parser = argparse.ArgumentParser() parser.add_argument('-d', '--models_dir', help='location of models files | default: ~/stanza_resources', default=DEFAULT_MODEL_DIR) parser.add_argument('-l', '--lang', help='Demo language', default="en") parser.add_argument('-c', '--cpu', action='store_true', help='Use cpu as the device.') args = parser.parse_args() example_sentences = {"en": "Barack Obama was born in Hawaii. He was elected president in 2008.", "zh": "中国文化经历上千年的历史演变,是各区域、各民族古代文化长期相互交流、借鉴、融合的结果。", "fr": "Van Gogh grandit au sein d'une famille de l'ancienne bourgeoisie. Il tente d'abord de faire carrière comme marchand d'art chez Goupil & C.", "vi": "Trận Trân Châu Cảng (hay Chiến dịch Hawaii theo cách gọi của Bộ Tổng tư lệnh Đế quốc Nhật Bản) là một đòn tấn công quân sự bất ngờ được Hải quân Nhật Bản thực hiện nhằm vào căn cứ hải quân của Hoa Kỳ tại Trân Châu Cảng thuộc tiểu bang Hawaii vào sáng Chủ Nhật, ngày 7 tháng 12 năm 1941, dẫn đến việc Hoa Kỳ sau đó quyết định tham gia vào hoạt động quân sự trong Chiến tranh thế giới thứ hai."} if args.lang not in example_sentences: print(f'Sorry, but we don\'t have a demo sentence for "{args.lang}" for the moment. Try one of these languages: {list(example_sentences.keys())}') sys.exit(1) # download the models stanza.download(args.lang, dir=args.models_dir) # set up a pipeline print('---') print('Building pipeline...') pipeline = stanza.Pipeline(lang=args.lang, dir=args.models_dir, use_gpu=(not args.cpu)) # process the document doc = pipeline(example_sentences[args.lang]) # access nlp annotations print('') print('Input: {}'.format(example_sentences[args.lang])) print("The tokenizer split the input into {} sentences.".format(len(doc.sentences))) print('---') print('tokens of first sentence: ') doc.sentences[0].print_tokens() print('') print('---') print('dependency parse of first sentence: ') doc.sentences[0].print_dependencies() print('') ================================================ FILE: demo/scenegraph.py ================================================ """ Very short demo for the SceneGraph interface in the CoreNLP server Requires CoreNLP >= 4.5.5, Stanza >= 1.5.1 """ import json from stanza.server import CoreNLPClient # start_server=None if you have the server running in another process on the same host # you can start it with whatever normal options CoreNLPClient has # # preload=False avoids having the server unnecessarily load annotators # if you don't plan on using them with CoreNLPClient(preload=False) as client: result = client.scenegraph("Jennifer's antennae are on her head.") print(json.dumps(result, indent=2)) ================================================ FILE: demo/semgrex visualization.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2787d5f5", "metadata": {}, "outputs": [], "source": [ "import stanza\n", "from stanza.server.semgrex import Semgrex\n", "from stanza.models.common.constant import is_right_to_left\n", "import spacy\n", "from spacy import displacy\n", "from spacy.tokens import Doc\n", "from IPython.display import display, HTML\n", "\n", "\n", "\"\"\"\n", "IMPORTANT: For the code in this module to run, you must have corenlp and Java installed on your machine. Additionally,\n", "set an environment variable CLASSPATH equal to the path of your corenlp directory.\n", "\n", "Example: CLASSPATH=C:\\\\Users\\\\Alex\\\\PycharmProjects\\\\pythonProject\\\\stanford-corenlp-4.5.0\\\\stanford-corenlp-4.5.0\\\\*\n", "\"\"\"\n", "\n", "%env CLASSPATH=C:\\\\stanford-corenlp-4.5.2\\\\stanford-corenlp-4.5.2\\\\*\n", "def get_sentences_html(doc, language):\n", " \"\"\"\n", " Returns a list of the HTML strings of the dependency visualizations of a given stanza doc object.\n", "\n", " The 'language' arg is the two-letter language code for the document to be processed.\n", "\n", " First converts the stanza doc object to a spacy doc object and uses displacy to generate an HTML\n", " string for each sentence of the doc object.\n", " \"\"\"\n", " html_strings = []\n", "\n", " # blank model - we don't use any of the model features, just the visualization\n", " nlp = spacy.blank(\"en\")\n", " sentences_to_visualize = []\n", " for sentence in doc.sentences:\n", " words, lemmas, heads, deps, tags = [], [], [], [], []\n", " if is_right_to_left(language): # order of words displayed is reversed, dependency arcs remain intact\n", " sent_len = len(sentence.words)\n", " for word in reversed(sentence.words):\n", " words.append(word.text)\n", " lemmas.append(word.lemma)\n", " deps.append(word.deprel)\n", " tags.append(word.upos)\n", " if word.head == 0: # spaCy head indexes are formatted differently than that of Stanza\n", " heads.append(sent_len - word.id)\n", " else:\n", " heads.append(sent_len - word.head)\n", " else: # left to right rendering\n", " for word in sentence.words:\n", " words.append(word.text)\n", " lemmas.append(word.lemma)\n", " deps.append(word.deprel)\n", " tags.append(word.upos)\n", " if word.head == 0:\n", " heads.append(word.id - 1)\n", " else:\n", " heads.append(word.head - 1)\n", " document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)\n", " sentences_to_visualize.append(document_result)\n", "\n", " for line in sentences_to_visualize: # render all sentences through displaCy\n", " html_strings.append(displacy.render(line, style=\"dep\",\n", " options={\"compact\": True, \"word_spacing\": 30, \"distance\": 100,\n", " \"arrow_spacing\": 20}, jupyter=False))\n", " return html_strings\n", "\n", "\n", "def find_nth(haystack, needle, n):\n", " \"\"\"\n", " Returns the starting index of the nth occurrence of the substring 'needle' in the string 'haystack'.\n", " \"\"\"\n", " start = haystack.find(needle)\n", " while start >= 0 and n > 1:\n", " start = haystack.find(needle, start + len(needle))\n", " n -= 1\n", " return start\n", "\n", "\n", "def round_base(num, base=10):\n", " \"\"\"\n", " Rounding a number to its nearest multiple of the base. round_base(49.2, base=50) = 50.\n", " \"\"\"\n", " return base * round(num/base)\n", "\n", "\n", "def process_sentence_html(orig_html, semgrex_sentence):\n", " \"\"\"\n", " Takes a semgrex sentence object and modifies the HTML of the original sentence's deprel visualization,\n", " highlighting words involved in the search queries and adding the label of the word inside of the semgrex match.\n", "\n", " Returns the modified html string of the sentence's deprel visualization.\n", " \"\"\"\n", " tracker = {} # keep track of which words have multiple labels\n", " DEFAULT_TSPAN_COUNT = 2 # the original displacy html assigns two objects per object\n", " CLOSING_TSPAN_LEN = 8 # is 8 chars long\n", " colors = ['#4477AA', '#66CCEE', '#228833', '#CCBB44', '#EE6677', '#AA3377', '#BBBBBB']\n", " css_bolded_class = \"\\n\"\n", " found_index = orig_html.find(\"\\n\") # returns index where the opening ends\n", " # insert the new style class into html string\n", " orig_html = orig_html[: found_index + 1] + css_bolded_class + orig_html[found_index + 1:]\n", "\n", " # Add color to words in the match, bold words in the match\n", " for query in semgrex_sentence.result:\n", " for i, match in enumerate(query.match):\n", " color = colors[i]\n", " paired_dy = 2\n", " for node in match.node:\n", " name, match_index = node.name, node.matchIndex\n", " # edit existing to change color and bold the text\n", " start = find_nth(orig_html, \" of interest\n", " if match_index not in tracker: # if we've already bolded and colored, keep the first color\n", " tspan_start = orig_html.find(\" inside of the \n", " tspan_end = orig_html.find(\"\", start) # finds start of the end of the above \n", " tspan_substr = orig_html[tspan_start: tspan_end + CLOSING_TSPAN_LEN + 1] + \"\\n\"\n", " # color words in the hit and bold words in the hit\n", " edited_tspan = tspan_substr.replace('class=\"displacy-word\"', 'class=\"bolded\"').replace(\n", " 'fill=\"currentColor\"', f'fill=\"{color}\"')\n", " # insert edited object into html string\n", " orig_html = orig_html[: tspan_start] + edited_tspan + orig_html[tspan_end + CLOSING_TSPAN_LEN + 2:]\n", " tracker[match_index] = DEFAULT_TSPAN_COUNT\n", "\n", " # next, we have to insert the new object for the label\n", " # Copy old to copy formatting when creating new later\n", " prev_tspan_start = find_nth(orig_html[start:], \" start index\n", " prev_tspan_end = find_nth(orig_html[start:], \"\",\n", " tracker[match_index] - 1) + start # find the prev start index\n", " prev_tspan = orig_html[prev_tspan_start: prev_tspan_end + CLOSING_TSPAN_LEN + 1]\n", "\n", " # Find spot to insert new tspan\n", " closing_tspan_start = find_nth(orig_html[start:], \"\", tracker[match_index]) + start\n", " up_to_new_tspan = orig_html[: closing_tspan_start + CLOSING_TSPAN_LEN + 1]\n", " rest_need_add_newline = orig_html[closing_tspan_start + CLOSING_TSPAN_LEN + 1:]\n", "\n", " # Calculate proper x value in svg\n", " x_value_start = prev_tspan.find('x=\"')\n", " x_value_end = prev_tspan[x_value_start + 3:].find('\"') + 3 # 3 is the length of the 'x=\"' substring\n", " x_value = prev_tspan[x_value_start + 3: x_value_end + x_value_start]\n", "\n", " # Calculate proper y value in svg\n", " DEFAULT_DY_VAL, dy = 2, 2\n", " if paired_dy != DEFAULT_DY_VAL and node == match.node[\n", " 1]: # we're on the second node and need to adjust height to match the paired node\n", " dy = paired_dy\n", " if node == match.node[0]:\n", " paired_node_level = 2\n", " if match.node[1].matchIndex in tracker: # check if we need to adjust heights of labels\n", " paired_node_level = tracker[match.node[1].matchIndex]\n", " dif = tracker[match_index] - paired_node_level\n", " if dif > 0: # current node has more labels\n", " paired_dy = DEFAULT_DY_VAL * dif + 1\n", " dy = DEFAULT_DY_VAL\n", " else: # paired node has more labels, adjust this label down\n", " dy = DEFAULT_DY_VAL * (abs(dif) + 1)\n", " paired_dy = DEFAULT_DY_VAL\n", "\n", " # Insert new object\n", " new_tspan = f' {name[: 3].title()}.\\n' # abbreviate label names to 3 chars\n", " orig_html = up_to_new_tspan + new_tspan + rest_need_add_newline\n", " tracker[match_index] += 1\n", " return orig_html\n", "\n", "\n", "def render_html_strings(edited_html_strings):\n", " \"\"\"\n", " Renders the HTML to make the edits visible\n", " \"\"\"\n", " for html_string in edited_html_strings:\n", " display(HTML(html_string))\n", "\n", "\n", "def visualize_search_doc(doc, semgrex_queries, lang_code, start_match=0, end_match=10):\n", " \"\"\"\n", " Visualizes the semgrex results of running semgrex search on a stanza doc object with the given list of\n", " semgrex queries. Returns a list of the edited HTML strings from the doc. Each element in the list represents\n", " the HTML to render one of the sentences in the document.\n", "\n", " 'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\n", "\n", "\n", " 'start_match' and 'end_match' determine which matches to visualize. Works similar to splices, so that\n", " start_match=0 and end_match=10 will display the first 10 semgrex matches.\n", " \"\"\"\n", " matches_count = 0 # Limits number of visualizations\n", " with Semgrex(classpath=\"$CLASSPATH\") as sem:\n", " edited_html_strings = []\n", " semgrex_results = sem.process(doc, *semgrex_queries)\n", " # one html string for each sentence\n", " unedited_html_strings = get_sentences_html(doc, lang_code)\n", " for i in range(len(unedited_html_strings)):\n", "\n", " if matches_count >= end_match: # we've collected enough matches, stop early\n", " break\n", "\n", " # check if sentence has matches, if not then do not visualize\n", " has_none = True\n", " for query in semgrex_results.result[i].result:\n", " for match in query.match:\n", " if match:\n", " has_none = False\n", "\n", " # Process HTML if queries have matches\n", " if not has_none:\n", " if start_match <= matches_count < end_match:\n", " edited_string = process_sentence_html(unedited_html_strings[i], semgrex_results.result[i])\n", " edited_string = adjust_dep_arrows(edited_string)\n", " edited_html_strings.append(edited_string)\n", " matches_count += 1\n", "\n", " render_html_strings(edited_html_strings)\n", " return edited_html_strings\n", "\n", "\n", "def visualize_search_str(text, semgrex_queries, lang_code):\n", " \"\"\"\n", " Visualizes the deprel of the semgrex results from running semgrex search on a string with the given list of\n", " semgrex queries. Returns a list of the edited HTML strings. Each element in the list represents\n", " the HTML to render one of the sentences in the document.\n", "\n", " Internally, this function converts the string into a stanza doc object before processing the doc object.\n", "\n", " 'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\n", " \"\"\"\n", " nlp = stanza.Pipeline(lang_code, processors=\"tokenize, pos, lemma, depparse\")\n", " doc = nlp(text)\n", " return visualize_search_doc(doc, semgrex_queries, lang_code)\n", "\n", "\n", "def adjust_dep_arrows(raw_html):\n", " \"\"\"\n", " The default spaCy dependency visualization has misaligned arrows.\n", " We fix arrows by aligning arrow ends and bodies to the word that they are directed to. If a word has an\n", " arrowhead that is pointing not directly on the word's center, align the arrowhead to match the center of the word.\n", "\n", " returns the edited html with fixed arrow placement\n", " \"\"\"\n", " HTML_ARROW_BEGINNING = ''\n", " HTML_ARROW_ENDING = \"\"\n", " HTML_ARROW_ENDING_LEN = 6 # there are 2 newline chars after the arrow ending\n", " arrows_start_idx = find_nth(haystack=raw_html, needle='', n=1)\n", " words_html, arrows_html = raw_html[: arrows_start_idx], raw_html[arrows_start_idx:] # separate html for words and arrows\n", " final_html = words_html # continually concatenate to this after processing each arrow\n", " arrow_number = 1 # which arrow we're editing (1-indexed)\n", " start_idx, end_of_class_idx = find_nth(haystack=arrows_html, needle=HTML_ARROW_BEGINNING, n=arrow_number), find_nth(arrows_html, HTML_ARROW_ENDING, arrow_number)\n", " while start_idx != -1: # edit every arrow\n", " arrow_section = arrows_html[start_idx: end_of_class_idx + HTML_ARROW_ENDING_LEN] # slice a single svg arrow object\n", " if arrow_section[-1] == \"<\": # this is the last arrow in the HTML, don't cut the splice early\n", " arrow_section = arrows_html[start_idx:]\n", " edited_arrow_section = edit_dep_arrow(arrow_section)\n", "\n", " final_html = final_html + edited_arrow_section # continually update html with new arrow html until done\n", "\n", " # Prepare for next iteration\n", " arrow_number += 1\n", " start_idx = find_nth(arrows_html, '', n=arrow_number)\n", " end_of_class_idx = find_nth(arrows_html, \"\", arrow_number)\n", " return final_html\n", "\n", "\n", "def edit_dep_arrow(arrow_html):\n", " \"\"\"\n", " The formatting of a displacy arrow in svg is the following:\n", " \n", " \n", " \n", " csubj\n", " \n", " \n", " \n", "\n", " We edit the 'd = ...' parts of the section to fix the arrow direction and length\n", "\n", " returns the arrow_html with distances fixed\n", " \"\"\"\n", " WORD_SPACING = 50 # words start at x=50 and are separated by 100s so their x values are multiples of 50\n", " M_OFFSET = 4 # length of 'd=\"M' that we search for to extract the number from d=\"M70, for instance\n", " ARROW_PIXEL_SIZE = 4\n", " first_d_idx, second_d_idx = find_nth(arrow_html, 'd=\"M', 1), find_nth(arrow_html, 'd=\"M', 2) # find where d=\"M starts\n", " first_d_cutoff, second_d_cutoff = arrow_html.find(\",\", first_d_idx), arrow_html.find(\",\", second_d_idx) # isolate the number after 'M' e.g. 'M70'\n", " # gives svg x values of arrow body starting position and arrowhead position\n", " arrow_position, arrowhead_position = float(arrow_html[first_d_idx + M_OFFSET: first_d_cutoff]), float(arrow_html[second_d_idx + M_OFFSET: second_d_cutoff])\n", " # gives starting index of where 'fill=\"none\"' or 'fill=\"currentColor\"' begin, reference points to end the d= section\n", " first_fill_start_idx, second_fill_start_idx = find_nth(arrow_html, \"fill\", n=1), find_nth(arrow_html, \"fill\", n=3)\n", "\n", " # isolate the d= ... section to edit\n", " first_d, second_d = arrow_html[first_d_idx: first_fill_start_idx], arrow_html[second_d_idx: second_fill_start_idx]\n", " first_d_split, second_d_split = first_d.split(\",\"), second_d.split(\",\")\n", "\n", " if arrow_position == arrowhead_position: # This arrow is incoming onto the word, center the arrow/head to word center\n", " corrected_arrow_pos = corrected_arrowhead_pos = round_base(arrow_position, base=WORD_SPACING)\n", "\n", " # edit first_d -- arrow body\n", " second_term = first_d_split[1].split(\" \")[0] + \" \" + str(corrected_arrow_pos)\n", " first_d = 'd=\"M' + str(corrected_arrow_pos) + \",\" + second_term + \",\" + \",\".join(first_d_split[2:])\n", "\n", " # edit second_d -- arrowhead\n", " second_term = second_d_split[1].split(\" \")[0] + \" L\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n", " third_term = second_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n", " second_d = 'd=\"M' + str(corrected_arrowhead_pos) + \",\" + second_term + \",\" + third_term + \",\" + \",\".join(second_d_split[3:])\n", " else: # This arrow is outgoing to another word, center the arrow/head to that word's center\n", " corrected_arrowhead_pos = round_base(arrowhead_position, base=WORD_SPACING)\n", "\n", " # edit first_d -- arrow body\n", " third_term = first_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n", " fourth_term = first_d_split[3].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n", " terms = [first_d_split[0], first_d_split[1], third_term, fourth_term] + first_d_split[4:]\n", " first_d = \",\".join(terms)\n", "\n", " # edit second_d -- arrow head\n", " first_term = f'd=\"M{corrected_arrowhead_pos}'\n", " second_term = second_d_split[1].split(\" \")[0] + \" L\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n", " third_term = second_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n", " terms = [first_term, second_term, third_term] + second_d_split[3:]\n", " second_d = \",\".join(terms)\n", " # rebuild and return html\n", " return arrow_html[:first_d_idx] + first_d + \" \" + arrow_html[first_fill_start_idx:second_d_idx] + second_d + \" \" + arrow_html[second_fill_start_idx:]\n", "\n", "\n", "def main():\n", " nlp = stanza.Pipeline(\"en\", processors=\"tokenize,pos,lemma,depparse\")\n", "\n", " # doc = nlp(\"This a dummy sentence. Banning opal removed all artifact decks from the meta. I miss playing lantern. This is a dummy sentence.\")\n", " doc = nlp(\"Banning opal removed artifact decks from the meta. Banning tennis resulted in players banning people.\")\n", " # A single result .result[i].result[j] is a list of matches for sentence i on semgrex query j.\n", " queries = [\"{pos:NN}=object en_pronouns.updated.conllu # This script updates the UD 2.11 version of UD_English-Pronouns to # better match punctuation attachments, MWT, and no double subjects. # This turns unwanted csubj into advcl {}=source >nsubj {} >csubj=bad {} relabelNamedEdge -edge bad -reln advcl # This detects punctuations which are not attached to the root and reattaches them {word:/[.]/}=punct =3.15.0', 'requests', 'networkx', 'tomli;python_version<"3.11"', 'torch>=1.13.0', 'tqdm', 'udtools>=0.2.4', ], # List required Python versions python_requires='>=3.9', # List additional groups of dependencies here (e.g. development # dependencies). You can install these using the following syntax, # for example: # $ pip install -e .[dev,test] extras_require={ 'dev': [ 'check-manifest', ], 'test': [ 'coverage', 'pytest', ], 'transformers': [ 'transformers>=3.0.0', 'peft>=0.6.1', ], 'datasets': [ 'datasets', ], 'tokenizers': [ 'jieba', 'pythainlp', 'python-crfsuite', 'spacy', 'sudachidict_core', 'sudachipy', ], 'visualization': [ 'spacy', 'streamlit', 'ipython', ], 'morphseg': [ 'morphseg>=0.2.0', ] }, # If there are data files included in your packages that need to be # installed, specify them here. If using Python 2.6 or less, then these # have to be included in MANIFEST.in as well. package_data={ "": ["pipeline/demo/*ttf", "pipeline/demo/*css", "pipeline/demo/*html", "pipeline/demo/*js", "pipeline/demo/*gif",], }, include_package_data=True, # Although 'package_data' is the preferred approach, in some case you may # need to place data files outside of your packages. See: # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa # In this case, 'data_file' will be installed into '/my_data' data_files=[], # To provide executable scripts, use entry points in preference to the # "scripts" keyword. Entry points provide cross-platform support and allow # pip to create the appropriate form of executable for the target platform. entry_points={ }, ) ================================================ FILE: stanza/__init__.py ================================================ from stanza.pipeline.core import DownloadMethod, Pipeline from stanza.pipeline.multilingual import MultilingualPipeline from stanza.models.common.doc import Document from stanza.resources.common import download from stanza.resources.installation import install_corenlp, download_corenlp_models from stanza._version import __version__, __resources_version__ from stanza.pipeline.morphseg_processor import MorphSegProcessor import logging logger = logging.getLogger('stanza') # if the client application hasn't set the log level, we set it # ourselves to INFO if logger.level == 0: logger.setLevel(logging.INFO) log_handler = logging.StreamHandler() log_formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s", datefmt='%Y-%m-%d %H:%M:%S') log_handler.setFormatter(log_formatter) # also, if the client hasn't added any handlers for this logger # (or a default handler), we add a handler of our own # # client can later do # logger.removeHandler(stanza.log_handler) if not logger.hasHandlers(): logger.addHandler(log_handler) ================================================ FILE: stanza/_version.py ================================================ """ Single source of truth for version number """ __version__ = "1.11.1" __resources_version__ = '1.11.0' ================================================ FILE: stanza/models/__init__.py ================================================ ================================================ FILE: stanza/models/_training_logging.py ================================================ import logging logger = logging.getLogger('stanza') logger.setLevel(logging.DEBUG) ================================================ FILE: stanza/models/charlm.py ================================================ """ Entry point for training and evaluating a character-level neural language model. """ import argparse from copy import copy import logging import lzma import math import os import random import time from types import GeneratorType import numpy as np import torch from stanza.models.common.char_model import build_charlm_vocab, CharacterLanguageModel, CharacterLanguageModelTrainer from stanza.models.common.vocab import CharVocab from stanza.models.common import utils from stanza.models import _training_logging logger = logging.getLogger('stanza') def repackage_hidden(h): """Wraps hidden states in new Tensors, to detach them from their history.""" if isinstance(h, torch.Tensor): return h.detach() else: return tuple(repackage_hidden(v) for v in h) def batchify(data, bsz, device): # Work out how cleanly we can divide the dataset into bsz parts. nbatch = data.size(0) // bsz # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data.narrow(0, 0, nbatch * bsz) # Evenly divide the data across the bsz batches. data = data.view(bsz, -1) # batch_first is True data = data.to(device) return data def get_batch(source, i, seq_len): seq_len = min(seq_len, source.size(1) - 1 - i) data = source[:, i:i+seq_len] target = source[:, i+1:i+1+seq_len].reshape(-1) return data, target def load_file(filename, vocab, direction): with utils.open_read_text(filename) as fin: data = fin.read() idx = vocab['char'].map(data) if direction == 'backward': idx = idx[::-1] return torch.tensor(idx) def load_data(path, vocab, direction): if os.path.isdir(path): filenames = sorted(os.listdir(path)) for filename in filenames: logger.info('Loading data from {}'.format(filename)) data = load_file(os.path.join(path, filename), vocab, direction) yield data else: data = load_file(path, vocab, direction) yield data def build_argparse(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--train_file', type=str, help="Input plaintext file") parser.add_argument('--train_dir', type=str, help="If non-empty, load from directory with multiple training files") parser.add_argument('--eval_file', type=str, help="Input plaintext file for the dev/test set") parser.add_argument('--shorthand', type=str, help="UD treebank shorthand") parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help="Forward or backward language model") parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help="Train a forward language model") parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help="Train a backward language model") parser.add_argument('--char_emb_dim', type=int, default=100, help="Dimension of unit embeddings") parser.add_argument('--char_hidden_dim', type=int, default=1024, help="Dimension of hidden units") parser.add_argument('--char_num_layers', type=int, default=1, help="Layers of RNN in the language model") parser.add_argument('--char_dropout', type=float, default=0.05, help="Dropout probability") parser.add_argument('--char_unit_dropout', type=float, default=1e-5, help="Randomly set an input char to UNK during training") parser.add_argument('--char_rec_dropout', type=float, default=0.0, help="Recurrent dropout probability") parser.add_argument('--batch_size', type=int, default=100, help="Batch size to use") parser.add_argument('--bptt_size', type=int, default=250, help="Sequence length to consider at a time") parser.add_argument('--epochs', type=int, default=50, help="Total epochs to train the model for") parser.add_argument('--max_grad_norm', type=float, default=0.25, help="Maximum gradient norm to clip to") parser.add_argument('--lr0', type=float, default=5, help="Initial learning rate") parser.add_argument('--anneal', type=float, default=0.25, help="Anneal the learning rate by this amount when dev performance deteriorate") parser.add_argument('--patience', type=int, default=1, help="Patience for annealing the learning rate") parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay") parser.add_argument('--momentum', type=float, default=0.0, help='Momentum for SGD.') parser.add_argument('--cutoff', type=int, default=1000, help="Frequency cutoff for char vocab. By default we assume a very large corpus.") parser.add_argument('--report_steps', type=int, default=50, help="Update step interval to report loss") parser.add_argument('--eval_steps', type=int, default=100000, help="Update step interval to run eval on dev; set to -1 to eval after each epoch") parser.add_argument('--save_name', type=str, default=None, help="File name to save the model") parser.add_argument('--vocab_save_name', type=str, default=None, help="File name to save the vocab") parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint") parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints") parser.add_argument('--save_dir', type=str, default='saved_models/charlm', help="Directory to save models in") parser.add_argument('--summary', action='store_true', help='Use summary writer to record progress.') utils.add_device_args(parser) parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') return parser def build_model_filename(args): if args['save_name']: save_name = args['save_name'] else: save_name = '{}_{}_charlm.pt'.format(args['shorthand'], args['direction']) model_file = os.path.join(args['save_dir'], save_name) return model_file def parse_args(args=None): parser = build_argparse() args = parser.parse_args(args=args) if args.wandb_name: args.wandb = True args = vars(args) return args def main(args=None): args = parse_args(args=args) utils.set_random_seed(args['seed']) logger.info("Running {} character-level language model in {} mode".format(args['direction'], args['mode'])) utils.ensure_dir(args['save_dir']) if args['mode'] == 'train': train(args) else: evaluate(args) def evaluate_epoch(args, vocab, data, model, criterion): """ Run an evaluation over entire dataset. """ model.eval() device = next(model.parameters()).device hidden = None total_loss = 0 if isinstance(data, GeneratorType): data = list(data) assert len(data) == 1, 'Only support single dev/test file' data = data[0] batches = batchify(data, args['batch_size'], device) with torch.no_grad(): for i in range(0, batches.size(1) - 1, args['bptt_size']): data, target = get_batch(batches, i, args['bptt_size']) lens = [data.size(1) for i in range(data.size(0))] output, hidden, decoded = model.forward(data, lens, hidden) loss = criterion(decoded.view(-1, len(vocab['char'])), target) hidden = repackage_hidden(hidden) total_loss += data.size(1) * loss.data.item() return total_loss / batches.size(1) def evaluate_and_save(args, vocab, data, trainer, best_loss, model_file, checkpoint_file, writer=None): """ Run an evaluation over entire dataset, print progress and save the model if necessary. """ start_time = time.time() loss = evaluate_epoch(args, vocab, data, trainer.model, trainer.criterion) ppl = math.exp(loss) elapsed = int(time.time() - start_time) # TODO: step the scheduler less often when the eval frequency is higher previous_lr = get_current_lr(trainer, args) trainer.scheduler.step(loss) current_lr = get_current_lr(trainer, args) if previous_lr != current_lr: logger.info("Updating learning rate to %f", current_lr) logger.info( "| eval checkpoint @ global step {:10d} | time elapsed {:6d}s | loss {:5.2f} | ppl {:8.2f}".format( trainer.global_step, elapsed, loss, ppl, ) ) if best_loss is None or loss < best_loss: best_loss = loss trainer.save(model_file, full=False) logger.info('new best model saved at step {:10d}'.format(trainer.global_step)) if writer: writer.add_scalar('dev_loss', loss, global_step=trainer.global_step) writer.add_scalar('dev_ppl', ppl, global_step=trainer.global_step) if checkpoint_file: trainer.save(checkpoint_file, full=True) logger.info('new checkpoint saved at step {:10d}'.format(trainer.global_step)) return loss, ppl, best_loss def get_current_lr(trainer, args): return trainer.scheduler.state_dict().get('_last_lr', [args['lr0']])[0] def load_char_vocab(vocab_file): return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage, weights_only=True))} def train(args): utils.log_training_args(args, logger) model_file = build_model_filename(args) vocab_file = args['save_dir'] + '/' + args['vocab_save_name'] if args['vocab_save_name'] is not None \ else '{}/{}_vocab.pt'.format(args['save_dir'], args['shorthand']) if args['checkpoint']: checkpoint_file = utils.checkpoint_name(args['save_dir'], model_file, args['checkpoint_save_name']) else: checkpoint_file = None if os.path.exists(vocab_file): logger.info('Loading existing vocab file') vocab = load_char_vocab(vocab_file) else: logger.info('Building and saving vocab') vocab = {'char': build_charlm_vocab(args['train_file'] if args['train_dir'] is None else args['train_dir'], cutoff=args['cutoff'])} torch.save(vocab['char'].state_dict(), vocab_file) logger.info("Training model with vocab size: {}".format(len(vocab['char']))) if checkpoint_file and os.path.exists(checkpoint_file): logger.info('Loading existing checkpoint: %s' % checkpoint_file) trainer = CharacterLanguageModelTrainer.load(args, checkpoint_file, finetune=True) else: trainer = CharacterLanguageModelTrainer.from_new_model(args, vocab) writer = None if args['summary']: from torch.utils.tensorboard import SummaryWriter summary_dir = '{}/{}_summary'.format(args['save_dir'], args['save_name']) if args['save_name'] is not None \ else '{}/{}_{}_charlm_summary'.format(args['save_dir'], args['shorthand'], args['direction']) writer = SummaryWriter(log_dir=summary_dir) # evaluate model within epoch if eval_interval is set eval_within_epoch = False if args['eval_steps'] > 0: eval_within_epoch = True if args['wandb']: import wandb wandb_name = args['wandb_name'] if args['wandb_name'] else '%s_%s_charlm' % (args['shorthand'], args['direction']) wandb.init(name=wandb_name, config=args) wandb.run.define_metric('best_loss', summary='min') wandb.run.define_metric('ppl', summary='min') device = next(trainer.model.parameters()).device best_loss = None start_epoch = trainer.epoch # will default to 1 for a new trainer for trainer.epoch in range(start_epoch, args['epochs']+1): # load train data from train_dir if not empty, otherwise load from file if args['train_dir'] is not None: train_path = args['train_dir'] else: train_path = args['train_file'] train_data = load_data(train_path, vocab, args['direction']) dev_data = load_file(args['eval_file'], vocab, args['direction']) # dev must be a single file # run over entire training set for data_chunk in train_data: batches = batchify(data_chunk, args['batch_size'], device) hidden = None total_loss = 0.0 total_batches = math.ceil((batches.size(1) - 1) / args['bptt_size']) iteration, i = 0, 0 # over the data chunk while i < batches.size(1) - 1 - 1: trainer.model.train() trainer.global_step += 1 start_time = time.time() bptt = args['bptt_size'] if np.random.random() < 0.95 else args['bptt_size']/ 2. # prevent excessively small or negative sequence lengths seq_len = max(5, int(np.random.normal(bptt, 5))) # prevent very large sequence length, must be <= 1.2 x bptt seq_len = min(seq_len, int(args['bptt_size'] * 1.2)) data, target = get_batch(batches, i, seq_len) lens = [data.size(1) for i in range(data.size(0))] trainer.optimizer.zero_grad() output, hidden, decoded = trainer.model.forward(data, lens, hidden) loss = trainer.criterion(decoded.view(-1, len(vocab['char'])), target) total_loss += loss.data.item() loss.backward() torch.nn.utils.clip_grad_norm_(trainer.params, args['max_grad_norm']) trainer.optimizer.step() hidden = repackage_hidden(hidden) if (iteration + 1) % args['report_steps'] == 0: cur_loss = total_loss / args['report_steps'] elapsed = time.time() - start_time logger.info( "| epoch {:5d} | {:5d}/{:5d} batches | sec/batch {:.6f} | loss {:5.2f} | ppl {:8.2f}".format( trainer.epoch, iteration + 1, total_batches, elapsed / args['report_steps'], cur_loss, math.exp(cur_loss), ) ) if args['wandb']: wandb.log({'train_loss': cur_loss}, step=trainer.global_step) total_loss = 0.0 iteration += 1 i += seq_len # evaluate if necessary if eval_within_epoch and trainer.global_step % args['eval_steps'] == 0: _, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer) if args['wandb']: wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step) # if eval_interval isn't provided, run evaluation after each epoch if not eval_within_epoch or trainer.epoch == args['epochs']: _, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer) if args['wandb']: wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step) if writer: writer.close() if args['wandb']: wandb.finish() return def evaluate(args): model_file = build_model_filename(args) model = CharacterLanguageModel.load(model_file).to(args['device']) vocab = model.vocab data = load_data(args['eval_file'], vocab, args['direction']) criterion = torch.nn.CrossEntropyLoss() loss = evaluate_epoch(args, vocab, data, model, criterion) logger.info( "| best model | loss {:5.2f} | ppl {:8.2f}".format( loss, math.exp(loss), ) ) return if __name__ == '__main__': main() ================================================ FILE: stanza/models/classifier.py ================================================ import argparse import ast import logging import os import random import re from enum import Enum import torch import torch.nn as nn from stanza.models.common import loss from stanza.models.common import utils from stanza.models.pos.vocab import CharVocab import stanza.models.classifiers.data as data from stanza.models.classifiers.trainer import Trainer from stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType from stanza.models.common.peft_config import add_peft_args, resolve_peft_args from stanza.utils.confusion import format_confusion, confusion_to_accuracy, confusion_to_macro_f1 class Loss(Enum): CROSS = 1 WEIGHTED_CROSS = 2 LOG_CROSS = 3 FOCAL = 4 class DevScoring(Enum): ACCURACY = 'ACC' WEIGHTED_F1 = 'WF' logger = logging.getLogger('stanza') tlogger = logging.getLogger('stanza.classifiers.trainer') logging.getLogger('elmoformanylangs').setLevel(logging.WARNING) DEFAULT_TRAIN='data/sentiment/en_sstplus.train.txt' DEFAULT_DEV='data/sentiment/en_sst3roots.dev.txt' DEFAULT_TEST='data/sentiment/en_sst3roots.test.txt' """A script for training and testing classifier models, especially on the SST. If you run the script with no arguments, it will start trying to train a sentiment model. python3 -m stanza.models.classifier This requires the sentiment dataset to be in an `extern_data` directory, such as by symlinking it from somewhere else. The default model is a CNN where the word vectors are first mapped to channels with filters of a few different widths, those channels are maxpooled over the entire sentence, and then the resulting pools have fully connected layers until they reach the number of classes in the training data. You can see the defaults in the options below. https://arxiv.org/abs/1408.5882 (Currently the CNN is the only sentence classifier implemented.) To train with a more complicated CNN arch: nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 > FC41.out 2>&1 & You can train models with word vectors other than the default word2vec. For example: nohup python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --max_epochs 200 --filter_channels 1000 --fc_shapes 200,100 --base_name FC21_google > FC21_google.out 2>&1 & A model trained on the 5 class dataset can be tested on the 2 class dataset with a command line like this: python3 -u -m stanza.models.classifier --no_train --load_name saved_models/classifier/sst_en_ewt_FS_3_4_5_C_1000_FC_400_100_classifier.E0165-ACC41.87.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 1:0, 3:1, 4:1}" python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --no_train --load_name saved_models/classifier/FC21_google_en_ewt_FS_3_4_5_C_1000_FC_200_100_classifier.E0189-ACC45.87.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 1:0, 3:1, 4:1}" A model trained on the 3 class dataset can be tested on the 2 class dataset with a command line like this: python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --no_train --load_name saved_models/classifier/FC21_3C_google_en_ewt_FS_3_4_5_C_1000_FC_200_100_classifier.E0101-ACC68.94.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 2:1}" To train models on combined 3 class datasets: nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_3class --extra_wordvec_method CONCAT --extra_wordvec_dim 200 --train_file data/sentiment/en_sstplus.train.txt --dev_file data/sentiment/en_sst3roots.dev.txt --test_file data/sentiment/en_sst3roots.test.txt > FC41_3class.out 2>&1 & This tests that model: python3 -u -m stanza.models.classifier --no_train --load_name en_sstplus.pt --test_file data/sentiment/en_sst3roots.test.txt Here is an example for training a model in a different language: nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_german --train_file data/sentiment/de_sb10k.train.txt --dev_file data/sentiment/de_sb10k.dev.txt --test_file data/sentiment/de_sb10k.test.txt --shorthand de_sb10k --min_train_len 3 --extra_wordvec_method CONCAT --extra_wordvec_dim 100 > de_sb10k.out 2>&1 & This uses more data, although that wound up being worse for the German model: nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_german --train_file data/sentiment/de_sb10k.train.txt,data/sentiment/de_scare.train.txt,data/sentiment/de_usage.train.txt --dev_file data/sentiment/de_sb10k.dev.txt --test_file data/sentiment/de_sb10k.test.txt --shorthand de_sb10k --min_train_len 3 --extra_wordvec_method CONCAT --extra_wordvec_dim 100 > de_sb10k.out 2>&1 & nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_chinese --train_file data/sentiment/zh_ren.train.txt --dev_file data/sentiment/zh_ren.dev.txt --test_file data/sentiment/zh_ren.test.txt --shorthand zh_ren --wordvec_type fasttext --extra_wordvec_method SUM --wordvec_pretrain_file ../stanza_resources/zh-hans/pretrain/gsdsimp.pt > zh_ren.out 2>&1 & nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --save_name vi_vsfc.pt --train_file data/sentiment/vi_vsfc.train.json --dev_file data/sentiment/vi_vsfc.dev.json --test_file data/sentiment/vi_vsfc.test.json --shorthand vi_vsfc --wordvec_pretrain_file ../stanza_resources/vi/pretrain/vtb.pt --wordvec_type word2vec --extra_wordvec_method SUM --dev_eval_scoring WEIGHTED_F1 > vi_vsfc.out 2>&1 & python3 -u -m stanza.models.classifier --no_train --test_file extern_data/sentiment/vietnamese/_UIT-VSFC/test.txt --shorthand vi_vsfc --wordvec_pretrain_file ../stanza_resources/vi/pretrain/vtb.pt --wordvec_type word2vec --load_name vi_vsfc.pt """ def convert_fc_shapes(arg): """ Returns a tuple of sizes to use in FC layers. For examples, converts "100" -> (100,) "100,200" -> (100,200) """ arg = arg.strip() if not arg: return () arg = ast.literal_eval(arg) if isinstance(arg, int): return (arg,) if isinstance(arg, tuple): return arg return tuple(arg) # For the most part, these values are for the constituency parser. # Only the WD for adadelta is originally for sentiment # Also LR for adadelta and madgrad # madgrad learning rate experiment on sstplus # note that the hyperparameters are not cross-validated in tandem, so # later changes may make some earlier experiments slightly out of date # LR # 0.01 failed to converge # 0.004 failed to converge # 0.003 0.5572 # 0.002 failed to converge # 0.001 0.6857 # 0.0008 0.6799 # 0.0005 0.6849 # 0.00025 0.6749 # 0.0001 0.6746 # 0.00001 0.6536 # 0.000001 0.6267 # LR 0.001 produced the best model, but it does occasionally fail to # converge to a working model, so we set the default to 0.0005 instead DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0005, "sgd": 0.001 } DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 } DEFAULT_LEARNING_RHO = 0.9 DEFAULT_MOMENTUM = { "madgrad": 0.9, "sgd": 0.9 } DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.0001, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6 } def build_argparse(): """ Build the argparse for the classifier. Refactored so that other utility scripts can use the same parser if needed. """ parser = argparse.ArgumentParser() parser.add_argument('--train', dest='train', default=True, action='store_true', help='Train the model (default)') parser.add_argument('--no_train', dest='train', action='store_false', help="Don't train the model") parser.add_argument('--shorthand', type=str, default='en_ewt', help="Treebank shorthand, eg 'en' for English") parser.add_argument('--load_name', type=str, default=None, help='Name for loading an existing model') parser.add_argument('--save_dir', type=str, default='saved_models/classifier', help='Root dir for saving models.') parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{bert_finetuning}_{classifier_type}_classifier.pt", help='Name for saving the model') parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint") parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints") parser.add_argument('--save_intermediate_models', default=False, action='store_true', help='Save all intermediate models - this can be a lot!') parser.add_argument('--train_file', type=str, default=DEFAULT_TRAIN, help='Input file(s) to train a model from. Each line is an example. Should go (S ... ) LATEX_TREE = 5 # \Tree [.S [.NP ... ] ] class Tree(StanzaObject): """ A data structure to represent a parse tree """ def __init__(self, label=None, children=None): if children is None: self.children = EMPTY_CHILDREN elif isinstance(children, Tree): self.children = (children,) else: self.children = tuple(children) self.label = label def is_leaf(self): return len(self.children) == 0 def is_preterminal(self): return len(self.children) == 1 and len(self.children[0].children) == 0 def yield_preterminals(self): """ Yield the preterminals one at a time in order """ if self.is_preterminal(): yield self return if self.is_leaf(): raise ValueError("Attempted to iterate preterminals on non-internal node") iterator = iter(self.children) node = next(iterator, None) while node is not None: if node.is_preterminal(): yield node else: iterator = itertools.chain(node.children, iterator) node = next(iterator, None) def leaf_labels(self): """ Get the labels of the leaves """ if self.is_leaf(): return [self.label] words = [x.children[0].label for x in self.yield_preterminals()] return words def __len__(self): return len(self.leaf_labels()) def all_leaves_are_preterminals(self): """ Returns True if all leaves are under preterminals, False otherwise """ if self.is_leaf(): return False if self.is_preterminal(): return True return all(t.all_leaves_are_preterminals() for t in self.children) def pretty_print(self, normalize=None): """ Print with newlines & indentation on each line Preterminals and nodes with all preterminal children go on their own line You can pass in your own normalize() function. If you do, make sure the function updates the parens to be something other than () or the brackets will be broken """ if normalize is None: normalize = lambda x: x.replace("(", "-LRB-").replace(")", "-RRB-") indent = 0 with StringIO() as buf: stack = deque() stack.append(self) while len(stack) > 0: node = stack.pop() if node is CLOSE_PAREN: # if we're trying to pretty print trees, pop all off close parens # then write a newline while node is CLOSE_PAREN: indent -= 1 buf.write(CLOSE_PAREN) if len(stack) == 0: node = None break node = stack.pop() buf.write("\n") if node is None: break stack.append(node) elif node.is_preterminal(): buf.write(" " * indent) buf.write("%s%s %s%s" % (OPEN_PAREN, normalize(node.label), normalize(node.children[0].label), CLOSE_PAREN)) if len(stack) == 0 or stack[-1] is not CLOSE_PAREN: buf.write("\n") elif all(x.is_preterminal() for x in node.children): buf.write(" " * indent) buf.write("%s%s" % (OPEN_PAREN, normalize(node.label))) for child in node.children: buf.write(" %s%s %s%s" % (OPEN_PAREN, normalize(child.label), normalize(child.children[0].label), CLOSE_PAREN)) buf.write(CLOSE_PAREN) if len(stack) == 0 or stack[-1] is not CLOSE_PAREN: buf.write("\n") else: buf.write(" " * indent) buf.write("%s%s\n" % (OPEN_PAREN, normalize(node.label))) stack.append(CLOSE_PAREN) for child in reversed(node.children): stack.append(child) indent += 1 buf.seek(0) return buf.read() def __format__(self, spec): """ Turn the tree into a string representing the tree Note that this is not a recursive traversal Otherwise, a tree too deep might blow up the call stack There is a type specific format: O -> one line PTB format, which is the default anyway L -> open and close brackets are labeled, spaces in the tokens are replaced with _ P -> pretty print over multiple lines V -> surround lines with ..., don't print ROOT, and turn () into L/RBKT ? -> spaces in the tokens are replaced with ? for any value of ? other than OLP warning: this may be removed in the future ?{OLPV} -> specific format AND a custom space replacement Vi -> add an ID to the in the V format. Also works with ?Vi """ space_replacement = " " print_format = TreePrintMethod.ONE_LINE if spec == 'L': print_format = TreePrintMethod.LABELED_PARENS space_replacement = "_" elif spec and spec[-1] == 'L': print_format = TreePrintMethod.LABELED_PARENS space_replacement = spec[0] elif spec == 'O': print_format = TreePrintMethod.ONE_LINE elif spec and spec[-1] == 'O': print_format = TreePrintMethod.ONE_LINE space_replacement = spec[0] elif spec == 'P': print_format = TreePrintMethod.PRETTY elif spec and spec[-1] == 'P': print_format = TreePrintMethod.PRETTY space_replacement = spec[0] elif spec and spec[0] == 'V': print_format = TreePrintMethod.VLSP use_tree_id = spec[-1] == 'i' elif spec and len(spec) > 1 and spec[1] == 'V': print_format = TreePrintMethod.VLSP space_replacement = spec[0] use_tree_id = spec[-1] == 'i' elif spec == 'T': print_format = TreePrintMethod.LATEX_TREE elif spec and len(spec) > 1 and spec[1] == 'T': print_format = TreePrintMethod.LATEX_TREE space_replacement = spec[0] elif spec: space_replacement = spec[0] warnings.warn("Use of a custom replacement without a format specifier is deprecated. Please use {}O instead".format(space_replacement), stacklevel=2) LRB = "LBKT" if print_format == TreePrintMethod.VLSP else "-LRB-" RRB = "RBKT" if print_format == TreePrintMethod.VLSP else "-RRB-" def normalize(text): return text.replace(" ", space_replacement).replace("(", LRB).replace(")", RRB) if print_format is TreePrintMethod.PRETTY: return self.pretty_print(normalize) with StringIO() as buf: stack = deque() if print_format == TreePrintMethod.VLSP: if use_tree_id: buf.write("\n".format(self.tree_id)) else: buf.write("\n") if len(self.children) == 0: raise ValueError("Cannot print an empty tree with V format") elif len(self.children) > 1: raise ValueError("Cannot print a tree with %d branches with V format" % len(self.children)) stack.append(self.children[0]) elif print_format == TreePrintMethod.LATEX_TREE: buf.write("\\Tree ") if len(self.children) == 0: raise ValueError("Cannot print an empty tree with T format") elif len(self.children) == 1 and len(self.children[0].children) == 0: buf.write("[.? ") buf.write(normalize(self.children[0].label)) buf.write(" ]") elif self.label == 'ROOT': stack.append(self.children[0]) else: stack.append(self) else: stack.append(self) while len(stack) > 0: node = stack.pop() if isinstance(node, str): buf.write(node) continue if len(node.children) == 0: if node.label is not None: buf.write(normalize(node.label)) continue if print_format is TreePrintMethod.LATEX_TREE: if node.is_preterminal(): buf.write(normalize(node.children[0].label)) continue buf.write("[.%s" % normalize(node.label)) stack.append(" ]") elif print_format is TreePrintMethod.ONE_LINE or print_format is TreePrintMethod.VLSP: buf.write(OPEN_PAREN) if node.label is not None: buf.write(normalize(node.label)) stack.append(CLOSE_PAREN) elif print_format is TreePrintMethod.LABELED_PARENS: buf.write("%s_%s" % (OPEN_PAREN, normalize(node.label))) stack.append(CLOSE_PAREN + "_" + normalize(node.label)) stack.append(SPACE_SEPARATOR) for child in reversed(node.children): stack.append(child) stack.append(SPACE_SEPARATOR) if print_format == TreePrintMethod.VLSP: buf.write("\n") buf.seek(0) return buf.read() def __repr__(self): return "{}".format(self) def __eq__(self, other): if self is other: return True if not isinstance(other, Tree): return False if self.label != other.label: return False if len(self.children) != len(other.children): return False if any(c1 != c2 for c1, c2 in zip(self.children, other.children)): return False return True def depth(self): if not self.children: return 0 return 1 + max(x.depth() for x in self.children) def visit_preorder(self, internal=None, preterminal=None, leaf=None): """ Visit the tree in a preorder order Applies the given functions to each node. internal: if not None, applies this function to each non-leaf, non-preterminal node preterminal: if not None, applies this functiion to each preterminal leaf: if not None, applies this function to each leaf The functions should *not* destructively alter the trees. There is no attempt to interpret the results of calling these functions. Rather, you can use visit_preorder to collect stats on trees, etc. """ if self.is_leaf(): if leaf: leaf(self) elif self.is_preterminal(): if preterminal: preterminal(self) else: if internal: internal(self) for child in self.children: child.visit_preorder(internal, preterminal, leaf) @staticmethod def get_unique_constituent_labels(trees): """ Walks over all of the trees and gets all of the unique constituent names from the trees """ if isinstance(trees, Tree): trees = [trees] constituents = Tree.get_constituent_counts(trees) return sorted(set(constituents.keys())) @staticmethod def get_constituent_counts(trees): """ Walks over all of the trees and gets the count of the unique constituent names from the trees """ if isinstance(trees, Tree): trees = [trees] constituents = Counter() for tree in trees: tree.visit_preorder(internal = lambda x: constituents.update([x.label])) return constituents @staticmethod def get_unique_tags(trees): """ Walks over all of the trees and gets all of the unique tags from the trees """ if isinstance(trees, Tree): trees = [trees] tags = set() for tree in trees: tree.visit_preorder(preterminal = lambda x: tags.add(x.label)) return sorted(tags) @staticmethod def get_unique_words(trees): """ Walks over all of the trees and gets all of the unique words from the trees """ if isinstance(trees, Tree): trees = [trees] words = set() for tree in trees: tree.visit_preorder(leaf = lambda x: words.add(x.label)) return sorted(words) @staticmethod def get_common_words(trees, num_words): """ Walks over all of the trees and gets the most frequently occurring words. """ if num_words == 0: return set() if isinstance(trees, Tree): trees = [trees] words = Counter() for tree in trees: tree.visit_preorder(leaf = lambda x: words.update([x.label])) return sorted(x[0] for x in words.most_common()[:num_words]) @staticmethod def get_rare_words(trees, threshold=0.05): """ Walks over all of the trees and gets the least frequently occurring words. threshold: choose the bottom X percent """ if isinstance(trees, Tree): trees = [trees] words = Counter() for tree in trees: tree.visit_preorder(leaf = lambda x: words.update([x.label])) threshold = max(int(len(words) * threshold), 1) return sorted(x[0] for x in words.most_common()[:-threshold-1:-1]) @staticmethod def get_root_labels(trees): return sorted(set(x.label for x in trees)) @staticmethod def get_compound_constituents(trees, separate_root=False): constituents = set() stack = deque() for tree in trees: if separate_root: constituents.add((tree.label,)) for child in tree.children: stack.append(child) else: stack.append(tree) while len(stack) > 0: node = stack.pop() if node.is_leaf() or node.is_preterminal(): continue labels = [node.label] while len(node.children) == 1 and not node.children[0].is_preterminal(): node = node.children[0] labels.append(node.label) constituents.add(tuple(labels)) for child in node.children: stack.append(child) return sorted(constituents) # TODO: test different pattern def simplify_labels(self, pattern=CONSTITUENT_SPLIT): """ Return a copy of the tree with the -=# removed Leaves the text of the leaves alone. """ new_label = self.label # check len(new_label) just in case it's a tag of - or = if new_label and not self.is_leaf() and len(new_label) > 1 and new_label not in ('-LRB-', '-RRB-'): new_label = pattern.split(new_label)[0] new_children = [child.simplify_labels(pattern) for child in self.children] return Tree(new_label, new_children) def reverse(self): """ Flip a tree backwards The intent is to train a parser backwards to see if the forward and backwards parsers can augment each other """ if self.is_leaf(): return Tree(self.label) new_children = [child.reverse() for child in reversed(self.children)] return Tree(self.label, new_children) def remap_constituent_labels(self, label_map): """ Copies the tree with some labels replaced. Labels in the map are replaced with the mapped value. Labels not in the map are unchanged. """ if self.is_leaf(): return Tree(self.label) if self.is_preterminal(): return Tree(self.label, Tree(self.children[0].label)) new_label = label_map.get(self.label, self.label) return Tree(new_label, [child.remap_constituent_labels(label_map) for child in self.children]) def remap_words(self, word_map): """ Copies the tree with some labels replaced. Labels in the map are replaced with the mapped value. Labels not in the map are unchanged. """ if self.is_leaf(): new_label = word_map.get(self.label, self.label) return Tree(new_label) if self.is_preterminal(): return Tree(self.label, self.children[0].remap_words(word_map)) return Tree(self.label, [child.remap_words(word_map) for child in self.children]) def replace_words(self, words): """ Replace all leaf words with the words in the given list (or iterable) Returns a new tree """ word_iterator = iter(words) def recursive_replace_words(subtree): if subtree.is_leaf(): word = next(word_iterator, None) if word is None: raise ValueError("Not enough words to replace all leaves") return Tree(word) return Tree(subtree.label, [recursive_replace_words(x) for x in subtree.children]) new_tree = recursive_replace_words(self) if any(True for _ in word_iterator): raise ValueError("Too many words for the given tree") return new_tree def replace_tags(self, tags): if self.is_leaf(): raise ValueError("Must call replace_tags with non-leaf") if isinstance(tags, Tree): tag_iterator = (x.label for x in tags.yield_preterminals()) else: tag_iterator = iter(tags) new_tree = copy.deepcopy(self) queue = deque() queue.append(new_tree) while len(queue) > 0: next_node = queue.pop() if next_node.is_preterminal(): try: label = next(tag_iterator) except StopIteration: raise ValueError("Not enough tags in sentence for given tree") next_node.label = label elif next_node.is_leaf(): raise ValueError("Got a badly structured tree: {}".format(self)) else: queue.extend(reversed(next_node.children)) if any(True for _ in tag_iterator): raise ValueError("Too many tags for the given tree") return new_tree def prune_none(self): """ Return a copy of the tree, eliminating all nodes which are in one of two categories: they are a preterminal -NONE-, such as appears in PTB *E* shows up in a VLSP dataset they have been pruned to 0 children by the recursive call """ if self.is_leaf(): return Tree(self.label) if self.is_preterminal(): if self.label == '-NONE-' or self.children[0].label in WORDS_TO_PRUNE: return None return Tree(self.label, Tree(self.children[0].label)) # must be internal node new_children = [child.prune_none() for child in self.children] new_children = [child for child in new_children if child is not None] if len(new_children) == 0: return None return Tree(self.label, new_children) def count_unary_depth(self): if self.is_preterminal() or self.is_leaf(): return 0 if len(self.children) == 1: t = self score = 0 while not t.is_preterminal() and not t.is_leaf() and len(t.children) == 1: score = score + 1 t = t.children[0] child_score = max(tc.count_unary_depth() for tc in t.children) score = max(score, child_score) return score score = max(t.count_unary_depth() for t in self.children) return score @staticmethod def write_treebank(trees, out_file, fmt="{}"): with open(out_file, "w", encoding="utf-8") as fout: for tree in trees: fout.write(fmt.format(tree)) fout.write("\n") ================================================ FILE: stanza/models/constituency/parser_training.py ================================================ from collections import Counter, namedtuple import copy import logging import os import random import re import torch from torch import nn #from stanza.models.common import pretrain from stanza.models.common import utils from stanza.models.common.foundation_cache import FoundationCache, NoTransformerFoundationCache from stanza.models.common.large_margin_loss import LargeMarginInSoftmaxLoss from stanza.models.common.utils import sort_with_indices, unsort from stanza.models.constituency import error_analysis_in_order from stanza.models.constituency import parse_transitions from stanza.models.constituency import transition_sequence from stanza.models.constituency import tree_reader from stanza.models.constituency.in_order_compound_oracle import InOrderCompoundOracle from stanza.models.constituency.in_order_oracle import InOrderOracle from stanza.models.constituency.lstm_model import LSTMModel from stanza.models.constituency.parse_transitions import TransitionScheme from stanza.models.constituency.parse_tree import Tree from stanza.models.constituency.top_down_oracle import TopDownOracle from stanza.models.constituency.trainer import Trainer from stanza.models.constituency.utils import retag_trees, build_optimizer, build_scheduler, verify_transitions, get_open_nodes, check_constituents, check_root_labels, remove_duplicate_trees, remove_singleton_trees from stanza.server.parser_eval import EvaluateParser, ParseResult from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() tlogger = logging.getLogger('stanza.constituency.trainer') TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals']) class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): def __add__(self, other): transitions_correct = self.transitions_correct + other.transitions_correct transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect repairs_used = self.repairs_used + other.repairs_used fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used epoch_loss = self.epoch_loss + other.epoch_loss nans = self.nans + other.nans return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def evaluate(args, model_file, retag_pipeline): """ Loads the given model file and tests the eval_file treebank. May retag the trees using retag_pipeline Uses a subprocess to run the Java EvalB code """ # we create the Evaluator here because otherwise the transformers # library constantly complains about forking the process # note that this won't help in the event of training multiple # models in the same run, although since that would take hours # or days, that's not a very common problem if args['num_generate'] > 0: kbest = args['num_generate'] + 1 else: kbest = None with EvaluateParser(kbest=kbest) as evaluator: foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() load_args = { 'wordvec_pretrain_file': args['wordvec_pretrain_file'], 'charlm_forward_file': args['charlm_forward_file'], 'charlm_backward_file': args['charlm_backward_file'], 'device': args['device'], } trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache) if args['log_shapes']: trainer.log_shapes() treebank = tree_reader.read_treebank(args['eval_file']) tlogger.info("Read %d trees for evaluation", len(treebank)) retagged_treebank = treebank if retag_pipeline is not None: retag_method = trainer.model.retag_method retag_xpos = retag_method == 'xpos' tlogger.info("Retagging trees using the %s tags from the %s package...", retag_method, args['retag_package']) retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos) tlogger.info("Retagging finished") if args['log_norms']: trainer.log_norms() f1, kbestF1, _ = run_dev_set(trainer.model, retagged_treebank, treebank, args, evaluator, analyze_first_errors=True) tlogger.info("F1 score on %s: %f", args['eval_file'], f1) if kbestF1 is not None: tlogger.info("KBest F1 score on %s: %f", args['eval_file'], kbestF1) def remove_optimizer(args, model_save_file, model_load_file): """ A utility method to remove the optimizer from a save file Will make the save file a lot smaller """ # TODO: kind of overkill to load in the pretrain rather than # change the load/save to work without it, but probably this # functionality isn't used that often anyway load_args = { 'wordvec_pretrain_file': args['wordvec_pretrain_file'], 'charlm_forward_file': args['charlm_forward_file'], 'charlm_backward_file': args['charlm_backward_file'], 'device': args['device'], } trainer = Trainer.load(model_load_file, args=load_args, load_optimizer=False) trainer.save(model_save_file) def add_grad_clipping(trainer, grad_clipping): """ Adds a torch.clamp hook on each parameter if grad_clipping is not None """ if grad_clipping is not None: for p in trainer.model.parameters(): if p.requires_grad: p.register_hook(lambda grad: torch.clamp(grad, -grad_clipping, grad_clipping)) def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file): """ Builds a Trainer (with model) and the train_sequences and transitions for the given trees. """ train_constituents = Tree.get_unique_constituent_labels(train_trees) tlogger.info("Unique constituents in training set: %s", train_constituents) if args['check_valid_states']: check_constituents(train_constituents, dev_trees, "dev", fail=args['strict_check_constituents']) check_constituents(train_constituents, silver_trees, "silver", fail=args['strict_check_constituents']) constituent_counts = Tree.get_constituent_counts(train_trees) tlogger.info("Constituent node counts: %s", constituent_counts) tags = Tree.get_unique_tags(train_trees) if None in tags: raise RuntimeError("Fatal problem: the tagger put None on some of the nodes!") tlogger.info("Unique tags in training set: %s", tags) # no need to fail for missing tags between train/dev set # the model has an unknown tag embedding for tag in Tree.get_unique_tags(dev_trees): if tag not in tags: tlogger.info("Found tag in dev set which does not exist in train set: %s Continuing...", tag) unary_limit = max(max(t.count_unary_depth() for t in train_trees), max(t.count_unary_depth() for t in dev_trees)) + 1 if silver_trees: unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees)) tlogger.info("Unary limit: %d", unary_limit) train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], args['reversed']) dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], args['reversed']) silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], args['reversed']) tlogger.info("Total unique transitions in train set: %d", len(train_transitions)) tlogger.info("Unique transitions in training set:\n %s", "\n ".join(map(str, train_transitions))) expanded_train_transitions = set(train_transitions + [x for trans in train_transitions for x in trans.components()]) if args['check_valid_states']: parse_transitions.check_transitions(expanded_train_transitions, dev_transitions, "dev") # theoretically could just train based on the items in the silver dataset parse_transitions.check_transitions(expanded_train_transitions, silver_transitions, "silver") root_labels = Tree.get_root_labels(train_trees) check_root_labels(root_labels, dev_trees, "dev") check_root_labels(root_labels, silver_trees, "silver") tlogger.info("Root labels in treebank: %s", root_labels) verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], "train", root_labels) verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit, args['reversed'], "dev", root_labels) # we don't check against the words in the dev set as it is # expected there will be some UNK words words = Tree.get_unique_words(train_trees) rare_words = Tree.get_rare_words(train_trees, args['rare_word_threshold']) # rare/unknown silver words will just get UNK if they are not already known if silver_trees and args['use_silver_words']: tlogger.info("Getting silver words to add to the delta embedding") silver_words = Tree.get_common_words(tqdm(silver_trees, postfix='Silver words'), len(words)) words = sorted(set(words + silver_words)) # also, it's not actually an error if there is a pattern of # compound unary or compound open nodes which doesn't exist in the # train set. it just means we probably won't ever get that right open_nodes = get_open_nodes(train_trees, args['transition_scheme']) tlogger.info("Using the following open nodes:\n %s", "\n ".join(map(str, open_nodes))) # at this point we have: # pretrain # train_trees, dev_trees # lists of transitions, internal nodes, and root states the parser needs to be aware of trainer = Trainer.build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file) trainer.log_num_words_known(words) # grad clipping is not saved with the rest of the model, # so even in the case of a model we saved, # we now have to add the grad clipping add_grad_clipping(trainer, args['grad_clipping']) return trainer, train_sequences, silver_sequences, train_transitions def train(args, model_load_file, retag_pipeline): """ Build a model, train it using the requested train & dev files """ utils.log_training_args(args, tlogger) # we create the Evaluator here because otherwise the transformers # library constantly complains about forking the process # note that this won't help in the event of training multiple # models in the same run, although since that would take hours # or days, that's not a very common problem if args['num_generate'] > 0: kbest = args['num_generate'] + 1 else: kbest = None if args['wandb']: global wandb import wandb wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_constituency" % args['shorthand'] wandb.init(name=wandb_name, config=args) wandb.run.define_metric('dev_score', summary='max') with EvaluateParser(kbest=kbest) as evaluator: utils.ensure_dir(args['save_dir']) train_trees = tree_reader.read_treebank(args['train_file']) tlogger.info("Read %d trees for the training set", len(train_trees)) if args['train_remove_duplicates']: train_trees = remove_duplicate_trees(train_trees, "train") train_trees = remove_singleton_trees(train_trees) dev_trees = tree_reader.read_treebank(args['eval_file']) tlogger.info("Read %d trees for the dev set", len(dev_trees)) dev_trees = remove_duplicate_trees(dev_trees, "dev") silver_trees = [] if args['silver_file']: silver_trees = tree_reader.read_treebank(args['silver_file']) tlogger.info("Read %d trees for the silver training set", len(silver_trees)) if args['silver_remove_duplicates']: silver_trees = remove_duplicate_trees(silver_trees, "silver") if retag_pipeline is not None: tlogger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos']) dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos']) silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos']) tlogger.info("Retagging finished") foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file) if args['log_shapes']: trainer.log_shapes() trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator) if args['wandb']: wandb.finish() return trainer def compose_train_data(trees, sequences): preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label)) for preterminal in tree.yield_preterminals()] for tree in trees] data = [TrainItem(*x) for x in zip(trees, sequences, preterminal_lists)] return data def next_epoch_data(leftover_training_data, train_data, epoch_size): """ Return the next epoch_size trees from the training data, starting with leftover data from the previous epoch if there is any The training loop generally operates on a fixed number of trees, rather than going through all the trees in the training set exactly once, and keeping the leftover training data via this function ensures that each tree in the training set is touched once before beginning to iterate again. """ if not train_data: return [], [] epoch_data = leftover_training_data while len(epoch_data) < epoch_size: random.shuffle(train_data) epoch_data.extend(train_data) leftover_training_data = epoch_data[epoch_size:] epoch_data = epoch_data[:epoch_size] return leftover_training_data, epoch_data def update_bert_learning_rate(args, optimizer, epochs_trained): """ Update the learning rate for the bert finetuning, if applicable """ # would be nice to have a parameter group specific scheduler # however, there is an issue with the optimizer we had the most success with, madgrad # when the learning rate is 0 for a group, it still learns by some # small amount because of the eps parameter # in fact, that is enough to make the learning for the bert in the # second half broken for base_param_group in optimizer.param_groups: if base_param_group['param_group_name'] == 'base': break else: raise AssertionError("There should always be a base parameter group") for param_group in optimizer.param_groups: if param_group['param_group_name'] == 'bert': # Occasionally a model goes haywire and forgets how to use the transformer # So far we have only seen this happen with Electra on the non-NML version of PTB # We tried fixing that with an increasing transformer learning rate, but that # didn't fully resolve the problem # Switching to starting the finetuning after a few epochs seems to help a lot, though old_lr = param_group['lr'] if args['bert_finetune_begin_epoch'] is not None and epochs_trained < args['bert_finetune_begin_epoch']: param_group['lr'] = 0.0 elif args['bert_finetune_end_epoch'] is not None and epochs_trained >= args['bert_finetune_end_epoch']: param_group['lr'] = 0.0 elif args['multistage'] and epochs_trained < args['epochs'] // 2: param_group['lr'] = base_param_group['lr'] * args['stage1_bert_learning_rate'] else: param_group['lr'] = base_param_group['lr'] * args['bert_learning_rate'] if param_group['lr'] != old_lr: tlogger.info("Setting %s finetuning rate from %f to %f", param_group['param_group_name'], old_lr, param_group['lr']) def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator): """ Given an initialized model, a processed dataset, and a secondary dev dataset, train the model The training is iterated in the following loop: extract a batch of trees of the same length from the training set convert those trees into initial parsing states repeat until trees are done: batch predict the model's interpretation of the current states add the errors to the list of things to backprop advance the parsing state for each of the trees """ # Somewhat unusual, but possibly related to the extreme variability in length of trees # Various experiments generally show about 0.5 F1 loss on various # datasets when using 'mean' instead of 'sum' for reduction # (Remember to adjust the weight decay when rerunning that experiment) if args['loss'] == 'cross': tlogger.info("Building CrossEntropyLoss(sum)") process_outputs = lambda x: x model_loss_function = nn.CrossEntropyLoss(reduction='sum') elif args['loss'] == 'focal': try: from focal_loss.focal_loss import FocalLoss except ImportError: raise ImportError("focal_loss not installed. Must `pip install focal_loss_torch` to use the --loss=focal feature") tlogger.info("Building FocalLoss, gamma=%f", args['loss_focal_gamma']) process_outputs = lambda x: torch.softmax(x, dim=1) model_loss_function = FocalLoss(reduction='sum', gamma=args['loss_focal_gamma']) elif args['loss'] == 'large_margin': tlogger.info("Building LargeMarginInSoftmaxLoss(sum)") process_outputs = lambda x: x model_loss_function = LargeMarginInSoftmaxLoss(reduction='sum') else: raise ValueError("Unexpected loss term: %s" % args['loss']) device = trainer.device model_loss_function.to(device) transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0) for (y, x) in enumerate(trainer.transitions)} trainer.train() train_data = compose_train_data(train_trees, train_sequences) silver_data = compose_train_data(silver_trees, silver_sequences) if not args['epoch_size']: args['epoch_size'] = len(train_data) if silver_data and not args['silver_epoch_size']: args['silver_epoch_size'] = args['epoch_size'] if args['multistage']: multistage_splits = {} # if we're halfway, only do pattn. save lattn for next time multistage_splits[args['epochs'] // 2] = (args['pattn_num_layers'], False) if LSTMModel.uses_lattn(args): multistage_splits[args['epochs'] * 3 // 4] = (args['pattn_num_layers'], True) # TODO: refactor the oracle choice into the transition scheme? oracle = None if args['transition_scheme'] is TransitionScheme.IN_ORDER: oracle = InOrderOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels']) elif args['transition_scheme'] is TransitionScheme.IN_ORDER_COMPOUND: oracle = InOrderCompoundOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels']) elif args['transition_scheme'] is TransitionScheme.TOP_DOWN: oracle = TopDownOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels']) leftover_training_data = [] leftover_silver_data = [] if trainer.best_epoch > 0: tlogger.info("Restarting trainer with a model trained for %d epochs. Best epoch %d, f1 %f", trainer.epochs_trained, trainer.best_epoch, trainer.best_f1) # if we're training a new model, save the initial state so it can be inspected if args['save_each_start'] == 0 and trainer.epochs_trained == 0: trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=True) # trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1 for trainer.epochs_trained in range(trainer.epochs_trained+1, args['epochs']+1): trainer.train() tlogger.info("Starting epoch %d", trainer.epochs_trained) update_bert_learning_rate(args, trainer.optimizer, trainer.epochs_trained) if args['log_norms']: trainer.log_norms() leftover_training_data, epoch_data = next_epoch_data(leftover_training_data, train_data, args['epoch_size']) leftover_silver_data, epoch_silver_data = next_epoch_data(leftover_silver_data, silver_data, args['silver_epoch_size']) epoch_data = epoch_data + epoch_silver_data epoch_data.sort(key=lambda x: len(x[1])) epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args) # print statistics # by now we've forgotten about the original tags on the trees, # but it doesn't matter for hill climbing f1, _, _ = run_dev_set(trainer.model, dev_trees, dev_trees, args, evaluator) if f1 > trainer.best_f1 or (trainer.best_epoch == 0 and trainer.best_f1 == 0.0): # best_epoch == 0 to force a save of an initial model # useful for tests which expect something, even when a # very simple model didn't learn anything tlogger.info("New best dev score: %.5f > %.5f", f1, trainer.best_f1) trainer.best_f1 = f1 trainer.best_epoch = trainer.epochs_trained trainer.save(args['save_name'], save_optimizer=False) if epoch_stats.nans > 0: tlogger.warning("Had to ignore %d batches with NaN", epoch_stats.nans) # TODO: refactor the logging? total_correct = sum(v for _, v in epoch_stats.transitions_correct.items()) correct_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_correct[x]) for x in epoch_stats.transitions_correct]) tlogger.info("Transitions correct: %d\n %s", total_correct, correct_transitions_str) total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items()) incorrect_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_incorrect[x]) for x in epoch_stats.transitions_incorrect]) tlogger.info("Transitions incorrect: %d\n %s", total_incorrect, incorrect_transitions_str) if len(epoch_stats.repairs_used) > 0: tlogger.info("Oracle repairs:\n %s", "\n ".join("%s (%s): %d" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common())) if epoch_stats.fake_transitions_used > 0: tlogger.info("Fake transitions used: %d", epoch_stats.fake_transitions_used) stats_log_lines = [ "Epoch %d finished" % trainer.epochs_trained, "Transitions correct: %d" % total_correct, "Transitions incorrect: %d" % total_incorrect, "Total loss for epoch: %.5f" % epoch_stats.epoch_loss, "Dev score (%5d): %8f" % (trainer.epochs_trained, f1), "Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1) ] tlogger.info("\n ".join(stats_log_lines)) old_lr = trainer.optimizer.param_groups[0]['lr'] trainer.scheduler.step(f1) new_lr = trainer.optimizer.param_groups[0]['lr'] if old_lr != new_lr: tlogger.info("Updating learning rate from %f to %f", old_lr, new_lr) if args['wandb']: wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained) if args['wandb_norm_regex']: watch_regex = re.compile(args['wandb_norm_regex']) for n, p in trainer.model.named_parameters(): if watch_regex.search(n): wandb.log({n: torch.linalg.norm(p)}) if args['early_dropout'] > 0 and trainer.epochs_trained >= args['early_dropout']: if any(x > 0.0 for x in (trainer.model.word_dropout.p, trainer.model.predict_dropout.p, trainer.model.lstm_input_dropout.p)): tlogger.info("Setting dropout to 0.0 at epoch %d", trainer.epochs_trained) trainer.model.word_dropout.p = 0 trainer.model.predict_dropout.p = 0 trainer.model.lstm_input_dropout.p = 0 # recreate the optimizer and alter the model as needed if we hit a new multistage split if args['multistage'] and trainer.epochs_trained in multistage_splits: # we may be loading a save model from an earlier epoch if the scores stopped increasing epochs_trained = trainer.epochs_trained batches_trained = trainer.batches_trained stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained] # when loading the model, let the saved model determine whether it has pattn or lattn temp_args = copy.deepcopy(trainer.model.args) temp_args.pop('pattn_num_layers', None) temp_args.pop('lattn_d_proj', None) # overwriting the old trainer & model will hopefully free memory # load a new bert, even in PEFT mode, mostly so that the bert model # doesn't collect a whole bunch of PEFTs # for one thing, two PEFTs would mean 2x the optimizer parameters, # messing up saving and loading the optimizer without jumping # through more hoops # loading the trainer w/o the foundation_cache should create # the necessary bert_model and bert_tokenizer, and then we # can reuse those values when building out new LSTMModel trainer = Trainer.load(args['save_name'], temp_args, load_optimizer=False) model = trainer.model tlogger.info("Finished stage at epoch %d. Restarting optimizer", epochs_trained) tlogger.info("Previous best model was at epoch %d", trainer.epochs_trained) temp_args = dict(args) tlogger.info("Switching to a model with %d pattn layers and %slattn", stage_pattn_layers, "" if stage_uses_lattn else "NO ") temp_args['pattn_num_layers'] = stage_pattn_layers if not stage_uses_lattn: temp_args['lattn_d_proj'] = 0 pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file']) forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file']) backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file']) new_model = LSTMModel(pt, forward_charlm, backward_charlm, model.bert_model, model.bert_tokenizer, model.force_bert_saved, model.peft_name, model.transitions, model.constituents, model.tags, model.delta_words, model.rare_words, model.root_labels, model.constituent_opens, model.unary_limit(), temp_args) new_model.to(device) new_model.copy_with_new_structure(model) optimizer = build_optimizer(temp_args, new_model, False) scheduler = build_scheduler(temp_args, optimizer) trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, batches_trained, trainer.best_f1, trainer.best_epoch) add_grad_clipping(trainer, args['grad_clipping']) # checkpoint needs to be saved AFTER rebuilding the optimizer # so that assumptions about the optimizer in the checkpoint # can be made based on the end of the epoch if args['checkpoint'] and args['checkpoint_save_name']: trainer.save(args['checkpoint_save_name'], save_optimizer=True) # same with the "each filename", actually, in case those are # brought back for more training or even just for testing if args['save_each_start'] is not None and args['save_each_start'] <= trainer.epochs_trained and trainer.epochs_trained % args['save_each_frequency'] == 0: trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=args['save_each_optimizer']) return trainer def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args): interval_starts = list(range(0, len(epoch_data), args['train_batch_size'])) random.shuffle(interval_starts) optimizer = trainer.optimizer epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0) for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)): batch = epoch_data[interval_start:interval_start+args['train_batch_size']] batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args) trainer.batches_trained += 1 # Early in the training, some trees will be degenerate in a # way that results in layers going up the tree amplifying the # weights until they overflow. Generally that problem # resolves itself in a few iterations, so for now we just # ignore those batches, but report how often it happens if batch_stats.nans == 0: optimizer.step() optimizer.zero_grad() epoch_stats = epoch_stats + batch_stats return epoch_stats def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args): """ Train the model for one batch The model itself will be updated, and a bunch of stats are returned It is unclear if this refactoring is useful in any way. Might not be ... although the indentation does get pretty ridiculous if this is merged into train_model_one_epoch and then iterate_training """ # now we add the state to the trees in the batch # the state is built as a bulk operation current_batch = model.initial_state_from_preterminals([x.preterminals for x in training_batch], [x.tree for x in training_batch], [x.gold_sequence for x in training_batch]) transitions_correct = Counter() transitions_incorrect = Counter() repairs_used = Counter() fake_transitions_used = 0 all_errors = [] all_answers = [] # we iterate through the batch in the following sequence: # predict the logits and the applied transition for each tree in the batch # collect errors # - we always train to the desired one-hot vector # this was a noticeable improvement over training just the # incorrect transitions # determine whether the training can continue using the "student" transition # or if we need to use teacher forcing # update all states using either the gold or predicted transition # any trees which are now finished are removed from the training cycle while len(current_batch) > 0: outputs, pred_transitions, _ = model.predict(current_batch, is_legal=False) gold_transitions = [x.gold_sequence[x.num_transitions] for x in current_batch] trans_tensor = [transition_tensors[gold_transition] for gold_transition in gold_transitions] all_errors.append(outputs) all_answers.extend(trans_tensor) new_batch = [] update_transitions = [] for pred_transition, gold_transition, state in zip(pred_transitions, gold_transitions, current_batch): # forget teacher forcing vs scheduled sampling # we're going with idiot forcing if pred_transition == gold_transition: transitions_correct[gold_transition.short_name()] += 1 if state.num_transitions + 1 < len(state.gold_sequence): if oracle is not None and epoch >= args['oracle_initial_epoch'] and random.random() < args['oracle_forced_errors']: # TODO: could randomly choose from the legal transitions # perhaps the second best scored transition fake_transition = random.choice(model.transitions) if fake_transition.is_legal(state, model): _, new_sequence = oracle.fix_error(fake_transition, model, state) if new_sequence is not None: new_batch.append(state._replace(gold_sequence=new_sequence)) update_transitions.append(fake_transition) fake_transitions_used = fake_transitions_used + 1 continue new_batch.append(state) update_transitions.append(gold_transition) continue transitions_incorrect[gold_transition.short_name(), pred_transition.short_name()] += 1 # if we are on the final operation, there are two choices: # - the parsing mode is IN_ORDER, and the final transition # is the close to end the sequence, which has no alternatives # - the parsing mode is something else, in which case # we have no oracle anyway if state.num_transitions + 1 >= len(state.gold_sequence): continue if oracle is None or epoch < args['oracle_initial_epoch'] or not pred_transition.is_legal(state, model): new_batch.append(state) update_transitions.append(gold_transition) continue repair_type, new_sequence = oracle.fix_error(pred_transition, model, state) # we can only reach here on an error assert not repair_type.is_correct repairs_used[repair_type] += 1 if new_sequence is not None and random.random() < args['oracle_frequency']: new_batch.append(state._replace(gold_sequence=new_sequence)) update_transitions.append(pred_transition) else: new_batch.append(state) update_transitions.append(gold_transition) if len(current_batch) > 0: # bulk update states - significantly faster current_batch = model.bulk_apply(new_batch, update_transitions, fail=True) errors = torch.cat(all_errors) answers = torch.cat(all_answers) errors = process_outputs(errors) tree_loss = model_loss_function(errors, answers) tree_loss.backward() if args['watch_regex']: matched = False tlogger.info("Watching %s ... epoch %d batch %d", args['watch_regex'], epoch, batch_idx) watch_regex = re.compile(args['watch_regex']) for n, p in trainer.model.named_parameters(): if watch_regex.search(n): matched = True if p.requires_grad and p.grad is not None: tlogger.info(" %s norm: %f grad: %f", n, torch.linalg.norm(p), torch.linalg.norm(p.grad)) elif p.requires_grad: tlogger.info(" %s norm: %f grad required, but is None!", n, torch.linalg.norm(p)) else: tlogger.info(" %s norm: %f grad not required", n, torch.linalg.norm(p)) if not matched: tlogger.info(" (none found!)") if torch.any(torch.isnan(tree_loss)): batch_loss = 0.0 nans = 1 else: batch_loss = tree_loss.item() nans = 0 return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None, analyze_first_errors=False): """ This reparses a treebank and executes the CoreNLP Java EvalB code. It only works if CoreNLP 4.3.0 or higher is in the classpath. """ tlogger.info("Processing %d trees from %s", len(retagged_trees), args['eval_file']) model.eval() num_generate = args.get('num_generate', 0) keep_scores = num_generate > 0 sorted_trees, original_indices = sort_with_indices(retagged_trees, key=len, reverse=True) tree_iterator = iter(tqdm(sorted_trees)) treebank = model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.predict, keep_scores=keep_scores) treebank = unsort(treebank, original_indices) full_results = treebank if num_generate > 0: tlogger.info("Generating %d random analyses", args['num_generate']) generated_treebanks = [treebank] for i in tqdm(range(num_generate)): tree_iterator = iter(tqdm(retagged_trees, leave=False, postfix="tb%03d" % i)) generated_treebanks.append(model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.weighted_choice, keep_scores=keep_scores)) #best_treebank = [ParseResult(parses[0].gold, [max([p.predictions[0] for p in parses], key=itemgetter(1))], None, None) # for parses in zip(*generated_treebanks)] #generated_treebanks = [best_treebank] + generated_treebanks # TODO: if the model is dropping trees, this will not work full_results = [ParseResult(parses[0].gold, [p.predictions[0] for p in parses], None, None) for parses in zip(*generated_treebanks)] if len(full_results) < len(retagged_trees): tlogger.warning("Only evaluating %d trees instead of %d", len(full_results), len(retagged_trees)) else: full_results = [x._replace(gold=gold) for x, gold in zip(full_results, original_trees)] if args.get('mode', None) == 'predict' and args['predict_file']: utils.ensure_dir(args['predict_dir'], verbose=False) pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".pred.mrg") orig_file = os.path.join(args['predict_dir'], args['predict_file'] + ".orig.mrg") if os.path.exists(pred_file): tlogger.warning("Cowardly refusing to overwrite {}".format(pred_file)) elif os.path.exists(orig_file): tlogger.warning("Cowardly refusing to overwrite {}".format(orig_file)) else: with open(pred_file, 'w') as fout: for tree in full_results: output_tree = tree.predictions[0].tree if args['predict_output_gold_tags']: output_tree = output_tree.replace_tags(tree.gold) fout.write(args['predict_format'].format(output_tree)) fout.write("\n") for i in range(num_generate): pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".%03d.pred.mrg" % i) with open(pred_file, 'w') as fout: for tree in generated_treebanks[-(i+1)]: output_tree = tree.predictions[0].tree if args['predict_output_gold_tags']: output_tree = output_tree.replace_tags(tree.gold) fout.write(args['predict_format'].format(output_tree)) fout.write("\n") with open(orig_file, 'w') as fout: for tree in full_results: fout.write(args['predict_format'].format(tree.gold)) fout.write("\n") if len(full_results) == 0: return 0.0, 0.0 if evaluator is None: if num_generate > 0: kbest = max(len(fr.predictions) for fr in full_results) else: kbest = None with EvaluateParser(kbest=kbest) as evaluator: response = evaluator.process(full_results) else: response = evaluator.process(full_results) if analyze_first_errors and args['transition_scheme'] is TransitionScheme.IN_ORDER: errors = Counter() for result in full_results: first_error = error_analysis_in_order.analyze_tree(result.gold, result.predictions[0].tree) errors[first_error] += 1 log_lines = ["%30s: %d" % (key.name, errors[key]) for key in error_analysis_in_order.FirstError] tlogger.info("First error frequency:\n %s", "\n ".join(log_lines)) kbestF1 = response.kbestF1 if response.HasField("kbestF1") else None return response.f1, kbestF1, response.treeF1 ================================================ FILE: stanza/models/constituency/partitioned_transformer.py ================================================ """ Transformer with partitioned content and position features. See section 3 of https://arxiv.org/pdf/1805.01052.pdf """ import copy import math import torch import torch.nn as nn import torch.nn.functional as F from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding class FeatureDropoutFunction(torch.autograd.function.InplaceFunction): @staticmethod def forward(ctx, input, p=0.5, train=False, inplace=False): if p < 0 or p > 1: raise ValueError( "dropout probability has to be between 0 and 1, but got {}".format(p) ) ctx.p = p ctx.train = train ctx.inplace = inplace if ctx.inplace: ctx.mark_dirty(input) output = input else: output = input.clone() if ctx.p > 0 and ctx.train: ctx.noise = torch.empty( (input.size(0), input.size(-1)), dtype=input.dtype, layout=input.layout, device=input.device, ) if ctx.p == 1: ctx.noise.fill_(0) else: ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p) ctx.noise = ctx.noise[:, None, :] output.mul_(ctx.noise) return output @staticmethod def backward(ctx, grad_output): if ctx.p > 0 and ctx.train: return grad_output.mul(ctx.noise), None, None, None else: return grad_output, None, None, None class FeatureDropout(nn.Dropout): """ Feature-level dropout: takes an input of size len x num_features and drops each feature with probabibility p. A feature is dropped across the full portion of the input that corresponds to a single batch element. """ def forward(self, x): if isinstance(x, tuple): x_c, x_p = x x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace) x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace) return x_c, x_p else: return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace) # TODO: this module apparently is not treated the same the built-in # nonlinearity modules, as multiple uses of the same relu on different # tensors winds up mixing the gradients See if there is a way to # resolve that other than creating a new nonlinearity for each layer class PartitionedReLU(nn.ReLU): def forward(self, x): if isinstance(x, tuple): x_c, x_p = x else: x_c, x_p = torch.chunk(x, 2, dim=-1) return super().forward(x_c), super().forward(x_p) class PartitionedLinear(nn.Module): def __init__(self, in_features, out_features, bias=True): super().__init__() self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias) self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias) def forward(self, x): if isinstance(x, tuple): x_c, x_p = x else: x_c, x_p = torch.chunk(x, 2, dim=-1) out_c = self.linear_c(x_c) out_p = self.linear_p(x_p) return out_c, out_p class PartitionedMultiHeadAttention(nn.Module): def __init__( self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02 ): super().__init__() self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2)) self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2)) self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2)) self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2)) bound = math.sqrt(3.0) * initializer_range for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]: nn.init.uniform_(param, -bound, bound) self.scaling_factor = 1 / d_qkv ** 0.5 self.dropout = nn.Dropout(attention_dropout) def forward(self, x, mask=None): if isinstance(x, tuple): x_c, x_p = x else: x_c, x_p = torch.chunk(x, 2, dim=-1) qkv_c = torch.einsum("btf,hfca->bhtca", x_c, self.w_qkv_c) qkv_p = torch.einsum("btf,hfca->bhtca", x_p, self.w_qkv_p) q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)] q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)] q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor k = torch.cat([k_c, k_p], dim=-1) v = torch.cat([v_c, v_p], dim=-1) dots = torch.einsum("bhqa,bhka->bhqk", q, k) if mask is not None: dots.data.masked_fill_(~mask[:, None, None, :], -float("inf")) probs = F.softmax(dots, dim=-1) probs = self.dropout(probs) o = torch.einsum("bhqk,bhka->bhqa", probs, v) o_c, o_p = torch.chunk(o, 2, dim=-1) out_c = torch.einsum("bhta,haf->btf", o_c, self.w_o_c) out_p = torch.einsum("bhta,haf->btf", o_p, self.w_o_p) return out_c, out_p class PartitionedTransformerEncoderLayer(nn.Module): def __init__(self, d_model, n_head, d_qkv, d_ff, ff_dropout, residual_dropout, attention_dropout, activation=PartitionedReLU(), ): super().__init__() self.self_attn = PartitionedMultiHeadAttention( d_model, n_head, d_qkv, attention_dropout=attention_dropout ) self.linear1 = PartitionedLinear(d_model, d_ff) self.ff_dropout = FeatureDropout(ff_dropout) self.linear2 = PartitionedLinear(d_ff, d_model) self.norm_attn = nn.LayerNorm(d_model) self.norm_ff = nn.LayerNorm(d_model) self.residual_dropout_attn = FeatureDropout(residual_dropout) self.residual_dropout_ff = FeatureDropout(residual_dropout) self.activation = activation def forward(self, x, mask=None): residual = self.self_attn(x, mask=mask) residual = torch.cat(residual, dim=-1) residual = self.residual_dropout_attn(residual) x = self.norm_attn(x + residual) residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x)))) residual = torch.cat(residual, dim=-1) residual = self.residual_dropout_ff(residual) x = self.norm_ff(x + residual) return x class PartitionedTransformerEncoder(nn.Module): def __init__(self, n_layers, d_model, n_head, d_qkv, d_ff, ff_dropout, residual_dropout, attention_dropout, activation=PartitionedReLU, ): super().__init__() self.layers = nn.ModuleList([PartitionedTransformerEncoderLayer(d_model=d_model, n_head=n_head, d_qkv=d_qkv, d_ff=d_ff, ff_dropout=ff_dropout, residual_dropout=residual_dropout, attention_dropout=attention_dropout, activation=activation()) for i in range(n_layers)]) def forward(self, x, mask=None): for layer in self.layers: x = layer(x, mask=mask) return x class ConcatPositionalEncoding(nn.Module): """ Learns a position embedding """ def __init__(self, d_model=256, max_len=512): super().__init__() self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model)) nn.init.normal_(self.timing_table) def forward(self, x): timing = self.timing_table[:x.shape[1], :] timing = timing.expand(x.shape[0], -1, -1) out = torch.cat([x, timing], dim=-1) return out # class PartitionedTransformerModule(nn.Module): def __init__(self, n_layers, d_model, n_head, d_qkv, d_ff, ff_dropout, residual_dropout, attention_dropout, word_input_size, bias, morpho_emb_dropout, timing, encoder_max_len, activation=PartitionedReLU() ): super().__init__() self.project_pretrained = nn.Linear( word_input_size, d_model // 2, bias=bias ) self.pattention_morpho_emb_dropout = FeatureDropout(morpho_emb_dropout) if timing == 'sin': self.add_timing = ConcatSinusoidalEncoding(d_model=d_model // 2, max_len=encoder_max_len) elif timing == 'learned': self.add_timing = ConcatPositionalEncoding(d_model=d_model // 2, max_len=encoder_max_len) else: raise ValueError("Unhandled timing type: %s" % timing) self.transformer_input_norm = nn.LayerNorm(d_model) self.pattn_encoder = PartitionedTransformerEncoder( n_layers, d_model=d_model, n_head=n_head, d_qkv=d_qkv, d_ff=d_ff, ff_dropout=ff_dropout, residual_dropout=residual_dropout, attention_dropout=attention_dropout, ) # def forward(self, attention_mask, bert_embeddings): # Prepares attention mask for feeding into the self-attention device = bert_embeddings[0].device if attention_mask: valid_token_mask = attention_mask else: valids = [] for sent in bert_embeddings: valids.append(torch.ones(len(sent), device=device)) padded_data = torch.nn.utils.rnn.pad_sequence( valids, batch_first=True, padding_value=-100 ) valid_token_mask = padded_data != -100 valid_token_mask = valid_token_mask.to(device=device) padded_embeddings = torch.nn.utils.rnn.pad_sequence( bert_embeddings, batch_first=True, padding_value=0 ) # Project the pretrained embedding onto the desired dimension extra_content_annotations = self.project_pretrained(padded_embeddings) # Add positional information through the table encoder_in = self.add_timing(self.pattention_morpho_emb_dropout(extra_content_annotations)) encoder_in = self.transformer_input_norm(encoder_in) # Put the partitioned input through the partitioned attention annotations = self.pattn_encoder(encoder_in, valid_token_mask) return annotations ================================================ FILE: stanza/models/constituency/positional_encoding.py ================================================ """ Based on https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model """ import math import torch from torch import nn class SinusoidalEncoding(nn.Module): """ Uses sine & cosine to represent position """ def __init__(self, model_dim, max_len): super().__init__() self.register_buffer('pe', self.build_position(model_dim, max_len)) @staticmethod def build_position(model_dim, max_len, device=None): position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim)) pe = torch.zeros(max_len, model_dim) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) if device is not None: pe = pe.to(device=device) return pe def forward(self, x): if max(x) >= self.pe.shape[0]: # try to drop the reference first before creating a new encoding # the goal being to save memory if we are close to the memory limit device = self.pe.device shape = self.pe.shape[1] self.register_buffer('pe', None) # TODO: this may result in very poor performance # in the event of a model that increases size one at a time self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device)) return self.pe[x] def max_len(self): return self.pe.shape[0] class AddSinusoidalEncoding(nn.Module): """ Uses sine & cosine to represent position. Adds the position to the given matrix Default behavior is batch_first """ def __init__(self, d_model=256, max_len=512): super().__init__() self.encoding = SinusoidalEncoding(d_model, max_len) def forward(self, x, scale=1.0): """ Adds the positional encoding to the input tensor The tensor is expected to be of the shape B, N, D Properly masking the output tensor is up to the caller """ if len(x.shape) == 3: timing = self.encoding(torch.arange(x.shape[1], device=x.device)) timing = timing.expand(x.shape[0], -1, -1) elif len(x.shape) == 2: timing = self.encoding(torch.arange(x.shape[0], device=x.device)) return x + timing * scale class ConcatSinusoidalEncoding(nn.Module): """ Uses sine & cosine to represent position. Concats the position and returns a larger object Default behavior is batch_first """ def __init__(self, d_model=256, max_len=512): super().__init__() self.encoding = SinusoidalEncoding(d_model, max_len) def forward(self, x): if len(x.shape) == 3: timing = self.encoding(torch.arange(x.shape[1], device=x.device)) timing = timing.expand(x.shape[0], -1, -1) else: timing = self.encoding(torch.arange(x.shape[0], device=x.device)) out = torch.cat((x, timing), dim=-1) return out ================================================ FILE: stanza/models/constituency/retagging.py ================================================ """ Refactor a few functions specifically for retagging trees Retagging is important because the gold tags will not be available at runtime Note that the method which does the actual retagging is in utils.py so as to avoid unnecessary circular imports (eg, Pipeline imports constituency/trainer which imports this which imports Pipeline) """ import copy import logging from stanza import Pipeline from stanza.models.common.foundation_cache import FoundationCache from stanza.models.common.vocab import VOCAB_PREFIX from stanza.resources.common import download_resources_json, load_resources_json, get_language_resources tlogger = logging.getLogger('stanza.constituency.trainer') # xpos tagger doesn't produce PP tag on the turin treebank, # so instead we use upos to avoid unknown tag errors RETAG_METHOD = { "da": "upos", # the DDT has no xpos tags anyway "de": "upos", # DE GSD is also missing a few punctuation tags "es": "upos", # AnCora has half-finished xpos tags "id": "upos", # GSD is missing a few punctuation tags - fixed in 2.12, though "it": "upos", "pt": "upos", # default PT model has no xpos either "vi": "xpos", # the new version of UD can be merged with xpos from VLSP22 } def add_retag_args(parser): """ Arguments specifically for retagging treebanks """ parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time') parser.add_argument('--retag_method', default=None, choices=['xpos', 'upos'], help='Which tags to use when retagging. Default depends on the language') parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default. Can specify multiple taggers with ; in which case the majority vote wins') parser.add_argument('--retag_pretrain_path', default=None, help='Use this for a pretrain path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom pretrain') parser.add_argument('--retag_charlm_forward_file', default=None, help='Use this for a forward charlm path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom charlm') parser.add_argument('--retag_charlm_backward_file', default=None, help='Use this for a backward charlm path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom charlm') parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees") def postprocess_args(args): """ After parsing args, unify some settings """ # use a language specific default for retag_method if we know the language # otherwise, use xpos if args['retag_method'] is None and 'lang' in args and args['lang'] in RETAG_METHOD: args['retag_method'] = RETAG_METHOD[args['lang']] if args['retag_method'] is None: args['retag_method'] = 'xpos' if args['retag_method'] == 'xpos': args['retag_xpos'] = True elif args['retag_method'] == 'upos': args['retag_xpos'] = False else: raise ValueError("Unknown retag method {}".format(xpos)) def build_retag_pipeline(args): """ Builds retag pipelines based on the arguments May alter the arguments if the pipeline is incompatible, such as taggers with no xpos Will return a list of one or more retag pipelines. Multiple tagger models can be specified by having them semi-colon separated in retag_model_path. """ # some argument sets might not use 'mode' if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer': download_resources_json() resources = load_resources_json() if '_' in args['retag_package']: lang, package = args['retag_package'].split('_', 1) lang_resources = get_language_resources(resources, lang) if lang_resources is None and 'lang' in args: lang_resources = get_language_resources(resources, args['lang']) if lang_resources is not None and 'pos' in lang_resources and args['retag_package'] in lang_resources['pos']: lang = args['lang'] package = args['retag_package'] else: if 'lang' not in args: raise ValueError("Retag package %s does not specify the language, and it is not clear from the arguments" % args['retag_package']) lang = args.get('lang', None) package = args['retag_package'] foundation_cache = FoundationCache() retag_args = {"lang": lang, "processors": "tokenize, pos", "tokenize_pretokenized": True, "package": {"pos": package}} if args['retag_pretrain_path'] is not None: retag_args['pos_pretrain_path'] = args['retag_pretrain_path'] if args['retag_charlm_forward_file'] is not None: retag_args['pos_forward_charlm_path'] = args['retag_charlm_forward_file'] if args['retag_charlm_backward_file'] is not None: retag_args['pos_backward_charlm_path'] = args['retag_charlm_backward_file'] def build(retag_args, path): retag_args = copy.deepcopy(retag_args) # we just downloaded the resources a moment ago # no need to repeatedly download retag_args['download_method'] = 'reuse_resources' if path is not None: retag_args['allow_unknown_language'] = True retag_args['pos_model_path'] = path tlogger.debug('Creating retag pipeline using %s', path) else: tlogger.debug('Creating retag pipeline for %s package', package) retag_pipeline = Pipeline(foundation_cache=foundation_cache, **retag_args) if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX): tlogger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package) args['retag_xpos'] = False args['retag_method'] = 'upos' return retag_pipeline if args['retag_model_path'] is None: return [build(retag_args, None)] paths = args['retag_model_path'].split(";") # can be length 1 if only one tagger to work with return [build(retag_args, path) for path in paths] return None ================================================ FILE: stanza/models/constituency/score_converted_dependencies.py ================================================ """ Script which processes a dependency file by using the constituency parser, then converting with the CoreNLP converter Currently this does not have the constituency parser as an option, although that is easy to add. Only English is supported, as only English is available in the CoreNLP converter """ import argparse import os import tempfile import stanza from stanza.models.constituency import retagging from stanza.models.depparse import scorer from stanza.utils.conll import CoNLL def score_converted_dependencies(args): if args['lang'] != 'en': raise ValueError("Converting and scoring dependencies is currently only supported for English") constituency_package = args['constituency_package'] pipeline_args = {'lang': args['lang'], 'tokenize_pretokenized': True, 'package': {'pos': args['retag_package'], 'depparse': 'converter', 'constituency': constituency_package}, 'processors': 'tokenize, pos, constituency, depparse'} pipeline = stanza.Pipeline(**pipeline_args) input_doc = CoNLL.conll2doc(args['eval_file']) output_doc = pipeline(input_doc) print("Processed %d sentences" % len(output_doc.sentences)) # reload - the pipeline clobbered the gold values input_doc = CoNLL.conll2doc(args['eval_file']) scorer.score_named_dependencies(output_doc, input_doc) with tempfile.TemporaryDirectory() as tempdir: output_path = os.path.join(tempdir, "converted.conll") CoNLL.write_doc2conll(output_doc, output_path) _, _, score = scorer.score(output_path, args['eval_file']) print("Parser score:") print("{} {:.2f}".format(constituency_package, score*100)) def main(): parser = argparse.ArgumentParser() parser.add_argument('--lang', default='en', type=str, help='Language') parser.add_argument('--eval_file', default="extern_data/ud2/ud-treebanks-v2.13/UD_English-EWT/en_ewt-ud-test.conllu", help='Input file for data loader.') parser.add_argument('--constituency_package', default="ptb3-revised_electra-large", help='Which constituency parser to use for converting') retagging.add_retag_args(parser) args = parser.parse_args() args = vars(args) retagging.postprocess_args(args) score_converted_dependencies(args) if __name__ == '__main__': main() ================================================ FILE: stanza/models/constituency/state.py ================================================ from collections import namedtuple class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'gold_tree', 'gold_sequence', 'sentence_length', 'num_opens', 'word_position', 'score', 'broken'])): """ Represents a partially completed transition parse Includes stack/buffers for unused words, already executed transitions, and partially build constituents At training time, also keeps track of the gold data we are reparsing num_opens is useful for tracking 1) if the parser is in a stuck state where it is making infinite opens 2) if a close transition is impossible because there are no previous opens sentence_length tracks how long the sentence is so we abort if we go infinite non-stack information such as sentence_length and num_opens will be copied from the original_state if possible, with the exact arguments overriding the values in the original_state gold_tree: the original tree, if made from a gold tree. might be None gold_sequence: the original transition sequence, if available Note that at runtime, gold values will not be available word_position tracks where in the word queue we are. cheaper than manipulating the list itself. this can be handled differently from transitions and constituents as it is processed once at the start of parsing The word_queue should have both a start and an end word. Those can be None in the case of the endpoints if they are unused. """ def empty_word_queue(self): # the first element of each stack is a sentinel with no value # and no parent return self.word_position == self.sentence_length def empty_transitions(self): # the first element of each stack is a sentinel with no value # and no parent return self.transitions.parent is None def has_one_constituent(self): # a length of 1 represents no constituents return self.constituents.length == 2 @property def empty_constituents(self): return self.constituents.parent is None def num_constituents(self): return self.constituents.length - 1 @property def num_transitions(self): # -1 for the sentinel value return self.transitions.length - 1 def get_word(self, pos): # +1 to handle the initial sentinel value # (which you can actually get with pos=-1) return self.word_queue[pos+1] def finished(self, model): return self.empty_word_queue() and self.has_one_constituent() and model.get_top_constituent(self.constituents).label in model.root_labels def get_tree(self, model): return model.get_top_constituent(self.constituents) def all_transitions(self, model): # TODO: rewrite this to be nicer / faster? or just refactor? all_transitions = [] transitions = self.transitions while transitions.parent is not None: all_transitions.append(model.get_top_transition(transitions)) transitions = transitions.parent return list(reversed(all_transitions)) def all_constituents(self, model): # TODO: rewrite this to be nicer / faster? all_constituents = [] constituents = self.constituents while constituents.parent is not None: all_constituents.append(model.get_top_constituent(constituents)) constituents = constituents.parent return list(reversed(all_constituents)) def all_words(self, model): return [model.get_word(x) for x in self.word_queue] def to_string(self, model): return "State(\n buffer:%s\n transitions:%s\n constituents:%s\n word_position:%d num_opens:%d)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)), self.word_position, self.num_opens) def __str__(self): return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.word_queue), str(self.transitions), str(self.constituents)) class MultiState(namedtuple('MultiState', ['states', 'gold_tree', 'gold_sequence', 'score'])): def finished(self, ensemble): return self.states[0].finished(ensemble.models[0]) def get_tree(self, ensemble): return self.states[0].get_tree(ensemble.models[0]) @property def empty_constituents(self): return self.states[0].empty_constituents def num_constituents(self): return len(self.states[0].constituents) - 1 @property def num_transitions(self): # -1 for the sentinel value return len(self.states[0].transitions) - 1 @property def num_opens(self): return self.states[0].num_opens @property def sentence_length(self): return self.states[0].sentence_length def empty_word_queue(self): return self.states[0].empty_word_queue() def empty_transitions(self): return self.states[0].empty_transitions() @property def constituents(self): # warning! if there is information in the constituents such as # the embedding of the constituent, this will only contain the # first such embedding # the other models' constituent states won't be returned return self.states[0].constituents @property def transitions(self): # warning! if there is information in the transitions such as # the embedding of the transition, this will only contain the # first such embedding # the other models' transition states won't be returned return self.states[0].transitions ================================================ FILE: stanza/models/constituency/text_processing.py ================================================ import os import logging from stanza.models.common import utils from stanza.models.constituency.utils import retag_tags from stanza.models.constituency.trainer import Trainer from stanza.models.constituency.tree_reader import read_trees from stanza.utils.get_tqdm import get_tqdm logger = logging.getLogger('stanza') tqdm = get_tqdm() def read_tokenized_file(tokenized_file): """ Read sentences from a tokenized file, potentially replacing _ with space for languages such as VI """ with open(tokenized_file, encoding='utf-8') as fin: lines = fin.readlines() lines = [x.strip() for x in lines] lines = [x for x in lines if x] docs = [[word if all(x == '_' for x in word) else word.replace("_", " ") for word in sentence.split()] for sentence in lines] ids = [None] * len(docs) return docs, ids def read_xml_tree_file(tree_file): """ Read sentences from a file of the format unique to VLSP test sets in particular, it should be multiple blocks of (tree ...) """ with open(tree_file, encoding='utf-8') as fin: lines = fin.readlines() lines = [x.strip() for x in lines] lines = [x for x in lines if x] docs = [] ids = [] tree_id = None tree_text = [] for line in lines: if line.startswith(" 1: tree_id = tree_id[1] if tree_id.endswith(">"): tree_id = tree_id[:-1] tree_id = int(tree_id) else: tree_id = None elif line.startswith("= len(gold_sequence): raise AssertionError("Found a sequence of OpenConstituent at the end of a TOP_DOWN sequence!") if not isinstance(gold_sequence[shift_index], Shift): raise AssertionError("Expected to find a Shift after a sequence of OpenConstituent. There should not be a %s" % gold_sequence[shift_index]) #print("Input sequence: %s\nIndex %d\nGold %s Pred %s" % (gold_sequence, gold_index, gold_transition, pred_transition)) updated_sequence = gold_sequence while shift_index > gold_index: close_index = advance_past_constituents(updated_sequence, shift_index) if close_index is None: raise AssertionError("Did not find a corresponding Close for this Open") # cut out the corresponding open and close updated_sequence = updated_sequence[:shift_index-1] + updated_sequence[shift_index:close_index] + updated_sequence[close_index+1:] shift_index -= 1 #print(" %s" % updated_sequence) #print("Final updated sequence: %s" % updated_sequence) return updated_sequence def fix_nested_open_constituent(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ We were supposed to predict Open(X), then Open(Y), but predicted Open(Y) instead We treat this as a single recall error. We could even go crazy and turn it into a Unary, such as Open(Y), Open(X), Open(Y)... presumably that would be very confusing to the parser not to mention ambiguous as to where to close the new constituent """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, OpenConstituent): return None assert len(gold_sequence) > gold_index + 1 if not isinstance(gold_sequence[gold_index+1], OpenConstituent): return None # This replacement works if we skipped exactly one level if gold_sequence[gold_index+1].label != pred_transition.label: return None close_index = advance_past_constituents(gold_sequence, gold_index+1) assert close_index is not None updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:] return updated_sequence def fix_shift_open_immediate_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ We were supposed to Shift, but instead we Opened The biggest problem with this type of error is that the Close of the Open is ambiguous. We could put it immediately before the next Close, immediately after the Shift, or anywhere in between. One unambiguous case would be if the proper sequence was Shift - Close. Then it is unambiguous that the only possible repair is Open - Shift - Close - Close. """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, Shift): return None assert len(gold_sequence) > gold_index + 1 if not isinstance(gold_sequence[gold_index+1], CloseConstituent): # this is the ambiguous case return None return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:] def fix_shift_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ We were supposed to Shift, but instead we Opened The biggest problem with this type of error is that the Close of the Open is ambiguous. We could put it immediately before the next Close, immediately after the Shift, or anywhere in between. In this fix, we are testing what happens if we treat this Open as a Unary transition. """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, Shift): return None assert len(gold_sequence) > gold_index + 1 if isinstance(gold_sequence[gold_index+1], CloseConstituent): # this is the unambiguous case, which should already be handled return None return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:] def fix_shift_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ We were supposed to Shift, but instead we Opened The biggest problem with this type of error is that the Close of the Open is ambiguous. We could put it immediately before the next Close, immediately after the Shift, or anywhere in between. In this fix, we put the corresponding Close for this Open at the end of the enclosing bracket. """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, Shift): return None assert len(gold_sequence) > gold_index + 1 if isinstance(gold_sequence[gold_index+1], CloseConstituent): # this is the unambiguous case, which should already be handled return None outer_close_index = advance_past_constituents(gold_sequence, gold_index) return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:outer_close_index] + [CloseConstituent()] + gold_sequence[outer_close_index:] def fix_shift_open_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, Shift): return None assert len(gold_sequence) > gold_index + 1 if isinstance(gold_sequence[gold_index+1], CloseConstituent): # this is the unambiguous case, which should already be handled return None # at this point: have Opened a constituent which we don't want # need to figure out where to Close it # could close it after the shift or after any given block candidates = [] current_index = gold_index while not isinstance(gold_sequence[current_index], CloseConstituent): if isinstance(gold_sequence[current_index], Shift): end_index = current_index else: end_index = find_constituent_end(gold_sequence, current_index) candidates.append((gold_sequence[:gold_index], [pred_transition], gold_sequence[gold_index:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:])) current_index = end_index + 1 scores, best_idx, best_candidate = score_candidates_single_block(model, state, candidates, candidate_idx=3) if best_idx == len(candidates) - 1: best_idx = -1 repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.name, value="%d.%d" % (RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.value, best_idx), is_correct=False) return repair_type, best_candidate def fix_close_shift_ambiguous_immediate(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ Instead of a Close, we predicted a Shift. This time, we immediately close no matter what comes after the next Shift. An alternate strategy would be to Close at the closing of the outer constituent. """ if not isinstance(pred_transition, Shift): return None if not isinstance(gold_transition, CloseConstituent): return None num_closes = 0 while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent): num_closes += 1 if not isinstance(gold_sequence[gold_index + num_closes], Shift): # TODO: we should be able to handle this case too (an Open) # however, it will be rare once the parser gets going and it # would cause a lot of errors, anyway return None if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent): # this one should just have been satisfied in the non-ambiguous version return None updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+1:] return updated_sequence def fix_close_shift_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ Instead of a Close, we predicted a Shift. This time, we close at the end of the outer bracket no matter what comes after the next Shift. An alternate strategy would be to Close as soon as possible after the Shift. """ if not isinstance(pred_transition, Shift): return None if not isinstance(gold_transition, CloseConstituent): return None num_closes = 0 while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent): num_closes += 1 if not isinstance(gold_sequence[gold_index + num_closes], Shift): # TODO: we should be able to handle this case too (an Open) # however, it will be rare once the parser gets going and it # would cause a lot of errors, anyway return None if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent): # this one should just have been satisfied in the non-ambiguous version return None # outer_close_index is now where the constituent which the broken constituent(s) reside inside gets closed outer_close_index = advance_past_constituents(gold_sequence, gold_index + num_closes) updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+num_closes:outer_close_index] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[outer_close_index:] return updated_sequence def fix_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, count_opens=False): """ We were supposed to Close, but instead did a Shift In most cases, this will be ambiguous. There is now a constituent which has been missed, no matter what we do, and we are on the hook for eventually closing this constituent, creating a precision error as well. The ambiguity arises because there will be multiple places where the Close could occur if there are more constituents created between now and when the outer constituent is Closed. The non-ambiguous case is if the proper sequence was Close - Shift - Close similar cases are also non-ambiguous, such as Close - Close - Shift - Close for that matter, so is the following, although the Opens will be lost Close - Open - Shift - Close - Close count_opens is an option to make it easy to count with or without Open as different oracle fixes """ if not isinstance(pred_transition, Shift): return None if not isinstance(gold_transition, CloseConstituent): return None num_closes = 0 while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent): num_closes += 1 # We may allow unary transitions here # the opens will be lost in the repaired sequence num_opens = 0 if count_opens: while isinstance(gold_sequence[gold_index + num_closes + num_opens], OpenConstituent): num_opens += 1 if not isinstance(gold_sequence[gold_index + num_closes + num_opens], Shift): if count_opens: raise AssertionError("Should have found a Shift after a sequence of Opens or a Close with no Open. Started counting at %d in sequence %s" % (gold_index, gold_sequence)) return None if not isinstance(gold_sequence[gold_index + num_closes + num_opens + 1], CloseConstituent): return None for idx in range(num_opens): if not isinstance(gold_sequence[gold_index + num_closes + num_opens + idx + 1], CloseConstituent): return None # Now we know it is Close x num_closes, Shift, Close # Since we have erroneously predicted a Shift now, the best we can # do is to follow that, then add num_closes Closes updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+num_opens*2+1:] return updated_sequence def fix_close_shift_with_opens(*args, **kwargs): return fix_close_shift(*args, **kwargs, count_opens=True) def fix_close_next_correct_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ We were supposed to Close, but instead predicted Shift when the next transition is Shift This differs from the previous Close-Shift in that this case does not have an unambiguous place to put the Close. Instead, we let the model predict where to put the Close Note that this can also work for Close-Open with the next Open correct Not covered (yet?) is multiple Close in a row """ if not isinstance(gold_transition, CloseConstituent): return None if not isinstance(pred_transition, (Shift, OpenConstituent)): return None if gold_sequence[gold_index+1] != pred_transition: return None candidates = [] current_index = gold_index + 1 while not isinstance(gold_sequence[current_index], CloseConstituent): if isinstance(gold_sequence[current_index], Shift): end_index = current_index else: end_index = find_constituent_end(gold_sequence, current_index) candidates.append((gold_sequence[:gold_index], gold_sequence[gold_index+1:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:])) current_index = end_index + 1 scores, best_idx, best_candidate = score_candidates_single_block(model, state, candidates, candidate_idx=3) if best_idx == len(candidates) - 1: best_idx = -1 repair_type = RepairEnum(name=RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.name, value="%d.%d" % (RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.value, best_idx), is_correct=False) return repair_type, best_candidate def fix_close_open_correct_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True): """ We were supposed to Close, but instead did an Open In general this is ambiguous (like close/shift), as we need to know when to close the incorrect constituent A case that is not ambiguous is when exactly one constituent was supposed to come after the Close and it matches the Open we just created. In that case, we treat that constituent as if it were part of the non-Closed constituent. For example, "ate (NP spaghetti) (PP with a fork)" -> "ate (NP spaghetti (PP with a fork))" (delicious) There is also an option to not check for the Close after the first constituent, in which case any number of constituents could have been predicted. This represents a solution of the ambiguous form of the Close/Open transition where the Close could occur in multiple places later in the sequence. """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, CloseConstituent): return None if gold_sequence[gold_index+1] != pred_transition: return None close_index = find_constituent_end(gold_sequence, gold_index+1) if check_close and not isinstance(gold_sequence[close_index+1], CloseConstituent): return None # at this point, we know we can put the Close at the end of the # Open which was accidentally added updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index+1] + [gold_transition] + gold_sequence[close_index+1:] return updated_sequence def fix_close_open_correct_open_ambiguous_immediate(*args, **kwargs): return fix_close_open_correct_open(*args, **kwargs, check_close=False) def fix_close_open_correct_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True): """ We were supposed to Close, but instead did an Open in an ambiguous context. Here we resolve it later in the tree """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, CloseConstituent): return None if gold_sequence[gold_index+1] != pred_transition: return None # this will be the index of the Close for the surrounding constituent close_index = advance_past_constituents(gold_sequence, gold_index+1) updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + [gold_transition] + gold_sequence[close_index:] return updated_sequence def fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ If there is an Open/Open error which is not covered by the unambiguous single recall error, we try fixing it as a Unary """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, OpenConstituent): return None if pred_transition == gold_transition: return None if gold_sequence[gold_index+1] == pred_transition: # This case is covered by the nested open repair return None close_index = find_constituent_end(gold_sequence, gold_index) assert close_index is not None assert isinstance(gold_sequence[close_index], CloseConstituent) updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:] return updated_sequence def fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ If there is an Open/Open error which is not covered by the unambiguous single recall error, we try fixing it by putting the close at the end of the outer constituent """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, OpenConstituent): return None if pred_transition == gold_transition: return None if gold_sequence[gold_index+1] == pred_transition: # This case is covered by the nested open repair return None close_index = advance_past_constituents(gold_sequence, gold_index) updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:] return updated_sequence def fix_open_open_ambiguous_random(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): """ If there is an Open/Open error which is not covered by the unambiguous single recall error, we try fixing it by putting the close at the end of the outer constituent """ if not isinstance(pred_transition, OpenConstituent): return None if not isinstance(gold_transition, OpenConstituent): return None if pred_transition == gold_transition: return None if gold_sequence[gold_index+1] == pred_transition: # This case is covered by the nested open repair return None if random.random() < 0.5: return fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels) else: return fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels) def report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): if not isinstance(gold_transition, Shift): return None if not isinstance(pred_transition, OpenConstituent): return None return RepairType.OTHER_SHIFT_OPEN, None def report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): if not isinstance(gold_transition, CloseConstituent): return None if not isinstance(pred_transition, Shift): return None return RepairType.OTHER_CLOSE_SHIFT, None def report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): if not isinstance(gold_transition, CloseConstituent): return None if not isinstance(pred_transition, OpenConstituent): return None return RepairType.OTHER_CLOSE_OPEN, None def report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state): if not isinstance(gold_transition, OpenConstituent): return None if not isinstance(pred_transition, OpenConstituent): return None return RepairType.OTHER_OPEN_OPEN, None class RepairType(Enum): """ Keep track of which repair is used, if any, on an incorrect transition A test of the top-down oracle with no charlm or transformer (eg, word vectors only) on EN PTB3 goes as follows. 3x training rounds, best training parameters as of Jan. 2024 unambiguous transitions only: oracle scheme dev test no oracle 0.9230 0.9194 +shift/close 0.9224 0.9180 +open/close 0.9225 0.9193 +open/shift (one) 0.9245 0.9207 +open/shift (mult) 0.9243 0.9211 +open/open nested 0.9258 0.9213 +shift/open 0.9266 0.9229 +close/shift (only) 0.9270 0.9230 +close/shift w/ opens 0.9262 0.9221 +close/open one con 0.9273 0.9230 Potential solutions for various ambiguous transitions: close/open can close immediately after the corresponding constituent or after any number of constituents close/shift can close immediately can close anywhere up to the next close any number of missed Opens are treated as recall errors open/open could treat as unary could close at any number of positions after the next structures, up to the outer open's closing shift/open ambiguity resolutions: treat as unary treat as wrapper around the next full constituent to build treat as wrapper around everything to build until the next constituent testing one at a time in addition to the full set of unambiguous corrections: +close/open immediate 0.9259 0.9225 +close/open later 0.9258 0.9257 +close/shift immediate 0.9261 0.9219 +close/shift later 0.9270 0.9230 +open/open later 0.9269 0.9239 +open/open unary 0.9275 0.9246 +shift/open later 0.9263 0.9253 +shift/open unary 0.9264 0.9243 so there is some evidence that open/open or shift/open would be beneficial Training by randomly choosing between the open/open, 50/50 +open/open random 0.9257 0.9235 so that didn't work great compared to the individual transitions Testing deterministic resolutions of the ambiguous transitions vs predicting the appropriate transition to use: SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR,CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR,CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR SHIFT_OPEN_AMBIGUOUS_PREDICTED,CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED EN ambiguous (no charlm or transformer) 0.9268 0.9231 EN predicted 0.9270 0.9257 EN none of the above 0.9268 0.9229 ZH ambiguous 0.9137 0.9127 ZH predicted 0.9148 0.9141 ZH none of the above 0.9141 0.9143 DE ambiguous 0.9579 0.9408 DE predicted 0.9575 0.9406 DE none of the above 0.9581 0.9411 ID ambiguous 0.8889 0.8794 ID predicted 0.8911 0.8801 ID none of the above 0.8913 0.8822 IT ambiguous 0.8404 0.8380 IT predicted 0.8397 0.8398 IT none of the above 0.8400 0.8409 VI ambiguous 0.8290 0.7676 VI predicted 0.8287 0.7682 VI none of the above 0.8292 0.7691 """ def __new__(cls, fn, correct=False, debug=False): """ Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error """ value = len(cls.__members__) obj = object.__new__(cls) obj._value_ = value + 1 obj.fn = fn obj.correct = correct obj.debug = debug return obj @property def is_correct(self): return self.correct # The parser chose to close a bracket instead of shift something # into the bracket # This causes both a precision and a recall error as there is now # an incorrect bracket and a missing correct bracket # Any bracket creation here would cause more wrong brackets, though SHIFT_CLOSE_ERROR = (fix_shift_close,) OPEN_CLOSE_ERROR = (fix_open_close,) # open followed by shift was instead predicted to be shift ONE_OPEN_SHIFT_ERROR = (fix_one_open_shift,) # open followed by shift was instead predicted to be shift MULTIPLE_OPEN_SHIFT_ERROR = (fix_multiple_open_shift,) # should have done Open(X), Open(Y) # instead just did Open(Y) NESTED_OPEN_OPEN_ERROR = (fix_nested_open_constituent,) SHIFT_OPEN_ERROR = (fix_shift_open_immediate_close,) CLOSE_SHIFT_ERROR = (fix_close_shift,) CLOSE_SHIFT_WITH_OPENS_ERROR = (fix_close_shift_with_opens,) CLOSE_OPEN_ONE_CON_ERROR = (fix_close_open_correct_open,) CORRECT = (None, True) UNKNOWN = None CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR = (fix_close_open_correct_open_ambiguous_immediate,) CLOSE_OPEN_AMBIGUOUS_LATER_ERROR = (fix_close_open_correct_open_ambiguous_later,) CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR = (fix_close_shift_ambiguous_immediate,) CLOSE_SHIFT_AMBIGUOUS_LATER_ERROR = (fix_close_shift_ambiguous_later,) # can potentially fix either close/shift or close/open # as long as the gold transition after the close # was the same as the transition we just predicted CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED = (fix_close_next_correct_predicted,) OPEN_OPEN_AMBIGUOUS_UNARY_ERROR = (fix_open_open_ambiguous_unary,) OPEN_OPEN_AMBIGUOUS_LATER_ERROR = (fix_open_open_ambiguous_later,) OPEN_OPEN_AMBIGUOUS_RANDOM_ERROR = (fix_open_open_ambiguous_random,) SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR = (fix_shift_open_ambiguous_unary,) SHIFT_OPEN_AMBIGUOUS_LATER_ERROR = (fix_shift_open_ambiguous_later,) SHIFT_OPEN_AMBIGUOUS_PREDICTED = (fix_shift_open_ambiguous_predicted,) OTHER_SHIFT_OPEN = (report_shift_open, False, True) OTHER_CLOSE_SHIFT = (report_close_shift, False, True) OTHER_CLOSE_OPEN = (report_close_open, False, True) OTHER_OPEN_OPEN = (report_open_open, False, True) class TopDownOracle(DynamicOracle): def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels): super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels) ================================================ FILE: stanza/models/constituency/trainer.py ================================================ """ This file includes a variety of methods needed to train new constituency parsers. It also includes a method to load an already-trained parser. See the `train` method for the code block which starts from raw treebank and returns a new parser. `evaluate` reads a treebank and gives a score for those trees. """ import copy import logging import os import torch from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain, NoTransformerFoundationCache from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper, pop_peft_args from stanza.models.constituency.base_trainer import BaseTrainer, ModelType from stanza.models.constituency.lstm_model import LSTMModel, SentenceBoundary, StackHistory, ConstituencyComposition from stanza.models.constituency.parse_transitions import Transition, TransitionScheme from stanza.models.constituency.utils import build_optimizer, build_scheduler # TODO: could put find_wordvec_pretrain, choose_charlm, etc in a more central place if it becomes widely used from stanza.utils.training.common import find_wordvec_pretrain, choose_charlm, find_charlm_file from stanza.resources.default_packages import default_charlms, default_pretrains logger = logging.getLogger('stanza') tlogger = logging.getLogger('stanza.constituency.trainer') class Trainer(BaseTrainer): """ Stores a constituency model and its optimizer Not inheriting from common/trainer.py because there's no concept of change_lr (yet?) """ def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False): super().__init__(model, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer) def save(self, filename, save_optimizer=True): """ Save the model (and by default the optimizer) to the given path """ super().save(filename, save_optimizer) def get_peft_params(self): # Hide import so that peft dependency is optional if self.model.args.get('use_peft', False): from peft import get_peft_model_state_dict return get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name) return None @property def model_type(self): return ModelType.LSTM @staticmethod def find_and_load_pretrain(saved_args, foundation_cache): if 'wordvec_pretrain_file' not in saved_args: return None if os.path.exists(saved_args['wordvec_pretrain_file']): return load_pretrain(saved_args['wordvec_pretrain_file'], foundation_cache) logger.info("Unable to find pretrain in %s Will try to load from the default resources instead", saved_args['wordvec_pretrain_file']) language = saved_args['lang'] wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains) return load_pretrain(wordvec_pretrain, foundation_cache) @staticmethod def find_and_load_charlm(charlm_file, direction, saved_args, foundation_cache): try: return load_charlm(charlm_file, foundation_cache) except FileNotFoundError as e: logger.info("Unable to load charlm from %s Will try to load from the default resources instead", charlm_file) language = saved_args['lang'] dataset = saved_args['shorthand'].split("_")[1] charlm = choose_charlm(language, dataset, "default", default_charlms, {}) charlm_file = find_charlm_file(direction, language, charlm) return load_charlm(charlm_file, foundation_cache) def log_num_words_known(self, words): tlogger.info("Number of words in the training set found in the embedding: %d out of %d", self.model.num_words_known(words), len(words)) @staticmethod def load_optimizer(model, checkpoint, first_optimizer, filename): optimizer = build_optimizer(model.args, model, first_optimizer) if checkpoint.get('optimizer_state_dict', None) is not None: try: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) except ValueError as e: raise ValueError("Failed to load optimizer from %s" % filename) from e else: logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer") return optimizer @staticmethod def load_scheduler(model, optimizer, checkpoint, first_optimizer): scheduler = build_scheduler(model.args, optimizer, first_optimizer=first_optimizer) if 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) return scheduler @staticmethod def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None): """ Build a new model just from the saved params and some extra args Refactoring allows other processors to include a constituency parser as a module """ saved_args = dict(params['config']) if isinstance(saved_args['sentence_boundary_vectors'], str): saved_args['sentence_boundary_vectors'] = SentenceBoundary[saved_args['sentence_boundary_vectors']] if isinstance(saved_args['constituency_composition'], str): saved_args['constituency_composition'] = ConstituencyComposition[saved_args['constituency_composition']] if isinstance(saved_args['transition_stack'], str): saved_args['transition_stack'] = StackHistory[saved_args['transition_stack']] if isinstance(saved_args['constituent_stack'], str): saved_args['constituent_stack'] = StackHistory[saved_args['constituent_stack']] if isinstance(saved_args['transition_scheme'], str): saved_args['transition_scheme'] = TransitionScheme[saved_args['transition_scheme']] # some parameters which change the structure of a model have # to be ignored, or the model will not function when it is # reloaded from disk if args is None: args = {} update_args = copy.deepcopy(args) pop_peft_args(update_args) update_args.pop("bert_hidden_layers", None) update_args.pop("bert_model", None) update_args.pop("constituency_composition", None) update_args.pop("constituent_stack", None) update_args.pop("num_tree_lstm_layers", None) update_args.pop("transition_scheme", None) update_args.pop("transition_stack", None) update_args.pop("maxout_k", None) # if the pretrain or charlms are not specified, don't override the values in the model # (if any), since the model won't even work without loading the same charlm if 'wordvec_pretrain_file' in update_args and update_args['wordvec_pretrain_file'] is None: update_args.pop('wordvec_pretrain_file') if 'charlm_forward_file' in update_args and update_args['charlm_forward_file'] is None: update_args.pop('charlm_forward_file') if 'charlm_backward_file' in update_args and update_args['charlm_backward_file'] is None: update_args.pop('charlm_backward_file') # we don't pop bert_finetune, with the theory being that if # the saved model has bert_finetune==True we can load the bert # weights but then not further finetune if bert_finetune==False saved_args.update(update_args) # TODO: not needed if we rebuild the models if saved_args.get("bert_finetune", None) is None: saved_args["bert_finetune"] = False if saved_args.get("stage1_bert_finetune", None) is None: saved_args["stage1_bert_finetune"] = False model_type = params['model_type'] if model_type == 'LSTM': pt = Trainer.find_and_load_pretrain(saved_args, foundation_cache) if saved_args.get('use_peft', False): # if loading a peft model, we first load the base transformer # then we load the weights using the saved weights in the file if peft_name is None: bert_model, bert_tokenizer, peft_name = load_bert_with_peft(saved_args.get('bert_model', None), "constituency", foundation_cache) else: bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache) bert_model = load_peft_wrapper(bert_model, peft_params, saved_args, logger, peft_name) bert_saved = True elif saved_args['bert_finetune'] or saved_args['stage1_bert_finetune'] or any(x.startswith("bert_model.") for x in params['model'].keys()): # if bert_finetune is True, don't use the cached model! # otherwise, other uses of the cached model will be ruined bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None)) bert_saved = True else: bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache) bert_saved = False forward_charlm = Trainer.find_and_load_charlm(saved_args["charlm_forward_file"], "forward", saved_args, foundation_cache) backward_charlm = Trainer.find_and_load_charlm(saved_args["charlm_backward_file"], "backward", saved_args, foundation_cache) # TODO: the isinstance will be unnecessary after 1.10.0 transitions = params['transitions'] if all(isinstance(x, str) for x in transitions): transitions = [Transition.from_repr(x) for x in transitions] model = LSTMModel(pretrain=pt, forward_charlm=forward_charlm, backward_charlm=backward_charlm, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=bert_saved, peft_name=peft_name, transitions=transitions, constituents=params['constituents'], tags=params['tags'], words=params['words'], rare_words=set(params['rare_words']), root_labels=params['root_labels'], constituent_opens=params['constituent_opens'], unary_limit=params['unary_limit'], args=saved_args) else: raise ValueError("Unknown model type {}".format(model_type)) model.load_state_dict(params['model'], strict=False) # model will stay on CPU if device==None # can be moved elsewhere later, of course model = model.to(args.get('device', None)) return model @staticmethod def build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file): # TODO: turn finetune, relearn_structure, multistage into an enum? # finetune just means continue learning, so checkpoint is sufficient # relearn_structure is essentially a one stage multistage # multistage with a checkpoint will have the proper optimizer for that epoch # and no special learning mode means we are training a new model and should continue if args['checkpoint'] and args['checkpoint_save_name'] and os.path.exists(args['checkpoint_save_name']): tlogger.info("Found checkpoint to continue training: %s", args['checkpoint_save_name']) trainer = Trainer.load(args['checkpoint_save_name'], args, load_optimizer=True, foundation_cache=foundation_cache) return trainer # in the 'finetune' case, this will preload the models into foundation_cache, # so the effort is not wasted pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file']) forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file']) backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file']) if args['finetune']: tlogger.info("Loading model to finetune: %s", model_load_file) trainer = Trainer.load(model_load_file, args, load_optimizer=True, foundation_cache=NoTransformerFoundationCache(foundation_cache)) # a new finetuning will start with a new epochs_trained count trainer.epochs_trained = 0 return trainer if args['relearn_structure']: tlogger.info("Loading model to continue training with new structure from %s", model_load_file) temp_args = dict(args) # remove the pattn & lattn layers unless the saved model had them temp_args.pop('pattn_num_layers', None) temp_args.pop('lattn_d_proj', None) trainer = Trainer.load(model_load_file, temp_args, load_optimizer=False, foundation_cache=NoTransformerFoundationCache(foundation_cache)) # using the model's current values works for if the new # dataset is the same or smaller # TODO: handle a larger dataset as well model = LSTMModel(pt, forward_charlm, backward_charlm, trainer.model.bert_model, trainer.model.bert_tokenizer, trainer.model.force_bert_saved, trainer.model.peft_name, trainer.model.transitions, trainer.model.constituents, trainer.model.tags, trainer.model.delta_words, trainer.model.rare_words, trainer.model.root_labels, trainer.model.constituent_opens, trainer.model.unary_limit(), args) model = model.to(args['device']) model.copy_with_new_structure(trainer.model) optimizer = build_optimizer(args, model, False) scheduler = build_scheduler(args, optimizer) trainer = Trainer(model, optimizer, scheduler) return trainer if args['multistage']: # run adadelta over the model for half the time with no pattn or lattn # training then switches to a different optimizer for the rest # this works surprisingly well tlogger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['epochs'] // 2) temp_args = dict(args) # remove the attention layers for the temporary model temp_args['pattn_num_layers'] = 0 temp_args['lattn_d_proj'] = 0 args = temp_args peft_name = None if args['use_peft']: peft_name = "constituency" bert_model, bert_tokenizer = load_bert(args['bert_model']) bert_model = build_peft_wrapper(bert_model, temp_args, tlogger, adapter_name=peft_name) elif args['bert_finetune'] or args['stage1_bert_finetune']: bert_model, bert_tokenizer = load_bert(args['bert_model']) else: bert_model, bert_tokenizer = load_bert(args['bert_model'], foundation_cache) model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, False, peft_name, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, args) model = model.to(args['device']) optimizer = build_optimizer(args, model, build_simple_adadelta=args['multistage']) scheduler = build_scheduler(args, optimizer, first_optimizer=args['multistage']) trainer = Trainer(model, optimizer, scheduler, first_optimizer=args['multistage']) return trainer ================================================ FILE: stanza/models/constituency/transformer_tree_stack.py ================================================ """ Based on Transition-based Parsing with Stack-Transformers Ramon Fernandez Astudillo, Miguel Ballesteros, Tahira Naseem, Austin Blodget, and Radu Florian https://aclanthology.org/2020.findings-emnlp.89.pdf """ from collections import namedtuple import torch import torch.nn as nn from stanza.models.constituency.positional_encoding import SinusoidalEncoding from stanza.models.constituency.tree_stack import TreeStack Node = namedtuple("Node", ['value', 'key_stack', 'value_stack', 'output']) class TransformerTreeStack(nn.Module): def __init__(self, input_size, output_size, input_dropout, length_limit=None, use_position=False, num_heads=1): """ Builds the internal matrices and start parameter TODO: currently only one attention head, implement MHA """ super().__init__() self.input_size = input_size self.output_size = output_size self.inv_sqrt_output_size = 1 / output_size ** 0.5 self.num_heads = num_heads self.w_query = nn.Linear(input_size, output_size) self.w_key = nn.Linear(input_size, output_size) self.w_value = nn.Linear(input_size, output_size) self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True))) if isinstance(input_dropout, nn.Module): self.input_dropout = input_dropout else: self.input_dropout = nn.Dropout(input_dropout) if length_limit is not None and length_limit < 1: raise ValueError("length_limit < 1 makes no sense") self.length_limit = length_limit self.use_position = use_position if use_position: self.position_encoding = SinusoidalEncoding(model_dim=self.input_size, max_len=512) def attention(self, key, query, value, mask=None): """ Calculate attention for the given key, query value Where B is the number of items stacked together, N is the length: The key should be BxNxD The query is BxD The value is BxNxD If mask is specified, it should be BxN of True/False values, where True means that location is masked out Reshapes and reorders are used to handle num_heads Return will be softmax(query x key^T) * value of size BxD """ B = key.shape[0] N = key.shape[1] D = key.shape[2] H = self.num_heads # query is now BxDx1 query = query.unsqueeze(2) # BxHxD/Hx1 query = query.reshape((B, H, -1, 1)) # BxNxHxD/H key = key.reshape((B, N, H, -1)) # BxHxNxD/H key = key.transpose(1, 2) # BxNxHxD/H value = value.reshape((B, N, H, -1)) # BxHxNxD/H value = value.transpose(1, 2) # BxHxNxD/H x BxHxD/Hx1 # result shape: BxHxN attn = torch.matmul(key, query).squeeze(3) * self.inv_sqrt_output_size if mask is not None: # mask goes from BxN -> Bx1xN mask = mask.unsqueeze(1) mask = mask.expand(-1, H, -1) attn.masked_fill_(mask, float('-inf')) # attn shape will now be BxHx1xN attn = torch.softmax(attn, dim=2).unsqueeze(2) # BxHx1xN x BxHxNxD/H -> BxHxD/H output = torch.matmul(attn, value).squeeze(2) output = output.reshape(B, -1) return output def initial_state(self, initial_value=None): """ Return an initial state based on a single layer of attention Running attention might be overkill, but it is the simplest way to put the Linears and start_embedding in the computation graph """ start = self.start_embedding if self.use_position: position = self.position_encoding([0]).squeeze(0) start = start + position # N=1 # shape: 1xD key = self.w_key(start).unsqueeze(0) # shape: D query = self.w_query(start) # shape: 1xD value = self.w_value(start).unsqueeze(0) # unsqueeze to make it look like we are part of a batch of size 1 output = self.attention(key.unsqueeze(0), query.unsqueeze(0), value.unsqueeze(0)).squeeze(0) return TreeStack(value=Node(initial_value, key, value, output), parent=None, length=1) def push_states(self, stacks, values, inputs): """ Push new inputs to the stacks and rerun attention on them Where B is the number of items stacked together, I is input_size stacks: B TreeStacks such as produced by initial_state and/or push_states values: the new items to push on the stacks such as tree nodes or anything inputs: BxI for the new input items Runs attention starting from the existing keys & values """ device = self.w_key.weight.device batch_len = len(stacks) # B positions = [x.value.key_stack.shape[0] for x in stacks] max_len = max(positions) # N if self.use_position: position_encodings = self.position_encoding(positions) inputs = inputs + position_encodings inputs = self.input_dropout(inputs) if len(inputs.shape) == 3: if inputs.shape[0] == 1: inputs = inputs.squeeze(0) else: raise ValueError("Expected the inputs to be of shape 1xBxI, got {}".format(inputs.shape)) new_keys = self.w_key(inputs) key_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device) key_stack[:, -1, :] = new_keys for stack_idx, stack in enumerate(stacks): key_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.key_stack new_values = self.w_value(inputs) value_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device) value_stack[:, -1, :] = new_values for stack_idx, stack in enumerate(stacks): value_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.value_stack query = self.w_query(inputs) mask = torch.zeros(batch_len, max_len+1, device=device, dtype=torch.bool) for stack_idx, stack in enumerate(stacks): if len(stack) < max_len: masked = max_len - positions[stack_idx] mask[stack_idx, :masked] = True batched_output = self.attention(key_stack, query, value_stack, mask) new_stacks = [] for stack_idx, (stack, node_value, new_key, new_value, output) in enumerate(zip(stacks, values, key_stack, value_stack, batched_output)): # max_len-len(stack) so that we ignore the padding at the start of shorter stacks new_key_stack = new_key[max_len-positions[stack_idx]:, :] new_value_stack = new_value[max_len-positions[stack_idx]:, :] if self.length_limit is not None and new_key_stack.shape[0] > self.length_limit + 1: new_key_stack = torch.cat([new_key_stack[:1, :], new_key_stack[2:, :]], axis=0) new_value_stack = torch.cat([new_value_stack[:1, :], new_value_stack[2:, :]], axis=0) new_stacks.append(stack.push(value=Node(node_value, new_key_stack, new_value_stack, output))) return new_stacks def output(self, stack): """ Return the last layer of the lstm_hx as the output from a stack Refactored so that alternate structures have an easy way of getting the output """ return stack.value.output ================================================ FILE: stanza/models/constituency/transition_sequence.py ================================================ """ Build a transition sequence from parse trees. Supports multiple transition schemes - TOP_DOWN and variants, IN_ORDER """ import logging from stanza.models.common import utils from stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme, Finalize from stanza.models.constituency.tree_reader import read_trees from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() logger = logging.getLogger('stanza.constituency.trainer') def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): """ For tree (X A B C D), yield Open(X) A B C D Close The details are in how to treat unary transitions Three possibilities handled by this method: TOP_DOWN_UNARY: (Y (X ...)) -> Open(X) ... Close Unary(Y) TOP_DOWN_COMPOUND: (Y (X ...)) -> Open(Y, X) ... Close TOP_DOWN: (Y (X ...)) -> Open(Y) Open(X) ... Close Close """ if tree.is_preterminal(): yield Shift() return if tree.is_leaf(): return if transition_scheme is TransitionScheme.TOP_DOWN_UNARY: if len(tree.children) == 1: labels = [] while not tree.is_preterminal() and len(tree.children) == 1: labels.append(tree.label) tree = tree.children[0] for transition in yield_top_down_sequence(tree, transition_scheme): yield transition yield CompoundUnary(*labels) return if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND: labels = [tree.label] while len(tree.children) == 1 and not tree.children[0].is_preterminal(): tree = tree.children[0] labels.append(tree.label) yield OpenConstituent(*labels) else: yield OpenConstituent(tree.label) for child in tree.children: for transition in yield_top_down_sequence(child, transition_scheme): yield transition yield CloseConstituent() def yield_in_order_sequence(tree): """ For tree (X A B C D), yield A Open(X) B C D Close """ if tree.is_preterminal(): yield Shift() return if tree.is_leaf(): return for transition in yield_in_order_sequence(tree.children[0]): yield transition yield OpenConstituent(tree.label) for child in tree.children[1:]: for transition in yield_in_order_sequence(child): yield transition yield CloseConstituent() def yield_in_order_compound_sequence(tree, transition_scheme): def helper(tree): if tree.is_leaf(): return labels = [] while len(tree.children) == 1 and not tree.is_preterminal(): labels.append(tree.label) tree = tree.children[0] if tree.is_preterminal(): yield Shift() if len(labels) > 0: yield CompoundUnary(*labels) return for transition in helper(tree.children[0]): yield transition if transition_scheme is TransitionScheme.IN_ORDER_UNARY: yield OpenConstituent(tree.label) else: labels.append(tree.label) yield OpenConstituent(*labels) for child in tree.children[1:]: for transition in helper(child): yield transition yield CloseConstituent() if transition_scheme is TransitionScheme.IN_ORDER_UNARY and len(labels) > 0: yield CompoundUnary(*labels) if len(tree.children) == 0: raise ValueError("Cannot build {} on an empty tree".format(transition_scheme)) if len(tree.children) != 1: raise ValueError("Cannot build {} with a tree that has two top level nodes: {}".format(transition_scheme, tree)) for t in helper(tree.children[0]): yield t yield Finalize(tree.label) def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): """ Turn a single tree into a list of transitions based on the TransitionScheme """ if transition_scheme is TransitionScheme.IN_ORDER: return list(yield_in_order_sequence(tree)) elif (transition_scheme is TransitionScheme.IN_ORDER_COMPOUND or transition_scheme is TransitionScheme.IN_ORDER_UNARY): return list(yield_in_order_compound_sequence(tree, transition_scheme)) else: return list(yield_top_down_sequence(tree, transition_scheme)) def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, reverse=False): """ Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme """ if reverse: return [build_sequence(tree.reverse(), transition_scheme) for tree in trees] else: return [build_sequence(tree, transition_scheme) for tree in trees] def all_transitions(transition_lists): """ Given a list of transition lists, combine them all into a list of unique transitions. """ transitions = set() for trans_list in transition_lists: transitions.update(trans_list) return sorted(transitions) def convert_trees_to_sequences(trees, treebank_name, transition_scheme, reverse=False): """ Wrap both build_treebank and all_transitions, possibly with a tqdm Converts trees to a list of sequences, then returns the list of known transitions """ if len(trees) == 0: return [], [] logger.info("Building %s transition sequences", treebank_name) if logger.getEffectiveLevel() <= logging.INFO: trees = tqdm(trees) sequences = build_treebank(trees, transition_scheme, reverse) transitions = all_transitions(sequences) return sequences, transitions def main(): """ Convert a sample tree and print its transitions """ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" #text = "(WP Who)" tree = read_trees(text)[0] print(tree) transitions = build_sequence(tree) print(transitions) if __name__ == '__main__': main() ================================================ FILE: stanza/models/constituency/tree_embedding.py ================================================ """ A module to use a Constituency Parser to make an embedding for a tree The embedding can be produced just from the words and the top of the tree, or it can be done with a form of attention over the nodes Can be done over an existing parse tree or unparsed text """ import torch import torch.nn as nn from stanza.models.constituency.trainer import Trainer class TreeEmbedding(nn.Module): def __init__(self, constituency_parser, args): super(TreeEmbedding, self).__init__() self.config = { "all_words": args["all_words"], "backprop": args["backprop"], #"batch_norm": args["batch_norm"], "node_attn": args["node_attn"], "top_layer": args["top_layer"], } self.constituency_parser = constituency_parser # word_lstm: hidden_size * num_tree_lstm_layers * 2 (start & end) # transition_stack: transition_hidden_size # constituent_stack: hidden_size self.hidden_size = self.constituency_parser.hidden_size + self.constituency_parser.transition_hidden_size if self.config["all_words"]: self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers else: self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers * 2 if self.config["node_attn"]: self.query = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size) self.key = nn.Linear(self.hidden_size, self.constituency_parser.hidden_size) self.value = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size) # TODO: cat transition and constituent hx as well? self.output_size = self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers else: self.output_size = self.hidden_size # TODO: maybe have batch_norm, maybe use Identity #if self.config["batch_norm"]: # self.input_norm = nn.BatchNorm1d(self.output_size) def embed_trees(self, inputs): if self.config["backprop"]: states = self.constituency_parser.analyze_trees(inputs) else: with torch.no_grad(): states = self.constituency_parser.analyze_trees(inputs) constituent_lists = [x.constituents for x in states] states = [x.state for x in states] word_begin_hx = torch.stack([state.word_queue[0].hx for state in states]) word_end_hx = torch.stack([state.word_queue[state.word_position].hx for state in states]) transition_hx = torch.stack([self.constituency_parser.transition_stack.output(state.transitions) for state in states]) # go down one layer to get the embedding off the top of the S, not the ROOT # (in terms of the typical treebank) # the idea being that the ROOT has no additional information # and may even have 0s for the embedding in certain circumstances, # such as after learning UNTIED_MAX long enough if self.config["top_layer"]: constituent_hx = torch.stack([self.constituency_parser.constituent_stack.output(state.constituents) for state in states]) else: constituent_hx = torch.cat([constituents[-2].tree_hx for constituents in constituent_lists], dim=0) if self.config["all_words"]: # need B matrices of N x hidden_size key = [torch.stack([torch.cat([word.hx, thx, chx]) for word in state.word_queue], dim=0) for state, thx, chx in zip(states, transition_hx, constituent_hx)] else: key = torch.cat((word_begin_hx, word_end_hx, transition_hx, constituent_hx), dim=1).unsqueeze(1) if not self.config["node_attn"]: return key key = [self.key(x) for x in key] node_hx = [torch.stack([con.tree_hx for con in constituents], dim=0) for constituents in constituent_lists] queries = [self.query(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx] values = [self.value(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx] # TODO: could pad to make faster here attn = [torch.matmul(q, k.transpose(0, 1)) for q, k in zip(queries, key)] attn = [torch.softmax(x, dim=0) for x in attn] previous_layer = [torch.matmul(weight.transpose(0, 1), value) for weight, value in zip(attn, values)] return previous_layer def forward(self, inputs): return embed_trees(self, inputs) def get_norms(self): lines = ["constituency_parser." + x for x in self.constituency_parser.get_norms()] for name, param in self.named_parameters(): if param.requires_grad and not name.startswith('constituency_parser.'): lines.append("%s %.6g" % (name, torch.norm(param).item())) return lines def get_params(self, skip_modules=True): model_state = self.state_dict() # skip all of the constituency parameters here - # we will add them by calling the model's get_params() skipped = [k for k in model_state.keys() if k.startswith("constituency_parser.")] for k in skipped: del model_state[k] parser = self.constituency_parser.get_params(skip_modules) params = { 'model': model_state, 'constituency': parser, 'config': self.config, } return params @staticmethod def from_parser_file(args, foundation_cache=None): constituency_parser = Trainer.load(args['model'], args, foundation_cache) return TreeEmbedding(constituency_parser.model, args) @staticmethod def model_from_params(params, args, foundation_cache=None): # TODO: integrate with peft constituency_parser = Trainer.model_from_params(params['constituency'], None, args, foundation_cache) model = TreeEmbedding(constituency_parser, params['config']) model.load_state_dict(params['model'], strict=False) return model ================================================ FILE: stanza/models/constituency/tree_reader.py ================================================ """ Reads ParseTree objects from a file, string, or similar input Works by first splitting the input into (, ), and all other tokens, then recursively processing those tokens into trees. """ from collections import deque import logging import os import re from stanza.models.constituency.parse_tree import Tree from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() OPEN_PAREN = "(" CLOSE_PAREN = ")" logger = logging.getLogger('stanza.constituency') # A few specific exception types to clarify parsing errors # They store the line number where the error occurred class UnclosedTreeError(ValueError): """ A tree looked like (Foo """ def __init__(self, line_num): super().__init__("Found an unfinished tree (missing close brackets). Tree started on line %d" % line_num) self.line_num = line_num class ExtraCloseTreeError(ValueError): """ A tree looked like (Foo)) """ def __init__(self, line_num): super().__init__("Found a broken tree (extra close brackets). Tree started on line %d" % line_num) self.line_num = line_num class UnlabeledTreeError(ValueError): """ A tree had no label, such as ((Foo) (Bar)) This does not actually happen at the root, btw, as ROOT is silently added """ def __init__(self, line_num): super().__init__("Found a tree with no label on a node! Line number %d" % line_num) self.line_num = line_num class MixedTreeError(ValueError): """ Leaf and constituent children are mixed in the same node """ def __init__(self, line_num, child_label, children): super().__init__("Found a tree with both text children and bracketed children! Line number {} Child label {} Children {}".format(line_num, child_label, children)) self.line_num = line_num self.child_label = child_label self.children = children def normalize(text): return text.replace("-LRB-", "(").replace("-RRB-", ")") def read_single_tree(token_iterator, broken_ok): """ Build a tree from the tokens in the token_iterator """ # we were called here at a open paren, so start the stack of # children with one empty list already on it children_stack = deque() children_stack.append([]) text_stack = deque() text_stack.append([]) token = next(token_iterator, None) token_iterator.set_mark() while token is not None: if token == OPEN_PAREN: children_stack.append([]) text_stack.append([]) elif token == CLOSE_PAREN: text = text_stack.pop() children = children_stack.pop() if text: pieces = " ".join(text).split() if len(pieces) == 1: child = Tree(pieces[0], children) else: # the assumption here is that a language such as VI may # have spaces in the words, but it still represents # just one child label = pieces[0] child_label = " ".join(pieces[1:]) if children: if broken_ok: child = Tree(label, children + [Tree(normalize(child_label))]) else: raise MixedTreeError(token_iterator.line_num, child_label, children) else: child = Tree(label, Tree(normalize(child_label))) if not children_stack: return child else: if not children_stack: return Tree("ROOT", children) elif broken_ok: child = Tree(None, children) else: raise UnlabeledTreeError(token_iterator.line_num) children_stack[-1].append(child) else: text_stack[-1].append(token) token = next(token_iterator, None) raise UnclosedTreeError(token_iterator.get_mark()) LINE_SPLIT_RE = re.compile(r"([()])") class TokenIterator: """ A specific iterator for reading trees from a tree file The idea is that this will keep track of which line we are processing, so that an error can be logged from the correct line """ def __init__(self): self.token_iterator = iter([]) self.line_num = -1 self.mark = None def set_mark(self): """ The mark is used for determining where the start of a tree occurs for an error """ self.mark = self.line_num def get_mark(self): if self.mark is None: raise ValueError("No mark set!") return self.mark def __iter__(self): return self def __next__(self): n = next(self.token_iterator, None) while n is None: self.line_num = self.line_num + 1 line = next(self.line_iterator) if line is None: raise StopIteration line = line.strip() if not line: continue pieces = LINE_SPLIT_RE.split(line) pieces = [x.strip() for x in pieces] pieces = [x for x in pieces if x] self.token_iterator = iter(pieces) n = next(self.token_iterator, None) return n class TextTokenIterator(TokenIterator): def __init__(self, text, use_tqdm=True): super().__init__() self.lines = text.split("\n") self.num_lines = len(self.lines) if self.num_lines > 1000 and use_tqdm: self.line_iterator = iter(tqdm(self.lines)) else: self.line_iterator = iter(self.lines) class FileTokenIterator(TokenIterator): def __init__(self, filename): super().__init__() self.filename = filename def __enter__(self): # TODO: use the file_size instead of counting the lines # file_size = Path(self.filename).stat().st_size with open(self.filename) as fin: num_lines = sum(1 for _ in fin) self.file_obj = open(self.filename) if num_lines > 1000: self.line_iterator = iter(tqdm(self.file_obj, total=num_lines)) else: self.line_iterator = iter(self.file_obj) return self def __exit__(self, exc_type, exc_value, exc_tb): if self.file_obj: self.file_obj.close() def read_token_iterator(token_iterator, broken_ok, tree_callback): trees = [] token = next(token_iterator, None) while token: if token == OPEN_PAREN: next_tree = read_single_tree(token_iterator, broken_ok=broken_ok) if next_tree is None: raise ValueError("Tree reader somehow created a None tree! Line number %d" % token_iterator.line_num) if tree_callback is not None: transformed = tree_callback(next_tree) if transformed is not None: trees.append(transformed) else: trees.append(next_tree) token = next(token_iterator, None) elif token == CLOSE_PAREN: raise ExtraCloseTreeError(token_iterator.line_num) else: raise ValueError("Tree document had text between trees! Line number %d" % token_iterator.line_num) return trees def read_trees(text, broken_ok=False, tree_callback=None, use_tqdm=True): """ Reads multiple trees from the text TODO: some of the error cases we hit can be recovered from """ token_iterator = TextTokenIterator(text, use_tqdm) return read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback) def read_tree_file(filename, broken_ok=False, tree_callback=None): """ Read all of the trees in the given file """ with FileTokenIterator(filename) as token_iterator: trees = read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback) return trees def read_directory(dirname, broken_ok=False, tree_callback=None): """ Read all of the trees in all of the files in a directory """ trees = [] for filename in sorted(os.listdir(dirname)): full_name = os.path.join(dirname, filename) trees.extend(read_tree_file(full_name, broken_ok, tree_callback)) return trees def read_treebank(filename, tree_callback=None): """ Read a treebank and alter the trees to be a simpler format for learning to parse """ logger.info("Reading trees from %s", filename) trees = read_tree_file(filename, tree_callback=tree_callback) trees = [t.prune_none().simplify_labels() for t in trees] illegal_trees = [t for t in trees if len(t.children) > 1] if len(illegal_trees) > 0: raise ValueError("Found {} tree(s) which had non-unary transitions at the ROOT. First illegal tree: {:P}".format(len(illegal_trees), illegal_trees[0])) return trees def main(): """ Reads a sample tree """ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = read_trees(text) print(trees) if __name__ == '__main__': main() ================================================ FILE: stanza/models/constituency/tree_stack.py ================================================ """ A utilitiy class for keeping track of intermediate parse states """ from collections import namedtuple class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])): """ A stack which can branch in several directions, as long as you keep track of the branching heads An example usage is when K constituents are removed at once to create a new constituent, and then the LSTM which tracks the values of the constituents is updated starting from the Kth output of the LSTM with the new value. We don't simply keep track of a single stack object using a deque because versions of the parser which use a beam will want to be able to branch in different directions from the same base stack Another possible usage is if an oracle is used for training in a manner where some fraction of steps are non-gold steps, but we also want to take a gold step from the same state. Eg, parser gets to state X, wants to make incorrect transition T instead of gold transition G, and so we continue training both X+G and X+T. If we only represent the state X with standard python stacks, it would not be possible to track both of these states at the same time without copying the entire thing. Value can be as transition, a word, or a partially built constituent Implemented as a namedtuple to make it a bit more efficient """ def pop(self): return self.parent def push(self, value): # returns a new stack node which points to this return TreeStack(value, self, self.length+1) def __iter__(self): stack = self while stack.parent is not None: yield stack.value stack = stack.parent yield stack.value def __reversed__(self): items = list(iter(self)) for item in reversed(items): yield item def __str__(self): return "TreeStack(%s)" % ", ".join([str(x) for x in self]) def __len__(self): return self.length ================================================ FILE: stanza/models/constituency/utils.py ================================================ """ Collects a few of the conparser utility methods which don't belong elsewhere """ from collections import Counter import logging import warnings import torch.nn as nn from torch import optim from stanza.models.common.doc import TEXT, Document from stanza.models.common.utils import get_optimizer from stanza.models.constituency.base_model import SimpleModel from stanza.models.constituency.parse_transitions import TransitionScheme from stanza.models.constituency.parse_tree import Tree from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0000007 , "mirror_madgrad": 0.00005 } DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 } DEFAULT_LEARNING_RHO = 0.9 DEFAULT_MOMENTUM = { "madgrad": 0.9, "mirror_madgrad": 0.9, "sgd": 0.9 } tlogger = logging.getLogger('stanza.constituency.trainer') # madgrad experiment for weight decay # with learning_rate set to 0.0000007 and momentum 0.9 # on en_wsj, with a baseline model trained on adadela for 200, # then madgrad used to further improve that model # 0.00000002.out: 0.9590347746438835 # 0.00000005.out: 0.9591378819960182 # 0.0000001.out: 0.9595450596319405 # 0.0000002.out: 0.9594603134479271 # 0.0000005.out: 0.9591317672706594 # 0.000001.out: 0.9592548741021389 # 0.000002.out: 0.9598395477013945 # 0.000003.out: 0.9594974271553495 # 0.000004.out: 0.9596665982603754 # 0.000005.out: 0.9591620720706487 DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.02, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6, "mirror_madgrad": 2e-6 } def retag_tags(doc, pipelines, xpos): """ Returns a list of list of tags for the items in doc doc can be anything which feeds into the pipeline(s) pipelines are a list of 1 or more retag pipelines if multiple pipelines are given, majority vote wins """ tag_lists = [] for pipeline in pipelines: doc = pipeline(doc) tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences]) # tag_lists: for N pipeline, S sentences # we now have N lists of S sentences each # for sentence in zip(*tag_lists): N lists of |s| tags for this given sentence s # for tag in zip(*sentence): N predicted tags. # most common one in the Counter will be chosen tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)] for sentence in zip(*tag_lists)] return tag_lists def retag_trees(trees, pipelines, xpos=True): """ Retag all of the trees using the given processor Returns a list of new trees """ if len(trees) == 0: return trees new_trees = [] chunk_size = 1000 with tqdm(total=len(trees)) as pbar: for chunk_start in range(0, len(trees), chunk_size): chunk_end = min(chunk_start + chunk_size, len(trees)) chunk = trees[chunk_start:chunk_end] sentences = [] try: for idx, tree in enumerate(chunk): tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()] sentences.append(tokens) except ValueError as e: raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e doc = Document(sentences) tag_lists = retag_tags(doc, pipelines, xpos) for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)): try: if any(tag is None for tag in tags): raise RuntimeError("Tagged tree #{} with a None tag!\n{}\n{}".format(tree_idx, tree, tags)) new_tree = tree.replace_tags(tags) new_trees.append(new_tree) pbar.update(1) except ValueError as e: raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e if len(new_trees) != len(trees): raise AssertionError("Retagged tree counts did not match: {} vs {}".format(len(new_trees), len(trees))) return new_trees def build_optimizer(args, model, build_simple_adadelta=False): """ Build an optimizer based on the arguments given If we are "multistage" training and epochs_trained < epochs // 2, we build an AdaDelta optimizer instead of whatever was requested The build_simple_adadelta parameter controls this """ bert_learning_rate = 0.0 bert_weight_decay = args['bert_weight_decay'] if build_simple_adadelta: optim_type = 'adadelta' bert_finetune = args.get('stage1_bert_finetune', False) if bert_finetune: bert_learning_rate = args['stage1_bert_learning_rate'] learning_beta2 = 0.999 # doesn't matter for AdaDelta learning_eps = DEFAULT_LEARNING_EPS['adadelta'] learning_rate = args['stage1_learning_rate'] learning_rho = DEFAULT_LEARNING_RHO momentum = None # also doesn't matter for AdaDelta weight_decay = DEFAULT_WEIGHT_DECAY['adadelta'] else: optim_type = args['optim'].lower() bert_finetune = args.get('bert_finetune', False) if bert_finetune: bert_learning_rate = args['bert_learning_rate'] learning_beta2 = args['learning_beta2'] learning_eps = args['learning_eps'] learning_rate = args['learning_rate'] learning_rho = args['learning_rho'] momentum = args['learning_momentum'] weight_decay = args['learning_weight_decay'] # TODO: allow rho as an arg for AdaDelta return get_optimizer(name=optim_type, model=model, lr=learning_rate, betas=(0.9, learning_beta2), eps=learning_eps, momentum=momentum, weight_decay=weight_decay, bert_learning_rate=bert_learning_rate, bert_weight_decay=weight_decay*bert_weight_decay, is_peft=args.get('use_peft', False), bert_finetune_layers=args['bert_finetune_layers'], opt_logger=tlogger) def build_scheduler(args, optimizer, first_optimizer=False): """ Build the scheduler for the conparser based on its args Used to use a warmup for learning rate, but that wasn't working very well Now, we just use a ReduceLROnPlateau, which does quite well """ #if args.get('learning_rate_warmup', 0) <= 0: # # TODO: is there an easier way to make an empty scheduler? # lr_lambda = lambda x: 1.0 #else: # warmup_end = args['learning_rate_warmup'] # def lr_lambda(x): # if x >= warmup_end: # return 1.0 # return x / warmup_end #scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) if first_optimizer: scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['stage1_learning_rate_min_lr']) else: scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['learning_rate_min_lr']) return scheduler def initialize_linear(linear, nonlinearity, bias): """ Initializes the bias to a positive value, hopefully preventing dead neurons """ if nonlinearity in ('relu', 'leaky_relu'): nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity) nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5) def add_predict_output_args(parser): """ Args specifically for the output location of data """ parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging') parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions') parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions') parser.add_argument('--predict_output_gold_tags', default=False, action='store_true', help='Output gold tags as part of the evaluation - useful for putting the trees through EvalB') def postprocess_predict_output_args(args): if len(args['predict_format']) <= 2 or (len(args['predict_format']) <= 3 and args['predict_format'].endswith("Vi")): args['predict_format'] = "{:" + args['predict_format'] + "}" def get_open_nodes(trees, transition_scheme): """ Return a list of all open nodes in the given dataset. Depending on the parameters, may be single or compound open transitions. """ if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND: return Tree.get_compound_constituents(trees) elif transition_scheme is TransitionScheme.IN_ORDER_COMPOUND: return Tree.get_compound_constituents(trees, separate_root=True) else: return [(x,) for x in Tree.get_unique_constituent_labels(trees)] def verify_transitions(trees, sequences, transition_scheme, unary_limit, reverse, name, root_labels): """ Given a list of trees and their transition sequences, verify that the sequences rebuild the trees """ model = SimpleModel(transition_scheme, unary_limit, reverse, root_labels) tlogger.info("Verifying the transition sequences for %d trees", len(trees)) data = zip(trees, sequences) if tlogger.getEffectiveLevel() <= logging.INFO: data = tqdm(zip(trees, sequences), total=len(trees)) for tree_idx, (tree, sequence) in enumerate(data): # TODO: make the SimpleModel have a parse operation? state = model.initial_state_from_gold_trees([tree])[0] for idx, trans in enumerate(sequence): if not trans.is_legal(state, model): raise RuntimeError("Tree {} of {} failed: transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(tree_idx, name, idx, trans, tree, sequence)) state = trans.apply(state, model) result = model.get_top_constituent(state.constituents) if reverse: result = result.reverse() if tree != result: raise RuntimeError("Tree {} of {} failed: transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree_idx, name, tree, sequence, result)) def check_constituents(train_constituents, trees, treebank_name, fail=True): """ Check that all the constituents in the other dataset are known in the train set """ constituents = Tree.get_unique_constituent_labels(trees) for con in constituents: if con not in train_constituents: first_error = None num_errors = 0 for tree_idx, tree in enumerate(trees): constituents = Tree.get_unique_constituent_labels(tree) if con in constituents: num_errors += 1 if first_error is None: first_error = tree_idx error = "Found constituent label {} in the {} set which don't exist in the train set. This constituent label occurred in {} trees, with the first tree index at {} counting from 1\nThe error tree (which may have POS tags changed from the retagger and may be missing functional tags or empty nodes) is:\n{:P}".format(con, treebank_name, num_errors, (first_error+1), trees[first_error]) if fail: raise RuntimeError(error) else: warnings.warn(error) def check_root_labels(root_labels, other_trees, treebank_name): """ Check that all the root states in the other dataset are known in the train set """ for root_state in Tree.get_root_labels(other_trees): if root_state not in root_labels: raise RuntimeError("Found root state {} in the {} set which is not a ROOT state in the train set".format(root_state, treebank_name)) def remove_duplicate_trees(trees, treebank_name): """ Filter duplicates from the given dataset """ new_trees = [] known_trees = set() for tree in trees: tree_str = "{}".format(tree) if tree_str in known_trees: continue known_trees.add(tree_str) new_trees.append(tree) if len(new_trees) < len(trees): tlogger.info("Filtered %d duplicates from %s dataset", (len(trees) - len(new_trees)), treebank_name) return new_trees def remove_singleton_trees(trees): """ remove trees which are just a root and a single word TODO: remove these trees in the conversion instead of here """ new_trees = [x for x in trees if len(x.children) > 1 or (len(x.children) == 1 and len(x.children[0].children) > 1) or (len(x.children) == 1 and len(x.children[0].children) == 1 and len(x.children[0].children[0].children) >= 1)] if len(trees) - len(new_trees) > 0: tlogger.info("Eliminated %d trees with missing structure", (len(trees) - len(new_trees))) return new_trees ================================================ FILE: stanza/models/constituency_parser.py ================================================ """A command line interface to a shift reduce constituency parser. This follows the work of Recurrent neural network grammars by Dyer et al In-Order Transition-based Constituent Parsing by Liu & Zhang The general outline is: Train a model by taking a list of trees, converting them to transition sequences, and learning a model which can predict the next transition given a current state Then, at inference time, repeatedly predict the next transition until parsing is complete The "transitions" are variations on shift/reduce as per an intro-to-compilers class. The idea is that you can treat all of the words in a sentence as a buffer of tokens, then either "shift" them to represent a new constituent, or "reduce" one or more constituents to form a new constituent. In order to make the runtime a more competitive speed, effort is taken to batch the transitions and apply multiple transitions at once. At train time, batches are groups together by length, and at inference time, new trees are added to the batch as previous trees on the batch finish their inference. There are a few minor differences in the model: - The word input is a bi-lstm, not a uni-lstm. This gave a small increase in accuracy. - The combination of several constituents into one constituent is done via a single bi-lstm rather than two separate lstms. This increases speed without a noticeable effect on accuracy. - In fact, an even better (in terms of final model accuracy) method is to combine the constituents with torch.max, believe it or not See lstm_model.py for more details - Initializing the embeddings with smaller values than pytorch default For example, on a ja_alt dataset, scores went from 0.8980 to 0.8985 at 200 iterations averaged over 5 trials - Training with AdaDelta first, then AdamW or madgrad later improves results quite a bit. See --multistage A couple experiments which have been tried with little noticeable impact: - Combining constituents using the method in the paper (only a trained vector at the start instead of both ends) did not affect results and is a little slower - Using multiple layers of LSTM hidden state for the input to the final classification layers didn't help - Initializing Linear layers with He initialization and a positive bias (to avoid dead connections) had no noticeable effect on accuracy 0.8396 on it_turin with the original initialization 0.8401 and 0.8427 on two runs with updated initialization (so maybe a small improvement...) - Initializing LSTM layers with different gates was slightly worse: forget gates of 1.0 forget gates of 1.0, input gates of -1.0 - Replacing the LSTMs that make up the Transition and Constituent LSTMs with Dynamic Skip LSTMs made no difference, but was slower - Highway LSTMs also made no difference - Putting labels on the shift transitions (the word or the tag shifted) or putting labels on the close transitions didn't help - Building larger constituents from the output of the constituent LSTM instead of the children constituents hurts scores For example, an experiment on ja_alt went from 0.8985 to 0.8964 when built that way - The initial transition scheme implemented was TOP_DOWN. We tried a compound unary option, since this worked so well in the CoreNLP constituency parser. Unfortunately, this is far less effective than IN_ORDER. Both specialized unary matrices and reusing the n-ary constituency combination fell short. On the ja_alt dataset: IN_ORDER, max combination method: 0.8985 TOP_DOWN_UNARY, specialized matrices: 0.8501 TOP_DOWN_UNARY, max combination method: 0.8508 - Adding multiple layers of MLP to combine inputs for words made no difference in the scores Tried both before the LSTM and after A simple single layer tensor multiply after the LSTM works well. Replacing that with a two layer MLP on the English PTB with roberta-base causes a notable drop in scores First experiment didn't use the fancy Linear weight init, but adding that barely made a difference 260 training iterations on en_wsj dev, roberta-base model as of bb983fd5e912f6706ad484bf819486971742c3d1 two layer MLP: 0.9409 two layer MLP, init weights: 0.9413 single layer: 0.9467 - There is code to rebuild models with a new structure in lstm_model.py As part of this, we tried to randomly reinitialize the transitions if the transition embedding had gone to 0, which often happens This didn't help at all - We tried something akin to attention with just the query vector over the bert embeddings as a way to mix them, but that did not improve scores. Example, with a self.bert_layer_mix of size bert_dim x 1: mixed_bert_embeddings = [] for feature in bert_embeddings: weighted_feature = self.bert_layer_mix(feature.transpose(1, 2)) weighted_feature = torch.softmax(weighted_feature, dim=1) weighted_feature = torch.matmul(feature, weighted_feature).squeeze(2) mixed_bert_embeddings.append(weighted_feature) bert_embeddings = mixed_bert_embeddings It seems just finetuning the transformer is already enough (in general, no need to mix layers at all when finetuning bert embeddings) The code breakdown is as follows: this file: main interface for training or evaluating models constituency/trainer.py: contains the training & evaluation code constituency/ensemble.py: evaluation code specifically for letting multiple models vote on the correct next transition. a modest improvement. constituency/evaluate_treebanks.py: specifically to evaluate multiple parsed treebanks against a gold. in particular, reports whether the theoretical best from those parsed treebanks is an improvement (eg, the k-best score as reported by CoreNLP) constituency/parse_tree.py: a data structure for representing a parse tree and utility methods constituency/tree_reader.py: a module which can read trees from a string or input file constituency/tree_stack.py: a linked list which can branch in different directions, which will be useful when implementing beam search or a dynamic oracle constituency/lstm_tree_stack.py: an LSTM over the elements of a TreeStack constituency/transformer_tree_stack.py: attempts to run attention over the nodes of a tree_stack. not as effective as the lstm_tree_stack in the initial experiments. perhaps it could be refined to work better, though constituency/parse_transitions.py: transitions and a State data structure to store them constituency/transition_sequence.py: turns ParseTree objects into the transition sequences needed to make them constituency/base_model.py: operates on the transitions to turn them in to constituents, eventually forming one final parse tree composed of all of the constituents constituency/lstm_model.py: adds LSTM features to the constituents to predict what the correct transition to make is, allowing for predictions on previously unseen text constituency/retagging.py: a couple utility methods specifically for retagging constituency/utils.py: a couple utility methods constituency/dyanmic_oracle.py: a dynamic oracle which currently only operates for the inorder transition sequence. uses deterministic rules to redo the correct action sequence when the parser makes an error. constituency/partitioned_transformer.py: implementation of a transformer for self-attention. presumably this should help, but we have yet to find a model structure where this makes the scores go up. constituency/label_attention.py: an even fancier form of transformer based on labeled attention: https://arxiv.org/abs/1911.03875 constituency/positional_encoding.py: so far, just the sinusoidal is here. a trained encoding is in partitioned_transformer.py. this should probably be refactored to common, especially if used elsewhere. stanza/pipeline/constituency_processor.py: interface between this model and the Pipeline stanza/utils/datasets/constituency: various scripts and tools for processing constituency datasets Some alternate optimizer methods: adabelief: https://github.com/juntang-zhuang/Adabelief-Optimizer madgrad: https://github.com/facebookresearch/madgrad """ import argparse import logging import os import random import re import torch import stanza from stanza.models.common import constant from stanza.models.common import utils from stanza.models.common.peft_config import add_peft_args, resolve_peft_args from stanza.models.common.utils import NONLINEARITY from stanza.models.constituency import parser_training from stanza.models.constituency import retagging from stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory from stanza.models.constituency.parse_transitions import TransitionScheme from stanza.models.constituency.text_processing import load_model_parse_text from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, add_predict_output_args, postprocess_predict_output_args from stanza.resources.common import DEFAULT_MODEL_DIR logger = logging.getLogger('stanza') tlogger = logging.getLogger('stanza.constituency.trainer') def build_argparse(): """ Adds the arguments for building the con parser For the most part, defaults are set to cross-validated values, at least for WSJ """ parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='data/constituency', help='Directory of constituency data.') parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors') parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors') parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument('--pretrain_max_vocab', type=int, default=250000) parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") # BERT helps a lot and actually doesn't slow things down too much # for VI, for example, use vinai/phobert-base parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)") parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert") parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer") parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding') # BERT finetuning (or any transformer finetuning) # also helps quite a lot. # Experimentally, finetuning all of the layers is the most effective # On the id_icon dataset with the indolem transformer # In this experiment, we trained for 150 iterations with AdaDelta, # with the learning rate 0.01, # then trained for another 150 with madgrad and no finetuning # 1 layer 0.880753 (152) # 2 layers 0.880453 (174) # 3 layers 0.881774 (163) # 4 layers 0.886915 (194) # 5 layers 0.892064 (299) # 6 layers 0.891825 (224) # 7 layers 0.894373 (173) # 8 layers 0.894505 (233) # 9 layers 0.896676 (269) # 10 layers 0.897525 (269) # 11 layers 0.897348 (211) # 12 layers 0.898729 (270) # everything 0.898855 (252) # so the trend is clear that more finetuning is better # # We found that finetuning works very well on the AdaDelta portion # of a multistage training, but less well on a madgrad second # stage. The issue was that we literally could not set the # learning rate low enough because madgrad used epsilon in the LR: # https://github.com/facebookresearch/madgrad/issues/16 # # Possible values of the AdaDelta learning rate on the id_icon dataset # In this experiment, we finetuned the entire transformer 150 # iterations on AdaDelta, then trained with madgrad for another # 150 with no finetuning # 0.0005: 0.89122 (155) # 0.001: 0.889807 (241) # 0.002: 0.894874 (202) # 0.005: 0.896327 (270) # 0.006: 0.898989 (246) # 0.007: 0.896712 (167) # 0.008: 0.900136 (237) # 0.009: 0.898597 (169) # 0.01: 0.898665 (251) # 0.012: 0.89661 (274) # 0.014: 0.899149 (283) # 0.016: 0.896314 (230) # 0.018: 0.897753 (257) # 0.02: 0.893665 (256) # 0.05: 0.849274 (159) # 0.1: 0.850633 (183) # 0.2: 0.847332 (176) # # The peak is somewhere around 0.008 to 0.014, with the further # observation that at the 150 iteration mark, 0.09 was winning: # 0.007: 0.894589 (33) # 0.008: 0.894777 (53) # 0.009: 0.896466 (56) # 0.01: 0.895557 (71) # 0.012: 0.893479 (45) # 0.014: 0.89468 (116) # 0.016: 0.893053 (128) # 0.018: 0.893086 (48) # # Another option is to train for a few iterations with no # finetuning, then begin finetuning. However, that was not # beneficial at all. # Start iteration on id_icon, same setup as above: # 1: 0.898855 (252) # 5: 0.897885 (217) # 10: 0.895367 (215) # 25: 0.896781 (193) # 50: 0.895216 (193) # Using adamw instead of madgrad: # 1: 0.900594 (226) # 5: 0.898153 (267) # 10: 0.898756 (271) # 25: 0.896867 (256) # 50: 0.895025 (220) # # # With the observation that very low learning rate is currently # not working for madgrad, we tried to parameter sweep LR for # AdamW, and got the following, using a first stage LR of 0.009: # 0.0: 0.899706 (290) # 0.00005: 0.899631 (176) # 0.0001: 0.899851 (233) # 0.0002: 0.898601 (207) # 0.0003: 0.899258 (252) # 0.0004: 0.90033 (187) # 0.0005: 0.899091 (183) # 0.001: 0.899791 (268) # 0.002: 0.899453 (196) # 0.003: 0.897029 (173) # 0.004: 0.899566 (290) # 0.005: 0.899285 (289) # 0.01: 0.898938 (233) # 0.02: 0.898983 (248) # 0.03: 0.898571 (247) # 0.04: 0.898466 (180) # 0.05: 0.897448 (214) # It should be noted that in the 0.0001 range, the epoch to epoch # change of the Bert weights was almost negligible. Weights would # change in the 5th or 6th decimal place, if at all. # # The conclusion of all these experiments is that, if we are using # bert_finetuning, the best approach is probably a stage1 learning # rate of 0.009 or so and a second stage optimizer of adamw with # no LR or a very low LR. This behavior is what happens with the # --stage1_bert_finetune flag parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)') parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)") parser.add_argument('--bert_finetune_layers', default=None, type=int, help='Only finetune this many layers from the transformer') parser.add_argument('--bert_finetune_begin_epoch', default=None, type=int, help='Which epoch to start finetuning the transformer') parser.add_argument('--bert_finetune_end_epoch', default=None, type=int, help='Which epoch to stop finetuning the transformer') parser.add_argument('--bert_learning_rate', default=0.009, type=float, help='Scale the learning rate for transformer finetuning by this much') parser.add_argument('--stage1_bert_learning_rate', default=None, type=float, help="Scale the learning rate for transformer finetuning by this much only during an AdaDelta warmup") parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much') parser.add_argument('--stage1_bert_finetune', default=None, action='store_true', help="Finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune") parser.add_argument('--no_stage1_bert_finetune', dest='stage1_bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune") add_peft_args(parser) parser.add_argument('--tag_embedding_dim', type=int, default=20, help="Embedding size for a tag. 0 turns off the feature") # Smaller values also seem to work # For example, after 700 iterations: # 32: 0.9174 # 50: 0.9183 # 72: 0.9176 # 100: 0.9185 # not a huge difference regardless # (these numbers were without retagging) parser.add_argument('--delta_embedding_dim', type=int, default=100, help="Embedding size for a delta embedding") parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--no_train_remove_duplicates', default=True, action='store_false', dest="train_remove_duplicates", help="Do/don't remove duplicates from the training file. Could be useful for intentionally reweighting some trees") parser.add_argument('--silver_file', type=str, default=None, help='Secondary training file.') parser.add_argument('--silver_remove_duplicates', default=False, action='store_true', help="Do/don't remove duplicates from the silver training file. Could be useful for intentionally reweighting some trees") parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') # TODO: possibly refactor --tokenized_file / --tokenized_dir from here & ensemble parser.add_argument('--xml_tree_file', type=str, default=None, help='Input file of VLSP formatted trees for parsing with parse_text.') parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.') parser.add_argument('--tokenized_dir', type=str, default=None, help='Input directory of tokenized text for parsing with parse_text.') parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer']) parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one') add_predict_output_args(parser) parser.add_argument('--lang', type=str, help='Language') parser.add_argument('--shorthand', type=str, help="Treebank shorthand") parser.add_argument('--transition_embedding_dim', type=int, default=20, help="Embedding size for a transition") parser.add_argument('--transition_hidden_size', type=int, default=20, help="Embedding size for transition stack") parser.add_argument('--transition_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()], help='How to track transitions over a parse. {}'.format(", ".join(x.name for x in StackHistory))) parser.add_argument('--transition_heads', default=4, type=int, help="How many heads to use in MHA *if* the transition_stack is Attention") parser.add_argument('--constituent_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()], help='How to track transitions over a parse. {}'.format(", ".join(x.name for x in StackHistory))) parser.add_argument('--constituent_heads', default=8, type=int, help="How many heads to use in MHA *if* the transition_stack is Attention") # larger was more effective, up to a point # substantially smaller, such as 128, # is fine if bert & charlm are not available parser.add_argument('--hidden_size', type=int, default=512, help="Size of the output layers for constituency stack and word queue") parser.add_argument('--epochs', type=int, default=400) parser.add_argument('--epoch_size', type=int, default=5000, help="Runs this many trees in an 'epoch' instead of going through the training dataset exactly once. Set to 0 to do the whole training set") parser.add_argument('--silver_epoch_size', type=int, default=None, help="Runs this many trees in a silver 'epoch'. If not set, will match --epoch_size") # AdaDelta warmup for the conparser. Motivation: AdaDelta results in # higher scores overall, but learns 0s for the weights of the pattn and # lattn layers. AdamW learns weights for pattn, and the models are more # accurate than models trained without pattn using AdamW, but the models # are lower scores overall than the AdaDelta models. # # This improves that by first running AdaDelta, then switching. # # Now, if --multistage is set, run AdaDelta for half the epochs with no # pattn or lattn. Then start the specified optimizer for the rest of # the time with the full model. If pattn and lattn are both present, # the model is 1/2 no attn, 1/4 pattn, 1/4 pattn and lattn # # Improvement on the WSJ dev set can be seen from 94.8 to 95.3 # when 4 layers of pattn are trained this way. # More experiments to follow. parser.add_argument('--multistage', default=True, action='store_true', help='1/2 epochs with adadelta no pattn or lattn, 1/4 with chosen optim and no lattn, 1/4 full model') parser.add_argument('--no_multistage', dest='multistage', action='store_false', help="don't do the multistage learning") # 1 seems to be the most effective, but we should cross-validate parser.add_argument('--oracle_initial_epoch', type=int, default=1, help="Epoch where we start using the dynamic oracle to let the parser keep going with wrong decisions") parser.add_argument('--oracle_frequency', type=float, default=0.8, help="How often to use the oracle vs how often to force the correct transition") parser.add_argument('--oracle_forced_errors', type=float, default=0.001, help="Occasionally have the model randomly walk through the state space to try to learn how to recover") parser.add_argument('--oracle_level', type=int, default=None, help='Restrict oracle transitions to this level or lower. 0 means off. None means use all oracle transitions.') parser.add_argument('--additional_oracle_levels', type=str, default=None, help='Add some additional experimental oracle transitions. Basically for A/B testing transitions we expect to be bad.') parser.add_argument('--deactivated_oracle_levels', type=str, default=None, help='Temporarily turn off a default oracle level. Basically for A/B testing transitions we expect to be bad.') # 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ # earlier version of the model (less accurate overall) had the following results with adadelta: # 30: 0.9085 # 50: 0.9070 # 75: 0.9010 # 150: 0.8985 # as another data point, running a newer version with better constituency lstm behavior had: # 30: 0.9111 # 50: 0.9094 # checking smaller batch sizes to see how this works, at 135 epochs, the values are # 10: 0.8919 # 20: 0.9072 # 30: 0.9121 # obviously these experiments aren't the complete story, but it # looks like 30 trees per batch is the best value for WSJ # note that these numbers are for adadelta and might not apply # to other optimizers # eval batch should generally be faster the bigger the batch, # up to a point, as it allows for more batching of the LSTM # operations and the prediction step parser.add_argument('--train_batch_size', type=int, default=30, help='How many trees to train before taking an optimizer step') parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval') parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.') parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{finetune}_constituency.pt", help="File name to save the model") parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing") parser.add_argument('--save_each_start', type=int, default=None, help="When to start saving each model") parser.add_argument('--save_each_frequency', type=int, default=1, help="How frequently to save each model") parser.add_argument('--no_save_each_optimizer', dest='save_each_optimizer', default=True, action='store_false', help="Don't save the optimizer when saving 'each' model") parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--no_seed', action='store_const', const=None, dest='seed', help='Remove the random seed, resulting in a randomly chosen random seed') parser.add_argument('--no_check_valid_states', default=True, action='store_false', dest='check_valid_states', help="Don't check the constituents or transitions in the dev set when starting a new parser. Warning: the parser will never guess unknown constituents") parser.add_argument('--no_strict_check_constituents', default=True, action='store_false', dest='strict_check_constituents', help="Don't check the constituents between the train & dev set. May result in untrainable transitions") utils.add_device_args(parser) # Numbers are on a VLSP dataset, before adding attn or other improvements # baseline is an 80.6 model that occurs when trained using adadelta, lr 1.0 # # adabelief 0.1: fails horribly # 0.02: converges very low scores # 0.01: very slow learning # 0.002: almost decent # 0.001: close, but about 1 f1 low on IT # 0.0005: 79.71 # 0.0002: 80.11 # 0.0001: 79.85 # 0.00005: 80.40 # 0.00002: 80.02 # 0.00001: 78.95 # # madgrad 0.005: fails horribly # 0.001: low scores # 0.0005: still somewhat low # 0.0002: close, but about 1 f1 low on IT # 0.0001: 80.04 # 0.00005: 79.91 # 0.00002: 80.15 # 0.00001: 80.44 # 0.000005: 80.34 # 0.000002: 80.39 # # adamw experiment on a TR dataset (not necessarily the best test case) # note that at that time, the expected best for adadelta was 0.816 # # 0.00005 - 0.7925 # 0.0001 - 0.7889 # 0.0002 - 0.8110 # 0.00025 - 0.8108 # 0.0003 - 0.8050 # 0.0005 - 0.8076 # 0.001 - 0.8069 # Numbers on the VLSP Dataset, with --multistage and default learning rates and adabelief optimizer # Gelu: 82.32 # Mish: 81.95 # ELU: 81.73 # Hardshrink: 0.3 # Hardsigmoid: 79.03 # Hardtanh: 81.44 # Hardswish: 81.67 # Logsigmoid: 80.91 # Prelu: 80.95 (terminated early) # Relu6: 81.91 # RReLU: 77.00 # Selu: 81.17 # Celu: 81.43 # Silu: 81.90 # Softplus: 80.94 # Softshrink: 0.3 # Softsign: 81.63 # Softshrink: 13.74 # # Tests with no_charlm, --multitstage # Gelu # 0.00002 0.819746 # 0.00005 0.818 # 0.0001 0.818566 # 0.0002 0.819111 # 0.001 0.815609 # # Mish # 0.00002 0.816898 # 0.00005 0.821085 # 0.0001 0.817821 # 0.0002 0.818806 # 0.001 0.816494 # # Relu # 0.00002 0.818402 # 0.00005 0.819019 # 0.0001 0.821625 # 0.0002 0.820633 # 0.001 0.814315 # # Relu6 # 0.00002 0.819719 # 0.00005 0.819871 # 0.0001 0.819018 # 0.0002 0.819506 # 0.001 0.819018 parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate for the optimizer. Reasonable values are 1.0 for adadelta or 0.001 for SGD. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_RATES)) parser.add_argument('--learning_eps', default=None, type=float, help='eps value to use in the optimizer. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_EPS)) parser.add_argument('--learning_momentum', default=None, type=float, help='Momentum. None uses a default for the given optimizer: {}'.format(DEFAULT_MOMENTUM)) # weight decay values other than adadelta have not been thoroughly tested. # When using adadelta, weight_decay of 0.01 to 0.001 had the best results. # 0.1 was very clearly too high. 0.0001 might have been okay. # Running a series of 5x experiments on a VI dataset: # 0.030: 0.8167018 # 0.025: 0.81659 # 0.020: 0.81722 # 0.015: 0.81721 # 0.010: 0.81474348 # 0.005: 0.81503 parser.add_argument('--learning_weight_decay', default=None, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer') parser.add_argument('--learning_rho', default=DEFAULT_LEARNING_RHO, type=float, help='Rho parameter in Adadelta') # A few experiments on beta2 didn't show much benefit from changing it # On an experiment with training WSJ with default parameters # AdaDelta for 200 iterations, then training AdamW for 200 more, # 0.999, 0.997, 0.995 all wound up with 0.9588 # values lower than 0.995 all had a slight dropoff parser.add_argument('--learning_beta2', default=0.999, type=float, help='Beta2 argument for AdamW') parser.add_argument('--optim', default=None, help='Optimizer type: SGD, AdamW, Adadelta, AdaBelief, Madgrad') parser.add_argument('--stage1_learning_rate', default=None, type=float, help='Learning rate to use in the first stage of --multistage. None means use default: {}'.format(DEFAULT_LEARNING_RATES['adadelta'])) parser.add_argument('--learning_rate_warmup', default=0, type=int, help="Number of epochs to ramp up learning rate from 0 to full. Set to 0 to always use the chosen learning rate. Currently not functional, as it didn't do anything") parser.add_argument('--learning_rate_factor', default=0.6, type=float, help='Plateau learning rate decreate when plateaued') parser.add_argument('--learning_rate_patience', default=5, type=int, help='Plateau learning rate patience') parser.add_argument('--learning_rate_cooldown', default=10, type=int, help='Plateau learning rate cooldown') parser.add_argument('--learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum') parser.add_argument('--stage1_learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum (stage 1)') parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping') parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping') # Large Margin is from Large Margin In Softmax Cross-Entropy Loss # it did not help on an Italian VIT test # scores went from 0.8252 to 0.8248 parser.add_argument('--loss', default='cross', help='cross, large_margin, or focal. Focal requires `pip install focal_loss_torch`') parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss') # turn off dropout for word_dropout, predict_dropout, and lstm_input_dropout # this mechanism doesn't actually turn off lstm_layer_dropout (yet) # but that is set to a default of 0 anyway # this is reusing the idea presented in # https://arxiv.org/pdf/2303.01500v2 # "Dropout Reduces Underfitting" # Zhuang Liu, Zhiqiu Xu, Joseph Jin, Zhiqiang Shen, Trevor Darrell # Unfortunately, this does not consistently help results # Averaged of 5 models w/ transformer, dev / test # id_icon - improves a little # baseline 0.8823 0.8904 # early_dropout 40 0.8835 0.8919 # ja_alt - worsens a little # baseline 0.9308 0.9355 # early_dropout 40 0.9287 0.9345 # vi_vlsp23 - worsens a little # baseline 0.8262 0.8290 # early_dropout 40 0.8255 0.8286 # We keep this as an available option for further experiments, if needed parser.add_argument('--early_dropout', default=-1, type=int, help='When to turn off dropout') # When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations: # 0.0: 0.9085 # 0.2: 0.9165 # 0.4: 0.9162 # 0.5: 0.9123 # Letting 0.2 and 0.4 run for longer, along with 0.3 as another # trial, continued to give extremely similar results over time. # No attempt has been made to test the different dropouts separately... parser.add_argument('--word_dropout', default=0.2, type=float, help='Dropout on the word embedding') parser.add_argument('--predict_dropout', default=0.2, type=float, help='Dropout on the final prediction layer') # lstm_dropout has not been fully tested yet # one experiment after 200 iterations (after retagging, so scores are lower than some other experiments): # 0.0: 0.9093 # 0.1: 0.9094 # 0.2: 0.9094 # 0.3: 0.9076 # 0.4: 0.9077 parser.add_argument('--lstm_layer_dropout', default=0.0, type=float, help='Dropout in the LSTM layers') # one not very conclusive experiment (not long enough) came up with these numbers after ~200 iterations # 0.0 0.9091 # 0.1 0.9095 # 0.2 0.9118 # 0.3 0.9123 # 0.4 0.9080 parser.add_argument('--lstm_input_dropout', default=0.2, type=float, help='Dropout on the input to an LSTM') parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()], help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme))) parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed') # combining dummy and open node embeddings might be a slight improvement # for example, after 550 iterations, one experiment had # True: 0.9154 # False: 0.9150 # another (with a different structure) had 850 iterations # True: 0.9155 # False: 0.9149 parser.add_argument('--combined_dummy_embedding', default=True, action='store_true', help="Use the same embedding for dummy nodes and the vectors used when combining constituents") parser.add_argument('--no_combined_dummy_embedding', dest='combined_dummy_embedding', action='store_false', help="Don't use the same embedding for dummy nodes and the vectors used when combining constituents") # relu gave at least 1 F1 improvement over tanh in various experiments # relu & gelu seem roughly the same, but relu is clearly faster. # relu, 496 iterations: 0.9176 # gelu, 467 iterations: 0.9181 # after the same clock time on the same hardware. the two had been # trading places in terms of accuracy over those ~500 iterations. # leaky_relu was not an improvement - a full run on WSJ led to 0.9181 f1 instead of 0.919 # See constituency/utils.py for more extensive comments on nonlinearity options parser.add_argument('--nonlinearity', default='relu', choices=NONLINEARITY.keys(), help='Nonlinearity to use in the model. relu is a noticeable improvement over tanh') # In one experiment on an Italian dataset, VIT, we got the following: # 0.8254 with relu as the nonlinearity (10 trials) # 0.8265 with maxout, k = 2 (15) # 0.8253 with maxout, k = 3 (5) # The speed in terms of trees/second might be slightly slower with maxout. # 51.4 it/s on a Titan Xp with maxout 2 and 51.9 it/s with relu # It might also be worth running some experiments with bigger # output layers to see if that makes up for the difference in score. parser.add_argument('--maxout_k', default=None, type=int, help="Use maxout layers instead of a nonlinearity for the output layers") parser.add_argument('--use_silver_words', default=True, dest='use_silver_words', action='store_true', help="Train/don't train word vectors for words only in the silver dataset") parser.add_argument('--no_use_silver_words', default=True, dest='use_silver_words', action='store_false', help="Train/don't train word vectors for words only in the silver dataset") parser.add_argument('--rare_word_unknown_frequency', default=0.02, type=float, help='How often to replace a rare word with UNK when training') parser.add_argument('--rare_word_threshold', default=0.02, type=float, help='How many words to consider as rare words as a fraction of the dataset') parser.add_argument('--tag_unknown_frequency', default=0.001, type=float, help='How often to replace a tag with UNK when training') parser.add_argument('--num_lstm_layers', default=2, type=int, help='How many layers to use in the LSTMs') parser.add_argument('--num_tree_lstm_layers', default=None, type=int, help='How many layers to use in the TREE_LSTMs, if used. This also increases the width of the word outputs to match the tree lstm inputs. Default 2 if TREE_LSTM or TREE_LSTM_CX, 1 otherwise') parser.add_argument('--num_output_layers', default=3, type=int, help='How many layers to use at the prediction level') parser.add_argument('--sentence_boundary_vectors', default=SentenceBoundary.EVERYTHING, type=lambda x: SentenceBoundary[x.upper()], help='Vectors to learn at the start & end of sentences. {}'.format(", ".join(x.name for x in SentenceBoundary))) parser.add_argument('--constituency_composition', default=ConstituencyComposition.MAX, type=lambda x: ConstituencyComposition[x.upper()], help='How to build a new composition from its children. {}'.format(", ".join(x.name for x in ConstituencyComposition))) parser.add_argument('--reduce_heads', default=8, type=int, help='Number of attn heads to use when reducing children into a parent tree (constituency_composition == attn)') parser.add_argument('--reduce_position', default=None, type=int, help="Dimension of position vector to use when reducing children. None means 1/4 hidden_size, 0 means don't use (constituency_composition == key | untied_key)") parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn. One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw') parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path') parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint") parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints") parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file') parser.add_argument('--load_package', type=str, default=None, help='Download an existing stanza package & use this for tests, finetuning, etc') retagging.add_retag_args(parser) # Partitioned Attention parser.add_argument('--pattn_d_model', default=1024, type=int, help='Partitioned attention model dimensionality') parser.add_argument('--pattn_morpho_emb_dropout', default=0.2, type=float, help='Dropout rate for morphological features obtained from pretrained model') parser.add_argument('--pattn_encoder_max_len', default=512, type=int, help='Max length that can be put into the transformer attention layer') parser.add_argument('--pattn_num_heads', default=8, type=int, help='Partitioned attention model number of attention heads') parser.add_argument('--pattn_d_kv', default=64, type=int, help='Size of the query and key vector') parser.add_argument('--pattn_d_ff', default=2048, type=int, help='Size of the intermediate vectors in the feed-forward sublayer') parser.add_argument('--pattn_relu_dropout', default=0.1, type=float, help='ReLU dropout probability in feed-forward sublayer') parser.add_argument('--pattn_residual_dropout', default=0.2, type=float, help='Residual dropout probability for all residual connections') parser.add_argument('--pattn_attention_dropout', default=0.2, type=float, help='Attention dropout probability') parser.add_argument('--pattn_num_layers', default=0, type=int, help='Number of layers for the Partitioned Attention. Currently turned off') parser.add_argument('--pattn_bias', default=False, action='store_true', help='Whether or not to learn an additive bias') # Results seem relatively similar with learned position embeddings or sin/cos position embeddings parser.add_argument('--pattn_timing', default='sin', choices=['learned', 'sin'], help='Use a learned embedding or a sin embedding') # Label Attention parser.add_argument('--lattn_d_input_proj', default=None, type=int, help='If set, project the non-positional inputs down to this size before proceeding.') parser.add_argument('--lattn_d_kv', default=64, type=int, help='Dimension of the key/query vector') parser.add_argument('--lattn_d_proj', default=64, type=int, help='Dimension of the output vector from each label attention head') parser.add_argument('--lattn_resdrop', default=True, action='store_true', help='Whether or not to use Residual Dropout') parser.add_argument('--lattn_pwff', default=True, action='store_true', help='Whether or not to use a Position-wise Feed-forward Layer') parser.add_argument('--lattn_q_as_matrix', default=False, action='store_true', help='Whether or not Label Attention uses learned query vectors. False means it does') parser.add_argument('--lattn_partitioned', default=True, action='store_true', help='Whether or not it is partitioned') parser.add_argument('--no_lattn_partitioned', default=True, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned') parser.add_argument('--lattn_combine_as_self', default=False, action='store_true', help='Whether or not the layer uses concatenation. False means it does') # currently unused - always assume 1/2 of pattn #parser.add_argument('--lattn_d_positional', default=512, type=int, help='Dimension for the positional embedding') parser.add_argument('--lattn_d_l', default=32, type=int, help='Number of labels') parser.add_argument('--lattn_attention_dropout', default=0.2, type=float, help='Dropout for attention layer') parser.add_argument('--lattn_d_ff', default=2048, type=int, help='Dimension of the Feed-forward layer') parser.add_argument('--lattn_relu_dropout', default=0.2, type=float, help='Relu dropout for the label attention') parser.add_argument('--lattn_residual_dropout', default=0.2, type=float, help='Residual dropout for the label attention') parser.add_argument('--lattn_combined_input', default=True, action='store_true', help='Combine all inputs for the lattn, not just the pattn') parser.add_argument('--use_lattn', default=False, action='store_true', help='Use the lattn layers - currently turned off') parser.add_argument('--no_use_lattn', dest='use_lattn', action='store_false', help='Use the lattn layers - currently turned off') parser.add_argument('--no_lattn_combined_input', dest='lattn_combined_input', action='store_false', help="Don't combine all inputs for the lattn, not just the pattn") parser.add_argument('--use_rattn', default=False, action='store_true', help='Use a local attention layer') parser.add_argument('--rattn_window', default=16, type=int, help='Number of tokens to use for context in the local attention') # Ran an experiment on id_icon with in_order, peft, 200 epochs training # Equivalent experiment with no rattn had an average of 0.8922 dev # window 16, cat, dim 200, sinks 0 # head dev score # 1 0.8915 # 2 0.8933 # 3 0.8918 # 4 0.8934 # 5 0.8924 # 6 0.8936 # 8 0.8920 # 10 0.8909 # 12 0.8939 # 14 0.8949 # 16 0.8952 # 18 0.8915 # 20 0.8925 # 25 0.8913 # 30 0.8913 # 40 0.8943 # 50 0.8931 # 75 0.8940 # The average here is 0.8928, which is a tiny bit higher... parser.add_argument('--rattn_heads', default=16, type=int, help='Number of heads to use for context in the local attention') parser.add_argument('--no_rattn_forward', default=True, action='store_false', dest='rattn_forward', help="Use or don't use the forward relative attention") parser.add_argument('--no_rattn_reverse', default=True, action='store_false', dest='rattn_reverse', help="Use or don't use the reverse relative attention") parser.add_argument('--no_rattn_cat', action='store_false', dest='rattn_cat', help='Stack the rattn layers instead of adding them') parser.add_argument('--rattn_cat', default=True, action='store_true', help='Stack the rattn layers instead of adding them') parser.add_argument('--rattn_dim', default=200, type=int, help='Dimension of the rattn output when cat') parser.add_argument('--rattn_sinks', default=0, type=int, help='Number of attention sink tokens to learn') parser.add_argument('--rattn_use_endpoint_sinks', default=False, action='store_true', help='Use the endpoints of the sentences as sinks') parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training. A very noisy option') parser.add_argument('--log_shapes', default=False, action='store_true', help='Log the parameters shapes at the beginning') parser.add_argument('--watch_regex', default=None, help='regex to describe which weights and biases to output, if any') parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') parser.add_argument('--wandb_norm_regex', default=None, help='Log on wandb any tensor whose norm matches this matrix. Might get cluttered?') return parser def build_model_filename(args): embedding = utils.embedding_name(args) maybe_finetune = "finetuned" if args['bert_finetune'] or args['stage1_bert_finetune'] else "" transformer_finetune_begin = "%d" % args['bert_finetune_begin_epoch'] if args['bert_finetune_begin_epoch'] is not None else "" rattn = "" if args['use_rattn']: if args['rattn_forward']: rattn = rattn + "F" if args['rattn_reverse']: rattn = rattn + "R" if rattn: if args['rattn_cat']: rattn += "c" rattn += "h%02d" % args['rattn_heads'] rattn += "w%02d" % args['rattn_window'] if args['rattn_sinks'] > 0: rattn += "s%d" % args['rattn_sinks'] model_save_file = args['save_name'].format(shorthand=args['shorthand'], oracle_level=args['oracle_level'], embedding=embedding, finetune=maybe_finetune, transformer_finetune_begin=transformer_finetune_begin, transition_scheme=args['transition_scheme'].name.lower().replace("_", ""), tscheme=args['transition_scheme'].short_name, trans_layers=args['bert_hidden_layers'], rattn=rattn, seed=args['seed']) model_save_file = re.sub("_+", "_", model_save_file) logger.info("Expanded save_name: %s", model_save_file) model_dir = os.path.split(model_save_file)[0] if model_dir != args['save_dir']: model_save_file = os.path.join(args['save_dir'], model_save_file) return model_save_file def parse_args(args=None): parser = build_argparse() args = parser.parse_args(args=args) resolve_peft_args(args, logger, check_bert_finetune=False) if not args.lang and args.shorthand and len(args.shorthand.split("_", maxsplit=1)) == 2: args.lang = args.shorthand.split("_")[0] if args.stage1_bert_learning_rate is None: args.stage1_bert_learning_rate = args.bert_learning_rate if args.optim is None and args.mode == 'train': if not args.multistage: # this seemed to work the best when not doing multistage args.optim = "adadelta" if args.use_peft and not args.bert_finetune: logger.info("--use_peft set. setting --bert_finetune as well") args.bert_finetune = True elif args.bert_finetune or args.stage1_bert_finetune: logger.info("Multistage training is set, optimizer is not chosen, and bert finetuning is active. Will use AdamW as the second stage optimizer.") args.optim = "adamw" else: # if MADGRAD exists, use it # otherwise, adamw try: import madgrad args.optim = "madgrad" logger.info("Multistage training is set, optimizer is not chosen, and MADGRAD is available. Will use MADGRAD as the second stage optimizer.") except ModuleNotFoundError as e: logger.warning("Multistage training is set. Best models are with MADGRAD, but it is not installed. Will use AdamW for the second stage optimizer. Consider installing MADGRAD") args.optim = "adamw" if args.mode == 'train': if args.learning_rate is None: args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim.lower(), None) if args.learning_eps is None: args.learning_eps = DEFAULT_LEARNING_EPS.get(args.optim.lower(), None) if args.learning_momentum is None: args.learning_momentum = DEFAULT_MOMENTUM.get(args.optim.lower(), None) if args.learning_weight_decay is None: args.learning_weight_decay = DEFAULT_WEIGHT_DECAY.get(args.optim.lower(), None) if args.stage1_learning_rate is None: args.stage1_learning_rate = DEFAULT_LEARNING_RATES["adadelta"] if args.stage1_bert_finetune is None: args.stage1_bert_finetune = args.bert_finetune if args.learning_rate_min_lr is None: args.learning_rate_min_lr = args.learning_rate * 0.02 if args.stage1_learning_rate_min_lr is None: args.stage1_learning_rate_min_lr = args.stage1_learning_rate * 0.02 if args.reduce_position is None: args.reduce_position = args.hidden_size // 4 if args.num_tree_lstm_layers is None: if args.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX): args.num_tree_lstm_layers = 2 else: args.num_tree_lstm_layers = 1 if args.wandb_name or args.wandb_norm_regex: args.wandb = True args = vars(args) retagging.postprocess_args(args) postprocess_predict_output_args(args) if args['seed'] is None: args['seed'] = random.randint(0, 1000000000) logger.info("Using random seed %d", args['seed']) model_save_file = build_model_filename(args) args['save_name'] = model_save_file if args['save_each_name']: model_save_each_file = os.path.join(args['save_dir'], args['save_each_name']) model_save_each_file = utils.build_save_each_filename(model_save_each_file) args['save_each_name'] = model_save_each_file else: # in the event that there is a start epoch setting, # this will make a reasonable default for the path pieces = os.path.splitext(args['save_name']) model_save_each_file = pieces[0] + "_%04d" + pieces[1] args['save_each_name'] = model_save_each_file if args['checkpoint']: args['checkpoint_save_name'] = utils.checkpoint_name(args['save_dir'], model_save_file, args['checkpoint_save_name']) return args def main(args=None): """ Main function for building con parser Processes args, calls the appropriate function for the chosen --mode """ args = parse_args(args=args) utils.set_random_seed(args['seed']) logger.info("Running constituency parser in %s mode", args['mode']) logger.debug("Using device: %s", args['device']) model_load_file = args['save_name'] if args['load_name']: if os.path.exists(args['load_name']): model_load_file = args['load_name'] else: model_load_file = os.path.join(args['save_dir'], args['load_name']) elif args['load_package']: if args['lang'] is None: lang_pieces = args['load_package'].split("_", maxsplit=1) try: lang = constant.lang_to_langcode(lang_pieces[0]) except ValueError as e: raise ValueError("--lang not specified, and the start of the --load_package name, %s, is not a known language. Please check the values of those parameters" % args['load_package']) from e args['lang'] = lang args['load_package'] = lang_pieces[1] stanza.download(args['lang'], processors="constituency", package={"constituency": args['load_package']}) model_load_file = os.path.join(DEFAULT_MODEL_DIR, args['lang'], 'constituency', args['load_package'] + ".pt") if not os.path.exists(model_load_file): raise FileNotFoundError("Expected the downloaded model file for language %s package %s to be in %s, but there is nothing there. Perhaps the package name doesn't exist?" % (args['lang'], args['load_package'], model_load_file)) else: logger.info("Model for language %s package %s is in %s", args['lang'], args['load_package'], model_load_file) # TODO: when loading a saved model, we should default to whatever # is in the model file for --retag_method, not the default for the language if args['mode'] == 'train': if tlogger.level == logging.NOTSET: tlogger.setLevel(logging.DEBUG) tlogger.debug("Set trainer logging level to DEBUG") retag_pipeline = retagging.build_retag_pipeline(args) if args['mode'] == 'train': parser_training.train(args, model_load_file, retag_pipeline) elif args['mode'] == 'predict': parser_training.evaluate(args, model_load_file, retag_pipeline) elif args['mode'] == 'parse_text': load_model_parse_text(args, model_load_file, retag_pipeline) elif args['mode'] == 'remove_optimizer': parser_training.remove_optimizer(args, args['save_name'], model_load_file) if __name__ == '__main__': main() ================================================ FILE: stanza/models/coref/__init__.py ================================================ ================================================ FILE: stanza/models/coref/anaphoricity_scorer.py ================================================ """ Describes AnaphicityScorer, a torch module that for a matrix of mentions produces their anaphoricity scores. """ import torch from stanza.models.coref import utils from stanza.models.coref.config import Config class AnaphoricityScorer(torch.nn.Module): """ Calculates anaphoricity scores by passing the inputs into a FFNN """ def __init__(self, in_features: int, config: Config): super().__init__() hidden_size = config.hidden_size if not config.n_hidden_layers: hidden_size = in_features layers = [] for i in range(config.n_hidden_layers): layers.extend([torch.nn.Linear(hidden_size if i else in_features, hidden_size), torch.nn.LeakyReLU(), torch.nn.Dropout(config.dropout_rate)]) self.hidden = torch.nn.Sequential(*layers) self.out = torch.nn.Linear(hidden_size, out_features=1) # are we going to predict singletons self.predict_singletons = config.singletons if self.predict_singletons: # map to whether or not this is a start of a coref given all the # antecedents; not used when config.singletons = False because # we only need to know this for predicting singletons self.start_map = torch.nn.Linear(config.rough_k, out_features=1, bias=False) def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch top_mentions: torch.Tensor, mentions_batch: torch.Tensor, pw_batch: torch.Tensor, top_rough_scores_batch: torch.Tensor, ) -> torch.Tensor: """ Builds a pairwise matrix, scores the pairs and returns the scores. Args: all_mentions (torch.Tensor): [n_mentions, mention_emb] mentions_batch (torch.Tensor): [batch_size, mention_emb] pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb] top_indices_batch (torch.Tensor): [batch_size, n_ants] top_rough_scores_batch (torch.Tensor): [batch_size, n_ants] Returns: torch.Tensor [batch_size, n_ants + 1] anaphoricity scores for the pairs + a dummy column """ # [batch_size, n_ants, pair_emb] pair_matrix = self._get_pair_matrix(mentions_batch, pw_batch, top_mentions) # [batch_size, n_ants] vs [batch_size, 1] # first is coref scores, the second is whether its the start of a coref if self.predict_singletons: scores, start = self._ffnn(pair_matrix) scores = utils.add_dummy(scores+top_rough_scores_batch, eps=True) return torch.cat([start, scores], dim=1) else: scores = self._ffnn(pair_matrix) return utils.add_dummy(scores+top_rough_scores_batch, eps=True) def _ffnn(self, x: torch.Tensor) -> torch.Tensor: """ Calculates anaphoricity scores. Args: x: tensor of shape [batch_size, n_ants, n_features] Returns: tensor of shape [batch_size, n_ants] """ x = self.out(self.hidden(x)) x = x.squeeze(2) if not self.predict_singletons: return x # because sometimes we only have the first 49 anaphoricities start = x @ self.start_map.weight[:,:x.shape[1]].T return x, start @staticmethod def _get_pair_matrix(mentions_batch: torch.Tensor, pw_batch: torch.Tensor, top_mentions: torch.Tensor) -> torch.Tensor: """ Builds the matrix used as input for AnaphoricityScorer. Args: all_mentions (torch.Tensor): [n_mentions, mention_emb], all the valid mentions of the document, can be on a different device mentions_batch (torch.Tensor): [batch_size, mention_emb], the mentions of the current batch, is expected to be on the current device pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb], pairwise features of the current batch, is expected to be on the current device top_indices_batch (torch.Tensor): [batch_size, n_ants], indices of antecedents of each mention Returns: torch.Tensor: [batch_size, n_ants, pair_emb] """ emb_size = mentions_batch.shape[1] n_ants = pw_batch.shape[1] a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size) b_mentions = top_mentions similarity = a_mentions * b_mentions out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2) return out ================================================ FILE: stanza/models/coref/bert.py ================================================ """Functions related to BERT or similar models""" import logging from typing import List, Tuple import numpy as np # type: ignore from transformers import AutoModel, AutoTokenizer # type: ignore from stanza.models.coref.config import Config from stanza.models.coref.const import Doc logger = logging.getLogger('stanza') def get_subwords_batches(doc: Doc, config: Config, tok: AutoTokenizer ) -> np.ndarray: """ Turns a list of subwords to a list of lists of subword indices of max length == batch_size (or shorter, as batch boundaries should match sentence boundaries). Each batch is enclosed in cls and sep special tokens. Returns: batches of bert tokens [n_batches, batch_size] """ batch_size = config.bert_window_size - 2 # to save space for CLS and SEP subwords: List[str] = doc["subwords"] subwords_batches = [] start, end = 0, 0 while end < len(subwords): # to prevent the case where a batch_size step forward # doesn't capture more than 1 sentence, we will just cut # that sequence prev_end = end end = min(end + batch_size, len(subwords)) # Move back till we hit a sentence end if end < len(subwords): sent_id = doc["sent_id"][doc["word_id"][end]] while end and doc["sent_id"][doc["word_id"][end - 1]] == sent_id: end -= 1 # this occurs IFF there was no sentence end found throughout # the forward scan; this means that our sentence was waay too # long (i.e. longer than the max length of the transformer. # # if so, we give up and just chop the sentence off at the max length # that was given if end == prev_end: end = min(end + batch_size, len(subwords)) length = end - start if tok.cls_token == None or tok.sep_token == None: batch = [tok.eos_token] + subwords[start:end] + [tok.eos_token] else: batch = [tok.cls_token] + subwords[start:end] + [tok.sep_token] # Padding to desired length batch += [tok.pad_token] * (batch_size - length) subwords_batches.append([tok.convert_tokens_to_ids(token) for token in batch]) start += length return np.array(subwords_batches) ================================================ FILE: stanza/models/coref/cluster_checker.py ================================================ """ Describes ClusterChecker, a class used to retrieve LEA scores. See aclweb.org/anthology/P16-1060.pdf. """ from typing import Hashable, List, Tuple from stanza.models.coref.const import EPSILON import numpy as np import math import logging logger = logging.getLogger('stanza') class ClusterChecker: """ Collects information on gold and predicted clusters across documents. Can be used to retrieve weighted LEA-score for them. """ def __init__(self): self._lea_precision = 0.0 self._lea_recall = 0.0 self._lea_precision_weighting = 0.0 self._lea_recall_weighting = 0.0 self._num_preds = 0.0 # muc self._muc_precision = 0.0 self._muc_recall = 0.0 # b3 self._b3_precision = 0.0 self._b3_recall = 0.0 # ceafe self._ceafe_precision = 0.0 self._ceafe_recall = 0.0 @staticmethod def _f1(p,r): return (p * r) / (p+r + EPSILON) * 2 def add_predictions(self, gold_clusters: List[List[Hashable]], pred_clusters: List[List[Hashable]]): """ Calculates LEA for the document's clusters and stores them to later output weighted LEA across documents. Returns: LEA score for the document as a tuple of (f1, precision, recall) """ # if len(gold_clusters) == 0: # breakpoint() self._num_preds += 1 recall, r_weight = ClusterChecker._lea(gold_clusters, pred_clusters) precision, p_weight = ClusterChecker._lea(pred_clusters, gold_clusters) self._muc_recall += ClusterChecker._muc(gold_clusters, pred_clusters) self._muc_precision += ClusterChecker._muc(pred_clusters, gold_clusters) self._b3_recall += ClusterChecker._b3(gold_clusters, pred_clusters) self._b3_precision += ClusterChecker._b3(pred_clusters, gold_clusters) ceafe_precision, ceafe_recall = ClusterChecker._ceafe(pred_clusters, gold_clusters) if math.isnan(ceafe_precision) and len(gold_clusters) > 0: # because our model predicted no clusters ceafe_precision = 0.0 self._ceafe_precision += ceafe_precision self._ceafe_recall += ceafe_recall self._lea_recall += recall self._lea_recall_weighting += r_weight self._lea_precision += precision self._lea_precision_weighting += p_weight doc_precision = precision / (p_weight + EPSILON) doc_recall = recall / (r_weight + EPSILON) doc_f1 = (doc_precision * doc_recall) \ / (doc_precision + doc_recall + EPSILON) * 2 return doc_f1, doc_precision, doc_recall @property def bakeoff(self): """ Get the F1 macroaverage score used by the bakeoff """ return sum(self.mbc)/3 @property def mbc(self): """ Get the F1 average score of (muc, b3, ceafe) over docs """ avg_precisions = [self._muc_precision, self._b3_precision, self._ceafe_precision] avg_precisions = [i/(self._num_preds + EPSILON) for i in avg_precisions] avg_recalls = [self._muc_recall, self._b3_recall, self._ceafe_recall] avg_recalls = [i/(self._num_preds + EPSILON) for i in avg_recalls] avg_f1s = [self._f1(p,r) for p,r in zip(avg_precisions, avg_recalls)] return avg_f1s @property def total_lea(self): """ Returns weighted LEA for all the documents as (f1, precision, recall) """ precision = self._lea_precision / (self._lea_precision_weighting + EPSILON) recall = self._lea_recall / (self._lea_recall_weighting + EPSILON) f1 = self._f1(precision, recall) return f1, precision, recall @staticmethod def _lea(key: List[List[Hashable]], response: List[List[Hashable]]) -> Tuple[float, float]: """ See aclweb.org/anthology/P16-1060.pdf. """ response_clusters = [set(cluster) for cluster in response] response_map = {mention: cluster for cluster in response_clusters for mention in cluster} importances = [] resolutions = [] for entity in key: size = len(entity) if size == 1: # entities of size 1 are not annotated continue importances.append(size) correct_links = 0 for i in range(size): for j in range(i + 1, size): correct_links += int(entity[i] in response_map.get(entity[j], {})) resolutions.append(correct_links / (size * (size - 1) / 2)) res = sum(imp * res for imp, res in zip(importances, resolutions)) weight = sum(importances) return res, weight @staticmethod def _muc(key: List[List[Hashable]], response: List[List[Hashable]]) -> float: """ See aclweb.org/anthology/P16-1060.pdf. """ response_clusters = [set(cluster) for cluster in response] response_map = {mention: cluster for cluster in response_clusters for mention in cluster} top = 0 # sum over k of |k_i| - response_partitions(|k_i|) bottom = 0 # sum over k of |k_i| - 1 for entity in key: S = len(entity) # we need to figure the number of DIFFERENT clusters # the response assigns to members of the entity; ideally # this number is 1 (i.e. they are all assigned the same # coref). response_clusters = [response_map.get(i, None) for i in entity] # and dedplicate deduped = [] for i in response_clusters: if i == None: deduped.append(i) elif i not in deduped: deduped.append(i) # the "partitions" will then be size of the deduped list p_k = len(deduped) top += (S - p_k) bottom += (S - 1) try: return top/bottom except ZeroDivisionError: logger.warning("muc got a zero division error because the model predicted no spans!") return 0 # +inf technically @staticmethod def _b3(key: List[List[Hashable]], response: List[List[Hashable]]) -> float: """ See aclweb.org/anthology/P16-1060.pdf. """ response_clusters = [set(cluster) for cluster in response] top = 0 # sum over key and response of (|k intersect response|^2/|k|) bottom = 0 # sum over k of |k_i| for entity in key: bottom += len(entity) entity = set(entity) for res_entity in response_clusters: top += (len(entity.intersection(res_entity))**2)/len(entity) try: return top/bottom except ZeroDivisionError: logger.warning("b3 got a zero division error because the model predicted no spans!") return 0 # +inf technically @staticmethod def _phi4(c1, c2): return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) @staticmethod def _ceafe(clusters: List[List[Hashable]], gold_clusters: List[List[Hashable]]): """ see https://github.com/ufal/corefud-scorer/blob/main/coval/eval/evaluator.py """ try: from scipy.optimize import linear_sum_assignment except ImportError: raise ImportError("To perform CEAF scoring, please install scipy via `pip install scipy` for the Kuhn-Munkres linear assignment scheme.") clusters = [c for c in clusters] scores = np.zeros((len(gold_clusters), len(clusters))) for i in range(len(gold_clusters)): for j in range(len(clusters)): scores[i, j] = ClusterChecker._phi4(gold_clusters[i], clusters[j]) row_ind, col_ind = linear_sum_assignment(-scores) similarity = scores[row_ind, col_ind].sum() # precision, recall try: prec = similarity/len(clusters) except ZeroDivisionError: logger.warning("ceafe got a zero division error because the model predicted no spans!") prec = 0 recc = similarity/len(gold_clusters) return prec, recc ================================================ FILE: stanza/models/coref/config.py ================================================ """ Describes Config, a simple namespace for config values. For description of all config values, refer to config.toml. """ from dataclasses import dataclass from typing import Dict, List @dataclass class Config: # pylint: disable=too-many-instance-attributes, too-few-public-methods """ Contains values needed to set up the coreference model. """ section: str # TODO: can either eliminate data_dir or use it for the train/dev/test data data_dir: str save_dir: str save_name: str train_data: str dev_data: str test_data: str device: str bert_model: str bert_window_size: int embedding_size: int sp_embedding_size: int a_scoring_batch_size: int hidden_size: int n_hidden_layers: int max_span_len: int rough_k: int lora: bool lora_alpha: int lora_rank: int lora_dropout: float full_pairwise: bool lora_target_modules: List[str] lora_modules_to_save: List[str] clusters_starts_are_singletons: bool bert_finetune: bool dropout_rate: float learning_rate: float bert_learning_rate: float # we find that setting this to a small but non-zero number # makes the model less likely to forget how to do anything bert_finetune_begin_epoch: float train_epochs: int # if plateaued for this many epochs, stop training plateau_epochs: int bce_loss_weight: float tokenizer_kwargs: Dict[str, dict] conll_log_dir: str save_each_checkpoint: bool log_norms: bool singletons: bool max_train_len: int use_zeros: bool lang_lr_attenuation: str lang_lr_weights: str ================================================ FILE: stanza/models/coref/conll.py ================================================ """ Contains functions to produce conll-formatted output files with predicted spans and their clustering """ from collections import defaultdict from contextlib import contextmanager import os from typing import List, TextIO from stanza.models.coref.config import Config from stanza.models.coref.const import Doc, Span # pylint: disable=too-many-locals def write_conll(doc: Doc, clusters: List[List[Span]], heads: List[int], f_obj: TextIO): """ Writes span/cluster information to f_obj, which is assumed to be a file object open for writing """ placeholder = list("\t_" * 7) # the nth token needs to be a number placeholder[9] = "0" placeholder = "".join(placeholder) doc_id = doc["document_id"].replace("-", "_").replace("/", "_").replace(".","_") words = doc["cased_words"] part_id = doc["part_id"] sents = doc["sent_id"] max_word_len = max(len(w) for w in words) starts = defaultdict(lambda: []) ends = defaultdict(lambda: []) single_word = defaultdict(lambda: []) for cluster_id, cluster in enumerate(clusters): if len(heads[cluster_id]) != len(cluster): # TODO debug this fact and why it occurs # print(f"cluster {cluster_id} doesn't have the same number of elements for word and span levels, skipping...") continue for cluster_part, (start, end) in enumerate(cluster): if end - start == 1: single_word[start].append((cluster_part, cluster_id)) else: starts[start].append((cluster_part, cluster_id)) ends[end - 1].append((cluster_part, cluster_id)) f_obj.write(f"# newdoc id = {doc_id}\n# global.Entity = eid-head\n") word_number = 0 sent_id = 0 for word_id, word in enumerate(words): cluster_info_lst = [] for part, cluster_marker in starts[word_id]: start, end = clusters[cluster_marker][part] cluster_info_lst.append(f"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)}") for part, cluster_marker in single_word[word_id]: start, end = clusters[cluster_marker][part] cluster_info_lst.append(f"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)})") for part, cluster_marker in ends[word_id]: cluster_info_lst.append(f"e{cluster_marker})") # we need our clusters to be ordered such that the one that is closest the first change # is listed last in the chains def compare_sort(x): split = x.split("-") if len(split) > 1: return int(split[-1].replace(")", "").strip()) else: # we want everything that's a closer to be first return float("inf") cluster_info_lst = sorted(cluster_info_lst, key=compare_sort, reverse=True) cluster_info = "".join(cluster_info_lst) if cluster_info_lst else "_" if word_id == 0 or sents[word_id] != sents[word_id - 1]: f_obj.write(f"# sent_id = {doc_id}-{sent_id}\n") word_number = 0 sent_id += 1 if cluster_info != "_": cluster_info = f"Entity={cluster_info}" f_obj.write(f"{word_id}\t{word}{placeholder}\t{cluster_info}\n") word_number += 1 f_obj.write("\n") @contextmanager def open_(config: Config, epochs: int, data_split: str): """ Opens conll log files for writing in a safe way. """ base_filename = f"{config.section}_{data_split}_e{epochs}" conll_dir = config.conll_log_dir kwargs = {"mode": "w", "encoding": "utf8"} os.makedirs(conll_dir, exist_ok=True) with open(os.path.join( # type: ignore conll_dir, f"{base_filename}.gold.conll"), **kwargs) as gold_f: with open(os.path.join( # type: ignore conll_dir, f"{base_filename}.pred.conll"), **kwargs) as pred_f: yield (gold_f, pred_f) ================================================ FILE: stanza/models/coref/const.py ================================================ """ Contains type aliases for coref module """ from dataclasses import dataclass from typing import Any, Dict, List, Tuple import torch EPSILON = 1e-7 LARGE_VALUE = 1000 # used instead of inf due to bug #16762 in pytorch Doc = Dict[str, Any] Span = Tuple[int, int] @dataclass class CorefResult: coref_scores: torch.Tensor = None # [n_words, k + 1] coref_y: torch.Tensor = None # [n_words, k + 1] rough_y: torch.Tensor = None # [n_words, n_words] word_clusters: List[List[int]] = None span_clusters: List[List[Span]] = None rough_scores: torch.Tensor = None # [n_words, n_words] span_scores: torch.Tensor = None # [n_heads, n_words, 2] span_y: Tuple[torch.Tensor, torch.Tensor] = None # [n_heads] x2 zero_scores: torch.Tensor = None ================================================ FILE: stanza/models/coref/coref_chain.py ================================================ """ Coref chain suitable for attaching to a Document after coref processing """ # by not using namedtuple, we can use this object as output from the json module # in the doc class as long as we wrap the encoder to print these out in dict() form # CorefMention = namedtuple('CorefMention', ['sentence', 'start_word', 'end_word']) class CorefMention: def __init__(self, sentence, start_word, end_word): self.sentence = sentence self.start_word = start_word self.end_word = end_word class CorefChain: def __init__(self, index, mentions, representative_text, representative_index): self.index = index self.mentions = mentions self.representative_text = representative_text self.representative_index = representative_index class CorefAttachment: def __init__(self, chain, is_start, is_end, is_representative): self.chain = chain self.is_start = is_start self.is_end = is_end self.is_representative = is_representative def to_json(self): j = { "index": self.chain.index, "representative_text": self.chain.representative_text } if self.is_start: j['is_start'] = True if self.is_end: j['is_end'] = True if self.is_representative: j['is_representative'] = True return j ================================================ FILE: stanza/models/coref/coref_config.toml ================================================ # ============================================================================= # Before you start changing anything here, read the comments. # All of them can be found below in the "DEFAULT" section [DEFAULT] # The directory that contains extracted files of everything you've downloaded. data_dir = "data/coref" # where to put checkpoints and final models save_dir = "saved_models/coref" save_name = "bert-large-cased" # Train, dev and test jsonlines # train_data = "data/coref/en_gum-ud.train.nosgl.json" # dev_data = "data/coref/en_gum-ud.dev.nosgl.json" # test_data = "data/coref/en_gum-ud.test.nosgl.json" train_data = "data/coref/corefud_concat_v1_0_langid.train.json" dev_data = "data/coref/corefud_concat_v1_0_langid.dev.json" test_data = "data/coref/corefud_concat_v1_0_langid.dev.json" #train_data = "data/coref/english_train_head.jsonlines" #dev_data = "data/coref/english_development_head.jsonlines" #test_data = "data/coref/english_test_head.jsonlines" # do not use the full pairwise encoding scheme full_pairwise = false # The device where everything is to be placed. "cuda:N"/"cpu" are supported. device = "cuda:0" save_each_checkpoint = false log_norms = false # Bert settings ====================== # Base bert model architecture and tokenizer bert_model = "bert-large-cased" # Controls max length of sequences passed through bert to obtain its # contextual embeddings # Must be less than or equal to 512 bert_window_size = 512 # General model settings ============= # Controls the dimensionality of feature embeddings embedding_size = 20 # Controls the dimensionality of distance embeddings used by SpanPredictor sp_embedding_size = 64 # Controls the number of spans for which anaphoricity can be scores in one # batch. Only affects final scoring; mention extraction and rough scoring # are less memory intensive, so they are always done in just one batch. a_scoring_batch_size = 128 # AnaphoricityScorer FFNN parameters hidden_size = 1024 n_hidden_layers = 1 # Do you want to support singletons? singletons = true # Mention extraction settings ======== # Mention extractor will check spans up to max_span_len words # The default value is chosen to be big enough to hold any dev data span max_span_len = 64 # Pruning settings =================== # Controls how many pairs should be preserved per mention # after applying rough scoring. rough_k = 50 # Lora settings =================== # LoRA settings lora = false lora_alpha = 128 lora_dropout = 0.1 lora_rank = 64 lora_target_modules = [] lora_modules_to_save = [] # Training settings ================== # Controls whether the first dummy node predicts cluster starts or singletons clusters_starts_are_singletons = true # Controls whether to fine-tune bert_model bert_finetune = true # Controls the dropout rate throughout all models dropout_rate = 0.3 # Bert learning rate (only used if bert_finetune is set) bert_learning_rate = 1e-6 bert_finetune_begin_epoch = 0.5 # Task learning rate learning_rate = 3e-4 # For how many epochs the training is done train_epochs = 32 # plateau for this many epochs = early terminate plateau_epochs = 0 # Controls the weight of binary cross entropy loss added to nlml loss bce_loss_weight = 0.5 # The directory that will contain conll prediction files conll_log_dir = "data/conll_logs" # Skip any documents longer than this length max_train_len = 5000 # if this is set to false, the model will set its zero_predictor to, well, 0 use_zeros = true # two different methods for specifying how to weaken the LR for certain languages # however, in their current forms, on an HE experiment, neither worked # better than just mixing the two datasets together unweighted # Starting from the HE IAHLT dataset, and possibly mixing in the ger/rom ud coref, # averaging over 5 different seeds, we got the following results: # HE only: 0.497 # Attenuated: 0.508 # Scaled: 0.517 # Mixed: 0.517 # the attenuation scheme for that experiment was 1/epoch # These were the settings # --lang_lr_weights es=0.2,en=0.2,de=0.2,ca=0.2,fr=0.2,no=0.2 # --lang_lr_attenuation es,en,de,ca,fr,no lang_lr_attenuation = "" lang_lr_weights = "" # ============================================================================= # Extra keyword arguments to be passed to bert tokenizers of specified models [DEFAULT.tokenizer_kwargs] [DEFAULT.tokenizer_kwargs.roberta-large] "add_prefix_space" = true [DEFAULT.tokenizer_kwargs.xlm-roberta-large] "add_prefix_space" = true [DEFAULT.tokenizer_kwargs.spanbert-large-cased] "do_lower_case" = false [DEFAULT.tokenizer_kwargs.bert-large-cased] "do_lower_case" = false # ============================================================================= # The sections listed here do not need to make use of all config variables # If a variable is omitted, its default value will be used instead [roberta] bert_model = "roberta-large" [roberta_lora] bert_model = "roberta-large" bert_learning_rate = 0.00005 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [scandibert_lora] bert_model = "vesteinn/ScandiBERT" bert_learning_rate = 0.0002 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [xlm_roberta] bert_model = "FacebookAI/xlm-roberta-large" bert_learning_rate = 0.00001 bert_finetune = true [xlm_roberta_lora] bert_model = "FacebookAI/xlm-roberta-large" bert_learning_rate = 0.000025 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [deeppavlov_slavic_bert_lora] bert_model = "DeepPavlov/bert-base-bg-cs-pl-ru-cased" bert_learning_rate = 0.000025 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [deberta_lora] bert_model = "microsoft/deberta-v3-large" bert_learning_rate = 0.00001 lora = true lora_target_modules = [ "query_proj", "value_proj", "output.dense" ] lora_modules_to_save = [ ] [electra] bert_model = "google/electra-large-discriminator" bert_learning_rate = 0.00002 [electra_lora] bert_model = "google/electra-large-discriminator" bert_learning_rate = 0.000025 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ ] [hungarian_electra_lora] # TODO: experiment with tokenizer options for this to see if that's # why the results are so low using this transformer bert_model = "NYTK/electra-small-discriminator-hungarian" bert_learning_rate = 0.000025 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ ] [muril_large_cased_lora] bert_model = "google/muril-large-cased" bert_learning_rate = 0.000025 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [muril_base_cased_lora] bert_model = "google/muril-base-cased" bert_learning_rate = 0.000025 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [indic_bert_lora] bert_model = "ai4bharat/indic-bert" bert_learning_rate = 0.0005 lora = true # indic-bert is an albert with repeating layers of different names lora_target_modules = [ "query", "value", "dense", "ffn", "full_layer" ] lora_modules_to_save = [ "pooler" ] [alephbertgimmel_lora] bert_model = "imvladikon/alephbertgimmel-base-512" bert_learning_rate = 0.000025 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [alephbert_lora] bert_model = "onlplab/alephbert-base" bert_learning_rate = 0.000025 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [hero_lora] # LR sweep on Hebrew IAHLT coref dev scores # (although there may be tokenization problems) # 0.000005 0.44202 # 0.00001 0.45271 # 0.000015 0.45771 # 0.00002 0.45877 # 0.000025 0.46076 # 0.00003 0.45957 # 0.000035 0.46187 # 0.00004 0.46066 # 0.000045 0.46132 # 0.00005 0.46238 # 0.000055 0.46084 # 0.00006 0.46047 # 0.000075 0.45772 # 0.0001 0.44910 bert_model = "HeNLP/HeRo" bert_learning_rate = 0.00005 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [bert_multilingual_cased_lora] # LR sweep on a Hindi dataset # 0.00001: 0.53238 # 0.00002: 0.54012 # 0.000025: 0.54206 # 0.00003: 0.54050 # 0.00004: 0.55081 # 0.00005: 0.55135 # 0.000075: 0.54482 # 0.0001: 0.53888 bert_model = "google-bert/bert-base-multilingual-cased" bert_learning_rate = 0.00005 lora = true lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ] lora_modules_to_save = [ "pooler" ] [t5_lora] bert_model = "google-t5/t5-large" bert_learning_rate = 0.000025 bert_window_size = 1024 lora = true lora_target_modules = [ "q", "v", "o", "wi", "wo" ] lora_modules_to_save = [ ] [mt5_lora] bert_model = "google/mt5-base" bert_learning_rate = 0.000025 lora_alpha = 64 lora_rank = 32 lora = true lora_target_modules = [ "q", "v", "o", "wi", "wo" ] lora_modules_to_save = [ ] [deepnarrow_t5_xl_lora] bert_model = "google/t5-efficient-xl" bert_learning_rate = 0.00025 lora = true lora_target_modules = [ "q", "v", "o", "wi", "wo" ] lora_modules_to_save = [ ] [roberta_no_finetune] bert_model = "roberta-large" bert_finetune = false [roberta_no_bce] bert_model = "roberta-large" bce_loss_weight = 0.0 [spanbert] bert_model = "SpanBERT/spanbert-large-cased" [spanbert_no_bce] bert_model = "SpanBERT/spanbert-large-cased" bce_loss_weight = 0.0 [bert] bert_model = "bert-large-cased" [longformer] bert_model = "allenai/longformer-large-4096" bert_window_size = 2048 [debug] bert_window_size = 384 bert_finetune = false device = "cpu:0" [debug_gpu] bert_window_size = 384 bert_finetune = false ================================================ FILE: stanza/models/coref/dataset.py ================================================ import json import logging from torch.utils.data import Dataset from stanza.models.coref.tokenizer_customization import TOKENIZER_FILTERS, TOKENIZER_MAPS logger = logging.getLogger('stanza') class CorefDataset(Dataset): def __init__(self, path, config, tokenizer): self.config = config self.tokenizer = tokenizer # by default, this doesn't filter anything (see lambda _ True); # however, there are some subword symbols which are standalone # tokens which we don't want on models like Albert; hence we # pass along a filter if needed. self.__filter_func = TOKENIZER_FILTERS.get(self.config.bert_model, lambda _: True) self.__token_map = TOKENIZER_MAPS.get(self.config.bert_model, {}) try: with open(path, encoding="utf-8") as fin: data_f = json.load(fin) except json.decoder.JSONDecodeError: # read the old jsonlines format if necessary with open(path, encoding="utf-8") as fin: text = "[" + ",\n".join(fin) + "]" data_f = json.loads(text) logger.info("Processing %d docs from %s...", len(data_f), path) self.__raw = data_f self.__avg_span = sum(len(doc["head2span"]) for doc in self.__raw) / len(self.__raw) self.__out = [] for doc in self.__raw: doc["span_clusters"] = [[tuple(mention) for mention in cluster] for cluster in doc["span_clusters"]] word2subword = [] subwords = [] word_id = [] for i, word in enumerate(doc["cased_words"]): tokenized = self.tokenizer.tokenize(word) if len(tokenized) == 0: word = "_" doc["cased_words"][i] = word tokenized = self.tokenizer.tokenize(word) assert len(tokenized) > 0 tokenized_word = self.__token_map.get(word, tokenized) tokenized_word = list(filter(self.__filter_func, tokenized_word)) word2subword.append((len(subwords), len(subwords) + len(tokenized_word))) subwords.extend(tokenized_word) word_id.extend([i] * len(tokenized_word)) doc["word2subword"] = word2subword doc["subwords"] = subwords doc["word_id"] = word_id self.__out.append(doc) logger.info("Loaded %d docs from %s.", len(data_f), path) @property def avg_span(self): return self.__avg_span def __getitem__(self, x): return self.__out[x] def __len__(self): return len(self.__out) ================================================ FILE: stanza/models/coref/loss.py ================================================ """ Describes the loss function used to train the model, which is a weighted sum of NLML and BCE losses. """ import torch class CorefLoss(torch.nn.Module): """ See the rationale for using NLML in Lee et al. 2017 https://www.aclweb.org/anthology/D17-1018/ The added weighted summand of BCE helps the model learn even after converging on the NLML task. """ def __init__(self, bce_weight: float): assert 0 <= bce_weight <= 1 super().__init__() self._bce_module = torch.nn.BCEWithLogitsLoss() self._bce_weight = bce_weight def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Returns a weighted sum of two losses as a torch.Tensor """ return (self._nlml(input_, target) + self._bce(input_, target) * self._bce_weight) def _bce(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ For numerical stability, clamps the input before passing it to BCE. """ return self._bce_module(torch.clamp(input_, min=-50, max=50), target) @staticmethod def _nlml(input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: gold = torch.logsumexp(input_ + torch.log(target), dim=1) input_ = torch.logsumexp(input_, dim=1) return (input_ - gold).mean() ================================================ FILE: stanza/models/coref/model.py ================================================ """ see __init__.py """ from datetime import datetime import dataclasses import json import logging import os import random import re from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np # type: ignore try: import tomllib except ImportError: import tomli as tomllib import torch import transformers # type: ignore from pickle import UnpicklingError import warnings from stanza.utils.get_tqdm import get_tqdm # type: ignore tqdm = get_tqdm() from stanza.models.coref import bert, conll, utils from stanza.models.coref.anaphoricity_scorer import AnaphoricityScorer from stanza.models.coref.cluster_checker import ClusterChecker from stanza.models.coref.config import Config from stanza.models.coref.const import CorefResult, Doc from stanza.models.coref.loss import CorefLoss from stanza.models.coref.pairwise_encoder import PairwiseEncoder from stanza.models.coref.rough_scorer import RoughScorer from stanza.models.coref.span_predictor import SpanPredictor from stanza.models.coref.utils import GraphNode from stanza.models.coref.utils import sigmoid_focal_loss from stanza.models.coref.word_encoder import WordEncoder from stanza.models.coref.dataset import CorefDataset from stanza.models.coref.tokenizer_customization import * from stanza.models.common.bert_embedding import load_tokenizer from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper import torch.nn as nn logger = logging.getLogger('stanza') class CorefModel: # pylint: disable=too-many-instance-attributes """Combines all coref modules together to find coreferent spans. Attributes: config (coref.config.Config): the model's configuration, see config.toml for the details epochs_trained (int): number of epochs the model has been trained for trainable (Dict[str, torch.nn.Module]): trainable submodules with their names used as keys training (bool): used to toggle train/eval modes Submodules (in the order of their usage in the pipeline): tokenizer (transformers.AutoTokenizer) bert (transformers.AutoModel) we (WordEncoder) rough_scorer (RoughScorer) pw (PairwiseEncoder) a_scorer (AnaphoricityScorer) sp (SpanPredictor) """ def __init__(self, epochs_trained: int = 0, build_optimizers: bool = True, config: Optional[dict] = None, foundation_cache=None): """ A newly created model is set to evaluation mode. Args: config_path (str): the path to the toml file with the configuration section (str): the selected section of the config file epochs_trained (int): the number of epochs finished (useful for warm start) """ if config is None: raise ValueError("Cannot create a model without a config") self.config = config self.epochs_trained = epochs_trained self._docs: Dict[str, List[Doc]] = {} self._build_model(foundation_cache) self.optimizers = {} self.schedulers = {} if build_optimizers: self._build_optimizers() self._set_training(False) # final coreference resolution score self._coref_criterion = CorefLoss(self.config.bce_loss_weight) # score simply for the top-k choices out of the rough scorer self._rough_criterion = CorefLoss(0) # exact span matches self._span_criterion = torch.nn.CrossEntropyLoss(reduction="sum") @property def training(self) -> bool: """ Represents whether the model is in the training mode """ return self._training @training.setter def training(self, new_value: bool): if self._training is new_value: return self._set_training(new_value) # ========================================================== Public methods @torch.no_grad() def evaluate(self, data_split: str = "dev", word_level_conll: bool = False, eval_lang: Optional[str] = None ) -> Tuple[float, Tuple[float, float, float]]: """ Evaluates the modes on the data split provided. Args: data_split (str): one of 'dev'/'test'/'train' word_level_conll (bool): if True, outputs conll files on word-level eval_lang (str): which language to evaluate Returns: mean loss span-level LEA: f1, precision, recal """ self.training = False w_checker = ClusterChecker() s_checker = ClusterChecker() try: data_split_data = f"{data_split}_data" data_path = self.config.__dict__[data_split_data] docs = self._get_docs(data_path) except FileNotFoundError as e: raise FileNotFoundError("Unable to find data split %s at file %s" % (data_split_data, data_path)) from e running_loss = 0.0 s_correct = 0 s_total = 0 z_correct = 0 z_total = 0 with conll.open_(self.config, self.epochs_trained, data_split) \ as (gold_f, pred_f): pbar = tqdm(docs, unit="docs", ncols=0) for doc in pbar: if eval_lang and doc.get("lang", "") != eval_lang: # skip that document, only used for ablation where we only # want to test evaluation on one language continue res = self.run(doc, True) # measure zero prediction accuracy zero_preds = (res.zero_scores > 0).view(-1).to(device=res.zero_scores.device) is_zero = doc.get("is_zero") if is_zero is None: zero_targets = torch.zeros_like(zero_preds, device=zero_preds.device) else: zero_targets = torch.tensor(is_zero, device=zero_preds.device) z_correct += (zero_preds == zero_targets).sum().item() z_total += zero_targets.numel() if (res.coref_y.argmax(dim=1) == 1).all(): logger.warning(f"EVAL: skipping document with no corefs...") continue running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item() if res.word_clusters is None or res.span_clusters is None: logger.warning(f"EVAL: skipping document with no clusters...") continue if res.span_y: pred_starts = res.span_scores[:, :, 0].argmax(dim=1) pred_ends = res.span_scores[:, :, 1].argmax(dim=1) s_correct += ((res.span_y[0] == pred_starts) * (res.span_y[1] == pred_ends)).sum().item() s_total += len(pred_starts) if word_level_conll: raise NotImplementedError("We now write Conll-U conforming to UDCoref, which means that the span_clusters annotations will have headword info. word_level option is meaningless.") else: conll.write_conll(doc, doc["span_clusters"], doc["word_clusters"], gold_f) conll.write_conll(doc, res.span_clusters, res.word_clusters, pred_f) w_checker.add_predictions(doc["word_clusters"], res.word_clusters) w_lea = w_checker.total_lea s_checker.add_predictions(doc["span_clusters"], res.span_clusters) s_lea = s_checker.total_lea del res pbar.set_description( f"{data_split}:" f" | WL: " f" loss: {running_loss / (pbar.n + 1):<.5f}," f" f1: {w_lea[0]:.5f}," f" p: {w_lea[1]:.5f}," f" r: {w_lea[2]:<.5f}" f" | SL: " f" sa: {s_correct / s_total:<.5f}," f" f1: {s_lea[0]:.5f}," f" p: {s_lea[1]:.5f}," f" r: {s_lea[2]:<.5f}" f" | ZA: {z_correct / z_total:<.5f}" ) logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}") logger.info(f"Zero prediction accuracy: {z_correct / z_total:.5f}") return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff) def load_weights(self, path: Optional[str] = None, ignore: Optional[Set[str]] = None, map_location: Optional[str] = None, noexception: bool = False) -> None: """ Loads pretrained weights of modules saved in a file located at path. If path is None, the last saved model with current configuration in save_dir is loaded. Assumes files are named like {configuration}_(e{epoch}_{time})*.pt. """ if path is None: # pattern = rf"{self.config.save_name}_\(e(\d+)_[^()]*\).*\.pt" # tries to load the last checkpoint in the same dir pattern = rf"{self.config.save_name}.*?\.checkpoint\.pt" files = [] os.makedirs(self.config.save_dir, exist_ok=True) for f in os.listdir(self.config.save_dir): match_obj = re.match(pattern, f) if match_obj: files.append(f) if not files: if noexception: logger.debug("No weights have been loaded", flush=True) return raise OSError(f"No weights found in {self.config.save_dir}!") path = sorted(files)[-1] path = os.path.join(self.config.save_dir, path) if map_location is None: map_location = self.config.device logger.debug(f"Loading from {path}...") try: state_dicts = torch.load(path, map_location=map_location, weights_only=True) except UnpicklingError: state_dicts = torch.load(path, map_location=map_location, weights_only=False) warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.") self.epochs_trained = state_dicts.pop("epochs_trained", 0) # just ignore a config in the model, since we should already have one # TODO: some config elements may be fixed parameters of the model, # such as the dimensions of the head, # so we would want to use the ones from the config even if the # user created a weird shaped model config = state_dicts.pop("config", {}) self.load_state_dicts(state_dicts, ignore) def load_state_dicts(self, state_dicts: dict, ignore: Optional[Set[str]] = None): """ Process the dictionaries from the save file Loads the weights into the tensors of this model May also have optimizer and/or schedule state """ for key, state_dict in state_dicts.items(): logger.debug("Loading state: %s", key) if not ignore or key not in ignore: if key.endswith("_optimizer"): self.optimizers[key].load_state_dict(state_dict) elif key.endswith("_scheduler"): self.schedulers[key].load_state_dict(state_dict) elif key == "bert_lora": assert self.config.lora, "Unable to load state dict of LoRA model into model initialized without LoRA!" self.bert = load_peft_wrapper(self.bert, state_dict, vars(self.config), logger, self.peft_name) else: self.trainable[key].load_state_dict(state_dict, strict=False) logger.debug(f"Loaded {key}") if self.config.log_norms: self.log_norms() def build_doc(self, doc: dict) -> dict: filter_func = TOKENIZER_FILTERS.get(self.config.bert_model, lambda _: True) token_map = TOKENIZER_MAPS.get(self.config.bert_model, {}) word2subword = [] subwords = [] word_id = [] for i, word in enumerate(doc["cased_words"]): tokenized_word = (token_map[word] if word in token_map else self.tokenizer.tokenize(word)) tokenized_word = list(filter(filter_func, tokenized_word)) word2subword.append((len(subwords), len(subwords) + len(tokenized_word))) subwords.extend(tokenized_word) word_id.extend([i] * len(tokenized_word)) doc["word2subword"] = word2subword doc["subwords"] = subwords doc["word_id"] = word_id doc["head2span"] = [] if "speaker" not in doc: doc["speaker"] = ["_" for _ in doc["cased_words"]] doc["word_clusters"] = [] doc["span_clusters"] = [] return doc @staticmethod def load_model(path: str, map_location: str = "cpu", ignore: Optional[Set[str]] = None, config_update: Optional[dict] = None, foundation_cache = None): if not path: raise FileNotFoundError("coref model got an invalid path |%s|" % path) if not os.path.exists(path): raise FileNotFoundError("coref model file %s not found" % path) try: state_dicts = torch.load(path, map_location=map_location, weights_only=True) except UnpicklingError: state_dicts = torch.load(path, map_location=map_location, weights_only=False) warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.") epochs_trained = state_dicts.pop("epochs_trained", 0) config = state_dicts.pop('config', None) if config is None: raise ValueError("Cannot load this format model without config in the dicts") if 'max_train_len' not in config: # TODO: this is to keep old models working. # Can get rid of it if those models are rebuilt config['max_train_len'] = 5000 if isinstance(config, dict): config = Config(**config) if config_update: for key, value in config_update.items(): setattr(config, key, value) model = CorefModel(config=config, build_optimizers=False, epochs_trained=epochs_trained, foundation_cache=foundation_cache) model.load_state_dicts(state_dicts, ignore) return model def run(self, # pylint: disable=too-many-locals doc: Doc, use_gold_spans_for_zeros = False ) -> CorefResult: """ This is a massive method, but it made sense to me to not split it into several ones to let one see the data flow. Args: doc (Doc): a dictionary with the document data. Returns: CorefResult (see const.py) """ # Encode words with bert # words [n_words, span_emb] # cluster_ids [n_words] words, cluster_ids = self.we(doc, self._bertify(doc)) # Obtain bilinear scores and leave only top-k antecedents for each word # top_rough_scores [n_words, n_ants] # top_indices [n_words, n_ants] top_rough_scores, top_indices, rough_scores = self.rough_scorer(words) # Get pairwise features [n_words, n_ants, n_pw_features] pw = self.pw(top_indices, doc) batch_size = self.config.a_scoring_batch_size a_scores_lst: List[torch.Tensor] = [] for i in range(0, len(words), batch_size): pw_batch = pw[i:i + batch_size] words_batch = words[i:i + batch_size] top_indices_batch = top_indices[i:i + batch_size] top_rough_scores_batch = top_rough_scores[i:i + batch_size] # a_scores_batch [batch_size, n_ants] a_scores_batch = self.a_scorer( top_mentions=words[top_indices_batch], mentions_batch=words_batch, pw_batch=pw_batch, top_rough_scores_batch=top_rough_scores_batch ) a_scores_lst.append(a_scores_batch) res = CorefResult() # coref_scores [n_spans, n_ants] res.coref_scores = torch.cat(a_scores_lst, dim=0) res.coref_y = self._get_ground_truth( cluster_ids, top_indices, (top_rough_scores > float("-inf")), self.config.clusters_starts_are_singletons, self.config.singletons ) res.word_clusters = self._clusterize( doc, res.coref_scores, top_indices, self.config.singletons ) res.span_scores, res.span_y = self.sp.get_training_data(doc, words) if not self.training: res.span_clusters = self.sp.predict(doc, words, res.word_clusters) if not self.training and not use_gold_spans_for_zeros: zero_words = words[[word_id for cluster in res.word_clusters for word_id in cluster]] else: zero_words = words[[i[0] for i in sorted(doc["head2span"])]] res.zero_scores = self.zeros_predictor(zero_words) return res def save_weights(self, save_path=None, save_optimizers=True): """ Saves trainable models as state dicts. """ to_save: List[Tuple[str, Any]] = \ [(key, value) for key, value in self.trainable.items() if (self.config.bert_finetune and not self.config.lora) or key != "bert"] if save_optimizers: to_save.extend(self.optimizers.items()) to_save.extend(self.schedulers.items()) time = datetime.strftime(datetime.now(), "%Y.%m.%d_%H.%M") if save_path is None: save_path = os.path.join(self.config.save_dir, f"{self.config.save_name}" f"_e{self.epochs_trained}_{time}.pt") savedict = {name: module.state_dict() for name, module in to_save} if self.config.lora: # so that this dependency remains optional from peft import get_peft_model_state_dict savedict["bert_lora"] = get_peft_model_state_dict(self.bert, adapter_name="coref") savedict["epochs_trained"] = self.epochs_trained # type: ignore # save as a dictionary because the weights_only=True load option # doesn't allow for arbitrary @dataclass configs savedict["config"] = dataclasses.asdict(self.config) save_dir = os.path.split(save_path)[0] if save_dir: os.makedirs(save_dir, exist_ok=True) torch.save(savedict, save_path) def log_norms(self): lines = ["NORMS FOR MODEL PARAMTERS"] for t_name, trainable in self.trainable.items(): for name, param in trainable.named_parameters(): if param.requires_grad: lines.append(" %s: %s %.6g (%d)" % (t_name, name, torch.norm(param).item(), param.numel())) logger.info("\n".join(lines)) def train(self, log=False): """ Trains all the trainable blocks in the model using the config provided. log: whether or not to log using wandb skip_lang: str if we want to skip training this language (used for ablation) """ if log: import wandb wandb.watch((self.bert, self.pw, self.a_scorer, self.we, self.rough_scorer, self.sp)) docs = self._get_docs(self.config.train_data) docs_ids = list(range(len(docs))) avg_spans = docs.avg_span # for a brand new model, we set the zeros prediction to all 0 if the dataset has no zeros training_has_zeros = any('is_zero' in doc for doc in docs) if not training_has_zeros: logger.info("No zeros found in the dataset. The zeros predictor will set to 0") if self.epochs_trained == 0: # new model, set it to always predict not-zero self.disable_zeros_predictor() attenuated_languages = set() if self.config.lang_lr_attenuation: attenuated_languages = self.config.lang_lr_attenuation.split(",") logger.info("Attenuating LR for the following languages: %s", attenuated_languages) lr_scaled_languages = dict() if self.config.lang_lr_weights: scaled_languages = self.config.lang_lr_weights.split(",") for piece in scaled_languages: pieces = piece.split("=") lr_scaled_languages[pieces[0]] = float(pieces[1]) logger.info("Scaling LR for the following languages: %s", lr_scaled_languages) best_f1 = None best_epoch = self.epochs_trained for epoch in range(self.epochs_trained, self.config.train_epochs): self.training = True if self.config.log_norms: self.log_norms() running_c_loss = 0.0 running_s_loss = 0.0 running_z_loss = 0.0 random.shuffle(docs_ids) pbar = tqdm(docs_ids, unit="docs", ncols=0) for doc_indx, doc_id in enumerate(pbar): doc = docs[doc_id] # skip very long documents during training time if len(doc["subwords"]) > self.config.max_train_len: continue for optim in self.optimizers.values(): optim.zero_grad() res = self.run(doc) if res.zero_scores.size(0) == 0 or not training_has_zeros: z_loss = 0.0 # since there are no corefs else: is_zero = doc.get("is_zero") if is_zero is None: is_zero = torch.zeros_like(res.zero_scores.squeeze(-1), device=res.zero_scores.device, dtype=torch.float) else: is_zero = torch.tensor(is_zero).to(res.zero_scores.device).float() z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1), is_zero, reduction="mean") c_loss = self._coref_criterion(res.coref_scores, res.coref_y) if res.span_y: s_loss = (self._span_criterion(res.span_scores[:, :, 0], res.span_y[0]) + self._span_criterion(res.span_scores[:, :, 1], res.span_y[1])) / avg_spans / 2 else: s_loss = torch.zeros_like(c_loss) lr_scale = lr_scaled_languages.get(doc.get("lang"), 1.0) if doc.get("lang") in attenuated_languages: lr_scale = lr_scale / max(epoch, 1.0) c_loss = c_loss * lr_scale s_loss = s_loss * lr_scale z_loss = z_loss * lr_scale (c_loss + s_loss + z_loss).backward() running_c_loss += c_loss.item() running_s_loss += s_loss.item() if res.zero_scores.size(0) != 0 and training_has_zeros: running_z_loss += z_loss.item() # log every 100 docs if log and doc_indx % 100 == 0: logged = { 'train_c_loss': c_loss.item(), 'train_s_loss': s_loss.item(), } if res.zero_scores.size(0) != 0 and training_has_zeros: logged['train_z_loss'] = z_loss.item() wandb.log(logged) del c_loss, s_loss, z_loss, res for optim in self.optimizers.values(): optim.step() for scheduler in self.schedulers.values(): scheduler.step() pbar.set_description( f"Epoch {epoch + 1}:" f" {doc['document_id']:26}" f" c_loss: {running_c_loss / (pbar.n + 1):<.5f}" f" s_loss: {running_s_loss / (pbar.n + 1):<.5f}" f" z_loss: {running_z_loss / (pbar.n + 1):<.5f}" ) self.epochs_trained += 1 scores = self.evaluate() prev_best_f1 = best_f1 if log: wandb.log({'dev_score': scores[1]}) wandb.log({'dev_bakeoff': scores[-1]}) if best_f1 is None or scores[1] > best_f1: best_epoch = epoch if best_f1 is None: logger.info("Saving new best model: F1 %.4f", scores[1]) else: logger.info("Saving new best model: F1 %.4f > %.4f", scores[1], best_f1) best_f1 = scores[1] if self.config.save_name.endswith(".pt"): save_path = os.path.join(self.config.save_dir, f"{self.config.save_name}") else: save_path = os.path.join(self.config.save_dir, f"{self.config.save_name}.pt") self.save_weights(save_path, save_optimizers=False) if self.config.save_each_checkpoint: self.save_weights() else: if self.config.save_name.endswith(".pt"): checkpoint_path = os.path.join(self.config.save_dir, f"{self.config.save_name[:-3]}.checkpoint.pt") else: checkpoint_path = os.path.join(self.config.save_dir, f"{self.config.save_name}.checkpoint.pt") self.save_weights(checkpoint_path) if prev_best_f1 is not None and prev_best_f1 != best_f1: logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f\nPrevious best F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1, prev_best_f1) else: logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1) if self.config.plateau_epochs > 0 and best_epoch + self.config.plateau_epochs < epoch: logger.info("Have plateaued for too long (%d epochs). Will terminate training", self.config.plateau_epochs) break # ========================================================= Private methods def _bertify(self, doc: Doc) -> torch.Tensor: all_batches = bert.get_subwords_batches(doc, self.config, self.tokenizer) # we index the batches n at a time to prevent oom result = [] for i in range(0, all_batches.shape[0], 1024): subwords_batches = all_batches[i:i+1024] special_tokens = np.array([self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.pad_token_id, self.tokenizer.eos_token_id]) subword_mask = ~(np.isin(subwords_batches, special_tokens)) subwords_batches_tensor = torch.tensor(subwords_batches, device=self.config.device, dtype=torch.long) subword_mask_tensor = torch.tensor(subword_mask, device=self.config.device) # Obtain bert output for selected batches only attention_mask = (subwords_batches != self.tokenizer.pad_token_id) if "t5" in self.config.bert_model: out = self.bert.encoder( input_ids=subwords_batches_tensor, attention_mask=torch.tensor( attention_mask, device=self.config.device)) else: out = self.bert( subwords_batches_tensor, attention_mask=torch.tensor( attention_mask, device=self.config.device)) out = out['last_hidden_state'] # [n_subwords, bert_emb] result.append(out[subword_mask_tensor]) # stack returns and return return torch.cat(result) def _build_model(self, foundation_cache): if hasattr(self.config, 'lora') and self.config.lora: self.bert, self.tokenizer, peft_name = load_bert_with_peft(self.config.bert_model, "coref", foundation_cache) # vars() converts a dataclass to a dict, used for being able to index things like args["lora_*"] self.bert = build_peft_wrapper(self.bert, vars(self.config), logger, adapter_name=peft_name) self.peft_name = peft_name else: if self.config.bert_finetune: logger.debug("Coref model requested a finetuned transformer; we are not using the foundation model cache to prevent we accidentally leak the finetuning weights elsewhere.") foundation_cache = NoTransformerFoundationCache(foundation_cache) self.bert, self.tokenizer = load_bert(self.config.bert_model, foundation_cache) base_bert_name = self.config.bert_model.split("/")[-1] tokenizer_kwargs = self.config.tokenizer_kwargs.get(base_bert_name, {}) if tokenizer_kwargs: logger.debug(f"Using tokenizer kwargs: {tokenizer_kwargs}") # we just downloaded the tokenizer, so for simplicity, we don't make another request to HF self.tokenizer = load_tokenizer(self.config.bert_model, tokenizer_kwargs, local_files_only=True) if self.config.bert_finetune or (hasattr(self.config, 'lora') and self.config.lora): self.bert = self.bert.train() self.bert = self.bert.to(self.config.device) self.pw = PairwiseEncoder(self.config).to(self.config.device) bert_emb = self.bert.config.hidden_size pair_emb = bert_emb * 3 + self.pw.shape # pylint: disable=line-too-long self.a_scorer = AnaphoricityScorer(pair_emb, self.config).to(self.config.device) self.we = WordEncoder(bert_emb, self.config).to(self.config.device) self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device) self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device) self.zeros_predictor = nn.Sequential( nn.Linear(bert_emb, bert_emb), nn.ReLU(), nn.Linear(bert_emb, 1) ).to(self.config.device) if not hasattr(self.config, 'use_zeros') or not self.config.use_zeros: self.disable_zeros_predictor() self.trainable: Dict[str, torch.nn.Module] = { "bert": self.bert, "we": self.we, "rough_scorer": self.rough_scorer, "pw": self.pw, "a_scorer": self.a_scorer, "sp": self.sp, "zeros_predictor": self.zeros_predictor } def disable_zeros_predictor(self): nn.init.zeros_(self.zeros_predictor[-1].weight) nn.init.zeros_(self.zeros_predictor[-1].bias) def _build_optimizers(self): n_docs = len(self._get_docs(self.config.train_data)) self.optimizers: Dict[str, torch.optim.Optimizer] = {} self.schedulers: Dict[str, torch.optim.lr_scheduler.LRScheduler] = {} if not getattr(self.config, 'lora', False): for param in self.bert.parameters(): param.requires_grad = self.config.bert_finetune if self.config.bert_finetune: logger.debug("Making bert optimizer with LR of %f", self.config.bert_learning_rate) self.optimizers["bert_optimizer"] = torch.optim.Adam( self.bert.parameters(), lr=self.config.bert_learning_rate ) start_finetuning = int(n_docs * self.config.bert_finetune_begin_epoch) if start_finetuning > 0: logger.info("Will begin finetuning transformer at iteration %d", start_finetuning) zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizers["bert_optimizer"], factor=0, total_iters=start_finetuning) warmup_scheduler = transformers.get_linear_schedule_with_warmup( self.optimizers["bert_optimizer"], start_finetuning, n_docs * self.config.train_epochs - start_finetuning) self.schedulers["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR( self.optimizers["bert_optimizer"], schedulers=[zero_scheduler, warmup_scheduler], milestones=[start_finetuning]) # Must ensure the same ordering of parameters between launches modules = sorted((key, value) for key, value in self.trainable.items() if key != "bert") params = [] for _, module in modules: for param in module.parameters(): param.requires_grad = True params.append(param) self.optimizers["general_optimizer"] = torch.optim.Adam( params, lr=self.config.learning_rate) self.schedulers["general_scheduler"] = \ transformers.get_linear_schedule_with_warmup( self.optimizers["general_optimizer"], 0, n_docs * self.config.train_epochs ) def _clusterize(self, doc: Doc, scores: torch.Tensor, top_indices: torch.Tensor, singletons: bool = True): if singletons: antecedents = scores[:,1:].argmax(dim=1) - 1 # set the dummy values to -1, so that they are not coref to themselves is_start = (scores[:, :2].argmax(dim=1) == 0) else: antecedents = scores.argmax(dim=1) - 1 not_dummy = antecedents >= 0 coref_span_heads = torch.arange(0, len(scores), device=not_dummy.device)[not_dummy] antecedents = top_indices[coref_span_heads, antecedents[not_dummy]] nodes = [GraphNode(i) for i in range(len(doc["cased_words"]))] for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()): nodes[i].link(nodes[j]) assert nodes[i] is not nodes[j] visited = {} clusters = [] for node in nodes: if len(node.links) > 0 and not node.visited: cluster = [] stack = [node] while stack: current_node = stack.pop() current_node.visited = True cluster.append(current_node.id) stack.extend(link for link in current_node.links if not link.visited) assert len(cluster) > 1 for i in cluster: visited[i] = True clusters.append(sorted(cluster)) if singletons: # go through the is_start nodes; if no clusters contain that node # i.e. visited[i] == False, we add it as a singleton for indx, i in enumerate(is_start): if i and not visited.get(indx, False): clusters.append([indx]) return sorted(clusters) def _get_docs(self, path: str) -> List[Doc]: if path not in self._docs: self._docs[path] = CorefDataset(path, self.config, self.tokenizer) return self._docs[path] @staticmethod def _get_ground_truth(cluster_ids: torch.Tensor, top_indices: torch.Tensor, valid_pair_map: torch.Tensor, cluster_starts: bool, singletons:bool = True) -> torch.Tensor: """ Args: cluster_ids: tensor of shape [n_words], containing cluster indices for each word. Non-gold words have cluster id of zero. top_indices: tensor of shape [n_words, n_ants], indices of antecedents of each word valid_pair_map: boolean tensor of shape [n_words, n_ants], whether for pair at [i, j] (i-th word and j-th word) j < i is True Returns: tensor of shape [n_words, n_ants + 1] (dummy added), containing 1 at position [i, j] if i-th and j-th words corefer. """ y = cluster_ids[top_indices] * valid_pair_map # [n_words, n_ants] y[y == 0] = -1 # -1 for non-gold words y = utils.add_dummy(y) # [n_words, n_cands + 1] if singletons: if not cluster_starts: unique, counts = cluster_ids.unique(return_counts=True) singleton_clusters = unique[(counts == 1) & (unique != 0)] first_corefs = [(cluster_ids == i).nonzero().flatten()[0] for i in singleton_clusters] if len(first_corefs) > 0: first_coref = torch.stack(first_corefs) else: first_coref = torch.tensor([]).to(cluster_ids.device).long() else: # I apologize for this abuse of everything that's good about PyTorch. # in essence, this line finds the INDEX of FIRST OCCURENCE of each NON-ZERO value # from cluster_ids. We need this information because we use it to mark the # special "is-start-of-ref" marker used to detect singletons. first_coref = (cluster_ids == cluster_ids.unique().sort().values[1:].unsqueeze(1) ).float().topk(k=1, dim=1).indices.squeeze() y = (y == cluster_ids.unsqueeze(1)) # True if coreferent # For all rows with no gold antecedents setting dummy to True y[y.sum(dim=1) == 0, 0] = True if singletons: # add another dummy for first coref y = utils.add_dummy(y) # [n_words, n_cands + 2] # for all rows that's a first coref, setting its dummy to True and unset the # non-coref dummy to false y[first_coref, 0] = True y[first_coref, 1] = False return y.to(torch.float) @staticmethod def _load_config(config_path: str, section: str) -> Config: with open(config_path, "rb") as fin: config = tomllib.load(fin) default_section = config["DEFAULT"] current_section = config[section] unknown_keys = (set(current_section.keys()) - set(default_section.keys())) if unknown_keys: raise ValueError(f"Unexpected config keys: {unknown_keys}") return Config(section, **{**default_section, **current_section}) def _set_training(self, value: bool): self._training = value for module in self.trainable.values(): module.train(self._training) ================================================ FILE: stanza/models/coref/pairwise_encoder.py ================================================ """ Describes PairwiseEncodes, that transforms pairwise features, such as distance between the mentions, same/different speaker into feature embeddings """ from typing import List import torch from stanza.models.coref.config import Config from stanza.models.coref.const import Doc class PairwiseEncoder(torch.nn.Module): """ A Pytorch module to obtain feature embeddings for pairwise features Usage: encoder = PairwiseEncoder(config) pairwise_features = encoder(pair_indices, doc) """ def __init__(self, config: Config): super().__init__() emb_size = config.embedding_size self.genre2int = {g: gi for gi, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} self.genre_emb = torch.nn.Embedding(len(self.genre2int), emb_size) # each position corresponds to a bucket: # [(0, 2), (2, 3), (3, 4), (4, 5), (5, 8), # (8, 16), (16, 32), (32, 64), (64, float("inf"))] self.distance_emb = torch.nn.Embedding(9, emb_size) # two possibilities: same vs different speaker self.speaker_emb = torch.nn.Embedding(2, emb_size) self.dropout = torch.nn.Dropout(config.dropout_rate) self.__full_pw = config.full_pairwise if self.__full_pw: self.shape = emb_size * 2 # distance, speaker else: self.shape = emb_size # distance only @property def device(self) -> torch.device: """ A workaround to get current device (which is assumed to be the device of the first parameter of one of the submodules) """ return next(self.genre_emb.parameters()).device def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch top_indices: torch.Tensor, doc: Doc) -> torch.Tensor: word_ids = torch.arange(0, len(doc["cased_words"]), device=self.device) # bucketing the distance (see __init__()) distance = (word_ids.unsqueeze(1) - word_ids[top_indices] ).clamp_min_(min=1) log_distance = distance.to(torch.float).log2().floor_() log_distance = log_distance.clamp_max_(max=6).to(torch.long) distance = torch.where(distance < 5, distance - 1, log_distance + 2) distance = self.distance_emb(distance) if not self.__full_pw: return self.dropout(distance) # calculate speaker embeddings speaker_map = torch.tensor(self._speaker_map(doc), device=self.device) same_speaker = (speaker_map[top_indices] == speaker_map.unsqueeze(1)) same_speaker = self.speaker_emb(same_speaker.to(torch.long)) return self.dropout(torch.cat((same_speaker, distance), dim=2)) @staticmethod def _speaker_map(doc: Doc) -> List[int]: """ Returns a tensor where i-th element is the speaker id of i-th word. """ # if speaker is not found in the doc, simply return "speaker#1" for all the speakers # and embed them using the same ID # speaker string -> speaker id str2int = {s: i for i, s in enumerate(set(doc.get("speaker", ["speaker#1" for _ in range(len(doc["cased_words"]))])))} # word id -> speaker id return [str2int[s] for s in doc.get("speaker", ["speaker#1" for _ in range(len(doc["cased_words"]))])] ================================================ FILE: stanza/models/coref/predict.py ================================================ import argparse import json import torch from tqdm import tqdm from stanza.models.coref.model import CorefModel if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("experiment") argparser.add_argument("input_file") argparser.add_argument("output_file") argparser.add_argument("--config-file", default="config.toml") argparser.add_argument("--batch-size", type=int, help="Adjust to override the config value if you're" " experiencing out-of-memory issues") argparser.add_argument("--weights", help="Path to file with weights to load." " If not supplied, in the latest" " weights of the experiment will be loaded;" " if there aren't any, an error is raised.") args = argparser.parse_args() model = CorefModel.load_model(path=args.weights, map_location="cpu", ignore={"bert_optimizer", "general_optimizer", "bert_scheduler", "general_scheduler"}) if args.batch_size: model.config.a_scoring_batch_size = args.batch_size model.training = False try: with open(args.input_file, encoding="utf-8") as fin: input_data = json.load(fin) except json.decoder.JSONDecodeError: # read the old jsonlines format if necessary with open(args.input_file, encoding="utf-8") as fin: text = "[" + ",\n".join(fin) + "]" input_data = json.loads(text) docs = [model.build_doc(doc) for doc in input_data] with torch.no_grad(): for doc in tqdm(docs, unit="docs"): result = model.run(doc) doc["span_clusters"] = result.span_clusters doc["word_clusters"] = result.word_clusters for key in ("word2subword", "subwords", "word_id", "head2span"): del doc[key] with open(args.output_file, mode="w") as fout: for doc in docs: json.dump(doc, fout) ================================================ FILE: stanza/models/coref/rough_scorer.py ================================================ """ Describes RoughScorer, a simple bilinear module to calculate rough anaphoricity scores. """ from typing import Tuple import torch from stanza.models.coref.config import Config class RoughScorer(torch.nn.Module): """ Is needed to give a roughly estimate of the anaphoricity of two candidates, only top scoring candidates are considered on later steps to reduce computational complexity. """ def __init__(self, features: int, config: Config): super().__init__() self.dropout = torch.nn.Dropout(config.dropout_rate) self.bilinear = torch.nn.Linear(features, features) self.k = config.rough_k def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch mentions: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns rough anaphoricity scores for candidates, which consist of the bilinear output of the current model summed with mention scores. """ # [n_mentions, n_mentions] pair_mask = torch.arange(mentions.shape[0]) pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0) pair_mask = torch.log((pair_mask > 0).to(torch.float)) pair_mask = pair_mask.to(mentions.device) bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) rough_scores = pair_mask + bilinear_scores return self._prune(rough_scores) def _prune(self, rough_scores: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Selects top-k rough antecedent scores for each mention. Args: rough_scores: tensor of shape [n_mentions, n_mentions], containing rough antecedent scores of each mention-antecedent pair. Returns: FloatTensor of shape [n_mentions, k], top rough scores LongTensor of shape [n_mentions, k], top indices """ top_scores, indices = torch.topk(rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False) return top_scores, indices, rough_scores ================================================ FILE: stanza/models/coref/span_predictor.py ================================================ """ Describes SpanPredictor which aims to predict spans by taking as input head word and context embeddings. """ from typing import List, Optional, Tuple from stanza.models.coref.const import Doc, Span import torch class SpanPredictor(torch.nn.Module): def __init__(self, input_size: int, distance_emb_size: int): super().__init__() self.ffnn = torch.nn.Sequential( torch.nn.Linear(input_size * 2 + 64, input_size), torch.nn.ReLU(), torch.nn.Dropout(0.3), torch.nn.Linear(input_size, 256), torch.nn.ReLU(), torch.nn.Dropout(0.3), torch.nn.Linear(256, 64), ) self.conv = torch.nn.Sequential( torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) ) self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far @property def device(self) -> torch.device: """ A workaround to get current device (which is assumed to be the device of the first parameter of one of the submodules) """ return next(self.ffnn.parameters()).device def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch doc: Doc, words: torch.Tensor, heads_ids: torch.Tensor) -> torch.Tensor: """ Calculates span start/end scores of words for each span head in heads_ids Args: doc (Doc): the document data words (torch.Tensor): contextual embeddings for each word in the document, [n_words, emb_size] heads_ids (torch.Tensor): word indices of span heads Returns: torch.Tensor: span start/end scores, [n_heads, n_words, 2] """ # Obtain distance embedding indices, [n_heads, n_words] relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0)) emb_ids = relative_positions + 63 # make all valid distances positive emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # "too_far" # Obtain "same sentence" boolean mask, [n_heads, n_words] sent_id = torch.tensor(doc["sent_id"], device=words.device) same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) # To save memory, only pass candidates from one sentence for each head # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # for each candidate among the words in the same sentence as span_head # [n_heads, input_size * 2 + distance_emb_size] rows, cols = same_sent.nonzero(as_tuple=True) pair_matrix = torch.cat(( words[heads_ids[rows]], words[cols], self.emb(emb_ids[rows, cols]), ), dim=1) lengths = same_sent.sum(dim=1) padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0) padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # This is necessary to allow the convolution layer to look at several # word scores padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device) padded_pairs[padding_mask] = pair_matrix res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output] res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2] scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device) scores[rows, cols] = res[padding_mask] # Make sure that start <= head <= end during inference if not self.training: valid_starts = torch.log((relative_positions >= 0).to(torch.float)) valid_ends = torch.log((relative_positions <= 0).to(torch.float)) valid_positions = torch.stack((valid_starts, valid_ends), dim=2) return scores + valid_positions return scores def get_training_data(self, doc: Doc, words: torch.Tensor ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Returns span starts/ends for gold mentions in the document. """ head2span = sorted(doc["head2span"]) if not head2span: return None, None heads, starts, ends = zip(*head2span) heads = torch.tensor(heads, device=self.device) starts = torch.tensor(starts, device=self.device) ends = torch.tensor(ends, device=self.device) - 1 return self(doc, words, heads), (starts, ends) def predict(self, doc: Doc, words: torch.Tensor, clusters: List[List[int]]) -> List[List[Span]]: """ Predicts span clusters based on the word clusters. Args: doc (Doc): the document data words (torch.Tensor): [n_words, emb_size] matrix containing embeddings for each of the words in the text clusters (List[List[int]]): a list of clusters where each cluster is a list of word indices Returns: List[List[Span]]: span clusters """ if not clusters: return [] heads_ids = torch.tensor( sorted(i for cluster in clusters for i in cluster), device=self.device ) scores = self(doc, words, heads_ids) starts = scores[:, :, 0].argmax(dim=1).tolist() ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist() head2span = { head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends) } return [[head2span[head] for head in cluster] for cluster in clusters] ================================================ FILE: stanza/models/coref/tokenizer_customization.py ================================================ """ This file defines functions used to modify the default behaviour of transformers.AutoTokenizer. These changes are necessary, because some tokenizers are meant to be used with raw text, while the OntoNotes documents have already been split into words. All the functions are used in coref_model.CorefModel._get_docs. """ # Filters out unwanted tokens produced by the tokenizer TOKENIZER_FILTERS = { "albert-xxlarge-v2": (lambda token: token != "▁"), # U+2581, not just "_" "albert-large-v2": (lambda token: token != "▁"), } # Maps some words to tokens directly, without a tokenizer TOKENIZER_MAPS = { "roberta-large": {".": ["."], ",": [","], "!": ["!"], "?": ["?"], ":":[":"], ";":[";"], "'s": ["'s"]} } ================================================ FILE: stanza/models/coref/utils.py ================================================ """ Contains functions not directly linked to coreference resolution """ from typing import List, Set import torch import torch.nn.functional as F from stanza.models.coref.const import EPSILON class GraphNode: def __init__(self, node_id: int): self.id = node_id self.links: Set[GraphNode] = set() self.visited = False def link(self, another: "GraphNode"): self.links.add(another) another.links.add(self) def __repr__(self) -> str: return str(self.id) def add_dummy(tensor: torch.Tensor, eps: bool = False): """ Prepends zeros (or a very small value if eps is True) to the first (not zeroth) dimension of tensor. """ kwargs = dict(device=tensor.device, dtype=tensor.dtype) shape: List[int] = list(tensor.shape) shape[1] = 1 if not eps: dummy = torch.zeros(shape, **kwargs) # type: ignore else: dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore return torch.cat((dummy, tensor), dim=1) def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = 0.25, gamma: float = 2, reduction: str = "none", ) -> torch.Tensor: """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs (Tensor): A float tensor of arbitrary shape. The predictions for each example. targets (Tensor): A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha (float): Weighting factor in range [0, 1] to balance positive vs negative examples or -1 for ignore. Default: ``0.25``. gamma (float): Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Default: ``2``. reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` ``'none'``: No reduction will be applied to the output. ``'mean'``: The output will be averaged. ``'sum'``: The output will be summed. Default: ``'none'``. Returns: Loss tensor with the reduction option applied. """ # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py if not (0 <= alpha <= 1) and alpha != -1: raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.") p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss # Check reduction option and return loss accordingly if reduction == "none": pass elif reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() else: raise ValueError( f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" ) return loss ================================================ FILE: stanza/models/coref/word_encoder.py ================================================ """ Describes WordEncoder. Extracts mention vectors from bert-encoded text. """ from typing import Tuple import torch from stanza.models.coref.config import Config from stanza.models.coref.const import Doc class WordEncoder(torch.nn.Module): # pylint: disable=too-many-instance-attributes """ Receives bert contextual embeddings of a text, extracts all the possible mentions in that text. """ def __init__(self, features: int, config: Config): """ Args: features (int): the number of featues in the input embeddings config (Config): the configuration of the current session """ super().__init__() self.attn = torch.nn.Linear(in_features=features, out_features=1) self.dropout = torch.nn.Dropout(config.dropout_rate) @property def device(self) -> torch.device: """ A workaround to get current device (which is assumed to be the device of the first parameter of one of the submodules) """ return next(self.attn.parameters()).device def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch doc: Doc, x: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: """ Extracts word representations from text. Args: doc: the document data x: a tensor containing bert output, shape (n_subtokens, bert_dim) Returns: words: a Tensor of shape [n_words, mention_emb]; mention representations cluster_ids: tensor of shape [n_words], containing cluster indices for each word. Non-coreferent words have cluster id of zero. """ word_boundaries = torch.tensor(doc["word2subword"], device=self.device) starts = word_boundaries[:, 0] ends = word_boundaries[:, 1] # [n_mentions, features] words = self._attn_scores(x, starts, ends).mm(x) words = self.dropout(words) return (words, self._cluster_ids(doc)) def _attn_scores(self, bert_out: torch.Tensor, word_starts: torch.Tensor, word_ends: torch.Tensor) -> torch.Tensor: """ Calculates attention scores for each of the mentions. Args: bert_out (torch.Tensor): [n_subwords, bert_emb], bert embeddings for each of the subwords in the document word_starts (torch.Tensor): [n_words], start indices of words word_ends (torch.Tensor): [n_words], end indices of words Returns: torch.Tensor: [description] """ n_subtokens = len(bert_out) n_words = len(word_starts) # [n_mentions, n_subtokens] # with 0 at positions belonging to the words and -inf elsewhere attn_mask = torch.arange(0, n_subtokens, device=self.device).expand((n_words, n_subtokens)) attn_mask = ((attn_mask >= word_starts.unsqueeze(1)) * (attn_mask < word_ends.unsqueeze(1))) # if first row all False, set col 0 to True # otherwise, set the row to be the previous row? word_lengths = torch.sum(attn_mask, dim=1) if torch.any(word_lengths == 0): raise ValueError("Found a blank word in training data! This will break everything, starting with the attention masks, as some rows of the scoring table will be set to entirely -inf and then softmax to NaN.") attn_mask = torch.log(attn_mask.to(torch.float)) attn_scores = self.attn(bert_out).T # [1, n_subtokens] attn_scores = attn_scores.expand((n_words, n_subtokens)) attn_scores = attn_mask + attn_scores del attn_mask return torch.softmax(attn_scores, dim=1) # [n_words, n_subtokens] def _cluster_ids(self, doc: Doc) -> torch.Tensor: """ Args: doc: document information Returns: torch.Tensor of shape [n_word], containing cluster indices for each word. Non-coreferent words have cluster id of zero. """ word2cluster = {word_i: i for i, cluster in enumerate(doc["word_clusters"], start=1) for word_i in cluster} return torch.tensor( [word2cluster.get(word_i, 0) for word_i in range(len(doc["cased_words"]))], device=self.device ) ================================================ FILE: stanza/models/depparse/__init__.py ================================================ ================================================ FILE: stanza/models/depparse/data.py ================================================ import random import logging import torch from stanza.models.common.bert_embedding import filter_data, needs_length_filter from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all from stanza.models.common.utils import DEFAULT_WORD_CUTOFF, simplify_punct from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, ROOT_ID, CompositeVocab, CharVocab from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory from stanza.models.common.doc import * logger = logging.getLogger('stanza') def data_to_batches(data, batch_size, eval_mode, sort_during_eval, min_length_to_batch_separately): """ Given a list of lists, where the first element of each sublist represents the sentence, group the sentences into batches. During training mode (not eval_mode) the sentences are sorted by length with a bit of random shuffling. During eval mode, the sentences are sorted by length if sort_during_eval is true. Refactored from the data structure in case other models could use it and for ease of testing. Returns (batches, original_order), where original_order is None when in train mode or when unsorted and represents the original location of each sentence in the sort """ res = [] if not eval_mode: # sort sentences (roughly) by length for better memory utilization data = sorted(data, key = lambda x: len(x[0]), reverse=random.random() > .5) data_orig_idx = None elif sort_during_eval: (data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data]) else: data_orig_idx = None current = [] currentlen = 0 for x in data: if min_length_to_batch_separately is not None and len(x[0]) > min_length_to_batch_separately: if currentlen > 0: res.append(current) current = [] currentlen = 0 res.append([x]) else: if len(x[0]) + currentlen > batch_size and currentlen > 0: res.append(current) current = [] currentlen = 0 current.append(x) currentlen += len(x[0]) if currentlen > 0: res.append(current) return res, data_orig_idx class DataLoader: def __init__(self, doc, batch_size, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, min_length_to_batch_separately=None, bert_tokenizer=None): self.batch_size = batch_size self.min_length_to_batch_separately=min_length_to_batch_separately self.args = args self.eval = evaluation self.shuffled = not self.eval self.sort_during_eval = sort_during_eval self.doc = doc data = self.load_doc(doc) # handle vocab if vocab is None: self.vocab = self.init_vocab(data) else: self.vocab = vocab # filter out the long sentences if bert is used if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']): data = filter_data(self.args['bert_model'], data, bert_tokenizer) # handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None self.pretrain_vocab = None if pretrain is not None and args['pretrain']: self.pretrain_vocab = pretrain.vocab # filter and sample data if args.get('sample_train', 1.0) < 1.0 and not self.eval: keep = int(args['sample_train'] * len(data)) data = random.sample(data, keep) logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) data = self.preprocess(data, self.vocab, self.pretrain_vocab, args) # shuffle for training if self.shuffled: random.shuffle(data) self.num_examples = len(data) # chunk into batches self.data = self.chunk_batches(data) logger.debug("{} batches created.".format(len(self.data))) def init_vocab(self, data): assert self.eval == False # for eval vocab must exist cutoff = self.args['word_cutoff'] if self.args.get('word_cutoff') is not None else DEFAULT_WORD_CUTOFF charvocab = CharVocab(data, self.args['shorthand']) wordvocab = WordVocab(data, self.args['shorthand'], cutoff=cutoff, lower=True) uposvocab = WordVocab(data, self.args['shorthand'], idx=1) xposvocab = xpos_vocab_factory(data, self.args['shorthand']) featsvocab = FeatureVocab(data, self.args['shorthand'], idx=3) lemmavocab = WordVocab(data, self.args['shorthand'], cutoff=cutoff, idx=4, lower=True) deprelvocab = WordVocab(data, self.args['shorthand'], idx=6) vocab = MultiVocab({'char': charvocab, 'word': wordvocab, 'upos': uposvocab, 'xpos': xposvocab, 'feats': featsvocab, 'lemma': lemmavocab, 'deprel': deprelvocab}) return vocab def preprocess(self, data, vocab, pretrain_vocab, args): processed = [] xpos_replacement = [[ROOT_ID] * len(vocab['xpos'])] if isinstance(vocab['xpos'], CompositeVocab) else [ROOT_ID] feats_replacement = [[ROOT_ID] * len(vocab['feats'])] for sent in data: processed_sent = [[ROOT_ID] + vocab['word'].map([w[0] for w in sent])] processed_sent += [[[ROOT_ID]] + [vocab['char'].map([x for x in w[0]]) for w in sent]] processed_sent += [[ROOT_ID] + vocab['upos'].map([w[1] for w in sent])] processed_sent += [xpos_replacement + vocab['xpos'].map([w[2] for w in sent])] processed_sent += [feats_replacement + vocab['feats'].map([w[3] for w in sent])] if pretrain_vocab is not None: # always use lowercase lookup in pretrained vocab processed_sent += [[ROOT_ID] + pretrain_vocab.map([w[0].lower() for w in sent])] else: processed_sent += [[ROOT_ID] + [PAD_ID] * len(sent)] processed_sent += [[ROOT_ID] + vocab['lemma'].map([w[4] for w in sent])] processed_sent += [[to_int(w[5], ignore_error=self.eval) for w in sent]] processed_sent += [vocab['deprel'].map([w[6] for w in sent])] processed_sent.append([w[0] for w in sent]) processed.append(processed_sent) return processed def __len__(self): return len(self.data) def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 10 # sort sentences by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # sort words by lens for easy char-RNN operations batch_words = [w for sent in batch[1] for w in sent] word_lens = [len(x) for x in batch_words] batch_words, word_orig_idx = sort_all([batch_words], word_lens) batch_words = batch_words[0] word_lens = [len(x) for x in batch_words] # convert to tensors words = batch[0] words = get_long_tensor(words, batch_size) words_mask = torch.eq(words, PAD_ID) wordchars = get_long_tensor(batch_words, len(word_lens)) wordchars_mask = torch.eq(wordchars, PAD_ID) upos = get_long_tensor(batch[2], batch_size) xpos = get_long_tensor(batch[3], batch_size) ufeats = get_long_tensor(batch[4], batch_size) pretrained = get_long_tensor(batch[5], batch_size) sentlens = [len(x) for x in batch[0]] lemma = get_long_tensor(batch[6], batch_size) head = get_long_tensor(batch[7], batch_size) deprel = get_long_tensor(batch[8], batch_size) text = batch[9] return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, orig_idx, word_orig_idx, sentlens, word_lens, text def load_doc(self, doc): data = doc.get([TEXT, UPOS, XPOS, FEATS, LEMMA, HEAD, DEPREL], as_sentences=True) data = self.resolve_none(data) data = simplify_punct(data) return data def resolve_none(self, data): # replace None to '_' for sent_idx in range(len(data)): for tok_idx in range(len(data[sent_idx])): for feat_idx in range(len(data[sent_idx][tok_idx])): if data[sent_idx][tok_idx][feat_idx] is None: data[sent_idx][tok_idx][feat_idx] = '_' return data def __iter__(self): for i in range(self.__len__()): yield self.__getitem__(i) def set_batch_size(self, batch_size): self.batch_size = batch_size def reshuffle(self): data = [y for x in self.data for y in x] self.data = self.chunk_batches(data) random.shuffle(self.data) def chunk_batches(self, data): batches, data_orig_idx = data_to_batches(data=data, batch_size=self.batch_size, eval_mode=self.eval, sort_during_eval=self.sort_during_eval, min_length_to_batch_separately=self.min_length_to_batch_separately) # data_orig_idx might be None at train time, since we don't anticipate unsorting self.data_orig_idx = data_orig_idx return batches def to_int(string, ignore_error=False): try: res = int(string) except ValueError as err: if ignore_error: return 0 else: raise err return res ================================================ FILE: stanza/models/depparse/model.py ================================================ import logging import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence from stanza.models.common.bert_embedding import extract_bert_embeddings from stanza.models.common.biaffine import DeepBiaffineScorer from stanza.models.common.foundation_cache import load_charlm from stanza.models.common.hlstm import HighwayLSTM from stanza.models.common.dropout import WordDropout from stanza.models.common.utils import attach_bert_model from stanza.models.common.vocab import CompositeVocab from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel from stanza.models.common import utils logger = logging.getLogger('stanza') class Parser(nn.Module): def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None): super().__init__() self.vocab = vocab self.args = args self.unsaved_modules = [] # input layers input_size = 0 if self.args['word_emb_dim'] > 0: # frequent word embeddings self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0) self.lemma_emb = nn.Embedding(len(vocab['lemma']), self.args['word_emb_dim'], padding_idx=0) input_size += self.args['word_emb_dim'] * 2 if self.args['tag_emb_dim'] > 0: if self.args.get('use_upos', True): self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0) if self.args.get('use_xpos', True): if not isinstance(vocab['xpos'], CompositeVocab): self.xpos_emb = nn.Embedding(len(vocab['xpos']), self.args['tag_emb_dim'], padding_idx=0) else: self.xpos_emb = nn.ModuleList() for l in vocab['xpos'].lens(): self.xpos_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0)) if self.args.get('use_upos', True) or self.args.get('use_xpos', True): input_size += self.args['tag_emb_dim'] if self.args.get('use_ufeats', True): self.ufeats_emb = nn.ModuleList() for l in vocab['feats'].lens(): self.ufeats_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0)) input_size += self.args['tag_emb_dim'] if self.args['char'] and self.args['char_emb_dim'] > 0: if self.args.get('charlm', None): if self.args['charlm_forward_file'] is None or not os.path.exists(self.args['charlm_forward_file']): raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(self.args['charlm_forward_file'])) if self.args['charlm_backward_file'] is None or not os.path.exists(self.args['charlm_backward_file']): raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(self.args['charlm_backward_file'])) logger.debug("Depparse model loading charmodels: %s and %s", self.args['charlm_forward_file'], self.args['charlm_backward_file']) self.add_unsaved_module('charmodel_forward', load_charlm(self.args['charlm_forward_file'], foundation_cache=foundation_cache)) self.add_unsaved_module('charmodel_backward', load_charlm(self.args['charlm_backward_file'], foundation_cache=foundation_cache)) input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim() else: self.charmodel = CharacterModel(self.args, vocab) self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False) input_size += self.args['transformed_dim'] self.peft_name = peft_name attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved) if self.args.get('bert_model', None): # TODO: refactor bert_hidden_layers between the different models if self.args.get('bert_hidden_layers', False): # The average will be offset by 1/N so that the default zeros # represents an average of the N layers self.bert_layer_mix = nn.Linear(self.args['bert_hidden_layers'], 1, bias=False) nn.init.zeros_(self.bert_layer_mix.weight) else: # an average of layers 2, 3, 4 will be used # (for historic reasons) self.bert_layer_mix = None input_size += self.bert_model.config.hidden_size if self.args['pretrain']: # pretrained embeddings, by default this won't be saved into model file self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True)) self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False) input_size += self.args['transformed_dim'] # recurrent layers self.parserlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh) self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size)) self.parserlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim'])) self.parserlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim'])) # dropout self.drop = nn.Dropout(self.args['dropout']) self.worddrop = WordDropout(self.args['word_dropout']) # classifiers # args.get to preserve old models, including models other people might have created if self.args.get('use_arc_embedding'): logger.debug("Using arc embedding enhancement") self.arc_embedding = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], self.args['deep_biaff_output_dim'], pairwise=True, dropout=self.args['dropout']) self.unlabeled_linear = nn.Sequential(self.drop, nn.Linear(self.args['deep_biaff_output_dim'], 1)) self.deprel_linear = nn.Sequential(self.drop, nn.Linear(self.args['deep_biaff_output_dim'], 2 * self.args['deep_biaff_output_dim']), nn.ReLU(), self.drop, nn.Linear(self.args['deep_biaff_output_dim'] * 2, len(vocab['deprel']))) else: logger.debug("Not using arc embedding enhancement") self.unlabeled = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout']) self.deprel = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], len(vocab['deprel']), pairwise=True, dropout=self.args['dropout']) if self.args['linearization']: self.linearization = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout']) if self.args['distance']: self.distance = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout']) # criterion self.crit = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') # ignore padding def add_unsaved_module(self, name, module): self.unsaved_modules += [name] setattr(self, name, module) def log_norms(self): utils.log_norms(self) def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text): def pack(x): return pack_padded_sequence(x, sentlens, batch_first=True) inputs = [] if self.args['pretrain']: pretrained_emb = self.pretrained_emb(pretrained) pretrained_emb = self.trans_pretrained(pretrained_emb) pretrained_emb = pack(pretrained_emb) inputs += [pretrained_emb] #def pad(x): # return pad_packed_sequence(PackedSequence(x, pretrained_emb.batch_sizes), batch_first=True)[0] if self.args['word_emb_dim'] > 0: word_emb = self.word_emb(word) word_emb = pack(word_emb) lemma_emb = self.lemma_emb(lemma) lemma_emb = pack(lemma_emb) inputs += [word_emb, lemma_emb] if self.args['tag_emb_dim'] > 0: if self.args.get('use_upos', True): pos_emb = self.upos_emb(upos) else: pos_emb = 0 if self.args.get('use_xpos', True): if isinstance(self.vocab['xpos'], CompositeVocab): for i in range(len(self.vocab['xpos'])): pos_emb += self.xpos_emb[i](xpos[:, :, i]) else: pos_emb += self.xpos_emb(xpos) if self.args.get('use_upos', True) or self.args.get('use_xpos', True): pos_emb = pack(pos_emb) inputs += [pos_emb] if self.args.get('use_ufeats', True): feats_emb = 0 for i in range(len(self.vocab['feats'])): feats_emb += self.ufeats_emb[i](ufeats[:, :, i]) feats_emb = pack(feats_emb) inputs += [pos_emb] if self.args['char'] and self.args['char_emb_dim'] > 0: if self.args.get('charlm', None): # \n is to add a somewhat neutral "word" for the ROOT charlm_text = [["\n"] + x for x in text] all_forward_chars = self.charmodel_forward.build_char_representation(charlm_text) all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True)) all_backward_chars = self.charmodel_backward.build_char_representation(charlm_text) all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True)) inputs += [all_forward_chars, all_backward_chars] else: char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens) char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes) inputs += [char_reps] if self.bert_model is not None: device = next(self.parameters()).device processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=True, num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None, detach=not self.args.get('bert_finetune', False) or not self.training, peft_name=self.peft_name) if self.bert_layer_mix is not None: # use a linear layer to weighted average the embedding dynamically processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert] # we are using the first endpoint from the transformer as the "word" for ROOT processed_bert = [x[:-1, :] for x in processed_bert] processed_bert = pad_sequence(processed_bert, batch_first=True) inputs += [pack(processed_bert)] lstm_inputs = torch.cat([x.data for x in inputs], 1) lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement) lstm_inputs = self.drop(lstm_inputs) lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes) lstm_outputs, _ = self.parserlstm(lstm_inputs, sentlens, hx=(self.parserlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.parserlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous())) lstm_outputs, _ = pad_packed_sequence(lstm_outputs, batch_first=True) if self.args.get('use_arc_embedding'): arc_scores = self.arc_embedding(self.drop(lstm_outputs), self.drop(lstm_outputs)) unlabeled_scores = self.unlabeled_linear(arc_scores).squeeze(3) deprel_scores = self.deprel_linear(arc_scores).squeeze(3) else: unlabeled_scores = self.unlabeled(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3) deprel_scores = self.deprel(self.drop(lstm_outputs), self.drop(lstm_outputs)) #goldmask = head.new_zeros(*head.size(), head.size(-1)+1, dtype=torch.uint8) #goldmask.scatter_(2, head.unsqueeze(2), 1) if self.args['linearization'] or self.args['distance']: head_offset = torch.arange(word.size(1), device=head.device).view(1, 1, -1).expand(word.size(0), -1, -1) - torch.arange(word.size(1), device=head.device).view(1, -1, 1).expand(word.size(0), -1, -1) if self.args['linearization']: lin_scores = self.linearization(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3) unlabeled_scores += F.logsigmoid(lin_scores * torch.sign(head_offset).float()).detach() if self.args['distance']: dist_scores = self.distance(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3) dist_pred = 1 + F.softplus(dist_scores) dist_target = torch.abs(head_offset) dist_kld = -torch.log((dist_target.float() - dist_pred)**2/2 + 1) unlabeled_scores += dist_kld.detach() diag = torch.eye(head.size(-1)+1, dtype=torch.bool, device=head.device).unsqueeze(0) unlabeled_scores.masked_fill_(diag, -float('inf')) preds = [] if self.training: unlabeled_scores = unlabeled_scores[:, 1:, :] # exclude attachment for the root symbol unlabeled_scores = unlabeled_scores.masked_fill(word_mask.unsqueeze(1), -float('inf')) unlabeled_target = head.masked_fill(word_mask[:, 1:], -1) loss = self.crit(unlabeled_scores.contiguous().view(-1, unlabeled_scores.size(2)), unlabeled_target.view(-1)) deprel_scores = deprel_scores[:, 1:] # exclude attachment for the root symbol #deprel_scores = deprel_scores.masked_select(goldmask.unsqueeze(3)).view(-1, len(self.vocab['deprel'])) deprel_scores = torch.gather(deprel_scores, 2, head.unsqueeze(2).unsqueeze(3).expand(-1, -1, -1, len(self.vocab['deprel']))).view(-1, len(self.vocab['deprel'])) deprel_target = deprel.masked_fill(word_mask[:, 1:], -1) loss += self.crit(deprel_scores.contiguous(), deprel_target.view(-1)) if self.args['linearization']: #lin_scores = lin_scores[:, 1:].masked_select(goldmask) lin_scores = torch.gather(lin_scores[:, 1:], 2, head.unsqueeze(2)).view(-1) lin_scores = torch.cat([-lin_scores.unsqueeze(1)/2, lin_scores.unsqueeze(1)/2], 1) #lin_target = (head_offset[:, 1:] > 0).long().masked_select(goldmask) lin_target = torch.gather((head_offset[:, 1:] > 0).long(), 2, head.unsqueeze(2)) loss += self.crit(lin_scores.contiguous(), lin_target.view(-1)) if self.args['distance']: #dist_kld = dist_kld[:, 1:].masked_select(goldmask) # dist_kld[:, 1:] so that the root isn't included in the distance calculation dist_kld = torch.gather(dist_kld[:, 1:], 2, head.unsqueeze(2)) loss -= dist_kld.sum() loss /= wordchars.size(0) # number of words else: loss = 0 preds.append(F.log_softmax(unlabeled_scores, 2).detach().cpu().numpy()) preds.append(deprel_scores.max(3)[1].detach().cpu().numpy()) return loss, preds ================================================ FILE: stanza/models/depparse/scorer.py ================================================ """ Utils and wrappers for scoring parsers. """ from collections import Counter import logging from stanza.models.common.utils import ud_scores logger = logging.getLogger('stanza') def score_named_dependencies(pred_doc, gold_doc, output_latex=False): if len(pred_doc.sentences) != len(gold_doc.sentences): logger.warning("Not evaluating individual dependency F1 on accound of document length mismatch") return for sent_idx, (x, y) in enumerate(zip(pred_doc.sentences, gold_doc.sentences)): if len(x.words) != len(y.words): logger.warning("Not evaluating individual dependency F1 on accound of sentence length mismatch") return tp = Counter() fp = Counter() fn = Counter() for pred_sentence, gold_sentence in zip(pred_doc.sentences, gold_doc.sentences): for pred_word, gold_word in zip(pred_sentence.words, gold_sentence.words): if pred_word.head == gold_word.head and pred_word.deprel == gold_word.deprel: tp[gold_word.deprel] = tp[gold_word.deprel] + 1 else: fn[gold_word.deprel] = fn[gold_word.deprel] + 1 fp[pred_word.deprel] = fp[pred_word.deprel] + 1 labels = sorted(set(tp.keys()).union(fp.keys()).union(fn.keys())) max_len = max(len(x) for x in labels) log_lines = [] #log_line_fmt = "%" + str(max_len) + "s: p %.4f r %.4f f1 %.4f (%d actual)" if output_latex: log_lines.append(r"\begin{tabular}{lrr}") log_lines.append(r"Reln & F1 & Total \\") log_line_fmt = "{label} & {f1:0.4f} & {actual} \\\\" else: log_line_fmt = "{label:>" + str(max_len) + "s}: p {precision:0.4f} r {recall:0.4f} f1 {f1:0.4f} ({actual} actual)" for label in labels: if tp[label] == 0: precision = 0 recall = 0 f1 = 0 else: precision = tp[label] / (tp[label] + fp[label]) recall = tp[label] / (tp[label] + fn[label]) f1 = 2 * (precision * recall) / (precision + recall) actual = tp[label] + fn[label] template = { 'label': label, 'precision': precision, 'recall': recall, 'f1': f1, 'actual': actual } log_lines.append(log_line_fmt.format(**template)) if output_latex: log_lines.append(r"\end{tabular}") logger.info("F1 scores for each dependency:\n Note that unlabeled attachment errors hurt the labeled attachment scores\n%s" % "\n".join(log_lines)) def score(system_conllu_file, gold_conllu_file, verbose=True): """ Wrapper for UD parser scorer. """ evaluation = ud_scores(gold_conllu_file, system_conllu_file) el = evaluation['LAS'] p = el.precision r = el.recall f = el.f1 if verbose: scores = [evaluation[k].f1 * 100 for k in ['LAS', 'MLAS', 'BLEX']] logger.info("LAS\tMLAS\tBLEX") logger.info("{:.2f}\t{:.2f}\t{:.2f}".format(*scores)) return p, r, f ================================================ FILE: stanza/models/depparse/trainer.py ================================================ """ A trainer class to handle training and testing of models. """ import copy import sys import logging import torch from torch import nn try: import transformers except ImportError: pass from stanza.models.common.trainer import Trainer as BaseTrainer from stanza.models.common import utils, loss from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache from stanza.models.common.chuliu_edmonds import chuliu_edmonds_one_root from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper from stanza.models.common.vocab import VOCAB_PREFIX_SIZE from stanza.models.depparse.model import Parser from stanza.models.pos.vocab import MultiVocab logger = logging.getLogger('stanza') def unpack_batch(batch, device): """ Unpack a batch from the data loader. """ inputs = [b.to(device) if b is not None else None for b in batch[:11]] orig_idx = batch[11] word_orig_idx = batch[12] sentlens = batch[13] wordlens = batch[14] text = batch[15] return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text class Trainer(BaseTrainer): """ A trainer for training models. """ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, foundation_cache=None, ignore_model_config=False, reset_history=False): self.global_step = 0 self.last_best_step = 0 self.dev_score_history = [] orig_args = copy.deepcopy(args) # whether the training is in primary or secondary stage # during FT (loading weights), etc., the training is considered to be in "secondary stage" # during this time, we (optionally) use a different set of optimizers than that during "primary stage". # # Regardless, we use TWO SETS of optimizers; once primary converges, we switch to secondary if model_file is not None: # load everything from file self.load(model_file, pretrain, args, foundation_cache, device) if reset_history: self.global_step = 0 self.last_best_step = 0 self.dev_score_history = [] else: # build model from scratch self.args = args self.vocab = vocab bert_model, bert_tokenizer = load_bert(self.args['bert_model']) peft_name = None if self.args['use_peft']: # fine tune the bert if we're using peft self.args['bert_finetune'] = True peft_name = "depparse" bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name) self.model = Parser(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name) self.model = self.model.to(device) self.__init_optim() self.fallback = self.vocab['deprel'].unit2id('dep') if 'dep' in self.vocab['deprel'] else None if ignore_model_config: self.args = orig_args if self.args.get('wandb'): import wandb # track gradients! wandb.watch(self.model, log_freq=4, log="all", log_graph=True) def __init_optim(self): # TODO: can get rid of args.get when models are rebuilt if (self.args.get("second_stage", False) and self.args.get('second_optim')): self.optimizer = utils.get_split_optimizer(self.args['second_optim'], self.model, self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6, bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0), is_peft=self.args.get('use_peft', False), bert_finetune_layers=self.args.get('bert_finetune_layers', None)) else: self.optimizer = utils.get_split_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0), weight_decay=self.args.get('weight_decay', None), bert_weight_decay=self.args.get('bert_weight_decay', 0.0), is_peft=self.args.get('use_peft', False), bert_finetune_layers=self.args.get('bert_finetune_layers', None)) self.scheduler = {} if self.args.get("second_stage", False) and self.args.get('second_optim'): if self.args.get('second_warmup_steps', None): for name, optimizer in self.optimizer.items(): name = name + "_scheduler" warmup_scheduler = transformers.get_constant_schedule_with_warmup(optimizer, self.args['second_warmup_steps']) self.scheduler[name] = warmup_scheduler else: if "bert_optimizer" in self.optimizer: zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer["bert_optimizer"], factor=0, total_iters=self.args['bert_start_finetuning']) warmup_scheduler = transformers.get_constant_schedule_with_warmup( self.optimizer["bert_optimizer"], self.args['bert_warmup_steps']) self.scheduler["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR( self.optimizer["bert_optimizer"], schedulers=[zero_scheduler, warmup_scheduler], milestones=[self.args['bert_start_finetuning']]) def update(self, batch, eval=False): device = next(self.model.parameters()).device inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device) word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs if eval: self.model.eval() else: self.model.train() for opt in self.optimizer.values(): opt.zero_grad() loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text) loss_val = loss.data.item() if eval: return loss_val loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) for opt in self.optimizer.values(): opt.step() for scheduler in self.scheduler.values(): scheduler.step() return loss_val def predict(self, batch, unsort=True): device = next(self.model.parameters()).device inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device) word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs self.model.eval() batch_size = word.size(0) _, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text) # TODO: would be cleaner for the model to not have the capability to produce predictions < VOCAB_PREFIX_SIZE if self.fallback is not None: preds[1][preds[1] < VOCAB_PREFIX_SIZE] = self.fallback head_seqs = [chuliu_edmonds_one_root(adj[:l, :l])[1:] for adj, l in zip(preds[0], sentlens)] # remove attachment for the root deprel_seqs = [self.vocab['deprel'].unmap([preds[1][i][j+1][h] for j, h in enumerate(hs)]) for i, hs in enumerate(head_seqs)] pred_tokens = [[[head_seqs[i][j], deprel_seqs[i][j]] for j in range(sentlens[i]-1)] for i in range(batch_size)] if unsort: pred_tokens = utils.unsort(pred_tokens, orig_idx) return pred_tokens def save(self, filename, skip_modules=True, save_optimizer=False): model_state = self.model.state_dict() # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file if skip_modules: skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules] for k in skipped: del model_state[k] params = { 'model': model_state, 'vocab': self.vocab.state_dict(), 'config': self.args, 'global_step': self.global_step, 'last_best_step': self.last_best_step, 'dev_score_history': self.dev_score_history, } if self.args.get('use_peft', False): # Hide import so that peft dependency is optional from peft import get_peft_model_state_dict params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name) if save_optimizer and self.optimizer is not None: params['optimizer_state_dict'] = {k: opt.state_dict() for k, opt in self.optimizer.items()} params['scheduler_state_dict'] = {k: scheduler.state_dict() for k, scheduler in self.scheduler.items()} try: torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) except BaseException as e: logger.warning("Saving failed... continuing anyway. Error was: %s" % e) def load(self, filename, pretrain, args=None, foundation_cache=None, device=None): """ Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input, and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args. """ try: checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise self.args = checkpoint['config'] if args is not None: self.args.update(args) # preserve old models which were created before transformers were added if 'bert_model' not in self.args: self.args['bert_model'] = None lora_weights = checkpoint.get('bert_lora') if lora_weights: logger.debug("Found peft weights for depparse; loading a peft adapter") self.args["use_peft"] = True self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) # load model emb_matrix = None if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None emb_matrix = pretrain.emb # TODO: refactor this common block of code with NER force_bert_saved = False peft_name = None if self.args.get('use_peft', False): force_bert_saved = True bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "depparse", foundation_cache) bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name) logger.debug("Loaded peft with name %s", peft_name) else: if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()): logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename) foundation_cache = NoTransformerFoundationCache(foundation_cache) force_bert_saved = True bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache) self.model = Parser(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name) self.model.load_state_dict(checkpoint['model'], strict=False) if device is not None: self.model = self.model.to(device) self.__init_optim() optim_state_dict = checkpoint.get("optimizer_state_dict") if optim_state_dict: for k, state in optim_state_dict.items(): self.optimizer[k].load_state_dict(state) scheduler_state_dict = checkpoint.get("scheduler_state_dict") if scheduler_state_dict: for k, state in scheduler_state_dict.items(): self.scheduler[k].load_state_dict(state) self.global_step = checkpoint.get("global_step", 0) self.last_best_step = checkpoint.get("last_best_step", 0) self.dev_score_history = checkpoint.get("dev_score_history", list()) ================================================ FILE: stanza/models/identity_lemmatizer.py ================================================ """ An identity lemmatizer that mimics the behavior of a normal lemmatizer but directly uses word as lemma. """ import os import argparse import logging import random from stanza.models.lemma.data import DataLoader from stanza.models.lemma import scorer from stanza.models.common import utils from stanza.models.common.doc import * from stanza.utils.conll import CoNLL from stanza.models import _training_logging logger = logging.getLogger('stanza') def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.') parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.') parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.') parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--shorthand', type=str, help='Shorthand') parser.add_argument('--batch_size', type=int, default=50) parser.add_argument('--seed', type=int, default=1234) args = parser.parse_args(args=args) return args def main(args=None): args = parse_args(args=args) random.seed(args.seed) args = vars(args) logger.info("[Launching identity lemmatizer...]") if args['mode'] == 'train': logger.info("[No training is required; will only generate evaluation output...]") document = CoNLL.conll2doc(input_file=args['eval_file']) batch = DataLoader(document, args['batch_size'], args, evaluation=True, conll_only=True) system_pred_file = args['output_file'] gold_file = args['gold_file'] # use identity mapping for prediction preds = batch.doc.get([TEXT]) # write to file and score batch.doc.set([LEMMA], preds) if system_pred_file is not None: CoNLL.write_doc2conll(batch.doc, system_pred_file) if gold_file is not None: system_pred_file = "{:C}\n\n".format(batch.doc) system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, gold_file) logger.info("Lemma score:") logger.info("{} {:.2f}".format(args['shorthand'], score*100)) return None, batch.doc if __name__ == '__main__': main() ================================================ FILE: stanza/models/lang_identifier.py ================================================ """ Entry point for training and evaluating a Bi-LSTM language identifier """ import argparse import json import logging import os import random import torch from datetime import datetime from stanza.models.common import utils from stanza.models.langid.data import DataLoader from stanza.models.langid.trainer import Trainer from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() logger = logging.getLogger('stanza') def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument("--batch_mode", help="custom settings when running in batch mode", action="store_true") parser.add_argument("--batch_size", help="batch size for training", type=int, default=64) parser.add_argument("--eval_length", help="length of strings to eval on", type=int, default=None) parser.add_argument("--eval_set", help="eval on dev or test", default="test") parser.add_argument("--data_dir", help="directory with train/dev/test data", default=None) parser.add_argument("--load_name", help="path to load model from", default=None) parser.add_argument("--mode", help="train or eval", default="train") parser.add_argument("--num_epochs", help="number of epochs for training", type=int, default=50) parser.add_argument("--randomize", help="take random substrings of samples", action="store_true") parser.add_argument("--randomize_lengths_range", help="range of lengths to use when random sampling text", type=randomize_lengths_range, default="5,20") parser.add_argument("--merge_labels_for_eval", help="merge some language labels for eval (e.g. \"zh-hans\" and \"zh-hant\" to \"zh\")", action="store_true") parser.add_argument("--save_best_epochs", help="save model for every epoch with new best score", action="store_true") parser.add_argument("--save_name", help="where to save model", default=None) utils.add_device_args(parser) args = parser.parse_args(args=args) return args def randomize_lengths_range(range_list): """ Range of lengths for random samples """ range_boundaries = [int(x) for x in range_list.split(",")] assert range_boundaries[0] < range_boundaries[1], f"Invalid range: ({range_boundaries[0]}, {range_boundaries[1]})" return range_boundaries def main(args=None): args = parse_args(args=args) torch.manual_seed(0) if args.mode == "train": train_model(args) else: eval_model(args) def build_indexes(args): tag_to_idx = {} char_to_idx = {} train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x] for train_file in train_files: with open(train_file) as curr_file: lines = curr_file.read().strip().split("\n") examples = [json.loads(line) for line in lines if line.strip()] for example in examples: label = example["label"] if label not in tag_to_idx: tag_to_idx[label] = len(tag_to_idx) sequence = example["text"] for char in list(sequence): if char not in char_to_idx: char_to_idx[char] = len(char_to_idx) char_to_idx["UNK"] = len(char_to_idx) char_to_idx[""] = len(char_to_idx) return tag_to_idx, char_to_idx def train_model(args): # set up indexes tag_to_idx, char_to_idx = build_indexes(args) # load training data train_data = DataLoader(args.device) train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x] train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize) # load dev data dev_data = DataLoader(args.device) dev_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "dev" in x] dev_data.load_data(args.batch_size, dev_files, char_to_idx, tag_to_idx, randomize=False, max_length=args.eval_length) # set up trainer trainer_config = { "model_path": args.save_name, "char_to_idx": char_to_idx, "tag_to_idx": tag_to_idx, "batch_size": args.batch_size, "lang_weights": train_data.lang_weights } if args.load_name: trainer_config["load_name"] = args.load_name logger.info(f"{datetime.now()}\tLoading model from: {args.load_name}") trainer = Trainer(trainer_config, load_model=args.load_name is not None, device=args.device) # run training best_accuracy = 0.0 for epoch in range(1, args.num_epochs+1): logger.info(f"{datetime.now()}\tEpoch {epoch}") logger.info(f"{datetime.now()}\tNum training batches: {len(train_data.batches)}") batches = train_data.batches if not args.batch_mode: batches = tqdm(batches) for train_batch in batches: inputs = (train_batch["sentences"], train_batch["targets"]) trainer.update(inputs) logger.info(f"{datetime.now()}\tEpoch complete. Evaluating on dev data.") curr_dev_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \ eval_trainer(trainer, dev_data, batch_mode=args.batch_mode) logger.info(f"{datetime.now()}\tCurrent dev accuracy: {curr_dev_accuracy}") if curr_dev_accuracy > best_accuracy: logger.info(f"{datetime.now()}\tNew best score. Saving model.") model_label = f"epoch{epoch}" if args.save_best_epochs else None trainer.save(label=model_label) with open(score_log_path(args.save_name), "w") as score_log_file: for score_log in [{"dev_accuracy": curr_dev_accuracy}, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s]: score_log_file.write(json.dumps(score_log) + "\n") best_accuracy = curr_dev_accuracy # reload training data logger.info(f"{datetime.now()}\tResampling training data.") train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize) def score_log_path(file_path): """ Helper that will determine corresponding log file (e.g. /path/to/demo.pt to /path/to/demo.json """ model_suffix = os.path.splitext(file_path) if model_suffix[1]: score_log_path = f"{file_path[:-len(model_suffix[1])]}.json" else: score_log_path = f"{file_path}.json" return score_log_path def eval_model(args): # set up trainer trainer_config = { "model_path": None, "load_name": args.load_name, "batch_size": args.batch_size } trainer = Trainer(trainer_config, load_model=True, device=args.device) # load test data test_data = DataLoader(args.device) test_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if args.eval_set in x] test_data.load_data(args.batch_size, test_files, trainer.model.char_to_idx, trainer.model.tag_to_idx, randomize=False, max_length=args.eval_length) curr_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \ eval_trainer(trainer, test_data, batch_mode=args.batch_mode, fine_grained=not args.merge_labels_for_eval) logger.info(f"{datetime.now()}\t{args.eval_set} accuracy: {curr_accuracy}") eval_save_path = args.save_name if args.save_name else score_log_path(args.load_name) if not os.path.exists(eval_save_path) or args.save_name: with open(eval_save_path, "w") as score_log_file: for score_log in [{"dev_accuracy": curr_accuracy}, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s]: score_log_file.write(json.dumps(score_log) + "\n") def eval_trainer(trainer, dev_data, batch_mode=False, fine_grained=True): """ Produce dev accuracy and confusion matrix for a trainer """ # set up confusion matrix tag_to_idx = dev_data.tag_to_idx idx_to_tag = dev_data.idx_to_tag confusion_matrix = {} for row_label in tag_to_idx: confusion_matrix[row_label] = {} for col_label in tag_to_idx: confusion_matrix[row_label][col_label] = 0 # process dev batches batches = dev_data.batches if not batch_mode: batches = tqdm(batches) for dev_batch in batches: inputs = (dev_batch["sentences"], dev_batch["targets"]) predictions = trainer.predict(inputs) for target_idx, prediction in zip(dev_batch["targets"], predictions): prediction_label = idx_to_tag[prediction] if fine_grained else idx_to_tag[prediction].split("-")[0] confusion_matrix[idx_to_tag[target_idx]][prediction_label] += 1 # calculate dev accuracy total_examples = sum([sum([confusion_matrix[i][j] for j in confusion_matrix[i]]) for i in confusion_matrix]) total_correct = sum([confusion_matrix[i][i] for i in confusion_matrix]) dev_accuracy = float(total_correct) / float(total_examples) # calculate precision, recall, F1 precision_scores = {"type": "precision"} recall_scores = {"type": "recall"} f1_scores = {"type": "f1"} for prediction_label in tag_to_idx: total = sum([confusion_matrix[k][prediction_label] for k in tag_to_idx]) if total != 0.0: precision_scores[prediction_label] = float(confusion_matrix[prediction_label][prediction_label])/float(total) else: precision_scores[prediction_label] = 0.0 for target_label in tag_to_idx: total = sum([confusion_matrix[target_label][k] for k in tag_to_idx]) if total != 0: recall_scores[target_label] = float(confusion_matrix[target_label][target_label])/float(total) else: recall_scores[target_label] = 0.0 for label in tag_to_idx: if precision_scores[label] == 0.0 and recall_scores[label] == 0.0: f1_scores[label] = 0.0 else: f1_scores[label] = \ 2.0 * (precision_scores[label] * recall_scores[label]) / (precision_scores[label] + recall_scores[label]) return dev_accuracy, confusion_matrix, precision_scores, recall_scores, f1_scores if __name__ == "__main__": main() ================================================ FILE: stanza/models/langid/__init__.py ================================================ ================================================ FILE: stanza/models/langid/create_ud_data.py ================================================ """ Script for producing training/dev/test data from UD data or sentences Example output data format (one example per line): {"text": "Hello world.", "label": "en"} This is an attempt to recreate data pre-processing in https://github.com/AU-DIS/LSTM_langid Specifically borrows methods from https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py Data format is same as LSTM_langid as well. """ import argparse import json import logging import os import re import sys from pathlib import Path from random import randint, random, shuffle from string import digits from tqdm import tqdm from stanza.models.common.constant import treebank_to_langid logger = logging.getLogger('stanza') DEFAULT_LANGUAGES = "af,ar,be,bg,bxr,ca,cop,cs,cu,da,de,el,en,es,et,eu,fa,fi,fr,fro,ga,gd,gl,got,grc,he,hi,hr,hsb,hu,hy,id,it,ja,kk,kmr,ko,la,lt,lv,lzh,mr,mt,nl,nn,no,olo,orv,pl,pt,ro,ru,sk,sl,sme,sr,sv,swl,ta,te,tr,ug,uk,ur,vi,wo,zh-hans,zh-hant".split(",") def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument("--data-format", help="input data format", choices=["ud", "one-per-line"], default="ud") parser.add_argument("--eval-length", help="length of eval strings", type=int, default=10) parser.add_argument("--languages", help="list of languages to use, or \"all\"", default=DEFAULT_LANGUAGES) parser.add_argument("--min-window", help="minimal training example length", type=int, default=10) parser.add_argument("--max-window", help="maximum training example length", type=int, default=50) parser.add_argument("--ud-path", help="path to ud data") parser.add_argument("--save-path", help="path to save data", default=".") parser.add_argument("--splits", help="size of train/dev/test splits in percentages", type=splits_from_list, default="0.8,0.1,0.1") args = parser.parse_args(args=args) return args def splits_from_list(value_list): return [float(x) for x in value_list.split(",")] def main(args=None): args = parse_args(args=args) if isinstance(args.languages, str): args.languages = args.languages.split(",") data_paths = [f"{args.save_path}/{data_split}.jsonl" for data_split in ["train", "dev", "test"]] lang_to_files = collect_files(args.ud_path, args.languages, data_format=args.data_format) logger.info(f"Building UD data for languages: {','.join(args.languages)}") for lang_id in tqdm(lang_to_files): lang_examples = generate_examples(lang_id, lang_to_files[lang_id], splits=args.splits, min_window=args.min_window, max_window=args.max_window, eval_length=args.eval_length, data_format=args.data_format) for (data_set, save_path) in zip(lang_examples, data_paths): with open(save_path, "a") as json_file: for json_entry in data_set: json.dump(json_entry, json_file, ensure_ascii=False) json_file.write("\n") def collect_files(ud_path, languages, data_format="ud"): """ Given path to UD, collect files If data_format = "ud", expects files to be of form *.conllu If data_format = "one-per-line", expects files to be of form "*.sentences.txt" In all cases, the UD path should be a directory with subdirectories for each language """ data_format_to_search_path = {"ud": "*/*.conllu", "one-per-line": "*/*sentences.txt"} ud_files = Path(ud_path).glob(data_format_to_search_path[data_format]) lang_to_files = {} for ud_file in ud_files: if data_format == "ud": lang_id = treebank_to_langid(ud_file.parent.name) else: lang_id = ud_file.name.split("_")[0] if lang_id not in languages and "all" not in languages: continue if not lang_id in lang_to_files: lang_to_files[lang_id] = [] lang_to_files[lang_id].append(ud_file) return lang_to_files def generate_examples(lang_id, list_of_files, splits=(0.8,0.1,0.1), min_window=10, max_window=50, eval_length=10, data_format="ud"): """ Generate train/dev/test examples for a given language """ examples = [] for ud_file in list_of_files: sentences = sentences_from_file(ud_file, data_format=data_format) for sentence in sentences: sentence = clean_sentence(sentence) if validate_sentence(sentence, min_window): examples += sentence_to_windows(sentence, min_window=min_window, max_window=max_window) shuffle(examples) train_idx = int(splits[0] * len(examples)) train_set = [example_json(lang_id, example) for example in examples[:train_idx]] dev_idx = int(splits[1] * len(examples)) + train_idx dev_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[train_idx:dev_idx]] test_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[dev_idx:]] return train_set, dev_set, test_set def sentences_from_file(ud_file_path, data_format="ud"): """ Retrieve all sentences from a UD file """ if data_format == "ud": with open(ud_file_path) as ud_file: ud_file_contents = ud_file.read().strip() assert "# text = " in ud_file_contents, \ f"{ud_file_path} does not have expected format, \"# text =\" does not appear" sentences = [x[9:] for x in ud_file_contents.split("\n") if x.startswith("# text = ")] elif data_format == "one-per-line": with open(ud_file_path) as ud_file: sentences = [x for x in ud_file.read().strip().split("\n") if x] return sentences def sentence_to_windows(sentence, min_window, max_window): """ Create window size chunks from a sentence, always starting with a word """ windows = [] words = sentence.split(" ") curr_window = "" for idx, word in enumerate(words): curr_window += (" " + word) curr_window = curr_window.lstrip() next_word_len = len(words[idx+1]) + 1 if idx+1 < len(words) else 0 if len(curr_window) + next_word_len > max_window: curr_window = clean_sentence(curr_window) if validate_sentence(curr_window, min_window): windows.append(curr_window.strip()) curr_window = "" if len(curr_window) >= min_window: windows.append(curr_window) return windows def validate_sentence(current_window, min_window): """ Sentence validation from: LSTM-LID GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py """ if len(current_window) < min_window: return False return True def find(s, ch): """ Helper for clean_sentence from LSTM-LID GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py """ return [i for i, ltr in enumerate(s) if ltr == ch] def clean_sentence(line): """ Sentence cleaning from LSTM-LID GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py """ # We remove some special characters and fix small errors in the data, to improve the quality of the data line = line.replace("\n", '') #{"text": "- Mor.\n", "label": "da"} line = line.replace("- ", '') #{"text": "- Mor.", "label": "da"} line = line.replace("_", '') #{"text": "- Mor.", "label": "da"} line = line.replace("\\", '') line = line.replace("\"", '') line = line.replace(" ", " ") remove_digits = str.maketrans('', '', digits) line = line.translate(remove_digits) words = line.split() new_words = [] # Below fixes large I instead of l. Does not catch everything, but should also not really make any mistakes either for word in words: clean_word = word s = clean_word if clean_word[1:].__contains__("I"): indices = find(clean_word, "I") for indx in indices: if clean_word[indx-1].islower(): if len(clean_word) > indx + 1: if clean_word[indx+1].islower(): s = s[:indx] + "l" + s[indx + 1:] else: s = s[:indx] + "l" + s[indx + 1:] new_words.append(s) new_line = " ".join(new_words) return new_line def example_json(lang_id, text, eval_length=None): if eval_length is not None: text = text[:eval_length] return {"text": text.strip(), "label": lang_id} if __name__ == "__main__": main() ================================================ FILE: stanza/models/langid/data.py ================================================ import json import random import torch class DataLoader: """ Class for loading language id data and providing batches Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py Data format is same as LSTM_langid """ def __init__(self, device=None): self.batches = None self.batches_iter = None self.tag_to_idx = None self.idx_to_tag = None self.lang_weights = None self.device = device def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20), max_length=None): """ Load sequence data and labels, calculate weights for weighted cross entropy loss. Data is stored in a file, 1 example per line Example: {"text": "Hello world.", "label": "en"} """ # set up examples from data files examples = [] for data_file in data_files: examples += [x for x in open(data_file).read().split("\n") if x.strip()] random.shuffle(examples) examples = [json.loads(x) for x in examples] # add additional labels in this data set to tag index tag_index = dict(tag_index) new_labels = set([x["label"] for x in examples]) - set(tag_index.keys()) for new_label in new_labels: tag_index[new_label] = len(tag_index) self.tag_to_idx = tag_index self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])] # set up lang counts used for weights for cross entropy loss lang_counts = [0 for _ in tag_index] # optionally limit text to max length if max_length is not None: examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples] # randomize data if randomize: split_examples = [] for example in examples: sequence = example["text"] label = example["label"] sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1], lower_lim=randomize_range[0]) split_examples += [{"text": seq, "label": label} for seq in sequences] examples = split_examples random.shuffle(examples) # break into equal length batches batch_lengths = {} for example in examples: sequence = example["text"] label = example["label"] if len(sequence) not in batch_lengths: batch_lengths[len(sequence)] = [] sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)] batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label])) lang_counts[tag_index[label]] += 1 for length in batch_lengths: random.shuffle(batch_lengths[length]) # create final set of batches batches = [] for length in batch_lengths: for sublist in [batch_lengths[length][i:i + batch_size] for i in range(0, len(batch_lengths[length]), batch_size)]: batches.append(sublist) self.batches = [self.build_batch_tensors(batch) for batch in batches] # set up lang weights most_frequent = max(lang_counts) # set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts] self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float) # shuffle batches to mix up lengths random.shuffle(self.batches) self.batches_iter = iter(self.batches) @staticmethod def randomize_data(sentences, upper_lim=20, lower_lim=5): """ Takes the original data and creates random length examples with length between upper limit and lower limit From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py """ new_data = [] for sentence in sentences: remaining = sentence while lower_lim < len(remaining): lim = random.randint(lower_lim, upper_lim) m = min(len(remaining), lim) new_sentence = remaining[:m] new_data.append(new_sentence) split = remaining[m:].split(" ", 1) if len(split) <= 1: break remaining = split[1] random.shuffle(new_data) return new_data def build_batch_tensors(self, batch): """ Helper to turn batches into tensors """ batch_tensors = dict() batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long) batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long) return batch_tensors def next(self): return next(self.batches_iter) ================================================ FILE: stanza/models/langid/model.py ================================================ import os import torch import torch.nn as nn class LangIDBiLSTM(nn.Module): """ Multi-layer BiLSTM model for language detecting. A recreation of "A reproduction of Apple's bi-directional LSTM models for language identification in short strings." (Toftrup et al 2021) Arxiv: https://arxiv.org/abs/2102.06282 GitHub: https://github.com/AU-DIS/LSTM_langid This class is similar to https://github.com/AU-DIS/LSTM_langid/blob/main/src/LSTMLID.py """ def __init__(self, char_to_idx, tag_to_idx, num_layers, embedding_dim, hidden_dim, batch_size=64, weights=None, dropout=0.0, lang_subset=None): super(LangIDBiLSTM, self).__init__() self.num_layers = num_layers self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.char_to_idx = char_to_idx self.vocab_size = len(char_to_idx) self.tag_to_idx = tag_to_idx self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])] self.lang_subset = lang_subset self.padding_idx = char_to_idx[""] self.tagset_size = len(tag_to_idx) self.batch_size = batch_size self.loss_train = nn.CrossEntropyLoss(weight=weights) self.dropout_prob = dropout # embeddings for chars self.char_embeds = nn.Embedding( num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim, padding_idx=self.padding_idx ) # the bidirectional LSTM self.lstm = nn.LSTM( self.embedding_dim, self.hidden_dim, num_layers=self.num_layers, bidirectional=True, batch_first=True ) # convert output to tag space self.hidden_to_tag = nn.Linear( self.hidden_dim * 2, self.tagset_size ) # dropout layer self.dropout = nn.Dropout(p=self.dropout_prob) def build_lang_mask(self, device): """ Build language mask if a lang subset is specified (e.g. ["en", "fr"]) The mask will be added to the results to set the prediction scores of illegal languages to -inf """ if self.lang_subset: lang_mask_list = [0.0 if lang in self.lang_subset else -float('inf') for lang in self.idx_to_tag] self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float) else: self.lang_mask = torch.zeros(len(self.idx_to_tag), device=device, dtype=torch.float) def loss(self, Y_hat, Y): return self.loss_train(Y_hat, Y) def forward(self, x): # embed input x = self.char_embeds(x) # run through LSTM x, _ = self.lstm(x) # run through linear layer x = self.hidden_to_tag(x) # sum character outputs for each sequence x = torch.sum(x, dim=1) return x def prediction_scores(self, x): prediction_probs = self(x) if self.lang_subset: prediction_batch_size = prediction_probs.size()[0] batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)]) prediction_probs = prediction_probs + batch_mask return torch.argmax(prediction_probs, dim=1) def save(self, path): """ Save a model at path """ checkpoint = { "char_to_idx": self.char_to_idx, "tag_to_idx": self.tag_to_idx, "num_layers": self.num_layers, "embedding_dim": self.embedding_dim, "hidden_dim": self.hidden_dim, "model_state_dict": self.state_dict() } torch.save(checkpoint, path) @classmethod def load(cls, path, device=None, batch_size=64, lang_subset=None): """ Load a serialized model located at path """ if path is None: raise FileNotFoundError("Trying to load langid model, but path not specified! Try --load_name") if not os.path.exists(path): raise FileNotFoundError("Trying to load langid model from path which does not exist: %s" % path) checkpoint = torch.load(path, map_location=torch.device("cpu"), weights_only=True) weights = checkpoint["model_state_dict"]["loss_train.weight"] model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"], checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights, lang_subset=lang_subset) model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(device) model.build_lang_mask(device) return model ================================================ FILE: stanza/models/langid/trainer.py ================================================ import torch import torch.optim as optim from stanza.models.langid.model import LangIDBiLSTM class Trainer: DEFAULT_BATCH_SIZE = 64 DEFAULT_LAYERS = 2 DEFAULT_EMBEDDING_DIM = 150 DEFAULT_HIDDEN_DIM = 150 def __init__(self, config, load_model=False, device=None): self.model_path = config["model_path"] self.batch_size = config.get("batch_size", Trainer.DEFAULT_BATCH_SIZE) if load_model: self.load(config["load_name"], device) else: self.model = LangIDBiLSTM(config["char_to_idx"], config["tag_to_idx"], Trainer.DEFAULT_LAYERS, Trainer.DEFAULT_EMBEDDING_DIM, Trainer.DEFAULT_HIDDEN_DIM, batch_size=self.batch_size, weights=config["lang_weights"]).to(device) self.optimizer = optim.AdamW(self.model.parameters()) def update(self, inputs): self.model.train() sentences, targets = inputs self.optimizer.zero_grad() y_hat = self.model.forward(sentences) loss = self.model.loss(y_hat, targets) loss.backward() self.optimizer.step() def predict(self, inputs): self.model.eval() sentences, targets = inputs return torch.argmax(self.model(sentences), dim=1) def save(self, label=None): # save a copy of model with label if label: self.model.save(f"{self.model_path[:-3]}-{label}.pt") self.model.save(self.model_path) def load(self, model_path=None, device=None): if not model_path: model_path = self.model_path self.model = LangIDBiLSTM.load(model_path, device, self.batch_size) ================================================ FILE: stanza/models/lemma/__init__.py ================================================ ================================================ FILE: stanza/models/lemma/attach_lemma_classifier.py ================================================ import argparse from stanza.models.lemma.trainer import Trainer from stanza.models.lemma_classifier.base_model import LemmaClassifier def attach_classifier(input_filename, output_filename, classifiers): trainer = Trainer(model_file=input_filename) for classifier in classifiers: classifier = LemmaClassifier.load(classifier) trainer.contextual_lemmatizers.append(classifier) trainer.save(output_filename) def main(args=None): parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, required=True, help='Which lemmatizer to start from') parser.add_argument('--output', type=str, required=True, help='Where to save the lemmatizer') parser.add_argument('--classifier', type=str, required=True, nargs='+', help='Lemma classifier to attach') args = parser.parse_args(args) attach_classifier(args.input, args.output, args.classifier) if __name__ == '__main__': main() ================================================ FILE: stanza/models/lemma/data.py ================================================ import random import numpy as np import os from collections import Counter import logging import torch import stanza.models.common.seq2seq_constant as constant from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all from stanza.models.common.vocab import DeltaVocab from stanza.models.lemma.vocab import Vocab, MultiVocab from stanza.models.lemma import edit from stanza.models.common.doc import * logger = logging.getLogger('stanza') class DataLoader: def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, conll_only=False, skip=None, expand_unk_vocab=False): self.batch_size = batch_size self.args = args self.eval = evaluation self.shuffled = not self.eval self.doc = doc data = self.raw_data() if conll_only: # only load conll file return if skip is not None: assert len(data) == len(skip) data = [x for x, y in zip(data, skip) if not y] # handle vocab if vocab is not None: if expand_unk_vocab: pos_vocab = vocab['pos'] char_vocab = DeltaVocab(data, vocab['char']) self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab}) else: self.vocab = vocab else: self.vocab = dict() char_vocab, pos_vocab = self.init_vocab(data) self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab}) # filter and sample data if args.get('sample_train', 1.0) < 1.0 and not self.eval: keep = int(args['sample_train'] * len(data)) data = random.sample(data, keep) logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) data = self.preprocess(data, self.vocab['char'], self.vocab['pos'], args) # shuffle for training if self.shuffled: indices = list(range(len(data))) random.shuffle(indices) data = [data[i] for i in indices] self.num_examples = len(data) # chunk into batches data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)] self.data = data logger.debug("{} batches created.".format(len(data))) def init_vocab(self, data): assert self.eval is False, "Vocab file must exist for evaluation" char_data = "".join(d[0] + d[2] for d in data) char_vocab = Vocab(char_data, self.args['lang']) pos_data = [d[1] for d in data] pos_vocab = Vocab(pos_data, self.args['lang']) return char_vocab, pos_vocab def preprocess(self, data, char_vocab, pos_vocab, args): processed = [] for d in data: edit_type = edit.EDIT_TO_ID[edit.get_edit_type(d[0], d[2])] src = list(d[0]) src = [constant.SOS] + src + [constant.EOS] src = char_vocab.map(src) pos = d[1] pos = pos_vocab.unit2id(pos) tgt = list(d[2]) tgt_in = char_vocab.map([constant.SOS] + tgt) tgt_out = char_vocab.map(tgt + [constant.EOS]) processed += [[src, tgt_in, tgt_out, pos, edit_type, d[0]]] return processed def __len__(self): return len(self.data) def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 6 # sort all fields by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # convert to tensors src = batch[0] src = get_long_tensor(src, batch_size) src_mask = torch.eq(src, constant.PAD_ID) tgt_in = get_long_tensor(batch[1], batch_size) tgt_out = get_long_tensor(batch[2], batch_size) pos = torch.LongTensor(batch[3]) edits = torch.LongTensor(batch[4]) text = batch[5] assert tgt_in.size(1) == tgt_out.size(1), "Target input and output sequence sizes do not match." return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx, text def __iter__(self): for i in range(self.__len__()): yield self.__getitem__(i) def raw_data(self): return self.load_doc(self.doc, self.args.get('caseless', False), self.args.get('skip_blank_lemmas', False), self.eval) @staticmethod def load_doc(doc, caseless, skip_blank_lemmas, evaluation): if evaluation: data = doc.get([TEXT, UPOS, LEMMA]) else: data = doc.get([TEXT, UPOS, LEMMA, HEAD, DEPREL, MISC], as_sentences=True) data = DataLoader.remove_goeswith(data) data = DataLoader.extract_correct_forms(data) data = DataLoader.resolve_none(data) if not evaluation and skip_blank_lemmas: data = DataLoader.skip_blank_lemmas(data) if caseless: data = DataLoader.lowercase_data(data) return data @staticmethod def extract_correct_forms(data): """ Here we go through the raw data and use the CorrectForm of words tagged with CorrectForm In addition, if the incorrect form of the word is not present in the training data, we keep the incorrect form for the lemmatizer to learn from. This way, it can occasionally get things right in misspelled input text. We do check for and eliminate words where the incorrect form is already known as the lemma for a different word. For example, in the English datasets, there is a "busy" which was meant to be "buys", and we don't want the model to learn to lemmatize "busy" to "buy" """ new_data = [] incorrect_forms = [] for word in data: misc = word[-1] if not misc: new_data.append(word[:3]) continue misc = misc.split("|") for piece in misc: if piece.startswith("CorrectForm="): cf = piece.split("=", maxsplit=1)[1] # treat the CorrectForm as the desired word new_data.append((cf, word[1], word[2])) # and save the broken one for later in case it wasn't used anywhere else incorrect_forms.append((cf, word)) break else: # if no CorrectForm, just keep the word as normal new_data.append(word[:3]) known_words = {x[0] for x in new_data} for correct_form, word in incorrect_forms: if word[0] not in known_words: new_data.append(word[:3]) return new_data @staticmethod def remove_goeswith(data): """ This method specifically removes words that goeswith something else, along with the something else The purpose is to eliminate text such as 1 Ken kenrice@enroncommunications X GW Typo=Yes 0 root 0:root _ 2 Rice@ENRON _ X GW _ 1 goeswith 1:goeswith _ 3 COMMUNICATIONS _ X ADD _ 1 goeswith 1:goeswith _ """ filtered_data = [] remove_indices = set() for sentence in data: remove_indices.clear() for word_idx, word in enumerate(sentence): if word[4] == 'goeswith': remove_indices.add(word_idx) remove_indices.add(word[3]-1) filtered_data.extend([x for idx, x in enumerate(sentence) if idx not in remove_indices]) return filtered_data @staticmethod def lowercase_data(data): for token in data: token[0] = token[0].lower() return data @staticmethod def skip_blank_lemmas(data): data = [x for x in data if x[2] != '_'] return data @staticmethod def resolve_none(data): # replace None to '_' for tok_idx in range(len(data)): for feat_idx in range(len(data[tok_idx])): if data[tok_idx][feat_idx] is None: data[tok_idx][feat_idx] = '_' return data ================================================ FILE: stanza/models/lemma/edit.py ================================================ """ Utilities for calculating edits between word and lemma forms. """ EDIT_TO_ID = {'none': 0, 'identity': 1, 'lower': 2} def get_edit_type(word, lemma): """ Calculate edit types. """ if lemma == word: return 'identity' elif lemma == word.lower(): return 'lower' return 'none' def edit_word(word, pred, edit_id): """ Edit a word, given edit and seq2seq predictions. """ if edit_id == 1: return word elif edit_id == 2: return word.lower() elif edit_id == 0: return pred else: raise Exception("Unrecognized edit ID: {}".format(edit_id)) ================================================ FILE: stanza/models/lemma/scorer.py ================================================ """ Utils and wrappers for scoring lemmatizers. """ import logging from stanza.models.common.utils import ud_scores logger = logging.getLogger('stanza') def score(system_conllu_file, gold_conllu_file): """ Wrapper for lemma scorer. """ logger.debug("Evaluating system file %s vs gold file %s", system_conllu_file, gold_conllu_file) evaluation = ud_scores(gold_conllu_file, system_conllu_file) el = evaluation["Lemmas"] p, r, f = el.precision, el.recall, el.f1 return p, r, f ================================================ FILE: stanza/models/lemma/trainer.py ================================================ """ A trainer class to handle training and testing of models. """ import os import sys import numpy as np from collections import Counter import logging import torch from torch import nn import torch.nn.init as init import stanza.models.common.seq2seq_constant as constant from stanza.models.common.doc import TEXT, UPOS from stanza.models.common.foundation_cache import load_charlm from stanza.models.common.seq2seq_model import Seq2SeqModel from stanza.models.common.char_model import CharacterLanguageModelWordAdapter from stanza.models.common import utils, loss from stanza.models.lemma import edit from stanza.models.lemma.vocab import MultiVocab from stanza.models.lemma_classifier.base_model import LemmaClassifier logger = logging.getLogger('stanza') def unpack_batch(batch, device): """ Unpack a batch from the data loader. """ inputs = [b.to(device) if b is not None else None for b in batch[:6]] orig_idx = batch[6] text = batch[7] return inputs, orig_idx, text class Trainer(object): """ A trainer for training models. """ def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None, lemma_classifier_args=None): if model_file is not None: # load everything from file self.load(model_file, args, foundation_cache, lemma_classifier_args) else: # build model from scratch self.args = args if args['dict_only']: self.model = None else: self.model = self.build_seq2seq(args, emb_matrix, foundation_cache) self.vocab = vocab # dict-based components self.word_dict = dict() self.composite_dict = dict() self.contextual_lemmatizers = [] self.caseless = self.args.get('caseless', False) if not self.args['dict_only']: self.model = self.model.to(device) if self.args.get('edit', False): self.crit = loss.MixLoss(self.vocab['char'].size, self.args['alpha']).to(device) logger.debug("Running seq2seq lemmatizer with edit classifier...") else: self.crit = loss.SequenceLoss(self.vocab['char'].size).to(device) self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr']) def build_seq2seq(self, args, emb_matrix, foundation_cache): charmodel = None charlms = [] if args is not None and args.get('charlm_forward_file', None): charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache) charlms.append(charmodel_forward) if args is not None and args.get('charlm_backward_file', None): charmodel_backward = load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache) charlms.append(charmodel_backward) if len(charlms) > 0: charlms = nn.ModuleList(charlms) charmodel = CharacterLanguageModelWordAdapter(charlms) model = Seq2SeqModel(args, emb_matrix=emb_matrix, contextual_embedding=charmodel) return model def update(self, batch, eval=False): device = next(self.model.parameters()).device inputs, orig_idx, text = unpack_batch(batch, device) src, src_mask, tgt_in, tgt_out, pos, edits = inputs if eval: self.model.eval() else: self.model.train() self.optimizer.zero_grad() log_probs, edit_logits = self.model(src, src_mask, tgt_in, pos, raw=text) if self.args.get('edit', False): assert edit_logits is not None loss = self.crit(log_probs.view(-1, self.vocab['char'].size), tgt_out.view(-1), \ edit_logits, edits) else: loss = self.crit(log_probs.view(-1, self.vocab['char'].size), tgt_out.view(-1)) loss_val = loss.data.item() if eval: return loss_val loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() return loss_val def predict(self, batch, beam_size=1, vocab=None): if vocab is None: vocab = self.vocab device = next(self.model.parameters()).device inputs, orig_idx, text = unpack_batch(batch, device) src, src_mask, tgt, tgt_mask, pos, edits = inputs self.model.eval() batch_size = src.size(0) preds, edit_logits = self.model.predict(src, src_mask, pos=pos, beam_size=beam_size, raw=text) pred_seqs = [vocab['char'].unmap(ids) for ids in preds] # unmap to tokens pred_seqs = utils.prune_decoded_seqs(pred_seqs) pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens pred_tokens = utils.unsort(pred_tokens, orig_idx) if self.args.get('edit', False): assert edit_logits is not None edits = np.argmax(edit_logits.data.cpu().numpy(), axis=1).reshape([batch_size]).tolist() edits = utils.unsort(edits, orig_idx) else: edits = None return pred_tokens, edits def postprocess(self, words, preds, edits=None): """ Postprocess, mainly for handing edits. """ assert len(words) == len(preds), "Lemma predictions must have same length as words." edited = [] if self.args.get('edit', False): assert edits is not None and len(words) == len(edits) for w, p, e in zip(words, preds, edits): lem = edit.edit_word(w, p, e) edited += [lem] else: edited = preds # do not edit # final sanity check assert len(edited) == len(words) final = [] for lem, w in zip(edited, words): if len(lem) == 0 or constant.UNK in lem: final += [w] # invalid prediction, fall back to word else: final += [lem] return final def has_contextual_lemmatizers(self): return self.contextual_lemmatizers is not None and len(self.contextual_lemmatizers) > 0 def predict_contextual(self, sentence_words, sentence_tags, preds): if len(self.contextual_lemmatizers) == 0: return preds # reversed so that the first lemmatizer has priority for contextual in reversed(self.contextual_lemmatizers): pred_idx = [] pred_sent_words = [] pred_sent_tags = [] pred_sent_ids = [] for sent_id, (words, tags) in enumerate(zip(sentence_words, sentence_tags)): indices = contextual.target_indices(words, tags) for idx in indices: pred_idx.append(idx) pred_sent_words.append(words) pred_sent_tags.append(tags) pred_sent_ids.append(sent_id) if len(pred_idx) == 0: continue contextual_predictions = contextual.predict(pred_idx, pred_sent_words, pred_sent_tags) for sent_id, word_id, pred in zip(pred_sent_ids, pred_idx, contextual_predictions): preds[sent_id][word_id] = pred return preds def update_contextual_preds(self, doc, preds): """ Update a flat list of preds with the output of the contextual lemmatizers - First, it unflattens the preds based on the lengths of the sentences - Then it uses the contextual lemmatizers - Finally, it reflattens the preds into the format expected by the caller """ if len(self.contextual_lemmatizers) == 0: return preds sentence_words = doc.get([TEXT], as_sentences=True) sentence_tags = doc.get([UPOS], as_sentences=True) sentence_preds = [] start_index = 0 for sent in sentence_words: end_index = start_index + len(sent) sentence_preds.append(preds[start_index:end_index]) start_index += len(sent) preds = self.predict_contextual(sentence_words, sentence_tags, sentence_preds) preds = [lemma for sentence in preds for lemma in sentence] return preds def update_lr(self, new_lr): utils.change_lr(self.optimizer, new_lr) def train_dict(self, triples, update_word_dict=True): """ Train a dict lemmatizer given training (word, pos, lemma) triples. Can update only the composite_dict (word/pos) in situations where the data might be limited from the tags, such as when adding more words at pipeline time """ # accumulate counter ctr = Counter() ctr.update([(p[0], p[1], p[2]) for p in triples]) # find the most frequent mappings for p, _ in ctr.most_common(): w, pos, l = p if (w,pos) not in self.composite_dict: self.composite_dict[(w,pos)] = l if update_word_dict and w not in self.word_dict: self.word_dict[w] = l return def predict_dict(self, pairs): """ Predict a list of lemmas using the dict model given (word, pos) pairs. """ lemmas = [] for p in pairs: w, pos = p if self.caseless: w = w.lower() if (w,pos) in self.composite_dict: lemmas += [self.composite_dict[(w,pos)]] elif w in self.word_dict: lemmas += [self.word_dict[w]] else: lemmas += [w] return lemmas def skip_seq2seq(self, pairs): """ Determine if we can skip the seq2seq module when ensembling with the frequency lexicon. """ skip = [] for p in pairs: w, pos = p if self.caseless: w = w.lower() if (w,pos) in self.composite_dict: skip.append(True) elif w in self.word_dict: skip.append(True) else: skip.append(False) return skip def ensemble(self, pairs, other_preds): """ Ensemble the dict with statistical model predictions. """ lemmas = [] assert len(pairs) == len(other_preds) for p, pred in zip(pairs, other_preds): w, pos = p if self.caseless: w = w.lower() if (w,pos) in self.composite_dict: lemma = self.composite_dict[(w,pos)] elif w in self.word_dict: lemma = self.word_dict[w] else: lemma = pred if lemma is None: lemma = w lemmas.append(lemma) return lemmas def save(self, filename, skip_modules=True): model_state = None if self.model is not None: model_state = self.model.state_dict() # skip saving modules like the pretrained charlm if skip_modules: skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules] for k in skipped: del model_state[k] params = { 'model': model_state, 'dicts': (self.word_dict, self.composite_dict), 'vocab': self.vocab.state_dict(), 'config': self.args, 'contextual': [], } for contextual in self.contextual_lemmatizers: params['contextual'].append(contextual.get_save_dict()) save_dir = os.path.split(filename)[0] if save_dir: os.makedirs(os.path.split(filename)[0], exist_ok=True) torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) def load(self, filename, args, foundation_cache, lemma_classifier_args=None): try: checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise self.args = checkpoint['config'] if args is not None: self.args['charlm_forward_file'] = args.get('charlm_forward_file', self.args['charlm_forward_file']) self.args['charlm_backward_file'] = args.get('charlm_backward_file', self.args['charlm_backward_file']) self.word_dict, self.composite_dict = checkpoint['dicts'] if not self.args['dict_only']: self.model = self.build_seq2seq(self.args, None, foundation_cache) # could remove strict=False after rebuilding all models, # or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False self.model.load_state_dict(checkpoint['model'], strict=False) else: self.model = None self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) self.contextual_lemmatizers = [] for contextual in checkpoint.get('contextual', []): self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual, args=lemma_classifier_args)) ================================================ FILE: stanza/models/lemma/vocab.py ================================================ from collections import Counter from stanza.models.common.vocab import BaseVocab, BaseMultiVocab from stanza.models.common.seq2seq_constant import VOCAB_PREFIX class Vocab(BaseVocab): def build_vocab(self): counter = Counter(self.data) self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True)) self._unit2id = {w:i for i, w in enumerate(self._id2unit)} class MultiVocab(BaseMultiVocab): @classmethod def load_state_dict(cls, state_dict): new = cls() for k,v in state_dict.items(): new[k] = Vocab.load_state_dict(v) return new ================================================ FILE: stanza/models/lemma_classifier/__init__.py ================================================ ================================================ FILE: stanza/models/lemma_classifier/base_model.py ================================================ """ Base class for the LemmaClassifier types. Versions include LSTM and Transformer varieties """ import logging from abc import ABC, abstractmethod import os import torch import torch.nn as nn from stanza.models.common.foundation_cache import load_pretrain from stanza.models.lemma_classifier.constants import ModelType from typing import List logger = logging.getLogger('stanza.lemmaclassifier') class LemmaClassifier(ABC, nn.Module): def __init__(self, label_decoder, target_words, target_upos, *args, **kwargs): super().__init__(*args, **kwargs) self.label_decoder = label_decoder self.label_encoder = {y: x for x, y in label_decoder.items()} self.target_words = target_words self.target_upos = target_upos self.unsaved_modules = [] def add_unsaved_module(self, name, module): self.unsaved_modules += [name] setattr(self, name, module) def is_unsaved_module(self, name): return name.split('.')[0] in self.unsaved_modules def save(self, save_name): """ Save the model to the given path, possibly with some args """ save_dir = os.path.split(save_name)[0] if save_dir: os.makedirs(save_dir, exist_ok=True) save_dict = self.get_save_dict() torch.save(save_dict, save_name) return save_dict @abstractmethod def model_type(self): """ return a ModelType """ def target_indices(self, words, tags): return [idx for idx, (word, tag) in enumerate(zip(words, tags)) if word.lower() in self.target_words and tag in self.target_upos] def predict(self, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[str]]=[]) -> torch.Tensor: upos_tags = self.convert_tags(upos_tags) with torch.no_grad(): logits = self.forward(position_indices, sentences, upos_tags) # should be size (batch_size, output_size) predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1) predicted_class = [self.label_encoder[x.item()] for x in predicted_class] return predicted_class @staticmethod def from_checkpoint(checkpoint, args=None): model_type = ModelType[checkpoint['model_type']] if model_type is ModelType.LSTM: # TODO: if anyone can suggest a way to avoid this circular import # (or better yet, avoid the load method knowing about subclasses) # please do so # maybe the subclassing is not necessary and we just put # save & load in the trainer from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM saved_args = checkpoint['args'] # other model args are part of the model and cannot be changed for evaluation or pipeline # the file paths might be relevant, though keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file'] for arg in keep_args: if args is not None and args.get(arg, None) is not None: saved_args[arg] = args[arg] # TODO: refactor loading the pretrain (also done in the trainer) pt = load_pretrain(saved_args['wordvec_pretrain_file']) use_charlm = saved_args['use_charlm'] charlm_forward_file = saved_args.get('charlm_forward_file', None) charlm_backward_file = saved_args.get('charlm_backward_file', None) model = LemmaClassifierLSTM(model_args=saved_args, output_dim=len(checkpoint['label_decoder']), pt_embedding=pt, label_decoder=checkpoint['label_decoder'], upos_to_id=checkpoint['upos_to_id'], known_words=checkpoint['known_words'], target_words=set(checkpoint['target_words']), target_upos=set(checkpoint['target_upos']), use_charlm=use_charlm, charlm_forward_file=charlm_forward_file, charlm_backward_file=charlm_backward_file) elif model_type is ModelType.TRANSFORMER: from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer output_dim = len(checkpoint['label_decoder']) saved_args = checkpoint['args'] bert_model = saved_args['bert_model'] model = LemmaClassifierWithTransformer(model_args=saved_args, output_dim=output_dim, transformer_name=bert_model, label_decoder=checkpoint['label_decoder'], target_words=set(checkpoint['target_words']), target_upos=set(checkpoint['target_upos'])) else: raise ValueError("Unknown model type %s" % model_type) # strict=False to accommodate missing parameters from the transformer or charlm model.load_state_dict(checkpoint['params'], strict=False) return model @staticmethod def load(filename, args=None): try: checkpoint = torch.load(filename, lambda storage, loc: storage) except BaseException: logger.exception("Cannot load model from %s", filename) raise logger.debug("Loading LemmaClassifier model from %s", filename) return LemmaClassifier.from_checkpoint(checkpoint) ================================================ FILE: stanza/models/lemma_classifier/base_trainer.py ================================================ from abc import ABC, abstractmethod import logging import os from typing import List, Tuple, Any, Mapping import torch import torch.nn as nn import torch.optim as optim from stanza.models.common.utils import default_device from stanza.models.lemma_classifier import utils from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE from stanza.models.lemma_classifier.evaluate_models import evaluate_model from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() logger = logging.getLogger('stanza.lemmaclassifier') class BaseLemmaClassifierTrainer(ABC): def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping): """ If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss. The weights are determined by the counts of the classes in the dataset. The weights are inversely proportional to the frequency of the class in the set. E.g. classes with lower frequency will have higher weight. """ weights = [0 for _ in label_decoder.keys()] # each key in the label decoder is one class, we have one weight per class total_samples = sum(counts.values()) for class_idx in counts: weights[class_idx] = total_samples / (counts[class_idx] * len(counts)) # weight_i = total / (# examples in class i * num classes) weights = torch.tensor(weights) logger.info(f"Using weights {weights} for weighted loss.") self.criterion = nn.BCEWithLogitsLoss(weight=weights) @abstractmethod def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos): """ Build a model using pieces of the dataset to determine some of the model shape """ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, train_file: str) -> None: """ Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token. Args: num_epochs (int): Number of training epochs save_name (str): Path to file where trained model should be saved. eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch. train_file (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line. """ # Put model on GPU (if possible) device = default_device() if not train_file: raise ValueError("Cannot train model - no train_file supplied!") dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE)) label_decoder = dataset.label_decoder upos_to_id = dataset.upos_to_id self.output_dim = len(label_decoder) logger.info(f"Loaded dataset successfully from {train_file}") logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}") logger.info(f"Target words: {dataset.target_words}") self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words, set(dataset.target_upos)) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.model.to(device) logger.info(f"Training model on device: {device}. {next(self.model.parameters()).device}") if os.path.exists(save_name) and not args.get('force', False): raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...") if self.weighted_loss: self.configure_weighted_loss(label_decoder, dataset.counts) # Put the criterion on GPU too logger.debug(f"Criterion on {next(self.model.parameters()).device}") self.criterion = self.criterion.to(next(self.model.parameters()).device) best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model for epoch in range(num_epochs): # go over entire dataset with each epoch for sentences, positions, upos_tags, labels in tqdm(dataset): assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})" self.optimizer.zero_grad() outputs = self.model(positions, sentences, upos_tags) # Compute loss, which is different if using CE or BCEWithLogitsLoss if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others. # TODO: three classes? targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device) # should be shape size (batch_size, 2) else: # CELoss accepts target as just raw label targets = labels.to(device) loss = self.criterion(outputs, targets) loss.backward() self.optimizer.step() logger.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}") if eval_file: # Evaluate model on dev set to see if it should be saved. _, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True) logger.info(f"Weighted f1 for model: {f1}") if f1 > best_f1: best_f1 = f1 self.model.save(save_name) logger.info(f"New best model: weighted f1 score of {f1}.") else: self.model.save(save_name) ================================================ FILE: stanza/models/lemma_classifier/baseline_model.py ================================================ """ Baseline model for the existing lemmatizer which always predicts "be" and never "have" on the "'s" token. The BaselineModel class can be updated to any arbitrary token and predicton lemma, not just "be" on the "s" token. """ import stanza import os from stanza.models.lemma_classifier.evaluate_models import evaluate_sequences from stanza.models.lemma_classifier.prepare_dataset import load_doc_from_conll_file class BaselineModel: def __init__(self, token_to_lemmatize, prediction_lemma, prediction_upos): self.token_to_lemmatize = token_to_lemmatize self.prediction_lemma = prediction_lemma self.prediction_upos = prediction_upos def predict(self, token): if token == self.token_to_lemmatize: return self.prediction_lemma def evaluate(self, conll_path): """ Evaluates the baseline model against the test set defined in conll_path. Returns a map where the keys are each class and the values are another map including the precision, recall and f1 scores for that class. Also returns confusion matrix. Keys are gold tags and inner keys are predicted tags """ doc = load_doc_from_conll_file(conll_path) gold_tag_sequences, pred_tag_sequences = [], [] for sentence in doc.sentences: gold_tags, pred_tags = [], [] for word in sentence.words: if word.upos in self.prediction_upos and word.text == self.token_to_lemmatize: pred = self.prediction_lemma gold = word.lemma gold_tags.append(gold) pred_tags.append(pred) gold_tag_sequences.append(gold_tags) pred_tag_sequences.append(pred_tags) multiclass_result, confusion_mtx, weighted_f1 = evaluate_sequences(gold_tag_sequences, pred_tag_sequences) return multiclass_result, confusion_mtx if __name__ == "__main__": bl_model = BaselineModel("'s", "be", ["AUX"]) coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu") bl_model.evaluate(coNLL_path) ================================================ FILE: stanza/models/lemma_classifier/constants.py ================================================ from enum import Enum UNKNOWN_TOKEN = "unk" # token name for unknown tokens UNKNOWN_TOKEN_IDX = -1 # custom index we apply to unknown tokens # TODO: ModelType could just be LSTM and TRANSFORMER # and then the transformer baseline would have the transformer as another argument class ModelType(Enum): LSTM = 1 TRANSFORMER = 2 BERT = 3 ROBERTA = 4 DEFAULT_BATCH_SIZE = 16 ================================================ FILE: stanza/models/lemma_classifier/evaluate_many.py ================================================ """ Utils to evaluate many models of the same type at once """ import argparse import os import logging from stanza.models.lemma_classifier.evaluate_models import main as evaluate_main logger = logging.getLogger('stanza.lemmaclassifier') def evaluate_n_models(path_to_models_dir, args): total_results = { "be": 0.0, "have": 0.0, "accuracy": 0.0, "weighted_f1": 0.0 } paths = os.listdir(path_to_models_dir) num_models = len(paths) for model_path in paths: full_path = os.path.join(path_to_models_dir, model_path) args.save_name = full_path mcc_results, confusion, acc, weighted_f1 = evaluate_main(predefined_args=args) for lemma in mcc_results: lemma_f1 = mcc_results.get(lemma, None).get("f1") * 100 total_results[lemma] += lemma_f1 total_results["accuracy"] += acc total_results["weighted_f1"] += weighted_f1 total_results["be"] /= num_models total_results["have"] /= num_models total_results["accuracy"] /= num_models total_results["weighted_f1"] /= num_models logger.info(f"Models in {path_to_models_dir} had average weighted f1 of {100 * total_results['weighted_f1']}.\nLemma 'be' had f1: {total_results['be']}\nLemma 'have' had f1: {total_results['have']}.\nAccuracy: {100 * total_results['accuracy']}.\n ({num_models} models evaluated).") return total_results def main(): parser = argparse.ArgumentParser() parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab") parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)") parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file") parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')") parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") parser.add_argument("--eval_file", type=str, help="path to evaluation file") # Args specific to several model eval parser.add_argument("--base_path", type=str, default=None, help="path to dir for eval") args = parser.parse_args() evaluate_n_models(args.base_path, args) if __name__ == "__main__": main() ================================================ FILE: stanza/models/lemma_classifier/evaluate_models.py ================================================ import os import sys parentdir = os.path.dirname(__file__) parentdir = os.path.dirname(parentdir) parentdir = os.path.dirname(parentdir) sys.path.append(parentdir) import logging import argparse import os from typing import Any, List, Tuple, Mapping from collections import defaultdict from numpy import random import torch import torch.nn as nn import stanza from stanza.models.common.utils import default_device from stanza.models.lemma_classifier import utils from stanza.models.lemma_classifier.base_model import LemmaClassifier from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer from stanza.utils.confusion import format_confusion from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() logger = logging.getLogger('stanza.lemmaclassifier') def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: Mapping[int, Mapping[int, int]]) -> float: """ Computes the weighted F1 score across an evaluation set. The weight of a class's F1 score is equal to the number of examples in evaluation. This makes classes that have more examples in the evaluation more impactful to the weighted f1. """ num_total_examples = 0 weighted_f1 = 0 for class_id in mcc_results: class_f1 = mcc_results.get(class_id).get("f1") num_class_examples = sum(confusion.get(class_id).values()) weighted_f1 += class_f1 * num_class_examples num_total_examples += num_class_examples return weighted_f1 / num_total_examples def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True): """ Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes. Precision = true positives / true positives + false positives Recall = true positives / true positives + false negatives F1 = 2 * (Precision * Recall) / (Precision + Recall) Returns: 1. Multi class result dictionary, where each class is a key and maps to another map of its F1, precision, and recall scores. e.g. multiclass_results[0]["precision"] would give class 0's precision. 2. Confusion matrix, where each key is a gold tag and its value is another map with a key of the predicted tag with value of that (gold, pred) count. e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times. """ assert len(gold_tag_sequences) == len(pred_tag_sequences), \ f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}" confusion = defaultdict(lambda: defaultdict(int)) reverse_label_decoder = {y: x for x, y in label_decoder.items()} for gold, pred in zip(gold_tag_sequences, pred_tag_sequences): confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1 multi_class_result = defaultdict(lambda: defaultdict(float)) # compute precision, recall and f1 for each class and store inside of `multi_class_result` for gold_tag in confusion.keys(): try: prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()]) except ZeroDivisionError: prec = 0.0 try: recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values()) except ZeroDivisionError: recall = 0.0 try: f1 = 2 * (prec * recall) / (prec + recall) except ZeroDivisionError: f1 = 0.0 multi_class_result[gold_tag] = { "precision": prec, "recall": recall, "f1": f1 } if verbose: for lemma in multi_class_result: logger.info(f"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}") weighted_f1 = get_weighted_f1(multi_class_result, confusion) return multi_class_result, confusion, weighted_f1 def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor: """ A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token. Args: model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token. position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch. sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences. Returns: (int): The index of the predicted class in `model`'s output. """ with torch.no_grad(): logits = model(position_indices, sentences, upos_tags) # should be size (batch_size, output_size) predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1) return predicted_class def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]: """ Helper function for model evaluation Args: model (LemmaClassifierLSTM or LemmaClassifierWithTransformer): An instance of the LemmaClassifier class that has architecture initialized which matches the model saved in `model_path`. model_path (str): Path to the saved model weights that will be loaded into `model`. eval_path (str): Path to the saved evaluation dataset. verbose (bool, optional): True if `evaluate_sequences()` should print the F1, Precision, and Recall for each class. Defaults to True. is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode. Returns: 1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is another map with key of "f1", "precision", or "recall" with corresponding values. 2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the map with the key as the predicted tag and corresponding count of that (gold, pred) pair. 3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set. """ # load model device = default_device() model.to(device) if not is_training: model.eval() # set to eval mode # load in eval data dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False) logger.info(f"Evaluating on evaluation file {eval_path}") correct, total = 0, 0 gold_tags, pred_tags = dataset.labels, [] # run eval on each example from dataset for sentences, pos_indices, upos_tags, labels in tqdm(dataset, "Evaluating examples from data file"): pred = model_predict(model, pos_indices, sentences, upos_tags) # Pred should be size (batch_size, ) correct_preds = pred == labels.to(device) correct += torch.sum(correct_preds) total += len(correct_preds) pred_tags += pred.tolist() logger.info("Finished evaluating on dataset. Computing scores...") accuracy = correct / total mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose) # add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper if verbose: logger.info(f"Accuracy: {accuracy} ({correct}/{total})") logger.info(f"Label decoder: {dataset.label_decoder}") return mc_results, confusion, accuracy, weighted_f1 def main(args=None, predefined_args=None): # TODO: can unify this script with train_lstm_model.py? # TODO: can save the model type in the model .pt, then # automatically figure out what type of model we are using by # looking in the file parser = argparse.ArgumentParser() parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab") parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)") parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file") parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')") parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") parser.add_argument("--eval_file", type=str, help="path to evaluation file") args = parser.parse_args(args) if not predefined_args else predefined_args logger.info("Running training script with the following args:") args = vars(args) for arg in args: logger.info(f"{arg}: {args[arg]}") logger.info("------------------------------------------------------------") logger.info(f"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}") model = LemmaClassifier.load(args['save_name'], args) mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, args['eval_file']) logger.info(f"MCC Results: {dict(mcc_results)}") logger.info("______________________________________________") logger.info(f"Confusion:\n%s", format_confusion(confusion)) logger.info("______________________________________________") logger.info(f"Accuracy: {acc}") logger.info("______________________________________________") logger.info(f"Weighted f1: {weighted_f1}") return mcc_results, confusion, acc, weighted_f1 if __name__ == "__main__": main() ================================================ FILE: stanza/models/lemma_classifier/lstm_model.py ================================================ import torch import torch.nn as nn import os import logging import math from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel from typing import List, Tuple from stanza.models.common.vocab import UNK_ID from stanza.models.lemma_classifier import utils from stanza.models.lemma_classifier.base_model import LemmaClassifier from stanza.models.lemma_classifier.constants import ModelType logger = logging.getLogger('stanza.lemmaclassifier') class LemmaClassifierLSTM(LemmaClassifier): """ Model architecture: Extracts word embeddings over the sentence, passes embeddings into a bi-LSTM to get a sentence encoding. From the LSTM output, we get the embedding of the specific token that we classify on. That embedding is fed into an MLP for classification. """ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos, use_charlm=False, charlm_forward_file=None, charlm_backward_file=None): """ Args: vocab_size (int): Size of the vocab being used (if custom vocab) output_dim (int): Size of output vector from MLP layer upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs pt_embedding (Pretrain): pretrained embeddings known_words (list(str)): Words which are in the training data target_words (set(str)): a set of the words which might need lemmatization use_charlm (bool): Whether or not to use the charlm embeddings charlm_forward_file (str): The path to the forward pass model for the character language model charlm_backward_file (str): The path to the forward pass model for the character language model. Kwargs: upos_emb_dim (int): The size of the UPOS tag embeddings num_heads (int): The number of heads to use for attention. If there are more than 0 heads, attention will be used instead of the LSTM. Raises: FileNotFoundError: if the forward or backward charlm file cannot be found. """ super(LemmaClassifierLSTM, self).__init__(label_decoder, target_words, target_upos) self.model_args = model_args self.hidden_dim = model_args['hidden_dim'] self.input_size = 0 self.num_heads = self.model_args['num_heads'] emb_matrix = pt_embedding.emb self.add_unsaved_module("embeddings", nn.Embedding.from_pretrained(emb_matrix, freeze=True)) self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pt_embedding.vocab) } self.vocab_size = emb_matrix.shape[0] self.embedding_dim = emb_matrix.shape[1] self.known_words = known_words self.known_word_map = {word: idx for idx, word in enumerate(known_words)} self.delta_embedding = nn.Embedding(num_embeddings=len(known_words)+1, embedding_dim=self.embedding_dim, padding_idx=0) nn.init.normal_(self.delta_embedding.weight, std=0.01) self.input_size += self.embedding_dim # Optionally, include charlm embeddings self.use_charlm = use_charlm if self.use_charlm: if charlm_forward_file is None or not os.path.exists(charlm_forward_file): raise FileNotFoundError(f'Could not find forward character model: {charlm_forward_file}') if charlm_backward_file is None or not os.path.exists(charlm_backward_file): raise FileNotFoundError(f'Could not find backward character model: {charlm_backward_file}') self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(charlm_forward_file, finetune=False)) self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(charlm_backward_file, finetune=False)) self.input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim() self.upos_emb_dim = self.model_args["upos_emb_dim"] self.upos_to_id = upos_to_id if self.upos_emb_dim > 0 and self.upos_to_id is not None: # TODO: should leave space for unknown POS? self.upos_emb = nn.Embedding(num_embeddings=len(self.upos_to_id), embedding_dim=self.upos_emb_dim, padding_idx=0) self.input_size += self.upos_emb_dim device = next(self.parameters()).device # Determine if attn or LSTM should be used if self.num_heads > 0: self.input_size = utils.round_up_to_multiple(self.input_size, self.num_heads) self.multihead_attn = nn.MultiheadAttention(embed_dim=self.input_size, num_heads=self.num_heads, batch_first=True).to(device) logger.debug(f"Using attention mechanism with embed dim {self.input_size} and {self.num_heads} attention heads.") else: self.lstm = nn.LSTM(self.input_size, self.hidden_dim, batch_first=True, bidirectional=True) logger.debug(f"Using LSTM mechanism.") mlp_input_size = self.hidden_dim * 2 if self.num_heads == 0 else self.input_size self.mlp = nn.Sequential( nn.Linear(mlp_input_size, 64), nn.ReLU(), nn.Linear(64, output_dim) ) def get_save_dict(self): save_dict = { "params": self.state_dict(), "label_decoder": self.label_decoder, "model_type": self.model_type().name, "args": self.model_args, "upos_to_id": self.upos_to_id, "known_words": self.known_words, "target_words": list(self.target_words), "target_upos": list(self.target_upos), } skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)] for k in skipped: del save_dict["params"][k] return save_dict def convert_tags(self, upos_tags: List[List[str]]): if self.upos_to_id is not None: return [[self.upos_to_id[x] for x in sentence] for sentence in upos_tags] return None def forward(self, pos_indices: List[int], sentences: List[List[str]], upos_tags: List[List[int]]): """ Computes the forward pass of the neural net Args: pos_indices (List[int]): A list of the position index of the target token for lemmatization classification in each sentence. sentences (List[List[str]]): A list of the token-split sentences of the input data. upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence. Returns: torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences. """ device = next(self.parameters()).device batch_size = len(sentences) token_ids = [] delta_token_ids = [] for words in sentences: sentence_token_ids = [self.vocab_map.get(word.lower(), UNK_ID) for word in words] sentence_token_ids = torch.tensor(sentence_token_ids, device=device) token_ids.append(sentence_token_ids) sentence_delta_token_ids = [self.known_word_map.get(word.lower(), 0) for word in words] sentence_delta_token_ids = torch.tensor(sentence_delta_token_ids, device=device) delta_token_ids.append(sentence_delta_token_ids) token_ids = pad_sequence(token_ids, batch_first=True) delta_token_ids = pad_sequence(delta_token_ids, batch_first=True) embedded = self.embeddings(token_ids) + self.delta_embedding(delta_token_ids) if self.upos_emb_dim > 0: upos_tags = [torch.tensor(sentence_tags) for sentence_tags in upos_tags] # convert internal lists to tensors upos_tags = pad_sequence(upos_tags, batch_first=True, padding_value=0).to(device) pos_emb = self.upos_emb(upos_tags) embedded = torch.cat((embedded, pos_emb), 2).to(device) if self.use_charlm: char_reps_forward = self.charmodel_forward.build_char_representation(sentences) # takes [[str]] char_reps_backward = self.charmodel_backward.build_char_representation(sentences) char_reps_forward = pad_sequence(char_reps_forward, batch_first=True) char_reps_backward = pad_sequence(char_reps_backward, batch_first=True) embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 2) if self.num_heads > 0: def positional_encoding(seq_len, d_model, device): encoding = torch.zeros(seq_len, d_model, device=device) position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device) encoding[:, 0::2] = torch.sin(position * div_term) encoding[:, 1::2] = torch.cos(position * div_term) # Add a new dimension to fit the batch size encoding = encoding.unsqueeze(0) return encoding seq_len, d_model = embedded.shape[1], embedded.shape[2] pos_enc = positional_encoding(seq_len, d_model, device=device) embedded += pos_enc.expand_as(embedded) padded_sequences = pad_sequence(embedded, batch_first=True) lengths = torch.tensor([len(seq) for seq in embedded]) if self.num_heads > 0: target_seq_length, src_seq_length = padded_sequences.size(1), padded_sequences.size(1) attn_mask = torch.triu(torch.ones(batch_size * self.num_heads, target_seq_length, src_seq_length, dtype=torch.bool), diagonal=1) attn_mask = attn_mask.view(batch_size, self.num_heads, target_seq_length, src_seq_length) attn_mask = attn_mask.repeat(1, 1, 1, 1).view(batch_size * self.num_heads, target_seq_length, src_seq_length).to(device) attn_output, attn_weights = self.multihead_attn(padded_sequences, padded_sequences, padded_sequences, attn_mask=attn_mask) # Extract the hidden state at the index of the token to classify token_reps = attn_output[torch.arange(attn_output.size(0)), pos_indices] else: packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True) lstm_out, (hidden, _) = self.lstm(packed_sequences) # Extract the hidden state at the index of the token to classify unpacked_lstm_outputs, _ = pad_packed_sequence(lstm_out, batch_first=True) token_reps = unpacked_lstm_outputs[torch.arange(unpacked_lstm_outputs.size(0)), pos_indices] # MLP forward pass output = self.mlp(token_reps) return output def model_type(self): return ModelType.LSTM ================================================ FILE: stanza/models/lemma_classifier/prepare_dataset.py ================================================ import argparse import json import os import re import stanza from stanza.models.lemma_classifier import utils from typing import List, Tuple, Any """ The code in this file processes a CoNLL dataset by taking its sentences and filtering out all sentences that do not contain the target token. Furthermore, it will store tuples of the Stanza document object, the position index of the target token, and its lemma. """ def load_doc_from_conll_file(path: str): """" loads in a Stanza document object from a path to a CoNLL file containing annotated sentences. """ return stanza.utils.conll.CoNLL.conll2doc(path) class DataProcessor(): def __init__(self, target_word: str, target_upos: List[str], allowed_lemmas: str): self.target_word = target_word self.target_word_regex = re.compile(target_word) self.target_upos = target_upos self.allowed_lemmas = re.compile(allowed_lemmas) def keep_sentence(self, sentence): for word in sentence.words: if self.target_word_regex.fullmatch(word.text) and word.upos in self.target_upos: return True return False def find_all_occurrences(self, sentence) -> List[int]: """ Finds all occurrences of self.target_word in tokens and returns the index(es) of such occurrences. """ occurrences = [] for idx, token in enumerate(sentence.words): if self.target_word_regex.fullmatch(token.text) and token.upos in self.target_upos: occurrences.append(idx) return occurrences @staticmethod def write_output_file(save_name, target_upos, sentences): with open(save_name, "w+", encoding="utf-8") as output_f: output_f.write("{\n") output_f.write(' "upos": %s,\n' % json.dumps(target_upos)) output_f.write(' "sentences": [') wrote_sentence = False for sentence in sentences: if not wrote_sentence: output_f.write("\n ") wrote_sentence = True else: output_f.write(",\n ") output_f.write(json.dumps(sentence)) output_f.write("\n ]\n}\n") def process_document(self, doc, save_name: str) -> None: """ Takes any sentence from `doc` that meets the condition of `keep_sentence` and writes its tokens, index of target word, and lemma to `save_name` Sentences that meet `keep_sentence` and contain `self.target_word` multiple times have each instance in a different example in the output file. Args: doc (Stanza.doc): Document object that represents the file to be analyzed save_name (str): Path to the file for storing output """ sentences = [] for sentence in doc.sentences: # for each sentence, we need to determine if it should be added to the output file. # if the sentence fulfills keep_sentence, then we will save it along with the target word's index and its corresponding lemma if self.keep_sentence(sentence): tokens = [token.text for token in sentence.words] indexes = self.find_all_occurrences(sentence) for idx in indexes: if self.allowed_lemmas.fullmatch(sentence.words[idx].lemma): # for each example found, we write the tokens, # their respective upos tags, the target token index, # and the target lemma upos_tags = [sentence.words[i].upos for i in range(len(sentence.words))] num_tokens = len(upos_tags) sentences.append({ "words": tokens, "upos_tags": upos_tags, "index": idx, "lemma": sentence.words[idx].lemma }) if save_name: self.write_output_file(save_name, self.target_upos, sentences) return sentences def main(args=None): parser = argparse.ArgumentParser() parser.add_argument("--conll_path", type=str, default=os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu"), help="path to the conll file to translate") parser.add_argument("--target_word", type=str, default="'s", help="Token to classify on, e.g. 's.") parser.add_argument("--target_upos", type=str, default="AUX", help="upos on target token") parser.add_argument("--output_path", type=str, default="test_output.txt", help="Path for output file") parser.add_argument("--allowed_lemmas", type=str, default=".*", help="A regex for allowed lemmas. If not set, all lemmas are allowed") args = parser.parse_args(args) conll_path = args.conll_path target_upos = args.target_upos output_path = args.output_path allowed_lemmas = args.allowed_lemmas args = vars(args) for arg in args: print(f"{arg}: {args[arg]}") doc = load_doc_from_conll_file(conll_path) processor = DataProcessor(target_word=args['target_word'], target_upos=[target_upos], allowed_lemmas=allowed_lemmas) return processor.process_document(doc, output_path) if __name__ == "__main__": main() ================================================ FILE: stanza/models/lemma_classifier/train_lstm_model.py ================================================ """ The code in this file works to train a lemma classifier for 's """ import argparse import logging import os import torch import torch.nn as nn from stanza.models.common.foundation_cache import load_pretrain from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM logger = logging.getLogger('stanza.lemmaclassifier') class LemmaClassifierTrainer(BaseLemmaClassifierTrainer): """ Class to assist with training a LemmaClassifierLSTM """ def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = False, charlm_forward_file: str = None, charlm_backward_file: str = None, lr: float = 0.001, loss_func: str = None): """ Initializes the LemmaClassifierTrainer class. Args: model_args (dict): Various model shape parameters embedding_file (str): What word embeddings file to use. Use a Stanza pretrain .pt use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False. charlm_forward_file (str): Path to the forward pass embeddings for the charlm charlm_backward_file (str): Path to the backward pass embeddings for the charlm upos_emb_dim (int): The dimension size of UPOS tag embeddings num_heads (int): The number of attention heads to use. lr (float): Learning rate, defaults to 0.001. loss_func (str): Which loss function to use (either 'ce' or 'weighted_bce') Raises: FileNotFoundError: If the forward charlm file is not present FileNotFoundError: If the backward charlm file is not present """ super().__init__() self.model_args = model_args # Load word embeddings pt = load_pretrain(embedding_file) self.pt_embedding = pt # Load CharLM embeddings if use_charlm and charlm_forward_file is not None and not os.path.exists(charlm_forward_file): raise FileNotFoundError(f"Could not find forward charlm file: {charlm_forward_file}") if use_charlm and charlm_backward_file is not None and not os.path.exists(charlm_backward_file): raise FileNotFoundError(f"Could not find backward charlm file: {charlm_backward_file}") # TODO: just pass around the args instead self.use_charlm = use_charlm self.charlm_forward_file = charlm_forward_file self.charlm_backward_file = charlm_backward_file self.lr = lr # Find loss function if loss_func == "ce": self.criterion = nn.CrossEntropyLoss() self.weighted_loss = False logger.debug("Using CE loss") elif loss_func == "weighted_bce": self.criterion = nn.BCEWithLogitsLoss() self.weighted_loss = True # used to add weights during train time. logger.debug("Using Weighted BCE loss") else: raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')") def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos): return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos, use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file) def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read') parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.") parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.') parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.") parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file") parser.add_argument("--lr", type=float, default=0.001, help="learning rate") parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch") parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file") parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.") parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file') return parser def main(args=None, predefined_args=None): parser = build_argparse() args = parser.parse_args(args) if predefined_args is None else predefined_args wordvec_pretrain_file = args.wordvec_pretrain_file use_charlm = args.use_charlm charlm_forward_file = args.charlm_forward_file charlm_backward_file = args.charlm_backward_file upos_emb_dim = args.upos_emb_dim use_attention = args.attn num_heads = args.num_heads save_name = args.save_name lr = args.lr num_epochs = args.num_epochs train_file = args.train_file weighted_loss = args.weighted_loss eval_file = args.eval_file args = vars(args) if os.path.exists(save_name) and not args.get('force', False): raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...") if not os.path.exists(train_file): raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.") logger.info("Running training script with the following args:") for arg in args: logger.info(f"{arg}: {args[arg]}") logger.info("------------------------------------------------------------") trainer = LemmaClassifierTrainer(model_args=args, embedding_file=wordvec_pretrain_file, use_charlm=use_charlm, charlm_forward_file=charlm_forward_file, charlm_backward_file=charlm_backward_file, lr=lr, loss_func="weighted_bce" if weighted_loss else "ce", ) trainer.train( num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_file=train_file ) return trainer if __name__ == "__main__": main() ================================================ FILE: stanza/models/lemma_classifier/train_many.py ================================================ """ Utils for training and evaluating multiple models simultaneously """ import argparse import os from stanza.models.lemma_classifier.train_lstm_model import main as train_lstm_main from stanza.models.lemma_classifier.train_transformer_model import main as train_tfmr_main from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE change_params_map = { "lstm_layer": [16, 32, 64, 128, 256, 512], "upos_emb_dim": [5, 10, 20, 30], "training_size": [150, 300, 450, 600, 'full'], } # TODO: Add attention def train_n_models(num_models: int, base_path: str, args): if args.change_param == "lstm_layer": for num_layers in change_params_map.get("lstm_layer", None): for i in range(num_models): new_save_name = os.path.join(base_path, f"{num_layers}_{i}.pt") args.save_name = new_save_name args.hidden_dim = num_layers train_lstm_main(predefined_args=args) if args.change_param == "upos_emb_dim": for upos_dim in change_params_map("upos_emb_dim", None): for i in range(num_models): new_save_name = os.path.join(base_path, f"dim_{upos_dim}_{i}.pt") args.save_name = new_save_name args.upos_emb_dim = upos_dim train_lstm_main(predefined_args=args) if args.change_param == "training_size": for size in change_params_map.get("training_size", None): for i in range(num_models): new_save_name = os.path.join(base_path, f"{size}_examples_{i}.pt") new_train_file = os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt") args.save_name = new_save_name args.train_file = new_train_file train_lstm_main(predefined_args=args) if args.change_param == "base": for i in range(num_models): new_save_name = os.path.join(base_path, f"lstm_model_{i}.pt") args.save_name = new_save_name args.weighted_loss = False train_lstm_main(predefined_args=args) if not args.weighted_loss: args.weighted_loss = True new_save_name = os.path.join(base_path, f"lstm_model_wloss_{i}.pt") args.save_name = new_save_name train_lstm_main(predefined_args=args) if args.change_param == "base_charlm": for i in range(num_models): new_save_name = os.path.join(base_path, f"lstm_charlm_{i}.pt") args.save_name = new_save_name train_lstm_main(predefined_args=args) if args.change_param == "base_charlm_upos": for i in range(num_models): new_save_name = os.path.join(base_path, f"lstm_charlm_upos_{i}.pt") args.save_name = new_save_name train_lstm_main(predefined_args=args) if args.change_param == "base_upos": for i in range(num_models): new_save_name = os.path.join(base_path, f"lstm_upos_{i}.pt") args.save_name = new_save_name train_lstm_main(predefined_args=args) if args.change_param == "attn_model": for i in range(num_models): new_save_name = os.path.join(base_path, f"attn_model_{args.num_heads}_heads_{i}.pt") args.save_name = new_save_name train_lstm_main(predefined_args=args) def train_n_tfmrs(num_models: int, base_path: str, args): if args.multi_train_type == "tfmr": for i in range(num_models): if args.change_param == "bert": new_save_name = os.path.join(base_path, f"bert_{i}.pt") args.save_name = new_save_name args.loss_fn = "ce" train_tfmr_main(predefined_args=args) new_save_name = os.path.join(base_path, f"bert_wloss_{i}.pt") args.save_name = new_save_name args.loss_fn = "weighted_bce" train_tfmr_main(predefined_args=args) elif args.change_param == "roberta": new_save_name = os.path.join(base_path, f"roberta_{i}.pt") args.save_name = new_save_name args.loss_fn = "ce" train_tfmr_main(predefined_args=args) new_save_name = os.path.join(base_path, f"roberta_wloss_{i}.pt") args.save_name = new_save_name args.loss_fn = "weighted_bce" train_tfmr_main(predefined_args=args) def main(): parser = argparse.ArgumentParser() parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read') parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.") parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.') parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.") parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file") parser.add_argument("--lr", type=float, default=0.001, help="learning rate") parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch") parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file") parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.") parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") # Tfmr-specific args parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')") parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')") # Multi-model train args parser.add_argument("--multi_train_type", type=str, default="lstm", help="Whether you are attempting to multi-train an LSTM or transformer") parser.add_argument("--multi_train_count", type=int, default=5, help="Number of each model to build") parser.add_argument("--base_path", type=str, default=None, help="Path to start generating model type for.") parser.add_argument("--change_param", type=str, default=None, help="Which hyperparameter to change when training") args = parser.parse_args() if args.multi_train_type == "lstm": train_n_models(num_models=args.multi_train_count, base_path=args.base_path, args=args) elif args.multi_train_type == "tfmr": train_n_tfmrs(num_models=args.multi_train_count, base_path=args.base_path, args=args) else: raise ValueError(f"Improper input {args.multi_train_type}") if __name__ == "__main__": main() ================================================ FILE: stanza/models/lemma_classifier/train_transformer_model.py ================================================ """ This file contains code used to train a baseline transformer model to classify on a lemma of a particular token. """ import argparse import os import sys import logging import torch import torch.nn as nn import torch.optim as optim from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer from stanza.models.common.utils import default_device logger = logging.getLogger('stanza.lemmaclassifier') class TransformerBaselineTrainer(BaseLemmaClassifierTrainer): """ Class to assist with training a baseline transformer model to classify on token lemmas. To find the model spec, refer to `model.py` in this directory. """ def __init__(self, model_args: dict, transformer_name: str = "roberta", loss_func: str = "ce", lr: int = 0.001): """ Creates the Trainer object Args: transformer_name (str, optional): What kind of transformer to use for embeddings. Defaults to "roberta". loss_func (str, optional): Which loss function to use (either 'ce' or 'weighted_bce'). Defaults to "ce". lr (int, optional): learning rate for the optimizer. Defaults to 0.001. """ super().__init__() self.model_args = model_args # Find loss function if loss_func == "ce": self.criterion = nn.CrossEntropyLoss() self.weighted_loss = False elif loss_func == "weighted_bce": self.criterion = nn.BCEWithLogitsLoss() self.weighted_loss = True # used to add weights during train time. else: raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')") self.transformer_name = transformer_name self.lr = lr def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torch.optim: """ Sets learning rates for each layer of the model. Currently, the model has the transformer layer and the MLP layer, so these are tweakable. Returns (torch.optim): An Adam optimizer with the learning rates adjusted per layer. Currently unused - could be refactored into the parent class's train method, or the parent class could call a build_optimizer and this subclass would use the optimizer """ transformer_params, mlp_params = [], [] for name, param in self.model.named_parameters(): if 'transformer' in name: transformer_params.append(param) elif 'mlp' in name: mlp_params.append(param) optimizer = optim.Adam([ {"params": transformer_params, "lr": transformer_lr}, {"params": mlp_params, "lr": mlp_lr} ]) return optimizer def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos): return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words, target_upos=target_upos) def main(args=None, predefined_args=None): parser = argparse.ArgumentParser() parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "saved_models", "big_model_roberta_weighted_loss.pt"), help="Path to model save file") parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs") parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_train.txt"), help="Full path to training file") parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')") parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')") parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch") parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.") parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file') args = parser.parse_args(args) if predefined_args is None else predefined_args save_name = args.save_name num_epochs = args.num_epochs train_file = args.train_file loss_fn = args.loss_fn eval_file = args.eval_file lr = args.lr args = vars(args) if args['model_type'] == 'bert': args['bert_model'] = 'bert-base-uncased' elif args['model_type'] == 'roberta': args['bert_model'] = 'roberta-base' elif args['model_type'] == 'transformer': if args['bert_model'] is None: raise ValueError("Need to specify a bert_model for model_type transformer!") else: raise ValueError("Unknown model type " + args['model_type']) if os.path.exists(save_name) and not args.get('force', False): raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...") if not os.path.exists(train_file): raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.") logger.info("Running training script with the following args:") for arg in args: logger.info(f"{arg}: {args[arg]}") logger.info("------------------------------------------------------------") trainer = TransformerBaselineTrainer(model_args=args, transformer_name=args['bert_model'], loss_func=loss_fn, lr=lr) trainer.train(num_epochs=num_epochs, save_name=save_name, train_file=train_file, args=args, eval_file=eval_file) return trainer if __name__ == "__main__": main() ================================================ FILE: stanza/models/lemma_classifier/transformer_model.py ================================================ import torch import torch.nn as nn import os import sys import logging from transformers import AutoTokenizer, AutoModel from typing import Mapping, List, Tuple, Any from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence from stanza.models.common.bert_embedding import extract_bert_embeddings from stanza.models.lemma_classifier.base_model import LemmaClassifier from stanza.models.lemma_classifier.constants import ModelType logger = logging.getLogger('stanza.lemmaclassifier') class LemmaClassifierWithTransformer(LemmaClassifier): def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set, target_upos: set): """ Model architecture: Use a transformer (BERT or RoBERTa) to extract contextual embedding over a sentence. Get the embedding for the word that is to be classified on, and feed the embedding as input to an MLP classifier that has 2 linear layers, and a prediction head. Args: model_args (dict): args for the model output_dim (int): Dimension of the output from the MLP transformer_name (str): name of the HF transformer to use label_decoder (dict): a map of the labels available to the model target_words (set(str)): a set of the words which might need lemmatization """ super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words, target_upos) self.model_args = model_args # Choose transformer self.transformer_name = transformer_name self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True) self.add_unsaved_module("transformer", AutoModel.from_pretrained(transformer_name)) config = self.transformer.config embedding_size = config.hidden_size # define an MLP layer self.mlp = nn.Sequential( nn.Linear(embedding_size, 64), nn.ReLU(), nn.Linear(64, output_dim) ) def get_save_dict(self): save_dict = { "params": self.state_dict(), "label_decoder": self.label_decoder, "target_words": list(self.target_words), "target_upos": list(self.target_upos), "model_type": self.model_type().name, "args": self.model_args, } skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)] for k in skipped: del save_dict["params"][k] return save_dict def convert_tags(self, upos_tags: List[List[str]]): return None def forward(self, idx_positions: List[int], sentences: List[List[str]], upos_tags: List[List[int]]): """ Computes the forward pass of the transformer baselines Args: idx_positions (List[int]): A list of the position index of the target token for lemmatization classification in each sentence. sentences (List[List[str]]): A list of the token-split sentences of the input data. upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence - not used in this model, here for compatibility Returns: torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences. """ device = next(self.transformer.parameters()).device bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device, keep_endpoints=False, num_layers=1, detach=True) embeddings = [emb[idx] for idx, emb in zip(idx_positions, bert_embeddings)] embeddings = torch.stack(embeddings, dim=0)[:, :, 0] # pass to the MLP output = self.mlp(embeddings) return output def model_type(self): return ModelType.TRANSFORMER ================================================ FILE: stanza/models/lemma_classifier/utils.py ================================================ from collections import Counter import json import logging import os import random from typing import List, Tuple, Any, Mapping import stanza import torch from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE logger = logging.getLogger('stanza.lemmaclassifier') class Dataset: def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True): """ Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence. Args: data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line. batch_size (int): Size of each batch of examples get_counts (optional, bool): Whether there should be a map of the label index to counts Returns: 1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence 2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence 3. List[torch.tensor[int]]: A batch of labels for the target token's lemma 4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.) 5 (Optional): A mapping of label ID to counts in the dataset. 6. Mapping[str, int]: A map between the labels and their indexes 7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches """ if data_path is None or not os.path.exists(data_path): raise FileNotFoundError(f"Data file {data_path} could not be found.") if label_decoder is None: label_decoder = {} else: # if labels in the test set aren't in the original model, # the model will never predict those labels, # but we can still use those labels in a confusion matrix label_decoder = dict(label_decoder) logger.debug("Final label decoder: %s Should be strings to ints", label_decoder) # words which we are analyzing target_words = set() # all known words in the dataset, not just target words known_words = set() with open(data_path, "r+", encoding="utf-8") as fin: sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), {} input_json = json.load(fin) sentences_data = input_json['sentences'] self.target_upos = input_json['upos'] for idx, sentence in enumerate(sentences_data): # TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma") if None in [words, target_idx, upos_tags, label]: raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}") label_id = label_decoder.get(label, None) if label_id is None: label_decoder[label] = len(label_decoder) # create a new ID for the unknown label converted_upos_tags = [] # convert upos tags to upos IDs for upos_tag in upos_tags: if upos_tag not in upos_to_id: upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag converted_upos_tags.append(upos_to_id[upos_tag]) sentences.append(words) indices.append(target_idx) upos_ids.append(converted_upos_tags) labels.append(label_decoder[label]) if get_counts: counts[label_decoder[label]] += 1 target_words.add(words[target_idx]) known_words.update(words) self.sentences = sentences self.indices = indices self.upos_ids = upos_ids self.labels = labels self.counts = counts self.label_decoder = label_decoder self.upos_to_id = upos_to_id self.batch_size = batch_size self.shuffle = shuffle self.known_words = [x.lower() for x in sorted(known_words)] self.target_words = set(x.lower() for x in target_words) def __len__(self): """ Number of batches, rounded up to nearest batch """ return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0) def __iter__(self): num_sentences = len(self.sentences) indices = list(range(num_sentences)) if self.shuffle: random.shuffle(indices) for i in range(self.__len__()): batch_start = self.batch_size * i batch_end = min(batch_start + self.batch_size, num_sentences) batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]] batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]]) batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]] batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]]) yield batch_sentences, batch_indices, batch_upos_ids, batch_labels def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]: """ Extracts the indices within `tokenized_indices` which match `unknown_token_idx` Args: tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices. unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors. Returns: List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index` """ return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx] def get_device(): """ Get the device to run computations on """ if torch.cuda.is_available: device = torch.device("cuda") if torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") return device def round_up_to_multiple(number, multiple): if multiple == 0: return "Error: The second number (multiple) cannot be zero." # Calculate the remainder when dividing the number by the multiple remainder = number % multiple # If remainder is non-zero, round up to the next multiple if remainder != 0: rounded_number = number + (multiple - remainder) else: rounded_number = number # No rounding needed return rounded_number def main(): default_test_path = os.path.join(os.path.dirname(__file__), "test_sets", "processed_ud_en", "combined_dev.txt") # get the GUM stuff sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True) if __name__ == "__main__": main() ================================================ FILE: stanza/models/lemmatizer.py ================================================ """ Entry point for training and evaluating a lemmatizer. This lemmatizer combines a neural sequence-to-sequence architecture with an `edit` classifier and two dictionaries to produce robust lemmas from word forms. For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf. """ import logging import sys import os import shutil import time from datetime import datetime import argparse import numpy as np import random import torch from torch import nn, optim from stanza.models.lemma.data import DataLoader from stanza.models.lemma.vocab import Vocab from stanza.models.lemma.trainer import Trainer from stanza.models.lemma import scorer, edit from stanza.models.common import utils import stanza.models.common.seq2seq_constant as constant from stanza.models.common.doc import * from stanza.utils.conll import CoNLL from stanza.models import _training_logging logger = logging.getLogger('stanza') def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.') parser.add_argument('--train_file', type=str, default=None, help='Training input file for data loader.') parser.add_argument('--eval_file', type=str, default=None, help='Evaluation input file for data loader.') parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.') parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--shorthand', type=str, help='Shorthand for the dataset to use. lang_dataset') parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default use ensemble.') parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based lemmatizer.') parser.add_argument('--hidden_dim', type=int, default=200) parser.add_argument('--emb_dim', type=int, default=50) parser.add_argument('--num_layers', type=int, default=1) parser.add_argument('--emb_dropout', type=float, default=0.5) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--max_dec_len', type=int, default=50) parser.add_argument('--beam_size', type=int, default=1) parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type') parser.add_argument('--pos_dim', type=int, default=50) parser.add_argument('--pos_dropout', type=float, default=0.5) parser.add_argument('--no_edit', dest='edit', action='store_false', help='Do not use edit classifier in lemmatization. By default use an edit classifier.') parser.add_argument('--num_edit', type=int, default=len(edit.EDIT_TO_ID)) parser.add_argument('--alpha', type=float, default=1.0) parser.add_argument('--no_pos', dest='pos', action='store_false', help='Do not use UPOS in lemmatization. By default UPOS is used.') parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in lemmatization. By default copy mechanism is used to improve generalization.') parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.') parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--lr_decay', type=float, default=0.9) parser.add_argument('--decay_epoch', type=int, default=30, help="Decay the lr starting from this epoch.") parser.add_argument('--num_epoch', type=int, default=60) parser.add_argument('--batch_size', type=int, default=50) parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.') parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.') parser.add_argument('--save_dir', type=str, default='saved_models/lemma', help='Root dir for saving models.') parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_lemmatizer.pt", help="File name to save the model") parser.add_argument('--caseless', default=False, action='store_true', help='Lowercase everything first before processing. This will happen automatically if 100%% of the data is caseless') parser.add_argument('--skip_blank_lemmas', default=False, action='store_true', help='Skip blank entries in the data files. Useful for training a lemmatizer from a partially annotated dataset') parser.add_argument('--seed', type=int, default=1234) utils.add_device_args(parser) parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') return parser def parse_args(args=None): parser = build_argparse() args = parser.parse_args(args=args) if args.wandb_name: args.wandb = True args = vars(args) # when building the vocab, we keep track of the original language name lang = args['shorthand'].split("_")[0] if args['shorthand'] else "" args['lang'] = lang return args def main(args=None): args = parse_args(args=args) utils.set_random_seed(args['seed']) logger.info("Running lemmatizer in {} mode".format(args['mode'])) if args['mode'] == 'train': train(args) else: evaluate(args) def all_lowercase(doc): for sentence in doc.sentences: for word in sentence.words: if word.text.lower() != word.text: return False return True def build_model_filename(args): embedding = "nocharlm" if args['charlm'] and args['charlm_forward_file']: embedding = "charlm" model_file = args['save_name'].format(shorthand=args['shorthand'], embedding=embedding) model_dir = os.path.split(model_file)[0] if not model_dir.startswith(args['save_dir']): model_file = os.path.join(args['save_dir'], model_file) return model_file def train(args): # load data logger.info("[Loading data with batch size {}...]".format(args['batch_size'])) train_doc = CoNLL.conll2doc(input_file=args['train_file']) train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False) vocab = train_batch.vocab args['vocab_size'] = vocab['char'].size args['pos_vocab_size'] = vocab['pos'].size dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True) utils.ensure_dir(args['save_dir']) model_file = build_model_filename(args) logger.info("Using full savename: %s", model_file) # gold path gold_file = args['eval_file'] utils.print_config(args) # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: logger.warning("[Skip training because no training data available...]") return if not args['caseless'] and all_lowercase(train_doc): logger.info("Building a caseless model, as all of the training data is caseless") args['caseless'] = True # start training # train a dictionary-based lemmatizer logger.info("Building lemmatizer in %s", model_file) trainer = Trainer(args=args, vocab=vocab, device=args['device']) logger.info("[Training dictionary-based lemmatizer...]") trainer.train_dict(train_batch.raw_data()) logger.info("Evaluating on dev set...") dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS])) dev_batch.doc.set([LEMMA], dev_preds) system_pred_file = "{:C}\n\n".format(dev_batch.doc) system_pred_file = io.StringIO(system_pred_file) _, _, dev_f = scorer.score(system_pred_file, gold_file) logger.info("Dev F1 = {:.2f}".format(dev_f * 100)) if args.get('dict_only', False): # save dictionaries trainer.save(model_file) else: if args['wandb']: import wandb wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_lemmatizer" % args['shorthand'] wandb.init(name=wandb_name, config=args) wandb.run.define_metric('train_loss', summary='min') wandb.run.define_metric('dev_score', summary='max') # train a seq2seq model logger.info("[Training seq2seq-based lemmatizer...]") global_step = 0 max_steps = len(train_batch) * args['num_epoch'] dev_score_history = [] best_dev_preds = [] current_lr = args['lr'] global_start_time = time.time() format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}' # start training for epoch in range(1, args['num_epoch']+1): train_loss = 0 for i, batch in enumerate(train_batch): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step train_loss += loss if global_step % args['log_step'] == 0: duration = time.time() - start_time logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step, max_steps, epoch, args['num_epoch'], loss, duration, current_lr)) # eval on dev logger.info("Evaluating on dev set...") dev_preds = [] dev_edits = [] for i, batch in enumerate(dev_batch): preds, edits = trainer.predict(batch, args['beam_size']) dev_preds += preds if edits is not None: dev_edits += edits dev_preds = trainer.postprocess(dev_batch.doc.get([TEXT]), dev_preds, edits=dev_edits) # try ensembling with dict if necessary if args.get('ensemble_dict', False): logger.info("[Ensembling dict with seq2seq model...]") dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds) dev_batch.doc.set([LEMMA], dev_preds) system_pred_file = "{:C}\n\n".format(dev_batch.doc) system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, gold_file) train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score)) if args['wandb']: wandb.log({'train_loss': train_loss, 'dev_score': dev_score}) # save best model if epoch == 1 or dev_score > max(dev_score_history): trainer.save(model_file) logger.info("new best model saved.") best_dev_preds = dev_preds # lr schedule if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1] and \ args['optim'] in ['sgd', 'adagrad']: current_lr *= args['lr_decay'] trainer.update_lr(current_lr) dev_score_history += [dev_score] logger.info("") logger.info("Training ended with {} epochs.".format(epoch)) if args['wandb']: wandb.finish() best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1 logger.info("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch)) def evaluate(args): # file paths system_pred_file = args['output_file'] model_file = build_model_filename(args) # load model trainer = Trainer(model_file=model_file, device=args['device'], args=args) loaded_args, vocab = trainer.args, trainer.vocab for k in args: if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']: loaded_args[k] = args[k] # load data logger.info("Loading data with batch size {}...".format(args['batch_size'])) doc = CoNLL.conll2doc(input_file=args['eval_file']) batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True) # skip eval if dev data does not exist if len(batch) == 0: logger.warning("Skip evaluation because no dev data is available...\nLemma score:\n{} ".format(args['shorthand'])) return dict_preds = trainer.predict_dict(batch.doc.get([TEXT, UPOS])) if loaded_args.get('dict_only', False): preds = dict_preds else: logger.info("Running the seq2seq model...") preds = [] edits = [] for i, b in enumerate(batch): ps, es = trainer.predict(b, args['beam_size']) preds += ps if es is not None: edits += es preds = trainer.postprocess(batch.doc.get([TEXT]), preds, edits=edits) if loaded_args.get('ensemble_dict', False): logger.info("[Ensembling dict with seq2seq lemmatizer...]") preds = trainer.ensemble(batch.doc.get([TEXT, UPOS]), preds) if trainer.has_contextual_lemmatizers(): preds = trainer.update_contextual_preds(batch.doc, preds) # write to file and score batch.doc.set([LEMMA], preds) if system_pred_file: CoNLL.write_doc2conll(batch.doc, system_pred_file) system_pred_file = "{:C}\n\n".format(batch.doc) system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, args['eval_file']) logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100)) if __name__ == '__main__': main() ================================================ FILE: stanza/models/mwt/__init__.py ================================================ ================================================ FILE: stanza/models/mwt/character_classifier.py ================================================ """ Classify characters based on an LSTM with learned character representations """ import logging import torch from torch import nn import stanza.models.common.seq2seq_constant as constant logger = logging.getLogger('stanza') class CharacterClassifier(nn.Module): def __init__(self, args): super().__init__() self.vocab_size = args['vocab_size'] self.emb_dim = args['emb_dim'] self.hidden_dim = args['hidden_dim'] self.nlayers = args['num_layers'] # lstm encoder layers self.pad_token = constant.PAD_ID self.enc_hidden_dim = self.hidden_dim // 2 # since it is bidirectional self.num_outputs = 2 self.args = args self.emb_dropout = args.get('emb_dropout', 0.0) self.emb_drop = nn.Dropout(self.emb_dropout) self.dropout = args['dropout'] self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token) self.input_dim = self.emb_dim self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \ bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0) self.output_layer = nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, self.num_outputs)) def encode(self, enc_inputs, lens): """ Encode source sequence. """ packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True) packed_h_in, (hn, cn) = self.encoder(packed_inputs) return packed_h_in def embed(self, src, src_mask): # the input data could have characters outside the known range # of characters in cases where the vocabulary was temporarily # expanded (note that this model does nothing with those chars) embed_src = src.clone() embed_src[embed_src >= self.vocab_size] = constant.UNK_ID enc_inputs = self.emb_drop(self.embedding(embed_src)) batch_size = enc_inputs.size(0) src_lens = list(src_mask.data.eq(self.pad_token).long().sum(1)) return enc_inputs, batch_size, src_lens, src_mask def forward(self, src, src_mask): enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask) encoded = self.encode(enc_inputs, src_lens) encoded, _ = nn.utils.rnn.pad_packed_sequence(encoded, batch_first=True) logits = self.output_layer(encoded) return logits ================================================ FILE: stanza/models/mwt/data.py ================================================ import random import numpy as np import os from collections import Counter, namedtuple import logging import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader as DL import stanza.models.common.seq2seq_constant as constant from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all from stanza.models.common.vocab import DeltaVocab from stanza.models.mwt.vocab import Vocab from stanza.models.common.doc import Document logger = logging.getLogger('stanza') DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text") DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx") # enforce that the MWT splitter knows about a couple different alternate apostrophes # including covering some potential " typos # setting the augmentation to a very low value should be enough to teach it # about the unknown characters without messing up the predictions for other text # # 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07 APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''') class DataLoader: def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False): self.batch_size = batch_size self.args = args self.augment_apos = args.get('augment_apos', 0.0) self.evaluation = evaluation self.doc = doc data = self.load_doc(self.doc, evaluation=self.evaluation) # handle vocab if vocab is None: assert self.evaluation == False # for eval vocab must exist self.vocab = self.init_vocab(data) if self.augment_apos > 0 and any(x in self.vocab for x in APOS): for apos in APOS: self.vocab.add_unit(apos) elif expand_unk_vocab: self.vocab = DeltaVocab(data, vocab) else: self.vocab = vocab # filter and sample data if args.get('sample_train', 1.0) < 1.0 and not self.evaluation: keep = int(args['sample_train'] * len(data)) data = random.sample(data, keep) logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) # shuffle for training if not self.evaluation: indices = list(range(len(data))) random.shuffle(indices) data = [data[i] for i in indices] self.data = data self.num_examples = len(data) def init_vocab(self, data): assert self.evaluation == False # for eval vocab must exist vocab = Vocab(data, self.args['shorthand']) return vocab def maybe_augment_apos(self, datum): for original in APOS: if original in datum[0]: if random.uniform(0,1) < self.augment_apos: replacement = random.choice(APOS) datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement)) break return datum def process(self, sample): if not self.evaluation and self.augment_apos > 0: sample = self.maybe_augment_apos(sample) src = list(sample[0]) src = [constant.SOS] + src + [constant.EOS] tgt_in, tgt_out = self.prepare_target(self.vocab, sample) src = self.vocab.map(src) processed = [src, tgt_in, tgt_out, sample[0]] return processed def prepare_target(self, vocab, datum): if self.evaluation: tgt = list(datum[0]) # as a placeholder else: tgt = list(datum[1]) tgt_in = vocab.map([constant.SOS] + tgt) tgt_out = vocab.map(tgt + [constant.EOS]) return tgt_in, tgt_out def __len__(self): return len(self.data) def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError sample = self.data[key] sample = self.process(sample) assert len(sample) == 4 src = torch.tensor(sample[0]) tgt_in = torch.tensor(sample[1]) tgt_out = torch.tensor(sample[2]) orig_text = sample[3] result = DataSample(src, tgt_in, tgt_out, orig_text), key return result @staticmethod def __collate_fn(data): (data, idx) = zip(*data) (src, tgt_in, tgt_out, orig_text) = zip(*data) # collate_fn is given a list of length batch size batch_size = len(data) # need to sort by length of src to properly handle # the batching in the model itself lens = [len(x) for x in src] (src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens) lens = [len(x) for x in src] # convert to tensors src = pad_sequence(src, True, constant.PAD_ID) src_mask = torch.eq(src, constant.PAD_ID) tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID) tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID) assert tgt_in.size(1) == tgt_out.size(1), \ "Target input and output sequence sizes do not match." return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx) def __iter__(self): for i in range(self.__len__()): yield self.__getitem__(i) def to_loader(self): """Converts self to a DataLoader """ batch_size = self.batch_size shuffle = not self.evaluation return DL(self, collate_fn=self.__collate_fn, batch_size=batch_size, shuffle=shuffle) def load_doc(self, doc, evaluation=False): data = doc.get_mwt_expansions(evaluation) if evaluation: data = [[e] for e in data] return data class BinaryDataLoader(DataLoader): """ This version of the DataLoader performs the same tasks as the regular DataLoader, except the targets are arrays of 0/1 indicating if the character is the location of an MWT split """ def prepare_target(self, vocab, datum): src = datum[0] if self.evaluation else datum[1] binary = [0] has_space = False for char in src: if char == ' ': has_space = True elif has_space: has_space = False binary.append(1) else: binary.append(0) binary.append(0) return binary, binary ================================================ FILE: stanza/models/mwt/scorer.py ================================================ """ Utils and wrappers for scoring MWT """ from stanza.models.common.utils import ud_scores def score(system_conllu_file, gold_conllu_file): """ Wrapper for word segmenter scorer. """ evaluation = ud_scores(gold_conllu_file, system_conllu_file) el = evaluation["Words"] p, r, f = el.precision, el.recall, el.f1 return p, r, f ================================================ FILE: stanza/models/mwt/trainer.py ================================================ """ A trainer class to handle training and testing of models. """ import sys import numpy as np from collections import Counter import logging import torch from torch import nn import torch.nn.init as init import stanza.models.common.seq2seq_constant as constant from stanza.models.common.trainer import Trainer as BaseTrainer from stanza.models.common.seq2seq_model import Seq2SeqModel from stanza.models.common import utils, loss from stanza.models.mwt.character_classifier import CharacterClassifier from stanza.models.mwt.vocab import Vocab logger = logging.getLogger('stanza') def unpack_batch(batch, device): """ Unpack a batch from the data loader. """ inputs = [b.to(device) if b is not None else None for b in batch[:4]] orig_text = batch[4] orig_idx = batch[5] return inputs, orig_text, orig_idx class Trainer(BaseTrainer): """ A trainer for training models. """ def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None): if model_file is not None: # load from file self.load(model_file) else: self.args = args if args['dict_only']: self.model = None elif args.get('force_exact_pieces', False): self.model = CharacterClassifier(args) else: self.model = Seq2SeqModel(args, emb_matrix=emb_matrix) self.vocab = vocab self.expansion_dict = dict() if not self.args['dict_only']: self.model = self.model.to(device) if self.args.get('force_exact_pieces', False): self.crit = nn.CrossEntropyLoss() else: self.crit = loss.SequenceLoss(self.vocab.size).to(device) self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr']) def update(self, batch, eval=False): device = next(self.model.parameters()).device # ignore the original text when training # can try to learn the correct values, even if we eventually # copy directly from the original text inputs, _, orig_idx = unpack_batch(batch, device) src, src_mask, tgt_in, tgt_out = inputs if eval: self.model.eval() else: self.model.train() self.optimizer.zero_grad() if self.args.get('force_exact_pieces', False): log_probs = self.model(src, src_mask) src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1)) packed_output = nn.utils.rnn.pack_padded_sequence(log_probs, src_lens, batch_first=True) packed_tgt = nn.utils.rnn.pack_padded_sequence(tgt_in, src_lens, batch_first=True) loss = self.crit(packed_output.data, packed_tgt.data) else: log_probs, _ = self.model(src, src_mask, tgt_in) loss = self.crit(log_probs.view(-1, self.vocab.size), tgt_out.view(-1)) loss_val = loss.data.item() if eval: return loss_val loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() return loss_val def predict(self, batch, unsort=True, never_decode_unk=False, vocab=None): if vocab is None: vocab = self.vocab device = next(self.model.parameters()).device inputs, orig_text, orig_idx = unpack_batch(batch, device) src, src_mask, tgt, tgt_mask = inputs self.model.eval() batch_size = src.size(0) if self.args.get('force_exact_pieces', False): log_probs = self.model(src, src_mask) cuts = log_probs[:, :, 1] > log_probs[:, :, 0] src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1)) pred_tokens = [] for src_ids, cut, src_len in zip(src, cuts, src_lens): src_chars = vocab.unmap(src_ids) pred_seq = [] for char_idx in range(1, src_len-1): if cut[char_idx]: pred_seq.append(' ') pred_seq.append(src_chars[char_idx]) pred_seq = "".join(pred_seq).strip() pred_tokens.append(pred_seq) else: preds, _ = self.model.predict(src, src_mask, self.args['beam_size'], never_decode_unk=never_decode_unk) pred_seqs = [vocab.unmap(ids) for ids in preds] # unmap to tokens pred_seqs = utils.prune_decoded_seqs(pred_seqs) pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens # if any tokens are predicted to expand to blank, # that is likely an error. use the original text # this originally came up with the Spanish model turning 's' into a blank # furthermore, if there are no spaces predicted by the seq2seq, # might as well use the original in case the seq2seq went crazy # this particular error came up training a Hebrew MWT pred_tokens = [x if x and ' ' in x else y for x, y in zip(pred_tokens, orig_text)] if unsort: pred_tokens = utils.unsort(pred_tokens, orig_idx) return pred_tokens def train_dict(self, pairs): """ Train a MWT expander given training word-expansion pairs. """ # accumulate counter ctr = Counter() ctr.update([(p[0], p[1]) for p in pairs]) seen = set() # find the most frequent mappings for p, _ in ctr.most_common(): w, l = p if w not in seen and w != l: self.expansion_dict[w] = l seen.add(w) return def dict_expansion(self, word): """ Check the expansion dictionary for the word along with a couple common lowercasings of the word (Leadingcase and UPPERCASE) """ expansion = self.expansion_dict.get(word) if expansion is not None: return expansion if word.isupper(): expansion = self.expansion_dict.get(word.lower()) if expansion is not None: return expansion.upper() if word[0].isupper() and word[1:].islower(): expansion = self.expansion_dict.get(word.lower()) if expansion is not None: return expansion[0].upper() + expansion[1:] # could build a truecasing model of some kind to handle cRaZyCaSe... # but that's probably too much effort return None def predict_dict(self, words): """ Predict a list of expansions given words. """ expansions = [] for w in words: expansion = self.dict_expansion(w) if expansion is not None: expansions.append(expansion) else: expansions.append(w) return expansions def ensemble(self, cands, other_preds): """ Ensemble the dict with statistical model predictions. """ expansions = [] assert len(cands) == len(other_preds) for c, pred in zip(cands, other_preds): expansion = self.dict_expansion(c) if expansion is not None: expansions.append(expansion) else: expansions.append(pred) return expansions def save(self, filename): params = { 'model': self.model.state_dict() if self.model is not None else None, 'dict': self.expansion_dict, 'vocab': self.vocab.state_dict(), 'config': self.args } try: torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) except BaseException: logger.warning("Saving failed... continuing anyway.") def load(self, filename): try: checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise self.args = checkpoint['config'] self.expansion_dict = checkpoint['dict'] if not self.args['dict_only']: if self.args.get('force_exact_pieces', False): self.model = CharacterClassifier(self.args) else: self.model = Seq2SeqModel(self.args) # could remove strict=False after rebuilding all models, # or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False self.model.load_state_dict(checkpoint['model'], strict=False) else: self.model = None self.vocab = Vocab.load_state_dict(checkpoint['vocab']) ================================================ FILE: stanza/models/mwt/utils.py ================================================ import stanza from stanza.models.common import doc from stanza.models.tokenization.data import TokenizationDataset from stanza.models.tokenization.utils import predict, decode_predictions def mwts_composed_of_words(doc): """ Return True/False if the MWTs in the doc are all exactly composed of the text in their words """ for sent_idx, sentence in enumerate(doc.sentences): for token_idx, token in enumerate(sentence.tokens): if len(token.words) > 1: expected = "".join(x.text for x in token.words) if token.text != expected: return False return True def resplit_mwt(tokens, pipeline, keep_tokens=True): """ Uses the tokenize processor and the mwt processor in the pipeline to resplit tokens into MWT tokens: a list of list of string pipeline: a Stanza pipeline which contains, at a minimum, tokenize and mwt keep_tokens: if True, enforce the old token boundaries by modify the results of the tokenize inference. Otherwise, use whatever new boundaries the model comes up with. between running the tokenize model and breaking the text into tokens, we can update all_preds to use the original token boundaries (if and only if keep_tokens == True) This method returns a Document with just the tokens and words annotated. """ if "tokenize" not in pipeline.processors: raise ValueError("Need a Pipeline with a valid tokenize processor") if "mwt" not in pipeline.processors: raise ValueError("Need a Pipeline with a valid mwt processor") tokenize_processor = pipeline.processors["tokenize"] mwt_processor = pipeline.processors["mwt"] fake_text = "\n\n".join(" ".join(sentence) for sentence in tokens) # set up batches batches = TokenizationDataset(tokenize_processor.config, input_text=fake_text, vocab=tokenize_processor.vocab, evaluation=True, dictionary=tokenize_processor.trainer.dictionary) all_preds, all_raw = predict(trainer=tokenize_processor.trainer, data_generator=batches, batch_size=tokenize_processor.trainer.args['batch_size'], max_seqlen=tokenize_processor.config.get('max_seqlen', tokenize_processor.MAX_SEQ_LENGTH_DEFAULT), use_regex_tokens=True, num_workers=tokenize_processor.config.get('num_workers', 0)) if keep_tokens: for sentence, pred in zip(tokens, all_preds): char_idx = 0 for word in sentence: if len(word) > 0: pred[char_idx:char_idx+len(word)-1] = 0 if pred[char_idx+len(word)-1] == 0: pred[char_idx+len(word)-1] = 1 char_idx += len(word) + 1 _, _, document = decode_predictions(vocab=tokenize_processor.vocab, mwt_dict=None, orig_text=fake_text, all_raw=all_raw, all_preds=all_preds, no_ssplit=True, skip_newline=tokenize_processor.trainer.args['skip_newline'], use_la_ittb_shorthand=tokenize_processor.trainer.args['shorthand'] == 'la_ittb') document = doc.Document(document, fake_text) mwt_processor.process(document) return document def main(): pipe = stanza.Pipeline("en", processors="tokenize,mwt", package="gum") tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]] doc = resplit_mwt(tokens, pipe) print(doc) doc = resplit_mwt(tokens, pipe, keep_tokens=False) print(doc) if __name__ == '__main__': main() ================================================ FILE: stanza/models/mwt/vocab.py ================================================ from collections import Counter from stanza.models.common.vocab import BaseVocab import stanza.models.common.seq2seq_constant as constant class Vocab(BaseVocab): def build_vocab(self): pairs = self.data allchars = "".join([src + tgt for src, tgt in pairs]) counter = Counter(allchars) self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True)) self._unit2id = {w:i for i, w in enumerate(self._id2unit)} def add_unit(self, unit): if unit in self._unit2id: return self._unit2id[unit] = len(self._id2unit) self._id2unit.append(unit) ================================================ FILE: stanza/models/mwt_expander.py ================================================ """ Entry point for training and evaluating a multi-word token (MWT) expander. This MWT expander combines a neural sequence-to-sequence architecture with a dictionary to decode the token into multiple words. For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf In the case of a dataset where all of the MWT exactly split into the words composing the MWT, a classifier over the characters is used instead of the seq2seq """ import io import sys import os import shutil import time from datetime import datetime import argparse import logging import math import numpy as np import random import torch from torch import nn, optim import copy from stanza.models.mwt.data import DataLoader, BinaryDataLoader from stanza.models.mwt.utils import mwts_composed_of_words from stanza.models.mwt.vocab import Vocab from stanza.models.mwt.trainer import Trainer from stanza.models.mwt import scorer from stanza.models.common import utils import stanza.models.common.seq2seq_constant as constant from stanza.models.common.doc import Document from stanza.utils.conll import CoNLL from stanza.models import _training_logging logger = logging.getLogger('stanza') def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='data/mwt', help='Root dir for saving models.') parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.') parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.') parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--lang', type=str, help='Language') parser.add_argument('--shorthand', type=str, help="Treebank shorthand") parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default ensemble a dict.') parser.add_argument('--ensemble_early_stop', action='store_true', help='Early stopping based on ensemble performance.') parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based MWT expander.') parser.add_argument('--hidden_dim', type=int, default=100) parser.add_argument('--emb_dim', type=int, default=50) parser.add_argument('--num_layers', type=int, default=None, help='Number of layers in model encoder. Defaults to 1 for seq2seq, 2 for classifier') parser.add_argument('--emb_dropout', type=float, default=0.5) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--max_dec_len', type=int, default=50) parser.add_argument('--beam_size', type=int, default=1) parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type') parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.') parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\'| to |"| |’| |ʼ|') parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)') parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help="Don't make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)") parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.') parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--lr_decay', type=float, default=0.9) parser.add_argument('--decay_epoch', type=int, default=30, help="Decay the lr starting from this epoch.") parser.add_argument('--num_epoch', type=int, default=30) parser.add_argument('--batch_size', type=int, default=50) parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.') parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.') parser.add_argument('--save_dir', type=str, default='saved_models/mwt', help='Root dir for saving models.') parser.add_argument('--save_name', type=str, default=None, help="File name to save the model") parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing") parser.add_argument('--seed', type=int, default=1234) utils.add_device_args(parser) parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') return parser def parse_args(args=None): parser = build_argparse() args = parser.parse_args(args=args) if args.wandb_name: args.wandb = True return args def main(args=None): args = parse_args(args=args) utils.set_random_seed(args.seed) args = vars(args) logger.info("Running MWT expander in {} mode".format(args['mode'])) if args['mode'] == 'train': return train(args) else: return evaluate(args) def train(args): # load data logger.debug('max_dec_len: %d' % args['max_dec_len']) logger.debug("Loading data with batch size {}...".format(args['batch_size'])) train_doc = CoNLL.conll2doc(input_file=args['train_file']) train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False) vocab = train_batch.vocab args['vocab_size'] = vocab.size dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True) utils.ensure_dir(args['save_dir']) save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand']) model_file = os.path.join(args['save_dir'], save_name) save_each_name = None if args['save_each_name']: save_each_name = os.path.join(args['save_dir'], args['save_each_name']) save_each_name = utils.build_save_each_filename(save_each_name) # pred and gold path gold_file = args['gold_file'] # skip training if the language does not have training or dev data if len(train_batch) == 0: logger.warning("Skip training because no data available...") return dev_mwt = dev_doc.get_mwt_expansions(False) if len(dev_batch) == 0 and args.get('dict_only', False): logger.warning("Training data available, but dev data has no MWTs. Only training a dict based MWT") args['dict_only'] = True if args['force_exact_pieces'] and not mwts_composed_of_words(train_doc): raise ValueError("Cannot train model with --force_exact_pieces, as the MWT in this dataset are not entirely composed of their subwords") if args['force_exact_pieces'] is None and mwts_composed_of_words(train_doc): # the force_exact_pieces mechanism trains a separate version of the MWT expander in the Trainer # (the training loop here does not need to change) # in this model, a classifier distinguishes whether or not a location is a split # and the text is copied exactly from the input rather than created via seq2seq # this behavior can be turned off at training time with --no_force_exact_pieces logger.info("Train MWTs entirely composed of their subwords. Training the MWT to match that paradigm as closely as possible") args['force_exact_pieces'] = True if args['force_exact_pieces']: logger.info("Reconverting to BinaryDataLoader") train_batch = BinaryDataLoader(train_doc, args['batch_size'], args, evaluation=False) vocab = train_batch.vocab args['vocab_size'] = vocab.size dev_batch = BinaryDataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True) if args['num_layers'] is None: if args['force_exact_pieces']: args['num_layers'] = 2 else: args['num_layers'] = 1 # train a dictionary-based MWT expander trainer = Trainer(args=args, vocab=vocab, device=args['device']) logger.info("Training dictionary-based MWT expander...") trainer.train_dict(train_batch.doc.get_mwt_expansions(evaluation=False)) logger.info("Evaluating on dev set...") dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True)) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) system_preds = "{:C}\n\n".format(doc) system_preds = io.StringIO(system_preds) _, _, dev_f = scorer.score(system_preds, gold_file) logger.info("Dev F1 = {:.2f}".format(dev_f * 100)) if args.get('dict_only', False): # save dictionaries trainer.save(model_file) else: # train a seq2seq model logger.info("Training seq2seq-based MWT expander...") global_step = 0 steps_per_epoch = math.ceil(len(train_batch) / args['batch_size']) max_steps = steps_per_epoch * args['num_epoch'] dev_score_history = [] best_dev_preds = [] current_lr = args['lr'] global_start_time = time.time() format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}' if args['wandb']: import wandb wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_mwt" % args['shorthand'] wandb.init(name=wandb_name, config=args) wandb.run.define_metric('train_loss', summary='min') wandb.run.define_metric('dev_score', summary='max') # start training for epoch in range(1, args['num_epoch']+1): train_loss = 0 for i, batch in enumerate(train_batch.to_loader()): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step train_loss += loss if global_step % args['log_step'] == 0: duration = time.time() - start_time logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,\ max_steps, epoch, args['num_epoch'], loss, duration, current_lr)) if save_each_name: trainer.save(save_each_name % epoch) logger.info("Saved epoch %d model to %s" % (epoch, save_each_name % epoch)) # eval on dev logger.info("Evaluating on dev set...") dev_preds = [] for i, batch in enumerate(dev_batch.to_loader()): preds = trainer.predict(batch) dev_preds += preds if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False): logger.info("[Ensembling dict with seq2seq model...]") dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) system_preds = "{:C}\n\n".format(doc) system_preds = io.StringIO(system_preds) _, _, dev_score = scorer.score(system_preds, gold_file) train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score)) if args['wandb']: wandb.log({'train_loss': train_loss, 'dev_score': dev_score}) # save best model if epoch == 1 or dev_score > max(dev_score_history): trainer.save(model_file) logger.info("new best model saved.") best_dev_preds = dev_preds # lr schedule if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1]: current_lr *= args['lr_decay'] trainer.change_lr(current_lr) dev_score_history += [dev_score] logger.info("Training ended with {} epochs.".format(epoch)) if args['wandb']: wandb.finish() best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1 logger.info("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch)) # try ensembling with dict if necessary if args.get('ensemble_dict', False): logger.info("[Ensembling dict with seq2seq model...]") dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) system_preds = "{:C}\n\n".format(doc) system_preds = io.StringIO(system_preds) _, _, dev_score = scorer.score(system_preds, gold_file) logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100)) best_f = max(best_f, dev_score) return trainer, _ def evaluate(args): # file paths system_pred_file = args['output_file'] gold_file = args['gold_file'] model_file = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand']) if args['save_dir'] and not model_file.startswith(args['save_dir']) and not os.path.exists(model_file): model_file = os.path.join(args['save_dir'], model_file) # load model trainer = Trainer(model_file=model_file, device=args['device']) loaded_args, vocab = trainer.args, trainer.vocab for k in args: if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']: loaded_args[k] = args[k] logger.debug('max_dec_len: %d' % loaded_args['max_dec_len']) # load data logger.debug("Loading data with batch size {}...".format(args['batch_size'])) doc = CoNLL.conll2doc(input_file=args['eval_file']) batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True) if len(batch) > 0: dict_preds = trainer.predict_dict(batch.doc.get_mwt_expansions(evaluation=True)) # decide trainer type and run eval if loaded_args['dict_only']: preds = dict_preds else: logger.info("Running the seq2seq model...") preds = [] for i, b in enumerate(batch.to_loader()): preds += trainer.predict(b) if loaded_args.get('ensemble_dict', False): preds = trainer.ensemble(batch.doc.get_mwt_expansions(evaluation=True), preds) else: # skip eval if dev data does not exist preds = [] # write to file and score doc = copy.deepcopy(batch.doc) doc.set_mwt_expansions(preds, fake_dependencies=True) if system_pred_file is not None: CoNLL.write_doc2conll(doc, system_pred_file) else: system_pred_file = "{:C}\n\n".format(doc) system_pred_file = io.StringIO(system_pred_file) if gold_file is not None: _, _, score = scorer.score(system_pred_file, gold_file) logger.info("MWT expansion score: {} {:.2f}".format(args['shorthand'], score*100)) return trainer, doc if __name__ == '__main__': main() ================================================ FILE: stanza/models/ner/__init__.py ================================================ ================================================ FILE: stanza/models/ner/data.py ================================================ import random import logging import torch from stanza.models.common.bert_embedding import filter_data, needs_length_filter from stanza.models.common.data import map_to_ids, get_long_tensor, sort_all from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX from stanza.models.pos.vocab import CharVocab, CompositeVocab, WordVocab from stanza.models.ner.vocab import MultiVocab from stanza.models.common.doc import * from stanza.models.ner.utils import process_tags, normalize_empty_tags logger = logging.getLogger('stanza') class DataLoader: def __init__(self, doc, batch_size, args, pretrain=None, vocab=None, evaluation=False, preprocess_tags=True, bert_tokenizer=None, scheme=None, max_batch_words=None): self.max_batch_words = max_batch_words self.batch_size = batch_size self.args = args self.eval = evaluation self.shuffled = not self.eval self.doc = doc self.preprocess_tags = preprocess_tags data = self._load_doc(self.doc, scheme) # filter out the long sentences if bert is used if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']): data = filter_data(self.args['bert_model'], data, bert_tokenizer) self.tags = [[w[1] for w in sent] for sent in data] # handle vocab self.pretrain = pretrain if vocab is None: self.vocab = self.init_vocab(data) else: self.vocab = vocab # filter and sample data if args.get('sample_train', 1.0) < 1.0 and not self.eval: keep = int(args['sample_train'] * len(data)) data = random.sample(data, keep) logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) data = self.preprocess(data, self.vocab, args) # shuffle for training if self.shuffled: random.shuffle(data) self.num_examples = len(data) # chunk into batches self.data = self.chunk_batches(data) logger.debug("{} batches created.".format(len(self.data))) def init_vocab(self, data): def from_model(model_filename): """ Try loading vocab from charLM model file. """ state_dict = torch.load(model_filename, lambda storage, loc: storage, weights_only=True) if 'vocab' in state_dict: return state_dict['vocab'] if 'model' in state_dict and 'vocab' in state_dict['model']: return state_dict['model']['vocab'] raise ValueError("Cannot find vocab in charLM model file %s" % model_filename) if self.eval: raise AssertionError("Vocab must exist for evaluation.") if self.args['charlm']: charvocab = CharVocab.load_state_dict(from_model(self.args['charlm_forward_file'])) else: charvocab = CharVocab(data, self.args['shorthand']) wordvocab = self.pretrain.vocab if self.pretrain is not None else None tag_data = [[(x[1],) for x in sentence] for sentence in data] tagvocab = CompositeVocab(tag_data, self.args['shorthand'], idx=0, sep=None) ignore = None if self.args['emb_finetune_known_only']: if self.pretrain is None: raise ValueError("Cannot train emb_finetune_known_only with no pretrain of known words") if self.args['lowercase']: ignore = set([w[0].lower() for sent in data for w in sent if w[0] not in wordvocab and w[0].lower() not in wordvocab]) else: ignore = set([w[0] for sent in data for w in sent if w[0] not in wordvocab]) logger.debug("Ignoring %d in the delta vocab as they did not appear in the original embedding", len(ignore)) deltavocab = WordVocab(data, self.args['shorthand'], cutoff=1, lower=self.args['lowercase'], ignore=ignore) logger.debug("Creating delta vocab of size %s", len(deltavocab)) vocabs = {'char': charvocab, 'delta': deltavocab, 'tag': tagvocab} if wordvocab is not None: vocabs['word'] = wordvocab vocab = MultiVocab(vocabs) return vocab def preprocess(self, data, vocab, args): processed = [] if args.get('char_lowercase', False): # handle character case char_case = lambda x: x.lower() else: char_case = lambda x: x for sent_idx, sent in enumerate(data): processed_sent = [[w[0] for w in sent]] processed_sent += [[vocab['char'].map([char_case(x) for x in w[0]]) for w in sent]] processed_sent += [vocab['tag'].map([w[1] for w in sent])] processed.append(processed_sent) return processed def __len__(self): return len(self.data) def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 3 # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[List[int]]] # sort sentences by lens for easy RNN operations sentlens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, sentlens) sentlens = [len(x) for x in batch[0]] # sort chars by lens for easy char-LM operations chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars(batch[1]) chars_sorted, char_orig_idx = sort_all([chars_forward, chars_backward, charoffsets_forward, charoffsets_backward], charlens) chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted charlens = [len(sent) for sent in chars_forward] # sort words by lens for easy char-RNN operations batch_words = [w for sent in batch[1] for w in sent] wordlens = [len(x) for x in batch_words] batch_words, word_orig_idx = sort_all([batch_words], wordlens) batch_words = batch_words[0] wordlens = [len(x) for x in batch_words] words = batch[0] wordchars = get_long_tensor(batch_words, len(wordlens)) wordchars_mask = torch.eq(wordchars, PAD_ID) chars_forward = get_long_tensor(chars_forward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) chars_backward = get_long_tensor(chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) chars = torch.cat([chars_forward.unsqueeze(0), chars_backward.unsqueeze(0)]) # padded forward and backward char idx charoffsets = [charoffsets_forward, charoffsets_backward] # idx for forward and backward lm to get word representation tags = get_long_tensor(batch[2], batch_size) return words, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets def __iter__(self): for i in range(self.__len__()): yield self.__getitem__(i) def _load_doc(self, doc, scheme): # preferentially load the MULTI_NER in case we are training / # testing a model with multiple layers of tags data = doc.get([TEXT, NER, MULTI_NER], as_sentences=True, from_token=True) data = [[[token[0], token[2]] if token[2] else [token[0], (token[1],)] for token in sentence] for sentence in data] if self.preprocess_tags: # preprocess tags if scheme is None: data = process_tags(data, self.args.get('scheme', 'bio')) data = normalize_empty_tags(data) return data def process_chars(self, sents): start_id, end_id = self.vocab['char'].unit2id('\n'), self.vocab['char'].unit2id(' ') # special token start_offset, end_offset = 1, 1 chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = [], [], [], [] # get char representation for each sentence for sent in sents: chars_forward_sent, chars_backward_sent, charoffsets_forward_sent, charoffsets_backward_sent = [start_id], [start_id], [], [] # forward lm for word in sent: chars_forward_sent += word charoffsets_forward_sent = charoffsets_forward_sent + [len(chars_forward_sent)] # add each token offset in the last for forward lm chars_forward_sent += [end_id] # backward lm for word in sent[::-1]: chars_backward_sent += word[::-1] charoffsets_backward_sent = [len(chars_backward_sent)] + charoffsets_backward_sent # add each offset in the first for backward lm chars_backward_sent += [end_id] # store each sentence chars_forward.append(chars_forward_sent) chars_backward.append(chars_backward_sent) charoffsets_forward.append(charoffsets_forward_sent) charoffsets_backward.append(charoffsets_backward_sent) charlens = [len(sent) for sent in chars_forward] # forward lm and backward lm should have the same lengths return chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens def reshuffle(self): data = [y for x in self.data for y in x] random.shuffle(data) self.data = self.chunk_batches(data) def chunk_batches(self, data): if self.max_batch_words is None: return [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)] batches = [] next_batch = [] for item in data: next_batch.append(item) if len(next_batch) >= self.batch_size: batches.append(next_batch) next_batch = [] if sum(len(x[0]) for x in next_batch) >= self.max_batch_words: batches.append(next_batch) next_batch = [] if len(next_batch) > 0: batches.append(next_batch) return batches ================================================ FILE: stanza/models/ner/model.py ================================================ import os import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence from stanza.models.common.data import map_to_ids, get_long_tensor from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError from stanza.models.common.packed_lstm import PackedLSTM from stanza.models.common.dropout import WordDropout, LockedDropout from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel from stanza.models.common.crf import CRFLoss from stanza.models.common.foundation_cache import load_bert from stanza.models.common.utils import attach_bert_model from stanza.models.common.vocab import PAD_ID, UNK_ID, EMPTY_ID from stanza.models.common.bert_embedding import extract_bert_embeddings logger = logging.getLogger('stanza') # this gets created in two places in trainer # in both places, pass in the bert model & tokenizer class NERTagger(nn.Module): def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None): super().__init__() self.vocab = vocab self.args = args self.unsaved_modules = [] # input layers input_size = 0 if self.args['word_emb_dim'] > 0: emb_finetune = self.args.get('emb_finetune', True) if 'word' in self.vocab: # load pretrained embeddings if specified word_emb = nn.Embedding(len(self.vocab['word']), self.args['word_emb_dim'], PAD_ID) # if a model trained with no 'delta' vocab is loaded, and # emb_finetune is off, any resaving of the model will need # the updated vectors. this is accounted for in load() if not emb_finetune or 'delta' in self.vocab: # if emb_finetune is off # or if the delta embedding is present # then we won't fine tune the original embedding self.add_unsaved_module('word_emb', word_emb) self.word_emb.weight.detach_() else: self.word_emb = word_emb if emb_matrix is not None: self.init_emb(emb_matrix) # TODO: allow for expansion of delta embedding if new # training data has new words in it? self.delta_emb = None if 'delta' in self.vocab: # zero inits seems to work better # note that the gradient will flow to the bottom and then adjust the 0 weights # as opposed to a 0 matrix cutting off the gradient if higher up in the model self.delta_emb = nn.Embedding(len(self.vocab['delta']), self.args['word_emb_dim'], PAD_ID) nn.init.zeros_(self.delta_emb.weight) # if the model was trained with a delta embedding, but emb_finetune is off now, # then we will detach the delta embedding if not emb_finetune: self.delta_emb.weight.detach_() input_size += self.args['word_emb_dim'] self.peft_name = peft_name attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved) if self.args.get('bert_model', None): # TODO: refactor bert_hidden_layers between the different models if args.get('bert_hidden_layers', False): # The average will be offset by 1/N so that the default zeros # represents an average of the N layers self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False) nn.init.zeros_(self.bert_layer_mix.weight) else: # an average of layers 2, 3, 4 will be used # (for historic reasons) self.bert_layer_mix = None input_size += self.bert_model.config.hidden_size if self.args['char'] and self.args['char_emb_dim'] > 0: if self.args['charlm']: if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']): raise ForwardCharlmNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']), args['charlm_forward_file']) if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']): raise BackwardCharlmNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']), args['charlm_backward_file']) self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False)) self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False)) input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim() else: self.charmodel = CharacterModel(args, vocab, bidirectional=True, attention=False) input_size += self.args['char_hidden_dim'] * 2 # optionally add a input transformation layer if self.args.get('input_transform', False): self.input_transform = nn.Linear(input_size, input_size) else: self.input_transform = None # recurrent layers self.taggerlstm = PackedLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, \ bidirectional=True, dropout=0 if self.args['num_layers'] == 1 else self.args['dropout']) # self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size)) self.drop_replacement = None self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False) self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False) # tag classifier tag_lengths = self.vocab['tag'].lens() self.num_output_layers = len(tag_lengths) if self.args.get('connect_output_layers'): tag_clfs = [nn.Linear(self.args['hidden_dim']*2, tag_lengths[0])] for prev_length, next_length in zip(tag_lengths[:-1], tag_lengths[1:]): tag_clfs.append(nn.Linear(self.args['hidden_dim']*2 + prev_length, next_length)) self.tag_clfs = nn.ModuleList(tag_clfs) else: self.tag_clfs = nn.ModuleList([nn.Linear(self.args['hidden_dim']*2, num_tag) for num_tag in tag_lengths]) for tag_clf in self.tag_clfs: tag_clf.bias.data.zero_() self.crits = nn.ModuleList([CRFLoss(num_tag) for num_tag in tag_lengths]) self.drop = nn.Dropout(args['dropout']) self.worddrop = WordDropout(args['word_dropout']) self.lockeddrop = LockedDropout(args['locked_dropout']) def init_emb(self, emb_matrix): if isinstance(emb_matrix, np.ndarray): emb_matrix = torch.from_numpy(emb_matrix) vocab_size = len(self.vocab['word']) dim = self.args['word_emb_dim'] assert emb_matrix.size() == (vocab_size, dim), \ "Input embedding matrix must match size: {} x {}, found {}".format(vocab_size, dim, emb_matrix.size()) self.word_emb.weight.data.copy_(emb_matrix) def add_unsaved_module(self, name, module): self.unsaved_modules += [name] setattr(self, name, module) def log_norms(self): lines = ["NORMS FOR MODEL PARAMTERS"] for name, param in self.named_parameters(): if param.requires_grad and name.split(".")[0] not in ('charmodel_forward', 'charmodel_backward'): lines.append(" %s %.6g" % (name, torch.norm(param).item())) logger.info("\n".join(lines)) def forward(self, sentences, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx): device = next(self.parameters()).device def pack(x): return pack_padded_sequence(x, sentlens, batch_first=True) inputs = [] batch_size = len(sentences) has_embedding = False if self.args['word_emb_dim'] > 0: #extract static embeddings if 'word' in self.vocab: static_words, word_mask = self.extract_static_embeddings(self.args, sentences, self.vocab['word']) word_mask = word_mask.to(device) static_words = static_words.to(device) word_static_emb = self.word_emb(static_words) has_embedding = True if 'delta' in self.vocab and self.delta_emb is not None: # masks should be the same delta_words, delta_mask = self.extract_static_embeddings(self.args, sentences, self.vocab['delta']) delta_words = delta_words.to(device) # unclear whether to treat words in the main embedding # but not in delta as unknown # simple heuristic though - treating them as not # unknown keeps existing models the same when # separating models into the base WV and delta WV # also, note that at training time, words like this # did not show up in the training data, but are # not exactly UNK, so it makes sense if has_embedding: delta_unk_mask = torch.eq(delta_words, UNK_ID) static_unk_mask = torch.not_equal(static_words, UNK_ID) unk_mask = delta_unk_mask * static_unk_mask delta_words[unk_mask] = PAD_ID else: word_mask = delta_mask.to(device) delta_emb = self.delta_emb(delta_words) if has_embedding: word_static_emb = word_static_emb + delta_emb else: has_embedding = True word_static_emb = delta_emb if has_embedding: word_emb = pack(word_static_emb) inputs += [word_emb] if self.bert_model is not None: device = next(self.parameters()).device processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, sentences, device, keep_endpoints=False, num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None, detach=not self.args.get('bert_finetune', False), peft_name=self.peft_name) if self.bert_layer_mix is not None: # use a linear layer to weighted average the embedding dynamically processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert] processed_bert = pad_sequence(processed_bert, batch_first=True) inputs += [pack(processed_bert)] if self.args['char'] and self.args['char_emb_dim'] > 0: if self.args.get('charlm', None): char_reps_forward = self.charmodel_forward.get_representation(chars[0], charoffsets[0], charlens, char_orig_idx) char_reps_forward = PackedSequence(char_reps_forward.data, char_reps_forward.batch_sizes) char_reps_backward = self.charmodel_backward.get_representation(chars[1], charoffsets[1], charlens, char_orig_idx) char_reps_backward = PackedSequence(char_reps_backward.data, char_reps_backward.batch_sizes) inputs += [char_reps_forward, char_reps_backward] else: char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens) char_reps = PackedSequence(char_reps.data, char_reps.batch_sizes) inputs += [char_reps] batch_sizes = inputs[0].batch_sizes def pad(x): return pad_packed_sequence(PackedSequence(x, batch_sizes), batch_first=True)[0] lstm_inputs = torch.cat([x.data for x in inputs], 1) if self.args['word_dropout'] > 0: lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement) lstm_inputs = self.drop(lstm_inputs) lstm_inputs = pad(lstm_inputs) lstm_inputs = self.lockeddrop(lstm_inputs) lstm_inputs = pack(lstm_inputs).data if self.input_transform: lstm_inputs = self.input_transform(lstm_inputs) lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes) lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(\ self.taggerlstm_h_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous(), \ self.taggerlstm_c_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous())) lstm_outputs = lstm_outputs.data # prediction layer lstm_outputs = self.drop(lstm_outputs) lstm_outputs = pad(lstm_outputs) lstm_outputs = self.lockeddrop(lstm_outputs) lstm_outputs = pack(lstm_outputs).data loss = 0 logits = [] trans = [] for idx, (tag_clf, crit) in enumerate(zip(self.tag_clfs, self.crits)): if not self.args.get('connect_output_layers') or idx == 0: next_logits = pad(tag_clf(lstm_outputs)).contiguous() else: # here we pack the output of the previous round, then append it packed_logits = pack(next_logits).data input_logits = torch.cat([lstm_outputs, packed_logits], axis=1) next_logits = pad(tag_clf(input_logits)).contiguous() # the tag_mask lets us avoid backprop on a blank tag tag_mask = torch.eq(tags[:, :, idx], EMPTY_ID) if has_embedding: tag_mask = torch.bitwise_or(tag_mask, word_mask) else: tag_mask = torch.bitwise_or(tag_mask, torch.eq(tags[:, :, idx], PAD_ID)) next_loss, next_trans = crit(next_logits, tag_mask, tags[:, :, idx]) loss = loss + next_loss logits.append(next_logits) trans.append(next_trans) return loss, logits, trans @staticmethod def extract_static_embeddings(args, sents, vocab): processed = [] if args.get('lowercase', True): # handle word case case = lambda x: x.lower() else: case = lambda x: x for idx, sent in enumerate(sents): processed_sent = [vocab.map([case(w) for w in sent])] processed.append(processed_sent[0]) words = get_long_tensor(processed, len(sents)) words_mask = torch.eq(words, PAD_ID) return words, words_mask ================================================ FILE: stanza/models/ner/scorer.py ================================================ """ An NER scorer that calculates F1 score given gold and predicted tags. """ import sys import os import logging from collections import Counter, defaultdict from stanza.models.ner.utils import decode_from_bioes logger = logging.getLogger('stanza') def score_by_entity(pred_tag_sequences, gold_tag_sequences, verbose=True, ignore_tags=None): """ Score predicted tags at the entity level. Args: pred_tags_sequences: a list of list of predicted tags for each word gold_tags_sequences: a list of list of gold tags for each word verbose: print log with results ignore_tags: a list or a string with a comma-separated list of tags to ignore Returns: Precision, recall and F1 scores. """ assert(len(gold_tag_sequences) == len(pred_tag_sequences)), \ "Number of predicted tag sequences does not match gold sequences." def decode_all(tag_sequences): # decode from all sequences, each sequence with a unique id ents = [] for sent_id, tags in enumerate(tag_sequences): for ent in decode_from_bioes(tags): ent['sent_id'] = sent_id ents += [ent] return ents ignore_tag_set = set() if ignore_tags: if isinstance(ignore_tags, str): ignore_tag_set.update(ignore_tags.split(",")) else: ignore_tag_set.update(ignore_tags) gold_ents = decode_all(gold_tag_sequences) gold_ents = [x for x in gold_ents if x['type'] not in ignore_tag_set] pred_ents = decode_all(pred_tag_sequences) pred_ents = [x for x in pred_ents if x['type'] not in ignore_tag_set] # scoring true_positive_by_type = Counter() false_positive_by_type = Counter() false_negative_by_type = Counter() guessed_by_type = Counter() gold_by_type = Counter() for p in pred_ents: guessed_by_type[p['type']] += 1 if p in gold_ents: true_positive_by_type[p['type']] += 1 else: false_positive_by_type[p['type']] += 1 for g in gold_ents: gold_by_type[g['type']] += 1 if g not in pred_ents: false_negative_by_type[g['type']] += 1 entities = sorted(set(list(true_positive_by_type.keys()) + list(false_positive_by_type.keys()) + list(false_negative_by_type.keys()))) entity_f1 = {} for entity in entities: entity_f1[entity] = 2 * true_positive_by_type[entity] / (2 * true_positive_by_type[entity] + false_positive_by_type[entity] + false_negative_by_type[entity]) prec_micro = 0.0 if sum(guessed_by_type.values()) > 0: prec_micro = sum(true_positive_by_type.values()) * 1.0 / sum(guessed_by_type.values()) rec_micro = 0.0 if sum(gold_by_type.values()) > 0: rec_micro = sum(true_positive_by_type.values()) * 1.0 / sum(gold_by_type.values()) f_micro = 0.0 if prec_micro + rec_micro > 0: f_micro = 2.0 * prec_micro * rec_micro / (prec_micro + rec_micro) if verbose: logger.info("Score by entity:\nPrec.\tRec.\tF1\n{:.2f}\t{:.2f}\t{:.2f}".format( prec_micro*100, rec_micro*100, f_micro*100)) return prec_micro, rec_micro, f_micro, entity_f1 def score_by_token(pred_tag_sequences, gold_tag_sequences, verbose=True, ignore_tags=None): """ Score predicted tags at the token level. Args: pred_tags_sequences: a list of list of predicted tags for each word gold_tags_sequences: a list of list of gold tags for each word verbose: print log with results ignore_tags: a list or a string with a comma-separated list of tags to ignore Returns: Precision, recall and F1 scores, along with a confusion matrix """ assert(len(gold_tag_sequences) == len(pred_tag_sequences)), \ "Number of predicted tag sequences does not match gold sequences." ignore_tag_set = set() if ignore_tags: if isinstance(ignore_tags, str): ignore_tag_set.update(ignore_tags.split(",")) else: ignore_tag_set.update(ignore_tags) def ignore_tag(tag): if tag == 'O': return True if len(tag) > 2 and (tag[1] == '_' or tag[1] == '-'): tag = tag[2:] if tag in ignore_tag_set: return True return False correct_by_tag = Counter() guessed_by_tag = Counter() gold_by_tag = Counter() confusion = defaultdict(lambda: defaultdict(int)) for gold_tags, pred_tags in zip(gold_tag_sequences, pred_tag_sequences): assert(len(gold_tags) == len(pred_tags)), \ "Number of predicted tags does not match gold." for g, p in zip(gold_tags, pred_tags): if ignore_tag(g): g = 'O' if ignore_tag(p): p = 'O' confusion[g][p] = confusion[g][p] + 1 if g == 'O' and p == 'O': continue elif g == 'O' and p != 'O': guessed_by_tag[p] += 1 elif g != 'O' and p == 'O': gold_by_tag[g] += 1 else: guessed_by_tag[p] += 1 gold_by_tag[p] += 1 if g == p: correct_by_tag[p] += 1 prec_micro = 0.0 if sum(guessed_by_tag.values()) > 0: prec_micro = sum(correct_by_tag.values()) * 1.0 / sum(guessed_by_tag.values()) rec_micro = 0.0 if sum(gold_by_tag.values()) > 0: rec_micro = sum(correct_by_tag.values()) * 1.0 / sum(gold_by_tag.values()) f_micro = 0.0 if prec_micro + rec_micro > 0: f_micro = 2.0 * prec_micro * rec_micro / (prec_micro + rec_micro) if verbose: logger.info("Score by token:\nPrec.\tRec.\tF1\n{:.2f}\t{:.2f}\t{:.2f}".format( prec_micro*100, rec_micro*100, f_micro*100)) return prec_micro, rec_micro, f_micro, confusion def test(): pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'], ['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']] gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'], ['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']] print(score_by_token(pred_sequences, gold_sequences)) print(score_by_entity(pred_sequences, gold_sequences)) if __name__ == '__main__': test() ================================================ FILE: stanza/models/ner/trainer.py ================================================ """ A trainer class to handle training and testing of models. """ import sys import logging import torch from torch import nn from stanza.models.common.foundation_cache import NoTransformerFoundationCache, load_bert, load_bert_with_peft from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper from stanza.models.common.trainer import Trainer as BaseTrainer from stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE from stanza.models.common import utils, loss from stanza.models.ner.model import NERTagger from stanza.models.ner.vocab import MultiVocab from stanza.models.common.crf import viterbi_decode logger = logging.getLogger('stanza') def unpack_batch(batch, device): """ Unpack a batch from the data loader. """ inputs = [batch[0]] inputs += [b.to(device) if b is not None else None for b in batch[1:5]] orig_idx = batch[5] word_orig_idx = batch[6] char_orig_idx = batch[7] sentlens = batch[8] wordlens = batch[9] charlens = batch[10] charoffsets = batch[11] return inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets def fix_singleton_tags(tags): """ If there are any singleton B- or E- tags, convert them to S- """ new_tags = list(tags) # first update all I- tags at the start or end of sequence to B- or E- as appropriate for idx, tag in enumerate(new_tags): if (tag.startswith("I-") and (idx == len(new_tags) - 1 or (new_tags[idx+1] != "I-" + tag[2:] and new_tags[idx+1] != "E-" + tag[2:]))): new_tags[idx] = "E-" + tag[2:] if (tag.startswith("I-") and (idx == 0 or (new_tags[idx-1] != "B-" + tag[2:] and new_tags[idx-1] != "I-" + tag[2:]))): new_tags[idx] = "B-" + tag[2:] # now make another pass through the data to update any singleton tags, # including ones which were turned into singletons by the previous operation for idx, tag in enumerate(new_tags): if (tag.startswith("B-") and (idx == len(new_tags) - 1 or (new_tags[idx+1] != "I-" + tag[2:] and new_tags[idx+1] != "E-" + tag[2:]))): new_tags[idx] = "S-" + tag[2:] if (tag.startswith("E-") and (idx == 0 or (new_tags[idx-1] != "B-" + tag[2:] and new_tags[idx-1] != "I-" + tag[2:]))): new_tags[idx] = "S-" + tag[2:] return new_tags class Trainer(BaseTrainer): """ A trainer for training models. """ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, train_classifier_only=False, foundation_cache=None, second_optim=False): if model_file is not None: # load everything from file self.load(model_file, pretrain, args, foundation_cache) else: assert args is not None assert vocab is not None # build model from scratch self.args = args self.vocab = vocab bert_model, bert_tokenizer = load_bert(self.args['bert_model']) peft_name = None if self.args['use_peft']: # fine tune the bert if we're using peft self.args['bert_finetune'] = True peft_name = "ner" # peft the lovely model bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name) emb_matrix=None if pretrain is not None: emb_matrix = pretrain.emb self.model = NERTagger(args, vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name) # IMPORTANT: gradient checkpointing BREAKS peft if applied before # 1. Apply PEFT FIRST (looksie! it's above this line) # 2. Run gradient checkpointing # https://github.com/huggingface/peft/issues/742 if self.args.get("gradient_checkpointing", False) and self.args.get("bert_finetune", False): self.model.bert_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) # if this wasn't set anywhere, we use a default of the 0th tagset # we don't set this as a default in the options so that # we can distinguish "intentionally set to 0" and "not set at all" if self.args.get('predict_tagset', None) is None: self.args['predict_tagset'] = 0 if train_classifier_only: logger.info('Disabling gradient for non-classifier layers') exclude = ['tag_clf', 'crit'] for pname, p in self.model.named_parameters(): if pname.split('.')[0] not in exclude: p.requires_grad = False self.model = self.model.to(device) if not second_optim: self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("use_peft")) else: self.optimizer = utils.get_optimizer(self.args['second_optim'], self.model, self.args['second_lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0), is_peft=self.args.get("use_peft")) def update(self, batch, eval=False): device = next(self.model.parameters()).device inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device) word, wordchars, wordchars_mask, chars, tags = inputs if eval: self.model.eval() else: self.model.train() self.optimizer.zero_grad() loss, _, _ = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx) loss_val = loss.data.item() if eval: return loss_val loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() return loss_val def predict(self, batch, unsort=True): device = next(self.model.parameters()).device inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device) word, wordchars, wordchars_mask, chars, tags = inputs self.model.eval() #batch_size = word.size(0) _, logits, trans = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx) # decode # TODO: might need to decode multiple columns of output for # models with multiple layers trans = [x.data.cpu().numpy() for x in trans] logits = [x.data.cpu().numpy() for x in logits] batch_size = logits[0].shape[0] if any(x.shape[0] != batch_size for x in logits): raise AssertionError("Expected all of the logits to have the same size") tag_seqs = [] predict_tagset = self.args['predict_tagset'] for i in range(batch_size): # for each tag column in the output, decode the tag assignments tags = [viterbi_decode(x[i, :sentlens[i]], y)[0] for x, y in zip(logits, trans)] # TODO: this is to patch that the model can sometimes predict < "O" tags = [[x if x >= VOCAB_PREFIX_SIZE else VOCAB_PREFIX_SIZE for x in y] for y in tags] # that gives us N lists of |sent| tags, whereas we want |sent| lists of N tags tags = list(zip(*tags)) # now unmap that to the tags in the vocab tags = self.vocab['tag'].unmap(tags) # for now, allow either TagVocab or CompositeVocab # TODO: we might want to return all of the predictions # rather than a single column tags = [x[predict_tagset] if isinstance(x, list) else x for x in tags] tags = fix_singleton_tags(tags) tag_seqs += [tags] if unsort: tag_seqs = utils.unsort(tag_seqs, orig_idx) return tag_seqs def save(self, filename, skip_modules=True): model_state = self.model.state_dict() # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file if skip_modules: skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules] for k in skipped: del model_state[k] params = { 'model': model_state, 'vocab': self.vocab.state_dict(), 'config': self.args } if self.args["use_peft"]: from peft import get_peft_model_state_dict params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name) try: torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) except (KeyboardInterrupt, SystemExit): raise except: logger.warning("Saving failed... continuing anyway.") def load(self, filename, pretrain=None, args=None, foundation_cache=None): try: checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise self.args = checkpoint['config'] if args: self.args.update(args) # if predict_tagset was not explicitly set in the args, # we use the value the model was trained with for keep_arg in ('predict_tagset', 'train_scheme', 'scheme'): if self.args.get(keep_arg, None) is None: self.args[keep_arg] = checkpoint['config'].get(keep_arg, None) lora_weights = checkpoint.get('bert_lora') if lora_weights: logger.debug("Found peft weights for NER; loading a peft adapter") self.args["use_peft"] = True self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) emb_matrix=None if pretrain is not None: emb_matrix = pretrain.emb force_bert_saved = False peft_name = None if self.args.get('use_peft', False): force_bert_saved = True bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "ner", foundation_cache) bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name) logger.debug("Loaded peft with name %s", peft_name) else: if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()): logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename) foundation_cache = NoTransformerFoundationCache(foundation_cache) force_bert_saved = True bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache) if any(x.startswith("crit.") for x in checkpoint['model'].keys()): logger.debug("Old model format detected. Updating to the new format with one column of tags") checkpoint['model']['crits.0._transitions'] = checkpoint['model'].pop('crit._transitions') checkpoint['model']['tag_clfs.0.weight'] = checkpoint['model'].pop('tag_clf.weight') checkpoint['model']['tag_clfs.0.bias'] = checkpoint['model'].pop('tag_clf.bias') self.model = NERTagger(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name) self.model.load_state_dict(checkpoint['model'], strict=False) # there is a possible issue with the delta embeddings. # specifically, with older models trained without the delta # embedding matrix # if those models have been trained with the embedding # modifications saved as part of the base embedding, # we need to resave the model with the updated embedding # otherwise the resulting model will be broken if 'delta' not in self.model.vocab and 'word_emb.weight' in checkpoint['model'].keys() and 'word_emb' in self.model.unsaved_modules: logger.debug("Removing word_emb from unsaved_modules so that resaving %s will keep the saved embedding", filename) self.model.unsaved_modules.remove('word_emb') def get_known_tags(self): """ Return the tags known by this model Removes the S-, B-, etc, and does not include O """ tags = set() for tag in self.vocab['tag'].items(0): if tag in VOCAB_PREFIX: continue if tag == 'O': continue if len(tag) > 2 and tag[:2] in ('S-', 'B-', 'I-', 'E-'): tag = tag[2:] tags.add(tag) return sorted(tags) ================================================ FILE: stanza/models/ner/utils.py ================================================ """ Utility functions for dealing with NER tagging. """ import logging from stanza.models.common.vocab import EMPTY logger = logging.getLogger('stanza') EMPTY_TAG = ('_', '-', '', None) EMPTY_OR_O_TAG = tuple(list(EMPTY_TAG) + ['O']) def is_basic_scheme(all_tags): """ Check if a basic tagging scheme is used. Return True if so. Args: all_tags: a list of NER tags Returns: True if the tagging scheme does not use B-, I-, etc, otherwise False """ for tag in all_tags: if len(tag) > 2 and tag[:2] in ('B-', 'I-', 'S-', 'E-', 'B_', 'I_', 'S_', 'E_'): return False return True def is_bio_scheme(all_tags): """ Check if BIO tagging scheme is used. Return True if so. Args: all_tags: a list of NER tags Returns: True if the tagging scheme is BIO, otherwise False """ for tag in all_tags: if tag in EMPTY_OR_O_TAG: continue elif len(tag) > 2 and tag[:2] in ('B-', 'I-', 'B_', 'I_'): continue else: return False return True def to_bio2(tags): """ Convert the original tag sequence to BIO2 format. If the input is already in BIO2 format, the original input is returned. Args: tags: a list of tags in either BIO or BIO2 format Returns: new_tags: a list of tags in BIO2 format """ new_tags = [] for i, tag in enumerate(tags): if tag in EMPTY_OR_O_TAG: new_tags.append(tag) elif tag[0] == 'I': if i == 0 or tags[i-1] == 'O' or tags[i-1][1:] != tag[1:]: new_tags.append('B' + tag[1:]) else: new_tags.append(tag) else: new_tags.append(tag) return new_tags def basic_to_bio(tags): """ Convert a basic tag sequence into a BIO sequence. You can compose this with bio2_to_bioes to convert to bioes Args: tags: a list of tags in basic (no B-, I-, etc) format Returns: new_tags: a list of tags in BIO format """ new_tags = [] for i, tag in enumerate(tags): if tag in EMPTY_OR_O_TAG: new_tags.append(tag) elif i == 0 or tags[i-1] == 'O' or tags[i-1] != tag: new_tags.append('B-' + tag) else: new_tags.append('I-' + tag) return new_tags def bio2_to_bioes(tags): """ Convert the BIO2 tag sequence into a BIOES sequence. Args: tags: a list of tags in BIO2 format Returns: new_tags: a list of tags in BIOES format """ new_tags = [] for i, tag in enumerate(tags): if tag in EMPTY_OR_O_TAG: new_tags.append(tag) else: if len(tag) < 2: raise Exception(f"Invalid BIO2 tag found: {tag}") else: if tag[:2] in ('I-', 'I_'): # convert to E- if next tag is not I- if i+1 < len(tags) and tags[i+1][:2] in ('I-', 'I_'): new_tags.append('I-' + tag[2:]) # compensate for underscores else: new_tags.append('E-' + tag[2:]) elif tag[:2] in ('B-', 'B_'): # convert to S- if next tag is not I- if i+1 < len(tags) and tags[i+1][:2] in ('I-', 'I_'): new_tags.append('B-' + tag[2:]) else: new_tags.append('S-' + tag[2:]) else: raise Exception(f"Invalid IOB tag found: {tag}") return new_tags def normalize_empty_tags(sentences): """ If any tags are None, _, -, or blank, turn them into EMPTY The input should be a list(sentence) of list(word) of tuple(text, list(tag)) which is the typical format for the data at the time data.py is preprocessing the tags """ new_sentences = [[(word[0], tuple(EMPTY if x in EMPTY_TAG else x for x in word[1])) for word in sentence] for sentence in sentences] return new_sentences def process_tags(sentences, scheme): """ Convert tags in these sentences to bioes We allow empty tags ('_', '-', None), which will represent tags that do not get any gradient when training """ all_words = [] all_tags = [] converted_tuples = False for sent_idx, sent in enumerate(sentences): words, tags = zip(*sent) all_words.append(words) # if we got one dimension tags w/o tuples or lists, make them tuples # but we also check that the format is consistent, # as otherwise the result being converted might be confusing if not converted_tuples and any(tag is None or isinstance(tag, str) for tag in tags): if sent_idx > 0: raise ValueError("Got a mix of tags and lists of tags. First non-list was in sentence %d" % sent_idx) converted_tuples = True if converted_tuples: if not all(tag is None or isinstance(tag, str) for tag in tags): raise ValueError("Got a mix of tags and lists of tags. First tag as a list was in sentence %d" % sent_idx) tags = [(tag,) for tag in tags] all_tags.append(tags) max_columns = max(len(x) for tags in all_tags for x in tags) for sent_idx, tags in enumerate(all_tags): if any(len(x) < max_columns for x in tags): raise ValueError("NER tags not uniform in length at sentence %d. TODO: extend those columns with O" % sent_idx) all_convert_bio_to_bioes = [] all_convert_basic_to_bioes = [] for column_idx in range(max_columns): # check if tag conversion is needed for each column # we treat each column separately, although practically # speaking it would be pretty weird for a dataset to have BIO # in one column and basic in another, for example convert_bio_to_bioes = False convert_basic_to_bioes = False tag_column = [x[column_idx] for sent in all_tags for x in sent] is_bio = is_bio_scheme(tag_column) is_basic = not is_bio and is_basic_scheme(tag_column) if is_bio and scheme.lower() == 'bioes': convert_bio_to_bioes = True logger.debug("BIO tagging scheme found in input at column %d; converting into BIOES scheme..." % column_idx) elif is_basic and scheme.lower() == 'bioes': convert_basic_to_bioes = True logger.debug("Basic tagging scheme found in input at column %d; converting into BIOES scheme..." % column_idx) all_convert_bio_to_bioes.append(convert_bio_to_bioes) all_convert_basic_to_bioes.append(convert_basic_to_bioes) result = [] for words, tags in zip(all_words, all_tags): # TODO: add a convert_basic_to_bio option as well # process tags # tags is a list of each column of tags for each word in this sentence # copy the tags to a list so we can edit them tags = [[x for x in sentence_tags] for sentence_tags in tags] for column_idx, (convert_bio_to_bioes, convert_basic_to_bioes) in enumerate(zip(all_convert_bio_to_bioes, all_convert_basic_to_bioes)): tag_column = [x[column_idx] for x in tags] if convert_basic_to_bioes: # if basic, convert tags -> bio -> bioes tag_column = bio2_to_bioes(basic_to_bio(tag_column)) else: # first ensure BIO2 scheme tag_column = to_bio2(tag_column) # then convert to BIOES if convert_bio_to_bioes: tag_column = bio2_to_bioes(tag_column) for tag_idx, tag in enumerate(tag_column): tags[tag_idx][column_idx] = tag result.append([(w,tuple(t)) for w,t in zip(words, tags)]) if converted_tuples: result = [[(word[0], word[1][0]) for word in sentence] for sentence in result] return result def decode_from_bioes(tags): """ Decode from a sequence of BIOES tags, assuming default tag is 'O'. Args: tags: a list of BIOES tags Returns: A list of dict with start_idx, end_idx, and type values. """ res = [] ent_idxs = [] cur_type = None def flush(): if len(ent_idxs) > 0: res.append({ 'start': ent_idxs[0], 'end': ent_idxs[-1], 'type': cur_type}) for idx, tag in enumerate(tags): if tag is None: tag = 'O' if tag == 'O': flush() ent_idxs = [] elif tag.startswith('B-'): # start of new ent flush() ent_idxs = [idx] cur_type = tag[2:] elif tag.startswith('I-'): # continue last ent ent_idxs.append(idx) cur_type = tag[2:] elif tag.startswith('E-'): # end last ent ent_idxs.append(idx) cur_type = tag[2:] flush() ent_idxs = [] elif tag.startswith('S-'): # start single word ent flush() ent_idxs = [idx] cur_type = tag[2:] flush() ent_idxs = [] # flush after whole sentence flush() return res def merge_tags(*sequences): """ Merge multiple sequences of NER tags into one sequence Only O is replaced, and the earlier tags have precedence """ tags = list(sequences[0]) for sequence in sequences[1:]: idx = 0 while idx < len(sequence): # skip empty tags in the later sequences if sequence[idx] == 'O': idx += 1 continue # check for singletons. copy if not O in the original if sequence[idx].startswith("S-"): if tags[idx] == 'O': tags[idx] = sequence[idx] idx += 1 continue # at this point, we know we have a B-... sequence if not sequence[idx].startswith("B-"): raise ValueError("Got unexpected tag sequence at idx {}: {}".format(idx, sequence)) # take the block of tags which are B- through E- start_idx = idx end_idx = start_idx + 1 while end_idx < len(sequence): if sequence[end_idx][2:] != sequence[start_idx][2:]: raise ValueError("Unexpected tag sequence at idx {}: {}".format(end_idx, sequence)) if sequence[end_idx].startswith("E-"): break if not sequence[end_idx].startswith("I-"): raise ValueError("Unexpected tag sequence at idx {}: {}".format(end_idx, sequence)) end_idx += 1 if end_idx == len(sequence): raise ValueError("Got a sequence with an unclosed tag: {}".format(sequence)) end_idx = end_idx + 1 # if all tags in the original are O, we can overwrite # otherwise, keep the originals if all(x == 'O' for x in tags[start_idx:end_idx]): tags[start_idx:end_idx] = sequence[start_idx:end_idx] idx = end_idx return tags ================================================ FILE: stanza/models/ner/vocab.py ================================================ from collections import Counter, OrderedDict from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab, CompositeVocab from stanza.models.common.vocab import VOCAB_PREFIX from stanza.models.common.pretrain import PretrainedWordVocab from stanza.models.pos.vocab import WordVocab class TagVocab(BaseVocab): """ A vocab for the output tag sequence. """ def build_vocab(self): counter = Counter([w[self.idx] for sent in self.data for w in sent]) self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True)) self._unit2id = {w:i for i, w in enumerate(self._id2unit)} def convert_tag_vocab(state_dict): if state_dict['lower']: raise AssertionError("Did not expect an NER vocab with 'lower' set to True") items = state_dict['_id2unit'][len(VOCAB_PREFIX):] # this looks silly, but the vocab builder treats this as words with multiple fields # (we set it to look for field 0 with idx=0) # and then the label field is expected to be a list or tuple of items items = [[[[x]]] for x in items] vocab = CompositeVocab(data=items, lang=state_dict['lang'], idx=0, sep=None) if len(vocab._id2unit[0]) != len(state_dict['_id2unit']): raise AssertionError("Failed to construct a new vocab of the same length as the original") if vocab._id2unit[0] != state_dict['_id2unit']: raise AssertionError("Failed to construct a new vocab in the same order as the original") return vocab class MultiVocab(BaseMultiVocab): def state_dict(self): """ Also save a vocab name to class name mapping in state dict. """ state = OrderedDict() key2class = OrderedDict() for k, v in self._vocabs.items(): state[k] = v.state_dict() key2class[k] = type(v).__name__ state['_key2class'] = key2class return state @classmethod def load_state_dict(cls, state_dict): class_dict = {'CharVocab': CharVocab.load_state_dict, 'PretrainedWordVocab': PretrainedWordVocab.load_state_dict, 'TagVocab': convert_tag_vocab, 'CompositeVocab': CompositeVocab.load_state_dict, 'WordVocab': WordVocab.load_state_dict} new = cls() assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!" key2class = state_dict['_key2class'] for k,v in state_dict.items(): if k == '_key2class': continue classname = key2class[k] new[k] = class_dict[classname](v) return new ================================================ FILE: stanza/models/ner_tagger.py ================================================ """ Entry point for training and evaluating an NER tagger. This tagger uses BiLSTM layers with character and word-level representations, and a CRF decoding layer to produce NER predictions. For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf. """ import sys import os import time from datetime import datetime import argparse import logging import numpy as np import random import re import json import torch from torch import nn, optim from stanza.models.ner.data import DataLoader from stanza.models.ner.trainer import Trainer from stanza.models.ner import scorer from stanza.models.common import utils from stanza.models.common.pretrain import Pretrain from stanza.utils.conll import CoNLL from stanza.models.common.doc import * from stanza.models import _training_logging from stanza.models.common.peft_config import add_peft_args, resolve_peft_args from stanza.utils.confusion import confusion_to_weighted_f1, format_confusion logger = logging.getLogger('stanza') def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='data/ner', help='Directory of NER data.') parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors') parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors') parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--eval_output_file', type=str, default=None, help='Where to write results: text, gold, pred. If None, no results file printed') parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `save_dir` path') parser.add_argument('--finetune_load_name', type=str, default=None, help='Model to load when finetuning') parser.add_argument('--train_classifier_only', action='store_true', help='In case of applying Transfer-learning approach and training only the classifier layer this will freeze gradient propagation for all other layers.') parser.add_argument('--shorthand', type=str, help="Treebank shorthand") parser.add_argument('--hidden_dim', type=int, default=256) parser.add_argument('--char_hidden_dim', type=int, default=100) parser.add_argument('--word_emb_dim', type=int, default=100) parser.add_argument('--char_emb_dim', type=int, default=100) parser.add_argument('--num_layers', type=int, default=1) parser.add_argument('--char_num_layers', type=int, default=1) parser.add_argument('--pretrain_max_vocab', type=int, default=100000) parser.add_argument('--word_dropout', type=float, default=0.01, help="How often to remove a word at training time. Set to a small value to train unk when finetuning word embeddings") parser.add_argument('--locked_dropout', type=float, default=0.0) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--rec_dropout', type=float, default=0, help="Word recurrent dropout") parser.add_argument('--char_rec_dropout', type=float, default=0, help="Character recurrent dropout") parser.add_argument('--char_dropout', type=float, default=0, help="Character-level language model dropout") parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off training a character model.") parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.") parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.") parser.add_argument('--no_lowercase', dest='lowercase', action='store_false', help="Use cased word vectors.") parser.add_argument('--no_emb_finetune', dest='emb_finetune', action='store_false', help="Turn off finetuning of the embedding matrix.") parser.add_argument('--emb_finetune_known_only', dest='emb_finetune_known_only', action='store_true', help="Finetune the embedding matrix only for words in the embedding. (Default: finetune words not in the embedding as well) This may be useful for very large datasets where obscure words are only trained once in a while, such as French-WikiNER") parser.add_argument('--no_input_transform', dest='input_transform', action='store_false', help="Do not use input transformation layer before tagger lstm.") parser.add_argument('--scheme', type=str, default='bioes', help="The tagging scheme to use: bio or bioes.") parser.add_argument('--train_scheme', type=str, default=None, help="The tagging scheme to use when training: bio or bioes. Overrides --scheme for the training set") parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)") parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert") parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer") parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)') parser.add_argument('--gradient_checkpointing', default=False, action='store_true', help='Checkpoint intermediate gradients between layers to save memory at the cost of training steps') parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)") parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much') parser.add_argument('--second_optim', type=str, default=None, help='once first optimizer converged, tune the model again. with: sgd, adagrad, adam or adamax.') parser.add_argument('--second_bert_learning_rate', default=0, type=float, help='Secondary stage transformer finetuning learning rate scale') parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.") parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.') parser.add_argument('--optim', type=str, default='sgd', help='sgd, adagrad, adam or adamax.') parser.add_argument('--lr', type=float, default=0.1, help='Learning rate.') parser.add_argument('--min_lr', type=float, default=1e-4, help='Minimum learning rate to stop training.') parser.add_argument('--second_lr', type=float, default=5e-3, help='Secondary learning rate') parser.add_argument('--momentum', type=float, default=0, help='Momentum for SGD.') parser.add_argument('--lr_decay', type=float, default=0.5, help="LR decay rate.") parser.add_argument('--patience', type=int, default=3, help="Patience for LR decay.") parser.add_argument('--connect_output_layers', action='store_true', default=False, help='Connect one output layer to the input of the next output layer. By default, those layers are all separate') parser.add_argument('--predict_tagset', type=int, default=None, help='Which tagset to predict if there are multiple tagsets. Will default to 0. Default of None allows the model to remember the value from training time, but be overridden at test time') parser.add_argument('--ignore_tag_scores', type=str, default=None, help="Which tags to ignore, if any, when scoring dev & test sets") parser.add_argument('--max_steps', type=int, default=200000) parser.add_argument('--max_steps_no_improve', type=int, default=2500, help='if the model doesn\'t improve after this many steps, give up or switch to new optimizer.') parser.add_argument('--eval_interval', type=int, default=500) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--max_batch_words', type=int, default=800, help='Long sentences can overwhelm even a large GPU when finetuning a transformer on otherwise reasonable batch sizes. This cuts off those batches early') parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.') parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.') parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)') parser.add_argument('--save_dir', type=str, default='saved_models/ner', help='Root dir for saving models.') parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{finetune}_nertagger.pt", help="File name to save the model") parser.add_argument('--seed', type=int, default=1234) utils.add_device_args(parser) parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') return parser def parse_args(args=None): parser = build_argparse() add_peft_args(parser) args = parser.parse_args(args=args) resolve_peft_args(args, logger) if args.wandb_name: args.wandb = True args = vars(args) return args def main(args=None): args = parse_args(args=args) utils.set_random_seed(args['seed']) logger.info("Running NER tagger in {} mode".format(args['mode'])) if args['mode'] == 'train': return train(args) else: evaluate(args) def load_pretrain(args): # load pretrained vectors if not args['pretrain']: return None if args['wordvec_pretrain_file']: pretrain_file = args['wordvec_pretrain_file'] pretrain = Pretrain(pretrain_file, None, args['pretrain_max_vocab'], save_to_file=False) else: if len(args['wordvec_file']) == 0: vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand']) else: vec_file = args['wordvec_file'] # do not save pretrained embeddings individually pretrain = Pretrain(None, vec_file, args['pretrain_max_vocab'], save_to_file=False) return pretrain def model_file_name(args): return utils.standard_model_file_name(args, "nertagger") def get_known_tags(tags): """ Tags are stored in the dataset as a list of list of tags This returns a sorted list for each column of tags in the dataset """ max_columns = max(len(word) for sent in tags for word in sent) known_tags = [set() for _ in range(max_columns)] for sent in tags: for word in sent: for tag_idx, tag in enumerate(word): known_tags[tag_idx].add(tag) return [sorted(x) for x in known_tags] def warn_missing_tags(tag_vocab, data_tags, error_msg, bioes_to_bio=False): """ Check for tags missing from the tag_vocab. Given a tag_vocab and the known tags in the format used by ner.data, go through the tags in the dataset and look for any which aren't in the tag_vocab. error_msg is something like "training set" or "eval file" to indicate where the missing tags came from. """ tag_depth = max(max(len(tags) for tags in sentence) for sentence in data_tags) if tag_depth != len(tag_vocab.lens()): logger.warning("Test dataset has a different number of tag types compared to the model: %d vs %d", tag_depth, len(tag_vocab.lens())) for tag_set_idx in range(min(tag_depth, len(tag_vocab.lens()))): tag_set = tag_vocab.items(tag_set_idx) if len(tag_vocab.lens()) > 1: current_error_msg = error_msg + " tag set %d" % tag_set_idx else: current_error_msg = error_msg current_tags = set([word[tag_set_idx] for sentence in data_tags for word in sentence]) if bioes_to_bio: current_tags = set([re.sub("^E-", "I-", re.sub("^S-", "B-", x)) for x in current_tags]) utils.warn_missing_tags(tag_set, current_tags, current_error_msg) def train(args): model_file = model_file_name(args) save_dir, save_name = os.path.split(model_file) utils.ensure_dir(save_dir) if args['save_dir'] is None: args['save_dir'] = save_dir args['save_name'] = save_name utils.log_training_args(args, logger) pretrain = None vocab = None trainer = None if args['finetune'] and args['finetune_load_name']: logger.warning('Finetune is ON. Using model from "{}"'.format(args['finetune_load_name'])) _, trainer, vocab = load_model(args, args['finetune_load_name']) elif args['finetune'] and os.path.exists(model_file): logger.warning('Finetune is ON. Using model from "{}"'.format(model_file)) _, trainer, vocab = load_model(args, model_file) else: if args['finetune']: raise FileNotFoundError('Finetune is set to true but model file is not found: {}'.format(model_file)) pretrain = load_pretrain(args) if pretrain is not None: word_emb_dim = pretrain.emb.shape[1] if args['word_emb_dim'] and args['word_emb_dim'] != word_emb_dim: logger.warning("Embedding file has a dimension of {}. Model will be built with that size instead of {}".format(word_emb_dim, args['word_emb_dim'])) args['word_emb_dim'] = word_emb_dim if args['charlm']: if args['charlm_shorthand'] is None: raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...") logger.info('Using pretrained contextualized char embedding') if not args['charlm_forward_file']: args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand']) if not args['charlm_backward_file']: args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand']) # load data logger.info("Loading training data with batch size %d from %s", args['batch_size'], args['train_file']) with open(args['train_file']) as fin: train_doc = Document(json.load(fin)) logger.info("Loaded %d sentences of training data", len(train_doc.sentences)) if len(train_doc.sentences) == 0: raise ValueError("File %s exists but has no usable training data" % args['train_file']) train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words']) vocab = train_batch.vocab logger.info("Loading dev data from %s", args['eval_file']) with open(args['eval_file']) as fin: dev_doc = Document(json.load(fin)) logger.info("Loaded %d sentences of dev data", len(dev_doc.sentences)) if len(dev_doc.sentences) == 0: raise ValueError("File %s exists but has no usable dev data" % args['train_file']) dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True) train_tags = get_known_tags(train_batch.tags) logger.info("Training data has %d columns of tags", len(train_tags)) for tag_idx, tags in enumerate(train_tags): logger.info("Tags present in training set at column %d:\n Tags without BIES markers: %s\n Tags with B-, I-, E-, or S-: %s", tag_idx, " ".join(sorted(set(i for i in tags if i[:2] not in ('B-', 'I-', 'E-', 'S-')))), " ".join(sorted(set(i[2:] for i in tags if i[:2] in ('B-', 'I-', 'E-', 'S-'))))) # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: logger.info("Skip training because no data available...") return logger.info("Training tagger...") if trainer is None: # init if model was not loaded previously from file trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'], train_classifier_only=args['train_classifier_only']) if args['finetune']: warn_missing_tags(trainer.vocab['tag'], train_batch.tags, "training set") # the evaluation will coerce the tags to the proper scheme, # so we won't need to alert for not having S- or E- tags bioes_to_bio = args['train_scheme'] == 'bio' and args['scheme'] == 'bioes' warn_missing_tags(trainer.vocab['tag'], dev_batch.tags, "dev set", bioes_to_bio=bioes_to_bio) # TODO: might still want to add multiple layers of tag evaluation to the scorer dev_gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in dev_batch.tags] logger.info(trainer.model) global_step = 0 max_steps = args['max_steps'] dev_score_history = [] best_dev_preds = [] current_lr = trainer.optimizer.param_groups[0]['lr'] global_start_time = time.time() format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}' # LR scheduling if args['lr_decay'] > 0: # learning rate changes on plateau -- no improvement on model for patience number of epochs # change is made as a factor of the learning rate decay scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, mode='max', factor=args['lr_decay'], patience=args['patience'], min_lr=args['min_lr']) else: scheduler = None if args['wandb']: import wandb wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_ner" % args['shorthand'] wandb.init(name=wandb_name, config=args) wandb.run.define_metric('train_loss', summary='min') wandb.run.define_metric('dev_score', summary='max') # track gradients! wandb.watch(trainer.model, log_freq=4, log="gradients") # start training last_best_step = 0 train_loss = 0 is_second_optim = False while True: should_stop = False for i, batch in enumerate(train_batch): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step train_loss += loss if global_step % args['log_step'] == 0: duration = time.time() - start_time logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step, max_steps, loss, duration, current_lr)) if global_step % args['eval_interval'] == 0: # eval on dev logger.info("Evaluating on dev set...") dev_preds = [] for batch in dev_batch: preds = trainer.predict(batch) dev_preds += preds _, _, dev_score, _ = scorer.score_by_entity(dev_preds, dev_gold_tags, ignore_tags=args['ignore_tag_scores']) train_loss = train_loss / args['eval_interval'] # avg loss per batch logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score)) if args['wandb']: wandb.log({'train_loss': train_loss, 'dev_score': dev_score}) train_loss = 0 # save best model if len(dev_score_history) == 0 or dev_score > max(dev_score_history): trainer.save(model_file) last_best_step = global_step logger.info("New best model saved.") best_dev_preds = dev_preds dev_score_history += [dev_score] logger.info("") # lr schedule if scheduler is not None: scheduler.step(dev_score) if args['log_norms']: trainer.model.log_norms() # check stopping current_lr = trainer.optimizer.param_groups[0]['lr'] if (global_step - last_best_step) >= args['max_steps_no_improve'] or global_step >= args['max_steps'] or current_lr <= args['min_lr']: if (global_step - last_best_step) >= args['max_steps_no_improve']: logger.info("{} steps without improvement...".format((global_step - last_best_step))) if not is_second_optim and args['second_optim'] is not None: logger.info("Switching to second optimizer: {}".format(args['second_optim'])) logger.info('Reloading best model to continue from current local optimum') trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'], train_classifier_only=args['train_classifier_only'], model_file=model_file, second_optim=True) is_second_optim = True last_best_step = global_step current_lr = trainer.optimizer.param_groups[0]['lr'] else: logger.info("stopping...") should_stop = True break if should_stop: break train_batch.reshuffle() logger.info("Training ended with {} steps.".format(global_step)) if args['wandb']: wandb.finish() if len(dev_score_history) > 0: best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1 logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval'])) else: logger.info("Dev set never evaluated. Saving final model.") trainer.save(model_file) return trainer def write_ner_results(filename, batch, preds, predict_tagset): if len(batch.tags) != len(preds): raise ValueError("Unexpected batch vs pred lengths: %d vs %d" % (len(batch.tags), len(preds))) with open(filename, "w", encoding="utf-8") as fout: tag_idx = 0 for b in batch: # b[0] is words, b[5] is orig_idx # a namedtuple would make this cleaner without being much slower text = utils.unsort(b[0], b[5]) for sentence in text: # TODO: if we change the predict_tagset mechanism, will have to change this sentence_gold = [x[predict_tagset] for x in batch.tags[tag_idx]] sentence_pred = preds[tag_idx] tag_idx += 1 for word, gold, pred in zip(sentence, sentence_gold, sentence_pred): fout.write("%s\t%s\t%s\n" % (word, gold, pred)) fout.write("\n") def evaluate(args): # file paths model_file = model_file_name(args) loaded_args, trainer, vocab = load_model(args, model_file) return evaluate_model(loaded_args, trainer, vocab, args['eval_file']) def evaluate_model(loaded_args, trainer, vocab, eval_file): if loaded_args['log_norms']: trainer.model.log_norms() model_file = os.path.join(loaded_args['save_dir'], loaded_args['save_name']) logger.debug("Loaded model for eval from %s", model_file) logger.debug("Using the %d tagset for evaluation", loaded_args['predict_tagset']) # load data logger.info("Loading data with batch size {}...".format(loaded_args['batch_size'])) with open(eval_file) as fin: doc = Document(json.load(fin)) batch = DataLoader(doc, loaded_args['batch_size'], loaded_args, vocab=vocab, evaluation=True, bert_tokenizer=trainer.model.bert_tokenizer) bioes_to_bio = loaded_args['train_scheme'] == 'bio' and loaded_args['scheme'] == 'bioes' warn_missing_tags(trainer.vocab['tag'], batch.tags, "eval_file", bioes_to_bio=bioes_to_bio) logger.info("Start evaluation...") preds = [] for i, b in enumerate(batch): preds += trainer.predict(b) gold_tags = batch.tags # TODO: might still want to add multiple layers of tag evaluation to the scorer gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in gold_tags] _, _, score, entity_f1 = scorer.score_by_entity(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores']) _, _, _, confusion = scorer.score_by_token(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores']) logger.info("Weighted f1 for non-O tokens: %5f", confusion_to_weighted_f1(confusion, exclude=["O"])) logger.info("NER tagger score: %s %s %s %.2f", loaded_args['shorthand'], model_file, eval_file, score*100) entity_f1_lines = ["%s: %.2f" % (x, y*100) for x, y in entity_f1.items()] logger.info("NER Entity F1 scores:\n %s", "\n ".join(entity_f1_lines)) logger.info("NER token confusion matrix:\n{}".format(format_confusion(confusion))) if loaded_args['eval_output_file']: write_ner_results(loaded_args['eval_output_file'], batch, preds, trainer.args['predict_tagset']) return confusion def load_model(args, model_file): # load model charlm_args = {} if 'charlm_forward_file' in args: charlm_args['charlm_forward_file'] = args['charlm_forward_file'] if 'charlm_backward_file' in args: charlm_args['charlm_backward_file'] = args['charlm_backward_file'] if args['predict_tagset'] is not None: charlm_args['predict_tagset'] = args['predict_tagset'] pretrain = load_pretrain(args) trainer = Trainer(args=charlm_args, model_file=model_file, pretrain=pretrain, device=args['device'], train_classifier_only=args['train_classifier_only']) loaded_args, vocab = trainer.args, trainer.vocab # load config for k in args: if k.endswith('_dir') or k.endswith('_file') or k in ['batch_size', 'ignore_tag_scores', 'log_norms', 'mode', 'scheme', 'shorthand']: loaded_args[k] = args[k] save_dir, save_name = os.path.split(model_file) loaded_args['save_dir'] = save_dir loaded_args['save_name'] = save_name return loaded_args, trainer, vocab if __name__ == '__main__': main() ================================================ FILE: stanza/models/parser.py ================================================ """ Entry point for training and evaluating a dependency parser. This implementation combines a deep biaffine graph-based parser with linearization and distance features. For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf. """ """ Training and evaluation for the parser. """ import io import sys import os import copy import shutil import time import argparse import logging import numpy as np import random import zipfile import torch from torch import nn, optim import stanza.models.depparse.data as data from stanza.models.depparse.data import DataLoader from stanza.models.depparse.trainer import Trainer from stanza.models.depparse import scorer from stanza.models.common import utils from stanza.models.common import pretrain from stanza.models.common.data import augment_punct from stanza.models.common.doc import * from stanza.models.common.peft_config import add_peft_args, resolve_peft_args from stanza.models.common.utils import log_training_args from stanza.utils.conll import CoNLL from stanza.models import _training_logging logger = logging.getLogger('stanza') def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='data/depparse', help='Root dir for saving models.') parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors.') parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.') parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.') parser.add_argument('--no_gold_labels', dest='gold_labels', action='store_false', help="Don't score the eval file - perhaps it has no gold labels, for example. Cannot be used at training time") parser.add_argument('--output_latex', default=False, action='store_true', help='Output the per-relation table in Latex form') parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--lang', type=str, help='Language') parser.add_argument('--shorthand', type=str, help="Treebank shorthand") parser.add_argument('--hidden_dim', type=int, default=400) parser.add_argument('--char_hidden_dim', type=int, default=400) parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400) parser.add_argument('--deep_biaff_output_dim', type=int, default=160) # As an additional option, we implement arc embeddings # described in https://arxiv.org/pdf/2501.09451 # Scaling Graph-Based Dependency Parsing with Arc Vectorization and Attention-Based Refinement # Nicolas Floquet, Joseph Le Roux, Nadi Tomeh, Thierry Charnois # Unfortunately, the current implementation and hyperparameters do not seem to help # when combined with a transformer as the input embedding # LAS Scores on a few dev sets, UD 2.17, averaged over 5 seeds # This is with a version where the arc -> unlabeled is one layer, arc -> label is two layers # Using two layers for the arc -> unlabeled hurts scores a bit more # treebank w/ w/o # en_ewt 93.46 93.47 # de_gsd 89.02 89.12 # it_vit 90.15 90.19 # However, this is without the transformer over the arcs, which is # an important component of making the arcs more useful parser.add_argument('--use_arc_embedding', action='store_true', default=False, help='Use arc embeddings, as per Scaling Graph-Based Dependency Parsing') parser.add_argument('--no_use_arc_embedding', dest='use_arc_embedding', action='store_false', help="Don't use arc embeddings") parser.add_argument('--word_emb_dim', type=int, default=75) parser.add_argument('--word_cutoff', type=int, default=None, help='How common a word must be to include it in the finetuned word embedding. If not set, small word vector files will be 0, larger will be %d' % utils.DEFAULT_WORD_CUTOFF) parser.add_argument('--char_emb_dim', type=int, default=100) parser.add_argument('--tag_emb_dim', type=int, default=50) parser.add_argument('--no_upos', dest='use_upos', action='store_false', default=True, help="Don't use upos tags as part of the tag embedding") parser.add_argument('--no_xpos', dest='use_xpos', action='store_false', default=True, help="Don't use xpos tags as part of the tag embedding") parser.add_argument('--no_ufeats', dest='use_ufeats', action='store_false', default=True, help="Don't use ufeats as part of the tag embedding") parser.add_argument('--transformed_dim', type=int, default=125) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--char_num_layers', type=int, default=1) parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint") parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints") parser.add_argument('--pretrain_max_vocab', type=int, default=250000) parser.add_argument('--word_dropout', type=float, default=0.33) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--rec_dropout', type=float, default=0, help="Recurrent dropout") parser.add_argument('--char_rec_dropout', type=float, default=0, help="Recurrent dropout") parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off character model.") parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.") parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)") parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert") parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer") parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding') parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)') parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)") parser.add_argument('--bert_finetune_layers', default=None, type=int, help='Only finetune this many layers from the transformer') parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much') parser.add_argument('--second_bert_learning_rate', default=1e-3, type=float, help='Secondary stage transformer finetuning learning rate scale') parser.add_argument('--bert_start_finetuning', default=200, type=int, help='When to start finetuning the transformer') parser.add_argument('--bert_warmup_steps', default=200, type=int, help='How many steps for a linear warmup when finetuning the transformer') parser.add_argument('--bert_weight_decay', default=0.0, type=float, help='Weight decay bert parameters by this much') parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.") parser.add_argument('--no_linearization', dest='linearization', action='store_false', help="Turn off linearization term.") parser.add_argument('--no_distance', dest='distance', action='store_false', help="Turn off distance term.") # Originally, we used a single adam optimizer, stopping after 1000 stalled iterations, # with a couple other hyperparameters corresponding to: TODO # --max_steps_before_stop 1000 # --beta2 0.95 # --lr 3e-3 # --weight_decay 0.0 # --optim adam # --no_second_optim # Later experiments found the current defaults helped the results # on several different datasets (using a transformer as the input embedding) # These experiements are averaged across 5 models, # with multiple early stopping values as well # 5 model dev avg LAS 1 stage 1 stage 2k 1 stage 4k 2 stage # de_gsd 89.03 89.50 89.71 89.83 # en_ewt 93.47 93.69 93.74 93.89 # fi_tdt 92.16 92.56 92.69 93.15 # it_vit 90.12 90.37 90.44 90.60 # ta_ttb 71.26 71.39 71.45 72.19 # zh-hans_gsdsimp 85.47 85.69 85.76 85.89 # # 5 model test avg LAS 1 stage 1 stage 2k 1 stage 4k 2 stage # de_gsd 86.60 86.96 87.04 87.09 # en_ewt 93.37 93.51 93.55 93.72 # fi_tdt 92.56 92.92 93.10 93.47 # it_vit 90.51 90.74 90.75 90.88 # ta_ttb 68.22 68.27 68.42 69.06 # zh-hans_gsdsimp 85.66 85.92 86.04 86.34 # # In addition to these experiments, we ran multiple alternate optimizer combinations, none of which # were a clear improvement over AdaDelta+Adam # # rmsprop --weight_decay 1e-5 --lr 0.0001 # adamw --second_lr 0.0001 # madgrad --second_lr 0.00008 # 5 model dev avg LAS ada+adam rms+adam ada+adamw ada+madgrad # de_gsd 89.83 89.80 89.67 89.55 # en_ewt 93.89 93.97 93.92 93.90 # fi_tdt 93.15 92.95 93.03 93.08 # it_vit 90.60 90.64 90.58 90.54 # ta_ttb 72.19 71.86 72.18 72.24 # zh-hans_gsdsimp 85.89 85.60 85.97 85.92 # # 5 model test avg LAS ada+adam rms+adam ada+adamw ada+madgrad # de_gsd 87.09 87.26 87.06 87.08 # en_ewt 93.72 93.73 93.75 93.73 # fi_tdt 93.47 93.30 93.43 93.44 # it_vit 90.88 90.95 90.90 90.85 # ta_ttb 69.06 68.45 69.05 69.26 # zh-hans_gsdsimp 86.34 85.86 86.27 86.23 parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.') parser.add_argument('--optim', type=str, default='adadelta', help='sgd, adagrad, adam or adamax.') parser.add_argument('--second_optim', type=str, default="adam", help='sgd, adagrad, adam or adamax.') parser.add_argument('--no_second_optim', dest='second_optim', action='store_const', const=None, help="Don't use the second optimizer") parser.add_argument('--lr', type=float, default=2.0, help='Learning rate') parser.add_argument('--second_lr', type=float, default=0.0002, help='Secondary stage learning rate') parser.add_argument('--weight_decay', type=float, default=0.00001, help='Weight decay for the first optimizer') parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--second_optim_start_step', type=int, default=10000, help='If set, switch to the second optimizer when stalled or at this step regardless of performance. Normally, the optimizer only switches when the dev scores have stalled for --max_steps_before_stop steps') parser.add_argument('--second_warmup_steps', type=int, default=200, help="If set, give the 2nd optimizer a linear warmup. Idea being that the optimizer won't have a good grasp on the initial gradients and square gradients when it first starts") parser.add_argument('--max_steps', type=int, default=50000) parser.add_argument('--eval_interval', type=int, default=100) parser.add_argument('--checkpoint_interval', type=int, default=500) parser.add_argument('--max_steps_before_stop', type=int, default=2000) parser.add_argument('--batch_size', type=int, default=5000) parser.add_argument('--second_batch_size', type=int, default=None, help='Use a different batch size for the second optimizer. Can be relevant for models with different transformer finetuning settings between optimizers, for example, where the larger batch size is impossible for FT the transformer"') parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.') parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.') parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)') parser.add_argument('--save_dir', type=str, default='saved_models/depparse', help='Root dir for saving models.') parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_parser.pt", help="File name to save the model") parser.add_argument('--continue_from', type=str, default=None, help="File name to preload the model to continue training from") parser.add_argument('--seed', type=int, default=1234) add_peft_args(parser) utils.add_device_args(parser) parser.add_argument('--augment_nopunct', type=float, default=None, help='Augment the training data by copying this fraction of punct-ending sentences as non-punct. Default of None will aim for roughly 10%%') parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') parser.add_argument('--train_size', type=int, default=None, help='If specified, randomly select this many sentences from the training data') return parser def parse_args(args=None): parser = build_argparse() args = parser.parse_args(args=args) resolve_peft_args(args, logger) if args.wandb_name: args.wandb = True args = vars(args) return args def main(args=None): args = parse_args(args=args) utils.set_random_seed(args['seed']) logger.info("Running parser in {} mode".format(args['mode'])) if args['mode'] == 'train': return train(args) else: return evaluate(args) def model_file_name(args): return utils.standard_model_file_name(args, "parser") # TODO: refactor with everywhere def load_pretrain(args): pt = None if args['pretrain']: pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang']) if os.path.exists(pretrain_file): vec_file = None else: vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand']) pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab']) return pt def predict_dataset(trainer, dev_batch): dev_preds = [] if len(dev_batch) > 0: for batch in dev_batch: preds = trainer.predict(batch) dev_preds += preds dev_preds = utils.unsort(dev_preds, dev_batch.data_orig_idx) return dev_preds def train(args): model_file = model_file_name(args) utils.ensure_dir(os.path.split(model_file)[0]) # load pretrained vectors if needed pretrain = load_pretrain(args) args['word_cutoff'] = utils.update_word_cutoff(pretrain, args['word_cutoff']) # TODO: refactor. the exact same thing is done in the tagger if args['charlm']: if args['charlm_shorthand'] is None: raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...") logger.info('Using pretrained contextualized char embedding') if not args['charlm_forward_file']: args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand']) if not args['charlm_backward_file']: args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand']) utils.log_training_args(args, logger) # load data logger.info("Loading data with batch size {}...".format(args['batch_size'])) train_file = args['train_file'] if zipfile.is_zipfile(train_file): logger.info("Decompressing %s" % train_file) train_data = [] with zipfile.ZipFile(train_file) as zin: for zipped_train_file in zin.namelist(): with zin.open(zipped_train_file) as fin: logger.info("Reading %s from %s" % (zipped_train_file, train_file)) train_str = fin.read() train_str = train_str.decode("utf-8") train_file_data, _, _ = CoNLL.conll2dict(input_str=train_str) logger.info("Train File {} from {}, Data Size: {}".format(zipped_train_file, train_file, len(train_file_data))) train_data.extend(train_file_data) else: train_data, _, _ = CoNLL.conll2dict(input_file=args['train_file']) logger.info("Train File {}, Data Size: {}".format(train_file, len(train_data))) # possibly augment the training data with some amount of fake data # based on the options chosen logger.info("Original data size: {}".format(len(train_data))) if args['train_size']: if len(train_data) < args['train_size']: random.shuffle(train_data) train_data = train_data[:args['train_size']] logger.info("Limiting training data to %d entries", len(train_data)) else: logger.info("Train data less than %d already, not limiting train data", args['train_size']) # build the training data once, before augmentation, so that random variation # (which might be different based on the random seed) # doesn't have an effect on the vocab being cut off at the word limit # otherwise different models will have different vocabs # based on how often the words were duplicated in the augmentation # TODO: put the augmentation into the dataloader, # such as is done with the POS or the tokenizer train_doc = Document(train_data) train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, evaluation=False) vocab = train_batch.vocab train_data.extend(augment_punct(train_data, args['augment_nopunct'], keep_original_sentences=False)) logger.info("Augmented data size: {}".format(len(train_data))) train_doc = Document(train_data) train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=False) dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True) # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: logger.info("Skip training because no data available...") sys.exit(0) if args['wandb']: import wandb wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_depparse" % args['shorthand'] wandb.init(name=wandb_name, config=args) wandb.run.define_metric('train_loss', summary='min') wandb.run.define_metric('dev_score', summary='max') logger.info("Training parser...") checkpoint_file = None if args.get("checkpoint"): # calculate checkpoint file name from the save filename checkpoint_file = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name")) args["checkpoint_save_name"] = checkpoint_file if args.get("checkpoint") and os.path.exists(args["checkpoint_save_name"]): trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args["checkpoint_save_name"], device=args['device'], ignore_model_config=True) if len(trainer.dev_score_history) > 0: logger.info("Continuing from checkpoint %s Model was previously trained for %d steps, with a best dev score of %.4f", args["checkpoint_save_name"], trainer.global_step, max(trainer.dev_score_history)) elif args["continue_from"]: if not os.path.exists(args["continue_from"]): raise FileNotFoundError("--continue_from specified, but the file %s does not exist" % args["continue_from"]) trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args["continue_from"], device=args['device'], ignore_model_config=True, reset_history=True) else: trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device']) max_steps = args['max_steps'] current_lr = args['lr'] global_start_time = time.time() format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}' is_second_stage = False # start training train_loss = 0 if args['log_norms']: trainer.model.log_norms() while True: do_break = False for i, batch in enumerate(train_batch): start_time = time.time() trainer.global_step += 1 loss = trainer.update(batch, eval=False) # update step train_loss += loss # will checkpoint if we switch optimizers or score a new best score force_checkpoint = False if trainer.global_step % args['log_step'] == 0: duration = time.time() - start_time logger.info(format_str.format(trainer.global_step, max_steps, loss, duration, current_lr)) if trainer.global_step % args['eval_interval'] == 0: # eval on dev logger.info("Evaluating on dev set...") dev_preds = predict_dataset(trainer, dev_batch) dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x]) system_pred_file = "{:C}\n\n".format(dev_batch.doc) system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file']) train_loss = train_loss / args['eval_interval'] # avg loss per batch logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(trainer.global_step, train_loss, dev_score)) if args['wandb']: wandb.log({'train_loss': train_loss, 'dev_score': dev_score}) train_loss = 0 # save best model trainer.dev_score_history += [dev_score] if dev_score >= max(trainer.dev_score_history): trainer.last_best_step = trainer.global_step trainer.save(model_file) logger.info("new best model saved.") force_checkpoint = True for scheduler_name, scheduler in trainer.scheduler.items(): logger.info('scheduler %s learning rate: %s', scheduler_name, scheduler.get_last_lr()) if args['log_norms']: trainer.model.log_norms() if not is_second_stage and args.get('second_optim', None) is not None: if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop'] or (args['second_optim_start_step'] is not None and trainer.global_step >= args['second_optim_start_step']): logger.info("Switching to second optimizer: {}".format(args.get('second_optim', None))) global_step = trainer.global_step args["second_stage"] = True # if the loader gets a model file, it uses secondary optimizer # (because of the second_stage = True argument) trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain, model_file=model_file, device=args['device']) logger.info('Reloading best model to continue from current local optimum') dev_preds = predict_dataset(trainer, dev_batch) dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x]) system_pred_file = "{:C}\n\n".format(dev_batch.doc) system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file']) logger.info("Reloaded model with dev score %.4f", dev_score) is_second_stage = True trainer.global_step = global_step trainer.last_best_step = global_step if args['second_batch_size'] is not None: train_batch.set_batch_size(args['second_batch_size']) force_checkpoint = True else: if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop']: do_break = True break if trainer.global_step % args['eval_interval'] == 0 or force_checkpoint: # if we need to save checkpoint, do so # (save after switching the optimizer, if applicable, so that # the new optimizer is the optimizer used if a restart happens) if checkpoint_file is not None: trainer.save(checkpoint_file, save_optimizer=True) logger.info("new model checkpoint saved.") if trainer.global_step >= args['max_steps']: do_break = True break if do_break: break train_batch.reshuffle() logger.info("Training ended with {} steps.".format(trainer.global_step)) if args['wandb']: wandb.finish() if len(trainer.dev_score_history) > 0: # TODO: technically the iteration position will be wrong if # the eval_interval changed when running from a checkpoint # could fix this by saving step & score instead of just score best_f, best_eval = max(trainer.dev_score_history)*100, np.argmax(trainer.dev_score_history)+1 logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval'])) else: logger.info("Dev set never evaluated. Saving final model.") trainer.save(model_file) return trainer, _ def evaluate(args): model_file = model_file_name(args) # load pretrained vectors if needed pretrain = load_pretrain(args) load_args = {'charlm_forward_file': args.get('charlm_forward_file', None), 'charlm_backward_file': args.get('charlm_backward_file', None)} # load model logger.info("Loading model from: {}".format(model_file)) trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args) if args['log_norms']: trainer.model.log_norms() return trainer, evaluate_trainer(args, trainer, pretrain) def evaluate_trainer(args, trainer, pretrain): system_pred_file = args['output_file'] loaded_args, vocab = trainer.args, trainer.vocab # load config for k in args: if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand'] or k == 'mode': loaded_args[k] = args[k] # load data logger.info("Loading data with batch size {}...".format(args['batch_size'])) doc = CoNLL.conll2doc(input_file=args['eval_file']) batch = DataLoader(doc, args['batch_size'], loaded_args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True) preds = predict_dataset(trainer, batch) # write to file and score batch.doc.set([HEAD, DEPREL], [y for x in preds for y in x]) if system_pred_file: CoNLL.write_doc2conll(batch.doc, system_pred_file) if args['gold_labels']: gold_doc = CoNLL.conll2doc(input_file=args['eval_file']) # Check for None ... otherwise an inscrutable error occurs later in the scorer for sent_idx, sentence in enumerate(gold_doc.sentences): for word_idx, word in enumerate(sentence.words): if word.deprel is None: raise ValueError("Gold document {} has a None at sentence {} word {}\n{:C}".format(args['eval_file'], sent_idx, word_idx, sentence)) scorer.score_named_dependencies(batch.doc, gold_doc, args['output_latex']) system_pred_file = "{:C}\n\n".format(batch.doc) system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, args['eval_file']) logger.info("Parser score on %s file %s: %.2f", args['shorthand'], args['eval_file'], score*100) return batch.doc if __name__ == '__main__': main() ================================================ FILE: stanza/models/pos/__init__.py ================================================ ================================================ FILE: stanza/models/pos/build_xpos_vocab_factory.py ================================================ import argparse from collections import defaultdict import logging import os import re import sys from zipfile import ZipFile from stanza.models.common.constant import treebank_to_short_name from stanza.models.pos.xpos_vocab_utils import DEFAULT_KEY, choose_simplest_factory, XPOSType from stanza.models.common.doc import * from stanza.utils.conll import CoNLL from stanza.utils import default_paths SHORTNAME_RE = re.compile("[a-z-]+_[a-z0-9]+") DATA_DIR = default_paths.get_default_paths()['POS_DATA_DIR'] logger = logging.getLogger('stanza') def get_xpos_factory(shorthand, fn): logger.info('Resolving vocab option for {}...'.format(shorthand)) doc = None train_file = os.path.join(DATA_DIR, '{}.train.in.conllu'.format(shorthand)) if os.path.exists(train_file): doc = CoNLL.conll2doc(input_file=train_file) else: zip_file = os.path.join(DATA_DIR, '{}.train.in.zip'.format(shorthand)) if os.path.exists(zip_file): with ZipFile(zip_file) as zin: for train_file in zin.namelist(): doc = CoNLL.conll2doc(input_file=train_file, zip_file=zip_file) if any(word.xpos for sentence in doc.sentences for word in sentence.words): break else: raise ValueError('Found training data in {}, but none of the files contained had xpos'.format(zip_file)) if doc is None: raise FileNotFoundError('Training data for {} not found. To generate the XPOS vocabulary ' 'for this treebank properly, please run the following command first:\n' ' python3 stanza/utils/datasets/prepare_pos_treebank.py {}'.format(fn, fn)) # without the training file, there's not much we can do key = DEFAULT_KEY return key data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True) return choose_simplest_factory(data, shorthand) def main(): parser = argparse.ArgumentParser() parser.add_argument('--treebanks', type=str, default=DATA_DIR, help="Treebanks to process - directory with processed datasets or a file with a list") parser.add_argument('--output_file', type=str, default="stanza/models/pos/xpos_vocab_factory.py", help="Where to write the results") args = parser.parse_args() output_file = args.output_file if os.path.isdir(args.treebanks): # if the path is a directory of datasets (which is the default if --treebanks is not set) # we use those datasets to prepare the xpos factories treebanks = os.listdir(args.treebanks) treebanks = [x.split(".", maxsplit=1)[0] for x in treebanks] treebanks = sorted(set(treebanks)) elif os.path.exists(args.treebanks): # maybe it's a file with a list of names with open(args.treebanks) as fin: treebanks = sorted(set([x.strip() for x in fin.readlines() if x.strip()])) else: raise ValueError("Cannot figure out which treebanks to use. Please set the --treebanks parameter") logger.info("Processing the following treebanks: %s" % " ".join(treebanks)) shorthands = [] fullnames = [] for treebank in treebanks: fullnames.append(treebank) if SHORTNAME_RE.match(treebank): shorthands.append(treebank) else: shorthands.append(treebank_to_short_name(treebank)) # For each treebank, we would like to find the XPOS Vocab configuration that minimizes # the number of total classes needed to predict by all tagger classifiers. This is # achieved by enumerating different options of separators that different treebanks might # use, and comparing that to treating the XPOS tags as separate categories (using a # WordVocab). mapping = defaultdict(list) for sh, fn in zip(shorthands, fullnames): factory = get_xpos_factory(sh, fn) mapping[factory].append(sh) if sh == 'zh-hans_gsdsimp': mapping[factory].append('zh_gsdsimp') elif sh == 'no_bokmaal': mapping[factory].append('nb_bokmaal') mapping[DEFAULT_KEY].append('en_test') # Generate code. This takes the XPOS vocabulary classes selected above, and generates the # actual factory class as seen in models.pos.xpos_vocab_factory. first = True with open(output_file, 'w') as f: max_len = max(max(len(x) for x in mapping[key]) for key in mapping) print('''# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory. # Please don't edit it! import logging from stanza.models.pos.vocab import WordVocab, XPOSVocab from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory # using a sublogger makes it easier to test in the unittests logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory') XPOS_DESCRIPTIONS = {''', file=f) for key_idx, key in enumerate(mapping): if key_idx > 0: print(file=f) for shorthand in sorted(mapping[key]): # +2 to max_len for the '' # this format string is left justified (either would be okay, probably) if key.sep is None: sep = 'None' else: sep = "'%s'" % key.sep print((" {:%ds}: XPOSDescription({}, {})," % (max_len+2)).format("'%s'" % shorthand, key.xpos_type, sep), file=f) print('''} def xpos_vocab_factory(data, shorthand): if shorthand not in XPOS_DESCRIPTIONS: logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand) desc = choose_simplest_factory(data, shorthand) if shorthand in XPOS_DESCRIPTIONS: if XPOS_DESCRIPTIONS[shorthand] != desc: # log instead of throw # otherwise, updating datasets would be unpleasant logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc) else: logger.warning("Chose %s for the xpos factory for %s", desc, shorthand) return build_xpos_vocab(desc, data, shorthand) ''', file=f) logger.info('Done!') if __name__ == "__main__": main() ================================================ FILE: stanza/models/pos/data.py ================================================ import random import logging import copy import torch from collections import namedtuple from torch.utils.data import DataLoader as DL from torch.utils.data.sampler import Sampler from torch.nn.utils.rnn import pad_sequence from stanza.models.common.bert_embedding import filter_data, needs_length_filter from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all from stanza.models.common.utils import DEFAULT_WORD_CUTOFF, simplify_punct from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, CharVocab from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory from stanza.models.common.doc import * logger = logging.getLogger('stanza') DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain text") DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx") class Dataset: def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs): self.args = args self.eval = evaluation self.shuffled = not self.eval self.sort_during_eval = sort_during_eval self.doc = doc if vocab is None: self.vocab = Dataset.init_vocab([doc], args) else: self.vocab = vocab self.has_upos = not all(x is None or x == '_' for x in doc.get(UPOS, as_sentences=False)) self.has_xpos = not all(x is None or x == '_' for x in doc.get(XPOS, as_sentences=False)) self.has_feats = not all(x is None or x == '_' for x in doc.get(FEATS, as_sentences=False)) data = self.load_doc(self.doc) # filter out the long sentences if bert is used if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']): data = filter_data(self.args['bert_model'], data, bert_tokenizer) # handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None self.pretrain_vocab = None if pretrain is not None and args['pretrain']: self.pretrain_vocab = pretrain.vocab # filter and sample data if args.get('sample_train', 1.0) < 1.0 and not self.eval: keep = int(args['sample_train'] * len(data)) data = random.sample(data, keep) logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) data = self.preprocess(data, self.vocab, self.pretrain_vocab, args) self.data = data self.num_examples = len(data) self.__punct_tags = self.vocab["upos"].map(["PUNCT"]) self.augment_nopunct = self.args.get("augment_nopunct", 0.0) @staticmethod def init_vocab(docs, args): cutoff = args['word_cutoff'] if args.get('word_cutoff') is not None else DEFAULT_WORD_CUTOFF data = [x for doc in docs for x in Dataset.load_doc(doc)] charvocab = CharVocab(data, args['shorthand']) wordvocab = WordVocab(data, args['shorthand'], cutoff=cutoff, lower=True) uposvocab = WordVocab(data, args['shorthand'], idx=1) xposvocab = xpos_vocab_factory(data, args['shorthand']) try: featsvocab = FeatureVocab(data, args['shorthand'], idx=3) except ValueError as e: raise ValueError("Unable to build features vocab. Please check the Features column of your data for an error which may match the following description.") from e vocab = MultiVocab({'char': charvocab, 'word': wordvocab, 'upos': uposvocab, 'xpos': xposvocab, 'feats': featsvocab}) return vocab def preprocess(self, data, vocab, pretrain_vocab, args): processed = [] for sent in data: processed_sent = DataSample( word = [vocab['word'].map([w[0] for w in sent])], char = [[vocab['char'].map([x for x in w[0]]) for w in sent]], upos = [vocab['upos'].map([w[1] for w in sent])], xpos = [vocab['xpos'].map([w[2] for w in sent])], feats = [vocab['feats'].map([w[3] for w in sent])], pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])] if pretrain_vocab is not None else [[PAD_ID] * len(sent)]), text = [w[0] for w in sent] ) processed.append(processed_sent) return processed def __len__(self): return len(self.data) def __mask(self, upos): """Returns a torch boolean about which elements should be masked out""" # creates all false mask mask = torch.zeros_like(upos, dtype=torch.bool) ### augmentation 1: punctuation augmentation ### # tags that needs to be checked, currently only PUNCT if random.uniform(0,1) < self.augment_nopunct: for i in self.__punct_tags: # generate a mask for the last element last_element = torch.zeros_like(upos, dtype=torch.bool) last_element[..., -1] = True # we or the bitmask against the existing mask # if it satisfies, we remove the word by masking it # to true # # if your input is just a lone punctuation, we perform # no masking if not torch.all(upos.eq(torch.tensor([[i]]))): mask |= ((upos == i) & (last_element)) return mask def __getitem__(self, key): """Retrieves a sample from the dataset. Retrieves a sample from the dataset. This function, for the most part, is spent performing ad-hoc data augmentation and restoration. It receives a DataSample object from the storage, and returns an almost-identical DataSample object that may have been augmented with /possibly/ (depending on augment_punct settings) PUNCT chopped. **Important Note** ------------------ If you would like to load the data into a model, please convert this Dataset object into a DataLoader via self.to_loader(). Then, you can use the resulting object like any other PyTorch data loader. As masks are calculated ad-hoc given the batch, the samples returned from this object doesn't have the appropriate masking. Motivation ---------- Why is this here? Every time you call next(iter(dataloader)), it calls this function. Therefore, if we augmented each sample on each iteration, the model will see dynamically generated augmentation. Furthermore, PyTorch dataloader handles shuffling natively. Parameters ---------- key : int the integer ID to from which to retrieve the key. Returns ------- DataSample The sample of data you requested, with augmentation. """ # get a sample of the input data sample = self.data[key] # some data augmentation requires constructing a mask based on upos. # For instance, sometimes we'd like to mask out ending sentence punctuation. # We copy the other items here so that any edits made because # of the mask don't clobber the version owned by the Dataset # convert to tensors # TODO: only store single lists per data entry? words = torch.tensor(sample.word[0]) # convert the rest to tensors upos = torch.tensor(sample.upos[0]) if self.has_upos else None xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None pretrained = torch.tensor(sample.pretrain[0]) # and deal with char & raw_text char = sample.char[0] raw_text = sample.text # some data augmentation requires constructing a mask based on # which upos. For instance, sometimes we'd like to mask out ending # sentence punctuation. The mask is True if we want to remove the element if self.has_upos and upos is not None and not self.eval: # perform actual masking mask = self.__mask(upos) else: # dummy mask that's all false mask = None if mask is not None: mask_index = mask.nonzero() # mask out the elements that we need to mask out for mask in mask_index: mask = mask.item() words[mask] = PAD_ID if upos is not None: upos[mask] = PAD_ID if xpos is not None: # TODO: test the multi-dimension xpos xpos[mask, ...] = PAD_ID if ufeats is not None: ufeats[mask, ...] = PAD_ID pretrained[mask] = PAD_ID char = char[:mask] + char[mask+1:] raw_text = raw_text[:mask] + raw_text[mask+1:] # get each character from the input sentnece # chars = [w for sent in char for w in sent] return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key def __iter__(self): for i in range(self.__len__()): yield self.__getitem__(i) def to_loader(self, **kwargs): """Converts self to a DataLoader """ return DL(self, collate_fn=Dataset.__collate_fn, **kwargs) def to_length_limited_loader(self, batch_size, maximum_tokens): sampler = LengthLimitedBatchSampler(self, batch_size, maximum_tokens) return DL(self, collate_fn=Dataset.__collate_fn, batch_sampler = sampler) @staticmethod def __collate_fn(data): """Function used by DataLoader to pack data""" (data, idx) = zip(*data) (words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data) # collate_fn is given a list of length batch size batch_size = len(data) # sort sentences by lens for easy RNN operations lens = [torch.sum(x != PAD_ID) for x in words] (words, wordchars, upos, xpos, ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos, ufeats, pretrained, text), lens) lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN # combine all words into one large list, and sort for easy charRNN ops wordchars = [w for sent in wordchars for w in sent] word_lens = [len(x) for x in wordchars] (wordchars,), word_orig_idx = sort_all([wordchars], word_lens) word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN # We now pad everything words = pad_sequence(words, True, PAD_ID) if None not in upos: upos = pad_sequence(upos, True, PAD_ID) else: upos = None if None not in xpos: xpos = pad_sequence(xpos, True, PAD_ID) else: xpos = None if None not in ufeats: ufeats = pad_sequence(ufeats, True, PAD_ID) else: ufeats = None pretrained = pad_sequence(pretrained, True, PAD_ID) wordchars = get_long_tensor(wordchars, len(word_lens)) # and finally create masks for the padding indices words_mask = torch.eq(words, PAD_ID) wordchars_mask = torch.eq(wordchars, PAD_ID) return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx) @staticmethod def load_doc(doc): data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True) data = Dataset.resolve_none(data) data = simplify_punct(data) return data @staticmethod def resolve_none(data): # replace None to '_' for sent_idx in range(len(data)): for tok_idx in range(len(data[sent_idx])): for feat_idx in range(len(data[sent_idx][tok_idx])): if data[sent_idx][tok_idx][feat_idx] is None: data[sent_idx][tok_idx][feat_idx] = '_' return data class LengthLimitedBatchSampler(Sampler): """ Batches up the text in batches of batch_size, but cuts off each time a batch reaches maximum_tokens Intent is to avoid GPU OOM in situations where one sentence is significantly longer than expected, leaving a batch too large to fit in the GPU Sentences which are longer than maximum_tokens by themselves are put in their own batches """ def __init__(self, data, batch_size, maximum_tokens): """ Precalculate the batches, making it so len and iter just read off the precalculated batches """ self.data = data self.batch_size = batch_size self.maximum_tokens = maximum_tokens self.batches = [] current_batch = [] current_length = 0 for item, item_idx in data: item_len = len(item.word) if maximum_tokens and item_len > maximum_tokens: if len(current_batch) > 0: self.batches.append(current_batch) current_batch = [] current_length = 0 self.batches.append([item_idx]) continue if len(current_batch) + 1 > batch_size or (maximum_tokens and item_len + current_length > maximum_tokens): self.batches.append(current_batch) current_batch = [] current_length = 0 current_batch.append(item_idx) current_length += item_len if len(current_batch) > 0: self.batches.append(current_batch) def __len__(self): return len(self.batches) def __iter__(self): for batch in self.batches: current_batch = [] for idx in batch: current_batch.append(idx) yield current_batch class ShuffledDataset: """A wrapper around one or more datasets which shuffles the data in batch_size chunks This means that if multiple datasets are passed in, the batches from each dataset are shuffled together, with one batch being entirely members of the same dataset. The main use case of this is that in the tagger, there are cases where batches from different datasets will have different properties, such as having or not having UPOS tags. We found that it is actually somewhat tricky to make the model's loss function (in model.py) properly represent batches with mixed w/ and w/o property, whereas keeping one entire batch together makes it a lot easier to process. The mechanism for the shuffling is that the iterator first makes a list long enough to represent each batch from each dataset, tracking the index of the dataset it is coming from, then shuffles that list. Another alternative would be to use a weighted randomization approach, but this is very simple and the memory requirements are not too onerous. Note that the batch indices are wasteful in the case of only one underlying dataset, which is actually the most common use case, but the overhead is small enough that it probably isn't worth special casing the one dataset version. """ def __init__(self, datasets, batch_size): self.batch_size = batch_size self.datasets = datasets self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets] def __iter__(self): iterators = [iter(x) for x in self.loaders] lengths = [len(x) for x in self.loaders] indices = [[x] * y for x, y in enumerate(lengths)] indices = [idx for inner in indices for idx in inner] random.shuffle(indices) for idx in indices: yield(next(iterators[idx])) def __len__(self): return sum(len(x) for x in self.datasets) ================================================ FILE: stanza/models/pos/model.py ================================================ import logging import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence from stanza.models.common.bert_embedding import extract_bert_embeddings from stanza.models.common.biaffine import BiaffineScorer from stanza.models.common.foundation_cache import load_bert, load_charlm from stanza.models.common.hlstm import HighwayLSTM from stanza.models.common.dropout import WordDropout from stanza.models.common.utils import attach_bert_model from stanza.models.common.vocab import CompositeVocab from stanza.models.common.char_model import CharacterModel from stanza.models.common import utils logger = logging.getLogger('stanza') class Tagger(nn.Module): def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None): super().__init__() self.vocab = vocab self.args = args self.share_hid = share_hid self.unsaved_modules = [] # input layers input_size = 0 if self.args['word_emb_dim'] > 0: # frequent word embeddings self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0) input_size += self.args['word_emb_dim'] if not share_hid: # upos embeddings self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0) if self.args['char'] and self.args['char_emb_dim'] > 0: if self.args.get('charlm', None): if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']): raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file'])) if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']): raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file'])) logger.debug("POS model loading charmodels: %s and %s", args['charlm_forward_file'], args['charlm_backward_file']) self.add_unsaved_module('charmodel_forward', load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache)) self.add_unsaved_module('charmodel_backward', load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache)) # optionally add a input transformation layer if self.args.get('charlm_transform_dim', 0): self.charmodel_forward_transform = nn.Linear(self.charmodel_forward.hidden_dim(), self.args['charlm_transform_dim'], bias=False) self.charmodel_backward_transform = nn.Linear(self.charmodel_backward.hidden_dim(), self.args['charlm_transform_dim'], bias=False) input_size += self.args['charlm_transform_dim'] * 2 else: self.charmodel_forward_transform = None self.charmodel_backward_transform = None input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim() else: bidirectional = args.get('char_bidirectional', False) self.charmodel = CharacterModel(args, vocab, bidirectional=bidirectional) if bidirectional: self.trans_char = nn.Linear(self.args['char_hidden_dim'] * 2, self.args['transformed_dim'], bias=False) else: self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False) input_size += self.args['transformed_dim'] self.peft_name = peft_name attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved) if self.args.get('bert_model', None): # TODO: refactor bert_hidden_layers between the different models if args.get('bert_hidden_layers', False): # The average will be offset by 1/N so that the default zeros # represents an average of the N layers self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False) nn.init.zeros_(self.bert_layer_mix.weight) else: # an average of layers 2, 3, 4 will be used # (for historic reasons) self.bert_layer_mix = None input_size += self.bert_model.config.hidden_size if self.args['pretrain']: # pretrained embeddings, by default this won't be saved into model file self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True)) self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False) input_size += self.args['transformed_dim'] # recurrent layers self.taggerlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh) self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size)) self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim'])) self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim'])) # classifiers self.upos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim']) self.upos_clf = nn.Linear(self.args['deep_biaff_hidden_dim'], len(vocab['upos'])) self.upos_clf.weight.data.zero_() self.upos_clf.bias.data.zero_() if share_hid: clf_constructor = lambda insize, outsize: nn.Linear(insize, outsize) else: self.xpos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'] if not isinstance(vocab['xpos'], CompositeVocab) else self.args['composite_deep_biaff_hidden_dim']) self.ufeats_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['composite_deep_biaff_hidden_dim']) clf_constructor = lambda insize, outsize: BiaffineScorer(insize, self.args['tag_emb_dim'], outsize) if isinstance(vocab['xpos'], CompositeVocab): self.xpos_clf = nn.ModuleList() for l in vocab['xpos'].lens(): self.xpos_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l)) else: self.xpos_clf = clf_constructor(self.args['deep_biaff_hidden_dim'], len(vocab['xpos'])) if share_hid: self.xpos_clf.weight.data.zero_() self.xpos_clf.bias.data.zero_() self.ufeats_clf = nn.ModuleList() for l in vocab['feats'].lens(): if share_hid: self.ufeats_clf.append(clf_constructor(self.args['deep_biaff_hidden_dim'], l)) self.ufeats_clf[-1].weight.data.zero_() self.ufeats_clf[-1].bias.data.zero_() else: self.ufeats_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l)) # criterion self.crit = nn.CrossEntropyLoss(ignore_index=0) # ignore padding self.drop = nn.Dropout(args['dropout']) self.worddrop = WordDropout(args['word_dropout']) def add_unsaved_module(self, name, module): self.unsaved_modules += [name] setattr(self, name, module) def log_norms(self): utils.log_norms(self) def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text): def pack(x): return pack_padded_sequence(x, sentlens, batch_first=True) inputs = [] if self.args['word_emb_dim'] > 0: word_emb = self.word_emb(word) word_emb = pack(word_emb) inputs += [word_emb] if self.args['pretrain']: pretrained_emb = self.pretrained_emb(pretrained) pretrained_emb = self.trans_pretrained(pretrained_emb) pretrained_emb = pack(pretrained_emb) inputs += [pretrained_emb] def pad(x): return pad_packed_sequence(PackedSequence(x, inputs[0].batch_sizes), batch_first=True)[0] if self.args['char'] and self.args['char_emb_dim'] > 0: if self.args.get('charlm', None): all_forward_chars = self.charmodel_forward.build_char_representation(text) assert isinstance(all_forward_chars, list) if self.charmodel_forward_transform is not None: all_forward_chars = [self.charmodel_forward_transform(x) for x in all_forward_chars] all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True)) all_backward_chars = self.charmodel_backward.build_char_representation(text) if self.charmodel_backward_transform is not None: all_backward_chars = [self.charmodel_backward_transform(x) for x in all_backward_chars] all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True)) inputs += [all_forward_chars, all_backward_chars] else: char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens) char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes) inputs += [char_reps] if self.bert_model is not None: device = next(self.parameters()).device processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=False, num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None, detach=not self.args.get('bert_finetune', False) or not self.training, peft_name=self.peft_name) if self.bert_layer_mix is not None: # add the average so that the default behavior is to # take an average of the N layers, and anything else # other than that needs to be learned # TODO: refactor this processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert] processed_bert = pad_sequence(processed_bert, batch_first=True) inputs += [pack(processed_bert)] lstm_inputs = torch.cat([x.data for x in inputs], 1) lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement) lstm_inputs = self.drop(lstm_inputs) lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes) lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(self.taggerlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.taggerlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous())) lstm_outputs = lstm_outputs.data upos_hid = F.relu(self.upos_hid(self.drop(lstm_outputs))) upos_pred = self.upos_clf(self.drop(upos_hid)) preds = [pad(upos_pred).max(2)[1]] if upos is not None: upos = pack(upos).data loss = self.crit(upos_pred.view(-1, upos_pred.size(-1)), upos.view(-1)) else: loss = 0.0 if self.share_hid: xpos_hid = upos_hid ufeats_hid = upos_hid clffunc = lambda clf, hid: clf(self.drop(hid)) else: xpos_hid = F.relu(self.xpos_hid(self.drop(lstm_outputs))) ufeats_hid = F.relu(self.ufeats_hid(self.drop(lstm_outputs))) if self.training and upos is not None: upos_emb = self.upos_emb(upos) else: upos_emb = self.upos_emb(upos_pred.max(1)[1]) clffunc = lambda clf, hid: clf(self.drop(hid), self.drop(upos_emb)) if xpos is not None: xpos = pack(xpos).data if isinstance(self.vocab['xpos'], CompositeVocab): xpos_preds = [] for i in range(len(self.vocab['xpos'])): xpos_pred = clffunc(self.xpos_clf[i], xpos_hid) if xpos is not None: loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos[:, i].view(-1)) xpos_preds.append(pad(xpos_pred).max(2, keepdim=True)[1]) preds.append(torch.cat(xpos_preds, 2)) else: xpos_pred = clffunc(self.xpos_clf, xpos_hid) if xpos is not None: loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos.view(-1)) preds.append(pad(xpos_pred).max(2)[1]) ufeats_preds = [] if ufeats is not None: ufeats = pack(ufeats).data for i in range(len(self.vocab['feats'])): ufeats_pred = clffunc(self.ufeats_clf[i], ufeats_hid) if ufeats is not None: loss += self.crit(ufeats_pred.view(-1, ufeats_pred.size(-1)), ufeats[:, i].view(-1)) ufeats_preds.append(pad(ufeats_pred).max(2, keepdim=True)[1]) preds.append(torch.cat(ufeats_preds, 2)) return loss, preds ================================================ FILE: stanza/models/pos/scorer.py ================================================ """ Utils and wrappers for scoring taggers. """ import logging from stanza.models.common.utils import ud_scores logger = logging.getLogger('stanza') def score(system_conllu_file, gold_conllu_file, verbose=True, eval_type='AllTags'): """ Wrapper for tagger scorer. """ evaluation = ud_scores(gold_conllu_file, system_conllu_file) el = evaluation[eval_type] p = el.precision r = el.recall f = el.f1 if verbose: scores = [evaluation[k].f1 * 100 for k in ['UPOS', 'XPOS', 'UFeats', 'AllTags']] logger.info("UPOS\tXPOS\tUFeats\tAllTags") logger.info("{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}".format(*scores)) return p, r, f ================================================ FILE: stanza/models/pos/trainer.py ================================================ """ A trainer class to handle training and testing of models. """ import sys import logging import torch from torch import nn from stanza.models.common.trainer import Trainer as BaseTrainer from stanza.models.common import utils, loss from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper from stanza.models.pos.model import Tagger from stanza.models.pos.vocab import MultiVocab logger = logging.getLogger('stanza') def unpack_batch(batch, device): """ Unpack a batch from the data loader. """ inputs = [b.to(device) if b is not None else None for b in batch[:8]] orig_idx = batch[8] word_orig_idx = batch[9] sentlens = batch[10] wordlens = batch[11] text = batch[12] return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text class Trainer(BaseTrainer): """ A trainer for training models. """ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, foundation_cache=None): if model_file is not None: # load everything from file self.load(model_file, pretrain, args=args, foundation_cache=foundation_cache) else: # build model from scratch self.args = args self.vocab = vocab bert_model, bert_tokenizer = load_bert(self.args['bert_model']) peft_name = None if self.args['use_peft']: # fine tune the bert if we're using peft self.args['bert_finetune'] = True peft_name = "pos" bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name) self.model = Tagger(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, share_hid=args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name) self.model = self.model.to(device) self.optimizers = utils.get_split_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, weight_decay=self.args.get('initial_weight_decay', None), bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("peft", False)) self.schedulers = {} if self.args.get('bert_finetune', None): import transformers warmup_scheduler = transformers.get_linear_schedule_with_warmup( self.optimizers["bert_optimizer"], # todo late starting? 0, self.args["max_steps"]) self.schedulers["bert_scheduler"] = warmup_scheduler def update(self, batch, eval=False): device = next(self.model.parameters()).device inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device) word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs if eval: self.model.eval() else: self.model.train() for optimizer in self.optimizers.values(): optimizer.zero_grad() loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text) if loss == 0.0: return loss loss_val = loss.data.item() if eval: return loss_val loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) for optimizer in self.optimizers.values(): optimizer.step() for scheduler in self.schedulers.values(): scheduler.step() return loss_val def predict(self, batch, unsort=True): device = next(self.model.parameters()).device inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device) word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs self.model.eval() batch_size = word.size(0) _, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text) upos_seqs = [self.vocab['upos'].unmap(sent) for sent in preds[0].tolist()] xpos_seqs = [self.vocab['xpos'].unmap(sent) for sent in preds[1].tolist()] feats_seqs = [self.vocab['feats'].unmap(sent) for sent in preds[2].tolist()] pred_tokens = [[[upos_seqs[i][j], xpos_seqs[i][j], feats_seqs[i][j]] for j in range(sentlens[i])] for i in range(batch_size)] if unsort: pred_tokens = utils.unsort(pred_tokens, orig_idx) return pred_tokens def save(self, filename, skip_modules=True): model_state = self.model.state_dict() # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file if skip_modules: skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules] for k in skipped: del model_state[k] params = { 'model': model_state, 'vocab': self.vocab.state_dict(), 'config': self.args } if self.args.get('use_peft', False): # Hide import so that peft dependency is optional from peft import get_peft_model_state_dict params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name) try: torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) except (KeyboardInterrupt, SystemExit): raise except Exception as e: logger.warning(f"Saving failed... {e} continuing anyway.") def load(self, filename, pretrain, args=None, foundation_cache=None): """ Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input, and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args. """ try: checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise self.args = checkpoint['config'] if args is not None: self.args.update(args) # preserve old models which were created before transformers were added if 'bert_model' not in self.args: self.args['bert_model'] = None lora_weights = checkpoint.get('bert_lora') if lora_weights: logger.debug("Found peft weights for POS; loading a peft adapter") self.args["use_peft"] = True # TODO: refactor this common block of code with NER force_bert_saved = False peft_name = None if self.args.get('use_peft', False): force_bert_saved = True bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "pos", foundation_cache) bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name) logger.debug("Loaded peft with name %s", peft_name) else: if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()): logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename) foundation_cache = NoTransformerFoundationCache(foundation_cache) force_bert_saved = True bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache) self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) # load model emb_matrix = None if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None emb_matrix = pretrain.emb if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()): logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename) foundation_cache = NoTransformerFoundationCache(foundation_cache) self.model = Tagger(self.args, self.vocab, emb_matrix=emb_matrix, share_hid=self.args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name) self.model.load_state_dict(checkpoint['model'], strict=False) ================================================ FILE: stanza/models/pos/vocab.py ================================================ from collections import Counter, OrderedDict from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab from stanza.models.common.vocab import CompositeVocab, VOCAB_PREFIX, EMPTY, EMPTY_ID class WordVocab(BaseVocab): def __init__(self, data=None, lang="", idx=0, cutoff=0, lower=False, ignore=None): self.ignore = ignore if ignore is not None else [] super().__init__(data, lang=lang, idx=idx, cutoff=cutoff, lower=lower) self.state_attrs += ['ignore'] def id2unit(self, id): if len(self.ignore) > 0 and id == EMPTY_ID: return '_' else: return super().id2unit(id) def unit2id(self, unit): if len(self.ignore) > 0 and unit in self.ignore: return self._unit2id[EMPTY] else: return super().unit2id(unit) def build_vocab(self): if self.lower: counter = Counter([w[self.idx].lower() for sent in self.data for w in sent]) else: counter = Counter([w[self.idx] for sent in self.data for w in sent]) for k in list(counter.keys()): if counter[k] < self.cutoff or k in self.ignore: del counter[k] self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True)) self._unit2id = {w:i for i, w in enumerate(self._id2unit)} def __iter__(self): # the EMPTY shenanigans above make list() look really weird # when using the __len__ / __getitem__ paradigm, # but yielding items like this works fine for x in self._id2unit: yield x def __str__(self): return "<{}: {}>".format(type(self), ",".join("|%s|" % x for x in self._id2unit)) class XPOSVocab(CompositeVocab): def __init__(self, data=None, lang="", idx=0, sep="", keyed=False): super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed) class FeatureVocab(CompositeVocab): def __init__(self, data=None, lang="", idx=0, sep="|", keyed=True): super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed) class MultiVocab(BaseMultiVocab): def state_dict(self): """ Also save a vocab name to class name mapping in state dict. """ state = OrderedDict() key2class = OrderedDict() for k, v in self._vocabs.items(): state[k] = v.state_dict() key2class[k] = type(v).__name__ state['_key2class'] = key2class return state @classmethod def load_state_dict(cls, state_dict): class_dict = {'CharVocab': CharVocab, 'WordVocab': WordVocab, 'XPOSVocab': XPOSVocab, 'FeatureVocab': FeatureVocab} new = cls() assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!" key2class = state_dict['_key2class'] for k,v in state_dict.items(): if k == '_key2class': continue classname = key2class[k] new[k] = class_dict[classname].load_state_dict(v) return new ================================================ FILE: stanza/models/pos/xpos_vocab_factory.py ================================================ # This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory. # Please don't edit it! import logging from stanza.models.pos.vocab import WordVocab, XPOSVocab from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory # using a sublogger makes it easier to test in the unittests logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory') XPOS_DESCRIPTIONS = { 'af_afribooms' : XPOSDescription(XPOSType.XPOS, ''), 'ar_padt' : XPOSDescription(XPOSType.XPOS, ''), 'bg_btb' : XPOSDescription(XPOSType.XPOS, ''), 'ca_ancora' : XPOSDescription(XPOSType.XPOS, ''), 'cs_cac' : XPOSDescription(XPOSType.XPOS, ''), 'cs_cltt' : XPOSDescription(XPOSType.XPOS, ''), 'cs_fictree' : XPOSDescription(XPOSType.XPOS, ''), 'cs_pdt' : XPOSDescription(XPOSType.XPOS, ''), 'en_partut' : XPOSDescription(XPOSType.XPOS, ''), 'es_ancora' : XPOSDescription(XPOSType.XPOS, ''), 'es_combined' : XPOSDescription(XPOSType.XPOS, ''), 'fr_partut' : XPOSDescription(XPOSType.XPOS, ''), 'gd_arcosg' : XPOSDescription(XPOSType.XPOS, ''), 'gl_ctg' : XPOSDescription(XPOSType.XPOS, ''), 'gl_treegal' : XPOSDescription(XPOSType.XPOS, ''), 'grc_perseus' : XPOSDescription(XPOSType.XPOS, ''), 'hr_set' : XPOSDescription(XPOSType.XPOS, ''), 'is_gc' : XPOSDescription(XPOSType.XPOS, ''), 'is_icepahc' : XPOSDescription(XPOSType.XPOS, ''), 'is_modern' : XPOSDescription(XPOSType.XPOS, ''), 'it_combined' : XPOSDescription(XPOSType.XPOS, ''), 'it_isdt' : XPOSDescription(XPOSType.XPOS, ''), 'it_markit' : XPOSDescription(XPOSType.XPOS, ''), 'it_parlamint' : XPOSDescription(XPOSType.XPOS, ''), 'it_partut' : XPOSDescription(XPOSType.XPOS, ''), 'it_postwita' : XPOSDescription(XPOSType.XPOS, ''), 'it_twittiro' : XPOSDescription(XPOSType.XPOS, ''), 'it_vit' : XPOSDescription(XPOSType.XPOS, ''), 'la_perseus' : XPOSDescription(XPOSType.XPOS, ''), 'la_udante' : XPOSDescription(XPOSType.XPOS, ''), 'lt_alksnis' : XPOSDescription(XPOSType.XPOS, ''), 'lv_lvtb' : XPOSDescription(XPOSType.XPOS, ''), 'ro_nonstandard' : XPOSDescription(XPOSType.XPOS, ''), 'ro_rrt' : XPOSDescription(XPOSType.XPOS, ''), 'ro_simonero' : XPOSDescription(XPOSType.XPOS, ''), 'sk_snk' : XPOSDescription(XPOSType.XPOS, ''), 'sl_ssj' : XPOSDescription(XPOSType.XPOS, ''), 'sl_sst' : XPOSDescription(XPOSType.XPOS, ''), 'sr_set' : XPOSDescription(XPOSType.XPOS, ''), 'ta_ttb' : XPOSDescription(XPOSType.XPOS, ''), 'uk_iu' : XPOSDescription(XPOSType.XPOS, ''), 'be_hse' : XPOSDescription(XPOSType.WORD, None), 'bxr_bdt' : XPOSDescription(XPOSType.WORD, None), 'cop_scriptorium': XPOSDescription(XPOSType.WORD, None), 'cu_proiel' : XPOSDescription(XPOSType.WORD, None), 'cy_ccg' : XPOSDescription(XPOSType.WORD, None), 'da_ddt' : XPOSDescription(XPOSType.WORD, None), 'de_gsd' : XPOSDescription(XPOSType.WORD, None), 'de_hdt' : XPOSDescription(XPOSType.WORD, None), 'el_gdt' : XPOSDescription(XPOSType.WORD, None), 'el_gud' : XPOSDescription(XPOSType.WORD, None), 'en_atis' : XPOSDescription(XPOSType.WORD, None), 'en_combined' : XPOSDescription(XPOSType.WORD, None), 'en_craft' : XPOSDescription(XPOSType.WORD, None), 'en_eslspok' : XPOSDescription(XPOSType.WORD, None), 'en_ewt' : XPOSDescription(XPOSType.WORD, None), 'en_genia' : XPOSDescription(XPOSType.WORD, None), 'en_gum' : XPOSDescription(XPOSType.WORD, None), 'en_gumreddit' : XPOSDescription(XPOSType.WORD, None), 'en_mimic' : XPOSDescription(XPOSType.WORD, None), 'en_test' : XPOSDescription(XPOSType.WORD, None), 'es_gsd' : XPOSDescription(XPOSType.WORD, None), 'et_edt' : XPOSDescription(XPOSType.WORD, None), 'et_ewt' : XPOSDescription(XPOSType.WORD, None), 'eu_bdt' : XPOSDescription(XPOSType.WORD, None), 'fa_perdt' : XPOSDescription(XPOSType.WORD, None), 'fa_seraji' : XPOSDescription(XPOSType.WORD, None), 'fi_tdt' : XPOSDescription(XPOSType.WORD, None), 'fr_combined' : XPOSDescription(XPOSType.WORD, None), 'fr_gsd' : XPOSDescription(XPOSType.WORD, None), 'fr_parisstories': XPOSDescription(XPOSType.WORD, None), 'fr_rhapsodie' : XPOSDescription(XPOSType.WORD, None), 'fr_sequoia' : XPOSDescription(XPOSType.WORD, None), 'fro_profiterole': XPOSDescription(XPOSType.WORD, None), 'ga_idt' : XPOSDescription(XPOSType.WORD, None), 'ga_twittirish' : XPOSDescription(XPOSType.WORD, None), 'got_proiel' : XPOSDescription(XPOSType.WORD, None), 'grc_proiel' : XPOSDescription(XPOSType.WORD, None), 'grc_ptnk' : XPOSDescription(XPOSType.WORD, None), 'gv_cadhan' : XPOSDescription(XPOSType.WORD, None), 'hbo_ptnk' : XPOSDescription(XPOSType.WORD, None), 'he_combined' : XPOSDescription(XPOSType.WORD, None), 'he_htb' : XPOSDescription(XPOSType.WORD, None), 'he_iahltknesset': XPOSDescription(XPOSType.WORD, None), 'he_iahltwiki' : XPOSDescription(XPOSType.WORD, None), 'hi_hdtb' : XPOSDescription(XPOSType.WORD, None), 'hsb_ufal' : XPOSDescription(XPOSType.WORD, None), 'hu_szeged' : XPOSDescription(XPOSType.WORD, None), 'hy_armtdp' : XPOSDescription(XPOSType.WORD, None), 'hy_bsut' : XPOSDescription(XPOSType.WORD, None), 'hyw_armtdp' : XPOSDescription(XPOSType.WORD, None), 'id_csui' : XPOSDescription(XPOSType.WORD, None), 'it_old' : XPOSDescription(XPOSType.WORD, None), 'ka_glc' : XPOSDescription(XPOSType.WORD, None), 'kk_ktb' : XPOSDescription(XPOSType.WORD, None), 'kmr_mg' : XPOSDescription(XPOSType.WORD, None), 'kpv_lattice' : XPOSDescription(XPOSType.WORD, None), 'ky_ktmu' : XPOSDescription(XPOSType.WORD, None), 'la_proiel' : XPOSDescription(XPOSType.WORD, None), 'lij_glt' : XPOSDescription(XPOSType.WORD, None), 'lt_hse' : XPOSDescription(XPOSType.WORD, None), 'lzh_kyoto' : XPOSDescription(XPOSType.WORD, None), 'mr_ufal' : XPOSDescription(XPOSType.WORD, None), 'mt_mudt' : XPOSDescription(XPOSType.WORD, None), 'myv_jr' : XPOSDescription(XPOSType.WORD, None), 'nb_bokmaal' : XPOSDescription(XPOSType.WORD, None), 'nds_lsdc' : XPOSDescription(XPOSType.WORD, None), 'nn_nynorsk' : XPOSDescription(XPOSType.WORD, None), 'nn_nynorsklia' : XPOSDescription(XPOSType.WORD, None), 'no_bokmaal' : XPOSDescription(XPOSType.WORD, None), 'orv_birchbark' : XPOSDescription(XPOSType.WORD, None), 'orv_rnc' : XPOSDescription(XPOSType.WORD, None), 'orv_torot' : XPOSDescription(XPOSType.WORD, None), 'ota_boun' : XPOSDescription(XPOSType.WORD, None), 'pcm_nsc' : XPOSDescription(XPOSType.WORD, None), 'pt_bosque' : XPOSDescription(XPOSType.WORD, None), 'pt_cintil' : XPOSDescription(XPOSType.WORD, None), 'pt_dantestocks' : XPOSDescription(XPOSType.WORD, None), 'pt_gsd' : XPOSDescription(XPOSType.WORD, None), 'pt_petrogold' : XPOSDescription(XPOSType.WORD, None), 'pt_porttinari' : XPOSDescription(XPOSType.WORD, None), 'qpm_philotis' : XPOSDescription(XPOSType.WORD, None), 'qtd_sagt' : XPOSDescription(XPOSType.WORD, None), 'ru_gsd' : XPOSDescription(XPOSType.WORD, None), 'ru_poetry' : XPOSDescription(XPOSType.WORD, None), 'ru_syntagrus' : XPOSDescription(XPOSType.WORD, None), 'ru_taiga' : XPOSDescription(XPOSType.WORD, None), 'sa_vedic' : XPOSDescription(XPOSType.WORD, None), 'sme_giella' : XPOSDescription(XPOSType.WORD, None), 'swl_sslc' : XPOSDescription(XPOSType.WORD, None), 'sq_staf' : XPOSDescription(XPOSType.WORD, None), 'te_mtg' : XPOSDescription(XPOSType.WORD, None), 'tr_atis' : XPOSDescription(XPOSType.WORD, None), 'tr_boun' : XPOSDescription(XPOSType.WORD, None), 'tr_framenet' : XPOSDescription(XPOSType.WORD, None), 'tr_imst' : XPOSDescription(XPOSType.WORD, None), 'tr_kenet' : XPOSDescription(XPOSType.WORD, None), 'tr_penn' : XPOSDescription(XPOSType.WORD, None), 'tr_tourism' : XPOSDescription(XPOSType.WORD, None), 'ug_udt' : XPOSDescription(XPOSType.WORD, None), 'uk_parlamint' : XPOSDescription(XPOSType.WORD, None), 'vi_vtb' : XPOSDescription(XPOSType.WORD, None), 'wo_wtb' : XPOSDescription(XPOSType.WORD, None), 'xcl_caval' : XPOSDescription(XPOSType.WORD, None), 'zh-hans_gsdsimp': XPOSDescription(XPOSType.WORD, None), 'zh-hant_gsd' : XPOSDescription(XPOSType.WORD, None), 'zh_gsdsimp' : XPOSDescription(XPOSType.WORD, None), 'en_lines' : XPOSDescription(XPOSType.XPOS, '-'), 'fo_farpahc' : XPOSDescription(XPOSType.XPOS, '-'), 'ja_gsd' : XPOSDescription(XPOSType.XPOS, '-'), 'ja_gsdluw' : XPOSDescription(XPOSType.XPOS, '-'), 'sv_lines' : XPOSDescription(XPOSType.XPOS, '-'), 'ur_udtb' : XPOSDescription(XPOSType.XPOS, '-'), 'fi_ftb' : XPOSDescription(XPOSType.XPOS, ','), 'orv_ruthenian' : XPOSDescription(XPOSType.XPOS, ','), 'id_gsd' : XPOSDescription(XPOSType.XPOS, '+'), 'ko_gsd' : XPOSDescription(XPOSType.XPOS, '+'), 'ko_kaist' : XPOSDescription(XPOSType.XPOS, '+'), 'ko_ksl' : XPOSDescription(XPOSType.XPOS, '+'), 'qaf_arabizi' : XPOSDescription(XPOSType.XPOS, '+'), 'la_ittb' : XPOSDescription(XPOSType.XPOS, '|'), 'la_llct' : XPOSDescription(XPOSType.XPOS, '|'), 'nl_alpino' : XPOSDescription(XPOSType.XPOS, '|'), 'nl_lassysmall' : XPOSDescription(XPOSType.XPOS, '|'), 'sv_talbanken' : XPOSDescription(XPOSType.XPOS, '|'), 'pl_lfg' : XPOSDescription(XPOSType.XPOS, ':'), 'pl_pdb' : XPOSDescription(XPOSType.XPOS, ':'), } def xpos_vocab_factory(data, shorthand): if shorthand not in XPOS_DESCRIPTIONS: logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand) desc = choose_simplest_factory(data, shorthand) if shorthand in XPOS_DESCRIPTIONS: if XPOS_DESCRIPTIONS[shorthand] != desc: # log instead of throw # otherwise, updating datasets would be unpleasant logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc) else: logger.warning("Chose %s for the xpos factory for %s", desc, shorthand) return build_xpos_vocab(desc, data, shorthand) ================================================ FILE: stanza/models/pos/xpos_vocab_utils.py ================================================ from collections import namedtuple from enum import Enum import logging import os from stanza.models.common.vocab import VOCAB_PREFIX from stanza.models.pos.vocab import XPOSVocab, WordVocab class XPOSType(Enum): XPOS = 1 WORD = 2 XPOSDescription = namedtuple('XPOSDescription', ['xpos_type', 'sep']) DEFAULT_KEY = XPOSDescription(XPOSType.WORD, None) logger = logging.getLogger('stanza') def filter_data(data, idx): data_filtered = [] for sentence in data: flag = True for token in sentence: if token[idx] is None: flag = False if flag: data_filtered.append(sentence) return data_filtered def choose_simplest_factory(data, shorthand): logger.info(f'Original length = {len(data)}') data = filter_data(data, idx=2) logger.info(f'Filtered length = {len(data)}') vocab = WordVocab(data, shorthand, idx=2, ignore=["_"]) key = DEFAULT_KEY best_size = len(vocab) - len(VOCAB_PREFIX) if best_size > 20: for sep in ['', '-', '+', '|', ',', ':']: # separators vocab = XPOSVocab(data, shorthand, idx=2, sep=sep) length = sum(len(x) - len(VOCAB_PREFIX) for x in vocab._id2unit.values()) if length < best_size: key = XPOSDescription(XPOSType.XPOS, sep) best_size = length return key def build_xpos_vocab(description, data, shorthand): if description.xpos_type is XPOSType.WORD: return WordVocab(data, shorthand, idx=2, ignore=["_"]) return XPOSVocab(data, shorthand, idx=2, sep=description.sep) ================================================ FILE: stanza/models/tagger.py ================================================ """ Entry point for training and evaluating a POS/morphological features tagger. This tagger uses highway BiLSTM layers with character and word-level representations, and biaffine classifiers to produce consistent POS and UFeats predictions. For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf. """ import argparse import logging import io import os import time import zipfile import numpy as np import torch from torch import nn, optim from stanza.models.pos.data import Dataset, ShuffledDataset from stanza.models.pos.trainer import Trainer from stanza.models.pos import scorer from stanza.models.common import utils from stanza.models.common import pretrain from stanza.models.common.doc import * from stanza.models.common.foundation_cache import FoundationCache from stanza.models.common.peft_config import add_peft_args, resolve_peft_args from stanza.models import _training_logging from stanza.utils.conll import CoNLL logger = logging.getLogger('stanza') def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='data/pos', help='Root dir for saving models.') parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors.') parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.') parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument('--train_file', type=str, default=None, help='Input file for training.') parser.add_argument('--eval_file', type=str, default=None, help='Input file for scoring.') parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.') parser.add_argument('--no_gold_labels', dest='gold_labels', action='store_false', help="Don't score the eval file - perhaps it has no gold labels, for example. Cannot be used at training time") parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--lang', type=str, help='Language') parser.add_argument('--shorthand', type=str, help="Treebank shorthand") parser.add_argument('--hidden_dim', type=int, default=200) parser.add_argument('--char_hidden_dim', type=int, default=400) parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400) parser.add_argument('--composite_deep_biaff_hidden_dim', type=int, default=100) parser.add_argument('--word_emb_dim', type=int, default=75, help='Dimension of the finetuned word embedding. Set to 0 to turn off') parser.add_argument('--word_cutoff', type=int, default=None, help='How common a word must be to include it in the finetuned word embedding. If not set, small word vector files will be 0, larger will be %d' % utils.DEFAULT_WORD_CUTOFF) parser.add_argument('--char_emb_dim', type=int, default=100) parser.add_argument('--tag_emb_dim', type=int, default=50) parser.add_argument('--charlm_transform_dim', type=int, default=None, help='Transform the pretrained charlm to this dimension. If not set, no transform is used') parser.add_argument('--transformed_dim', type=int, default=125) parser.add_argument('--num_layers', type=int, default=2) parser.add_argument('--char_num_layers', type=int, default=1) parser.add_argument('--pretrain_max_vocab', type=int, default=250000) parser.add_argument('--word_dropout', type=float, default=0.33) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--rec_dropout', type=float, default=0, help="Recurrent dropout") parser.add_argument('--char_rec_dropout', type=float, default=0, help="Recurrent dropout") # TODO: refactor charlm arguments for models which use it? parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off character model.") parser.add_argument('--char_bidirectional', dest='char_bidirectional', action='store_true', help="Use a bidirectional version of the non-pretrained charlm. Doesn't help much, makes the models larger") parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.") parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.") parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)") parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert") parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer") parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)') parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)") parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much') parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.") parser.add_argument('--share_hid', action='store_true', help="Share hidden representations for UPOS, XPOS and UFeats.") parser.set_defaults(share_hid=False) parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.') parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam, adamw, adamax, or adadelta. madgrad as an optional dependency') parser.add_argument('--second_optim', type=str, default='amsgrad', help='Optimizer for the second half of training. Default is Adam with AMSGrad') parser.add_argument('--second_optim_reload', default=False, action='store_true', help='Reload the best model instead of continuing from current model if the first optimizer stalls out. This does not seem to help, but might be useful for further experiments') parser.add_argument('--no_second_optim', action='store_const', const=None, dest='second_optim', help="Don't use a second optimizer - only use the first optimizer") parser.add_argument('--lr', type=float, default=3e-3, help='Learning rate') parser.add_argument('--second_lr', type=float, default=None, help='Alternate learning rate for the second optimizer') parser.add_argument('--initial_weight_decay', type=float, default=None, help='Optimizer weight decay for the first optimizer') parser.add_argument('--second_weight_decay', type=float, default=None, help='Optimizer weight decay for the second optimizer') parser.add_argument('--beta2', type=float, default=0.95) parser.add_argument('--max_steps', type=int, default=50000) parser.add_argument('--eval_interval', type=int, default=100) parser.add_argument('--fix_eval_interval', dest='adapt_eval_interval', action='store_false', \ help="Use fixed evaluation interval for all treebanks, otherwise by default the interval will be increased for larger treebanks.") parser.add_argument('--max_steps_before_stop', type=int, default=3000, help='Changes learning method or early terminates after this many steps if the dev scores are not improving') parser.add_argument('--batch_size', type=int, default=250) parser.add_argument('--batch_maximum_tokens', type=int, default=5000, help='When run in a Pipeline, limit a batch to this many tokens to help avoid OOM for long sentences') parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.') parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.') parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)') parser.add_argument('--save_dir', type=str, default='saved_models/pos', help='Root dir for saving models.') parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_tagger.pt", help="File name to save the model") parser.add_argument('--save_each', default=False, action='store_true', help="Save each checkpoint to its own model. Will take up a bunch of space") parser.add_argument('--seed', type=int, default=1234) add_peft_args(parser) utils.add_device_args(parser) parser.add_argument('--augment_nopunct', type=float, default=None, help='Augment the training data by copying this fraction of punct-ending sentences as non-punct. Default of None will aim for roughly 50%%') parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') return parser def parse_args(args=None): parser = build_argparse() args = parser.parse_args(args=args) resolve_peft_args(args, logger) if args.augment_nopunct is None: args.augment_nopunct = 0.25 if args.wandb_name: args.wandb = True if not args.share_hid and args.tag_emb_dim == 0: raise ValueError("Cannot have tag_emb_dim==0 with share_hid==False, as the tags will be embedded for the next layer") args = vars(args) return args def main(args=None): args = parse_args(args=args) utils.set_random_seed(args['seed']) logger.info("Running tagger in {} mode".format(args['mode'])) if args['mode'] == 'train': return train(args) else: return evaluate(args) def model_file_name(args): return utils.standard_model_file_name(args, "tagger") def save_each_file_name(args): model_file = model_file_name(args) pieces = os.path.splitext(model_file) return pieces[0] + "_%05d" + pieces[1] def load_pretrain(args): pt = None if args['pretrain']: pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang']) if os.path.exists(pretrain_file): vec_file = None else: vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand']) pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab']) return pt def get_eval_type(dev_batch): """ If there is only one column to score in the dev set, use that instead of AllTags """ if dev_batch.has_xpos and not dev_batch.has_upos and not dev_batch.has_feats: return "XPOS" elif dev_batch.has_upos and not dev_batch.has_xpos and not dev_batch.has_feats: return "UPOS" else: return "AllTags" def load_training_data(args, pretrain): train_docs = [] raw_train_files = args['train_file'].split(";") train_files = [] for train_file in raw_train_files: if zipfile.is_zipfile(train_file): logger.info("Decompressing %s" % train_file) with zipfile.ZipFile(train_file) as zin: for zipped_train_file in zin.namelist(): with zin.open(zipped_train_file) as fin: logger.info("Reading %s from %s" % (zipped_train_file, train_file)) train_str = fin.read() train_str = train_str.decode("utf-8") train_file_data, _, _ = CoNLL.conll2dict(input_str=train_str) logger.info("Train File {} from {}, Data Size: {}".format(zipped_train_file, train_file, len(train_file_data))) train_docs.append(Document(train_file_data)) train_files.append("%s %s" % (train_file, zipped_train_file)) else: logger.info("Reading %s" % train_file) # train_data is now a list of sentences, where each sentence is a # list of words, in which each word is a dict of conll attributes train_file_data, _, _ = CoNLL.conll2dict(input_file=train_file) logger.info("Train File {}, Data Size: {}".format(train_file, len(train_file_data))) train_docs.append(Document(train_file_data)) train_files.append(train_file) if sum(len(x.sentences) for x in train_docs) == 0: raise RuntimeError("Training data for the tagger is empty: %s" % args['train_file']) # we want to ensure that the model is able te output _ for empty columns, # but create batches whereby if a doc has upos/xpos tags we include them all. # therefore, we create separate datasets and loaders for each input training file, # which will ensure the system be able to see batches with both upos available # and upos unavailable depending on what the availability in the file is. vocab = Dataset.init_vocab(train_docs, args) train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False) for i in train_docs] for train_file, td in zip(train_files, train_data): if not td.has_upos: logger.info("No UPOS in %s" % train_file) if not td.has_xpos: logger.info("No XPOS in %s" % train_file) if not td.has_feats: logger.info("No feats in %s" % train_file) # reject partially tagged upos or xpos documents # otherwise, the model will learn to output blanks for some words, # which is probably a confusing result # (and definitely throws off the depparse) # another option would be to treat those as masked out for td_idx, td in enumerate(train_data): if td.has_upos: upos_data = td.doc.get(UPOS, as_sentences=True) for sentence_idx, sentence in enumerate(upos_data): for word_idx, upos in enumerate(sentence): if upos == '_' or upos is None: conll = "{:C}".format(td.doc.sentences[sentence_idx]) raise RuntimeError("Found a blank tag in the UPOS at sentence %d word %d of %s.\n%s" % ((sentence_idx+1), (word_idx+1), train_files[td_idx], conll)) # here we make sure the model will learn to output _ for empty columns # if *any* dataset has data for the upos, xpos, or feature column, # we consider that data enough to train the model on that column # otherwise, we want to train the model to always output blanks if not any(td.has_upos for td in train_data): for td in train_data: td.has_upos = True if not any(td.has_xpos for td in train_data): for td in train_data: td.has_xpos = True if not any(td.has_feats for td in train_data): for td in train_data: td.has_feats = True # calculate the batches train_batches = ShuffledDataset(train_data, args["batch_size"]) return vocab, train_data, train_batches def train(args): model_file = model_file_name(args) utils.ensure_dir(os.path.split(model_file)[0]) if args['save_each']: # so models.pt -> models_0001.pt, etc model_save_each_file = save_each_file_name(args) logger.info("Saving each checkpoint to %s" % model_save_each_file) # load pretrained vectors if needed pretrain = load_pretrain(args) args['word_cutoff'] = utils.update_word_cutoff(pretrain, args['word_cutoff']) if args['charlm']: if args['charlm_shorthand'] is None: raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...") logger.info('Using pretrained contextualized char embedding') if not args['charlm_forward_file']: args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand']) if not args['charlm_backward_file']: args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand']) # load data logger.info("Loading data with batch size {}...".format(args['batch_size'])) vocab, train_data, train_batches = load_training_data(args, pretrain) dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True) dev_batch = dev_data.to_loader(batch_size=args["batch_size"]) eval_type = get_eval_type(dev_data) # skip training if the language does not have training or dev data # sum(...) to check if all of the training files are empty if sum(len(td) for td in train_data) == 0 or len(dev_data) == 0: logger.info("Skip training because no data available...") return None, None if args['wandb']: import wandb wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_tagger" % args['shorthand'] wandb.init(name=wandb_name, config=args) wandb.run.define_metric('train_loss', summary='min') wandb.run.define_metric('dev_score', summary='max') logger.info("Training tagger...") foundation_cache = FoundationCache() trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'], foundation_cache=foundation_cache) global_step = 0 max_steps = args['max_steps'] dev_score_history = [] best_dev_preds = [] current_lr = args['lr'] global_start_time = time.time() format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}' logger.debug("Training model on device %s", next(trainer.model.parameters()).device) if args['adapt_eval_interval']: args['eval_interval'] = utils.get_adaptive_eval_interval(dev_data.num_examples, 2000, args['eval_interval']) logger.info("Evaluating the model every {} steps...".format(args['eval_interval'])) if args['save_each']: logger.info("Saving initial checkpoint to %s" % (model_save_each_file % global_step)) trainer.save(model_save_each_file % global_step) using_amsgrad = False last_best_step = 0 # start training train_loss = 0 if args['log_norms']: trainer.model.log_norms() while True: do_break = False for i, batch in enumerate(train_batches): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step train_loss += loss if global_step % args['log_step'] == 0: duration = time.time() - start_time logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr)) if args['log_norms']: trainer.model.log_norms() if global_step % args['eval_interval'] == 0: # eval on dev logger.info("Evaluating on dev set...") dev_preds = [] indices = [] for batch in dev_batch: preds = trainer.predict(batch) dev_preds += preds indices.extend(batch[-1]) dev_preds = utils.unsort(dev_preds, indices) dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in dev_preds for y in x]) system_pred_file = "{:C}\n\n".format(dev_data.doc) system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type) train_loss = train_loss / args['eval_interval'] # avg loss per batch logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score)) if args['wandb']: wandb.log({'train_loss': train_loss, 'dev_score': dev_score}) train_loss = 0 if args['save_each']: logger.info("Saving checkpoint to %s" % (model_save_each_file % global_step)) trainer.save(model_save_each_file % global_step) # save best model if len(dev_score_history) == 0 or dev_score > max(dev_score_history): last_best_step = global_step trainer.save(model_file) logger.info("new best model saved.") best_dev_preds = dev_preds dev_score_history += [dev_score] if global_step - last_best_step >= args['max_steps_before_stop']: if not using_amsgrad and args['second_optim'] is not None: logger.info("Switching to second optimizer: {}".format(args['second_optim'])) if args['second_optim_reload']: logger.info('Reloading best model to continue from current local optimum') trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain, model_file=model_file, device=args['device'], foundation_cache=foundation_cache) last_best_step = global_step using_amsgrad = True lr = args['second_lr'] if lr is None: lr = args['lr'] trainer.optimizer = utils.get_optimizer(args['second_optim'], trainer.model, lr=lr, betas=(.9, args['beta2']), eps=1e-6, weight_decay=args['second_weight_decay']) else: logger.info("Early termination: have not improved in {} steps".format(args['max_steps_before_stop'])) do_break = True break if global_step >= args['max_steps']: do_break = True break if do_break: break logger.info("Training ended with {} steps.".format(global_step)) if args['wandb']: wandb.finish() if len(dev_score_history) > 0: best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1 logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval'])) else: logger.info("Dev set never evaluated. Saving final model.") trainer.save(model_file) return trainer, _ def evaluate(args): # file paths model_file = model_file_name(args) pretrain = load_pretrain(args) load_args = {'charlm_forward_file': args.get('charlm_forward_file', None), 'charlm_backward_file': args.get('charlm_backward_file', None)} # load model logger.info("Loading model from: {}".format(model_file)) trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args) result_doc = evaluate_trainer(args, trainer, pretrain) return trainer, result_doc def evaluate_trainer(args, trainer, pretrain): system_pred_file = args['output_file'] loaded_args, vocab = trainer.args, trainer.vocab # load config for k in args: if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand'] or k == 'mode': loaded_args[k] = args[k] # load data logger.info("Loading data with batch size {}...".format(args['batch_size'])) doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_data = Dataset(doc, loaded_args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True) dev_batch = dev_data.to_loader(batch_size=args['batch_size']) eval_type = get_eval_type(dev_data) if len(dev_batch) > 0: logger.info("Start evaluation...") preds = [] indices = [] with torch.no_grad(): for b in dev_batch: preds += trainer.predict(b) indices.extend(b[-1]) else: # skip eval if dev data does not exist preds = [] preds = utils.unsort(preds, indices) # write to file and score dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in preds for y in x]) if system_pred_file: CoNLL.write_doc2conll(dev_data.doc, system_pred_file) if args['gold_labels']: system_pred_file = "{:C}\n\n".format(dev_data.doc) system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type) logger.info("POS Tagger score: %s %.2f", args['shorthand'], score*100) return dev_data.doc if __name__ == '__main__': main() ================================================ FILE: stanza/models/tokenization/__init__.py ================================================ ================================================ FILE: stanza/models/tokenization/data.py ================================================ from bisect import bisect_right from collections import defaultdict from copy import copy import numpy as np import random import logging import re import torch from torch.utils.data import Dataset from stanza.models.common.utils import sort_with_indices, unsort from stanza.models.tokenization.vocab import Vocab logger = logging.getLogger('stanza') def filter_consecutive_whitespaces(para): filtered = [] for i, (char, label) in enumerate(para): if i > 0: if char == ' ' and para[i-1][0] == ' ': continue filtered.append((char, label)) return filtered NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n') # this was (r'^([\d]+[,\.]*)+$') # but the runtime on that can explode exponentially # for example, on 111111111111111111111111a NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$') WHITESPACE_RE = re.compile(r'\s') class TokenizationDataset: def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs): super().__init__(*args, **kwargs) # forwards all unused arguments self.args = tokenizer_args self.eval = evaluation self.dictionary = dictionary self.vocab = vocab # get input files txt_file = input_files['txt'] label_file = input_files['label'] # Load data and process it # set up text from file or input string assert txt_file is not None or input_text is not None if input_text is None: with open(txt_file, encoding="utf-8") as f: text = ''.join(f.readlines()).rstrip() else: text = input_text text_chunks = NEWLINE_WHITESPACE_RE.split(text) text_chunks = [pt.rstrip() for pt in text_chunks] text_chunks = [pt for pt in text_chunks if pt] if label_file is not None: with open(label_file, encoding="utf-8") as f: labels = ''.join(f.readlines()).rstrip() labels = NEWLINE_WHITESPACE_RE.split(labels) labels = [pt.rstrip() for pt in labels] labels = [map(int, pt) for pt in labels if pt] else: labels = [[0 for _ in pt] for pt in text_chunks] skip_newline = self.args.get('skip_newline', False) self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten for pt, pc in zip(text_chunks, labels)] # remove consecutive whitespaces self.data = [filter_consecutive_whitespaces(x) for x in self.data] def labels(self): """ Returns a list of the labels for all of the sentences in this DataLoader Used at eval time to compare to the results, for example """ return [np.array(list(x[1] for x in sent)) for sent in self.data] def extract_dict_feat(self, para, idx): """ This function is to extract dictionary features for each character """ length = len(para) dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])] dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])] forward_word = para[idx][0] backward_word = para[idx][0] prefix = True suffix = True for window in range(1,self.args['num_dict_feat']+1): # concatenate each character and check if words found in dict not, stop if prefix not found #check if idx+t is out of bound and if the prefix is already not found if (idx + window) <= length-1 and prefix: forward_word += para[idx+window][0].lower() #check in json file if the word is present as prefix or word or None. feat = 1 if forward_word in self.dictionary["words"] else 0 #if the return value is not 2 or 3 then the checking word is not a valid word in dict. dict_forward_feats[window-1] = feat #if the dict return 0 means no prefixes found, thus, stop looking for forward. if forward_word not in self.dictionary["prefixes"]: prefix = False #backward check: similar to forward if (idx - window) >= 0 and suffix: backward_word = para[idx-window][0].lower() + backward_word feat = 1 if backward_word in self.dictionary["words"] else 0 dict_backward_feats[window-1] = feat if backward_word not in self.dictionary["suffixes"]: suffix = False #if cannot find both prefix and suffix, then exit the loop if not prefix and not suffix: break return dict_forward_feats + dict_backward_feats def para_to_sentences(self, para): """ Convert a paragraph to a list of processed sentences. """ res = [] funcs = [] for feat_func in self.args['feat_funcs']: if feat_func == 'end_of_para' or feat_func == 'start_of_para': # skip for position-dependent features continue if feat_func == 'space_before': func = lambda x: 1 if x.startswith(' ') else 0 elif feat_func == 'capitalized': func = lambda x: 1 if x[0].isupper() else 0 elif feat_func == 'numeric': func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0 else: raise ValueError('Feature function "{}" is undefined.'.format(feat_func)) funcs.append(func) # stacking all featurize functions composite_func = lambda x: [f(x) for f in funcs] def process_sentence(sent_units, sent_labels, sent_feats): return (np.array([self.vocab.unit2id(y) for y in sent_units]), np.array(sent_labels), np.array(sent_feats), list(sent_units)) use_end_of_para = 'end_of_para' in self.args['feat_funcs'] use_start_of_para = 'start_of_para' in self.args['feat_funcs'] use_dictionary = self.args['use_dictionary'] current_units = [] current_labels = [] current_feats = [] for i, (unit, label) in enumerate(para): feats = composite_func(unit) # position-dependent features if use_end_of_para: f = 1 if i == len(para)-1 else 0 feats.append(f) if use_start_of_para: f = 1 if i == 0 else 0 feats.append(f) #if dictionary feature is selected if use_dictionary: dict_feats = self.extract_dict_feat(para, i) feats = feats + dict_feats current_units.append(unit) current_labels.append(label) current_feats.append(feats) if not self.eval and (label == 2 or label == 4): # end of sentence if len(current_units) <= self.args['max_seqlen']: # get rid of sentences that are too long during training of the tokenizer res.append(process_sentence(current_units, current_labels, current_feats)) current_units.clear() current_labels.clear() current_feats.clear() if len(current_units) > 0: if self.eval or len(current_units) <= self.args['max_seqlen']: res.append(process_sentence(current_units, current_labels, current_feats)) return res def advance_old_batch(self, eval_offsets, old_batch): """ Advance to a new position in a batch where we have partially processed the batch If we have previously built a batch of data and made predictions on them, then when we are trying to make prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch and just (essentially) advance the indices/offsets from where we read converted data in this old batch. In this case, eval_offsets index within the old_batch to advance the strings to process. """ unkid = self.vocab.unit2id('') padid = self.vocab.unit2id('') ounits, olabels, ofeatures, oraw = old_batch feat_size = ofeatures.shape[-1] lens = (ounits != padid).sum(1).tolist() pad_len = max(l-i for i, l in zip(eval_offsets, lens)) units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64) labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32) features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32) raw_units = [] for i in range(len(ounits)): eval_offsets[i] = min(eval_offsets[i], lens[i]) units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]] labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]] features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]] raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + [''] * (pad_len - lens[i] + eval_offsets[i])) return units, labels, features, raw_units def build_move_punct_set(data, move_back_prob): move_punct = {',', ':', '!', '.', '?', '"', '(', ')'} for chunk in data: # ignore positions at the start and end of a chunk for idx in range(1, len(chunk)-1): if chunk[idx][0] not in move_punct: continue if chunk[idx][1] == 0: if chunk[idx+1][0].isspace() and not chunk[idx-1][0].isdigit(): # this check removes punct which isn't ending a word... # honestly that's a rather unusual situation # VI has |3, 5| as a complete token # so we also eliminate isdigit() move_punct.remove(chunk[idx][0]) continue # we skip isdigit() because we will intentionally not # create things that look like decimal numbers if not chunk[idx-1][0].isspace() and chunk[idx-1][0] not in move_punct and not chunk[idx-1][0].isdigit(): # this check eliminates things like '.' after 'Mr.' move_punct.remove(chunk[idx][0]) continue return move_punct def build_known_mwt(data, mwt_expansions): known_mwts = set() for chunk in data: for idx, unit in enumerate(chunk): if unit[1] != 3: continue # found an MWT prev_idx = idx - 1 while prev_idx >= 0 and chunk[prev_idx][1] == 0: prev_idx -= 1 prev_idx += 1 while chunk[prev_idx][0].isspace(): prev_idx += 1 if prev_idx == idx: continue mwt = "".join(x[0] for x in chunk[prev_idx:idx+1]) if mwt not in mwt_expansions: continue if len(mwt_expansions[mwt]) > 2: # TODO: could split 3 word tokens as well continue known_mwts.add(mwt) return known_mwts class DataLoader(TokenizationDataset): """ This is the training version of the dataset. """ def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None): super().__init__(args, input_files, input_text, vocab, evaluation, dictionary) self.vocab = vocab if vocab is not None else self.init_vocab() # data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels. # At evaluation time, each paragraph is treated as single "sentence" as we don't know a priori where # sentence breaks occur. We make prediction from left to right for each paragraph and move forward to # the last predicted sentence break to start afresh. self.sentences = [self.para_to_sentences(para) for para in self.data] self.init_sent_ids() logger.debug(f"{len(self.sentence_ids)} sentences loaded.") punct_move_back_prob = args.get('punct_move_back_prob', 0.0) if punct_move_back_prob > 0.0: self.move_punct = build_move_punct_set(self.data, punct_move_back_prob) if len(self.move_punct) > 0: logger.debug('Based on the training data, will augment space/punct combinations {}'.format(self.move_punct)) else: logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace') split_mwt_prob = args.get('split_mwt_prob', 0.0) if split_mwt_prob > 0.0 and not evaluation: self.mwt_expansions = mwt_expansions self.known_mwt = build_known_mwt(self.data, mwt_expansions) if len(self.known_mwt) > 0: logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt)) else: logger.debug('Based on the training data, there are NO MWT to split at training time') augment_final_punct_prob = 0.0 if evaluation else args.get('augment_final_punct_prob', 0.0) if augment_final_punct_prob > 0: self.augmentations = defaultdict(list) AUGMENT_PAIRS = [("?", "?"), ("?", "︖"), ("?", "﹖"), ("?", "⁇"), ("!", "!"), ("!", "︕"), ("!", "﹗"), ("!", "‼"),] for orig, target in AUGMENT_PAIRS: if self.augment_vocab(self.vocab, self.data, orig, target): logger.debug('Based on the training data, augmenting |%s| to |%s|' % (orig, target)) self.augmentations[orig].append(target) if self.augment_vocab(self.vocab, self.data, target, orig): logger.debug('Based on the training data, augmenting |%s| to |%s|' % (target, orig)) self.augmentations[target].append(orig) def __len__(self): return len(self.sentence_ids) def init_vocab(self): vocab = Vocab(self.data, self.args['lang']) return vocab @staticmethod def augment_vocab(vocab, data, existing_unit, new_unit): if existing_unit not in vocab: return False new_unit_count = 0 existing_unit_count = 0 for sentence in data: unit = sentence[-1][0] if unit == new_unit: new_unit_count += 1 elif unit == existing_unit: existing_unit_count += 1 if existing_unit_count == 0: return False if new_unit_count > 0: return False if new_unit not in vocab: vocab.append(new_unit) logger.debug("Found %d |%s| and %d |%s|", new_unit_count, new_unit, existing_unit_count, existing_unit) return True def init_sent_ids(self): self.sentence_ids = [] self.cumlen = [0] for i, para in enumerate(self.sentences): for j in range(len(para)): self.sentence_ids += [(i, j)] self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])] def has_mwt(self): # presumably this only needs to be called either 0 or 1 times, # 1 when training and 0 any other time, so no effort is put # into caching the result for sentence in self.data: for word in sentence: if word[1] > 2: return True return False def shuffle(self): for para in self.sentences: random.shuffle(para) self.init_sent_ids() def move_last_char(self, sentence): if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen'] and sentence[1][-1] == 2 and sentence[1][-2] != 0: new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])] new_units.extend([(' ', 0), (sentence[3][-1], int(sentence[1][-1]))]) encoded = self.para_to_sentences(new_units) return encoded return None def split_mwt(self, sentence): if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']: return None # if we find a token in the sentence which ends with label 3, # eg it is an MWT, # with some probability we split it into two tokens # and treat the split tokens as both label 1 instead of 3 # in this manner, we teach the tokenizer not to treat the # entire sequence of characters with added spaces as an MWT, # which weirdly can happen in some corner cases mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3] if len(mwt_ends) == 0: return None random_end = random.randint(0, len(mwt_ends)-1) mwt_end = mwt_ends[random_end] mwt_start = mwt_end - 1 while mwt_start >= 0 and sentence[1][mwt_start] == 0: mwt_start -= 1 mwt_start += 1 while sentence[3][mwt_start].isspace(): mwt_start += 1 if mwt_start == mwt_end: return None mwt = "".join(x for x in sentence[3][mwt_start:mwt_end+1]) if mwt not in self.mwt_expansions: return None all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])] w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]] w0_units[-1] = (w0_units[-1][0], 1) w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]] w1_units[-1] = (w1_units[-1][0], 1) split_units = w0_units + [(' ', 0)] + w1_units new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:] encoded = self.para_to_sentences(new_units) return encoded def move_punct_back(self, sentence): if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']: return None # check that we are not accidentally creating decimal numbers # idx == 1 or not sentence[3][idx-2].isdigit() # one disadvantage of checking for sentence[1][idx] == 0 # would be that tokens of all punct, such as '...', # should move but would not move if this is eliminated commas = [idx for idx, c in enumerate(sentence[3]) if c in self.move_punct and idx > 0 and sentence[3][idx-1].isspace() and (idx == 1 or not sentence[3][idx-2].isdigit())] if len(commas) == 0: return None all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])] new_units = [] span_start = 0 for span_end in commas: new_units.extend(all_units[span_start:span_end-1]) span_start = span_end if span_end < len(sentence[3]): new_units.extend(all_units[span_end:]) encoded = self.para_to_sentences(new_units) return encoded def augment_final_punct(self, sentence): if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen']: if sentence[3][-1] in self.augmentations: augmented = random.choice(self.augmentations[sentence[3][-1]]) new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])] new_units.append((augmented, sentence[1][-1])) else: return None encoded = self.para_to_sentences(new_units) return encoded return None def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0): ''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. ''' feat_size = len(self.sentences[0][0][2][0]) unkid = self.vocab.unit2id('') padid = self.vocab.unit2id('') def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): # At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed # by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences # from the entire dataset until we reach max_seqlen. drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0)) drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0)) move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0) move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0) split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0) augment_final_punct_prob = 0.0 if self.eval else self.args.get('augment_final_punct_prob', 0.0) pid, sid = id_pair if self.eval else random.choice(self.sentence_ids) sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])] total_len = len(sentences[0][0]) assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])])) if self.eval: for sid1 in range(sid+1, len(self.sentences[pid])): total_len += len(self.sentences[pid][sid1][0]) sentences.append(self.sentences[pid][sid1]) if total_len >= self.args['max_seqlen']: break else: while True: pid1, sid1 = random.choice(self.sentence_ids) total_len += len(self.sentences[pid1][sid1][0]) sentences.append(self.sentences[pid1][sid1]) if total_len >= self.args['max_seqlen']: break if move_last_char_prob > 0.0: for sentence_idx, sentence in enumerate(sentences): if random.random() < move_last_char_prob: # the sentence might not be eligible, such as # already having a space or not having a sentence final punct, # so we need to do a two step checking process here new_sentence = self.move_last_char(sentence) if new_sentence is not None: sentences[sentence_idx] = new_sentence[0] total_len += 1 if move_punct_back_prob > 0.0: for sentence_idx, sentence in enumerate(sentences): if random.random() < move_punct_back_prob: # the sentence might not be eligible, such as # not having a space separated punct, # so we need to do a two step checking process here new_sentence = self.move_punct_back(sentence) if new_sentence is not None: total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3]) sentences[sentence_idx] = new_sentence[0] if split_mwt_prob > 0.0: for sentence_idx, sentence in enumerate(sentences): if random.random() < split_mwt_prob: new_sentence = self.split_mwt(sentence) if new_sentence is not None: total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3]) sentences[sentence_idx] = new_sentence[0] if augment_final_punct_prob > 0.0: for sentence_idx, sentence in enumerate(sentences): if random.random() < split_mwt_prob: new_sentence = self.augment_final_punct(sentence) if new_sentence is not None: sentences[sentence_idx] = new_sentence[0] if drop_sents and len(sentences) > 1: if total_len > self.args['max_seqlen']: sentences = sentences[:-1] if len(sentences) > 1: p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0] sentences = sentences[:cutoff+1] units = np.concatenate([s[0] for s in sentences]) labels = np.concatenate([s[1] for s in sentences]) feats = np.concatenate([s[2] for s in sentences]) raw_units = [x for s in sentences for x in s[3]] if not self.eval: cutoff = self.args['max_seqlen'] units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff] if drop_last_char: # can only happen in non-eval mode if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3): # training text ended with a sentence end position # and that word was a single character # and the previous character ended the word units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1] # word end -> sentence end, mwt end -> sentence mwt end labels[-1] = labels[-1] + 1 return units, labels, feats, raw_units if eval_offsets is not None: # find max padding length pad_len = 0 for eval_offset in eval_offsets: if eval_offset < self.cumlen[-1]: pair_id = bisect_right(self.cumlen, eval_offset) - 1 pair = self.sentence_ids[pair_id] pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0])) pad_len += 1 id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets] pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs] offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)] offsets_pairs = list(zip(offsets, pairs)) else: id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size'])) offsets_pairs = [(0, x) for x in id_pairs] pad_len = self.args['max_seqlen'] # put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64) labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64) features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32) raw_units = [] for i, (offset, pair) in enumerate(offsets_pairs): u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len) units[i, :len(u_)] = u_ labels[i, :len(l_)] = l_ features[i, :len(f_), :] = f_ raw_units.append(r_ + [''] * (pad_len - len(r_))) if unit_dropout > 0 and not self.eval: # dropout characters/units at training time and replace them with UNKs mask = np.random.random_sample(units.shape) < unit_dropout mask[units == padid] = 0 units[mask] = unkid for i in range(len(raw_units)): for j in range(len(raw_units[i])): if mask[i, j]: raw_units[i][j] = '' # dropout unit feature vector in addition to only torch.dropout in the model. # experiments showed that only torch.dropout hurts the model # we believe it is because the dict feature vector is mostly scarse so it makes # more sense to drop out the whole vector instead of only single element. if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval: mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout mask_feat[units == padid] = 0 for i in range(len(raw_units)): for j in range(len(raw_units[i])): if mask_feat[i,j]: features[i,j,:] = 0 units = torch.from_numpy(units) labels = torch.from_numpy(labels) features = torch.from_numpy(features) return units, labels, features, raw_units class SortedDataset(Dataset): """ Holds a TokenizationDataset for use in a torch DataLoader The torch DataLoader is different from the DataLoader defined here and allows for cpu & gpu parallelism. Updating output_predictions to use this class as a wrapper to a TokenizationDataset means the calculation of features can happen in parallel, saving quite a bit of time. """ def __init__(self, dataset): super().__init__() self.dataset = dataset self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True) def __len__(self): return len(self.data) def __getitem__(self, index): # This will return a single sample # np: index in character map # np: tokenization label # np: features # list: original text as one length strings return self.dataset.para_to_sentences(self.data[index]) def unsort(self, arr): return unsort(arr, self.indices) def collate(self, samples): if any(len(x) > 1 for x in samples): raise ValueError("Expected all paragraphs to have no preset sentence splits!") feat_size = samples[0][0][2].shape[-1] padid = self.dataset.vocab.unit2id('') # +1 so that all samples end with at least one pad pad_len = max(len(x[0][3]) for x in samples) + 1 units = torch.full((len(samples), pad_len), padid, dtype=torch.int64) labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32) features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32) raw_units = [] for i, sample in enumerate(samples): u_, l_, f_, r_ = sample[0] units[i, :len(u_)] = torch.from_numpy(u_) labels[i, :len(l_)] = torch.from_numpy(l_) features[i, :len(f_), :] = torch.from_numpy(f_) raw_units.append(r_ + ['']) return units, labels, features, raw_units ================================================ FILE: stanza/models/tokenization/model.py ================================================ import torch import torch.nn.functional as F import torch.nn as nn from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence from stanza.models.common.char_model import CharacterLanguageModelWordAdapter from stanza.models.common.foundation_cache import load_charlm class Tokenizer(nn.Module): def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, foundation_cache=None): super().__init__() self.unsaved_modules = [] self.args = args feat_dim = args['feat_dim'] self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0) self.input_dim = emb_dim + feat_dim charmodel = None if args is not None and args.get('charlm_forward_file', None): charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache) charmodels = nn.ModuleList([charmodel_forward]) charmodel = CharacterLanguageModelWordAdapter(charmodels) self.input_dim += charmodel.hidden_dim() self.add_unsaved_module("charmodel", charmodel) self.rnn = nn.LSTM(self.input_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0) if self.args['conv_res'] is not None: self.conv_res = nn.ModuleList() self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')] for si, size in enumerate(self.conv_sizes): l = nn.Conv1d(self.input_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0)) self.conv_res.append(l) if self.args.get('hier_conv_res', False): self.conv_res2 = nn.Conv1d(hidden_dim * 2 * len(self.conv_sizes), hidden_dim * 2, 1) self.tok_clf = nn.Linear(hidden_dim * 2, 1) self.sent_clf = nn.Linear(hidden_dim * 2, 1) if self.args['use_mwt']: self.mwt_clf = nn.Linear(hidden_dim * 2, 1) if args['hierarchical']: in_dim = hidden_dim * 2 self.rnn2 = nn.LSTM(in_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True) self.tok_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) self.sent_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) if self.args['use_mwt']: self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) self.dropout = nn.Dropout(dropout) self.dropout_feat = nn.Dropout(feat_dropout) self.toknoise = nn.Dropout(self.args['tok_noise']) def add_unsaved_module(self, name, module): self.unsaved_modules += [name] setattr(self, name, module) def forward(self, x, feats, lengths, raw=None): emb = self.embeddings(x) if self.charmodel is not None and raw is not None: char_emb = self.charmodel(raw, wrap=False) emb = torch.cat([emb, char_emb], axis=2) emb = self.dropout(emb) feats = self.dropout_feat(feats) emb = torch.cat([emb, feats], 2) emb = pack_padded_sequence(emb, lengths, batch_first=True) inp, _ = self.rnn(emb) inp, _ = pad_packed_sequence(inp, batch_first=True) if self.args['conv_res'] is not None: conv_input = emb.transpose(1, 2).contiguous() if not self.args.get('hier_conv_res', False): for l in self.conv_res: inp = inp + l(conv_input).transpose(1, 2).contiguous() else: hid = [] for l in self.conv_res: hid += [l(conv_input)] hid = torch.cat(hid, 1) hid = F.relu(hid) hid = self.dropout(hid) inp = inp + self.conv_res2(hid).transpose(1, 2).contiguous() inp = self.dropout(inp) tok0 = self.tok_clf(inp) sent0 = self.sent_clf(inp) if self.args['use_mwt']: mwt0 = self.mwt_clf(inp) if self.args['hierarchical']: inp2 = inp if self.args['hier_invtemp'] > 0: inp2 = inp2 * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp']))) inp2 = pack_padded_sequence(inp2, lengths, batch_first=True) inp2, _ = self.rnn2(inp2) inp2, _ = pad_packed_sequence(inp2, batch_first=True) inp2 = self.dropout(inp2) tok0 = tok0 + self.tok_clf2(inp2) sent0 = sent0 + self.sent_clf2(inp2) if self.args['use_mwt']: mwt0 = mwt0 + self.mwt_clf2(inp2) nontok = F.logsigmoid(-tok0) tok = F.logsigmoid(tok0) nonsent = F.logsigmoid(-sent0) sent = F.logsigmoid(sent0) if self.args['use_mwt']: nonmwt = F.logsigmoid(-mwt0) mwt = F.logsigmoid(mwt0) pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2) else: pred = torch.cat([nontok, tok+nonsent, tok+sent], 2) return pred ================================================ FILE: stanza/models/tokenization/tokenize_files.py ================================================ """Use a Stanza tokenizer to turn a text file into one tokenized paragraph per line For example, the output of this script is suitable for Glove Currently this *only* supports tokenization, no MWT splitting. It also would be beneficial to have an option to convert spaces into NBSP, underscore, or some other marker to make it easier to process languages such as VI which have spaces in them """ import argparse import io import os import time import re import zipfile import torch import stanza from stanza.models.common.utils import open_read_text, default_device from stanza.models.tokenization.data import TokenizationDataset from stanza.models.tokenization.utils import output_predictions from stanza.pipeline.tokenize_processor import TokenizeProcessor from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() NEWLINE_SPLIT_RE = re.compile(r"\n\s*\n") def tokenize_to_file(tokenizer, fin, fout, chunk_size=500): raw_text = fin.read() documents = NEWLINE_SPLIT_RE.split(raw_text) for chunk_start in tqdm(range(0, len(documents), chunk_size), leave=False): chunk_end = min(chunk_start + chunk_size, len(documents)) chunk = documents[chunk_start:chunk_end] in_docs = [stanza.Document([], text=d) for d in chunk] out_docs = tokenizer.bulk_process(in_docs) for document in out_docs: for sent_idx, sentence in enumerate(document.sentences): if sent_idx > 0: fout.write(" ") fout.write(" ".join(x.text for x in sentence.tokens)) fout.write("\n") def main(args=None): parser = argparse.ArgumentParser() parser.add_argument("--lang", type=str, default="sd", help="Which language to use for tokenization") parser.add_argument("--tokenize_model_path", type=str, default=None, help="Specific tokenizer model to use") parser.add_argument("input_files", type=str, nargs="+", help="Which input files to tokenize") parser.add_argument("--output_file", type=str, default="glove.txt", help="Where to write the tokenized output") parser.add_argument("--model_dir", type=str, default=None, help="Where to get models for a Pipeline (None => default models dir)") parser.add_argument("--chunk_size", type=int, default=500, help="How many 'documents' to use in a chunk when tokenizing. This is separate from the tokenizer batching - this limits how much memory gets used at once, since we don't need to store an entire file in memory at once") args = parser.parse_args(args=args) if os.path.exists(args.output_file): print("Cowardly refusing to overwrite existing output file %s" % args.output_file) return if args.tokenize_model_path: config = { "model_path": args.tokenize_model_path, "check_requirements": False } tokenizer = TokenizeProcessor(config, pipeline=None, device=default_device()) else: pipe = stanza.Pipeline(lang=args.lang, processors="tokenize", model_dir=args.model_dir) tokenizer = pipe.processors["tokenize"] with open(args.output_file, "w", encoding="utf-8") as fout: for filename in tqdm(args.input_files): if filename.endswith(".zip"): with zipfile.ZipFile(filename) as zin: input_names = zin.namelist() for input_name in tqdm(input_names, leave=False): with zin.open(input_names[0]) as fin: fin = io.TextIOWrapper(fin, encoding='utf-8') tokenize_to_file(tokenizer, fin, fout) else: with open_read_text(filename, encoding="utf-8") as fin: tokenize_to_file(tokenizer, fin, fout) if __name__ == '__main__': main() ================================================ FILE: stanza/models/tokenization/trainer.py ================================================ import sys import logging import torch import torch.nn as nn import torch.optim as optim from stanza.models.common import utils from stanza.models.common.trainer import Trainer as BaseTrainer from stanza.models.tokenization.utils import create_dictionary from .model import Tokenizer from .vocab import Vocab logger = logging.getLogger('stanza') class Trainer(BaseTrainer): def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, foundation_cache=None): # TODO: make a test of the training w/ and w/o charlm if model_file is not None: # load everything from file self.load(model_file, args, foundation_cache) else: # build model from scratch self.args = args self.vocab = vocab self.lexicon = list(lexicon) if lexicon is not None else None self.dictionary = dictionary self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) self.model = self.model.to(device) self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device) self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay']) self.feat_funcs = self.args.get('feat_funcs', None) self.lang = self.args['lang'] # language determines how token normalization is done def update(self, inputs): self.model.train() units, labels, features, text = inputs lengths = [len(x) for x in text] device = next(self.model.parameters()).device units = units.to(device) labels = labels.to(device) features = features.to(device) pred = self.model(units, features, lengths, text) self.optimizer.zero_grad() classes = pred.size(2) loss = self.criterion(pred.view(-1, classes), labels.view(-1)) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() return loss.item() def predict(self, inputs): self.model.eval() units, _, features, text = inputs lengths = [len(x) for x in text] device = next(self.model.parameters()).device units = units.to(device) features = features.to(device) pred = self.model(units, features, lengths, text) return pred.data.cpu().numpy() def save(self, filename, skip_modules=True): model_state = None if self.model is not None: model_state = self.model.state_dict() # skip saving modules like the pretrained charlm if skip_modules: skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules] for k in skipped: del model_state[k] params = { 'model': model_state, 'vocab': self.vocab.state_dict(), # save and load lexicon as list instead of set so # we can use weights_only=True 'lexicon': list(self.lexicon) if self.lexicon is not None else None, 'config': self.args } try: torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) except BaseException: logger.warning("Saving failed... continuing anyway.") def load(self, filename, args, foundation_cache): try: checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise self.args = checkpoint['config'] if args is not None and args.get('charlm_forward_file', None) is not None: if checkpoint['config'].get('charlm_forward_file') is None: # if the saved model didn't use a charlm, we skip the charlm here # otherwise the loaded model weights won't fit in the newly created model self.args['charlm_forward_file'] = None else: self.args['charlm_forward_file'] = args['charlm_forward_file'] if self.args.get('use_mwt', None) is None: # Default to True as many currently saved models # were built with mwt layers self.args['use_mwt'] = True self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) self.model.load_state_dict(checkpoint['model'], strict=False) self.vocab = Vocab.load_state_dict(checkpoint['vocab']) self.lexicon = checkpoint['lexicon'] if self.lexicon is not None: self.lexicon = set(self.lexicon) self.dictionary = create_dictionary(self.lexicon) else: self.dictionary = None ================================================ FILE: stanza/models/tokenization/utils.py ================================================ from collections import Counter from copy import copy import json import numpy as np import re import logging import os from torch.utils.data import DataLoader as TorchDataLoader import stanza.utils.default_paths as default_paths from stanza.models.common.utils import ud_scores, harmonic_mean from stanza.models.common.doc import Document from stanza.utils.conll import CoNLL from stanza.models.common.doc import * from stanza.models.tokenization.data import SortedDataset logger = logging.getLogger('stanza') paths = default_paths.get_default_paths() def create_dictionary(lexicon): """ This function is to create a new dictionary used for improving tokenization model for multi-syllable words languages such as vi, zh or th. This function takes the lexicon as input and output a dictionary that contains three set: words, prefixes and suffixes where prefixes set should contains all the prefixes in the lexicon and similar for suffixes. The point of having prefixes/suffixes sets in the dictionary is just to make it easier to check during data preparation. :param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp :param lexicon - set of words used to create dictionary :return a dictionary object that contains words and their prefixes and suffixes. """ dictionary = {"words":set(), "prefixes":set(), "suffixes":set()} def add_word(word): if word not in dictionary["words"]: dictionary["words"].add(word) prefix = "" suffix = "" for i in range(0,len(word)-1): prefix = prefix + word[i] suffix = word[len(word) - i - 1] + suffix dictionary["prefixes"].add(prefix) dictionary["suffixes"].add(suffix) for word in lexicon: if len(word)>1: add_word(word) return dictionary def create_lexicon(shorthand=None, train_path=None, external_path=None): """ This function is to create a lexicon to store all the words from the training set and external dictionary. This lexicon will be saved with the model and will be used to create dictionary when the model is loaded. The idea of separating lexicon and dictionary in two different phases is a good tradeoff between time and space. Note that we eliminate all the long words but less frequently appeared in the lexicon by only taking 95-percentile list of words. :param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp :param train_path - path to conllu train file :param external_path - path to extenral dict, expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt :return a set lexicon object that contains all distinct words """ lexicon = set() length_freq = [] #this regex is to check if a character is an actual Thai character as seems .isalpha() python method doesn't pick up Thai accent characters.. pattern_thai = re.compile(r"(?:[^\d\W]+)|\s") def check_valid_word(shorthand, word): """ This function is to check if the word are multi-syllable words and not numbers. For vi, whitespaces are syllabe-separator. """ if shorthand.startswith("vi_"): return True if len(word.split(" ")) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False elif shorthand.startswith("th_"): return True if len(word) > 1 and any(map(pattern_thai.match, word)) and not any(map(str.isdigit, word)) else False else: return True if len(word) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False #checking for words in the training set to add them to lexicon. if train_path is not None: if not os.path.isfile(train_path): raise FileNotFoundError(f"Cannot open train set at {train_path}") train_doc = CoNLL.conll2doc(input_file=train_path) for train_sent in train_doc.sentences: train_words = [x.text for x in train_sent.tokens if x.is_mwt()] + [x.text for x in train_sent.words] for word in train_words: word = word.lower() if check_valid_word(shorthand, word) and word not in lexicon: lexicon.add(word) length_freq.append(len(word)) count_word = len(lexicon) logger.info(f"Added {count_word} words from the training data to the lexicon.") #checking for external dictionary and add them to lexicon. if external_path is not None: if not os.path.isfile(external_path): raise FileNotFoundError(f"Cannot open external dictionary at {external_path}") with open(external_path, "r", encoding="utf-8") as external_file: lines = external_file.readlines() for line in lines: word = line.lower() word = word.replace("\n","") if check_valid_word(shorthand, word) and word not in lexicon: lexicon.add(word) length_freq.append(len(word)) logger.info(f"Added another {len(lexicon) - count_word} words from the external dict to dictionary.") #automatically calculate the number of dictionary features (window size to look for words) based on the frequency of word length #take the length at 95-percentile to eliminate all the longest (maybe) compounds words in the lexicon num_dict_feat = int(np.percentile(length_freq, 95)) lexicon = {word for word in lexicon if len(word) <= num_dict_feat } logger.info(f"Final lexicon consists of {len(lexicon)} words after getting rid of long words.") return lexicon, num_dict_feat def load_lexicon(args): """ This function is to create a new dictionary and load it to training. The external dictionary is expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt For example, vi_vlsp-externaldict.txt """ shorthand = args["shorthand"] tokenize_dir = paths["TOKENIZE_DATA_DIR"] train_path = f"{tokenize_dir}/{shorthand}.train.gold.conllu" external_dict_path = f"{tokenize_dir}/{shorthand}-externaldict.txt" if not os.path.exists(external_dict_path): logger.info(f"External dictionary not found! Looked in {external_dict_path} Checking training data...") external_dict_path = None if not os.path.exists(train_path): logger.info(f"Training dataset does not exist, thus cannot create dictionary {shorthand}") train_path = None if train_path is None and external_dict_path is None: raise FileNotFoundError(f"Cannot find training set / external dictionary at {train_path} and {external_dict_path}") return create_lexicon(shorthand, train_path, external_dict_path) def load_mwt_dict(filename): """ Returns a dict from an MWT to its most common expansion and count. Other less common expansions are discarded. """ if filename is None: return None with open(filename, 'r') as f: mwt_dict0 = json.load(f) mwt_dict = dict() for item in mwt_dict0: (key, expansion), count = item if key not in mwt_dict or mwt_dict[key][1] < count: mwt_dict[key] = (expansion, count) return mwt_dict def process_sentence(sentence, mwt_dict=None): sent = [] i = 0 for tok, p, position_info in sentence: expansion = None if (p == 3 or p == 4) and mwt_dict is not None: # MWT found, (attempt to) expand it! if tok in mwt_dict: expansion = mwt_dict[tok][0] elif tok.lower() in mwt_dict: expansion = mwt_dict[tok.lower()][0] if expansion is not None: sent.append({ID: (i+1, i+len(expansion)), TEXT: tok}) if position_info is not None: sent[-1][START_CHAR] = position_info[0] sent[-1][END_CHAR] = position_info[1] for etok in expansion: sent.append({ID: (i+1, ), TEXT: etok}) i += 1 else: if len(tok) <= 0: continue sent.append({ID: (i+1, ), TEXT: tok}) if position_info is not None: sent[-1][START_CHAR] = position_info[0] sent[-1][END_CHAR] = position_info[1] if p == 3 or p == 4:# MARK sent[-1][MISC] = 'MWT=Yes' i += 1 return sent # https://stackoverflow.com/questions/201323/how-to-validate-an-email-address-using-a-regular-expression EMAIL_RAW_RE = r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(?:2(?:5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(?:2(?:5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""" # https://stackoverflow.com/questions/3809401/what-is-a-good-regular-expression-to-match-a-url # modification: disallow " as opposed to all ^\s URL_RAW_RE = r"""(?:https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s"]{2,}|www\.[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s"]{2,}|https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9]+\.[^\s"]{2,}|www\.[a-zA-Z0-9]+\.[^\s"]{2,})|[a-zA-Z0-9]+\.(?:gov|org|edu|net|com|co)(?:\.[^\s"]{2,})""" MASK_RE = re.compile(f"(?:{EMAIL_RAW_RE}|{URL_RAW_RE})") def find_spans(raw): """ Return spans of text which don't contain and are split by """ pads = [idx for idx, char in enumerate(raw) if char == ''] if len(pads) == 0: spans = [(0, len(raw))] else: prev = 0 spans = [] for pad in pads: if pad != prev: spans.append( (prev, pad) ) prev = pad + 1 if prev < len(raw): spans.append( (prev, len(raw)) ) return spans def update_pred_regex(raw, pred): """ Update the results of a tokenization batch by checking the raw text against a couple regular expressions Currently, emails and urls are handled TODO: this might work better as a constraint on the inference for efficiency pred is modified in place """ spans = find_spans(raw) for span_begin, span_end in spans: text = "".join(raw[span_begin:span_end]) for match in MASK_RE.finditer(text): match_begin, match_end = match.span() # first, update all characters touched by the regex to not split # with the exception of the last character... for char in range(match_begin+span_begin, match_end+span_begin-1): pred[char] = 0 # if the last character is not currently a split, make it a word split if pred[match_end+span_begin-1] == 0: pred[match_end+span_begin-1] = 1 return pred SPACE_RE = re.compile(r'\s') SPACE_SPLIT_RE = re.compile(r'( *[^ ]+)') def predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, num_workers): """ The guts of the prediction method Calls trainer.predict() over and over until we have predictions for all of the text """ all_preds = [] all_raw = [] sorted_data = SortedDataset(data_generator) dataloader = TorchDataLoader(sorted_data, batch_size=batch_size, collate_fn=sorted_data.collate, num_workers=num_workers) for batch_idx, batch in enumerate(dataloader): num_sentences = len(batch[3]) # being sorted by descending length, we need to use 0 as the longest sentence N = len(batch[3][0]) for paragraph in batch[3]: all_raw.append(list(paragraph)) if N <= max_seqlen: pred = np.argmax(trainer.predict(batch), axis=2) else: # TODO: we could shortcircuit some processing of # long strings of PAD by tracking which rows are finished idx = [0] * num_sentences adv = [0] * num_sentences para_lengths = [x.index('') for x in batch[3]] pred = [[] for _ in range(num_sentences)] while True: ens = [min(N - idx1, max_seqlen) for idx1, N in zip(idx, para_lengths)] en = max(ens) batch1 = batch[0][:, :en], batch[1][:, :en], batch[2][:, :en], [x[:en] for x in batch[3]] pred1 = np.argmax(trainer.predict(batch1), axis=2) for j in range(num_sentences): sentbreaks = np.where((pred1[j] == 2) + (pred1[j] == 4))[0] if len(sentbreaks) <= 0 or idx[j] >= para_lengths[j] - max_seqlen: advance = ens[j] else: advance = np.max(sentbreaks) + 1 pred[j] += [pred1[j, :advance]] idx[j] += advance adv[j] = advance if all([idx1 >= N for idx1, N in zip(idx, para_lengths)]): break # once we've made predictions on a certain number of characters for each paragraph (recorded in `adv`), # we skip the first `adv` characters to make the updated batch batch = data_generator.advance_old_batch(adv, batch) pred = [np.concatenate(p, 0) for p in pred] for par_idx in range(num_sentences): offset = batch_idx * batch_size + par_idx raw = all_raw[offset] par_len = raw.index('') raw = raw[:par_len] all_raw[offset] = raw if pred[par_idx][par_len-1] < 2: pred[par_idx][par_len-1] = 2 elif pred[par_idx][par_len-1] > 2: pred[par_idx][par_len-1] = 4 if use_regex_tokens: all_preds.append(update_pred_regex(raw, pred[par_idx][:par_len])) else: all_preds.append(pred[par_idx][:par_len]) all_preds = sorted_data.unsort(all_preds) all_raw = sorted_data.unsort(all_raw) return all_preds, all_raw def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, max_seqlen=1000, orig_text=None, no_ssplit=False, use_regex_tokens=True, num_workers=0, postprocessor=None): batch_size = trainer.args['batch_size'] max_seqlen = max(1000, max_seqlen) all_preds, all_raw = predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, num_workers) use_la_ittb_shorthand = trainer.args['shorthand'] == 'la_ittb' skip_newline = trainer.args['skip_newline'] oov_count, offset, doc = decode_predictions(vocab, mwt_dict, orig_text, all_raw, all_preds, no_ssplit, skip_newline, use_la_ittb_shorthand) # If we are provided a postprocessor, we prepare a list of pre-tokenized words and mwt flags and # call the postprocessor for analysis. if postprocessor: doc = postprocess_doc(doc, postprocessor, orig_text) if output_file: CoNLL.dict2conll(doc, output_file) return oov_count, offset, all_preds, doc def postprocess_doc(doc, postprocessor, orig_text=None): """Applies a postprocessor on the doc""" # get a list of all the words in the "draft" document to pass to the postprocessor # the words array looks like [["words, "words", "words"], ["words, ("i_am_a_mwt", True), "I_am_not"]] # and the postprocessor is expected to return in the same format words = [[((word["text"], True) if word.get("misc") == "MWT=Yes" else word["text"]) for word in sentence] for sentence in doc] if not orig_text: raw_text = "".join("".join(i) for i in all_raw) # template to compare the stitched text against else: raw_text = orig_text # perform correction with the postprocessor postprocessor_return = postprocessor(words) # collect the words and MWTs separately corrected_words = [] corrected_mwts = [] corrected_expansions = [] # for each word, if its just a string (without the ("word", mwt_bool) format) # we default that the word is not a MWT. for sent in postprocessor_return: sent_words = [] sent_mwts = [] sent_expansions = [] for word in sent: if isinstance(word, str): sent_words.append(word) sent_mwts.append(False) sent_expansions.append(None) else: if isinstance(word[1], bool): sent_words.append(word[0]) sent_mwts.append(word[1]) sent_expansions.append(None) else: sent_words.append(word[0]) sent_mwts.append(True) # expansions are marked in a space-separated list, which # `stanza.common.doc.set_mwt_expansions` reads and splits again # by splitting by spaces. Therefore, to serialize the users' supplied MWT # information, we join them by spaces to be split later by # `set_mwt_expansions`. sent_expansions.append(" ".join(word[1])) corrected_words.append(sent_words) corrected_mwts.append(sent_mwts) corrected_expansions.append(sent_expansions) # check postprocessor output token_lens = [len(i) for i in corrected_words] mwt_lens = [len(i) for i in corrected_mwts] assert token_lens == mwt_lens, "Postprocessor returned token and MWT lists of different length! Token list lengths %s, MWT list lengths %s" % (token_lens, mwt_lens) # reassemble document. offsets and oov shouldn't change doc = reassemble_doc_from_tokens(corrected_words, corrected_mwts, corrected_expansions, raw_text) return doc def reassemble_doc_from_tokens(tokens, mwts, expansions, raw_text): """Assemble a Stanza document list format from a list of string tokens, calculating offsets as needed. Parameters ---------- tokens : List[List[str]] A list of sentences, which includes string tokens. mwts : List[List[bool]] Whether or not each of the tokens are MWTs to be analyzed by the MWT system. expansions : List[List[Optional[List[str]]]] A list of possible expansions for MWTs, or None if no user-defined expansion is given. parser_text : str The raw text off of which we can compare offsets. Returns ------- List[List[Dict]] List of words and their offsets, used as `doc`. """ # oov count and offset stays the same; doc gets regenerated new_offset = 0 corrected_doc = [] for sent_words, sent_mwts, sent_expansions in zip(tokens, mwts, expansions): sentence_doc = [] for indx, (word, mwt, expansion) in enumerate(zip(sent_words, sent_mwts, sent_expansions)): try: offset_index = raw_text.index(word, new_offset) except ValueError as e: sub_start = max(0, new_offset - 20) sub_end = min(len(raw_text), new_offset + 20) sub = raw_text[sub_start:sub_end] raise ValueError("Could not find word |%s| starting from char_offset %d. Surrounding text: |%s|. \n Hint: did you accidentally add/subtract a symbol/character such as a space when combining tokens?" % (word, new_offset, sub)) from e wd = { "id": (indx+1,), "text": word, "start_char": offset_index, "end_char": offset_index+len(word) } if expansion: wd["manual_expansion"] = True elif mwt: wd["misc"] = "MWT=Yes" sentence_doc.append(wd) # start the next search after the previous word ended new_offset = offset_index+len(word) corrected_doc.append(sentence_doc) # use the built in MWT system to expand MWTs doc = Document(corrected_doc, raw_text) doc.set_mwt_expansions([j for i in expansions for j in i if j], process_manual_expanded=True) return doc.to_dict() def decode_predictions(vocab, mwt_dict, orig_text, all_raw, all_preds, no_ssplit, skip_newline, use_la_ittb_shorthand): """ Decode the predictions into a document of words Once everything is fed through the tokenizer model, it's time to decode the predictions into actual tokens and sentences that the rest of the pipeline uses """ offset = 0 oov_count = 0 doc = [] text = SPACE_RE.sub(' ', orig_text) if orig_text is not None else None char_offset = 0 if vocab is not None: UNK_ID = vocab.unit2id('') for raw, pred in zip(all_raw, all_preds): current_tok = '' current_sent = [] for t, p in zip(raw, pred): if t == '': break # hack la_ittb if use_la_ittb_shorthand and t in (":", ";"): p = 2 offset += 1 if vocab is not None and vocab.unit2id(t) == UNK_ID: oov_count += 1 current_tok += t if p >= 1: if vocab is not None: tok = vocab.normalize_token(current_tok) else: tok = current_tok assert '\t' not in tok, tok if len(tok) <= 0: current_tok = '' continue if orig_text is not None: st = -1 tok_len = 0 for part in SPACE_SPLIT_RE.split(current_tok): if len(part) == 0: continue if skip_newline: part_pattern = re.compile(r'\s*'.join(re.escape(c) for c in part)) match = part_pattern.search(text, char_offset) st0 = match.start(0) - char_offset partlen = match.end(0) - match.start(0) lstripped = match.group(0).lstrip() else: try: st0 = text.index(part, char_offset) - char_offset except ValueError as e: sub_start = max(0, char_offset - 20) sub_end = min(len(text), char_offset + 20) sub = text[sub_start:sub_end] raise ValueError("Could not find |%s| starting from char_offset %d. Surrounding text: |%s|" % (part, char_offset, sub)) from e partlen = len(part) lstripped = part.lstrip() if st < 0: st = char_offset + st0 + (partlen - len(lstripped)) char_offset += st0 + partlen position_info = (st, char_offset) else: position_info = None current_sent.append((tok, p, position_info)) current_tok = '' if (p == 2 or p == 4) and not no_ssplit: doc.append(process_sentence(current_sent, mwt_dict)) current_sent = [] if len(current_tok) > 0: raise ValueError("Finished processing tokens, but there is still text left!") if len(current_sent): doc.append(process_sentence(current_sent, mwt_dict)) return oov_count, offset, doc def match_tokens_with_text(sentences, orig_text): """ Turns pretokenized text and the original text into a Doc object sentences: list of list of string orig_text: string, where the text must be exactly the sentences concatenated with 0 or more whitespace characters if orig_text deviates in any way, a ValueError will be thrown """ text = "".join(["".join(x) for x in sentences]) all_raw = list(text) all_preds = [0] * len(all_raw) offset = 0 for sentence in sentences: for word in sentence: offset += len(word) all_preds[offset-1] = 1 all_preds[offset-1] = 2 _, _, doc = decode_predictions(None, None, orig_text, [all_raw], [all_preds], False, False, False) doc = Document(doc, orig_text) # check that all the orig_text was used up by the tokens offset = doc.sentences[-1].tokens[-1].end_char remainder = orig_text[offset:].strip() if len(remainder) > 0: raise ValueError("Finished processing tokens, but there is still text left!") return doc def eval_model(args, trainer, batches, vocab, mwt_dict): oov_count, N, all_preds, doc = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen']) all_preds = np.concatenate(all_preds, 0) labels = np.concatenate(batches.labels()) counter = Counter(zip(all_preds, labels)) def f1(pred, gold, mapping): pred = [mapping[p] for p in pred] gold = [mapping[g] for g in gold] lastp = -1; lastg = -1 tp = 0; fp = 0; fn = 0 for i, (p, g) in enumerate(zip(pred, gold)): if p == g > 0 and lastp == lastg: lastp = i lastg = i tp += 1 elif p > 0 and g > 0: lastp = i lastg = i fp += 1 fn += 1 elif p > 0: # and g == 0 lastp = i fp += 1 elif g > 0: lastg = i fn += 1 if tp == 0: return 0 else: return 2 * tp / (2 * tp + fp + fn) f1tok = f1(all_preds, labels, {0:0, 1:1, 2:1, 3:1, 4:1}) f1sent = f1(all_preds, labels, {0:0, 1:0, 2:1, 3:0, 4:1}) f1mwt = f1(all_preds, labels, {0:0, 1:1, 2:1, 3:2, 4:2}) logger.info(f"{args['shorthand']}: token F1 = {f1tok*100:.2f}, sentence F1 = {f1sent*100:.2f}, mwt F1 = {f1mwt*100:.2f}") return harmonic_mean([f1tok, f1sent, f1mwt], [1, 1, .01]) ================================================ FILE: stanza/models/tokenization/vocab.py ================================================ from collections import Counter import re from stanza.models.common.vocab import BaseVocab from stanza.models.common.vocab import UNK, PAD SPACE_RE = re.compile(r'\s') class Vocab(BaseVocab): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lang_replaces_spaces = any([self.lang.startswith(x) for x in ['zh', 'ja', 'ko']]) def build_vocab(self): paras = self.data counter = Counter() for para in paras: for unit in para: normalized = self.normalize_unit(unit[0]) counter[normalized] += 1 self._id2unit = [PAD, UNK] + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True)) self._unit2id = {w:i for i, w in enumerate(self._id2unit)} def append(self, unit): self._id2unit.append(unit) idx = len(self._id2unit) - 1 self._unit2id[unit] = idx def normalize_unit(self, unit): # Normalize minimal units used by the tokenizer return unit def normalize_token(self, token): token = SPACE_RE.sub(' ', token.lstrip()) if self.lang_replaces_spaces: token = token.replace(' ', '') return token ================================================ FILE: stanza/models/tokenizer.py ================================================ """ Entry point for training and evaluating a neural tokenizer. This tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of recurrent and convolutional architectures. For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf. Updated: This new version of tokenizer model incorporates the dictionary feature, especially useful for languages that have multi-syllable words such as Vietnamese, Chinese or Thai. In summary, a lexicon contains all unique words found in training dataset and external lexicon (if any) is created during training and saved alongside the model after training. Using this lexicon, a dictionary is created which includes "words", "prefixes" and "suffixes" sets. During data preparation, dictionary features are extracted at each character position, to "look ahead" and "look backward" to see if any words formed found in the dictionary. The window size (or the dictionary feature length) is defined at the 95-percentile among all the existing words in the lexicon, this is to eliminate the less frequent but long words (avoid having a high-dimension feat vector). Prefixes and suffixes are used to stop early during the window-dictionary checking process. """ import argparse from copy import copy import logging import random import numpy as np import os import torch import json from stanza.models.common import utils from stanza.models.tokenization.trainer import Trainer from stanza.models.tokenization.data import DataLoader, TokenizationDataset from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary from stanza.models import _training_logging logger = logging.getLogger('stanza') def build_argparse(): """ If args == None, the system args are used. """ parser = argparse.ArgumentParser() parser.add_argument('--txt_file', type=str, help="Input plaintext file") parser.add_argument('--label_file', type=str, default=None, help="Character-level label file") parser.add_argument('--mwt_json_file', type=str, default=None, help="JSON file for MWT expansions") parser.add_argument('--conll_file', type=str, default=None, help="CoNLL file for output") parser.add_argument('--dev_txt_file', type=str, help="(Train only) Input plaintext file for the dev set") parser.add_argument('--dev_label_file', type=str, default=None, help="(Train only) Character-level label file for the dev set") parser.add_argument('--dev_conll_gold', type=str, default=None, help="(Train only) CoNLL-U file for the dev set for early stopping") parser.add_argument('--lang', type=str, help="Language") parser.add_argument('--shorthand', type=str, help="UD treebank shorthand") parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--skip_newline', action='store_true', help="Whether to skip newline characters in input. Particularly useful for languages like Chinese.") parser.add_argument('--emb_dim', type=int, default=32, help="Dimension of unit embeddings") parser.add_argument('--hidden_dim', type=int, default=64, help="Dimension of hidden units") parser.add_argument('--conv_filters', type=str, default="1,9", help="Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.") parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections") parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer") parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers") parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well") parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN") parser.add_argument('--rnn_layers', type=int, default=1, help="Layers of RNN in the tokenizer") parser.add_argument('--use_dictionary', action='store_true', help="Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/tokenize/zh_gsdsimp-externaldict.txt") parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm to clip to") parser.add_argument('--anneal', type=float, default=.999, help="Anneal the learning rate by this amount when dev performance deteriorate") parser.add_argument('--anneal_after', type=int, default=2000, help="Anneal the learning rate no earlier than this step") parser.add_argument('--lr0', type=float, default=2e-3, help="Initial learning rate") parser.add_argument('--dropout', type=float, default=0.33, help="Dropout probability") parser.add_argument('--unit_dropout', type=float, default=0.33, help="Unit dropout probability") parser.add_argument('--feat_dropout', type=float, default=0.05, help="Features dropout probability for each element in feature vector") parser.add_argument('--feat_unit_dropout', type=float, default=0.33, help="The whole feature of units dropout probability") parser.add_argument('--tok_noise', type=float, default=0.02, help="Probability to induce noise to the input of the higher RNN") parser.add_argument('--sent_drop_prob', type=float, default=0.2, help="Probability to drop sentences at the end of batches during training uniformly at random. Idea is to fake paragraph endings.") parser.add_argument('--last_char_drop_prob', type=float, default=0.2, help="Probability to drop the last char of a block of text during training, uniformly at random. Idea is to fake a document ending w/o sentence final punctuation, hopefully to avoid the tokenizer learning to always tokenize the last character as a period") parser.add_argument('--last_char_move_prob', type=float, default=0.02, help="Probability to move the sentence final punctuation of a sentence during training, uniformly at random. Idea is to teach the tokenizer that a space separated sentence final punct still ends the sentence") parser.add_argument('--punct_move_back_prob', type=float, default=0.02, help="Probability to move a comma in the sentence one over, removing the previous space, during training. Idea is to teach the tokenizer that commas can appear next to words even in languages where the dataset doesn't allow it, such as Vietnamese") parser.add_argument('--split_mwt_prob', type=float, default=0.01, help="Probably to split an MWT into its component pieces and turn it into separate words") parser.add_argument('--augment_final_punct_prob', type=float, default=0.05, help="Probability to replace a ? with a ? or other similar augmentations") parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay") parser.add_argument('--max_seqlen', type=int, default=100, help="Maximum sequence length to consider at a time") parser.add_argument('--batch_size', type=int, default=32, help="Batch size to use") parser.add_argument('--epochs', type=int, default=10, help="Total epochs to train the model for") parser.add_argument('--steps', type=int, default=50000, help="Steps to train the model for, if unspecified use epochs") parser.add_argument('--report_steps', type=int, default=20, help="Update step interval to report loss") parser.add_argument('--shuffle_steps', type=int, default=100, help="Step interval to shuffle each paragraph in the generator") parser.add_argument('--eval_steps', type=int, default=200, help="Step interval to evaluate the model on the dev set for early stopping") parser.add_argument('--max_steps_before_stop', type=int, default=5000, help='Early terminates after this many steps if the dev scores are not improving') parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_tokenizer.pt", help="File name to save the model") parser.add_argument('--load_name', type=str, default=None, help="File name to load a saved model") parser.add_argument('--save_dir', type=str, default='saved_models/tokenize', help="Directory to save models in") utils.add_device_args(parser) parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.") parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--use_mwt', dest='use_mwt', default=None, action='store_true', help='Whether or not to include mwt output layers. If set to None, this will be determined by examining the training data for MWTs') parser.add_argument('--no_use_mwt', dest='use_mwt', action='store_false', help='Whether or not to include mwt output layers') parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') return parser def parse_args(args=None): parser = build_argparse() args = parser.parse_args(args=args) if args.wandb_name: args.wandb = True args = vars(args) return args def model_file_name(args): embedding = "nocharlm" if args['charlm'] and args['charlm_forward_file']: embedding = "charlm" save_name = args['save_name'].format(shorthand=args['shorthand'], embedding=embedding) logger.info("Saving to: %s", save_name) if not os.path.exists(os.path.join(args['save_dir'], save_name)) and os.path.exists(save_name): return save_name return os.path.join(args['save_dir'], save_name) def main(args=None): args = parse_args(args=args) utils.set_random_seed(args['seed']) logger.info("Running tokenizer in {} mode".format(args['mode'])) args['feat_funcs'] = ['space_before', 'capitalized', 'numeric', 'end_of_para', 'start_of_para'] args['feat_dim'] = len(args['feat_funcs']) args['save_name'] = model_file_name(args) utils.ensure_dir(os.path.split(args['save_name'])[0]) if args['mode'] == 'train': return train(args) else: return evaluate(args) def train(args): if args['use_dictionary']: #load lexicon lexicon, args['num_dict_feat'] = load_lexicon(args) #create the dictionary dictionary = create_dictionary(lexicon) #adjust the feat_dim args['feat_dim'] += args['num_dict_feat']*2 else: args['num_dict_feat'] = 0 lexicon=None dictionary=None mwt_dict = load_mwt_dict(args['mwt_json_file']) mwt_expansions = {x: y[0] for x, y in mwt_dict.items()} train_input_files = { 'txt': args['txt_file'], 'label': args['label_file'] } train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary, mwt_expansions=mwt_expansions) vocab = train_batches.vocab args['vocab_size'] = len(vocab) dev_input_files = { 'txt': args['dev_txt_file'], 'label': args['dev_label_file'] } dev_batches = TokenizationDataset(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=dictionary) if args['use_mwt'] is None: args['use_mwt'] = train_batches.has_mwt() logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt'])) trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'], foundation_cache=None) if args['load_name'] is not None: load_name = os.path.join(args['save_dir'], args['load_name']) trainer.load(load_name) trainer.change_lr(args['lr0']) N = len(train_batches) steps = args['steps'] if args['steps'] is not None else int(N * args['epochs'] / args['batch_size'] + .5) lr = args['lr0'] prev_dev_score = -1 best_dev_score = -1 best_dev_step = -1 if args['wandb']: import wandb wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_tokenizer" % args['shorthand'] wandb.init(name=wandb_name, config=args) wandb.run.define_metric('train_loss', summary='min') wandb.run.define_metric('dev_score', summary='max') for step in range(1, steps+1): batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout']) loss = trainer.update(batch) if step % args['report_steps'] == 0: logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(step, steps, loss)) if args['wandb']: wandb.log({'train_loss': loss}, step=step) if args['shuffle_steps'] > 0 and step % args['shuffle_steps'] == 0: train_batches.shuffle() if step % args['eval_steps'] == 0: dev_score = eval_model(args, trainer, dev_batches, vocab, mwt_dict) if args['wandb']: wandb.log({'dev_score': dev_score}, step=step) reports = ['Dev score: {:6.3f}'.format(dev_score * 100)] if step >= args['anneal_after'] and dev_score < prev_dev_score: reports += ['lr: {:.6f} -> {:.6f}'.format(lr, lr * args['anneal'])] lr *= args['anneal'] trainer.change_lr(lr) prev_dev_score = dev_score if dev_score > best_dev_score: reports += ['New best dev score!'] best_dev_score = dev_score best_dev_step = step trainer.save(args['save_name']) elif best_dev_step > 0 and step - best_dev_step > args['max_steps_before_stop']: reports += ['Stopping training after {} steps with no improvement'.format(step - best_dev_step)] logger.info('\t'.join(reports)) break logger.info('\t'.join(reports)) if args['wandb']: wandb.finish() if best_dev_step > -1: logger.info('Best dev score={} at step {}'.format(best_dev_score, best_dev_step)) else: logger.info('Dev set never evaluated. Saving final model') trainer.save(args['save_name']) return trainer, None def evaluate(args): mwt_dict = load_mwt_dict(args['mwt_json_file']) trainer = Trainer(args=args, model_file=args['load_name'] or args['save_name'], device=args['device'], foundation_cache=None) loaded_args, vocab = trainer.args, trainer.vocab for k in loaded_args: if not k.endswith('_file') and k not in ['device', 'mode', 'save_dir', 'load_name', 'save_name']: args[k] = loaded_args[k] eval_input_files = { 'txt': args['txt_file'], 'label': args['label_file'] } batches = TokenizationDataset(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=trainer.dictionary) oov_count, N, _, doc = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen']) logger.info("OOV rate: {:6.3f}% ({:6d}/{:6d})".format(oov_count / N * 100, oov_count, N)) return trainer, doc if __name__ == '__main__': main() ================================================ FILE: stanza/models/wl_coref.py ================================================ """ Runs experiments with CorefModel. Try 'python wl_coref.py -h' for more details. Code based on https://github.com/KarelDO/wl-coref/tree/master https://arxiv.org/abs/2310.06165 This was a fork of https://github.com/vdobrovolskii/wl-coref https://aclanthology.org/2021.emnlp-main.605/ If you use Stanza's coref module in your work, please cite the following: @misc{doosterlinck2023cawcoref, title={CAW-coref: Conjunction-Aware Word-level Coreference Resolution}, author={Karel D'Oosterlinck and Semere Kiros Bitew and Brandon Papineau and Christopher Potts and Thomas Demeester and Chris Develder}, year={2023}, eprint={2310.06165}, archivePrefix={arXiv}, primaryClass={cs.CL}, url = "https://arxiv.org/abs/2310.06165", } @inproceedings{dobrovolskii-2021-word, title = "Word-Level Coreference Resolution", author = "Dobrovolskii, Vladimir", booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", month = nov, year = "2021", address = "Online and Punta Cana, Dominican Republic", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/2021.emnlp-main.605", pages = "7670--7675" } """ import argparse from contextlib import contextmanager import datetime import logging import os import random import sys import dataclasses import time import numpy as np # type: ignore import torch # type: ignore from stanza.models.common.utils import set_random_seed from stanza.models.coref.model import CorefModel logger = logging.getLogger('stanza') @contextmanager def output_running_time(): """ Prints the time elapsed in the context """ start = int(time.time()) try: yield finally: end = int(time.time()) delta = datetime.timedelta(seconds=end - start) logger.info(f"Total running time: {delta}") def deterministic() -> None: torch.backends.cudnn.deterministic = True # type: ignore torch.backends.cudnn.benchmark = False # type: ignore if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("mode", choices=("train", "eval")) argparser.add_argument("experiment") argparser.add_argument("--config_file", default="config.toml") argparser.add_argument("--data_split", choices=("train", "dev", "test"), default="test", help="Data split to be used for evaluation." " Defaults to 'test'." " Ignored in 'train' mode.") argparser.add_argument("--batch_size", type=int, help="Adjust to override the config value of anaphoricity " "batch size if you are experiencing out-of-memory " "issues") argparser.add_argument("--disable_singletons", action="store_true", help="don't predict singletons") argparser.add_argument("--full_pairwise", action="store_true", help="use speaker and document embeddings") argparser.add_argument("--hidden_size", type=int, help="Adjust the anaphoricity scorer hidden size") argparser.add_argument("--rough_k", type=int, help="Adjust the number of dummies to keep") argparser.add_argument("--n_hidden_layers", type=int, help="Adjust the anaphoricity scorer hidden layers") argparser.add_argument("--dummy_mix", type=float, help="Adjust the dummy mix") argparser.add_argument("--bert_finetune_begin_epoch", type=float, help="Adjust the bert finetune begin epoch") argparser.add_argument("--bert_model", type=str, help="Use this transformer for the given experiment") argparser.add_argument("--warm_start", action="store_true", help="If set, the training will resume from the" " last checkpoint saved if any. Ignored in" " evaluation modes." " Incompatible with '--weights'.") argparser.add_argument("--weights", help="Path to file with weights to load." " If not supplied, in 'eval' mode the latest" " weights of the experiment will be loaded;" " in 'train' mode no weights will be loaded.") argparser.add_argument("--word_level", action="store_true", help="If set, output word-level conll-formatted" " files in evaluation modes. Ignored in" " 'train' mode.") argparser.add_argument("--learning_rate", default=None, type=float, help="If set, update the learning rate for the model") argparser.add_argument("--bert_learning_rate", default=None, type=float, help="If set, update the learning rate for the transformer") argparser.add_argument("--save_dir", default=None, help="If set, update the save directory for writing models") argparser.add_argument("--save_name", default=None, help="If set, update the save name for writing models (otherwise, section name)") argparser.add_argument("--score_lang", default=None, help="only score a particular language for eval") argparser.add_argument("--log_norms", action="store_true", default=None, help="If set, log all of the trainable norms each epoch. Very noisy!") argparser.add_argument("--seed", type=int, default=2020, help="Random seed to set") argparser.add_argument("--lang_lr_attenuation", type=str, default=None, help="A comma-separated list of languages where the LR will be scaled by 1/epoch, such as --lang_lr_attenuation=es,en,de,...") argparser.add_argument("--lang_lr_weights", type=str, default=None, help="A comma-separated list of languages and their weights of LR scaling for different languages, such as es=0.5,en=1.0,...") argparser.add_argument("--max_train_len", type=int, default=5000, help="Skip any documents longer than this maximum length") argparser.add_argument("--no_max_train_len", action="store_const", const=float("inf"), dest="max_train_len", help="Do not skip any documents for being too long") argparser.add_argument("--train_epochs", type=int, default=None, help="Train this many epochs") argparser.add_argument("--plateau_epochs", type=int, default=None, help="Stop training if plateaued for this many epochs (only applies if positive)") argparser.add_argument("--train_data", default=None, help="File to use for train data") argparser.add_argument("--dev_data", default=None, help="File to use for dev data") argparser.add_argument("--test_data", default=None, help="File to use for test data") argparser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name', default=False) argparser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') args = argparser.parse_args() if args.warm_start and args.weights is not None: raise ValueError("The following options are incompatible: '--warm_start' and '--weights'") set_random_seed(args.seed) deterministic() config = CorefModel._load_config(args.config_file, args.experiment) if args.batch_size: config.a_scoring_batch_size = args.batch_size if args.hidden_size: config.hidden_size = args.hidden_size if args.n_hidden_layers: config.n_hidden_layers = args.n_hidden_layers if args.learning_rate is not None: config.learning_rate = args.learning_rate if args.bert_model is not None: config.bert_model = args.bert_model if args.bert_learning_rate is not None: config.bert_learning_rate = args.bert_learning_rate if args.bert_finetune_begin_epoch is not None: config.bert_finetune_begin_epoch = args.bert_finetune_begin_epoch if args.dummy_mix is not None: config.dummy_mix = args.dummy_mix if args.save_dir is not None: config.save_dir = args.save_dir if args.save_name: config.save_name = args.save_name else: config.save_name = args.experiment if args.rough_k is not None: config.rough_k = args.rough_k if args.log_norms is not None: config.log_norms = args.log_norms if args.full_pairwise: config.full_pairwise = args.full_pairwise if args.disable_singletons: config.singletons = False if args.train_data: config.train_data = args.train_data if args.dev_data: config.dev_data = args.dev_data if args.test_data: config.test_data = args.test_data if args.max_train_len: config.max_train_len = args.max_train_len if args.train_epochs: config.train_epochs = args.train_epochs if args.plateau_epochs: config.plateau_epochs = args.plateau_epochs if args.lang_lr_attenuation: config.lang_lr_attenuation = args.lang_lr_attenuation if args.lang_lr_weights: config.lang_lr_weights = args.lang_lr_weights # if wandb, generate wandb configuration if args.mode == "train": if args.wandb: import wandb wandb_name = args.wandb_name if args.wandb_name else f"wl_coref_{args.experiment}" wandb.init(name=wandb_name, config=dataclasses.asdict(config), project="stanza") wandb.run.define_metric('train_c_loss', summary='min') wandb.run.define_metric('train_s_loss', summary='min') wandb.run.define_metric('dev_score', summary='max') model = CorefModel(config=config) if args.weights is not None or args.warm_start: model.load_weights(path=args.weights, map_location="cpu", noexception=args.warm_start) with output_running_time(): model.train(args.wandb) else: config_update = { 'log_norms': args.log_norms if args.log_norms is not None else False } if args.test_data: config_update['test_data'] = args.test_data if args.weights is None and config.save_name is not None: args.weights = config.save_name if not os.path.exists(args.weights) and os.path.exists(args.weights + ".pt"): args.weights = args.weights + ".pt" elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights)): args.weights = os.path.join(config.save_dir, args.weights) elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights + ".pt")): args.weights = os.path.join(config.save_dir, args.weights + ".pt") model = CorefModel.load_model(path=args.weights, map_location="cpu", ignore={"bert_optimizer", "general_optimizer", "bert_scheduler", "general_scheduler"}, config_update=config_update) results = model.evaluate(data_split=args.data_split, word_level_conll=args.word_level, eval_lang=args.score_lang) # logger.info(("mean loss", ")) print("\t".join([str(round(i, 3)) for i in results])) ================================================ FILE: stanza/pipeline/__init__.py ================================================ ================================================ FILE: stanza/pipeline/_constants.py ================================================ """ Module defining constants """ # string constants for processor names LANGID = 'langid' TOKENIZE = 'tokenize' MWT = 'mwt' POS = 'pos' LEMMA = 'lemma' DEPPARSE = 'depparse' NER = 'ner' SENTIMENT = 'sentiment' CONSTITUENCY = 'constituency' COREF = 'coref' MORPHSEG = 'morphseg' ================================================ FILE: stanza/pipeline/constituency_processor.py ================================================ """ Processor that attaches a constituency tree to a sentence """ from stanza.models.constituency.trainer import Trainer from stanza.models.common import doc from stanza.models.common.utils import sort_with_indices, unsort from stanza.utils.get_tqdm import get_tqdm from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor tqdm = get_tqdm() @register_processor(CONSTITUENCY) class ConstituencyProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([CONSTITUENCY]) # set of processor requirements for this processor REQUIRES_DEFAULT = set([TOKENIZE, POS]) # default batch size, measured in sentences DEFAULT_BATCH_SIZE = 50 def _set_up_requires(self): self._pretagged = self._config.get('pretagged') if self._pretagged: self._requires = set() else: self._requires = self.__class__.REQUIRES_DEFAULT def _set_up_model(self, config, pipeline, device): # set up model # pretrain and charlm paths are args from the config # bert (if used) will be chosen from the model save file args = { "wordvec_pretrain_file": config.get('pretrain_path', None), "charlm_forward_file": config.get('forward_charlm_path', None), "charlm_backward_file": config.get('backward_charlm_path', None), "device": device, } trainer = Trainer.load(filename=config['model_path'], args=args, foundation_cache=pipeline.foundation_cache) self._trainer = trainer self._model = trainer.model self._model.eval() # batch size counted as sentences self._batch_size = int(config.get('batch_size', ConstituencyProcessor.DEFAULT_BATCH_SIZE)) self._tqdm = 'tqdm' in config and config['tqdm'] def _set_up_final_config(self, config): loaded_args = self._model.args loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)} loaded_args.update(config) self._config = loaded_args def process(self, document): sentences = document.sentences if self._model.uses_xpos(): words = [[(w.text, w.xpos) for w in s.words] for s in sentences] else: words = [[(w.text, w.upos) for w in s.words] for s in sentences] words, original_indices = sort_with_indices(words, key=len, reverse=True) if self._tqdm: words = tqdm(words) trees = self._model.parse_tagged_words(words, self._batch_size) trees = unsort(trees, original_indices) document.set(CONSTITUENCY, trees, to_sentence=True) return document def get_constituents(self): """ Return a set of the constituents known by this model For a pipeline, this can be queried with pipeline.processors["constituency"].get_constituents() """ return set(self._model.constituents) ================================================ FILE: stanza/pipeline/core.py ================================================ """ Pipeline that runs tokenize,mwt,pos,lemma,depparse """ import argparse import collections from enum import Enum import io import itertools import sys import logging import json import os from stanza.pipeline._constants import * from stanza.models.common.constant import langcode_to_lang from stanza.models.common.doc import Document from stanza.models.common.foundation_cache import FoundationCache from stanza.models.common.utils import default_device from stanza.pipeline.processor import Processor, ProcessorRequirementsException from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS from stanza.pipeline.langid_processor import LangIDProcessor from stanza.pipeline.tokenize_processor import TokenizeProcessor from stanza.pipeline.mwt_processor import MWTProcessor from stanza.pipeline.pos_processor import POSProcessor from stanza.pipeline.lemma_processor import LemmaProcessor from stanza.pipeline.constituency_processor import ConstituencyProcessor from stanza.pipeline.coref_processor import CorefProcessor from stanza.pipeline.depparse_processor import DepparseProcessor from stanza.pipeline.sentiment_processor import SentimentProcessor from stanza.pipeline.ner_processor import NERProcessor from stanza.resources.common import DEFAULT_MODEL_DIR, DEFAULT_RESOURCES_URL, DEFAULT_RESOURCES_VERSION, ModelSpecification, add_dependencies, add_mwt, download_models, download_resources_json, flatten_processor_list, load_resources_json, maintain_processor_list, process_pipeline_parameters, set_logging_level, sort_processors from stanza.resources.default_packages import PACKAGES from stanza.utils.conll import CoNLL, CoNLLError from stanza.utils.helper_func import make_table logger = logging.getLogger('stanza') class DownloadMethod(Enum): """ Determines a couple options on how to download resources for the pipeline. NONE will not download anything, including HF transformers, probably resulting in failure if the resources aren't already in place. REUSE_RESOURCES will reuse the existing resources.json and models, but will download any missing models. DOWNLOAD_RESOURCES will download a new resources.json and will overwrite any out of date models. """ NONE = 1 REUSE_RESOURCES = 2 DOWNLOAD_RESOURCES = 3 class LanguageNotDownloadedError(FileNotFoundError): def __init__(self, lang, lang_dir, model_path): super().__init__(f'Could not find the model file {model_path}. The expected model directory {lang_dir} is missing. Perhaps you need to run stanza.download("{lang}")') self.lang = lang self.lang_dir = lang_dir self.model_path = model_path class UnsupportedProcessorError(FileNotFoundError): def __init__(self, processor, lang): super().__init__(f'Processor {processor} is not known for language {lang}. If you have created your own model, please specify the {processor}_model_path parameter when creating the pipeline.') self.processor = processor self.lang = lang class IllegalPackageError(ValueError): def __init__(self, msg): super().__init__(msg) class PipelineRequirementsException(Exception): """ Exception indicating one or more requirements failures while attempting to build a pipeline. Contains a ProcessorRequirementsException list. """ def __init__(self, processor_req_fails): self._processor_req_fails = processor_req_fails self.build_message() @property def processor_req_fails(self): return self._processor_req_fails def build_message(self): err_msg = io.StringIO() print(*[req_fail.message for req_fail in self.processor_req_fails], sep='\n', file=err_msg) self.message = '\n\n' + err_msg.getvalue() def __str__(self): return self.message def build_default_config_option(model_specs): """ Build a config option for a couple situations: lemma=identity, processor is a variant Returns the option name and value Refactored from build_default_config so that we can reuse it when downloading all models """ # handle case when processor variants are used if any(model_spec.package in PROCESSOR_VARIANTS[model_spec.processor] for model_spec in model_specs): if len(model_specs) > 1: raise IllegalPackageError("Variant processor selected for {}, but multiple packages requested".format(model_spec.processor)) return f"{model_specs[0].processor}_with_{model_specs[0].package}", True # handle case when identity is specified as lemmatizer elif any(model_spec.processor == LEMMA and model_spec.package == 'identity' for model_spec in model_specs): if len(model_specs) > 1: raise IllegalPackageError("Identity processor selected for lemma, but multiple packages requested") return f"{LEMMA}_use_identity", True return None def filter_variants(model_specs): return [(key, value) for (key, value) in model_specs if build_default_config_option(value) is None] # given a language and models path, build a default configuration def build_default_config(resources, lang, model_dir, load_list): default_config = {} for processor, model_specs in load_list: option = build_default_config_option(model_specs) if option is not None: # if an option is set for the model_specs, keep that option and ignore # the rest of the model spec default_config[option[0]] = option[1] continue model_paths = [os.path.join(model_dir, lang, processor, model_spec.package + '.pt') for model_spec in model_specs] dependencies = [model_spec.dependencies for model_spec in model_specs] # Special case for NER: load multiple models at once # The pattern will be: # a list of ner_model_path # a list of ner_dependencies # where each item in ner_dependencies is a map # the map may contain forward_charlm_path, backward_charlm_path, or any other deps # The user will be able to override the defaults using a semicolon separated string # TODO: at least use the same config pattern for all other models if processor == NER: default_config[f"{processor}_model_path"] = model_paths dependency_paths = [] for dependency_block in dependencies: if not dependency_block: dependency_paths.append({}) continue dependency_paths.append({}) for dependency in dependency_block: dep_processor, dep_model = dependency dependency_paths[-1][f"{dep_processor}_path"] = os.path.join(model_dir, lang, dep_processor, dep_model + '.pt') default_config[f"{processor}_dependencies"] = dependency_paths continue if len(model_specs) > 1: raise IllegalPackageError("Specified multiple packages for {}, which currently only handles one package".format(processor)) default_config[f"{processor}_model_path"] = model_paths[0] if not dependencies[0]: continue for dependency in dependencies[0]: dep_processor, dep_model = dependency default_config[f"{processor}_{dep_processor}_path"] = os.path.join( model_dir, lang, dep_processor, dep_model + '.pt' ) return default_config def normalize_download_method(download_method): """ Turn None -> DownloadMethod.NONE, strings to the corresponding enum """ if download_method is None: return DownloadMethod.NONE elif isinstance(download_method, str): try: return DownloadMethod[download_method.upper()] except KeyError as e: raise ValueError("Unknown download method %s" % download_method) from e return download_method class Pipeline: def __init__(self, lang='en', dir=DEFAULT_MODEL_DIR, package='default', processors={}, logging_level=None, verbose=None, use_gpu=None, model_dir=None, download_method=DownloadMethod.DOWNLOAD_RESOURCES, resources_url=DEFAULT_RESOURCES_URL, resources_branch=None, resources_version=DEFAULT_RESOURCES_VERSION, resources_filepath=None, proxies=None, foundation_cache=None, device=None, allow_unknown_language=False, **kwargs): self.lang, self.dir, self.kwargs = lang, dir, kwargs if model_dir is not None and dir == DEFAULT_MODEL_DIR: self.dir = model_dir # set global logging level set_logging_level(logging_level, verbose) self.download_method = normalize_download_method(download_method) if (self.download_method is DownloadMethod.DOWNLOAD_RESOURCES or (self.download_method is DownloadMethod.REUSE_RESOURCES and not os.path.exists(os.path.join(self.dir, "resources.json")))): logger.info("Checking for updates to resources.json in case models have been updated. Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES") download_resources_json(self.dir, resources_url=resources_url, resources_branch=resources_branch, resources_version=resources_version, resources_filepath=resources_filepath, proxies=proxies) # processors can use this to save on the effort of loading # large sub-models, such as pretrained embeddings, bert, etc if foundation_cache is None: self.foundation_cache = FoundationCache(local_files_only=(self.download_method is DownloadMethod.NONE)) else: self.foundation_cache = FoundationCache(foundation_cache, local_files_only=(self.download_method is DownloadMethod.NONE)) # process different pipeline parameters lang, self.dir, package, processors = process_pipeline_parameters(lang, self.dir, package, processors) # Load resources.json to obtain latest packages. logger.debug('Loading resource file...') resources = load_resources_json(self.dir, resources_filepath) if lang in resources: if 'alias' in resources[lang]: logger.info(f'"{lang}" is an alias for "{resources[lang]["alias"]}"') lang = resources[lang]['alias'] lang_name = resources[lang]['lang_name'] if 'lang_name' in resources[lang] else '' elif allow_unknown_language: logger.warning("Trying to create pipeline for unsupported language: %s", lang) lang_name = langcode_to_lang(lang) else: logger.warning("Unsupported language: %s If trying to add a new language, consider using allow_unknown_language=True", lang) lang_name = langcode_to_lang(lang) # Maintain load list if lang in resources: self.load_list = maintain_processor_list(resources, lang, package, processors, maybe_add_mwt=(not kwargs.get("tokenize_pretokenized"))) self.load_list = add_dependencies(resources, lang, self.load_list) if self.download_method is not DownloadMethod.NONE: # skip processors which aren't downloaded from our collection download_list = [x for x in self.load_list if x[0] in resources.get(lang, {})] # skip variants download_list = filter_variants(download_list) # gather up the model list... download_list = flatten_processor_list(download_list) # download_models will skip models we already have download_models(download_list, resources=resources, lang=lang, model_dir=self.dir, resources_version=resources_version, proxies=proxies, log_info=False) elif allow_unknown_language: self.load_list = [(proc, [ModelSpecification(processor=proc, package='default', dependencies=None)]) for proc in list(processors.keys())] else: self.load_list = [] self.load_list = self.update_kwargs(kwargs, self.load_list) if len(self.load_list) == 0: if lang not in resources or PACKAGES not in resources[lang]: raise ValueError(f'No processors to load for language {lang}. Language {lang} is currently unsupported') else: raise ValueError('No processors to load for language {}. Please check if your language or package is correctly set.'.format(lang)) load_table = make_table(['Processor', 'Package'], [(row[0], ";".join(model_spec.package for model_spec in row[1])) for row in self.load_list]) logger.info(f'Loading these models for language: {lang} ({lang_name}):\n{load_table}') self.config = build_default_config(resources, lang, self.dir, self.load_list) self.config.update(kwargs) # Load processors self.processors = {} # configs that are the same for all processors pipeline_level_configs = {'lang': lang, 'mode': 'predict'} if device is None: if use_gpu is None or use_gpu == True: device = default_device() else: device = 'cpu' if use_gpu == True and device == 'cpu': logger.warning("GPU requested, but is not available!") self.device = device logger.info("Using device: {}".format(self.device)) # set up processors pipeline_reqs_exceptions = [] for item in self.load_list: processor_name, _ = item logger.info('Loading: ' + processor_name) curr_processor_config = self.filter_config(processor_name, self.config) curr_processor_config.update(pipeline_level_configs) # TODO: this is obviously a hack # a better solution overall would be to make a pretagged version of the pos annotator # and then subsequent modules can use those tags without knowing where those tags came from if "pretagged" in self.config and "pretagged" not in curr_processor_config: curr_processor_config["pretagged"] = self.config["pretagged"] logger.debug('With settings: ') logger.debug(curr_processor_config) try: # try to build processor, throw an exception if there is a requirements issue self.processors[processor_name] = NAME_TO_PROCESSOR_CLASS[processor_name](config=curr_processor_config, pipeline=self, device=self.device) except ProcessorRequirementsException as e: # if there was a requirements issue, add it to list which will be printed at end pipeline_reqs_exceptions.append(e) # add the broken processor to the loaded processors for the sake of analyzing the validity of the # entire proposed pipeline, but at this point the pipeline will not be built successfully self.processors[processor_name] = e.err_processor except FileNotFoundError as e: # For a FileNotFoundError, we try to guess if there's # a missing model directory or file. If so, we # suggest the user try to download the models if 'model_path' in curr_processor_config: model_path = curr_processor_config['model_path'] if e.filename == model_path or (isinstance(model_path, (tuple, list)) and e.filename in model_path): model_path = e.filename model_dir, model_name = os.path.split(model_path) lang_dir = os.path.dirname(model_dir) if lang_dir and not os.path.exists(lang_dir): # model files for this language can't be found in the expected directory raise LanguageNotDownloadedError(lang, lang_dir, model_path) from e if processor_name not in resources[lang]: # user asked for a model which doesn't exist for this language? raise UnsupportedProcessorError(processor_name, lang) from e if not os.path.exists(model_path): model_name, _ = os.path.splitext(model_name) # TODO: before recommending this, check that such a thing exists in resources.json. # currently that case is handled by ignoring the model, anyway raise FileNotFoundError('Could not find model file %s, although there are other models downloaded for language %s. Perhaps you need to download a specific model. Try: stanza.download(lang="%s",package=None,processors={"%s":"%s"})' % (model_path, lang, lang, processor_name, model_name)) from e # if we couldn't find a more suitable description of the # FileNotFoundError, just raise the old error raise # if there are any processor exceptions, throw an exception to indicate pipeline build failure if pipeline_reqs_exceptions: logger.info('\n') raise PipelineRequirementsException(pipeline_reqs_exceptions) logger.info("Done loading processors!") @staticmethod def update_kwargs(kwargs, processor_list): processor_dict = {processor: [{'package': model_spec.package, 'dependencies': model_spec.dependencies} for model_spec in model_specs] for (processor, model_specs) in processor_list} for key, value in kwargs.items(): pieces = key.split('_', 1) if len(pieces) == 1: continue k, v = pieces if v == 'model_path': package = value if len(value) < 25 else value[:10]+ '...' + value[-10:] original_spec = processor_dict.get(k, []) if len(original_spec) > 0: dependencies = original_spec[0].get('dependencies') else: dependencies = None processor_dict[k] = [{'package': package, 'dependencies': dependencies}] processor_list = [(processor, [ModelSpecification(processor=processor, package=model_spec['package'], dependencies=model_spec['dependencies']) for model_spec in processor_dict[processor]]) for processor in processor_dict] processor_list = sort_processors(processor_list) return processor_list @staticmethod def filter_config(prefix, config_dict): filtered_dict = {} for key in config_dict.keys(): pieces = key.split('_', 1) # split tokenize_pretokenize to tokenize+pretokenize if len(pieces) == 1: continue k, v = pieces if k == prefix: filtered_dict[v] = config_dict[key] return filtered_dict @property def loaded_processors(self): """ Return all currently loaded processors in execution order. :return: list of Processor instances """ return [self.processors[processor_name] for processor_name in PIPELINE_NAMES if self.processors.get(processor_name)] def process(self, doc, processors=None): """ Run the pipeline processors: allow for a list of processors used by this pipeline action can be list, tuple, set, or comma separated string if None, use all the processors this pipeline knows about MWT is added if necessary otherwise, no care is taken to make sure prerequisites are followed... some of the annotators, such as depparse, will check, but others will fail in some unusual manner or just have really bad results """ assert any([isinstance(doc, str), isinstance(doc, list), isinstance(doc, Document)]), 'input should be either str, list or Document' # empty bulk process if isinstance(doc, list) and len(doc) == 0: return [] # determine whether we are in bulk processing mode for multiple documents bulk=(isinstance(doc, list) and len(doc) > 0 and isinstance(doc[0], Document)) # various options to limit the processors used by this pipeline action if processors is None: processors = PIPELINE_NAMES elif not isinstance(processors, (str, list, tuple, set)): raise ValueError("Cannot process {} as a list of processors to run".format(type(processors))) else: if isinstance(processors, str): processors = {x for x in processors.split(",")} else: processors = set(processors) if TOKENIZE in processors and MWT in self.processors and MWT not in processors: logger.debug("Requested processors for pipeline did not have mwt, but pipeline needs mwt, so mwt is added") processors.add(MWT) processors = [x for x in PIPELINE_NAMES if x in processors] for processor_name in processors: if self.processors.get(processor_name): process = self.processors[processor_name].bulk_process if bulk else self.processors[processor_name].process doc = process(doc) return doc def process_conllu(self, doc, ignore_gapping=True, processors=None): """ Convenience method: treat the doc as a conllu text, convert it, and process it accordingly """ if processors is None: processors = set(self.processors.keys()) if TOKENIZE in processors: processors.remove(TOKENIZE) if MWT in processors: processors.remove(MWT) doc = CoNLL.conll2doc(input_str=doc, ignore_gapping=ignore_gapping) return self.process(doc, processors=processors) def bulk_process(self, docs, *args, **kwargs): """ Run the pipeline in bulk processing mode Expects a list of str or a list of Docs """ # Wrap each text as a Document unless it is already such a document docs = [doc if isinstance(doc, Document) else Document([], text=doc) for doc in docs] return self.process(docs, *args, **kwargs) def stream(self, docs, batch_size=50, *args, **kwargs): """ Go through an iterator of documents in batches, yield processed documents sentence indices will be counted across the entire iterator """ if not isinstance(docs, collections.abc.Iterator): docs = iter(docs) def next_batch(): batch = [] for _ in range(batch_size): try: next_doc = next(docs) batch.append(next_doc) except StopIteration: return batch return batch sentence_start_index = 0 batch = next_batch() while batch: batch = self.bulk_process(batch, *args, **kwargs) for doc in batch: doc.reindex_sentences(sentence_start_index) sentence_start_index += len(doc.sentences) yield doc batch = next_batch() def __str__(self): """ Assemble the processors in order to make a simple description of the pipeline """ processors = ["%s=%s" % (x, str(self.processors[x])) for x in PIPELINE_NAMES if x in self.processors] return "" % ", ".join(processors) def __call__(self, doc, processors=None): return self.process(doc, processors) def main(): # TODO: can add a bunch more arguments parser = argparse.ArgumentParser() parser.add_argument('--lang', type=str, default='en', help='Language of the pipeline to use') parser.add_argument('--input_file', type=str, required=True, help='Input file to read') parser.add_argument('--processors', type=str, default='tokenize,pos,lemma,depparse', help='Processors to use') parser.add_argument('--package', type=str, default='default', help='Which package to use') parser.add_argument('--tokenize_no_ssplit', default=False, action='store_true', help="Don't ssplit") parser.add_argument('--tokenize_pretokenized', default=False, action='store_true', help="Text is pretokenized") args, extra_args = parser.parse_known_args() try: doc = CoNLL.conll2doc(args.input_file) extra_args = { "tokenize_pretokenized": True } except CoNLLError: logger.debug("Input file %s does not appear to be a conllu file. Will read it as a text file") with open(args.input_file, encoding="utf-8") as fin: doc = fin.read() extra_args = {} extra_args['package'] = args.package if args.tokenize_no_ssplit: extra_args['tokenize_no_ssplit'] = True if args.tokenize_pretokenized: extra_args['tokenize_pretokenized'] = True pipe = Pipeline(args.lang, processors=args.processors, **extra_args) doc = pipe(doc) print("{:C}".format(doc)) if __name__ == '__main__': main() ================================================ FILE: stanza/pipeline/coref_processor.py ================================================ """ Processor that attaches coref annotations to a document """ from stanza.models.common.utils import misc_to_space_after from stanza.models.coref.coref_chain import CorefMention, CorefChain from stanza.models.common.doc import Word from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor import torch def extract_text(document, sent_id, start_word, end_word): sentence = document.sentences[sent_id] tokens = [] # the coref model indexes the words from 0, # whereas the ids we are looking at on the tokens start from 1 # here we will switch to ID space start_word = start_word + 1 end_word = end_word + 1 # For each position between start and end word: # If a word is part of an MWT, and the entire token # is inside the range, we use that Token's text for that span # This will let us easily handle words which are split into pieces # Otherwise, we only take the text of the word itself next_idx = start_word while next_idx < end_word: word = sentence.words[next_idx-1] parent_token = word.parent if isinstance(parent_token.id, int) or len(parent_token.id) == 1: tokens.append(parent_token) next_idx += 1 elif parent_token.id[0] >= start_word and parent_token.id[1] < end_word: tokens.append(parent_token) next_idx = parent_token.id[1] + 1 else: tokens.append(word) next_idx += 1 # We use the SpaceAfter or SpacesAfter attribute on each Word or Token # we chose in the above loop to separate the text pieces text = [] for token in tokens: text.append(token.text) text.append(misc_to_space_after(token.misc)) # the last token space_after will be discarded # so that we don't have stray WS at the end of the mention text text = text[:-1] return "".join(text) @register_processor(COREF) class CorefProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([COREF]) # set of processor requirements for this processor REQUIRES_DEFAULT = set([TOKENIZE]) def _set_up_model(self, config, pipeline, device): try: from stanza.models.coref.model import CorefModel except ImportError: raise ImportError("Please install the transformers and peft libraries before using coref! Try `pip install -e .[transformers]`.") # set up model # currently, the model has everything packaged in it # (except its config) # TODO: separate any pretrains if possible # TODO: add device parameter to the load mechanism config_update = {'log_norms': config.get('log_norms', False), 'device': device} model = CorefModel.load_model(path=config['model_path'], ignore={"bert_optimizer", "general_optimizer", "bert_scheduler", "general_scheduler"}, config_update=config_update, foundation_cache=pipeline.foundation_cache) if config.get('batch_size', None): model.config.a_scoring_batch_size = int(config['batch_size']) model.training = False self._model = model # coref_use_zeros=False will turn off creating new nodes and attaching mentions to those zero nodes self._use_zeros = config.get('use_zeros', True) if isinstance(self._use_zeros, str): self._use_zeros = self._use_zeros.lower() != 'false' def process(self, document): sentences = document.sentences cased_words = [] sent_ids = [] word_pos = [] speaker = [] for sent_idx, sentence in enumerate(sentences): for word_idx, word in enumerate(sentence.words): cased_words.append(word.text) sent_ids.append(sent_idx) word_pos.append(word_idx) if sentence.speaker: speaker.append(sentence.speaker) else: speaker.append("_") coref_input = { "document_id": "wb_doc_1", "cased_words": cased_words, "sent_id": sent_ids, "speaker": speaker, } coref_input = self._model.build_doc(coref_input) results = self._model.run(coref_input) # Handle zero anaphora - zero_scores is always predicted zero_nodes_created = self._handle_zero_anaphora(document, results, sent_ids, word_pos) clusters = [] for cluster_idx, span_cluster in enumerate(results.span_clusters): if len(span_cluster) == 0: continue span_cluster = sorted(span_cluster) for span in span_cluster: # check there are no sentence crossings before # manipulating the spans, since we will expect it to # be this way for multiple usages of the spans sent_id = sent_ids[span[0]] if sent_ids[span[1]-1] != sent_id: raise ValueError("The coref model predicted a span that crossed two sentences! Please send this example to us on our github") # treat the longest span as the representative # break ties using the first one # IF there is the POS processor, and it adds upos tags # to the sentence, ties are broken first by maximum # number of UPOS and then earliest in the document max_len = 0 best_span = None max_propn = 0 for span_idx, span in enumerate(span_cluster): word_idx = results.word_clusters[cluster_idx][span_idx] is_zero = zero_nodes_created.get((cluster_idx, word_idx)) if is_zero: continue sent_id = sent_ids[span[0]] sentence = sentences[sent_id] start_word = word_pos[span[0]] # fiddle -1 / +1 so as to avoid problems with coref # clusters that end at exactly the end of a document end_word = word_pos[span[1]-1] + 1 # very UD specific test for most number of proper nouns in a mention # will do nothing if POS is not active (they will all be None) num_propn = sum(word.pos == 'PROPN' for word in sentence.words[start_word:end_word]) if ((span[1] - span[0] > max_len) or span[1] - span[0] == max_len and num_propn > max_propn): max_len = span[1] - span[0] best_span = span_idx max_propn = num_propn mentions = [] for span_idx, span in enumerate(span_cluster): word_idx = results.word_clusters[cluster_idx][span_idx] is_zero = zero_nodes_created.get((cluster_idx, word_idx)) if is_zero: (sent_id, zero_word_id) = is_zero # if the word id is a tuple, it will be attached # to the zero mentions.append( CorefMention( sent_id, zero_word_id, zero_word_id ) ) else: sent_id = sent_ids[span[0]] start_word = word_pos[span[0]] end_word = word_pos[span[1]-1] + 1 mentions.append(CorefMention(sent_id, start_word, end_word)) # if we ended up with no best span, then our "representative text" # is just underscore if best_span is not None: representative = mentions[best_span] representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word) else: representative_text = "_" chain = CorefChain(len(clusters), mentions, representative_text, best_span) clusters.append(chain) document.coref = clusters return document def _handle_zero_anaphora(self, document, results, sent_ids, word_pos): """Handle zero anaphora by creating zero nodes and updating coreference clusters.""" if results.zero_scores is None or results.word_clusters is None: return {} if not self._use_zeros: return {} zero_scores = results.zero_scores.squeeze(-1) if results.zero_scores.dim() > 1 else results.zero_scores is_zero = [] # Flatten word_clusters to get the word indices that correspond to zero_scores cluster_word_ids = [] cluster_mapping = {} counter = 0 for indx, cluster in enumerate(results.word_clusters): for _ in range(len(cluster)): cluster_mapping[counter] = indx counter += 1 cluster_word_ids.extend(cluster) # Find indices where zero_scores > 0 zero_indices = (zero_scores > 0.0).nonzero() # this dict maps (cluster_id, word_id) to (cluster_id, start, end) # which overrides span_clusters zero_to_coref = {} for zero_idx in zero_indices: zero_idx = zero_idx.item() if zero_idx >= len(cluster_word_ids): continue word_idx = cluster_word_ids[zero_idx] sent_id = sent_ids[word_idx] word_id = word_pos[word_idx] # Create zero node - attach BEFORE the current word # This means the zero node comes after word_id-1 but before word_id zero_word_id = ( word_id, len(document.sentences[sent_id]._empty_words)+1 ) # attach after word_id-1, before word_id zero_word = Word(document.sentences[sent_id], { "text": "_", "lemma": "_", "id": zero_word_id }) document.sentences[sent_id]._empty_words.append(zero_word) # Track this zero node for adding to coreference clusters cluster_idx = cluster_mapping[zero_idx] zero_to_coref[(cluster_idx, word_idx)] = ( sent_id, zero_word_id ) return zero_to_coref ================================================ FILE: stanza/pipeline/demo/README.md ================================================ ## Interactive Demo for Stanza ### Requirements stanza, flask ### Run the demo locally 1. Make sure you know how to disable your browser's CORS rule. For Chrome, [this extension](https://mybrowseraddon.com/access-control-allow-origin.html) works pretty well. 2. From this directory, start the Stanza demo server ```bash export FLASK_APP=demo_server.py flask run ``` 3. In `stanza-brat.js`, uncomment the line at the top that declares `serverAddress` and point it to where your flask is serving the demo server (usually `http://localhost:5000`) 4. Open `stanza-brat.html` in your browser (with CORS disabled) and enjoy! ### Common issues Make sure you have the models corresponding to the language you want to test out locally before submitting requests to the server! (Models can be obtained by `import stanza; stanza.download()`. ================================================ FILE: stanza/pipeline/demo/__init__.py ================================================ ================================================ FILE: stanza/pipeline/demo/demo_server.py ================================================ from flask import Flask, request, abort import json import stanza import os app = Flask(__name__, static_url_path='', static_folder=os.path.abspath(os.path.dirname(__file__))) pipelineCache = dict() def get_file(path): res = os.path.join(os.path.dirname(os.path.abspath(__file__)), path) print(res) return res @app.route('/') @app.route('/static/fonts/') def static_file(path): if path in ['stanza-brat.css', 'stanza-brat.js', 'stanza-parseviewer.js', 'loading.gif', 'favicon.png', 'stanza-logo.png', 'Astloch-Bold.ttf', 'Liberation_Sans-Regular.ttf', 'PT_Sans-Caption-Web-Regular.ttf']: return app.send_static_file(path) elif path in 'index.html': return app.send_static_file('stanza-brat.html') else: abort(403) @app.route('/', methods=['GET']) def index(): return static_file('index.html') @app.route('/', methods=['POST']) def annotate(): global pipelineCache properties = request.args.get('properties', '') lang = request.args.get('pipelineLanguage', '') text = list(request.form.keys())[0] if lang not in pipelineCache: pipelineCache[lang] = stanza.Pipeline(lang=lang, use_gpu=False) res = pipelineCache[lang](text) annotated_sentences = [] for sentence in res.sentences: tokens = [] deps = [] for word in sentence.words: tokens.append({'index': word.id, 'word': word.text, 'lemma': word.lemma, 'pos': word.xpos, 'upos': word.upos, 'feats': word.feats, 'ner': word.parent.ner if word.parent.ner is None or word.parent.ner == 'O' else word.parent.ner[2:]}) deps.append({'dep': word.deprel, 'governor': word.head, 'governorGloss': sentence.words[word.head-1].text, 'dependent': word.id, 'dependentGloss': word.text}) annotated_sentences.append({'basicDependencies': deps, 'tokens': tokens}) if hasattr(sentence, 'constituency') and sentence.constituency is not None: annotated_sentences[-1]['parse'] = str(sentence.constituency) return json.dumps({'sentences': annotated_sentences}) def create_app(): return app if __name__ == "__main__": app.run(host='0.0.0.0', port=8080) ================================================ FILE: stanza/pipeline/demo/stanza-brat.css ================================================ .red { color:#990000 } #wrap { min-height: 100%; height: auto; margin: 0 auto -6ex; padding: 0 0 6ex; } .pattern_tab { margin: 1ex; } .pattern_brat { margin-top: 1ex; } .label { color: #777777; font-size: small; } .footer { bottom: 0; width: 100%; /* Set the fixed height of the footer here */ height: 5ex; padding-top: 1ex; margin-top: 1ex; background-color: #f5f5f5; } .corenlp_error { margin-top: 2ex; } /* Styling for parse graph */ .node rect { stroke: #333; fill: #fff; } .parse-RULE rect { fill: #C0D9AF; } .parse-TERMINAL rect { stroke: #333; fill: #EEE8AA; } .node.highlighted { stroke: #ffff00; } .edgePath path { stroke: #333; fill: #333; stroke-width: 1.5px; } .parse-EDGE path { stroke: DarkGray; fill: DarkGray; stroke-width: 1.5px; } .logo { font-family: "Lato", "Gill Sans MT", "Gill Sans", "Helvetica", "Arial", sans-serif; font-style: italic; } ================================================ FILE: stanza/pipeline/demo/stanza-brat.html ================================================
================================================ FILE: stanza/pipeline/demo/stanza-brat.js ================================================ // Takes Stanford CoreNLP JSON output (var data = ... in data.js) // and uses brat to render everything. //var serverAddress = 'http://localhost:5000'; // Load Brat libraries var bratLocation = 'https://nlp.stanford.edu/js/brat/'; head.js( // External libraries bratLocation + '/client/lib/jquery.svg.min.js', bratLocation + '/client/lib/jquery.svgdom.min.js', // brat helper modules bratLocation + '/client/src/configuration.js', bratLocation + '/client/src/util.js', bratLocation + '/client/src/annotation_log.js', bratLocation + '/client/lib/webfont.js', // brat modules bratLocation + '/client/src/dispatcher.js', bratLocation + '/client/src/url_monitor.js', bratLocation + '/client/src/visualizer.js', // parse viewer './stanza-parseviewer.js' ); // Uses Dagre (https://github.com/cpettitt/dagre) for constinuency parse // visualization. It works better than the brat visualization. var useDagre = true; var currentQuery = 'The quick brown fox jumped over the lazy dog.'; var currentSentences = ''; var currentText = ''; // ---------------------------------------------------------------------------- // HELPERS // ---------------------------------------------------------------------------- /** * Add the startsWith function to the String class */ if (typeof String.prototype.startsWith !== 'function') { // see below for better implementation! String.prototype.startsWith = function (str){ return this.indexOf(str) === 0; }; } function isInt(value) { return !isNaN(value) && (function(x) { return (x | 0) === x; })(parseFloat(value)) } /** * A reverse map of PTB tokens to their original gloss */ var tokensMap = { '-LRB-': '(', '-RRB-': ')', '-LSB-': '[', '-RSB-': ']', '-LCB-': '{', '-RCB-': '}', '``': '"', '\'\'': '"', }; /** * A mapping from part of speech tag to the associated * visualization color */ function posColor(posTag) { if (posTag === null) { return '#E3E3E3'; } else if (posTag.startsWith('N')) { return '#A4BCED'; } else if (posTag.startsWith('V') || posTag.startsWith('M')) { return '#ADF6A2'; } else if (posTag.startsWith('P')) { return '#CCDAF6'; } else if (posTag.startsWith('I')) { return '#FFE8BE'; } else if (posTag.startsWith('R') || posTag.startsWith('W')) { return '#FFFDA8'; } else if (posTag.startsWith('D') || posTag === 'CD') { return '#CCADF6'; } else if (posTag.startsWith('J')) { return '#FFFDA8'; } else if (posTag.startsWith('T')) { return '#FFE8BE'; } else if (posTag.startsWith('E') || posTag.startsWith('S')) { return '#E4CBF6'; } else if (posTag.startsWith('CC')) { return '#FFFFFF'; } else if (posTag === 'LS' || posTag === 'FW') { return '#FFFFFF'; } else { return '#E3E3E3'; } } /** * A mapping from part of speech tag to the associated * visualization color */ function uposColor(posTag) { if (posTag === null) { return '#E3E3E3'; } else if (posTag === 'NOUN' || posTag === 'PROPN') { return '#A4BCED'; } else if (posTag.startsWith('V') || posTag === 'AUX') { return '#ADF6A2'; } else if (posTag === 'PART') { return '#CCDAF6'; } else if (posTag === 'ADP') { return '#FFE8BE'; } else if (posTag === 'ADV' || posTag.startsWith('PRON')) { return '#FFFDA8'; } else if (posTag === 'NUM' || posTag === 'DET') { return '#CCADF6'; } else if (posTag === 'ADJ') { return '#FFFDA8'; } else if (posTag.startsWith('E') || posTag.startsWith('S')) { return '#E4CBF6'; } else if (posTag.startsWith('CC')) { return '#FFFFFF'; } else if (posTag === 'X' || posTag === 'FW') { return '#FFFFFF'; } else { return '#E3E3E3'; } } /** * A mapping from named entity tag to the associated * visualization color */ function nerColor(nerTag) { if (nerTag === null) { return '#E3E3E3'; } else if (nerTag === 'PERSON' || nerTag === 'PER') { return '#FFCCAA'; } else if (nerTag === 'ORGANIZATION' || nerTag === 'ORG') { return '#8FB2FF'; } else if (nerTag === 'MISC') { return '#F1F447'; } else if (nerTag === 'LOCATION' || nerTag == 'LOC') { return '#95DFFF'; } else if (nerTag === 'DATE' || nerTag === 'TIME' || nerTag === 'SET') { return '#9AFFE6'; } else if (nerTag === 'MONEY') { return '#FFFFFF'; } else if (nerTag === 'PERCENT') { return '#FFA22B'; } else { return '#E3E3E3'; } } /** * A mapping from sentiment value to the associated * visualization color */ function sentimentColor(sentiment) { if (sentiment === "VERY POSITIVE") { return '#00FF00'; } else if (sentiment === "POSITIVE") { return '#7FFF00'; } else if (sentiment === "NEUTRAL") { return '#FFFF00'; } else if (sentiment === "NEGATIVE") { return '#FF7F00'; } else if (sentiment === "VERY NEGATIVE") { return '#FF0000'; } else { return '#E3E3E3'; } } /** * Get a list of annotators, from the annotator option input. */ function annotators() { var annotators = "tokenize,ssplit"; $('#annotators').find('option:selected').each(function () { annotators += "," + $(this).val(); }); return annotators; } /** * Get the input date */ function date() { function f(n) { return n < 10 ? '0' + n : n; } var date = new Date(); var M = date.getMonth() + 1; var D = date.getDate(); var Y = date.getFullYear(); var h = date.getHours(); var m = date.getMinutes(); var s = date.getSeconds(); return "" + Y + "-" + f(M) + "-" + f(D) + "T" + f(h) + ':' + f(m) + ':' + f(s); } //----------------------------------------------------------------------------- // Constituency parser //----------------------------------------------------------------------------- function ConstituencyParseProcessor() { var parenthesize = function (input, list) { if (list === undefined) { return parenthesize(input, []); } else { var token = input.shift(); if (token === undefined) { return list.pop(); } else if (token === "(") { list.push(parenthesize(input, [])); return parenthesize(input, list); } else if (token === ")") { return list; } else { return parenthesize(input, list.concat(token)); } } }; var toTree = function (list) { if (list.length === 2 && typeof list[1] === 'string') { return {label: list[0], text: list[1], isTerminal: true}; } else if (list.length >= 2) { var label = list.shift(); var node = {label: label}; var rest = list.map(function (x) { var t = toTree(x); if (typeof t === 'object') { t.parent = node; } return t; }); node.children = rest; return node; } else { return list; } }; var indexTree = function (tree, tokens, index) { index = index || 0; if (tree.isTerminal) { tree.token = tokens[index]; tree.tokenIndex = index; tree.tokenStart = index; tree.tokenEnd = index + 1; return index + 1; } else if (tree.children) { tree.tokenStart = index; for (var i = 0; i < tree.children.length; i++) { var child = tree.children[i]; index = indexTree(child, tokens, index); } tree.tokenEnd = index; } return index; }; var tokenize = function (input) { return input.split('"') .map(function (x, i) { if (i % 2 === 0) { // not in string return x.replace(/\(/g, ' ( ') .replace(/\)/g, ' ) '); } else { // in string return x.replace(/ /g, "!whitespace!"); } }) .join('"') .trim() .split(/\s+/) .map(function (x) { return x.replace(/!whitespace!/g, " "); }); }; var convertParseStringToTree = function (input, tokens) { var p = parenthesize(tokenize(input)); if (Array.isArray(p)) { var tree = toTree(p); // Correlate tree with tokens indexTree(tree, tokens); return tree; } }; this.process = function(annotation) { for (var i = 0; i < annotation.sentences.length; i++) { var s = annotation.sentences[i]; if (s.parse) { s.parseTree = convertParseStringToTree(s.parse, s.tokens); } } } } // ---------------------------------------------------------------------------- // RENDER // ---------------------------------------------------------------------------- /** * Render a given JSON data structure */ function render(data, reverse) { // Tweak arguments if (typeof reverse !== 'boolean') { reverse = false; } // Error checks if (typeof data.sentences === 'undefined') { return; } /** * Register an entity type (a tag) for Brat */ var entityTypesSet = {}; var entityTypes = []; function addEntityType(name, type, coarseType) { if (typeof coarseType === "undefined") { coarseType = type; } // Don't add duplicates if (entityTypesSet[type]) return; entityTypesSet[type] = true; // Get the color of the entity type color = '#ffccaa'; if (name === 'POS') { color = posColor(type); } else if (name === 'UPOS') { color = uposColor(type); } else if (name === 'NER') { color = nerColor(coarseType); } else if (name === 'NNER') { color = nerColor(coarseType); } else if (name === 'COREF') { color = '#FFE000'; } else if (name === 'ENTITY') { color = posColor('NN'); } else if (name === 'RELATION') { color = posColor('VB'); } else if (name === 'LEMMA') { color = '#FFFFFF'; } else if (name === 'SENTIMENT') { color = sentimentColor(type); } else if (name === 'LINK') { color = '#FFFFFF'; } else if (name === 'KBP_ENTITY') { color = '#FFFFFF'; } // Register the type entityTypes.push({ type: type, labels : [type], bgColor: color, borderColor: 'darken' }); } /** * Register a relation type (an arc) for Brat */ var relationTypesSet = {}; var relationTypes = []; function addRelationType(type, symmetricEdge) { // Prevent adding duplicates if (relationTypesSet[type]) return; relationTypesSet[type] = true; // Default arguments if (typeof symmetricEdge === 'undefined') { symmetricEdge = false; } // Add the type relationTypes.push({ type: type, labels: [type], dashArray: (symmetricEdge ? '3,3' : undefined), arrowHead: (symmetricEdge ? 'none' : undefined), }); } // // Construct text of annotation // currentText = []; // GLOBAL currentSentences = data.sentences; // GLOBAL data.sentences.forEach(function(sentence) { for (var i = 0; i < sentence.tokens.length; ++i) { var token = sentence.tokens[i]; var word = token.word; if (!(typeof tokensMap[word] === "undefined")) { word = tokensMap[word]; } if (i > 0) { currentText.push(' '); } token.characterOffsetBegin = currentText.length; for (var j = 0; j < word.length; ++j) { currentText.push(word[j]); } token.characterOffsetEnd = currentText.length; } currentText.push('\n'); }); currentText = currentText.join(''); // // Shared variables // These are what we'll render in BRAT // // (pos) var posEntities = []; // (upos) var uposEntities = []; // (lemma) var lemmaEntities = []; // (ner) var nerEntities = []; var nerEntitiesNormalized = []; // (sentiment) var sentimentEntities = []; // (entitylinking) var linkEntities = []; // (dependencies) var depsRelations = []; var deps2Relations = []; // (openie) var openieEntities = []; var openieEntitiesSet = {}; var openieRelations = []; var openieRelationsSet = {}; // (kbp) var kbpEntities = []; var kbpEntitiesSet = []; var kbpRelations = []; var kbpRelationsSet = []; var cparseEntities = []; var cparseRelations = []; // // Loop over sentences. // This fills in the variables above. // for (var sentI = 0; sentI < data.sentences.length; ++sentI) { var sentence = data.sentences[sentI]; var index = sentence.index; var tokens = sentence.tokens; var deps = sentence['basicDependencies']; var deps2 = sentence['enhancedPlusPlusDependencies']; var parseTree = sentence['parseTree']; // POS tags /** * Generate a POS tagged token id */ function posID(i) { return 'POS_' + sentI + '_' + i; } var noXPOS = true; if (tokens.length > 0 && typeof tokens[0].pos !== 'undefined' && tokens[0].pos !== null) { noXPOS = false; for (var i = 0; i < tokens.length; i++) { var token = tokens[i]; var pos = token.pos; var begin = parseInt(token.characterOffsetBegin); var end = parseInt(token.characterOffsetEnd); addEntityType('POS', pos); posEntities.push([posID(i), pos, [[begin, end]]]); } } // Universal POS tags /** * Generate a POS tagged token id */ function uposID(i) { return 'UPOS_' + sentI + '_' + i; } if (tokens.length > 0 && typeof tokens[0].upos !== 'undefined') { for (var i = 0; i < tokens.length; i++) { var token = tokens[i]; var upos = token.upos; var begin = parseInt(token.characterOffsetBegin); var end = parseInt(token.characterOffsetEnd); addEntityType('UPOS', upos); uposEntities.push([uposID(i), upos, [[begin, end]]]); } } // Constituency parse // Carries the same assumption as NER if (parseTree && !useDagre) { var parseEntities = []; var parseRels = []; function processParseTree(tree, index) { tree.visitIndex = index; index++; if (tree.isTerminal) { parseEntities[tree.visitIndex] = uposEntities[tree.tokenIndex]; return index; } else if (tree.children) { addEntityType('PARSENODE', tree.label); parseEntities[tree.visitIndex] = ['PARSENODE_' + sentI + '_' + tree.visitIndex, tree.label, [[tokens[tree.tokenStart].characterOffsetBegin, tokens[tree.tokenEnd-1].characterOffsetEnd]]]; var parentEnt = parseEntities[tree.visitIndex]; for (var i = 0; i < tree.children.length; i++) { var child = tree.children[i]; index = processParseTree(child, index); var childEnt = parseEntities[child.visitIndex]; addRelationType('pc'); parseRels.push(['PARSEEDGE_' + sentI + '_' + parseRels.length, 'pc', [['parent', parentEnt[0]], ['child', childEnt[0]]]]); } } return index; } processParseTree(parseTree, 0); cparseEntities = cparseEntities.concat(cparseEntities, parseEntities); cparseRelations = cparseRelations.concat(parseRels); } // Dependency parsing /** * Process a dependency tree from JSON to Brat relations */ function processDeps(name, deps) { var relations = []; // Format: [${ID}, ${TYPE}, [[${ARGNAME}, ${TARGET}], [${ARGNAME}, ${TARGET}]]] for (var i = 0; i < deps.length; i++) { var dep = deps[i]; var governor = dep.governor - 1; var dependent = dep.dependent - 1; if (governor == -1) continue; addRelationType(dep.dep); relations.push([name + '_' + sentI + '_' + i, dep.dep, [['governor', uposID(governor)], ['dependent', uposID(dependent)]]]); } return relations; } // Actually add the dependencies if (typeof deps !== 'undefined') { depsRelations = depsRelations.concat(processDeps('dep', deps)); } if (typeof deps2 !== 'undefined') { deps2Relations = deps2Relations.concat(processDeps('dep2', deps2)); } // Lemmas if (tokens.length > 0 && typeof tokens[0].lemma !== 'undefined') { for (var i = 0; i < tokens.length; i++) { var token = tokens[i]; var lemma = token.lemma; var begin = parseInt(token.characterOffsetBegin); var end = parseInt(token.characterOffsetEnd); addEntityType('LEMMA', lemma); lemmaEntities.push(['LEMMA_' + sentI + '_' + i, lemma, [[begin, end]]]); } } // NER tags // Assumption: contiguous occurrence of one non-O is a single entity var noNER = true; if (tokens.some(function(token) { return token.ner; })) { noNER = false; for (var i = 0; i < tokens.length; i++) { var ner = tokens[i].ner || 'O'; var normalizedNER = tokens[i].normalizedNER; if (typeof normalizedNER === "undefined") { normalizedNER = ner; } if (ner == 'O') continue; var j = i; while (j < tokens.length - 1 && tokens[j+1].ner == ner) j++; addEntityType('NER', ner, ner); nerEntities.push(['NER_' + sentI + '_' + i, ner, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]); if (ner != normalizedNER) { addEntityType('NNER', normalizedNER, ner); nerEntities.push(['NNER_' + sentI + '_' + i, normalizedNER, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]); } i = j; } } // Sentiment if (typeof sentence.sentiment !== "undefined") { var sentiment = sentence.sentiment.toUpperCase().replace("VERY", "VERY "); addEntityType('SENTIMENT', sentiment); sentimentEntities.push(['SENTIMENT_' + sentI, sentiment, [[tokens[0].characterOffsetBegin, tokens[tokens.length - 1].characterOffsetEnd]]]); } // Entity Links // Carries the same assumption as NER if (tokens.length > 0) { for (var i = 0; i < tokens.length; i++) { var link = tokens[i].entitylink; if (link == 'O' || typeof link === 'undefined') continue; var j = i; while (j < tokens.length - 1 && tokens[j+1].entitylink == link) j++; addEntityType('LINK', link); linkEntities.push(['LINK_' + sentI + '_' + i, link, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]); i = j; } } // Open IE // Helper Functions function openieID(span) { return 'OPENIEENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1]; } function addEntity(span, role) { // Don't add duplicate entities if (openieEntitiesSet[[sentI, span, role]]) return; openieEntitiesSet[[sentI, span, role]] = true; // Add the entity openieEntities.push([openieID(span), role, [[tokens[span[0]].characterOffsetBegin, tokens[span[1] - 1].characterOffsetEnd ]] ]); } function addRelation(gov, dep, role) { // Don't add duplicate relations if (openieRelationsSet[[sentI, gov, dep, role]]) return; openieRelationsSet[[sentI, gov, dep, role]] = true; // Add the relation openieRelations.push(['OPENIESUBJREL_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1], role, [['governor', openieID(gov)], ['dependent', openieID(dep)] ] ]); } // Render OpenIE if (typeof sentence.openie !== 'undefined') { // Register the entities + relations we'll need addEntityType('ENTITY', 'Entity'); addEntityType('RELATION', 'Relation'); addRelationType('subject'); addRelationType('object'); // Loop over triples for (var i = 0; i < sentence.openie.length; ++i) { var subjectSpan = sentence.openie[i].subjectSpan; var relationSpan = sentence.openie[i].relationSpan; var objectSpan = sentence.openie[i].objectSpan; if (parseInt(relationSpan[0]) < 0 || parseInt(relationSpan[1]) < 0) { continue; // This is a phantom relation } var begin = parseInt(token.characterOffsetBegin); // Add the entities addEntity(subjectSpan, 'Entity'); addEntity(relationSpan, 'Relation'); addEntity(objectSpan, 'Entity'); // Add the relations addRelation(relationSpan, subjectSpan, 'subject'); addRelation(relationSpan, objectSpan, 'object'); } } // End OpenIE block // // KBP // // Helper Functions function kbpEntity(span) { return 'KBPENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1]; } function addKBPEntity(span, role) { // Don't add duplicate entities if (kbpEntitiesSet[[sentI, span, role]]) return; kbpEntitiesSet[[sentI, span, role]] = true; // Add the entity kbpEntities.push([kbpEntity(span), role, [[tokens[span[0]].characterOffsetBegin, tokens[span[1] - 1].characterOffsetEnd ]] ]); } function addKBPRelation(gov, dep, role) { // Don't add duplicate relations if (kbpRelationsSet[[sentI, gov, dep, role]]) return; kbpRelationsSet[[sentI, gov, dep, role]] = true; // Add the relation kbpRelations.push(['KBPRELATION_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1], role, [['governor', kbpEntity(gov)], ['dependent', kbpEntity(dep)] ] ]); } if (typeof sentence.kbp !== 'undefined') { // Register the entities + relations we'll need addRelationType('subject'); addRelationType('object'); // Loop over triples for (var i = 0; i < sentence.kbp.length; ++i) { var subjectSpan = sentence.kbp[i].subjectSpan; var subjectLink = 'Entity'; for (var k = subjectSpan[0]; k < subjectSpan[1]; ++k) { if (subjectLink == 'Entity' && typeof tokens[k] !== 'undefined' && tokens[k].entitylink != 'O' && typeof tokens[k].entitylink !== 'undefined') { subjectLink = tokens[k].entitylink } } addEntityType('KBP_ENTITY', subjectLink); var objectSpan = sentence.kbp[i].objectSpan; var objectLink = 'Entity'; for (var k = objectSpan[0]; k < objectSpan[1]; ++k) { if (objectLink == 'Entity' && typeof tokens[k] !== 'undefined' && tokens[k].entitylink != 'O' && typeof tokens[k].entitylink !== 'undefined') { objectLink = tokens[k].entitylink } } addEntityType('KBP_ENTITY', objectLink); var relation = sentence.kbp[i].relation; var begin = parseInt(token.characterOffsetBegin); // Add the entities addKBPEntity(subjectSpan, subjectLink); addKBPEntity(objectSpan, objectLink); // Add the relations addKBPRelation(subjectSpan, objectSpan, relation); } } // End KBP block } // End sentence loop // // Coreference // var corefEntities = []; var corefRelations = []; if (typeof data.corefs !== 'undefined') { addRelationType('coref', true); addEntityType('COREF', 'Mention'); var clusters = Object.keys(data.corefs); clusters.forEach( function (clusterId) { var chain = data.corefs[clusterId]; if (chain.length > 1) { for (var i = 0; i < chain.length; ++i) { var mention = chain[i]; var id = 'COREF' + mention.id; var tokens = data.sentences[mention.sentNum - 1].tokens; corefEntities.push([id, 'Mention', [[tokens[mention.startIndex - 1].characterOffsetBegin, tokens[mention.endIndex - 2].characterOffsetEnd ]] ]); if (i > 0) { var lastId = 'COREF' + chain[i - 1].id; corefRelations.push(['COREF' + chain[i-1].id + '_' + chain[i].id, 'coref', [['governor', lastId], ['dependent', id] ] ]); } } } }); } // End coreference block // // Actually render the elements // /** * Helper function to render a given set of entities / relations * to a Div, if it exists. */ function embed(container, entities, relations, reverse) { var text = currentText; if (reverse) { var length = currentText.length; for (var i = 0; i < entities.length; ++i) { var offsets = entities[i][2][0]; var tmp = length - offsets[0]; offsets[0] = length - offsets[1]; offsets[1] = tmp; } text = text.split("").reverse().join(""); } if ($('#' + container).length > 0) { Util.embed(container, {entity_types: entityTypes, relation_types: relationTypes}, {text: text, entities: entities, relations: relations} ); } } function reportna(container, text) { $('#' + container).text(text); } // Render each annotation head.ready(function() { if (!noXPOS) { embed('pos', posEntities); } else { reportna('pos', 'XPOS is not available for this language at this time.') } embed('upos', uposEntities); embed('lemma', lemmaEntities); if (!noNER) { embed('ner', nerEntities); } else { reportna('ner', 'NER is not available for this language at this time.') } embed('entities', linkEntities); if (!useDagre) { embed('parse', cparseEntities, cparseRelations); } embed('deps', uposEntities, depsRelations); embed('deps2', posEntities, deps2Relations); embed('coref', corefEntities, corefRelations); embed('openie', openieEntities, openieRelations); embed('kbp', kbpEntities, kbpRelations); embed('sentiment', sentimentEntities); // Constituency parse // Uses d3 and dagre-d3 (not brat) if ($('#parse').length > 0 && useDagre) { var parseViewer = new ParseViewer({ selector: '#parse' }); parseViewer.showAnnotation(data); $('#parse').addClass('svg').css('display', 'block'); } }); } // End render function /** * Render a TokensRegex response */ function renderTokensregex(data) { /** * Register an entity type (a tag) for Brat */ var entityTypesSet = {}; var entityTypes = []; function addEntityType(type, color) { // Don't add duplicates if (entityTypesSet[type]) return; entityTypesSet[type] = true; // Set the color if (typeof color === 'undefined') { color = '#ADF6A2'; } // Register the type entityTypes.push({ type: type, labels : [type], bgColor: color, borderColor: 'darken' }); } var entities = []; for (var sentI = 0; sentI < data.sentences.length; ++sentI) { var tokens = currentSentences[sentI].tokens; for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) { var match = data.sentences[sentI][matchI]; // Add groups for (groupName in match) { if (groupName.startsWith("$") || isInt(groupName)) { addEntityType(groupName, '#FFFDA8'); var begin = parseInt(tokens[match[groupName].begin].characterOffsetBegin); var end = parseInt(tokens[match[groupName].end - 1].characterOffsetEnd); entities.push(['TOK_' + sentI + '_' + matchI + '_' + groupName, groupName, [[begin, end]]]); } } // Add match addEntityType('match', '#ADF6A2'); var begin = parseInt(tokens[match.begin].characterOffsetBegin); var end = parseInt(tokens[match.end - 1].characterOffsetEnd); entities.push(['TOK_' + sentI + '_' + matchI + '_match', 'match', [[begin, end]]]); } } Util.embed('tokensregex', {entity_types: entityTypes, relation_types: []}, {text: currentText, entities: entities, relations: []} ); } // END renderTokensregex() /** * Render a Semgrex response */ function renderSemgrex(data) { /** * Register an entity type (a tag) for Brat */ var entityTypesSet = {}; var entityTypes = []; function addEntityType(type, color) { // Don't add duplicates if (entityTypesSet[type]) return; entityTypesSet[type] = true; // Set the color if (typeof color === 'undefined') { color = '#ADF6A2'; } // Register the type entityTypes.push({ type: type, labels : [type], bgColor: color, borderColor: 'darken' }); } relationTypes = [{ type: 'semgrex', labels: ['-'], dashArray: '3,3', arrowHead: 'none', }]; var entities = []; var relations = []; for (var sentI = 0; sentI < data.sentences.length; ++sentI) { var tokens = currentSentences[sentI].tokens; for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) { var match = data.sentences[sentI][matchI]; // Add match addEntityType('match', '#ADF6A2'); var begin = parseInt(tokens[match.begin].characterOffsetBegin); var end = parseInt(tokens[match.end - 1].characterOffsetEnd); entities.push(['SEM_' + sentI + '_' + matchI + '_match', 'match', [[begin, end]]]); // Add groups for (groupName in match) { if (groupName.startsWith("$") || isInt(groupName)) { // (add node) group = match[groupName]; groupName = groupName.substring(1); addEntityType(groupName, '#FFFDA8'); var begin = parseInt(tokens[group.begin].characterOffsetBegin); var end = parseInt(tokens[group.end - 1].characterOffsetEnd); entities.push(['SEM_' + sentI + '_' + matchI + '_' + groupName, groupName, [[begin, end]]]); // (add relation) relations.push(['SEMGREX_' + sentI + '_' + matchI + '_' + groupName, 'semgrex', [['governor', 'SEM_' + sentI + '_' + matchI + '_match'], ['dependent', 'SEM_' + sentI + '_' + matchI + '_' + groupName] ] ]); } } } } Util.embed('semgrex', {entity_types: entityTypes, relation_types: relationTypes}, {text: currentText, entities: entities, relations: relations} ); } // END renderSemgrex /** * Render a Tregex response */ function renderTregex(data) { $('#tregex').empty(); $('#tregex').append('
' + JSON.stringify(data, null, 4) + '
'); } // END renderTregex // ---------------------------------------------------------------------------- // MAIN // ---------------------------------------------------------------------------- /** * MAIN() * * The entry point of the page */ $(document).ready(function() { // Some initial styling $('.chosen-select').chosen(); $('.chosen-container').css('width', '100%'); // Language-specific changes $('#language').on('change', function() { $('#text').attr('dir', ''); if ($('#language').val() === 'ar' || $('#language').val() === 'fa' || $('#language').val() === 'he' || $('#language').val() === 'ur') { $('#text').attr('dir', 'rtl'); } if ($('#language').val() === 'ar') { $('#text').attr('placeholder', 'على سبيل المثال، قفز الثعلب البني السريع فوق الكلب الكسول.'); } else if ($('#language').val() === 'en') { $('#text').attr('placeholder', 'e.g., The quick brown fox jumped over the lazy dog.'); } else if ($('#language').val() === 'zh') { $('#text').attr('placeholder', '例如,快速的棕色狐狸跳过了懒惰的狗。'); } else if ($('#language').val() === 'zh-Hant') { $('#text').attr('placeholder', '例如,快速的棕色狐狸跳過了懶惰的狗。'); } else if ($('#language').val() === 'fr') { $('#text').attr('placeholder', 'Par exemple, le renard brun rapide a sauté sur le chien paresseux.'); } else if ($('#language').val() === 'de') { $('#text').attr('placeholder', 'Z. B. sprang der schnelle braune Fuchs über den faulen Hund.'); } else if ($('#language').val() === 'es') { $('#text').attr('placeholder', 'Por ejemplo, el rápido zorro marrón saltó sobre el perro perezoso.'); } else if ($('#language').val() === 'ur') { $('#text').attr('placeholder', 'میرا نام علی ہے'); } else { $('#text').attr('placeholder', 'Unknown language for placeholder query: ' + $('#language').val()); } }); // Submit on shift-enter $('#text').keydown(function (event) { if (event.keyCode == 13) { if(event.shiftKey){ event.preventDefault(); // don't register the enter key when pressed return false; } } }); $('#text').keyup(function (event) { if (event.keyCode == 13) { if(event.shiftKey){ $('#submit').click(); // submit the form when the enter key is released event.stopPropagation(); return false; } } }); // Submit on clicking the 'submit' button $('#submit').click(function() { // Get the text to annotate currentQuery = $('#text').val(); if (currentQuery.trim() == '') { if ($('#language').val() === 'ar') { currentQuery = 'قفز الثعلب البني السريع فوق الكلب الكسول.'; } else if ($('#language').val() === 'en') { currentQuery = 'The quick brown fox jumped over the lazy dog.'; } else if ($('#language').val() === 'zh') { currentQuery = '快速的棕色狐狸跳过了懒惰的狗。'; } else if ($('#language').val() === 'zh-Hant') { currentQuery = '快速的棕色狐狸跳過了懶惰的狗。'; } else if ($('#language').val() === 'fr') { currentQuery = 'Le renard brun rapide a sauté sur le chien paresseux.'; } else if ($('#language').val() === 'de') { currentQuery = 'Sprang der schnelle braune Fuchs über den faulen Hund.'; } else if ($('#language').val() === 'es') { currentQuery = 'El rápido zorro marrón saltó sobre el perro perezoso.'; } else if ($('#language').val() === 'ur') { currentQuery = 'میرا نام علی ہے'; } else { currentQuery = 'Unknown language for default query: ' + $('#language').val(); } $('#text').val(currentQuery); } // Update the UI $('#submit').prop('disabled', true); $('#annotations').hide(); $('#patterns_row').hide(); $('#loading').show(); // Run query $.ajax({ type: 'POST', url: serverAddress + '?properties=' + encodeURIComponent( '{"annotators": "' + annotators() + '", "date": "' + date() + '"}') + '&pipelineLanguage=' + encodeURIComponent($('#language').val()), data: encodeURIComponent(currentQuery), //jQuery doesn't automatically URI encode strings dataType: 'json', contentType: "application/x-www-form-urlencoded;charset=UTF-8", responseType: "application/json", success: function(data) { $('#submit').prop('disabled', false); if (typeof data === 'undefined' || data.sentences == undefined) { alert("Failed to reach server!"); } else { // Process constituency parse var constituencyParseProcessor = new ConstituencyParseProcessor(); constituencyParseProcessor.process(data); // Empty divs $('#annotations').empty(); // Re-render divs function createAnnotationDiv(id, annotator, selector, label) { // (make sure we requested that element) if (annotators().split(",").indexOf(annotator) < 0) { return; } // (make sure the data contains that element) ok = false; if (typeof data[selector] !== 'undefined') { ok = true; } else if (typeof data.sentences !== 'undefined' && data.sentences.length > 0) { if (typeof data.sentences[0][selector] !== 'undefined') { ok = true; } else if (typeof data.sentences[0].tokens != 'undefined' && data.sentences[0].tokens.length > 0) { // (make sure the annotator select is in at least one of the tokens of any sentence) ok = data.sentences.some(function(sentence) { return sentence.tokens.some(function(token) { return typeof token[selector] !== 'undefined'; }); }); } } // (render the element) if (ok) { $('#annotations').append('

' + label + ':

'); } } // (create the divs) // div id annotator field_in_data label createAnnotationDiv('pos', 'pos', 'pos', 'Part-of-Speech (XPOS)' ); createAnnotationDiv('upos', 'upos', 'upos', 'Universal Part-of-Speech'); createAnnotationDiv('lemma', 'lemma', 'lemma', 'Lemmas' ); createAnnotationDiv('ner', 'ner', 'ner', 'Named Entity Recognition'); createAnnotationDiv('deps', 'depparse', 'basicDependencies', 'Universal Dependencies' ); createAnnotationDiv('parse', 'parse', 'parseTree', 'Constituency Parse' ); //createAnnotationDiv('deps2', 'depparse', 'enhancedPlusPlusDependencies', 'Enhanced++ Dependencies' ); //createAnnotationDiv('openie', 'openie', 'openie', 'Open IE' ); //createAnnotationDiv('coref', 'coref', 'corefs', 'Coreference' ); //createAnnotationDiv('entities', 'entitylink', 'entitylink', 'Wikidict Entities' ); //createAnnotationDiv('kbp', 'kbp', 'kbp', 'KBP Relations' ); //createAnnotationDiv('sentiment','sentiment', 'sentiment', 'Sentiment' ); // Update UI $('#loading').hide(); $('.corenlp_error').remove(); // Clear error messages $('#annotations').show(); // Render var reverse = ($('#language').val() === 'ar' || $('#language').val() === 'fa' || $('#language').val() === 'he' || $('#language').val() === 'ur'); render(data, reverse); // Render patterns //$('#annotations').append('

CoreNLP Tools:

'); // TODO(gabor) a strange place to add this header to //$('#patterns_row').show(); } }, error: function(data) { DATA = data; var alertDiv = $('
').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('corenlp_error').attr('role', 'alert') var button = $(''); var message = $('').text(data.responseText); button.appendTo(alertDiv); message.appendTo(alertDiv); $('#loading').hide(); alertDiv.appendTo($('#errors')); $('#submit').prop('disabled', false); } }); event.preventDefault(); event.stopPropagation(); return false; }); // Support passing parameters on page launch, via window.location.hash parameters. // Example: http://localhost:9000/#text=foo%20bar&annotators=pos,lemma,ner (function() { var rawParams = window.location.hash.slice(1).split("&"); var params = {}; rawParams.forEach(function(paramKV) { paramKV = paramKV.split("="); if (paramKV.length === 2) { var key = paramKV[0]; var value = paramKV[1]; params[key] = value; } }); if (params.text) { var text = decodeURIComponent(params.text); $('#text').val(text); } if (params.annotators) { var annotators = params.annotators.split(","); // De-select everything $('#annotators').find('option').each(function() { $(this).prop('selected', false); }); // Select the specified ones. annotators.forEach(function(a) { $('#annotators').find('option[value="'+a+'"]').prop('selected', true); }); // Refresh Chosen $('#annotators').trigger('chosen:updated'); } if (params.text || params.annotators) { // Finally, let's auto-submit. $('#submit').click(); } })(); $('#form_tokensregex').submit( function (e) { // Don't actually submit the form e.preventDefault(); // Get text if ($('#tokensregex_search').val().trim() == '') { $('#tokensregex_search').val('(?$foxtype [{pos:JJ}]+ ) fox'); } var pattern = $('#tokensregex_search').val(); // Remove existing annotation $('#tokensregex').remove(); // Make ajax call $.ajax({ type: 'POST', url: serverAddress + '/tokensregex?pattern=' + encodeURIComponent( pattern.replace("&", "\\&").replace('+', '\\+')) + '&properties=' + encodeURIComponent( '{"annotators": "' + annotators() + '", "date": "' + date() + '"}') + '&pipelineLanguage=' + encodeURIComponent($('#language').val()), data: encodeURIComponent(currentQuery), success: function(data) { $('.tokensregex_error').remove(); // Clear error messages $('
').appendTo($('#div_tokensregex')); renderTokensregex(data); }, error: function(data) { var alertDiv = $('
').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tokensregex_error').attr('role', 'alert') var button = $(''); var message = $('').text(data.responseText); button.appendTo(alertDiv); message.appendTo(alertDiv); alertDiv.appendTo($('#div_tokensregex')); } }); }); $('#form_semgrex').submit( function (e) { // Don't actually submit the form e.preventDefault(); // Get text if ($('#semgrex_search').val().trim() == '') { $('#semgrex_search').val('{pos:/VB.*/} >nsubj {}=subject >/nmod:.*/ {}=prep_phrase'); } var pattern = $('#semgrex_search').val(); // Remove existing annotation $('#semgrex').remove(); // Add missing required annotators var requiredAnnotators = annotators().split(','); if (requiredAnnotators.indexOf('depparse') < 0) { requiredAnnotators.push('depparse'); } // Make ajax call $.ajax({ type: 'POST', url: serverAddress + '/semgrex?pattern=' + encodeURIComponent( pattern.replace("&", "\\&").replace('+', '\\+')) + '&properties=' + encodeURIComponent( '{"annotators": "' + requiredAnnotators.join(',') + '", "date": "' + date() + '"}') + '&pipelineLanguage=' + encodeURIComponent($('#language').val()), data: encodeURIComponent(currentQuery), success: function(data) { $('.semgrex_error').remove(); // Clear error messages $('
').appendTo($('#div_semgrex')); renderSemgrex(data); }, error: function(data) { var alertDiv = $('
').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('semgrex_error').attr('role', 'alert') var button = $(''); var message = $('').text(data.responseText); button.appendTo(alertDiv); message.appendTo(alertDiv); alertDiv.appendTo($('#div_semgrex')); } }); }); $('#form_tregex').submit( function (e) { // Don't actually submit the form e.preventDefault(); // Get text if ($('#tregex_search').val().trim() == '') { $('#tregex_search').val('NP < NN=animal'); } var pattern = $('#tregex_search').val(); // Remove existing annotation $('#tregex').remove(); // Add missing required annotators var requiredAnnotators = annotators().split(','); if (requiredAnnotators.indexOf('parse') < 0) { requiredAnnotators.push('parse'); } // Make ajax call $.ajax({ type: 'POST', url: serverAddress + '/tregex?pattern=' + encodeURIComponent( pattern.replace("&", "\\&").replace('+', '\\+')) + '&properties=' + encodeURIComponent( '{"annotators": "' + requiredAnnotators.join(',') + '", "date": "' + date() + '"}') + '&pipelineLanguage=' + encodeURIComponent($('#language').val()), data: encodeURIComponent(currentQuery), success: function(data) { $('.tregex_error').remove(); // Clear error messages $('
').appendTo($('#div_tregex')); renderTregex(data); }, error: function(data) { var alertDiv = $('
').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tregex_error').attr('role', 'alert') var button = $(''); var message = $('').text(data.responseText); button.appendTo(alertDiv); message.appendTo(alertDiv); alertDiv.appendTo($('#div_tregex')); } }); }); }); ================================================ FILE: stanza/pipeline/demo/stanza-parseviewer.js ================================================ //'use strict'; //d3 || require('d3'); //var dagreD3 = require('dagre-d3'); //var jquery = require('jquery'); //var $ = jquery; var ParseViewer = function(params) { // Container in which the scene template is displayed this.selector = params.selector; this.container = $(this.selector); this.fitToGraph = true; this.onClickNodeCallback = params.onClickNodeCallback; this.onHoverNodeCallback = params.onHoverNodeCallback; this.init(); return this; }; ParseViewer.MIN_WIDTH = 100; ParseViewer.MIN_HEIGHT = 100; ParseViewer.prototype.constructor = ParseViewer; ParseViewer.prototype.getAutoWidth = function () { return Math.max(ParseViewer.MIN_WIDTH, this.container.width()); }; ParseViewer.prototype.getAutoHeight = function () { return Math.max(ParseViewer.MIN_HEIGHT, this.container.height() - 20); }; ParseViewer.prototype.init = function () { var canvasWidth = this.getAutoWidth(); var canvasHeight = this.getAutoHeight(); this.parseElem = d3.select(this.selector) .append('svg') .attr({'width': canvasWidth, 'height': canvasHeight}) .style({'width': canvasWidth, 'height': canvasHeight}); console.log(this.parseElem); this.graph = null; this.graphRendered = false; this.controls = $('
'); this.container.append(this.controls); }; var GraphBuilder = function(roots) { // Create the input graph this.graph = new dagreD3.graphlib.Graph() .setGraph({}) .setDefaultEdgeLabel(function () { return {}; }); this.visitIndex = 0; //console.log('building graph', roots); for (var i = 0; i < roots.length; i++) { this.build(roots[i]); } }; GraphBuilder.prototype.build = function(node) { console.log(node); // Track my visit index this.visitIndex++; node.visitIndex = this.visitIndex; // Add a node var nodeData = node; // TODO: replace with semantic data var nodeLabel = node.label; var nodeIndex = node.visitIndex; var nodeClass = 'parse-RULE'; this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData }); if (node.parent) { this.graph.setEdge(node.parent.visitIndex, nodeIndex, { class: 'parse-EDGE' }); } if (node.isTerminal) { this.visitIndex++; nodeIndex = this.visitIndex; nodeLabel = node.text; nodeClass = 'parse-TERMINAL'; this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData }); this.graph.setEdge(node.visitIndex, nodeIndex, { class: 'parse-EDGE' }); } else if (node.children) { for (var i = 0; i < node.children.length; i++) { this.build(node.children[i]); } } }; ParseViewer.prototype.updateGraphPosition = function (svg, g, minWidth, minHeight) { if (this.fitToGraph) { minWidth = g.graph().width; minHeight = this.getAutoHeight(); } adjustGraphPositioning(svg, g, minWidth, minHeight); }; function adjustGraphPositioning(svg, g, minWidth, minHeight) { // Resize svg var newWidth = Math.max(minWidth, g.graph().width); var newHeight = Math.max(minHeight, g.graph().height + 40); svg.attr({'width': newWidth, 'height': newHeight}); svg.style({'width': newWidth, 'height': newHeight}); // Center the graph var svgGroup = svg.select('g'); var xCenterOffset = (svg.attr('width') - g.graph().width) / 2; svgGroup.attr('transform', 'translate(' + xCenterOffset + ', 20)'); svg.attr('height', g.graph().height + 40); svg.style('height', g.graph().height + 40); } ParseViewer.prototype.renderGraph = function (svg, g, parse) { // Create the renderer var render = new dagreD3.render(); // Run the renderer. This is what draws the final graph. var svgGroup = svg.select('g'); render(svgGroup, g); var scope = this; var nodes = svgGroup.selectAll('g.node'); nodes.on('click', function (d) { var v = d; var node = g.node(v); if (scope.onClickNodeCallback) { scope.onClickNodeCallback(node.data); } console.log(g.node(v)); } ); nodes.on('mouseover', function (d) { var v = d; var node = g.node(v); if (scope.onHoverNodeCallback) { scope.onHoverNodeCallback(node.data); } } ); this.updateGraphPosition(svg, g, svg.attr('width'), svg.attr('height')); this.graphRendered = true; }; ParseViewer.prototype.showParse = function (root) { this.showParses([root]); }; ParseViewer.prototype.showParses = function (roots) { // Take parse and create a graph var gb = new GraphBuilder(roots); var g = gb.graph; g.nodes().forEach(function (v) { var node = g.node(v); // Round the corners of the nodes node.rx = node.ry = 5; }); var svg = this.parseElem; svg.selectAll('*').remove(); var svgGroup = svg.append('g'); this.graph = g; this.parse = roots; if (this.container.is(':visible')) { if (roots.length > 0) { this.renderGraph(svg, this.graph, this.parse); } } else { this.graphRendered = false; } }; ParseViewer.prototype.showAnnotation = function (annotation) { var parses = []; for (var i = 0; i < annotation.sentences.length; i++) { var s = annotation.sentences[i]; if (s && s.parseTree) { parses.push(s.parseTree); } } this.showParses(parses); }; ParseViewer.prototype.onResize = function () { var canvasWidth = this.getAutoWidth(); var canvasHeight = this.getAutoHeight(); var svg = this.parseElem; // Center the graph var svgGroup = svg.select('g'); if (svgGroup && this.graph) { if (!this.graphRendered) { svg.attr({'width': canvasWidth, 'height': canvasHeight}); svg.style({'width': canvasWidth, 'height': canvasHeight}); this.renderGraph(svg, this.graph, this.parse); } else { this.updateGraphPosition(svg, this.graph, canvasWidth, canvasHeight); } } else { svg.attr({'width': canvasWidth, 'height': canvasHeight}); svg.style({'width': canvasWidth, 'height': canvasHeight}); } }; // Exports //module.exports = ParseViewer; ================================================ FILE: stanza/pipeline/depparse_processor.py ================================================ """ Processor for performing dependency parsing """ import torch from stanza.models.common import doc from stanza.models.common.utils import unsort from stanza.models.common.vocab import VOCAB_PREFIX from stanza.models.depparse.data import DataLoader from stanza.models.depparse.trainer import Trainer from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor # these imports trigger the "register_variant" decorations from stanza.pipeline.external.corenlp_converter_depparse import ConverterDepparse DEFAULT_SEPARATE_BATCH=150 @register_processor(name=DEPPARSE) class DepparseProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([DEPPARSE]) # set of processor requirements for this processor REQUIRES_DEFAULT = set([TOKENIZE, POS, LEMMA]) def __init__(self, config, pipeline, device): self._pretagged = None super().__init__(config, pipeline, device) def _set_up_requires(self): self._pretagged = self._config.get('pretagged') if self._pretagged: self._requires = set() else: self._requires = self.__class__.REQUIRES_DEFAULT def _set_up_model(self, config, pipeline, device): self._trainer = config.get('trainer') if self._trainer is not None: return self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None args = {'charlm_forward_file': config.get('forward_charlm_path', None), 'charlm_backward_file': config.get('backward_charlm_path', None)} self._trainer = Trainer(args=args, pretrain=self.pretrain, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache) def get_known_relations(self): """ Return a list of relations which this processor can produce """ keys = [k for k in self.vocab['deprel']._unit2id.keys() if k not in VOCAB_PREFIX] return keys def process(self, document): if hasattr(self, '_variant'): return self._variant.process(document) if any(word.upos is None and word.xpos is None for sentence in document.sentences for word in sentence.words): raise ValueError("POS not run before depparse!") try: batch = DataLoader(document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True, sort_during_eval=self.config.get('sort_during_eval', True), min_length_to_batch_separately=self.config.get('min_length_to_batch_separately', DEFAULT_SEPARATE_BATCH)) with torch.no_grad(): preds = [] for i, b in enumerate(batch): preds += self.trainer.predict(b) if batch.data_orig_idx is not None: preds = unsort(preds, batch.data_orig_idx) batch.doc.set((doc.HEAD, doc.DEPREL), [y for x in preds for y in x]) # build dependencies based on predictions for sentence in batch.doc.sentences: sentence.build_dependencies() return batch.doc except RuntimeError as e: if str(e).startswith("CUDA out of memory. Tried to allocate"): new_message = str(e) + " ... You may be able to compensate for this by separating long sentences into their own batch with a parameter such as depparse_min_length_to_batch_separately=150 or by limiting the overall batch size with depparse_batch_size=400." raise RuntimeError(new_message) from e else: raise ================================================ FILE: stanza/pipeline/external/__init__.py ================================================ ================================================ FILE: stanza/pipeline/external/corenlp_converter_depparse.py ================================================ """ A depparse processor which converts constituency trees using CoreNLP """ from stanza.pipeline._constants import TOKENIZE, CONSTITUENCY, DEPPARSE from stanza.pipeline.processor import ProcessorVariant, register_processor_variant from stanza.server.dependency_converter import DependencyConverter @register_processor_variant(DEPPARSE, 'converter') class ConverterDepparse(ProcessorVariant): # set of processor requirements for this processor REQUIRES_DEFAULT = set([TOKENIZE, CONSTITUENCY]) def __init__(self, config): if config['lang'] != 'en': raise ValueError("Constituency to dependency converter only works for English") # TODO: get classpath from config # TODO: close this when finished? # a more involved approach would be to turn the Pipeline into # a context with __enter__ and __exit__ # __exit__ would try to free all resources, although some # might linger such as GPU allocations # maybe it isn't worth even trying to clean things up on account of that self.converter = DependencyConverter(classpath="$CLASSPATH") self.converter.open_pipe() def process(self, document): return self.converter.process(document) ================================================ FILE: stanza/pipeline/external/jieba.py ================================================ """ Processors related to Jieba in the pipeline. """ import re import warnings from stanza.models.common import doc from stanza.pipeline._constants import TOKENIZE from stanza.pipeline.processor import ProcessorVariant, register_processor_variant def check_jieba(): """ Import necessary components from Jieba to perform tokenization. """ try: import jieba except ImportError: raise ImportError( "Jieba is used but not installed on your machine. Go to https://pypi.org/project/jieba/ for installation instructions." ) return True @register_processor_variant(TOKENIZE, 'jieba') class JiebaTokenizer(ProcessorVariant): def __init__(self, config): """ Construct a Jieba-based tokenizer by loading the Jieba pipeline. Note that this tokenizer uses regex for sentence segmentation. """ if config['lang'] not in ['zh', 'zh-hans', 'zh-hant']: raise Exception("Jieba tokenizer is currently only allowed in Chinese (simplified or traditional) pipelines.") # Surpress a DeprecationWarning about pkg_resource from jieba. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning, module="jieba") check_jieba() import jieba self.nlp = jieba self.no_ssplit = config.get('no_ssplit', False) def process(self, document): """ Tokenize a document with the Jieba tokenizer and wrap the results into a Doc object. """ if isinstance(document, doc.Document): text = document.text else: text = document if not isinstance(text, str): raise Exception("Must supply a string or Stanza Document object to the Jieba tokenizer.") tokens = self.nlp.cut(text, cut_all=False) sentences = [] current_sentence = [] offset = 0 for token in tokens: if re.match(r'\s+', token): offset += len(token) continue token_entry = { doc.TEXT: token, doc.MISC: f"{doc.START_CHAR}={offset}|{doc.END_CHAR}={offset+len(token)}" } current_sentence.append(token_entry) offset += len(token) if not self.no_ssplit and token in ['。', '!', '?', '!', '?']: sentences.append(current_sentence) current_sentence = [] if len(current_sentence) > 0: sentences.append(current_sentence) return doc.Document(sentences, text) ================================================ FILE: stanza/pipeline/external/pythainlp.py ================================================ """ Processors related to PyThaiNLP in the pipeline. GitHub Home: https://github.com/PyThaiNLP/pythainlp """ from stanza.models.common import doc from stanza.pipeline._constants import TOKENIZE from stanza.pipeline.processor import ProcessorVariant, register_processor_variant def check_pythainlp(): """ Import necessary components from pythainlp to perform tokenization. """ try: import pythainlp except ImportError: raise ImportError( "The pythainlp library is required. " "Try to install it with `pip install pythainlp`. " "Go to https://github.com/PyThaiNLP/pythainlp for more information." ) return True @register_processor_variant(TOKENIZE, 'pythainlp') class PyThaiNLPTokenizer(ProcessorVariant): def __init__(self, config): """ Construct a PyThaiNLP-based tokenizer. Note that we always uses the default tokenizer of PyThaiNLP for sentence and word segmentation. Currently this is a CRF model for sentence segmentation and a dictionary-based model (newmm) for word segmentation. """ if config['lang'] != 'th': raise Exception("PyThaiNLP tokenizer is only allowed in Thai pipeline.") check_pythainlp() from pythainlp.tokenize import sent_tokenize as pythai_sent_tokenize from pythainlp.tokenize import word_tokenize as pythai_word_tokenize self.pythai_sent_tokenize = pythai_sent_tokenize self.pythai_word_tokenize = pythai_word_tokenize self.no_ssplit = config.get('no_ssplit', False) def process(self, document): """ Tokenize a document with the PyThaiNLP tokenizer and wrap the results into a Doc object. """ if isinstance(document, doc.Document): text = document.text else: text = document if not isinstance(text, str): raise Exception("Must supply a string or Stanza Document object to the PyThaiNLP tokenizer.") sentences = [] current_sentence = [] offset = 0 if self.no_ssplit: # skip sentence segmentation sent_strs = [text] else: sent_strs = self.pythai_sent_tokenize(text, engine='crfcut') for sent_str in sent_strs: for token_str in self.pythai_word_tokenize(sent_str, engine='newmm'): # by default pythainlp will output whitespace as a token # we need to skip these tokens to be consistent with other tokenizers if token_str.isspace(): offset += len(token_str) continue # create token entry token_entry = { doc.TEXT: token_str, doc.MISC: f"{doc.START_CHAR}={offset}|{doc.END_CHAR}={offset+len(token_str)}" } current_sentence.append(token_entry) offset += len(token_str) # finish sentence sentences.append(current_sentence) current_sentence = [] if len(current_sentence) > 0: sentences.append(current_sentence) return doc.Document(sentences, text) ================================================ FILE: stanza/pipeline/external/spacy.py ================================================ """ Processors related to spaCy in the pipeline. """ from stanza.models.common import doc from stanza.pipeline._constants import TOKENIZE from stanza.pipeline.processor import ProcessorVariant, register_processor_variant def check_spacy(): """ Import necessary components from spaCy to perform tokenization. """ try: import spacy except ImportError: raise ImportError( "spaCy is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions." ) return True @register_processor_variant(TOKENIZE, 'spacy') class SpacyTokenizer(ProcessorVariant): def __init__(self, config): """ Construct a spaCy-based tokenizer by loading the spaCy pipeline. """ if config['lang'] != 'en': raise Exception("spaCy tokenizer is currently only allowed in English pipeline.") try: import spacy from spacy.lang.en import English except ImportError: raise ImportError( "spaCy 2.0+ is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions." ) # Create a Tokenizer with the default settings for English # including punctuation rules and exceptions self.nlp = English() # by default spacy uses dependency parser to do ssplit # we need to add a sentencizer for fast rule-based ssplit if spacy.__version__.startswith("2."): self.nlp.add_pipe(self.nlp.create_pipe("sentencizer")) else: self.nlp.add_pipe("sentencizer") self.no_ssplit = config.get('no_ssplit', False) def process(self, document): """ Tokenize a document with the spaCy tokenizer and wrap the results into a Doc object. """ if isinstance(document, doc.Document): text = document.text else: text = document if not isinstance(text, str): raise Exception("Must supply a string or Stanza Document object to the spaCy tokenizer.") spacy_doc = self.nlp(text) sentences = [] for sent in spacy_doc.sents: tokens = [] for tok in sent: token_entry = { doc.TEXT: tok.text, doc.MISC: f"{doc.START_CHAR}={tok.idx}|{doc.END_CHAR}={tok.idx+len(tok.text)}" } tokens.append(token_entry) sentences.append(tokens) # if no_ssplit is set, flatten all the sentences into one sentence if self.no_ssplit: sentences = [[t for s in sentences for t in s]] return doc.Document(sentences, text) ================================================ FILE: stanza/pipeline/external/sudachipy.py ================================================ """ Processors related to SudachiPy in the pipeline. GitHub Home: https://github.com/WorksApplications/SudachiPy """ import re from stanza.models.common import doc from stanza.pipeline._constants import TOKENIZE from stanza.pipeline.processor import ProcessorVariant, register_processor_variant def check_sudachipy(): """ Import necessary components from SudachiPy to perform tokenization. """ try: import sudachipy import sudachidict_core except ImportError: raise ImportError( "Both sudachipy and sudachidict_core libraries are required. " "Try install them with `pip install sudachipy sudachidict_core`. " "Go to https://github.com/WorksApplications/SudachiPy for more information." ) return True @register_processor_variant(TOKENIZE, 'sudachipy') class SudachiPyTokenizer(ProcessorVariant): def __init__(self, config): """ Construct a SudachiPy-based tokenizer. Note that this tokenizer uses regex for sentence segmentation. """ if config['lang'] != 'ja': raise Exception("SudachiPy tokenizer is only allowed in Japanese pipelines.") check_sudachipy() from sudachipy import tokenizer from sudachipy import dictionary self.tokenizer = dictionary.Dictionary().create() self.no_ssplit = config.get('no_ssplit', False) def process(self, document): """ Tokenize a document with the SudachiPy tokenizer and wrap the results into a Doc object. """ if isinstance(document, doc.Document): text = document.text else: text = document if not isinstance(text, str): raise Exception("Must supply a string or Stanza Document object to the SudachiPy tokenizer.") # we use the default sudachipy tokenization mode (i.e., mode C) # more config needs to be added to support other modes tokens = self.tokenizer.tokenize(text) sentences = [] current_sentence = [] for token in tokens: token_text = token.surface() # by default sudachipy will output whitespace as a token # we need to skip these tokens to be consistent with other tokenizers if token_text.isspace(): continue start = token.begin() end = token.end() token_entry = { doc.TEXT: token_text, doc.MISC: f"{doc.START_CHAR}={start}|{doc.END_CHAR}={end}" } current_sentence.append(token_entry) if not self.no_ssplit and token_text in ['。', '!', '?', '!', '?']: sentences.append(current_sentence) current_sentence = [] if len(current_sentence) > 0: sentences.append(current_sentence) return doc.Document(sentences, text) ================================================ FILE: stanza/pipeline/langid_processor.py ================================================ """ Processor for determining language of text. """ import emoji import re import stanza import torch from stanza.models.common.doc import Document from stanza.models.langid.model import LangIDBiLSTM from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor @register_processor(name=LANGID) class LangIDProcessor(UDProcessor): """ Class for detecting language of text. """ # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([LANGID]) # set of processor requirements for this processor REQUIRES_DEFAULT = set([]) # default max sequence length MAX_SEQ_LENGTH_DEFAULT = 1000 def _set_up_model(self, config, pipeline, device): batch_size = config.get("batch_size", 64) self._model = LangIDBiLSTM.load(path=config["model_path"], device=device, batch_size=batch_size, lang_subset=config.get("lang_subset")) self._char_index = self._model.char_to_idx self._clean_text = config.get("clean_text") def _text_to_tensor(self, docs): """ Map list of strings to batch tensor. Assumed all docs are same length. """ device = next(self._model.parameters()).device all_docs = [] for doc in docs: doc_chars = [self._char_index.get(c, self._char_index["UNK"]) for c in list(doc)] all_docs.append(doc_chars) return torch.tensor(all_docs, device=device, dtype=torch.long) def _id_langs(self, batch_tensor): """ Identify languages for each sequence in a batch tensor """ predictions = self._model.prediction_scores(batch_tensor) prediction_labels = [self._model.idx_to_tag[prediction] for prediction in predictions] return prediction_labels # regexes for cleaning text http_regex = re.compile(r"https?:\/\/t\.co/[a-zA-Z0-9]+") handle_regex = re.compile("@[a-zA-Z0-9_]+") hashtag_regex = re.compile("#[a-zA-Z]+") punctuation_regex = re.compile("[!.]+") all_regexes = [http_regex, handle_regex, hashtag_regex, punctuation_regex] @staticmethod def clean_text(text): """ Process text to improve language id performance. Main emphasis is on tweets, this method removes shortened urls, hashtags, handles, and punctuation and emoji. """ for regex in LangIDProcessor.all_regexes: text = regex.sub(" ", text) text = emoji.emojize(text) text = emoji.replace_emoji(text, replace=' ') if text.strip(): text = text.strip() return text def _process_list(self, docs): """ Identify language of list of strings or Documents """ if len(docs) == 0: # TO DO: what standard do we want for bad input, such as empty list? # TO DO: more handling of bad input return if isinstance(docs[0], str): docs = [Document([], text) for text in docs] docs_by_length = {} for doc in docs: text = LangIDProcessor.clean_text(doc.text) if self._clean_text else doc.text doc_length = len(text) if doc_length not in docs_by_length: docs_by_length[doc_length] = [] docs_by_length[doc_length].append((doc, text)) for doc_length in docs_by_length: inputs = [doc[1] for doc in docs_by_length[doc_length]] predictions = self._id_langs(self._text_to_tensor(inputs)) for doc, lang in zip(docs_by_length[doc_length], predictions): doc[0].lang = lang return docs def process(self, doc): """ Handle single str or Document """ wrapped_doc = [doc] return self._process_list(wrapped_doc)[0] def bulk_process(self, docs): """ Handle list of strings or Documents """ return self._process_list(docs) ================================================ FILE: stanza/pipeline/lemma_processor.py ================================================ """ Processor for performing lemmatization """ from itertools import compress import torch from stanza.models.common import doc from stanza.models.lemma.data import DataLoader from stanza.models.lemma.trainer import Trainer from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor WORD_TAGS = [doc.TEXT, doc.UPOS] @register_processor(name=LEMMA) class LemmaProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([LEMMA]) # set of processor requirements for this processor # pos will be added later for non-identity lemmatizerx REQUIRES_DEFAULT = set([TOKENIZE]) # default batch size DEFAULT_BATCH_SIZE = 5000 def __init__(self, config, pipeline, device): # run lemmatizer in identity mode self._use_identity = None self._pretagged = None super().__init__(config, pipeline, device) @property def use_identity(self): return self._use_identity def _set_up_model(self, config, pipeline, device): if config.get('use_identity') in ['True', True]: self._use_identity = True self._config = config self.config['batch_size'] = LemmaProcessor.DEFAULT_BATCH_SIZE else: # the lemmatizer only looks at one word when making # decisions, not the surrounding context # therefore, we can save some time by remembering what # we did the last time we saw any given word,pos # since a long running program will remember everything # (unless we go back and make it smarter) # we make this an option, not the default # TODO: need to update the cache to skip the contextual lemmatizer self.store_results = config.get('store_results', False) self._use_identity = False args = {'charlm_forward_file': config.get('forward_charlm_path', None), 'charlm_backward_file': config.get('backward_charlm_path', None)} lemma_classifier_args = dict(args) lemma_classifier_args['wordvec_pretrain_file'] = config.get('pretrain_path', None) self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache, lemma_classifier_args=lemma_classifier_args) def _set_up_requires(self): self._pretagged = self._config.get('pretagged', None) if self._pretagged: self._requires = set() elif self.config.get('pos') and not self.use_identity: self._requires = LemmaProcessor.REQUIRES_DEFAULT.union(set([POS])) else: self._requires = LemmaProcessor.REQUIRES_DEFAULT def process(self, document): if not self.use_identity: batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True) else: batch = DataLoader(document, self.config['batch_size'], self.config, evaluation=True, conll_only=True) if self.use_identity: preds = [word.text for sent in batch.doc.sentences for word in sent.words] elif self.config.get('dict_only', False): preds = self.trainer.predict_dict(batch.doc.get([doc.TEXT, doc.UPOS])) else: if self.config.get('ensemble_dict', False): # skip the seq2seq model when we can skip = self.trainer.skip_seq2seq(batch.doc.get([doc.TEXT, doc.UPOS])) # although there is no explicit use of caseless or lemma_caseless in this processor, # it shows up in the config which gets passed to the DataLoader, # possibly affecting its results seq2seq_batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, skip=skip, expand_unk_vocab=True) else: seq2seq_batch = batch with torch.no_grad(): preds = [] edits = [] for i, b in enumerate(seq2seq_batch): ps, es = self.trainer.predict(b, self.config['beam_size'], seq2seq_batch.vocab) preds += ps if es is not None: edits += es if self.config.get('ensemble_dict', False): word_tags = batch.doc.get(WORD_TAGS) words = [x[0] for x in word_tags] preds = self.trainer.postprocess([x for x, y in zip(words, skip) if not y], preds, edits=edits) if self.store_results: new_word_tags = compress(word_tags, map(lambda x: not x, skip)) new_predictions = [(x[0], x[1], y) for x, y in zip(new_word_tags, preds)] self.trainer.train_dict(new_predictions, update_word_dict=False) # expand seq2seq predictions to the same size as all words i = 0 preds1 = [] for s in skip: if s: preds1.append('') else: preds1.append(preds[i]) i += 1 preds = self.trainer.ensemble(word_tags, preds1) else: preds = self.trainer.postprocess(batch.doc.get([doc.TEXT]), preds, edits=edits) if self.trainer.has_contextual_lemmatizers(): preds = self.trainer.update_contextual_preds(batch.doc, preds) # map empty string lemmas to '_' preds = [max([(len(x), x), (0, '_')])[1] for x in preds] batch.doc.set([doc.LEMMA], preds) return batch.doc ================================================ FILE: stanza/pipeline/morphseg_processor.py ================================================ from stanza.pipeline.core import UnsupportedProcessorError from stanza.pipeline.processor import UDProcessor, register_processor from stanza.pipeline._constants import MORPHSEG, TOKENIZE @register_processor(name=MORPHSEG) class MorphSegProcessor(UDProcessor): PROVIDES_DEFAULT = {MORPHSEG} REQUIRES_DEFAULT = {TOKENIZE} def __init__(self, config, pipeline, device): self._config = config self._pipeline = pipeline self._set_up_requires() self._set_up_provides() self._set_up_model(config, pipeline, device) def _set_up_model(self, config, pipeline, device): try: from morphseg import MorphemeSegmenter except ImportError: raise ImportError( "morphseg is required for morpheme segmentation. " "Install it with: pip install morphseg" ) lang = config.get('lang', 'en') model_path = config.get('morphseg_model_path', None) if model_path: self._segmenter = MorphemeSegmenter( lang=lang, load_pretrained=False, model_filepath=model_path, is_local=True ) else: self._segmenter = MorphemeSegmenter( lang=lang, load_pretrained=True ) if self._segmenter.sequence_labeller is None: raise UnsupportedProcessorError("morphseg", lang) def process(self, document): # Collect all words from all sentences all_words = [] word_mapping = [] # Track which sentence and word index each prediction belongs to for sent_idx, sent in enumerate(document.sentences): if not sent.words: continue for word_idx, word in enumerate(sent.words): all_words.append(word.text) word_mapping.append((sent_idx, word_idx)) if not all_words: return document # Prepare input for morphseg (it expects normalized, lowercased character lists) word_char_lists = [ list(self._segmenter.normalize_for_morphology(word)) for word in all_words ] # Batch predict using the internal sequence_labeller predictions = self._segmenter.sequence_labeller.predict(sources=word_char_lists) # Extract segmentations from predictions from morphseg.training.oracle import rules2sent segmentations = [ rules2sent( source=[align_pos.symbol for align_pos in pred.alignment], actions=pred.prediction ).split(' @@') # Split by morphseg's default delimiter for pred in predictions ] # Assign segmentations back to words for (sent_idx, word_idx), seg in zip(word_mapping, segmentations): document.sentences[sent_idx].words[word_idx].morphemes = seg return document ================================================ FILE: stanza/pipeline/multilingual.py ================================================ """ Class for running multilingual pipelines """ from collections import OrderedDict import copy import logging from typing import Union from stanza.models.common.doc import Document from stanza.models.common.utils import default_device from stanza.pipeline.core import Pipeline, DownloadMethod from stanza.pipeline._constants import * from stanza.resources.common import DEFAULT_MODEL_DIR, get_language_resources, load_resources_json logger = logging.getLogger('stanza') class MultilingualPipeline: """ Pipeline for handling multilingual data. Takes in text, detects language, and routes request to pipeline for that language. You can specify options to individual language pipelines with the lang_configs field. For example, if you want English pipelines to have NER, but want to turn that off for French, you can do: lang_configs = {"en": {"processors": "tokenize,pos,lemma,depparse,ner"}, "fr": {"processors": "tokenize,pos,lemma,depparse"}} pipeline = MultilingualPipeline(lang_configs=lang_configs) You can also pass in a defaultdict created in such a way that it provides default parameters for each language. For example, in order to only get tokenization for each language: (remembering that the Pipeline will automagically add MWT to a language which uses MWT): from collections import defaultdict lang_configs = defaultdict(lambda: dict(processors="tokenize")) pipeline = MultilingualPipeline(lang_configs=lang_configs) download_method can be set as in Pipeline to turn off downloading of the .json config or turn off downloading of everything """ def __init__(self, model_dir: str = DEFAULT_MODEL_DIR, lang_id_config: dict = None, lang_configs: dict = None, ld_batch_size: int = 64, max_cache_size: int = 10, use_gpu: bool = None, restrict: bool = False, device: str = None, download_method: DownloadMethod = DownloadMethod.DOWNLOAD_RESOURCES, # python 3.6 compatibility - maybe want to update to 3.7 at some point processors: Union[str, list] = None, ): # set up configs and cache for various language pipelines self.model_dir = model_dir self.lang_id_config = {} if lang_id_config is None else copy.deepcopy(lang_id_config) self.lang_configs = {} if lang_configs is None else copy.deepcopy(lang_configs) self.max_cache_size = max_cache_size # OrderedDict so we can use it as a LRU cache # most recent Pipeline goes to the end, pop the oldest one # when we run out of space self.pipeline_cache = OrderedDict() if processors is None: self.default_processors = None elif isinstance(processors, str): self.default_processors = [x.strip() for x in processors.split(",")] else: self.default_processors = list(processors) self.download_method = download_method if 'download_method' not in self.lang_id_config: self.lang_id_config['download_method'] = self.download_method # if lang is not in any of the lang_configs, update them to # include the lang parameter. otherwise, the default language # will always be used... for lang in self.lang_configs: if 'lang' not in self.lang_configs[lang]: self.lang_configs[lang]['lang'] = lang if restrict and 'langid_lang_subset' not in self.lang_id_config: known_langs = sorted(self.lang_configs.keys()) if known_langs == 0: logger.warning("MultilingualPipeline asked to restrict to lang_configs, but lang_configs was empty. Ignoring...") else: logger.debug("Restricting MultilingualPipeline to %s", known_langs) self.lang_id_config['langid_lang_subset'] = known_langs # set use_gpu if device is None: if use_gpu is None or use_gpu == True: device = default_device() else: device = 'cpu' self.device = device # build language id pipeline self.lang_id_pipeline = Pipeline(dir=self.model_dir, lang='multilingual', processors="langid", device=self.device, **self.lang_id_config) # load the resources so that we can refer to it later when building a new pipeline # note that it was either downloaded or not based on download_method when building the lang_id_pipeline self.resources = load_resources_json(self.model_dir) def _update_pipeline_cache(self, lang): """ Do any necessary updates to the pipeline cache for this language. This includes building a new pipeline for the lang, and possibly clearing out a language with the old last access date. """ # update request history if lang in self.pipeline_cache: self.pipeline_cache.move_to_end(lang, last=True) # update language configs # try/except to allow for a defaultdict try: lang_config = self.lang_configs[lang] except KeyError: lang_config = {'lang': lang} self.lang_configs[lang] = lang_config # if a defaultdict is passed in, the defaultdict might not contain 'lang' # so even though we tried adding 'lang' in the constructor, we'll check again here if 'lang' not in lang_config: lang_config['lang'] = lang if 'download_method' not in lang_config: lang_config['download_method'] = self.download_method if 'processors' not in lang_config: if self.default_processors: lang_resources = get_language_resources(self.resources, lang) lang_processors = [x for x in self.default_processors if x in lang_resources] if lang_processors != self.default_processors: logger.info("Not all requested processors %s available for %s. Loading %s instead", self.default_processors, lang, lang_processors) lang_config['processors'] = ",".join(lang_processors) if 'device' not in lang_config: lang_config['device'] = self.device # update pipeline cache if lang not in self.pipeline_cache: logger.debug("Loading unknown language in MultilingualPipeline: %s", lang) # clear least recently used lang from pipeline cache if len(self.pipeline_cache) == self.max_cache_size: self.pipeline_cache.popitem(last=False) self.pipeline_cache[lang] = Pipeline(dir=self.model_dir, **self.lang_configs[lang]) def process(self, doc): """ Run language detection on a string, a Document, or a list of either, route to language specific pipeline """ # only return a list if given a list singleton_input = not isinstance(doc, list) if singleton_input: docs = [doc] else: docs = doc if docs and isinstance(docs[0], str): docs = [Document([], text=text) for text in docs] # run language identification docs_w_langid = self.lang_id_pipeline.process(docs) # create language specific batches, store global idx with each doc lang_batches = {} for doc_idx, doc in enumerate(docs_w_langid): logger.debug("Language for document %d: %s", doc_idx, doc.lang) if doc.lang not in lang_batches: lang_batches[doc.lang] = [] lang_batches[doc.lang].append(doc) # run through each language, submit a batch to the language specific pipeline for lang in lang_batches.keys(): self._update_pipeline_cache(lang) self.pipeline_cache[lang](lang_batches[lang]) # only return a list if given a list if singleton_input: return docs_w_langid[0] else: return docs_w_langid def __call__(self, doc): doc = self.process(doc) return doc ================================================ FILE: stanza/pipeline/mwt_processor.py ================================================ """ Processor for performing multi-word-token expansion """ import io import torch from stanza.models.mwt.data import DataLoader from stanza.models.mwt.trainer import Trainer from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor @register_processor(MWT) class MWTProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([MWT]) # set of processor requirements for this processor REQUIRES_DEFAULT = set([TOKENIZE]) def _set_up_model(self, config, pipeline, device): self._trainer = Trainer(model_file=config['model_path'], device=device) def build_batch(self, document): return DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True) def process(self, document): batch = self.build_batch(document) # process the rest expansions = batch.doc.get_mwt_expansions(evaluation=True) if len(batch) > 0: # decide trainer type and run eval if self.config['dict_only']: preds = self.trainer.predict_dict(expansions) else: with torch.no_grad(): preds = [] for i, b in enumerate(batch.to_loader()): preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab) if self.config.get('ensemble_dict', False): preds = self.trainer.ensemble(expansions, preds) else: # skip eval if dev data does not exist preds = [] batch.doc.set_mwt_expansions(preds, process_manual_expanded=False) return batch.doc def bulk_process(self, docs): """ MWT processor counts some statistics on the individual docs, so we need to separately redo those stats """ docs = super().bulk_process(docs) for doc in docs: doc._count_words() return docs ================================================ FILE: stanza/pipeline/ner_processor.py ================================================ """ Processor for performing named entity tagging. """ import torch import logging from stanza.models.common import doc from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError from stanza.models.common.utils import unsort from stanza.models.ner.data import DataLoader from stanza.models.ner.trainer import Trainer from stanza.models.ner.utils import merge_tags from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor logger = logging.getLogger('stanza') @register_processor(name=NER) class NERProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([NER]) # set of processor requirements for this processor REQUIRES_DEFAULT = set([TOKENIZE]) def _get_dependencies(self, config, dep_name): dependencies = config.get(dep_name, None) if dependencies is not None: dependencies = dependencies.split(";") dependencies = [x if x else None for x in dependencies] else: dependencies = [x.get(dep_name) for x in config.get('dependencies', [])] return dependencies def _set_up_model(self, config, pipeline, device): # set up trainer model_paths = config.get('model_path') if isinstance(model_paths, str): model_paths = model_paths.split(";") charlm_forward_files = self._get_dependencies(config, 'forward_charlm_path') charlm_backward_files = self._get_dependencies(config, 'backward_charlm_path') pretrain_files = self._get_dependencies(config, 'pretrain_path') # allow predict_tagset to be specified as an int # (which only applies to the first model) # or as a string ";" separated list of ints self._predict_tagset = {} predict_tagset = config.get('predict_tagset', None) if predict_tagset: if isinstance(predict_tagset, int): self._predict_tagset[0] = predict_tagset else: predict_tagset = predict_tagset.split(";") for piece_idx, piece in enumerate(predict_tagset): if piece: self._predict_tagset[piece_idx] = int(piece) self.trainers = [] for (model_path, pretrain_path, charlm_forward, charlm_backward) in zip(model_paths, pretrain_files, charlm_forward_files, charlm_backward_files): logger.debug("Loading %s with pretrain %s, forward charlm %s, backward charlm %s", model_path, pretrain_path, charlm_forward, charlm_backward) pretrain = pipeline.foundation_cache.load_pretrain(pretrain_path) if pretrain_path else None args = {'charlm_forward_file': charlm_forward, 'charlm_backward_file': charlm_backward} predict_tagset = self._predict_tagset.get(len(self.trainers), None) if predict_tagset is not None: args['predict_tagset'] = predict_tagset try: trainer = Trainer(args=args, model_file=model_path, pretrain=pretrain, device=device, foundation_cache=pipeline.foundation_cache) except ForwardCharlmNotFoundError as e: raise ForwardCharlmNotFoundError("Could not find the forward charlm %s. Please specify the correct path with ner_forward_charlm_path" % e.filename, e.filename) from None except BackwardCharlmNotFoundError as e: raise BackwardCharlmNotFoundError("Could not find the backward charlm %s. Please specify the correct path with ner_backward_charlm_path" % e.filename, e.filename) from None self.trainers.append(trainer) self._trainer = self.trainers[0] self.model_paths = model_paths def _set_up_final_config(self, config): """ Finalize the configurations for this processor, based off of values from a UD model. """ # set configurations from loaded model if len(self.trainers) == 0: raise RuntimeError("Somehow there are no models loaded!") self._vocab = self.trainers[0].vocab self.configs = [] for trainer in self.trainers: loaded_args = trainer.args # filter out unneeded args from model loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)} loaded_args.update(config) self.configs.append(loaded_args) self._config = self.configs[0] def __str__(self): return "NERProcessor(%s)" % ";".join(self.model_paths) def mark_inactive(self): """ Drop memory intensive resources if keeping this processor around for reasons other than running it. """ super().mark_inactive() self.trainers = None def process(self, document): with torch.no_grad(): all_preds = [] for trainer, config in zip(self.trainers, self.configs): # set up a eval-only data loader and skip tag preprocessing batch = DataLoader(document, config['batch_size'], config, vocab=trainer.vocab, evaluation=True, preprocess_tags=False, bert_tokenizer=trainer.model.bert_tokenizer) preds = [] for i, b in enumerate(batch): preds += trainer.predict(b) all_preds.append(preds) # for each sentence, gather a list of predictions # merge those predictions into a single list # earlier models will have precedence preds = [merge_tags(*x) for x in zip(*all_preds)] batch.doc.set([doc.NER], [y for x in preds for y in x], to_token=True) batch.doc.set([doc.MULTI_NER], [tuple(y) for x in zip(*all_preds) for y in zip(*x)], to_token=True) # collect entities into document attribute total = len(batch.doc.build_ents()) logger.debug(f'{total} entities found in document.') return batch.doc def bulk_process(self, docs): """ NER processor has a collation step after running inference """ docs = super().bulk_process(docs) for doc in docs: doc.build_ents() return docs def get_known_tags(self, model_idx=0): """ Return the tags known by this model Removes the S-, B-, etc, and does not include O Specify model_idx if the processor has more than one model """ return self.trainers[model_idx].get_known_tags() ================================================ FILE: stanza/pipeline/pos_processor.py ================================================ """ Processor for performing part-of-speech tagging """ import torch from stanza.models.common import doc from stanza.models.common.utils import unsort from stanza.models.common.vocab import VOCAB_PREFIX, CompositeVocab from stanza.models.pos.data import Dataset from stanza.models.pos.trainer import Trainer from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() @register_processor(name=POS) class POSProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([POS]) # set of processor requirements for this processor REQUIRES_DEFAULT = set([TOKENIZE]) def _set_up_model(self, config, pipeline, device): # get pretrained word vectors self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None args = {'charlm_forward_file': config.get('forward_charlm_path', None), 'charlm_backward_file': config.get('backward_charlm_path', None)} # set up trainer self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], device=device, args=args, foundation_cache=pipeline.foundation_cache) self._tqdm = 'tqdm' in config and config['tqdm'] def __str__(self): return "POSProcessor(%s)" % self.config['model_path'] def get_known_xpos(self): """ Returns the xpos tags known by this model """ if isinstance(self.vocab['xpos'], CompositeVocab): if len(self.vocab['xpos']) == 1: return [k for k in self.vocab['xpos'][0]._unit2id.keys() if k not in VOCAB_PREFIX] else: return {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['xpos']._unit2id.items()} return [k for k in self.vocab['xpos']._unit2id.keys() if k not in VOCAB_PREFIX] def is_composite_xpos(self): """ Returns if the xpos tags are part of a composite vocab """ return isinstance(self.vocab['xpos'], CompositeVocab) def get_known_upos(self): """ Returns the upos tags known by this model """ keys = [k for k in self.vocab['upos']._unit2id.keys() if k not in VOCAB_PREFIX] return keys def get_known_feats(self): """ Returns the features known by this model """ values = {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['feats']._unit2id.items()} return values def process(self, document): # currently, POS models are saved w/o the batch_maximum_tokens flag maximum_tokens = self.config.get('batch_maximum_tokens', 5000) dataset = Dataset( document, self.config, self.pretrain, vocab=self.vocab, evaluation=True, sort_during_eval=True) batch = iter(dataset.to_length_limited_loader(batch_size=self.config['batch_size'], maximum_tokens=maximum_tokens)) preds = [] idx = [] with torch.no_grad(): if self._tqdm: batch = tqdm(batch) for i, b in enumerate(batch): idx.extend(b[-1]) preds += self.trainer.predict(b) preds = unsort(preds, idx) dataset.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x]) return dataset.doc ================================================ FILE: stanza/pipeline/processor.py ================================================ """ Base classes for processors """ from abc import ABC, abstractmethod from stanza.models.common.doc import Document from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS class ProcessorRequirementsException(Exception): """ Exception indicating a processor's requirements will not be met """ def __init__(self, processors_list, err_processor, provided_reqs): self._err_processor = err_processor # mark the broken processor as inactive, drop resources self.err_processor.mark_inactive() self._processors_list = processors_list self._provided_reqs = provided_reqs self.build_message() @property def err_processor(self): """ The processor that raised the exception """ return self._err_processor @property def processor_type(self): return type(self.err_processor).__name__ @property def processors_list(self): return self._processors_list @property def provided_reqs(self): return self._provided_reqs def build_message(self): self.message = (f"---\nPipeline Requirements Error!\n" f"\tProcessor: {self.processor_type}\n" f"\tPipeline processors list: {','.join(self.processors_list)}\n" f"\tProcessor Requirements: {self.err_processor.requires}\n" f"\t\t- fulfilled: {self.err_processor.requires.intersection(self.provided_reqs)}\n" f"\t\t- missing: {self.err_processor.requires - self.provided_reqs}\n" f"\nThe processors list provided for this pipeline is invalid. Please make sure all " f"prerequisites are met for every processor.\n\n") def __str__(self): return self.message class Processor(ABC): """ Base class for all processors """ def __init__(self, config, pipeline, device): # overall config for the processor self._config = config # pipeline building this processor (presently processors are only meant to exist in one pipeline) self._pipeline = pipeline self._set_up_variants(config, device) # run set up process # set up what annotations are required based on config if not self._set_up_variant_requires(): self._set_up_requires() # set up what annotations are provided based on config self._set_up_provides() # given pipeline constructing this processor, check if requirements are met, throw exception if not self._check_requirements() if hasattr(self, '_variant') and self._variant.OVERRIDE: self.process = self._variant.process def __str__(self): """ Simple description of the processor: name(model) """ name = self.__class__.__name__ model = None if self._config is not None: model = self._config.get('model_path') if model is None: return name else: return "{}({})".format(name, model) @abstractmethod def process(self, doc): """ Process a Document. This is the main method of a processor. """ pass def bulk_process(self, docs): """ Process a list of Documents. This should be replaced with a more efficient implementation if possible. """ if hasattr(self, '_variant'): return self._variant.bulk_process(docs) return [self.process(doc) for doc in docs] def _set_up_provides(self): """ Set up what processor requirements this processor fulfills. Default is to use a class defined list. """ self._provides = self.__class__.PROVIDES_DEFAULT def _set_up_requires(self): """ Set up requirements for this processor. Default is to use a class defined list. """ self._requires = self.__class__.REQUIRES_DEFAULT def _set_up_variant_requires(self): """ If this has a variant with its own requirements, use those instead Returns True iff the _requires is set from the _variant """ if not hasattr(self, '_variant'): return False if hasattr(self._variant, '_set_up_requires'): self._variant._set_up_requires() self._requires = self._variant._requires return True if hasattr(self._variant.__class__, 'REQUIRES_DEFAULT'): self._requires = self._variant.__class__.REQUIRES_DEFAULT return True return False def _set_up_variants(self, config, device): processor_name = list(self.__class__.PROVIDES_DEFAULT)[0] if any(config.get(f'with_{variant}', False) for variant in PROCESSOR_VARIANTS[processor_name]): self._trainer = None variant_name = [variant for variant in PROCESSOR_VARIANTS[processor_name] if config.get(f'with_{variant}', False)][0] self._variant = PROCESSOR_VARIANTS[processor_name][variant_name](config) @property def config(self): """ Configurations for the processor """ return self._config @property def pipeline(self): """ The pipeline that this processor belongs to """ return self._pipeline @property def provides(self): return self._provides @property def requires(self): return self._requires def _check_requirements(self): """ Given a list of fulfilled requirements, check if all of this processor's requirements are met or not. """ if not self.config.get("check_requirements", True): return provided_reqs = set.union(*[processor.provides for processor in self.pipeline.loaded_processors]+[set([])]) if self.requires - provided_reqs: load_names = [item[0] for item in self.pipeline.load_list] raise ProcessorRequirementsException(load_names, self, provided_reqs) class ProcessorVariant(ABC): """ Base class for all processor variants """ OVERRIDE = False # Set to true to override all the processing from the processor @abstractmethod def process(self, doc): """ Process a document that is potentially preprocessed by the processor. This is the main method of a processor variant. If `OVERRIDE` is set to True, all preprocessing by the processor would be bypassed, and the processor variant would serve as a drop-in replacement of the entire processor, and has to be able to interpret all the configs that are typically handled by the processor it replaces. """ pass def bulk_process(self, docs): """ Process a list of Documents. This should be replaced with a more efficient implementation if possible. """ return [self.process(doc) for doc in docs] class UDProcessor(Processor): """ Base class for the neural UD Processors (tokenize,mwt,pos,lemma,depparse,sentiment,constituency) """ def __init__(self, config, pipeline, device): super().__init__(config, pipeline, device) # UD model resources, set up is processor specific self._pretrain = None self._trainer = None self._vocab = None if not hasattr(self, '_variant'): self._set_up_model(config, pipeline, device) # build the final config for the processor self._set_up_final_config(config) @abstractmethod def _set_up_model(self, config, pipeline, device): pass def _set_up_final_config(self, config): """ Finalize the configurations for this processor, based off of values from a UD model. """ # set configurations from loaded model if self._trainer is not None: loaded_args, self._vocab = self._trainer.args, self._trainer.vocab # filter out unneeded args from model loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)} else: loaded_args = {} loaded_args.update(config) self._config = loaded_args def mark_inactive(self): """ Drop memory intensive resources if keeping this processor around for reasons other than running it. """ self._trainer = None self._vocab = None @property def pretrain(self): return self._pretrain @property def trainer(self): return self._trainer @property def vocab(self): return self._vocab @staticmethod def filter_out_option(option): """ Filter out non-processor configurations """ options_to_filter = ['device', 'cpu', 'cuda', 'dev_conll_gold', 'epochs', 'lang', 'mode', 'save_name', 'shorthand'] if option.endswith('_file') or option.endswith('_dir'): return True elif option in options_to_filter: return True else: return False def bulk_process(self, docs): """ Most processors operate on the sentence level, where each sentence is processed independently and processors can benefit a lot from the ability to combine sentences from multiple documents for faster batched processing. This is a transparent implementation that allows these processors to batch process a list of Documents as if they were from a single Document. """ if hasattr(self, '_variant'): return self._variant.bulk_process(docs) combined_sents = [sent for doc in docs for sent in doc.sentences] combined_doc = Document([]) combined_doc.sentences = combined_sents combined_doc.num_tokens = sum(doc.num_tokens for doc in docs) combined_doc.num_words = sum(doc.num_words for doc in docs) self.process(combined_doc) # annotations are attached to sentence objects return docs class ProcessorRegisterException(Exception): """ Exception indicating processor or processor registration failure """ def __init__(self, processor_class, expected_parent): self._processor_class = processor_class self._expected_parent = expected_parent self.build_message() def build_message(self): self.message = f"Failed to register '{self._processor_class}'. It must be a subclass of '{self._expected_parent}'." def __str__(self): return self.message def register_processor(name): def wrapper(Cls): if not issubclass(Cls, Processor): raise ProcessorRegisterException(Cls, Processor) NAME_TO_PROCESSOR_CLASS[name] = Cls PIPELINE_NAMES.append(name) return Cls return wrapper def register_processor_variant(name, variant): def wrapper(Cls): if not issubclass(Cls, ProcessorVariant): raise ProcessorRegisterException(Cls, ProcessorVariant) PROCESSOR_VARIANTS[name][variant] = Cls return Cls return wrapper ================================================ FILE: stanza/pipeline/registry.py ================================================ from collections import defaultdict # these two get filled by register_processor NAME_TO_PROCESSOR_CLASS = dict() PIPELINE_NAMES = [] # this gets filled by register_processor_variant PROCESSOR_VARIANTS = defaultdict(dict) ================================================ FILE: stanza/pipeline/sentiment_processor.py ================================================ """Processor that attaches a sentiment score to a sentence The model used is a generally a model trained on the Stanford Sentiment Treebank or some similar dataset. When run, this processor attaches a score in the form of a string to each sentence in the document. TODO: a possible way to generalize this would be to make it a ClassifierProcessor and have "sentiment" be an option. """ import dataclasses import torch from types import SimpleNamespace from stanza.models.classifiers.trainer import Trainer from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor @register_processor(SENTIMENT) class SentimentProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([SENTIMENT]) # set of processor requirements for this processor # TODO: a constituency based model needs CONSTITUENCY as well # issue: by the time we load the model in Processor.__init__, # the requirements are already prepared REQUIRES_DEFAULT = set([TOKENIZE]) # default batch size, measured in words per batch DEFAULT_BATCH_SIZE = 5000 def _set_up_model(self, config, pipeline, device): # get pretrained word vectors pretrain_path = config.get('pretrain_path', None) forward_charlm_path = config.get('forward_charlm_path', None) backward_charlm_path = config.get('backward_charlm_path', None) # elmo does not have a convenient way to download intermediate # models the way stanza downloads charlms & pretrains or # transformers downloads bert etc # however, elmo in general is not as good as using a # transformer, so it is unlikely we will ever fix this args = SimpleNamespace(device = device, charlm_forward_file = forward_charlm_path, charlm_backward_file = backward_charlm_path, wordvec_pretrain_file = pretrain_path, elmo_model = None, use_elmo = False, save_dir = None) filename = config['model_path'] if filename is None: raise FileNotFoundError("No model specified for the sentiment processor. Perhaps it is not supported for the language. {}".format(config)) # set up model trainer = Trainer.load(filename=filename, args=args, foundation_cache=pipeline.foundation_cache) self._trainer = trainer self._model = trainer.model self._model_type = self._model.config.model_type # batch size counted as words self._batch_size = config.get('batch_size', SentimentProcessor.DEFAULT_BATCH_SIZE) def _set_up_final_config(self, config): loaded_args = dataclasses.asdict(self._model.config) loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)} loaded_args.update(config) self._config = loaded_args def process(self, document): sentences = self._model.extract_sentences(document) with torch.no_grad(): labels = self._model.label_sentences(sentences, batch_size=self._batch_size) # TODO: allow a classifier processor for any attribute, not just sentiment document.set(SENTIMENT, labels, to_sentence=True) return document ================================================ FILE: stanza/pipeline/tokenize_processor.py ================================================ """ Processor for performing tokenization """ import copy import io import logging import torch from stanza.models.tokenization.data import TokenizationDataset from stanza.models.tokenization.trainer import Trainer from stanza.models.tokenization.utils import output_predictions from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor from stanza.pipeline.registry import PROCESSOR_VARIANTS from stanza.models.common import doc # these imports trigger the "register_variant" decorations from stanza.pipeline.external.jieba import JiebaTokenizer from stanza.pipeline.external.spacy import SpacyTokenizer from stanza.pipeline.external.sudachipy import SudachiPyTokenizer from stanza.pipeline.external.pythainlp import PyThaiNLPTokenizer logger = logging.getLogger('stanza') TOKEN_TOO_LONG_REPLACEMENT = "" # class for running the tokenizer @register_processor(name=TOKENIZE) class TokenizeProcessor(UDProcessor): # set of processor requirements this processor fulfills PROVIDES_DEFAULT = set([TOKENIZE]) # set of processor requirements for this processor REQUIRES_DEFAULT = set([]) # default max sequence length MAX_SEQ_LENGTH_DEFAULT = 1000 def _set_up_model(self, config, pipeline, device): # set up trainer if config.get('pretokenized'): self._trainer = None else: args = {'charlm_forward_file': config.get('forward_charlm_path', None)} self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache) # get and typecheck the postprocessor postprocessor = config.get('postprocessor') if postprocessor and callable(postprocessor): self._postprocessor = postprocessor elif not postprocessor: self._postprocessor = None else: raise ValueError("Tokenizer received 'postprocessor' option of unrecognized type; postprocessor must be callable. Got %s" % postprocessor) def process_pre_tokenized_text(self, input_src): """ Pretokenized text can be provided in 2 manners: 1.) str, tokenized by whitespace, sentence split by newline 2.) list of token lists, each token list represents a sentence generate dictionary data structure """ document = [] if isinstance(input_src, str): sentences = [sent.strip().split() for sent in input_src.strip().split('\n') if len(sent.strip()) > 0] elif isinstance(input_src, list): sentences = input_src idx = 0 for sentence in sentences: sent = [] for token_id, token in enumerate(sentence): sent.append({doc.ID: (token_id + 1, ), doc.TEXT: token, doc.MISC: f'start_char={idx}|end_char={idx + len(token)}'}) idx += len(token) + 1 document.append(sent) raw_text = ' '.join([' '.join(sentence) for sentence in sentences]) return raw_text, document def process(self, document): if not (isinstance(document, str) or isinstance(document, doc.Document) or (self.config.get('pretokenized') or self.config.get('no_ssplit', False))): raise ValueError("If neither 'pretokenized' or 'no_ssplit' option is enabled, the input to the TokenizerProcessor must be a string or a Document object. Got %s" % str(type(document))) if isinstance(document, doc.Document): if self.config.get('pretokenized'): return document document = document.text if self.config.get('pretokenized'): raw_text, document = self.process_pre_tokenized_text(document) return doc.Document(document, raw_text) if hasattr(self, '_variant'): return self._variant.process(document) raw_text = '\n\n'.join(document) if isinstance(document, list) else document max_seq_len = self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT) # set up batches batches = TokenizationDataset(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True, dictionary=self.trainer.dictionary) # get dict data with torch.no_grad(): _, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None, max_seq_len, orig_text=raw_text, no_ssplit=self.config.get('no_ssplit', False), num_workers = self.config.get('num_workers', 0), postprocessor = self._postprocessor) # replace excessively long tokens with to avoid downstream GPU memory issues in POS for sentence in document: for token in sentence: if len(token['text']) > max_seq_len: token['text'] = TOKEN_TOO_LONG_REPLACEMENT return doc.Document(document, raw_text) def bulk_process(self, docs): """ The tokenizer cannot use UDProcessor's sentence-level cross-document batching interface, and requires special handling. Essentially, this method concatenates the text of multiple documents with "\n\n", tokenizes it with the neural tokenizer, then splits the result into the original Documents and recovers the original character offsets. """ if hasattr(self, '_variant'): return self._variant.bulk_process(docs) if self.config.get('pretokenized'): res = [] for document in docs: if len(document.sentences) > 0: # perhaps this is a document already tokenized, # being sent back in for more analysis / reparsing / etc? # in that case, no need to try to tokenize it # based on whitespace tokenizing the document text # which, interestingly, may not even exist depending on # how the document was created) # by making a whole deepcopy, the original Document is unchanged res.append(copy.deepcopy(document)) else: raw_text, document = self.process_pre_tokenized_text(document.text) res.append(doc.Document(document, raw_text)) return res combined_text = '\n\n'.join([thisdoc.text for thisdoc in docs]) processed_combined = self.process(doc.Document([], text=combined_text)) # postprocess sentences and tokens to reset back pointers and char offsets charoffset = 0 sentst = senten = 0 for thisdoc in docs: while senten < len(processed_combined.sentences) and processed_combined.sentences[senten].tokens[-1].end_char - charoffset <= len(thisdoc.text): senten += 1 sentences = processed_combined.sentences[sentst:senten] thisdoc.sentences = sentences for sent in sentences: # fix doc back pointers for sentences sent._doc = thisdoc # fix char offsets for tokens and words for token in sent.tokens: token._start_char -= charoffset token._end_char -= charoffset if token.words: # not-yet-processed MWT can leave empty tokens for word in token.words: word._start_char -= charoffset word._end_char -= charoffset # Here we need to fix up the SpacesAfter for the very last token # and the SpacesBefore for the first token of the next doc # After all, we had connected the text with \n\n # Need to be careful about this - in a case such as # " -text one- " # " -text two- " # We want the SpacesBefore for the second document to reflect # the extra space at the start of its text # and the SpacesAfter for the first document to reflect # the whitespace after its text if len(sentences) > 0: last_token = sentences[-1].tokens[-1] last_whitespace = thisdoc.text[last_token.end_char:] last_token.spaces_after = last_whitespace first_token = sentences[0].tokens[0] first_whitespace = thisdoc.text[:first_token.start_char] first_token.spaces_before = first_whitespace thisdoc.num_tokens = sum(len(sent.tokens) for sent in sentences) thisdoc.num_words = sum(len(sent.words) for sent in sentences) sentst = senten charoffset += len(thisdoc.text) + 2 return docs ================================================ FILE: stanza/protobuf/CoreNLP_pb2.py ================================================ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: CoreNLP.proto # Protobuf Python Version: 4.25.5 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rCoreNLP.proto\x12\x19\x65\x64u.stanford.nlp.pipeline\"\xe1\x05\n\x08\x44ocument\x12\x0c\n\x04text\x18\x01 \x02(\t\x12\x35\n\x08sentence\x18\x02 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x39\n\ncorefChain\x18\x03 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.CorefChain\x12\r\n\x05\x64ocID\x18\x04 \x01(\t\x12\x0f\n\x07\x64ocDate\x18\x07 \x01(\t\x12\x10\n\x08\x63\x61lendar\x18\x08 \x01(\x04\x12;\n\x11sentencelessToken\x18\x05 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x33\n\tcharacter\x18\n \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12/\n\x05quote\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x37\n\x08mentions\x18\t \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12#\n\x1bhasEntityMentionsAnnotation\x18\r \x01(\x08\x12\x0e\n\x06xmlDoc\x18\x0b \x01(\x08\x12\x34\n\x08sections\x18\x0c \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Section\x12<\n\x10mentionsForCoref\x18\x0e \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12!\n\x19hasCorefMentionAnnotation\x18\x0f \x01(\x08\x12\x1a\n\x12hasCorefAnnotation\x18\x10 \x01(\x08\x12+\n#corefMentionToEntityMentionMappings\x18\x11 \x03(\x05\x12+\n#entityMentionToCorefMentionMappings\x18\x12 \x03(\x05*\x05\x08\x64\x10\x80\x02\"\xf3\x0f\n\x08Sentence\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x18\n\x10tokenOffsetBegin\x18\x02 \x02(\r\x12\x16\n\x0etokenOffsetEnd\x18\x03 \x02(\r\x12\x15\n\rsentenceIndex\x18\x04 \x01(\r\x12\x1c\n\x14\x63haracterOffsetBegin\x18\x05 \x01(\r\x12\x1a\n\x12\x63haracterOffsetEnd\x18\x06 \x01(\r\x12\x37\n\tparseTree\x18\x07 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x62inarizedParseTree\x18\x1f \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x61nnotatedParseTree\x18 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x11\n\tsentiment\x18! \x01(\t\x12=\n\x0fkBestParseTrees\x18\" \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x45\n\x11\x62\x61sicDependencies\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12I\n\x15\x63ollapsedDependencies\x18\t \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12T\n collapsedCCProcessedDependencies\x18\n \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12K\n\x17\x61lternativeDependencies\x18\r \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12?\n\x0copenieTriple\x18\x0e \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12<\n\tkbpTriple\x18\x10 \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12\x45\n\x10\x65ntailedSentence\x18\x0f \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12\x43\n\x0e\x65ntailedClause\x18# \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12H\n\x14\x65nhancedDependencies\x18\x11 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12P\n\x1c\x65nhancedPlusPlusDependencies\x18\x12 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x33\n\tcharacter\x18\x13 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x11\n\tparagraph\x18\x0b \x01(\r\x12\x0c\n\x04text\x18\x0c \x01(\t\x12\x12\n\nlineNumber\x18\x14 \x01(\r\x12\x1e\n\x16hasRelationAnnotations\x18\x33 \x01(\x08\x12\x31\n\x06\x65ntity\x18\x34 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x35\n\x08relation\x18\x35 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Relation\x12$\n\x1chasNumerizedTokensAnnotation\x18\x36 \x01(\x08\x12\x37\n\x08mentions\x18\x37 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12<\n\x10mentionsForCoref\x18\x38 \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12\"\n\x1ahasCorefMentionsAnnotation\x18\x39 \x01(\x08\x12\x12\n\nsentenceID\x18: \x01(\t\x12\x13\n\x0bsectionDate\x18; \x01(\t\x12\x14\n\x0csectionIndex\x18< \x01(\r\x12\x13\n\x0bsectionName\x18= \x01(\t\x12\x15\n\rsectionAuthor\x18> \x01(\t\x12\r\n\x05\x64ocID\x18? \x01(\t\x12\x15\n\rsectionQuoted\x18@ \x01(\x08\x12#\n\x1bhasEntityMentionsAnnotation\x18\x41 \x01(\x08\x12\x1f\n\x17hasKBPTriplesAnnotation\x18\x44 \x01(\x08\x12\"\n\x1ahasOpenieTriplesAnnotation\x18\x45 \x01(\x08\x12\x14\n\x0c\x63hapterIndex\x18\x42 \x01(\r\x12\x16\n\x0eparagraphIndex\x18\x43 \x01(\r\x12=\n\x10\x65nhancedSentence\x18\x46 \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x0f\n\x07speaker\x18G \x01(\t\x12\x13\n\x0bspeakerType\x18H \x01(\t*\x05\x08\x64\x10\x80\x02\"\xf6\x0c\n\x05Token\x12\x0c\n\x04word\x18\x01 \x01(\t\x12\x0b\n\x03pos\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\x12\x10\n\x08\x63\x61tegory\x18\x04 \x01(\t\x12\x0e\n\x06\x62\x65\x66ore\x18\x05 \x01(\t\x12\r\n\x05\x61\x66ter\x18\x06 \x01(\t\x12\x14\n\x0coriginalText\x18\x07 \x01(\t\x12\x0b\n\x03ner\x18\x08 \x01(\t\x12\x11\n\tcoarseNER\x18> \x01(\t\x12\x16\n\x0e\x66ineGrainedNER\x18? \x01(\t\x12\x15\n\rnerLabelProbs\x18\x42 \x03(\t\x12\x15\n\rnormalizedNER\x18\t \x01(\t\x12\r\n\x05lemma\x18\n \x01(\t\x12\x11\n\tbeginChar\x18\x0b \x01(\r\x12\x0f\n\x07\x65ndChar\x18\x0c \x01(\r\x12\x11\n\tutterance\x18\r \x01(\r\x12\x0f\n\x07speaker\x18\x0e \x01(\t\x12\x13\n\x0bspeakerType\x18M \x01(\t\x12\x12\n\nbeginIndex\x18\x0f \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x10 \x01(\r\x12\x17\n\x0ftokenBeginIndex\x18\x11 \x01(\r\x12\x15\n\rtokenEndIndex\x18\x12 \x01(\r\x12\x34\n\ntimexValue\x18\x13 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x15\n\rhasXmlContext\x18\x15 \x01(\x08\x12\x12\n\nxmlContext\x18\x16 \x03(\t\x12\x16\n\x0e\x63orefClusterID\x18\x17 \x01(\r\x12\x0e\n\x06\x61nswer\x18\x18 \x01(\t\x12\x15\n\rheadWordIndex\x18\x1a \x01(\r\x12\x35\n\x08operator\x18\x1b \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Operator\x12\x35\n\x08polarity\x18\x1c \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Polarity\x12\x14\n\x0cpolarity_dir\x18\' \x01(\t\x12-\n\x04span\x18\x1d \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x11\n\tsentiment\x18\x1e \x01(\t\x12\x16\n\x0equotationIndex\x18\x1f \x01(\x05\x12\x42\n\x0e\x63onllUFeatures\x18 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x11\n\tcoarseTag\x18! \x01(\t\x12\x38\n\x0f\x63onllUTokenSpan\x18\" \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x12\n\nconllUMisc\x18# \x01(\t\x12G\n\x13\x63onllUSecondaryDeps\x18$ \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x17\n\x0fwikipediaEntity\x18% \x01(\t\x12\x11\n\tisNewline\x18& \x01(\x08\x12\x0e\n\x06gender\x18\x33 \x01(\t\x12\x10\n\x08trueCase\x18\x34 \x01(\t\x12\x14\n\x0ctrueCaseText\x18\x35 \x01(\t\x12\x13\n\x0b\x63hineseChar\x18\x36 \x01(\t\x12\x12\n\nchineseSeg\x18\x37 \x01(\t\x12\x16\n\x0e\x63hineseXMLChar\x18< \x01(\t\x12\x11\n\tarabicSeg\x18L \x01(\t\x12\x13\n\x0bsectionName\x18\x38 \x01(\t\x12\x15\n\rsectionAuthor\x18\x39 \x01(\t\x12\x13\n\x0bsectionDate\x18: \x01(\t\x12\x17\n\x0fsectionEndLabel\x18; \x01(\t\x12\x0e\n\x06parent\x18= \x01(\t\x12\x19\n\x11\x63orefMentionIndex\x18@ \x03(\r\x12\x1a\n\x12\x65ntityMentionIndex\x18\x41 \x01(\r\x12\r\n\x05isMWT\x18\x43 \x01(\x08\x12\x12\n\nisFirstMWT\x18\x44 \x01(\x08\x12\x0f\n\x07mwtText\x18\x45 \x01(\t\x12\x0f\n\x07mwtMisc\x18N \x01(\t\x12\x14\n\x0cnumericValue\x18\x46 \x01(\x04\x12\x13\n\x0bnumericType\x18G \x01(\t\x12\x1d\n\x15numericCompositeValue\x18H \x01(\x04\x12\x1c\n\x14numericCompositeType\x18I \x01(\t\x12\x1c\n\x14\x63odepointOffsetBegin\x18J \x01(\r\x12\x1a\n\x12\x63odepointOffsetEnd\x18K \x01(\r\x12\r\n\x05index\x18O \x01(\r\x12\x12\n\nemptyIndex\x18P \x01(\r*\x05\x08\x64\x10\x80\x02\"\xe4\x03\n\x05Quote\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\r\x12\x0b\n\x03\x65nd\x18\x03 \x01(\r\x12\x15\n\rsentenceBegin\x18\x05 \x01(\r\x12\x13\n\x0bsentenceEnd\x18\x06 \x01(\r\x12\x12\n\ntokenBegin\x18\x07 \x01(\r\x12\x10\n\x08tokenEnd\x18\x08 \x01(\r\x12\r\n\x05\x64ocid\x18\t \x01(\t\x12\r\n\x05index\x18\n \x01(\r\x12\x0e\n\x06\x61uthor\x18\x0b \x01(\t\x12\x0f\n\x07mention\x18\x0c \x01(\t\x12\x14\n\x0cmentionBegin\x18\r \x01(\r\x12\x12\n\nmentionEnd\x18\x0e \x01(\r\x12\x13\n\x0bmentionType\x18\x0f \x01(\t\x12\x14\n\x0cmentionSieve\x18\x10 \x01(\t\x12\x0f\n\x07speaker\x18\x11 \x01(\t\x12\x14\n\x0cspeakerSieve\x18\x12 \x01(\t\x12\x18\n\x10\x63\x61nonicalMention\x18\x13 \x01(\t\x12\x1d\n\x15\x63\x61nonicalMentionBegin\x18\x14 \x01(\r\x12\x1b\n\x13\x63\x61nonicalMentionEnd\x18\x15 \x01(\r\x12N\n\x1a\x61ttributionDependencyGraph\x18\x16 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xc7\x01\n\tParseTree\x12\x33\n\x05\x63hild\x18\x01 \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\r\n\x05value\x18\x02 \x01(\t\x12\x17\n\x0fyieldBeginIndex\x18\x03 \x01(\r\x12\x15\n\ryieldEndIndex\x18\x04 \x01(\r\x12\r\n\x05score\x18\x05 \x01(\x01\x12\x37\n\tsentiment\x18\x06 \x01(\x0e\x32$.edu.stanford.nlp.pipeline.Sentiment\"\x9b\x04\n\x0f\x44\x65pendencyGraph\x12=\n\x04node\x18\x01 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\x12=\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\x12\x10\n\x04root\x18\x03 \x03(\rB\x02\x10\x01\x12/\n\x05token\x18\x04 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x14\n\x08rootNode\x18\x05 \x03(\rB\x02\x10\x01\x1aX\n\x04Node\x12\x15\n\rsentenceIndex\x18\x01 \x02(\r\x12\r\n\x05index\x18\x02 \x02(\r\x12\x16\n\x0e\x63opyAnnotation\x18\x03 \x01(\r\x12\x12\n\nemptyIndex\x18\x04 \x01(\r\x1a\xd6\x01\n\x04\x45\x64ge\x12\x0e\n\x06source\x18\x01 \x02(\r\x12\x0e\n\x06target\x18\x02 \x02(\r\x12\x0b\n\x03\x64\x65p\x18\x03 \x01(\t\x12\x0f\n\x07isExtra\x18\x04 \x01(\x08\x12\x12\n\nsourceCopy\x18\x05 \x01(\r\x12\x12\n\ntargetCopy\x18\x06 \x01(\r\x12\x13\n\x0bsourceEmpty\x18\x08 \x01(\r\x12\x13\n\x0btargetEmpty\x18\t \x01(\r\x12>\n\x08language\x18\x07 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.Language:\x07Unknown\"\xc6\x02\n\nCorefChain\x12\x0f\n\x07\x63hainID\x18\x01 \x02(\x05\x12\x43\n\x07mention\x18\x02 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\x12\x16\n\x0erepresentative\x18\x03 \x02(\r\x1a\xc9\x01\n\x0c\x43orefMention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x12\n\nbeginIndex\x18\x06 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x07 \x01(\r\x12\x11\n\theadIndex\x18\t \x01(\r\x12\x15\n\rsentenceIndex\x18\n \x01(\r\x12\x10\n\x08position\x18\x0b \x01(\r\"\xef\x08\n\x07Mention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x0e\n\x06person\x18\x06 \x01(\t\x12\x12\n\nstartIndex\x18\x07 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\t \x01(\r\x12\x11\n\theadIndex\x18\n \x01(\x05\x12\x12\n\nheadString\x18\x0b \x01(\t\x12\x11\n\tnerString\x18\x0c \x01(\t\x12\x13\n\x0boriginalRef\x18\r \x01(\x05\x12\x1a\n\x12goldCorefClusterID\x18\x0e \x01(\x05\x12\x16\n\x0e\x63orefClusterID\x18\x0f \x01(\x05\x12\x12\n\nmentionNum\x18\x10 \x01(\x05\x12\x0f\n\x07sentNum\x18\x11 \x01(\x05\x12\r\n\x05utter\x18\x12 \x01(\x05\x12\x11\n\tparagraph\x18\x13 \x01(\x05\x12\x11\n\tisSubject\x18\x14 \x01(\x08\x12\x16\n\x0eisDirectObject\x18\x15 \x01(\x08\x12\x18\n\x10isIndirectObject\x18\x16 \x01(\x08\x12\x1b\n\x13isPrepositionObject\x18\x17 \x01(\x08\x12\x0f\n\x07hasTwin\x18\x18 \x01(\x08\x12\x0f\n\x07generic\x18\x19 \x01(\x08\x12\x13\n\x0bisSingleton\x18\x1a \x01(\x08\x12\x1a\n\x12hasBasicDependency\x18\x1b \x01(\x08\x12\x1d\n\x15hasEnhancedDependency\x18\x1c \x01(\x08\x12\x1b\n\x13hasContextParseTree\x18\x1d \x01(\x08\x12?\n\x0fheadIndexedWord\x18\x1e \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12=\n\rdependingVerb\x18\x1f \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x38\n\x08headWord\x18 \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12;\n\x0bspeakerInfo\x18! \x01(\x0b\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\x12=\n\rsentenceWords\x18\x32 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12<\n\x0coriginalSpan\x18\x33 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x12\n\ndependents\x18\x34 \x03(\t\x12\x19\n\x11preprocessedTerms\x18\x35 \x03(\t\x12\x13\n\x0b\x61ppositions\x18\x36 \x03(\x05\x12\x1c\n\x14predicateNominatives\x18\x37 \x03(\x05\x12\x18\n\x10relativePronouns\x18\x38 \x03(\x05\x12\x13\n\x0blistMembers\x18\x39 \x03(\x05\x12\x15\n\rbelongToLists\x18: \x03(\x05\"X\n\x0bIndexedWord\x12\x13\n\x0bsentenceNum\x18\x01 \x01(\x05\x12\x12\n\ntokenIndex\x18\x02 \x01(\x05\x12\r\n\x05\x64ocID\x18\x03 \x01(\x05\x12\x11\n\tcopyCount\x18\x04 \x01(\r\"4\n\x0bSpeakerInfo\x12\x13\n\x0bspeakerName\x18\x01 \x01(\t\x12\x10\n\x08mentions\x18\x02 \x03(\x05\"\"\n\x04Span\x12\r\n\x05\x62\x65gin\x18\x01 \x02(\r\x12\x0b\n\x03\x65nd\x18\x02 \x02(\r\"w\n\x05Timex\x12\r\n\x05value\x18\x01 \x01(\t\x12\x10\n\x08\x61ltValue\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0b\n\x03tid\x18\x05 \x01(\t\x12\x12\n\nbeginPoint\x18\x06 \x01(\r\x12\x10\n\x08\x65ndPoint\x18\x07 \x01(\r\"\xdb\x01\n\x06\x45ntity\x12\x11\n\theadStart\x18\x06 \x01(\r\x12\x0f\n\x07headEnd\x18\x07 \x01(\r\x12\x13\n\x0bmentionType\x18\x08 \x01(\t\x12\x16\n\x0enormalizedName\x18\t \x01(\t\x12\x16\n\x0eheadTokenIndex\x18\n \x01(\r\x12\x0f\n\x07\x63orefID\x18\x0b \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb7\x01\n\x08Relation\x12\x0f\n\x07\x61rgName\x18\x06 \x03(\t\x12.\n\x03\x61rg\x18\x07 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x11\n\tsignature\x18\x08 \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb2\x01\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x1b\n\x13quantifierSpanBegin\x18\x02 \x02(\x05\x12\x19\n\x11quantifierSpanEnd\x18\x03 \x02(\x05\x12\x18\n\x10subjectSpanBegin\x18\x04 \x02(\x05\x12\x16\n\x0esubjectSpanEnd\x18\x05 \x02(\x05\x12\x17\n\x0fobjectSpanBegin\x18\x06 \x02(\x05\x12\x15\n\robjectSpanEnd\x18\x07 \x02(\x05\"\xa9\x04\n\x08Polarity\x12K\n\x12projectEquivalence\x18\x01 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectForwardEntailment\x18\x02 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectReverseEntailment\x18\x03 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12H\n\x0fprojectNegation\x18\x04 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12K\n\x12projectAlternation\x18\x05 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12\x45\n\x0cprojectCover\x18\x06 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12L\n\x13projectIndependence\x18\x07 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\"\xdd\x02\n\nNERMention\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12%\n\x1dtokenStartInSentenceInclusive\x18\x02 \x02(\r\x12#\n\x1btokenEndInSentenceExclusive\x18\x03 \x02(\r\x12\x0b\n\x03ner\x18\x04 \x02(\t\x12\x15\n\rnormalizedNER\x18\x05 \x01(\t\x12\x12\n\nentityType\x18\x06 \x01(\t\x12/\n\x05timex\x18\x07 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x17\n\x0fwikipediaEntity\x18\x08 \x01(\t\x12\x0e\n\x06gender\x18\t \x01(\t\x12\x1a\n\x12\x65ntityMentionIndex\x18\n \x01(\r\x12#\n\x1b\x63\x61nonicalEntityMentionIndex\x18\x0b \x01(\r\x12\x19\n\x11\x65ntityMentionText\x18\x0c \x01(\t\"Y\n\x10SentenceFragment\x12\x12\n\ntokenIndex\x18\x01 \x03(\r\x12\x0c\n\x04root\x18\x02 \x01(\r\x12\x14\n\x0c\x61ssumedTruth\x18\x03 \x01(\x08\x12\r\n\x05score\x18\x04 \x01(\x01\":\n\rTokenLocation\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12\x12\n\ntokenIndex\x18\x02 \x01(\r\"\x9a\x03\n\x0eRelationTriple\x12\x0f\n\x07subject\x18\x01 \x01(\t\x12\x10\n\x08relation\x18\x02 \x01(\t\x12\x0e\n\x06object\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x01\x12?\n\rsubjectTokens\x18\r \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12@\n\x0erelationTokens\x18\x0e \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12>\n\x0cobjectTokens\x18\x0f \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12\x38\n\x04tree\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0e\n\x06istmod\x18\t \x01(\x08\x12\x10\n\x08prefixBe\x18\n \x01(\x08\x12\x10\n\x08suffixBe\x18\x0b \x01(\x08\x12\x10\n\x08suffixOf\x18\x0c \x01(\x08\"-\n\x0fMapStringString\x12\x0b\n\x03key\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x03(\t\"*\n\x0cMapIntString\x12\x0b\n\x03key\x18\x01 \x03(\r\x12\r\n\x05value\x18\x02 \x03(\t\"\xfc\x01\n\x07Section\x12\x11\n\tcharBegin\x18\x01 \x02(\r\x12\x0f\n\x07\x63harEnd\x18\x02 \x02(\r\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x17\n\x0fsentenceIndexes\x18\x04 \x03(\r\x12\x10\n\x08\x64\x61tetime\x18\x05 \x01(\t\x12\x30\n\x06quotes\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x17\n\x0f\x61uthorCharBegin\x18\x07 \x01(\r\x12\x15\n\rauthorCharEnd\x18\x08 \x01(\r\x12\x30\n\x06xmlTag\x18\t \x02(\x0b\x32 .edu.stanford.nlp.pipeline.Token\"\xe4\x01\n\x0eSemgrexRequest\x12\x0f\n\x07semgrex\x18\x01 \x03(\t\x12\x45\n\x05query\x18\x02 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\x1az\n\x0c\x44\x65pendencies\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x39\n\x05graph\x18\x02 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xfb\x06\n\x0fSemgrexResponse\x12\x46\n\x06result\x18\x01 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\x1a-\n\tNamedNode\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x12\n\nmatchIndex\x18\x02 \x02(\x05\x1a+\n\rNamedRelation\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0c\n\x04reln\x18\x02 \x02(\t\x1a\x80\x01\n\tNamedEdge\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0e\n\x06source\x18\x02 \x02(\x05\x12\x0e\n\x06target\x18\x03 \x02(\x05\x12\x0c\n\x04reln\x18\x04 \x01(\t\x12\x0f\n\x07isExtra\x18\x05 \x01(\x08\x12\x12\n\nsourceCopy\x18\x06 \x01(\r\x12\x12\n\ntargetCopy\x18\x07 \x01(\r\x1a-\n\x0eVariableString\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\r\n\x05value\x18\x02 \x02(\t\x1a\xe6\x02\n\x05Match\x12\x12\n\nmatchIndex\x18\x01 \x02(\x05\x12\x42\n\x04node\x18\x02 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\x12\x46\n\x04reln\x18\x03 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\x12\x42\n\x04\x65\x64ge\x18\x06 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedEdge\x12L\n\tvarstring\x18\x07 \x03(\x0b\x32\x39.edu.stanford.nlp.pipeline.SemgrexResponse.VariableString\x12\x15\n\rsentenceIndex\x18\x04 \x01(\x05\x12\x14\n\x0csemgrexIndex\x18\x05 \x01(\x05\x1aP\n\rSemgrexResult\x12?\n\x05match\x18\x01 \x03(\x0b\x32\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\x1aW\n\x0bGraphResult\x12H\n\x06result\x18\x01 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\"\xf0\x01\n\x0fSsurgeonRequest\x12\x45\n\x08ssurgeon\x18\x01 \x03(\x0b\x32\x33.edu.stanford.nlp.pipeline.SsurgeonRequest.Ssurgeon\x12\x39\n\x05graph\x18\x02 \x03(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x1a[\n\x08Ssurgeon\x12\x0f\n\x07semgrex\x18\x01 \x01(\t\x12\x11\n\toperation\x18\x02 \x03(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12\r\n\x05notes\x18\x04 \x01(\t\x12\x10\n\x08language\x18\x05 \x01(\t\"\xbc\x01\n\x10SsurgeonResponse\x12J\n\x06result\x18\x01 \x03(\x0b\x32:.edu.stanford.nlp.pipeline.SsurgeonResponse.SsurgeonResult\x1a\\\n\x0eSsurgeonResult\x12\x39\n\x05graph\x18\x01 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0f\n\x07\x63hanged\x18\x02 \x01(\x08\"W\n\x12TokensRegexRequest\x12\x30\n\x03\x64oc\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x0f\n\x07pattern\x18\x02 \x03(\t\"\xa7\x03\n\x13TokensRegexResponse\x12J\n\x05match\x18\x01 \x03(\x0b\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\x1a\x39\n\rMatchLocation\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x05\x1a\xb3\x01\n\x05Match\x12\x10\n\x08sentence\x18\x01 \x02(\x05\x12K\n\x05match\x18\x02 \x02(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x12K\n\x05group\x18\x03 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x1aS\n\x0cPatternMatch\x12\x43\n\x05match\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\"\xae\x01\n\x19\x44\x65pendencyEnhancerRequest\x12\x35\n\x08\x64ocument\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x37\n\x08language\x18\x02 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.LanguageH\x00\x12\x1a\n\x10relativePronouns\x18\x03 \x01(\tH\x00\x42\x05\n\x03ref\"\xb4\x01\n\x12\x46lattenedParseTree\x12\x41\n\x05nodes\x18\x01 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.FlattenedParseTree.Node\x1a[\n\x04Node\x12\x12\n\x08openNode\x18\x01 \x01(\x08H\x00\x12\x13\n\tcloseNode\x18\x02 \x01(\x08H\x00\x12\x0f\n\x05value\x18\x03 \x01(\tH\x00\x12\r\n\x05score\x18\x04 \x01(\x01\x42\n\n\x08\x63ontents\"\xf6\x01\n\x15\x45valuateParserRequest\x12N\n\x08treebank\x18\x01 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult\x1a\x8c\x01\n\x0bParseResult\x12;\n\x04gold\x18\x01 \x02(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x12@\n\tpredicted\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"E\n\x16\x45valuateParserResponse\x12\n\n\x02\x66\x31\x18\x01 \x02(\x01\x12\x0f\n\x07kbestF1\x18\x02 \x01(\x01\x12\x0e\n\x06treeF1\x18\x03 \x03(\x01\"\xc8\x01\n\x0fTsurgeonRequest\x12H\n\noperations\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TsurgeonRequest.Operation\x12<\n\x05trees\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x1a-\n\tOperation\x12\x0e\n\x06tregex\x18\x01 \x02(\t\x12\x10\n\x08tsurgeon\x18\x02 \x03(\t\"P\n\x10TsurgeonResponse\x12<\n\x05trees\x18\x01 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"\x85\x01\n\x11MorphologyRequest\x12\x46\n\x05words\x18\x01 \x03(\x0b\x32\x37.edu.stanford.nlp.pipeline.MorphologyRequest.TaggedWord\x1a(\n\nTaggedWord\x12\x0c\n\x04word\x18\x01 \x02(\t\x12\x0c\n\x04xpos\x18\x02 \x01(\t\"\x9a\x01\n\x12MorphologyResponse\x12I\n\x05words\x18\x01 \x03(\x0b\x32:.edu.stanford.nlp.pipeline.MorphologyResponse.WordTagLemma\x1a\x39\n\x0cWordTagLemma\x12\x0c\n\x04word\x18\x01 \x02(\t\x12\x0c\n\x04xpos\x18\x02 \x01(\t\x12\r\n\x05lemma\x18\x03 \x02(\t\"Z\n\x1a\x44\x65pendencyConverterRequest\x12<\n\x05trees\x18\x01 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"\x90\x02\n\x1b\x44\x65pendencyConverterResponse\x12`\n\x0b\x63onversions\x18\x01 \x03(\x0b\x32K.edu.stanford.nlp.pipeline.DependencyConverterResponse.DependencyConversion\x1a\x8e\x01\n\x14\x44\x65pendencyConversion\x12\x39\n\x05graph\x18\x01 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12;\n\x04tree\x18\x02 \x01(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree*\xa3\x01\n\x08Language\x12\x0b\n\x07Unknown\x10\x00\x12\x07\n\x03\x41ny\x10\x01\x12\n\n\x06\x41rabic\x10\x02\x12\x0b\n\x07\x43hinese\x10\x03\x12\x0b\n\x07\x45nglish\x10\x04\x12\n\n\x06German\x10\x05\x12\n\n\x06\x46rench\x10\x06\x12\n\n\x06Hebrew\x10\x07\x12\x0b\n\x07Spanish\x10\x08\x12\x14\n\x10UniversalEnglish\x10\t\x12\x14\n\x10UniversalChinese\x10\n*h\n\tSentiment\x12\x13\n\x0fSTRONG_NEGATIVE\x10\x00\x12\x11\n\rWEAK_NEGATIVE\x10\x01\x12\x0b\n\x07NEUTRAL\x10\x02\x12\x11\n\rWEAK_POSITIVE\x10\x03\x12\x13\n\x0fSTRONG_POSITIVE\x10\x04*\x93\x01\n\x14NaturalLogicRelation\x12\x0f\n\x0b\x45QUIVALENCE\x10\x00\x12\x16\n\x12\x46ORWARD_ENTAILMENT\x10\x01\x12\x16\n\x12REVERSE_ENTAILMENT\x10\x02\x12\x0c\n\x08NEGATION\x10\x03\x12\x0f\n\x0b\x41LTERNATION\x10\x04\x12\t\n\x05\x43OVER\x10\x05\x12\x10\n\x0cINDEPENDENCE\x10\x06\x42*\n\x19\x65\x64u.stanford.nlp.pipelineB\rCoreNLPProtos') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'CoreNLP_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: _globals['DESCRIPTOR']._options = None _globals['DESCRIPTOR']._serialized_options = b'\n\031edu.stanford.nlp.pipelineB\rCoreNLPProtos' _globals['_DEPENDENCYGRAPH'].fields_by_name['root']._options = None _globals['_DEPENDENCYGRAPH'].fields_by_name['root']._serialized_options = b'\020\001' _globals['_DEPENDENCYGRAPH'].fields_by_name['rootNode']._options = None _globals['_DEPENDENCYGRAPH'].fields_by_name['rootNode']._serialized_options = b'\020\001' _globals['_LANGUAGE']._serialized_start=13585 _globals['_LANGUAGE']._serialized_end=13748 _globals['_SENTIMENT']._serialized_start=13750 _globals['_SENTIMENT']._serialized_end=13854 _globals['_NATURALLOGICRELATION']._serialized_start=13857 _globals['_NATURALLOGICRELATION']._serialized_end=14004 _globals['_DOCUMENT']._serialized_start=45 _globals['_DOCUMENT']._serialized_end=782 _globals['_SENTENCE']._serialized_start=785 _globals['_SENTENCE']._serialized_end=2820 _globals['_TOKEN']._serialized_start=2823 _globals['_TOKEN']._serialized_end=4477 _globals['_QUOTE']._serialized_start=4480 _globals['_QUOTE']._serialized_end=4964 _globals['_PARSETREE']._serialized_start=4967 _globals['_PARSETREE']._serialized_end=5166 _globals['_DEPENDENCYGRAPH']._serialized_start=5169 _globals['_DEPENDENCYGRAPH']._serialized_end=5708 _globals['_DEPENDENCYGRAPH_NODE']._serialized_start=5403 _globals['_DEPENDENCYGRAPH_NODE']._serialized_end=5491 _globals['_DEPENDENCYGRAPH_EDGE']._serialized_start=5494 _globals['_DEPENDENCYGRAPH_EDGE']._serialized_end=5708 _globals['_COREFCHAIN']._serialized_start=5711 _globals['_COREFCHAIN']._serialized_end=6037 _globals['_COREFCHAIN_COREFMENTION']._serialized_start=5836 _globals['_COREFCHAIN_COREFMENTION']._serialized_end=6037 _globals['_MENTION']._serialized_start=6040 _globals['_MENTION']._serialized_end=7175 _globals['_INDEXEDWORD']._serialized_start=7177 _globals['_INDEXEDWORD']._serialized_end=7265 _globals['_SPEAKERINFO']._serialized_start=7267 _globals['_SPEAKERINFO']._serialized_end=7319 _globals['_SPAN']._serialized_start=7321 _globals['_SPAN']._serialized_end=7355 _globals['_TIMEX']._serialized_start=7357 _globals['_TIMEX']._serialized_end=7476 _globals['_ENTITY']._serialized_start=7479 _globals['_ENTITY']._serialized_end=7698 _globals['_RELATION']._serialized_start=7701 _globals['_RELATION']._serialized_end=7884 _globals['_OPERATOR']._serialized_start=7887 _globals['_OPERATOR']._serialized_end=8065 _globals['_POLARITY']._serialized_start=8068 _globals['_POLARITY']._serialized_end=8621 _globals['_NERMENTION']._serialized_start=8624 _globals['_NERMENTION']._serialized_end=8973 _globals['_SENTENCEFRAGMENT']._serialized_start=8975 _globals['_SENTENCEFRAGMENT']._serialized_end=9064 _globals['_TOKENLOCATION']._serialized_start=9066 _globals['_TOKENLOCATION']._serialized_end=9124 _globals['_RELATIONTRIPLE']._serialized_start=9127 _globals['_RELATIONTRIPLE']._serialized_end=9537 _globals['_MAPSTRINGSTRING']._serialized_start=9539 _globals['_MAPSTRINGSTRING']._serialized_end=9584 _globals['_MAPINTSTRING']._serialized_start=9586 _globals['_MAPINTSTRING']._serialized_end=9628 _globals['_SECTION']._serialized_start=9631 _globals['_SECTION']._serialized_end=9883 _globals['_SEMGREXREQUEST']._serialized_start=9886 _globals['_SEMGREXREQUEST']._serialized_end=10114 _globals['_SEMGREXREQUEST_DEPENDENCIES']._serialized_start=9992 _globals['_SEMGREXREQUEST_DEPENDENCIES']._serialized_end=10114 _globals['_SEMGREXRESPONSE']._serialized_start=10117 _globals['_SEMGREXRESPONSE']._serialized_end=11008 _globals['_SEMGREXRESPONSE_NAMEDNODE']._serialized_start=10208 _globals['_SEMGREXRESPONSE_NAMEDNODE']._serialized_end=10253 _globals['_SEMGREXRESPONSE_NAMEDRELATION']._serialized_start=10255 _globals['_SEMGREXRESPONSE_NAMEDRELATION']._serialized_end=10298 _globals['_SEMGREXRESPONSE_NAMEDEDGE']._serialized_start=10301 _globals['_SEMGREXRESPONSE_NAMEDEDGE']._serialized_end=10429 _globals['_SEMGREXRESPONSE_VARIABLESTRING']._serialized_start=10431 _globals['_SEMGREXRESPONSE_VARIABLESTRING']._serialized_end=10476 _globals['_SEMGREXRESPONSE_MATCH']._serialized_start=10479 _globals['_SEMGREXRESPONSE_MATCH']._serialized_end=10837 _globals['_SEMGREXRESPONSE_SEMGREXRESULT']._serialized_start=10839 _globals['_SEMGREXRESPONSE_SEMGREXRESULT']._serialized_end=10919 _globals['_SEMGREXRESPONSE_GRAPHRESULT']._serialized_start=10921 _globals['_SEMGREXRESPONSE_GRAPHRESULT']._serialized_end=11008 _globals['_SSURGEONREQUEST']._serialized_start=11011 _globals['_SSURGEONREQUEST']._serialized_end=11251 _globals['_SSURGEONREQUEST_SSURGEON']._serialized_start=11160 _globals['_SSURGEONREQUEST_SSURGEON']._serialized_end=11251 _globals['_SSURGEONRESPONSE']._serialized_start=11254 _globals['_SSURGEONRESPONSE']._serialized_end=11442 _globals['_SSURGEONRESPONSE_SSURGEONRESULT']._serialized_start=11350 _globals['_SSURGEONRESPONSE_SSURGEONRESULT']._serialized_end=11442 _globals['_TOKENSREGEXREQUEST']._serialized_start=11444 _globals['_TOKENSREGEXREQUEST']._serialized_end=11531 _globals['_TOKENSREGEXRESPONSE']._serialized_start=11534 _globals['_TOKENSREGEXRESPONSE']._serialized_end=11957 _globals['_TOKENSREGEXRESPONSE_MATCHLOCATION']._serialized_start=11633 _globals['_TOKENSREGEXRESPONSE_MATCHLOCATION']._serialized_end=11690 _globals['_TOKENSREGEXRESPONSE_MATCH']._serialized_start=11693 _globals['_TOKENSREGEXRESPONSE_MATCH']._serialized_end=11872 _globals['_TOKENSREGEXRESPONSE_PATTERNMATCH']._serialized_start=11874 _globals['_TOKENSREGEXRESPONSE_PATTERNMATCH']._serialized_end=11957 _globals['_DEPENDENCYENHANCERREQUEST']._serialized_start=11960 _globals['_DEPENDENCYENHANCERREQUEST']._serialized_end=12134 _globals['_FLATTENEDPARSETREE']._serialized_start=12137 _globals['_FLATTENEDPARSETREE']._serialized_end=12317 _globals['_FLATTENEDPARSETREE_NODE']._serialized_start=12226 _globals['_FLATTENEDPARSETREE_NODE']._serialized_end=12317 _globals['_EVALUATEPARSERREQUEST']._serialized_start=12320 _globals['_EVALUATEPARSERREQUEST']._serialized_end=12566 _globals['_EVALUATEPARSERREQUEST_PARSERESULT']._serialized_start=12426 _globals['_EVALUATEPARSERREQUEST_PARSERESULT']._serialized_end=12566 _globals['_EVALUATEPARSERRESPONSE']._serialized_start=12568 _globals['_EVALUATEPARSERRESPONSE']._serialized_end=12637 _globals['_TSURGEONREQUEST']._serialized_start=12640 _globals['_TSURGEONREQUEST']._serialized_end=12840 _globals['_TSURGEONREQUEST_OPERATION']._serialized_start=12795 _globals['_TSURGEONREQUEST_OPERATION']._serialized_end=12840 _globals['_TSURGEONRESPONSE']._serialized_start=12842 _globals['_TSURGEONRESPONSE']._serialized_end=12922 _globals['_MORPHOLOGYREQUEST']._serialized_start=12925 _globals['_MORPHOLOGYREQUEST']._serialized_end=13058 _globals['_MORPHOLOGYREQUEST_TAGGEDWORD']._serialized_start=13018 _globals['_MORPHOLOGYREQUEST_TAGGEDWORD']._serialized_end=13058 _globals['_MORPHOLOGYRESPONSE']._serialized_start=13061 _globals['_MORPHOLOGYRESPONSE']._serialized_end=13215 _globals['_MORPHOLOGYRESPONSE_WORDTAGLEMMA']._serialized_start=13158 _globals['_MORPHOLOGYRESPONSE_WORDTAGLEMMA']._serialized_end=13215 _globals['_DEPENDENCYCONVERTERREQUEST']._serialized_start=13217 _globals['_DEPENDENCYCONVERTERREQUEST']._serialized_end=13307 _globals['_DEPENDENCYCONVERTERRESPONSE']._serialized_start=13310 _globals['_DEPENDENCYCONVERTERRESPONSE']._serialized_end=13582 _globals['_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION']._serialized_start=13440 _globals['_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION']._serialized_end=13582 # @@protoc_insertion_point(module_scope) ================================================ FILE: stanza/protobuf/__init__.py ================================================ from __future__ import absolute_import from io import BytesIO import warnings from google.protobuf.internal.encoder import _EncodeVarint from google.protobuf.internal.decoder import _DecodeVarint from google.protobuf.message import DecodeError from .CoreNLP_pb2 import * def parseFromDelimitedString(obj, buf, offset=0): """ Stanford CoreNLP uses the Java "writeDelimitedTo" function, which writes the size (and offset) of the buffer before writing the object. This function handles parsing this message starting from offset 0. @returns how many bytes of @buf were consumed. """ size, pos = _DecodeVarint(buf, offset) try: obj.ParseFromString(buf[offset+pos:offset+pos+size]) except DecodeError as e: warnings.warn("Failed to decode a serialized output from CoreNLP server. An incomplete or empty object will be returned.", \ RuntimeWarning) return pos+size def writeToDelimitedString(obj, stream=None): """ Stanford CoreNLP uses the Java "writeDelimitedTo" function, which writes the size (and offset) of the buffer before writing the object. This function handles parsing this message starting from offset 0. @returns how many bytes of @buf were consumed. """ if stream is None: stream = BytesIO() _EncodeVarint(stream.write, obj.ByteSize(), True) stream.write(obj.SerializeToString()) return stream def to_text(sentence): """ Helper routine that converts a Sentence protobuf to a string from its tokens. """ text = "" for i, tok in enumerate(sentence.token): if i != 0: text += tok.before text += tok.word return text ================================================ FILE: stanza/resources/__init__.py ================================================ ================================================ FILE: stanza/resources/common.py ================================================ """ Common utilities for Stanza resources. """ from collections import defaultdict, namedtuple import errno import hashlib import json import logging import os from pathlib import Path import requests import shutil import tempfile import zipfile from platformdirs import user_cache_dir from tqdm.auto import tqdm from stanza.utils.helper_func import make_table from stanza.pipeline._constants import TOKENIZE, MWT, POS, LEMMA, DEPPARSE, NER, SENTIMENT from stanza.pipeline.registry import PIPELINE_NAMES, PROCESSOR_VARIANTS from stanza.resources.default_packages import PACKAGES from stanza._version import __resources_version__ logger = logging.getLogger('stanza') # set home dir for default USER_CACHE_DIR = user_cache_dir('stanza', 'StanfordNLP', __resources_version__) STANFORDNLP_RESOURCES_URL = 'https://nlp.stanford.edu/software/stanza/stanza-resources/' STANZA_RESOURCES_GITHUB = 'https://raw.githubusercontent.com/stanfordnlp/stanza-resources/' DEFAULT_RESOURCES_URL = os.getenv('STANZA_RESOURCES_URL', STANZA_RESOURCES_GITHUB + 'main') DEFAULT_RESOURCES_VERSION = os.getenv( 'STANZA_RESOURCES_VERSION', __resources_version__ ) DEFAULT_MODEL_URL = os.getenv('STANZA_MODEL_URL', 'default') DEFAULT_MODEL_DIR = os.getenv( 'STANZA_RESOURCES_DIR', os.path.join(USER_CACHE_DIR, 'resources') ) PRETRAIN_NAMES = ("pretrain", "forward_charlm", "backward_charlm") class ResourcesFileNotFoundError(FileNotFoundError): def __init__(self, resources_filepath): super().__init__(f"Resources file not found at: {resources_filepath} Try to download the model again.") self.resources_filepath = resources_filepath class UnknownLanguageError(ValueError): def __init__(self, unknown): super().__init__(f"Unknown language requested: {unknown}") self.unknown_language = unknown class UnknownProcessorError(ValueError): def __init__(self, unknown): super().__init__(f"Unknown processor type requested: {unknown}") self.unknown_processor = unknown ModelSpecification = namedtuple('ModelSpecification', ['processor', 'package', 'dependencies']) def ensure_dir(path): """ Create dir in case it does not exist. """ Path(path).mkdir(parents=True, exist_ok=True) def get_md5(path): """ Get the MD5 value of a path. """ try: with open(path, 'rb') as fin: data = fin.read() except OSError as e: if not e.filename: e.filename = path raise return hashlib.md5(data).hexdigest() def unzip(path, filename): """ Fully unzip a file `filename` that's in a directory `dir`. """ logger.debug(f'Unzip: {path}/{filename}...') with zipfile.ZipFile(os.path.join(path, filename)) as f: f.extractall(path) def get_root_from_zipfile(filename): """ Get the root directory from a archived zip file. """ zf = zipfile.ZipFile(filename, "r") assert len(zf.filelist) > 0, \ f"Zip file at f{filename} seems to be corrupted. Please check it." return os.path.dirname(zf.filelist[0].filename) def file_exists(path, md5): """ Check if the file at `path` exists and match the provided md5 value. """ return os.path.exists(path) and get_md5(path) == md5 def assert_file_exists(path, md5=None, alternate_md5=None): if not os.path.exists(path): raise FileNotFoundError(errno.ENOENT, "Cannot find expected file", path) if md5: file_md5 = get_md5(path) if file_md5 != md5: if file_md5 == alternate_md5: logger.debug("Found a possibly older version of file %s, md5 %s instead of %s", path, alternate_md5, md5) else: raise ValueError("md5 for %s is %s, expected %s" % (path, file_md5, md5)) def download_file(url, path, proxies, raise_for_status=False): """ Download a URL into a file as specified by `path`. """ verbose = logger.level in [0, 10, 20] r = requests.get(url, stream=True, proxies=proxies) if raise_for_status: r.raise_for_status() with open(path, 'wb') as f: file_size = r.headers.get('content-length', None) if file_size: file_size = int(file_size) default_chunk_size = 131072 desc = 'Downloading ' + url with tqdm(total=file_size, unit='B', unit_scale=True, \ disable=not verbose, desc=desc) as pbar: for chunk in r.iter_content(chunk_size=default_chunk_size): if chunk: f.write(chunk) f.flush() pbar.update(len(chunk)) return r.status_code def request_file(url, path, proxies=None, md5=None, raise_for_status=False, log_info=True, alternate_md5=None): """ A complete wrapper over download_file() that also make sure the directory of `path` exists, and that a file matching the md5 value does not exist. alternate_md5 allows for an alternate md5 that is acceptable (such as if an older version of a file is okay) """ basedir = Path(path).parent ensure_dir(basedir) if file_exists(path, md5): if log_info: logger.info(f'File exists: {path}') else: logger.debug(f'File exists: {path}') return # We write data first to a temporary directory, # then use os.replace() so that multiple processes # running at the same time don't clobber each other # with partially downloaded files # This was especially common with resources.json with tempfile.TemporaryDirectory(dir=basedir) as temp: temppath = os.path.join(temp, os.path.split(path)[-1]) download_file(url, temppath, proxies, raise_for_status) os.replace(temppath, path) assert_file_exists(path, md5, alternate_md5) if log_info: logger.info(f'Downloaded file to {path}') else: logger.debug(f'Downloaded file to {path}') def sort_processors(processor_list): sorted_list = [] for processor in PIPELINE_NAMES: for item in processor_list: if item[0] == processor: sorted_list.append(item) # going just by processors in PIPELINE_NAMES, this drops any names # which are not an official processor but might still be useful # check the list and append them to the end # this is especially useful when downloading pretrain or charlm models for processor in processor_list: for item in sorted_list: if processor[0] == item[0]: break else: sorted_list.append(item) return sorted_list def add_mwt(processors, resources, lang): """Add mwt if tokenize is passed without mwt. If tokenize is in the list, but mwt is not, and there is a corresponding tokenize and mwt pair in the resources file, mwt is added so no missing mwt errors are raised. """ value = processors[TOKENIZE] if value in resources[lang][PACKAGES] and MWT in resources[lang][PACKAGES][value]: logger.warning("Language %s package %s expects mwt, which has been added", lang, value) processors[MWT] = value elif (value in resources[lang][TOKENIZE] and MWT in resources[lang] and value in resources[lang][MWT]): logger.warning("Language %s package %s expects mwt, which has been added", lang, value) processors[MWT] = value def maintain_processor_list(resources, lang, package, processors, allow_pretrain=False, maybe_add_mwt=True): """ Given a parsed resources file, language, and possible package and/or processors, expands the package to the list of processors Returns a list of processors Each item in the list of processors is a pair: name, then a list of ModelSpecification so, for example: [['pos', [ModelSpecification(processor='pos', package='gsd', dependencies=None)]], ['depparse', [ModelSpecification(processor='depparse', package='gsd', dependencies=None)]]] """ processor_list = defaultdict(list) # resolve processor models if processors: logger.debug(f'Processing parameter "processors"...') if maybe_add_mwt and TOKENIZE in processors and MWT not in processors: add_mwt(processors, resources, lang) for key, plist in processors.items(): if not isinstance(key, str): raise ValueError("Processor names must be strings") if not isinstance(plist, (tuple, list, str)): raise ValueError("Processor values must be strings") if isinstance(plist, str): plist = [plist] if key not in PIPELINE_NAMES: if not allow_pretrain or key not in PRETRAIN_NAMES: raise UnknownProcessorError(key) for value in plist: # check if keys and values can be found if key in resources[lang] and value in resources[lang][key]: logger.debug(f'Found {key}: {value}.') processor_list[key].append(value) # allow values to be default in some cases elif value in resources[lang][PACKAGES] and key in resources[lang][PACKAGES][value]: logger.debug( f'Found {key}: {resources[lang][PACKAGES][value][key]}.' ) processor_list[key].append(resources[lang][PACKAGES][value][key]) # optional defaults will be activated if specifically turned on elif value in resources[lang][PACKAGES] and 'optional' in resources[lang][PACKAGES][value] and key in resources[lang][PACKAGES][value]['optional']: logger.debug( f"Found {key}: {resources[lang][PACKAGES][value]['optional'][key]}." ) processor_list[key].append(resources[lang][PACKAGES][value]['optional'][key]) # allow processors to be set to variants that we didn't implement elif value in PROCESSOR_VARIANTS[key]: logger.debug( f'Found {key}: {value}. ' f'Using external {value} variant for the {key} processor.' ) processor_list[key].append(value) # allow lemma to be set to "identity" elif key == LEMMA and value == 'identity': logger.debug( f'Found {key}: {value}. Using identity lemmatizer.' ) processor_list[key].append(value) # not a processor in the officially supported processor list elif key not in resources[lang]: logger.debug( f'{key}: {value} is not officially supported by Stanza, ' f'loading it anyway.' ) processor_list[key].append(value) # cannot find the package for a processor and warn user else: logger.warning( f'Can not find {key}: {value} from official model list. ' f'Ignoring it.' ) # resolve package if package: logger.debug(f'Processing parameter "package"...') if PACKAGES in resources[lang] and package in resources[lang][PACKAGES]: for key, value in resources[lang][PACKAGES][package].items(): if key != 'optional' and key not in processor_list: logger.debug(f'Found {key}: {value}.') processor_list[key].append(value) else: flag = False for key in PIPELINE_NAMES: if key not in resources[lang]: continue if package in resources[lang][key]: flag = True if key not in processor_list: logger.debug(f'Found {key}: {package}.') processor_list[key].append(package) else: logger.debug( f'{key}: {package} is overwritten by ' f'{key}: {processors[key]}.' ) if not flag: logger.warning((f'Can not find package: {package}.')) processor_list = [[key, [ModelSpecification(processor=key, package=value, dependencies=None) for value in plist]] for key, plist in processor_list.items()] processor_list = sort_processors(processor_list) return processor_list def add_dependencies(resources, lang, processor_list): """ Expand the processor_list as given in maintain_processor_list to have the dependencies Still a list of model types to ModelSpecifications the dependencies are tuples: name and package for example: [['pos', (ModelSpecification(processor='pos', package='gsd', dependencies=(('pretrain', 'gsd'),)),)], ['depparse', (ModelSpecification(processor='depparse', package='gsd', dependencies=(('pretrain', 'gsd'),)),)]] """ lang_resources = resources[lang] for item in processor_list: processor, model_specs = item new_model_specs = [] for model_spec in model_specs: # skip dependency checking for external variants of processors and identity lemmatizer if not any([ model_spec.package in PROCESSOR_VARIANTS[processor], processor == LEMMA and model_spec.package == 'identity' ]): dependencies = lang_resources.get(processor, {}).get(model_spec.package, {}).get('dependencies', []) dependencies = [(dependency['model'], dependency['package']) for dependency in dependencies] model_spec = model_spec._replace(dependencies=tuple(dependencies)) logger.debug("Found dependencies %s for processor %s model %s", dependencies, processor, model_spec.package) new_model_specs.append(model_spec) item[1] = tuple(new_model_specs) return processor_list def flatten_processor_list(processor_list): """ The flattened processor list is just a list of types & packages For example: [['pos', 'gsd'], ['depparse', 'gsd'], ['pretrain', 'gsd']] """ flattened_processor_list = [] dependencies_list = [] for item in processor_list: processor, model_specs = item for model_spec in model_specs: package = model_spec.package dependencies = model_spec.dependencies flattened_processor_list.append([processor, package]) if dependencies: dependencies_list += [tuple(dependency) for dependency in dependencies] dependencies_list = [list(item) for item in set(dependencies_list)] for processor, package in dependencies_list: logger.debug(f'Find dependency {processor}: {package}.') flattened_processor_list += dependencies_list return flattened_processor_list def set_logging_level(logging_level, verbose): # Check verbose for easy logging control if verbose == False: logging_level = 'ERROR' elif verbose == True: logging_level = 'INFO' if logging_level is None: # default logging level of INFO is set in stanza.__init__ # but the user may have set it via the logging API # it should NOT be 0, but let's check to be sure... if logger.level == 0: logger.setLevel('INFO') return logger.level # Set logging level logging_level = logging_level.upper() all_levels = ['DEBUG', 'INFO', 'WARNING', 'WARN', 'ERROR', 'CRITICAL', 'FATAL'] if logging_level not in all_levels: raise ValueError( f"Unrecognized logging level for pipeline: " f"{logging_level}. Must be one of {', '.join(all_levels)}." ) logger.setLevel(logging_level) return logger.level def process_pipeline_parameters(lang, model_dir, package, processors): # Check parameter types and convert values to lower case if isinstance(lang, str): lang = lang.strip().lower() elif lang is not None: raise TypeError( f"The parameter 'lang' should be str, " f"but got {type(lang).__name__} instead." ) if isinstance(model_dir, str): model_dir = model_dir.strip() elif model_dir is not None: raise TypeError( f"The parameter 'model_dir' should be str, " f"but got {type(model_dir).__name__} instead." ) if isinstance(processors, (str, list, tuple)): # Special case: processors is str, compatible with older version # also allow for setting alternate packages for these processors # via the package argument if package is None: # each processor will be 'default' for this language package = defaultdict(lambda: 'default') elif isinstance(package, str): # same, but now the named package will be the default instead default = package package = defaultdict(lambda: default) elif isinstance(package, dict): # the dictionary of packages will be used to build the processors dict # any processor not specified in package will be 'default' package = defaultdict(lambda: 'default', package) else: raise TypeError( f"The parameter 'package' should be None, str, or dict, " f"but got {type(package).__name__} instead." ) if isinstance(processors, str): processors = [x.strip().lower() for x in processors.split(",")] processors = { processor: package[processor] for processor in processors } package = None elif isinstance(processors, dict): processors = { k.strip().lower(): ([v_i.strip().lower() for v_i in v] if isinstance(v, (tuple, list)) else v.strip().lower()) for k, v in processors.items() } elif processors is not None: raise TypeError( f"The parameter 'processors' should be dict or str, " f"but got {type(processors).__name__} instead." ) if isinstance(package, str): package = package.strip().lower() elif package is not None: raise TypeError( f"The parameter 'package' should be str, or a dict if 'processors' is a str, " f"but got {type(package).__name__} instead." ) return lang, model_dir, package, processors def download_resources_json(model_dir=DEFAULT_MODEL_DIR, resources_url=DEFAULT_RESOURCES_URL, resources_branch=None, resources_version=DEFAULT_RESOURCES_VERSION, resources_filepath=None, proxies=None): """ Downloads resources.json to obtain latest packages. """ if resources_url == DEFAULT_RESOURCES_URL and resources_branch is not None: resources_url = STANZA_RESOURCES_GITHUB + resources_branch # handle short name for resources urls; otherwise treat it as url if resources_url.lower() in ('stanford', 'stanfordnlp'): resources_url = STANFORDNLP_RESOURCES_URL resources_url = f'{resources_url}/resources_{resources_version}.json' logger.debug('Downloading resource file from %s', resources_url) if resources_filepath is None: resources_filepath = os.path.join(model_dir, 'resources.json') # make request request_file( resources_url, resources_filepath, proxies, raise_for_status=True ) def load_resources_json(model_dir=DEFAULT_MODEL_DIR, resources_filepath=None): """ Unpack the resources json file from the given model_dir """ if resources_filepath is None: resources_filepath = os.path.join(model_dir, 'resources.json') if not os.path.exists(resources_filepath): raise ResourcesFileNotFoundError(resources_filepath) with open(resources_filepath, encoding="utf-8") as fin: resources = json.load(fin) return resources def get_language_resources(resources, lang): """ Get the resources for a lang from an already loaded resources json, following 'alias' if needed """ if lang not in resources: return None lang_resources = resources[lang] while 'alias' in lang_resources: lang = lang_resources['alias'] lang_resources = resources[lang] return lang_resources def list_available_languages(model_dir=DEFAULT_MODEL_DIR, resources_url=DEFAULT_RESOURCES_URL, resources_branch=None, resources_version=DEFAULT_RESOURCES_VERSION, proxies=None): """ List the non-alias languages in the resources file """ download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies) resources = load_resources_json(model_dir) # isinstance(str) is because of fields such as "url" # 'alias' is because we want to skip German, alias of de, for example languages = [lang for lang in resources if not isinstance(resources[lang], str) and 'alias' not in resources[lang]] languages = sorted(languages) return languages def expand_model_url(resources, model_url): """ Returns the url in the resources dict if model_url is default, or returns the model_url """ return resources['url'] if model_url.lower() == 'default' else model_url def download_models(download_list, resources, lang, model_dir=DEFAULT_MODEL_DIR, resources_version=DEFAULT_RESOURCES_VERSION, model_url=DEFAULT_MODEL_URL, proxies=None, log_info=True): lang_name = resources.get(lang, {}).get('lang_name', lang) download_table = make_table(['Processor', 'Package'], download_list) if log_info: log_msg = logger.info else: log_msg = logger.debug log_msg( f'Downloading these customized packages for language: ' f'{lang} ({lang_name})...\n{download_table}' ) url = expand_model_url(resources, model_url) # Download packages for key, value in download_list: try: request_file( url.format(resources_version=resources_version, lang=lang, filename=f"{key}/{value}.pt"), os.path.join(model_dir, lang, key, f'{value}.pt'), proxies, md5=resources[lang][key][value]['md5'], log_info=log_info, alternate_md5=resources[lang][key][value].get('alternate_md5', None) ) except KeyError as e: raise ValueError( f'Cannot find the following processor and model name combination: ' f'{key}, {value}. Please check if you have provided the correct model name.' ) from e # main download function def download( lang='en', model_dir=DEFAULT_MODEL_DIR, package='default', processors={}, logging_level=None, verbose=None, resources_url=DEFAULT_RESOURCES_URL, resources_branch=None, resources_version=DEFAULT_RESOURCES_VERSION, model_url=DEFAULT_MODEL_URL, proxies=None, download_json=True ): # set global logging level set_logging_level(logging_level, verbose) # process different pipeline parameters lang, model_dir, package, processors = process_pipeline_parameters( lang, model_dir, package, processors ) if download_json or not os.path.exists(os.path.join(model_dir, 'resources.json')): if not download_json: logger.warning("Asked to skip downloading resources.json, but the file does not exist. Downloading anyway") download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies) resources = load_resources_json(model_dir) if lang not in resources: raise UnknownLanguageError(lang) if 'alias' in resources[lang]: logger.info(f'"{lang}" is an alias for "{resources[lang]["alias"]}"') lang = resources[lang]['alias'] lang_name = resources.get(lang, {}).get('lang_name', lang) url = expand_model_url(resources, model_url) # Default: download zipfile and unzip if package == 'default' and (processors is None or len(processors) == 0): logger.info( f'Downloading default packages for language: {lang} ({lang_name}) ...' ) # want the URL to become, for example: # https://huggingface.co/stanfordnlp/stanza-af/resolve/v1.3.0/models/default.zip # so we hopefully start from # https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename} request_file( url.format(resources_version=resources_version, lang=lang, filename="default.zip"), os.path.join(model_dir, lang, f'default.zip'), proxies, md5=resources[lang]['default_md5'], ) unzip(os.path.join(model_dir, lang), 'default.zip') download_list = [['zip', 'default.zip']] # Customize: maintain download list else: download_list = maintain_processor_list(resources, lang, package, processors, allow_pretrain=True) download_list = add_dependencies(resources, lang, download_list) download_list = flatten_processor_list(download_list) download_models(download_list=download_list, resources=resources, lang=lang, model_dir=model_dir, resources_version=resources_version, model_url=model_url, proxies=proxies, log_info=True) logger.info(f'Finished downloading models and saved to {model_dir}') return download_list ================================================ FILE: stanza/resources/default_packages.py ================================================ """ Constants for default packages, default pretrains, charlms, etc Separated from prepare_resources.py so that other modules can use the same lists / maps without importing the resources script and possibly causing a circular import """ import copy # all languages will have a map which represents the available packages PACKAGES = "packages" # default treebank for languages default_treebanks = { "ab": "abnc", "af": "afribooms", # currently not publicly released! sent to us from the group developing this resource "ang": "nerthus", "ar": "padt", "be": "hse", "bg": "btb", "bxr": "bdt", "ca": "ancora", "cop": "scriptorium", "cs": "pdt", "cu": "proiel", "cy": "ccg", "da": "ddt", "de": "combined", "el": "gdt", "en": "combined", "es": "combined", "et": "edt", "eu": "bdt", "fa": "perdt", "fi": "tdt", "fo": "farpahc", "fr": "combined", "fro": "profiterole", "ga": "idt", "gd": "arcosg", "gl": "ctg", "got": "proiel", "grc": "perseus", "gv": "cadhan", "hbo": "ptnk", "he": "combined", "hi": "hdtb", "hr": "set", "hsb": "ufal", "hu": "szeged", "hy": "armtdp", "hyw": "armtdp", "id": "gsd", "is": "icepahc", "it": "combined", "ja": "combined", "ka": "glc", "kk": "ktb", "kmr": "mg", "ko": "kaist", "kpv": "lattice", "ky": "ktmu", "la": "ittb", "lij": "glt", "lt": "alksnis", "lv": "lvtb", "lzh": "kyoto", "mr": "ufal", "mt": "mudt", "my": "ucsy", "myv": "jr", "nb": "bokmaal", "nds": "lsdc", "nl": "alpino", "nn": "nynorsk", "olo": "kkpp", "orv": "torot", "ota": "boun", "pcm": "nsc", "pl": "pdb", "pt": "bosque", "qaf": "arabizi", "qpm": "philotis", "qtd": "sagt", "ro": "rrt", "ru": "syntagrus", "sa": "vedic", "sd": "isra", "sk": "snk", "sl": "ssj", "sme": "giella", "sq": "combined", "sr": "set", "sv": "talbanken", "swl": "sslc", "ta": "ttb", "te": "mtg", "th": "tud", "tr": "imst", "ug": "udt", "uk": "iu", "ur": "udtb", "vi": "vtb", "wo": "wtb", "xcl": "caval", "zh-hans": "gsdsimp", "zh-hant": "gsd", "multilingual": "ud" } no_pretrain_languages = set([ "cop", "olo", "orv", "pcm", "qaf", # the QAF treebank is code switched and Romanized, so not easy to reuse existing resources "qpm", # have talked about deriving this from a language neighborinig to Pomak, but that hasn't happened yet "qtd", "swl", "multilingual", # special case so that all languages with a default treebank are represented somewhere ]) # in some cases, we give the pretrain a name other than the original # name for the UD dataset # we will eventually do this for all of the pretrains specific_default_pretrains = { "ab": "fasttextwiki", "af": "fasttextwiki", "ang": "nerthus", "ar": "conll17", "be": "fasttextwiki", "bg": "conll17", "bxr": "fasttextwiki", "ca": "conll17", "cs": "conll17", "cu": "conll17", "cy": "fasttext157", "da": "conll17", "de": "conll17", "el": "conll17", "en": "conll17", "es": "conll17", "et": "conll17", "eu": "conll17", "fa": "conll17", "fi": "conll17", "fo": "fasttextwiki", "fr": "conll17", "fro": "conll17", "ga": "conll17", "gd": "fasttextwiki", "gl": "conll17", "got": "fasttextwiki", "grc": "conll17", "gv": "fasttext157", "hbo": "utah", "he": "conll17", "hi": "conll17", "hr": "conll17", "hsb": "fasttextwiki", "hu": "conll17", "hy": "isprasglove", "hyw": "isprasglove", "id": "conll17", "is": "fasttext157", "it": "conll17", "ja": "conll17", "ka": "fasttext157", "kk": "fasttext157", "kmr": "fasttextwiki", "ko": "conll17", "kpv": "fasttextwiki", "ky": "fasttext157", "la": "conll17", "lij": "fasttextwiki", "lt": "fasttextwiki", "lv": "conll17", "lzh": "fasttextwiki", "mr": "fasttextwiki", "mt": "fasttextwiki", "my": "ucsy", "myv": "mokha", "nb": "conll17", "nds": "fasttext157", "nl": "conll17", "nn": "conll17", "or": "fasttext157", "ota": "conll17", "pl": "conll17", "pt": "conll17", "ro": "conll17", "ru": "conll17", "sa": "fasttext157", "sd": "isra", "sk": "conll17", "sl": "conll17", "sme": "fasttextwiki", "sq": "fasttext157", "sr": "fasttextwiki", "sv": "conll17", "ta": "fasttextwiki", "te": "fasttextwiki", "th": "fasttext157", "tr": "conll17", "ug": "conll17", "uk": "conll17", "ur": "conll17", "vi": "conll17", "wo": "fasttextwiki", "xcl": "caval", "zh-hans": "fasttext157", "zh-hant": "conll17", } def build_default_pretrains(default_treebanks): default_pretrains = dict(default_treebanks) for lang in no_pretrain_languages: default_pretrains.pop(lang, None) for lang in specific_default_pretrains.keys(): default_pretrains[lang] = specific_default_pretrains[lang] return default_pretrains default_pretrains = build_default_pretrains(default_treebanks) pos_pretrains = { "en": { "craft": "biomed", "genia": "biomed", "mimic": "mimic", }, } depparse_pretrains = pos_pretrains ner_pretrains = { "ar": { "aqmar": "fasttextwiki", }, "de": { "conll03": "fasttextwiki", # the bert version of germeval uses the smaller vector file "germeval2014": "fasttextwiki", }, "en": { "anatem": "biomed", "bc4chemd": "biomed", "bc5cdr": "biomed", "bionlp13cg": "biomed", "jnlpba": "biomed", "linnaeus": "biomed", "ncbi_disease": "biomed", "s800": "biomed", "ontonotes": "fasttextcrawl", # the stanza-train sample NER model should use the default NER pretrain # for English, that is the same as ontonotes "sample": "fasttextcrawl", "conll03": "glove", "i2b2": "mimic", "radiology": "mimic", }, "es": { "ancora": "fasttextwiki", "conll02": "fasttextwiki", }, "nl": { "conll02": "fasttextwiki", "wikiner": "fasttextwiki", }, "ru": { "wikiner": "fasttextwiki", }, "th": { "lst20": "fasttext157", }, } # default charlms for languages default_charlms = { "af": "oscar", "ang": "nerthus1024", "ar": "ccwiki", "bg": "conll17", "da": "oscar", "de": "newswiki", "en": "1billion", "es": "newswiki", "fa": "conll17", "fi": "conll17", "fr": "newswiki", "he": "oscar", "hi": "oscar", "id": "oscar2023", "it": "conll17", "ja": "conll17", "kk": "oscar", "mr": "l3cube", "my": "oscar", "nb": "conll17", "nl": "ccwiki", "pl": "oscar", "pt": "oscar2023", "ru": "newswiki", "sd": "isra", "sv": "conll17", "te": "oscar2022", "th": "oscar", "tr": "conll17", "uk": "conll17", "vi": "conll17", "zh-hans": "gigaword" } pos_charlms = { "en": { # none of the English charlms help with craft or genia "craft": None, "genia": None, "mimic": "mimic", }, "tr": { # no idea why, but this particular one goes down in dev score "boun": None, }, } depparse_charlms = copy.deepcopy(pos_charlms) lemma_charlms = copy.deepcopy(pos_charlms) tokenizer_charlms = copy.deepcopy(pos_charlms) ner_charlms = { "en": { "conll03": "1billion", "ontonotes": "1billion", "anatem": "pubmed", "bc4chemd": "pubmed", "bc5cdr": "pubmed", "bionlp13cg": "pubmed", "i2b2": "mimic", "jnlpba": "pubmed", "linnaeus": "pubmed", "ncbi_disease": "pubmed", "radiology": "mimic", "s800": "pubmed", }, "hu": { "combined": None, }, "nn": { "norne": None, }, } # default ner for languages default_ners = { "af": "nchlt", "ang": "oedt_charlm", "ar": "aqmar_charlm", "bg": "bsnlp19", "da": "ddt", "de": "germeval2014", "en": "ontonotes-ww-multi_charlm", "es": "conll02", "fa": "arman", "fi": "turku", "fr": "wikinergold_charlm", "he": "iahlt_charlm", "hi": "ilner_charlm", "hu": "combined", "hy": "armtdp", "it": "fbk", "ja": "gsd", "kk": "kazNERD", "mr": "l3cube", "my": "ucsy", "nb": "norne", "nl": "conll02", "nn": "norne", "pl": "nkjp", "ru": "wikiner", "sd": "siner", "sv": "suc3shuffle", "te": "ilner_charlm", "th": "lst20", "tr": "starlang", "uk": "languk", "ur": "ilner_nocharlm", "vi": "vlsp", "zh-hans": "ontonotes", } # a few languages have sentiment classifier models default_sentiment = { "en": "sstplus_charlm", "de": "sb10k_charlm", "es": "tass2020_charlm", "mr": "l3cube_charlm", "vi": "vsfc_charlm", "zh-hans": "ren_charlm", } # also, a few languages (very few, currently) have constituency parser models default_constituency = { "da": "arboretum_charlm", "de": "spmrl_charlm", "en": "ptb3-revised_charlm", "es": "combined_charlm", "id": "icon_charlm", "it": "vit_charlm", "ja": "alt_charlm", "pt": "cintil_charlm", #"tr": "starlang_charlm", "vi": "vlsp22_charlm", "zh-hans": "ctb-51_charlm", } optional_constituency = { "tr": "starlang_charlm", } # an alternate tokenizer for languages which aren't trained from a base UD source default_tokenizer = { "my": "alt", } # ideally we would have a less expensive model as the base model #default_coref = { # "en": "ontonotes_roberta-large_finetuned", #} optional_coref = { "ca": "udcoref_xlm-roberta-lora", "cs": "udcoref_xlm-roberta-lora", "de": "udcoref_xlm-roberta-lora", "en": "udcoref_xlm-roberta-lora", "es": "udcoref_xlm-roberta-lora", "fr": "udcoref_xlm-roberta-lora", "he": "iahlt_xlm-roberta-lora", "hi": "deeph_muril-large-cased-lora", # UD Coref has both nb and nn datasets for Norwegian "nb": "udcoref_xlm-roberta-lora", "nn": "udcoref_xlm-roberta-lora", "pl": "udcoref_xlm-roberta-lora", "ru": "udcoref_xlm-roberta-lora", "ta": "kbc_muril-large-cased-lora", } """ default transformers to use for various languages we try to document why we choose a particular model in each case """ TRANSFORMERS = { # We tested three candidate AR models on POS, Depparse, and NER # # POS: padt dev set scores, AllTags # depparse: padt dev set scores, LAS # NER: dev scores on a random split of AQMAR, entity scores # # pos depparse ner # none (pt & charlm only) 94.08 83.49 84.19 # asafaya/bert-base-arabic 95.10 84.96 85.98 # aubmindlab/bert-base-arabertv2 95.33 85.28 84.93 # aubmindlab/araelectra-base-discriminator 95.66 85.83 86.10 "ar": "aubmindlab/araelectra-base-discriminator", # https://huggingface.co/Maltehb/danish-bert-botxo # contrary to normal expectations, this hurts F1 # on a dev split by about 1 F1 # "da": "Maltehb/danish-bert-botxo", # # the multilingual bert is a marginal improvement for conparse # # December 2022 update: # there are quite a few Danish transformers available on HuggingFace # here are the results of training a constituency parser with adadelta/adamw # on each of them: # # no bert 0.8245 0.8230 # alexanderfalk/danbert-small-cased 0.8236 0.8286 # Geotrend/distilbert-base-da-cased 0.8268 0.8306 # sarnikowski/convbert-small-da-cased 0.8322 0.8341 # bert-base-multilingual-cased 0.8341 0.8342 # vesteinn/ScandiBERT-no-faroese 0.8373 0.8408 # Maltehb/danish-bert-botxo 0.8383 0.8408 # vesteinn/ScandiBERT 0.8421 0.8475 # # Also, two models have token windows too short for use with the # Danish dataset: # jonfd/electra-small-nordic # Maltehb/aelaectra-danish-electra-small-cased # "da": "vesteinn/ScandiBERT", # As of April 2022, the bert models available have a weird # tokenizer issue where soft hyphen causes it to crash. # We attempt to compensate for that in the dev branch # # NER scores # model dev text # xlm-roberta-large 86.56 85.23 # bert-base-german-cased 87.59 86.95 # dbmdz/bert-base-german-cased 88.27 87.47 # german-nlp-group/electra-base-german-uncased 88.60 87.09 # # constituency scores w/ peft, March 2024 model, in-order # model dev test # xlm-roberta-base 95.17 93.34 # xlm-roberta-large 95.86 94.46 (!!!) # bert-base 95.24 93.24 # dbmdz/bert 95.32 93.33 # german/electra 95.72 94.05 # # POS scores # model dev test # None 88.65 87.28 # xlm-roberta-large 89.21 88.11 # bert-base 89.52 88.42 # dbmdz/bert 89.67 88.54 # german/electra 89.98 88.66 # # depparse scores, LAS # model dev test # None 87.76 84.37 # xlm-roberta-large 89.00 85.79 # bert-base 88.72 85.40 # dbmdz/bert 88.70 85.14 # german/electra 89.21 86.06 "de": "german-nlp-group/electra-base-german-uncased", # experiments on various forms of roberta & electra # https://huggingface.co/roberta-base # https://huggingface.co/roberta-large # https://huggingface.co/google/electra-small-discriminator # https://huggingface.co/google/electra-base-discriminator # https://huggingface.co/google/electra-large-discriminator # # experiments using the different models for POS tagging, # dev set, including WV and charlm, AllTags score: # roberta-base: 95.67 # roberta-large: 95.98 # electra-small: 95.31 # electra-base: 95.90 # electra-large: 96.01 # # depparse scores, dev set, no finetuning, with WV and charlm # UAS LAS CLAS MLAS BLEX # roberta-base: 93.16 91.20 89.87 89.38 89.87 # roberta-large: 93.47 91.56 90.13 89.71 90.13 # electra-small: 92.17 90.02 88.25 87.66 88.25 # electra-base: 93.42 91.44 90.10 89.67 90.10 # electra-large: 94.07 92.17 90.99 90.53 90.99 # # conparse scores, dev & test set, with WV and charlm # roberta_base: 96.05 95.60 # roberta_large: 95.95 95.60 # electra-small: 95.33 95.04 # electra-base: 96.09 95.98 # electra-large: 96.25 96.14 # # conparse scores w/ finetune, dev & test set, with WV and charlm # roberta_base: 96.07 95.81 # roberta_large: 96.37 96.41 (!!!) # electra-small: 95.62 95.36 # electra-base: 96.21 95.94 # electra-large: 96.40 96.32 # "en": "google/electra-large-discriminator", # TODO need to test, possibly compare with others "es": "bertin-project/bertin-roberta-base-spanish", # NER scores for a couple Persian options: # none: # dev: 2022-04-23 01:44:53 INFO: fa_arman 79.46 # test: 2022-04-23 01:45:03 INFO: fa_arman 80.06 # # HooshvareLab/bert-fa-zwnj-base # dev: 2022-04-23 02:43:44 INFO: fa_arman 80.87 # test: 2022-04-23 02:44:07 INFO: fa_arman 80.81 # # HooshvareLab/roberta-fa-zwnj-base # dev: 2022-04-23 16:23:25 INFO: fa_arman 81.23 # test: 2022-04-23 16:23:48 INFO: fa_arman 81.11 # # HooshvareLab/bert-base-parsbert-uncased # dev: 2022-04-26 10:42:09 INFO: fa_arman 82.49 # test: 2022-04-26 10:42:31 INFO: fa_arman 83.16 "fa": 'HooshvareLab/bert-base-parsbert-uncased', # NER scores for a couple options: # none: # dev: 2022-03-04 INFO: fi_turku 83.45 # test: 2022-03-04 INFO: fi_turku 86.25 # # bert-base-multilingual-cased # dev: 2022-03-04 INFO: fi_turku 85.23 # test: 2022-03-04 INFO: fi_turku 89.00 # # TurkuNLP/bert-base-finnish-cased-v1: # dev: 2022-03-04 INFO: fi_turku 88.41 # test: 2022-03-04 INFO: fi_turku 91.36 "fi": "TurkuNLP/bert-base-finnish-cased-v1", # POS dev set tagging results for French: # No bert: # 98.60 100.00 98.55 98.04 # dbmdz/electra-base-french-europeana-cased-discriminator # 98.70 100.00 98.69 98.24 # benjamin/roberta-base-wechsel-french # 98.71 100.00 98.75 98.26 # camembert/camembert-large # 98.75 100.00 98.75 98.30 # camembert-base # 98.78 100.00 98.77 98.33 # # GSD depparse dev set results for French: # No bert: # 95.83 94.52 91.34 91.10 91.34 # camembert/camembert-large # 96.80 95.71 93.37 93.13 93.37 # TODO: the rest of the chart "fr": "camembert/camembert-large", # Ancient Greek has a surprising number of transformers, considering # Model POS Depparse LAS # None 0.8812 0.7684 # Microbert M 0.8883 0.7706 # Microbert MX 0.8910 0.7755 # Microbert MXP 0.8916 0.7742 # Pranaydeeps Bert 0.9139 0.7987 "grc": "pranaydeeps/Ancient-Greek-BERT", # a couple possibilities to experiment with for Hebrew # dev scores for POS and depparse # https://huggingface.co/imvladikon/alephbertgimmel-base-512 # UPOS XPOS UFeats AllTags # 97.25 97.25 92.84 91.81 # UAS LAS CLAS MLAS BLEX # 94.42 92.47 89.49 88.82 89.49 # # https://huggingface.co/onlplab/alephbert-base # UPOS XPOS UFeats AllTags # 97.37 97.37 92.50 91.55 # UAS LAS CLAS MLAS BLEX # 94.06 92.12 88.80 88.13 88.80 # # https://huggingface.co/avichr/heBERT # UPOS XPOS UFeats AllTags # 97.09 97.09 92.36 91.28 # UAS LAS CLAS MLAS BLEX # 94.29 92.30 88.99 88.38 88.99 "he": "imvladikon/alephbertgimmel-base-512", # can also experiment with xlm-roberta # on a coref dataset from IITH, span F1: # dev test # xlm-roberta-large 0.63635 0.66579 # muril-large 0.65369 0.68290 "hi": "google/muril-large-cased", # https://huggingface.co/xlm-roberta-base # Scores by entity for armtdp NER on 18 labels: # no bert : 86.68 # xlm-roberta-base : 89.31 "hy": "xlm-roberta-base", # Indonesian POS experiments: dev set of GSD # python3 stanza/utils/training/run_pos.py id_gsd --no_bert # python3 stanza/utils/training/run_pos.py id_gsd --bert_model ... # also ran on the ICON constituency dataset # model POS CON # no_bert 89.95 84.74 # flax-community/indonesian-roberta-large 89.78 (!) xxx # flax-community/indonesian-roberta-base 90.14 xxx # indobenchmark/indobert-base-p2 90.09 # indobenchmark/indobert-base-p1 90.14 # indobenchmark/indobert-large-p1 90.19 # indolem/indobert-base-uncased 90.21 88.60 # cahya/bert-base-indonesian-1.5G 90.32 88.15 # cahya/roberta-base-indonesian-1.5G 90.40 87.27 "id": "indolem/indobert-base-uncased", # from https://github.com/idb-ita/GilBERTo # annoyingly, it doesn't handle cased text # supposedly there is an argument "do_lower_case" # but that still leaves a lot of unk tokens # "it": "idb-ita/gilberto-uncased-from-camembert", # # from https://github.com/musixmatchresearch/umberto # on NER, this gets 88.37 dev and 91.02 test # another option is dbmdz/bert-base-italian-cased, # which gets 87.27 dev and 90.32 test # # in-order constituency parser on the VIT dev set: # dbmdz/bert-base-italian-cased 0.8079 # dbmdz/bert-base-italian-xxl-cased: 0.8195 # Musixmatch/umberto-commoncrawl-cased-v1: 0.8256 # dbmdz/electra-base-italian-xxl-cased-discriminator: 0.8314 # # FBK NER dev set: # dbmdz/bert-base-italian-cased: 87.76 # Musixmatch/umberto-commoncrawl-cased-v1: 88.62 # dbmdz/bert-base-italian-xxl-cased: 88.84 # dbmdz/electra-base-italian-xxl-cased-discriminator: 89.91 # # combined UD POS dev set: UPOS XPOS UFeats AllTags # dbmdz/bert-base-italian-cased: 98.62 98.53 98.06 97.49 # dbmdz/bert-base-italian-xxl-cased: 98.61 98.54 98.07 97.58 # dbmdz/electra-base-italian-xxl-cased-discriminator: 98.64 98.54 98.14 97.61 # Musixmatch/umberto-commoncrawl-cased-v1: 98.56 98.45 98.13 97.62 "it": "dbmdz/electra-base-italian-xxl-cased-discriminator", # for Japanese # there are others that would also work, # but they require different tokenizers instead of being # plug & play # # Constitutency scores on ALT (in-order) # no bert: 90.68 dev, 91.40 test # rinna: 91.54 dev, 91.89 test "ja": "rinna/japanese-roberta-base", # could also try: # l3cube-pune/marathi-bert-v2 # or # https://huggingface.co/l3cube-pune/hindi-marathi-dev-roberta # l3cube-pune/hindi-marathi-dev-roberta # # depparse ufal dev scores: # no transformer 74.89 63.70 57.43 53.01 57.43 # l3cube-pune/marathi-roberta 76.48 66.21 61.20 57.60 61.20 "mr": "l3cube-pune/marathi-roberta", "or": "google/muril-large-cased", # https://huggingface.co/allegro/herbert-base-cased # Scores by entity on the NKJP NER task: # no bert (dev/test): 88.64/88.75 # herbert-base-cased (dev/test): 91.48/91.02, # herbert-large-cased (dev/test): 92.25/91.62 # sdadas/polish-roberta-large-v2 (dev/test): 92.66/91.22 "pl": "allegro/herbert-base-cased", # experiments on the cintil conparse dataset # ran a variety of transformer settings # found the following dev set scores after 400 iterations: # Geotrend/distilbert-base-pt-cased : not plug & play # no bert: 0.9082 # xlm-roberta-base: 0.9109 # xlm-roberta-large: 0.9254 # adalbertojunior/distilbert-portuguese-cased: 0.9300 # neuralmind/bert-base-portuguese-cased: 0.9307 # neuralmind/bert-large-portuguese-cased: 0.9343 "pt": "neuralmind/bert-large-portuguese-cased", # hope is actually to build our own using a large text collection "sd": "google/muril-large-cased", # Tamil options: quite a few, need to run a bunch of experiments # dev pos dev depparse las # no transformer 82.82 69.12 # ai4bharat/indic-bert 82.98 70.47 # lgessler/microbert-tamil-mxp 83.21 69.28 # monsoon-nlp/tamillion 83.37 69.28 # l3cube-pune/tamil-bert 85.27 72.53 # d42kw01f/Tamil-RoBERTa 85.59 70.55 # google/muril-base-cased 85.67 72.68 # google/muril-large-cased 86.30 72.45 # # should also consider xlm-roberta-large # updated on UD 2.16 data: dev pos ner # google/muril-large-cased 86.86 65.08 # xlm-roberta-large 66.28 "ta": "google/muril-large-cased", "te": "google/muril-large-cased", # https://huggingface.co/airesearch/wangchanberta-base-att-spm-uncased # this is clearly better than no transformer on a couple datasets: # # TUD dev upos TUD dev depparse LAS # no transformer 91.26 73.57 # wangchanberta 92.21 76.65 "th": "airesearch/wangchanberta-base-att-spm-uncased", # https://huggingface.co/dbmdz/bert-base-turkish-128k-cased # helps the Turkish model quite a bit "tr": "dbmdz/bert-base-turkish-128k-cased", "ur": "google/muril-large-cased", # from https://github.com/VinAIResearch/PhoBERT # "vi": "vinai/phobert-base", # using 6 or 7 layers of phobert-large is slightly # more effective for constituency parsing than # using 4 layers of phobert-base # ... going beyond 4 layers of phobert-base # does not help the scores "vi": "vinai/phobert-large", # https://github.com/ymcui/Chinese-BERT-wwm # there's also hfl/chinese-roberta-wwm-ext-large # or hfl/chinese-electra-base-discriminator # or hfl/chinese-electra-180g-large-discriminator, # which works better than the below roberta on constituency # "zh-hans": "hfl/chinese-roberta-wwm-ext", # conparse dev scores (averaged over 5): # google bert: 0.9422 # hfl bert: 0.9469 # hfl roberta: 0.9459 # hfl electra: 0.9515 # hfl macbert: 0.9530 # There is also a ShannonAI model, but our current codebase is # somehow not compatible # further comparing HFL: # POS dev Depparse dev LAS NER dev # HFL Electra 96.90 85.66 77.90 # HFL Macbert 96.53 84.72 78.46 # "zh-hans": "hfl/chinese-macbert-large", "zh-hans": "hfl/chinese-electra-180g-large-discriminator", } TRANSFORMER_LAYERS = { # not clear what the best number is without more experiments, # but more than 4 is working better than just 4 "vi": 7, } TRANSFORMER_NICKNAMES = { # ar "asafaya/bert-base-arabic": "asafaya-bert", "aubmindlab/araelectra-base-discriminator": "aubmind-electra", "aubmindlab/bert-base-arabertv2": "aubmind-bert", # da "vesteinn/ScandiBERT": "scandibert", # de "bert-base-german-cased": "bert-base-german-cased", "dbmdz/bert-base-german-cased": "dbmdz-bert-german-cased", "german-nlp-group/electra-base-german-uncased": "german-nlp-electra", # en "bert-base-multilingual-cased": "mbert", "xlm-roberta-large": "xlm-roberta-large", "google/electra-large-discriminator": "electra-large", "microsoft/deberta-v3-large": "deberta-v3-large", "princeton-nlp/Sheared-LLaMA-1.3B": "sheared-llama-1b3", # es "bertin-project/bertin-roberta-base-spanish": "bertin-roberta", # fa "HooshvareLab/bert-base-parsbert-uncased": "parsbert", # fi "TurkuNLP/bert-base-finnish-cased-v1": "bert", # fr "benjamin/roberta-base-wechsel-french": "wechsel-roberta", "camembert-base": "camembert-base", "camembert/camembert-large": "camembert-large", "dbmdz/electra-base-french-europeana-cased-discriminator": "dbmdz-electra", # grc "pranaydeeps/Ancient-Greek-BERT": "grc-pranaydeeps", "lgessler/microbert-ancient-greek-m": "grc-microbert-m", "lgessler/microbert-ancient-greek-mx": "grc-microbert-mx", "lgessler/microbert-ancient-greek-mxp": "grc-microbert-mxp", "altsoph/bert-base-ancientgreek-uncased": "grc-altsoph", # he "HeNLP/HeRo": "hero-roberta", "imvladikon/alephbertgimmel-base-512": "alephbertgimmel", "onlplab/alephbert-base": "alephbert", # hy "xlm-roberta-base": "xlm-roberta-base", # id "indolem/indobert-base-uncased": "indobert", "indobenchmark/indobert-large-p1": "indobenchmark-large-p1", "indobenchmark/indobert-base-p1": "indobenchmark-base-p1", "indobenchmark/indobert-lite-large-p1": "indobenchmark-lite-large-p1", "indobenchmark/indobert-lite-base-p1": "indobenchmark-lite-base-p1", "indobenchmark/indobert-large-p2": "indobenchmark-large-p2", "indobenchmark/indobert-base-p2": "indobenchmark-base-p2", "indobenchmark/indobert-lite-large-p2": "indobenchmark-lite-large-p2", "indobenchmark/indobert-lite-base-p2": "indobenchmark-lite-base-p2", # it "dbmdz/electra-base-italian-xxl-cased-discriminator": "electra", # ja "rinna/japanese-roberta-base": "rinna-roberta", # mr "l3cube-pune/marathi-roberta": "l3cube-marathi-roberta", # pl "allegro/herbert-base-cased": "herbert", # pt "neuralmind/bert-large-portuguese-cased": "bertimbau", # ta: tamil "monsoon-nlp/tamillion": "tamillion", "lgessler/microbert-tamil-m": "ta-microbert-m", "lgessler/microbert-tamil-mxp": "ta-microbert-mxp", "l3cube-pune/tamil-bert": "l3cube-tamil-bert", "d42kw01f/Tamil-RoBERTa": "ta-d42kw01f-roberta", # th "airesearch/wangchanberta-base-att-spm-uncased": "wangchanberta", # tr "dbmdz/bert-base-turkish-128k-cased": "bert", # vi "vinai/phobert-base": "phobert-base", "vinai/phobert-large": "phobert-large", # zh "google-bert/bert-base-chinese": "google-bert-chinese", "hfl/chinese-bert-wwm": "hfl-bert-chinese", "hfl/chinese-macbert-large": "hfl-macbert-chinese", "hfl/chinese-roberta-wwm-ext": "hfl-roberta-chinese", "hfl/chinese-electra-180g-large-discriminator": "electra-large", "ShannonAI/ChineseBERT-base": "shannonai-chinese-bert", # multi-lingual Indic "ai4bharat/indic-bert": "indic-bert", "google/muril-base-cased": "muril-base-cased", "google/muril-large-cased": "muril-large-cased", # multi-lingual "FacebookAI/xlm-roberta-large": "xlm-roberta-large", } def known_nicknames(): """ Return a list of all the transformer nicknames We return a list so that we can sort them in decreasing key length """ nicknames = list(value for key, value in TRANSFORMER_NICKNAMES.items()) # previously unspecific transformers get "transformer" as the nickname nicknames.append("transformer") nicknames = sorted(nicknames, key=lambda x: -len(x)) return nicknames ================================================ FILE: stanza/resources/installation.py ================================================ """ Functions for setting up the environments. """ import os import logging import zipfile import shutil from stanza.resources.common import USER_CACHE_DIR, request_file, unzip, \ get_root_from_zipfile, set_logging_level logger = logging.getLogger('stanza') DEFAULT_CORENLP_MODEL_URL = os.getenv( 'CORENLP_MODEL_URL', 'https://huggingface.co/stanfordnlp/corenlp-{model}/resolve/{tag}/stanford-corenlp-models-{model}.jar' ) BACKUP_CORENLP_MODEL_URL = "http://nlp.stanford.edu/software/stanford-corenlp-{version}-models-{model}.jar" DEFAULT_CORENLP_URL = os.getenv( 'CORENLP_MODEL_URL', 'https://huggingface.co/stanfordnlp/CoreNLP/resolve/{tag}/stanford-corenlp-latest.zip' ) DEFAULT_CORENLP_DIR = os.getenv( 'CORENLP_HOME', os.path.join(USER_CACHE_DIR, 'corenlp') ) AVAILABLE_MODELS = set(['arabic', 'chinese', 'english-extra', 'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish']) def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_MODEL_URL, logging_level='INFO', proxies=None, force=True): """ A automatic way to download the CoreNLP models. Args: model: the name of the model, can be one of 'arabic', 'chinese', 'english', 'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish' version: the version of the model dir: the directory to download CoreNLP model into; alternatively can be set up with environment variable $CORENLP_HOME url: The link to download CoreNLP models. It will need {model} and either {version} or {tag} to properly format the URL logging_level: logging level to use during installation force: Download model anyway, no matter model file exists or not """ dir = os.path.expanduser(dir) if not model or not version: raise ValueError( "Both model and model version should be specified." ) logger.info(f"Downloading {model} models (version {version}) into directory {dir}") model = model.strip().lower() if model not in AVAILABLE_MODELS: raise KeyError( f'{model} is currently not supported. ' f'Must be one of: {list(AVAILABLE_MODELS)}.' ) # for example: # https://huggingface.co/stanfordnlp/CoreNLP/resolve/v4.2.2/stanford-corenlp-models-french.jar tag = version if version == 'main' else 'v' + version download_url = url.format(tag=tag, model=model, version=version) model_path = os.path.join(dir, f'stanford-corenlp-{version}-models-{model}.jar') if os.path.exists(model_path) and not force: logger.warn( f"Model file {model_path} already exists. " f"Please download this model to a new directory.") return try: request_file( download_url, model_path, proxies ) except (KeyboardInterrupt, SystemExit): raise except Exception as e: raise RuntimeError( "Downloading CoreNLP model file failed. " "Please try manual downloading at: https://stanfordnlp.github.io/CoreNLP/." ) from e def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None, version="main"): """ A fully automatic way to install and setting up the CoreNLP library to use the client functionality. Args: dir: the directory to download CoreNLP model into; alternatively can be set up with environment variable $CORENLP_HOME url: The link to download CoreNLP models Needs a {version} or {tag} parameter to specify the version logging_level: logging level to use during installation """ dir = os.path.expanduser(dir) set_logging_level(logging_level=logging_level, verbose=None) if os.path.exists(dir) and len(os.listdir(dir)) > 0: logger.warn( f"Directory {dir} already exists. " f"Please install CoreNLP to a new directory.") return logger.info(f"Installing CoreNLP package into {dir}") # First download the URL package logger.debug(f"Download to destination file: {os.path.join(dir, 'corenlp.zip')}") tag = version if version == 'main' else 'v' + version url = url.format(version=version, tag=tag) try: request_file(url, os.path.join(dir, 'corenlp.zip'), proxies) except (KeyboardInterrupt, SystemExit): raise except Exception as e: raise RuntimeError( "Downloading CoreNLP zip file failed. " "Please try manual installation: https://stanfordnlp.github.io/CoreNLP/." ) from e # Unzip corenlp into dir logger.debug("Unzipping downloaded zip file...") unzip(dir, 'corenlp.zip') # By default CoreNLP will be unzipped into a version-dependent folder, # e.g., stanford-corenlp-4.0.0. We need some hack around that and move # files back into our designated folder logger.debug(f"Moving files into the designated folder at: {dir}") corenlp_dirname = get_root_from_zipfile(os.path.join(dir, 'corenlp.zip')) corenlp_dirname = os.path.join(dir, corenlp_dirname) for f in os.listdir(corenlp_dirname): shutil.move(os.path.join(corenlp_dirname, f), dir) # Remove original zip and folder logger.debug("Removing downloaded zip file...") os.remove(os.path.join(dir, 'corenlp.zip')) shutil.rmtree(corenlp_dirname) # Warn user to set up env if dir != DEFAULT_CORENLP_DIR: logger.warning( f"For customized installation location, please set the `CORENLP_HOME` " f"environment variable to the location of the installation. " f"In Unix, this is done with `export CORENLP_HOME={dir}`.") ================================================ FILE: stanza/resources/prepare_resources.py ================================================ """ Converts a directory of models organized by type into a directory organized by language. Also produces the resources.json file. For example, on the cluster, you can do this: python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0 > resources.out 2>&1 nlprun -a stanza-1.2 -q john "python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0" -o resources.out """ import argparse from collections import defaultdict import json import os from pathlib import Path import hashlib import shutil import zipfile from stanza import __resources_version__ from stanza.models.common.constant import lcode2lang, two_to_three_letters, three_to_two_letters, extra_lang_to_lcodes from stanza.resources.default_packages import PACKAGES, TRANSFORMERS, TRANSFORMER_NICKNAMES from stanza.resources.default_packages import * from stanza.utils.datasets.prepare_lemma_classifier import DATASET_MAPPING as LEMMA_CLASSIFIER_DATASETS from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--input_dir', type=str, default="/u/nlp/software/stanza/models/current-models-%s" % __resources_version__, help='Input dir for various models. Defaults to the recommended home on the nlp cluster') parser.add_argument('--output_dir', type=str, default="/u/nlp/software/stanza/models/%s" % __resources_version__, help='Output dir for various models.') parser.add_argument('--packages_only', action='store_true', default=False, help='Only build the package maps instead of rebuilding everything') parser.add_argument('--lang', type=str, default=None, help='Only process this language or a comma-separated list of languages. If left blank, will prepare all languages. To use this argument, a previous prepared resources with all of the languages is necessary.') args = parser.parse_args() args.input_dir = os.path.abspath(args.input_dir) args.output_dir = os.path.abspath(args.output_dir) if args.lang is not None: args.lang = ",".join(args.lang.strip().split()) return args allowed_empty_languages = [ # only tokenize and NER for Myanmar right now (soon...) "my", # currently only an NER, not even a tokenizer, for Oriya "or", ] # map processor name to file ending # the order of this dict determines the order in which default.zip files are built # changing it will necessitate rebuilding all of the default.zip files # not a disaster, but it would involve a bunch of uploading processor_to_ending = { "tokenize": "tokenizer", "mwt": "mwt_expander", "lemma": "lemmatizer", "pos": "tagger", "depparse": "parser", "pretrain": "pretrain", "ner": "nertagger", "forward_charlm": "forward_charlm", "backward_charlm": "backward_charlm", "sentiment": "sentiment", "constituency": "constituency", "coref": "coref", "langid": "langid", } ending_to_processor = {j: i for i, j in processor_to_ending.items()} PROCESSORS = list(processor_to_ending.keys()) def ensure_dir(dir): Path(dir).mkdir(parents=True, exist_ok=True) def copy_file(src, dst): ensure_dir(Path(dst).parent) shutil.copy2(src, dst) def get_md5(path): data = open(path, 'rb').read() return hashlib.md5(data).hexdigest() def split_model_name(model): """ Split model names by _ Takes into account packages with _ and processor types with _ """ model = model[:-3].replace('.', '_') # sort by key length so that nertagger is checked before tagger, for example for processor in sorted(ending_to_processor.keys(), key=lambda x: -len(x)): if model.endswith(processor): model = model[:-(len(processor)+1)] processor = ending_to_processor[processor] break else: raise AssertionError(f"Could not find a processor type in {model}") lang, package = model.split('_', 1) return lang, package, processor def split_package(package, default_use_charlm=True): if package.endswith("_finetuned"): package = package[:-10] if package.endswith("_nopretrain"): package = package[:-11] return package, False, False if package.endswith("_nocharlm"): package = package[:-9] return package, True, False if package.endswith("_charlm"): package = package[:-7] return package, True, True underscore = package.rfind("_") if underscore >= 0: # +1 to skip the underscore nickname = package[underscore+1:] if nickname in known_nicknames(): return package[:underscore], True, True # guess it was a model which wasn't built with the new naming convention of putting the pretrain type at the end # assume WV and charlm... if the language / package doesn't allow for one, that should be caught later return package, True, default_use_charlm def get_pretrain_package(lang, package, model_pretrains, default_pretrains): package, uses_pretrain, _ = split_package(package) if not uses_pretrain or lang in no_pretrain_languages: return None elif model_pretrains is not None and lang in model_pretrains and package in model_pretrains[lang]: return model_pretrains[lang][package] elif lang in default_pretrains: return default_pretrains[lang] raise RuntimeError("pretrain not specified for lang %s package %s" % (lang, package)) def get_charlm_package(lang, package, model_charlms, default_charlms, default_use_charlm=True): package, _, uses_charlm = split_package(package, default_use_charlm) if not uses_charlm: return None if model_charlms is not None and lang in model_charlms and package in model_charlms[lang]: return model_charlms[lang][package] else: return default_charlms.get(lang, None) def get_con_dependencies(lang, package): # so far, this invariant is true: # constituency models use the default pretrain and charlm for the language # sometimes there is no charlm for a language that has constituency, though pretrain_package = get_pretrain_package(lang, package, None, default_pretrains) dependencies = [{'model': 'pretrain', 'package': pretrain_package}] charlm_package = default_charlms.get(lang, None) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package}) return dependencies def get_pos_charlm_package(lang, package): return get_charlm_package(lang, package, pos_charlms, default_charlms) def get_pos_dependencies(lang, package): dependencies = [] pretrain_package = get_pretrain_package(lang, package, pos_pretrains, default_pretrains) if pretrain_package is not None: dependencies.append({'model': 'pretrain', 'package': pretrain_package}) charlm_package = get_pos_charlm_package(lang, package) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package}) return dependencies def get_lemma_pretrain_package(lang, package): package, uses_pretrain, uses_charlm = split_package(package) if not uses_pretrain: return None if not uses_charlm: # currently the contextual lemma classifier is only active # for the charlm lemmatizers return None if "%s_%s" % (lang, package) not in LEMMA_CLASSIFIER_DATASETS: return None return get_pretrain_package(lang, package, {}, default_pretrains) def get_lemma_charlm_package(lang, package): return get_charlm_package(lang, package, lemma_charlms, default_charlms) def get_lemma_dependencies(lang, package): dependencies = [] pretrain_package = get_lemma_pretrain_package(lang, package) if pretrain_package is not None: dependencies.append({'model': 'pretrain', 'package': pretrain_package}) charlm_package = get_lemma_charlm_package(lang, package) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package}) return dependencies def get_tokenizer_charlm_package(lang, package): return get_charlm_package(lang, package, tokenizer_charlms, default_charlms, default_use_charlm=False) def get_tokenizer_dependencies(lang, package): dependencies = [] charlm_package = get_tokenizer_charlm_package(lang, package) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) return dependencies def get_depparse_charlm_package(lang, package): return get_charlm_package(lang, package, depparse_charlms, default_charlms) def get_depparse_dependencies(lang, package): dependencies = [] pretrain_package = get_pretrain_package(lang, package, depparse_pretrains, default_pretrains) if pretrain_package is not None: dependencies.append({'model': 'pretrain', 'package': pretrain_package}) charlm_package = get_depparse_charlm_package(lang, package) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package}) return dependencies def get_ner_charlm_package(lang, package): return get_charlm_package(lang, package, ner_charlms, default_charlms) def get_ner_pretrain_package(lang, package): return get_pretrain_package(lang, package, ner_pretrains, default_pretrains) def get_ner_dependencies(lang, package): dependencies = [] pretrain_package = get_ner_pretrain_package(lang, package) if pretrain_package is not None: dependencies.append({'model': 'pretrain', 'package': pretrain_package}) charlm_package = get_ner_charlm_package(lang, package) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package}) return dependencies def get_sentiment_dependencies(lang, package): """ Return a list of dependencies for the sentiment model Generally this will be pretrain, forward & backward charlm So far, this invariant is true: sentiment models use the default pretrain for the language also, they all use the default charlm for a language """ pretrain_package = get_pretrain_package(lang, package, None, default_pretrains) dependencies = [{'model': 'pretrain', 'package': pretrain_package}] charlm_package = default_charlms.get(lang, None) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package}) return dependencies def get_dependencies(processor, lang, package): """ Get the dependencies for a particular lang/package based on the package name The package can include descriptors such as _nopretrain, _nocharlm, _charlm which inform whether or not this particular model uses charlm or pretrain """ if processor == 'depparse': return get_depparse_dependencies(lang, package) elif processor == 'lemma': return get_lemma_dependencies(lang, package) elif processor == 'pos': return get_pos_dependencies(lang, package) elif processor == 'ner': return get_ner_dependencies(lang, package) elif processor == 'sentiment': return get_sentiment_dependencies(lang, package) elif processor == 'constituency': return get_con_dependencies(lang, package) elif processor == 'tokenize': return get_tokenizer_dependencies(lang, package) return {} def process_dirs(args): dirs = sorted(os.listdir(args.input_dir)) resources = {} if args.lang: resources = json.load(open(os.path.join(args.output_dir, 'resources.json'))) # this one language gets overridden # if this is not done, and we reuse the old resources, # any models which were deleted will still be in the resources for lang in args.lang.split(","): resources[lang] = {} for model_dir in dirs: print(f"Processing models in {model_dir}") models = sorted(os.listdir(os.path.join(args.input_dir, model_dir))) for model in tqdm(models): if not model.endswith('.pt'): continue # get processor lang, package, processor = split_model_name(model) if args.lang and lang not in args.lang.split(","): continue # copy file input_path = os.path.join(args.input_dir, model_dir, model) output_path = os.path.join(args.output_dir, lang, "models", processor, package + '.pt') copy_file(input_path, output_path) # maintain md5 md5 = get_md5(output_path) # maintain dependencies dependencies = get_dependencies(processor, lang, package) # maintain resources if lang not in resources: resources[lang] = {} if processor not in resources[lang]: resources[lang][processor] = {} if dependencies: resources[lang][processor][package] = {'md5': md5, 'dependencies': dependencies} else: resources[lang][processor][package] = {'md5': md5} print("Processed initial model directories. Writing preliminary resources.json") json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2) def get_default_pos_package(lang, ud_package): charlm_package = get_pos_charlm_package(lang, ud_package) if charlm_package is not None: return ud_package + "_charlm" if lang in no_pretrain_languages: return ud_package + "_nopretrain" return ud_package + "_nocharlm" def get_default_depparse_package(lang, ud_package): charlm_package = get_depparse_charlm_package(lang, ud_package) if charlm_package is not None: return ud_package + "_charlm" if lang in no_pretrain_languages: return ud_package + "_nopretrain" return ud_package + "_nocharlm" def process_default_zips(args): resources = json.load(open(os.path.join(args.output_dir, 'resources.json'))) for lang in resources: # check url, alias, and lang_name in case we are rerunning this step on an already built resources.json if lang == 'url': continue if 'alias' in resources[lang]: continue if all(k in ("backward_charlm", "forward_charlm", "pretrain", "lang_name") for k in resources[lang].keys()): continue if lang in allowed_empty_languages and lang not in default_treebanks: continue if lang not in default_treebanks: raise AssertionError(f'{lang} not in default treebanks!!!') if args.lang and lang not in args.lang.split(","): continue print(f'Preparing default models for language {lang}') models_needed = defaultdict(set) packages = resources[lang][PACKAGES]["default"] for processor, package in packages.items(): if processor == 'lemma' and package == 'identity': continue if processor == 'optional': continue models_needed[processor].add(package) dependencies = get_dependencies(processor, lang, package) for dependency in dependencies: models_needed[dependency['model']].add(dependency['package']) model_files = [] for processor in PROCESSORS: if processor in models_needed: for package in sorted(models_needed[processor]): filename = os.path.join(args.output_dir, lang, "models", processor, package + '.pt') if os.path.exists(filename): print(" Model {} package {}: file {}".format(processor, package, filename)) model_files.append((filename, processor, package)) else: raise FileNotFoundError(f"Processor {processor} package {package} needed for {lang} but cannot be found at {filename}") with zipfile.ZipFile(os.path.join(args.output_dir, lang, 'models', 'default.zip'), 'w', zipfile.ZIP_DEFLATED) as zipf: for filename, processor, package in model_files: zipf.write(filename=filename, arcname=os.path.join(processor, package + '.pt')) default_md5 = get_md5(os.path.join(args.output_dir, lang, 'models', 'default.zip')) resources[lang]['default_md5'] = default_md5 print("Processed default model zips. Writing resources.json") json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2) def get_default_processors(resources, lang): """ Build a default package for this language Will add each of pos, lemma, depparse, etc if those are available Uses the existing models scraped from the language directories into resources.json, as relevant """ if lang == "multilingual": return {"langid": "ud"} default_package = default_treebanks[lang] default_processors = {} if lang in default_tokenizer: default_processors['tokenize'] = default_tokenizer[lang] else: tokenize_package = default_package if tokenize_package not in resources[lang]['tokenize']: tokenize_package = tokenize_package + "_nocharlm" if tokenize_package not in resources[lang]['tokenize']: raise AssertionError("Can't find a tokenizer package for %s! Tried %s and %s" % (lang, default_package, tokenize_package)) default_processors['tokenize'] = tokenize_package if 'mwt' in resources[lang] and default_package in resources[lang]['mwt']: # if this doesn't happen, we just skip MWT default_processors['mwt'] = default_package if 'lemma' in resources[lang]: expected_lemma = default_package + "_nocharlm" if expected_lemma in resources[lang]['lemma']: default_processors['lemma'] = expected_lemma else: expected_lemma = default_package + "_charlm" if expected_lemma in resources[lang]['lemma']: default_processors['lemma'] = expected_lemma print("WARNING: nocharlm lemmatizer for %s model does not exist, but %s does" % (default_package, expected_lemma)) elif lang not in allowed_empty_languages: default_processors['lemma'] = 'identity' if 'pos' in resources[lang]: default_processors['pos'] = get_default_pos_package(lang, default_package) if default_processors['pos'] not in resources[lang]['pos']: raise AssertionError("Expected POS model not in resources: %s" % default_processors['pos']) elif lang not in allowed_empty_languages: raise AssertionError("Expected to find POS models for language %s" % lang) if 'depparse' in resources[lang]: default_processors['depparse'] = get_default_depparse_package(lang, default_package) if default_processors['depparse'] not in resources[lang]['depparse']: raise AssertionError("Expected depparse model not in resources: %s" % default_processors['depparse']) elif lang not in allowed_empty_languages: raise AssertionError("Expected to find depparse models for language %s" % lang) if lang in default_ners: default_processors['ner'] = default_ners[lang] if lang in default_sentiment: default_processors['sentiment'] = default_sentiment[lang] if lang in default_constituency: default_processors['constituency'] = default_constituency[lang] optional = get_default_optional_processors(resources, lang) if optional: default_processors['optional'] = optional return default_processors def get_default_optional_processors(resources, lang): optional_processors = {} if lang in optional_constituency: optional_processors['constituency'] = optional_constituency[lang] if lang in optional_coref: optional_processors['coref'] = optional_coref[lang] return optional_processors def update_processor_add_transformer(resources, lang, current_processors, processor, transformer): if processor not in current_processors: return new_model = current_processors[processor].replace('_charlm', "_" + transformer).replace('_nocharlm', "_" + transformer) if new_model in resources[lang][processor]: current_processors[processor] = new_model else: print("WARNING: wanted to use %s for %s accurate %s, but that model does not exist" % (new_model, lang, processor)) def get_default_accurate(resources, lang): """ A package that, if available, uses charlm and transformer models for each processor """ default_processors = get_default_processors(resources, lang) tokenizer_model = default_processors['tokenize'] if tokenizer_model.endswith('_nocharlm'): tokenizer_model = tokenizer_model.replace('_nocharlm', '_charlm') elif 'charlm' not in tokenizer_model: tokenizer_model = tokenizer_model + '_charlm' if tokenizer_model.endswith('_charlm') and tokenizer_model in resources[lang]['tokenize']: default_processors['tokenize'] = tokenizer_model print("TOKENIZE found a charlm version %s for %s default_accurate" % (tokenizer_model, lang)) if 'lemma' in default_processors and default_processors['lemma'] != 'identity': lemma_model = default_processors['lemma'] lemma_model = lemma_model.replace('_nocharlm', '_charlm') charlm_package = get_lemma_charlm_package(lang, lemma_model) if charlm_package is not None: if lemma_model in resources[lang]['lemma']: default_processors['lemma'] = lemma_model else: print("WARNING: wanted to use %s for %s default_accurate lemma, but that model does not exist" % (lemma_model, lang)) transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None) if transformer is not None: for processor in ('pos', 'depparse', 'constituency', 'sentiment'): update_processor_add_transformer(resources, lang, default_processors, processor, transformer) if 'ner' in default_processors and (default_processors['ner'].endswith("_charlm") or default_processors['ner'].endswith("_nocharlm")): update_processor_add_transformer(resources, lang, default_processors, "ner", transformer) optional = get_optional_accurate(resources, lang) if optional: default_processors['optional'] = optional return default_processors def get_optional_accurate(resources, lang): optional_processors = get_default_optional_processors(resources, lang) transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None) if transformer is not None: for processor in ('pos', 'depparse', 'constituency', 'sentiment'): update_processor_add_transformer(resources, lang, optional_processors, processor, transformer) if lang in optional_coref: optional_processors['coref'] = optional_coref[lang] return optional_processors def get_default_fast(resources, lang): """ Build a packages entry which only has the nocharlm models Will make it easy for people to use the lower tier of models We do this by building the same default package as normal, then switching everything out for the lower tier model when possible. We also remove constituency, as it is super slow. Note that in the case of a language which doesn't have a charlm, that means we wind up building the same for default and default_nocharlm """ default_processors = get_default_processors(resources, lang) # this is a slow model and we don't have non-charlm versions of it yet if 'constituency' in default_processors: default_processors.pop('constituency') for processor, model in default_processors.items(): if "_charlm" in model: nocharlm = model.replace("_charlm", "_nocharlm") if nocharlm not in resources[lang][processor]: print("WARNING: wanted to use %s for %s default_fast processor %s, but that model does not exist" % (nocharlm, lang, processor)) else: default_processors[processor] = nocharlm return default_processors def process_packages(args): """ Build a package for a language's default processors and all of the treebanks specifically used for that language """ resources = json.load(open(os.path.join(args.output_dir, 'resources.json'))) for lang in resources: # check url, alias, and lang_name in case we are rerunning this step on an already built resources.json if lang == 'url': continue if 'alias' in resources[lang]: continue if all(k in ("backward_charlm", "forward_charlm", "pretrain", "lang_name") for k in resources[lang].keys()): continue if lang in allowed_empty_languages and lang not in default_treebanks: continue if lang not in default_treebanks: raise AssertionError(f'{lang} not in default treebanks!!!') if args.lang and lang not in args.lang.split(","): continue default_processors = get_default_processors(resources, lang) # TODO: eventually we can remove default_processors # For now, we want to keep this so that v1.5.1 is compatible # with the next iteration of resources files resources[lang]['default_processors'] = default_processors resources[lang][PACKAGES] = {} resources[lang][PACKAGES]['default'] = default_processors if lang not in no_pretrain_languages and lang != "multilingual": default_fast = get_default_fast(resources, lang) resources[lang][PACKAGES]['default_fast'] = default_fast default_accurate = get_default_accurate(resources, lang) resources[lang][PACKAGES]['default_accurate'] = default_accurate # Now we loop over each of the tokenizers for this language # ... we use this as a proxy for the available UD treebanks # This loop also catches things such as "craft" which are # included treebanks that aren't UD # We then create a package in the packages dict for each of those treebanks if 'tokenize' in resources[lang]: for package in resources[lang]['tokenize']: package, _, _ = split_package(package) if package in resources[lang][PACKAGES]: # can happen in the case of a _nocharlm and _charlm version of the tokenizer continue processors = {} # TODO: when we rebuild all the models, make all the tokenizers say _nocharlm if package in resources[lang]['tokenize']: processors["tokenize"] = package elif package + "_nocharlm" in resources[lang]['tokenize']: processors["tokenize"] = package + "_nocharlm" else: raise AssertionError("Should have found a tokenizer for lang %s package %s" % (lang, package)) if "mwt" in resources[lang] and package in resources[lang]["mwt"]: processors["mwt"] = package if "pos" in resources[lang]: if package + "_charlm" in resources[lang]["pos"]: processors["pos"] = package + "_charlm" elif package + "_nocharlm" in resources[lang]["pos"]: processors["pos"] = package + "_nocharlm" if "lemma" in resources[lang] and "pos" in processors: lemma_package = package + "_nocharlm" if lemma_package in resources[lang]["lemma"]: processors["lemma"] = lemma_package else: lemma_package = package + "_charlm" if lemma_package in resources[lang]['lemma']: processors['lemma'] = lemma_package print("WARNING: nocharlm lemmatizer for %s model does not exist, but %s does" % (package, lemma_package)) if "depparse" in resources[lang] and "pos" in processors: depparse_package = None if package + "_charlm" in resources[lang]["depparse"]: depparse_package = package + "_charlm" elif package + "_nocharlm" in resources[lang]["depparse"]: depparse_package = package + "_nocharlm" # we want to set the lemma first if it's identity # THEN set the depparse if depparse_package is not None: if "lemma" not in processors: processors["lemma"] = "identity" processors["depparse"] = depparse_package resources[lang][PACKAGES][package] = processors print("Processed packages. Writing resources.json") json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2) def process_lcode(args): resources = json.load(open(os.path.join(args.output_dir, 'resources.json'))) resources_new = {} resources_new["multilingual"] = resources["multilingual"] for lang in resources: if lang == 'multilingual': continue if 'alias' in resources[lang]: continue if lang not in lcode2lang: print(lang + ' not found in lcode2lang!') continue lang_name = lcode2lang[lang] resources[lang]['lang_name'] = lang_name resources_new[lang.lower()] = resources[lang.lower()] resources_new[lang_name.lower()] = {'alias': lang.lower()} if lang.lower() in two_to_three_letters: resources_new[two_to_three_letters[lang.lower()]] = {'alias': lang.lower()} elif lang.lower() in three_to_two_letters: resources_new[three_to_two_letters[lang.lower()]] = {'alias': lang.lower()} if lang.lower() in extra_lang_to_lcodes: alternative = extra_lang_to_lcodes[lang.lower()].lower() if alternative not in resources_new: resources_new[alternative] = {'alias': lang.lower()} print("Processed lcode aliases. Writing resources.json") json.dump(resources_new, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2) def process_misc(args): resources = json.load(open(os.path.join(args.output_dir, 'resources.json'))) resources['no'] = {'alias': 'nb'} resources['zh'] = {'alias': 'zh-hans'} # This is intended to be unformatted. expand_model_url in common.py will fill in the raw string # with the appropriate values in order to find the needed model file on huggingface resources['url'] = 'https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}' print("Finalized misc attributes. Writing resources.json") json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2) def main(): args = parse_args() print("Converting models from %s to %s" % (args.input_dir, args.output_dir)) if not args.packages_only: process_dirs(args) process_packages(args) if not args.packages_only: process_default_zips(args) process_lcode(args) process_misc(args) if __name__ == '__main__': main() ================================================ FILE: stanza/resources/print_charlm_depparse.py ================================================ """ A small utility script to output which depparse models use charlm (It should skip en_genia, en_craft, but currently doesn't) Not frequently useful, but seems like the kind of thing that might get used a couple times """ from stanza.resources.common import load_resources_json from stanza.resources.default_packages import default_charlms, depparse_charlms def list_depparse(): charlm_langs = list(default_charlms.keys()) resources = load_resources_json() models = ["%s_%s" % (lang, model) for lang in charlm_langs for model in resources[lang].get("depparse", {}) if lang not in depparse_charlms or model not in depparse_charlms[lang] or depparse_charlms[lang][model] is not None] return models if __name__ == "__main__": models = list_depparse() print(" ".join(models)) ================================================ FILE: stanza/server/__init__.py ================================================ from stanza.protobuf import to_text from stanza.protobuf import Document, Sentence, Token, IndexedWord, Span from stanza.protobuf import ParseTree, DependencyGraph, CorefChain from stanza.protobuf import Mention, NERMention, Entity, Relation, RelationTriple, Timex from stanza.protobuf import Quote, SpeakerInfo from stanza.protobuf import Operator, Polarity from stanza.protobuf import SentenceFragment, TokenLocation from stanza.protobuf import MapStringString, MapIntString from .client import CoreNLPClient, AnnotationException, TimeoutException, PermanentlyFailedException, StartServer from .annotator import Annotator ================================================ FILE: stanza/server/annotator.py ================================================ """ Defines a base class that can be used to annotate. """ import io from multiprocessing import Process from http.server import BaseHTTPRequestHandler, HTTPServer from http import client as HTTPStatus from stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString class Annotator(Process): """ This annotator base class hosts a lightweight server that accepts annotation requests from CoreNLP. Each annotator simply defines 3 functions: requires, provides and annotate. This class takes care of defining appropriate endpoints to interface with CoreNLP. """ @property def name(self): """ Name of the annotator (used by CoreNLP) """ raise NotImplementedError() @property def requires(self): """ Requires has to specify all the annotations required before we are called. """ raise NotImplementedError() @property def provides(self): """ The set of annotations guaranteed to be provided when we are done. NOTE: that these annotations are either fully qualified Java class names or refer to nested classes of edu.stanford.nlp.ling.CoreAnnotations (as is the case below). """ raise NotImplementedError() def annotate(self, ann): """ @ann: is a protobuf annotation object. Actually populate @ann with tokens. """ raise NotImplementedError() @property def properties(self): """ Defines a Java property to define this annotator to CoreNLP. """ return { "customAnnotatorClass.{}".format(self.name): "edu.stanford.nlp.pipeline.GenericWebServiceAnnotator", "generic.endpoint": "http://{}:{}".format(self.host, self.port), "generic.requires": ",".join(self.requires), "generic.provides": ",".join(self.provides), } class _Handler(BaseHTTPRequestHandler): annotator = None def __init__(self, request, client_address, server): BaseHTTPRequestHandler.__init__(self, request, client_address, server) def do_GET(self): """ Handle a ping request """ if not self.path.endswith("/"): self.path += "/" if self.path == "/ping/": msg = "pong".encode("UTF-8") self.send_response(HTTPStatus.OK) self.send_header("Content-Type", "text/application") self.send_header("Content-Length", len(msg)) self.end_headers() self.wfile.write(msg) else: self.send_response(HTTPStatus.BAD_REQUEST) self.end_headers() def do_POST(self): """ Handle an annotate request """ if not self.path.endswith("/"): self.path += "/" if self.path == "/annotate/": # Read message length = int(self.headers.get('content-length')) msg = self.rfile.read(length) # Do the annotation doc = Document() parseFromDelimitedString(doc, msg) self.annotator.annotate(doc) with io.BytesIO() as stream: writeToDelimitedString(doc, stream) msg = stream.getvalue() # write message self.send_response(HTTPStatus.OK) self.send_header("Content-Type", "application/x-protobuf") self.send_header("Content-Length", len(msg)) self.end_headers() self.wfile.write(msg) else: self.send_response(HTTPStatus.BAD_REQUEST) self.end_headers() def __init__(self, host="", port=8432): """ Launches a server endpoint to communicate with CoreNLP """ Process.__init__(self) self.host, self.port = host, port self._Handler.annotator = self def run(self): """ Runs the server using Python's simple HTTPServer. TODO: make this multithreaded. """ httpd = HTTPServer((self.host, self.port), self._Handler) sa = httpd.socket.getsockname() serve_message = "Serving HTTP on {host} port {port} (http://{host}:{port}/) ..." print(serve_message.format(host=sa[0], port=sa[1])) try: httpd.serve_forever() except KeyboardInterrupt: print("\nKeyboard interrupt received, exiting.") httpd.shutdown() ================================================ FILE: stanza/server/client.py ================================================ """ Client for accessing Stanford CoreNLP in Python """ import atexit import contextlib import enum import io import os import re import requests import logging import json import shlex import socket import subprocess import time import sys import uuid from datetime import datetime from pathlib import Path from urllib.parse import urlparse from stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString, to_text __author__ = 'arunchaganty, kelvinguu, vzhong, wmonroe4' logger = logging.getLogger('stanza') # pattern tmp props file should follow SERVER_PROPS_TMP_FILE_PATTERN = re.compile('corenlp_server-(.*).props') # Check if str is CoreNLP supported language CORENLP_LANGS = ['ar', 'arabic', 'chinese', 'zh', 'english', 'en', 'french', 'fr', 'de', 'german', 'hu', 'hungarian', 'it', 'italian', 'es', 'spanish'] # map shorthands to full language names LANGUAGE_SHORTHANDS_TO_FULL = { "ar": "arabic", "zh": "chinese", "en": "english", "fr": "french", "de": "german", "hu": "hungarian", "it": "italian", "es": "spanish" } def is_corenlp_lang(props_str): """ Check if a string references a CoreNLP language """ return props_str.lower() in CORENLP_LANGS # Validate CoreNLP properties CORENLP_OUTPUT_VALS = ["conll", "conllu", "json", "serialized", "text", "xml", "inlinexml"] def validate_corenlp_props(properties=None, annotators=None, output_format=None): """ Do basic checks to validate CoreNLP properties """ if output_format and output_format.lower() not in CORENLP_OUTPUT_VALS: raise ValueError(f"{output_format} not a valid CoreNLP outputFormat value! Choose from: {CORENLP_OUTPUT_VALS}") if type(properties) == dict: if "outputFormat" in properties and properties["outputFormat"].lower() not in CORENLP_OUTPUT_VALS: raise ValueError(f"{properties['outputFormat']} not a valid CoreNLP outputFormat value! Choose from: " f"{CORENLP_OUTPUT_VALS}") class AnnotationException(Exception): """ Exception raised when there was an error communicating with the CoreNLP server. """ pass class TimeoutException(AnnotationException): """ Exception raised when the CoreNLP server timed out. """ pass class ShouldRetryException(Exception): """ Exception raised if the service should retry the request. """ pass class PermanentlyFailedException(Exception): """ Exception raised if the service should NOT retry the request. """ pass class StartServer(enum.Enum): DONT_START = 0 FORCE_START = 1 TRY_START = 2 def clean_props_file(props_file): # check if there is a temp server props file to remove and remove it if props_file: if os.path.isfile(props_file) and SERVER_PROPS_TMP_FILE_PATTERN.match(os.path.basename(props_file)): os.remove(props_file) class RobustService(object): """ Service that resuscitates itself if it is not available. """ CHECK_ALIVE_TIMEOUT = 120 def __init__(self, start_cmd, stop_cmd, endpoint, stdout=None, stderr=None, be_quiet=False, host=None, port=None, ignore_binding_error=False): self.start_cmd = start_cmd and shlex.split(start_cmd) self.stop_cmd = stop_cmd and shlex.split(stop_cmd) self.endpoint = endpoint self.stdout = stdout self.stderr = stderr self.server = None self.is_active = False self.be_quiet = be_quiet self.host = host self.port = port self.ignore_binding_error = ignore_binding_error atexit.register(self.atexit_kill) def is_alive(self): try: if not self.ignore_binding_error and self.server is not None and self.server.poll() is not None: return False return requests.get(self.endpoint + "/ping").ok except requests.exceptions.ConnectionError as e: raise ShouldRetryException(e) def start(self): if self.start_cmd: if self.host and self.port: with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: try: sock.bind((self.host, self.port)) except socket.error as e: if self.ignore_binding_error: logger.info(f"Connecting to existing CoreNLP server at {self.host}:{self.port}") self.server = None return else: raise PermanentlyFailedException("Error: unable to start the CoreNLP server on port %d " "(possibly something is already running there)" % self.port) from e if self.be_quiet: # Issue #26: subprocess.DEVNULL isn't supported in python 2.7. if hasattr(subprocess, 'DEVNULL'): stderr = subprocess.DEVNULL else: stderr = open(os.devnull, 'w') stdout = stderr else: stdout = self.stdout stderr = self.stderr logger.info(f"Starting server with command: {' '.join(self.start_cmd)}") try: self.server = subprocess.Popen(self.start_cmd, stderr=stderr, stdout=stdout) except FileNotFoundError as e: raise FileNotFoundError("When trying to run CoreNLP, a FileNotFoundError occurred, which frequently means Java was not installed or was not in the classpath.") from e def atexit_kill(self): # make some kind of effort to stop the service (such as a # CoreNLP server) at the end of the program. not waiting so # that the python script exiting isn't delayed if self.server and self.server.poll() is None: self.server.terminate() def stop(self): if self.server: self.server.terminate() try: self.server.wait(5) except subprocess.TimeoutExpired: # Resorting to more aggressive measures... self.server.kill() try: self.server.wait(5) except subprocess.TimeoutExpired: # oh well pass self.server = None if self.stop_cmd: subprocess.run(self.stop_cmd, check=True) self.is_active = False def __enter__(self): self.start() return self def __exit__(self, _, __, ___): self.stop() def ensure_alive(self): # Check if the service is active and alive if self.is_active: try: if self.is_alive(): return else: self.stop() except ShouldRetryException: pass # If not, try to start up the service. if self.server is None: self.start() # Wait for the service to start up. start_time = time.time() while True: try: if self.is_alive(): break except ShouldRetryException: pass if time.time() - start_time < self.CHECK_ALIVE_TIMEOUT: time.sleep(1) else: raise PermanentlyFailedException("Timed out waiting for service to come alive.") # At this point we are guaranteed that the service is alive. self.is_active = True def resolve_classpath(classpath=None): """ Returns the classpath to use for corenlp. Prefers to use the given classpath parameter, if available. If not, uses the CORENLP_HOME environment variable. Resolves $CLASSPATH (the exact string) in either the classpath parameter or $CORENLP_HOME. """ if classpath == '$CLASSPATH' or (classpath is None and os.getenv("CORENLP_HOME", None) == '$CLASSPATH'): classpath = os.getenv("CLASSPATH") elif classpath is None: classpath = os.getenv("CORENLP_HOME", os.path.join(str(Path.home()), 'stanza_corenlp')) if not os.path.exists(classpath): raise FileNotFoundError("Please install CoreNLP by running `stanza.install_corenlp()`. If you have installed it, please define " "$CORENLP_HOME to be location of your CoreNLP distribution or pass in a classpath parameter. " "$CORENLP_HOME={}".format(os.getenv("CORENLP_HOME"))) classpath = os.path.join(classpath, "*") return classpath class CoreNLPClient(RobustService): """ A client to the Stanford CoreNLP server. """ DEFAULT_ENDPOINT = "http://localhost:9000" DEFAULT_TIMEOUT = 60000 DEFAULT_THREADS = 5 DEFAULT_OUTPUT_FORMAT = "serialized" DEFAULT_MEMORY = "5G" DEFAULT_MAX_CHAR_LENGTH = 100000 def __init__(self, start_server=StartServer.FORCE_START, endpoint=DEFAULT_ENDPOINT, timeout=DEFAULT_TIMEOUT, threads=DEFAULT_THREADS, annotators=None, pretokenized=False, output_format=None, properties=None, stdout=None, stderr=None, memory=DEFAULT_MEMORY, be_quiet=False, max_char_length=DEFAULT_MAX_CHAR_LENGTH, preload=True, classpath=None, **kwargs): # whether or not server should be started by client self.start_server = start_server self.server_props_path = None self.server_start_time = None self.server_host = None self.server_port = None self.server_classpath = None # validate properties validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format) # set up client defaults self.properties = properties self.annotators = annotators self.pretokenized = pretokenized self.output_format = output_format self._setup_client_defaults() # start the server if isinstance(start_server, bool): warning_msg = f"Setting 'start_server' to a boolean value when constructing {self.__class__.__name__} is deprecated and will stop" + \ " to function in a future version of stanza. Please consider switching to using a value from stanza.server.StartServer." logger.warning(warning_msg) start_server = StartServer.FORCE_START if start_server is True else StartServer.DONT_START # start the server if start_server is StartServer.FORCE_START or start_server is StartServer.TRY_START: # record info for server start self.server_start_time = datetime.now() # set up default properties for server self._setup_server_defaults() host, port = urlparse(endpoint).netloc.split(":") port = int(port) assert host == "localhost", "If starting a server, endpoint must be localhost" classpath = resolve_classpath(classpath) start_cmd = f"java -Xmx{memory} -cp '{classpath}' edu.stanford.nlp.pipeline.StanfordCoreNLPServer " \ f"-port {port} -timeout {timeout} -threads {threads} -maxCharLength {max_char_length} " \ f"-quiet {be_quiet} " self.server_classpath = classpath self.server_host = host self.server_port = port # set up server defaults if self.server_props_path is not None: start_cmd += f" -serverProperties {self.server_props_path}" # possibly set pretokenized if self.pretokenized: start_cmd += f" -preTokenized" # set annotators for server default if self.annotators is not None: annotators_str = self.annotators if type(annotators) == str else ",".join(annotators) start_cmd += f" -annotators {annotators_str}" # specify what to preload, if anything if preload: if type(preload) == bool: # -preload flag means to preload all default annotators start_cmd += " -preload" elif type(preload) == list: # turn list into comma separated list string, only preload these annotators start_cmd += f" -preload {','.join(preload)}" elif type(preload) == str: # comma separated list of annotators start_cmd += f" -preload {preload}" # set outputFormat for server default # if no output format requested by user, set to serialized start_cmd += f" -outputFormat {self.output_format}" # additional options for server: # - server_id # - ssl # - status_port # - uriContext # - strict # - key # - username # - password # - blockList for kw in ['ssl', 'strict']: if kwargs.get(kw) is not None: start_cmd += f" -{kw}" for kw in ['status_port', 'uriContext', 'key', 'username', 'password', 'blockList', 'server_id']: if kwargs.get(kw) is not None: start_cmd += f" -{kw} {kwargs.get(kw)}" stop_cmd = None else: start_cmd = stop_cmd = None host = port = None super(CoreNLPClient, self).__init__(start_cmd, stop_cmd, endpoint, stdout, stderr, be_quiet, host=host, port=port, ignore_binding_error=(start_server == StartServer.TRY_START)) self.timeout = timeout def _setup_client_defaults(self): """ Do some processing of annotators and output_format specified for the client. If interacting with an externally started server, these will be defaults for annotate() calls. :return: None """ # normalize annotators to str if self.annotators is not None: self.annotators = self.annotators if type(self.annotators) == str else ",".join(self.annotators) # handle case where no output format is specified if self.output_format is None: if type(self.properties) == dict and 'outputFormat' in self.properties: self.output_format = self.properties['outputFormat'] else: self.output_format = CoreNLPClient.DEFAULT_OUTPUT_FORMAT def _setup_server_defaults(self): """ Set up the default properties for the server. The properties argument can take on one of 3 value types 1. File path on system or in CLASSPATH (e.g. /path/to/server.props or StanfordCoreNLP-french.properties 2. Name of a Stanford CoreNLP supported language (e.g. french or fr) 3. Python dictionary (properties written to tmp file for Java server, erased at end) In addition, an annotators list and output_format can be specified directly with arguments. These will overwrite any settings in the specified properties. If no properties are specified, the standard Stanford CoreNLP English server will be launched. The outputFormat will be set to 'serialized' and use the ProtobufAnnotationSerializer. """ # ensure properties is str or dict if self.properties is None or (not isinstance(self.properties, str) and not isinstance(self.properties, dict)): if self.properties is not None: logger.warning('properties passed invalid value (not a str or dict), setting properties = {}') self.properties = {} # check if properties is a string, pass on to server which can handle if isinstance(self.properties, str): # try to translate to Stanford CoreNLP language name, or assume properties is a path if is_corenlp_lang(self.properties): if self.properties.lower() in LANGUAGE_SHORTHANDS_TO_FULL: self.properties = LANGUAGE_SHORTHANDS_TO_FULL[self.properties] logger.info( f"Using CoreNLP default properties for: {self.properties}. Make sure to have " f"{self.properties} models jar (available for download here: " f"https://stanfordnlp.github.io/CoreNLP/) in CLASSPATH") else: if not os.path.isfile(self.properties): logger.warning(f"{self.properties} does not correspond to a file path. Make sure this file is in " f"your CLASSPATH.") self.server_props_path = self.properties elif isinstance(self.properties, dict): # make a copy server_start_properties = dict(self.properties) if self.annotators is not None: server_start_properties['annotators'] = self.annotators if self.output_format is not None and isinstance(self.output_format, str): server_start_properties['outputFormat'] = self.output_format # write desired server start properties to tmp file # set up to erase on exit tmp_path = write_corenlp_props(server_start_properties) logger.info(f"Writing properties to tmp file: {tmp_path}") atexit.register(clean_props_file, tmp_path) self.server_props_path = tmp_path def _request(self, buf, properties, reset_default=False, **kwargs): """ Send a request to the CoreNLP server. :param (str | bytes) buf: data to be sent with the request :param (dict) properties: properties that the server expects :return: request result """ if self.start_server is not StartServer.DONT_START: self.ensure_alive() try: input_format = properties.get("inputFormat", "text") if input_format == "text": ctype = "text/plain; charset=utf-8" elif input_format == "serialized": ctype = "application/x-protobuf" else: raise ValueError("Unrecognized inputFormat " + input_format) # handle auth if 'username' in kwargs and 'password' in kwargs: kwargs['auth'] = requests.auth.HTTPBasicAuth(kwargs['username'], kwargs['password']) kwargs.pop('username') kwargs.pop('password') r = requests.post(self.endpoint, params={'properties': str(properties), 'resetDefault': str(reset_default).lower()}, data=buf, headers={'content-type': ctype}, timeout=(self.timeout*2)/1000, **kwargs) r.raise_for_status() return r except requests.exceptions.Timeout as e: raise TimeoutException("Timeout requesting to CoreNLPServer. Maybe server is unavailable or your document is too long") except requests.exceptions.RequestException as e: if e.response is not None and e.response.text is not None: raise AnnotationException(e.response.text) from e elif e.args: raise AnnotationException(e.args[0]) from e raise AnnotationException() from e def annotate(self, text, annotators=None, output_format=None, properties=None, reset_default=None, **kwargs): """ Send a request to the CoreNLP server. :param (str | unicode) text: raw text for the CoreNLPServer to parse :param (list | string) annotators: list of annotators to use :param (str) output_format: output type from server: serialized, json, text, conll, conllu, or xml :param (dict) properties: additional request properties (written on top of defaults) :param (bool) reset_default: don't use server defaults Precedence for settings: 1. annotators and output_format args 2. Values from properties dict 3. Client defaults self.annotators and self.output_format (set during client construction) 4. Server defaults Additional request parameters (apart from CoreNLP pipeline properties) such as 'username' and 'password' can be specified with the kwargs. :return: request result """ # validate request properties validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format) # set request properties request_properties = {} # start with client defaults if self.annotators is not None: request_properties['annotators'] = self.annotators if self.output_format is not None: request_properties['outputFormat'] = self.output_format # add values from properties arg # handle str case if type(properties) == str: if is_corenlp_lang(properties): properties = {'pipelineLanguage': properties.lower()} if reset_default is None: reset_default = True else: raise ValueError(f"Unrecognized properties keyword {properties}") if type(properties) == dict: request_properties.update(properties) # if annotators list is specified, override with that # also can use the annotators field the object was created with if annotators is not None and (type(annotators) == str or type(annotators) == list): request_properties['annotators'] = annotators if type(annotators) == str else ",".join(annotators) # if output format is specified, override with that if output_format is not None and type(output_format) == str: request_properties['outputFormat'] = output_format # make the request # if not explicitly set or the case of pipelineLanguage, reset_default should be None if reset_default is None: reset_default = False r = self._request(text.encode('utf-8'), request_properties, reset_default, **kwargs) if request_properties["outputFormat"] == "json": return r.json() elif request_properties["outputFormat"] == "serialized": doc = Document() parseFromDelimitedString(doc, r.content) return doc elif request_properties["outputFormat"] in ["text", "conllu", "conll", "xml"]: return r.text else: return r def update(self, doc, annotators=None, properties=None): if properties is None: properties = {} properties.update({ 'inputFormat': 'serialized', 'outputFormat': 'serialized', 'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer' }) if annotators: properties['annotators'] = annotators if type(annotators) == str else ",".join(annotators) with io.BytesIO() as stream: writeToDelimitedString(doc, stream) msg = stream.getvalue() r = self._request(msg, properties) doc = Document() parseFromDelimitedString(doc, r.content) return doc def tokensregex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None): # this is required for some reason matches = self.__regex('/tokensregex', text, pattern, filter, annotators, properties) if to_words: matches = regex_matches_to_indexed_words(matches) return matches def semgrex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None): matches = self.__regex('/semgrex', text, pattern, filter, annotators, properties) if to_words: matches = regex_matches_to_indexed_words(matches) return matches def fill_tree_proto(self, tree, proto_tree): if tree.label: proto_tree.value = tree.label for child in tree.children: proto_child = proto_tree.child.add() self.fill_tree_proto(child, proto_child) def tregex(self, text=None, pattern=None, filter=False, annotators=None, properties=None, trees=None): # parse is not included by default in some of the pipelines, # so we may need to manually override the annotators # to include parse in order for tregex to do anything if annotators is None and self.annotators is not None: assert isinstance(self.annotators, str) pieces = self.annotators.split(",") if "parse" not in pieces: annotators = self.annotators + ",parse" else: annotators = "tokenize,ssplit,pos,parse" if pattern is None: raise ValueError("Cannot have None as a pattern for tregex") # TODO: we could also allow for passing in a complete document, # along with the original text, so that the spans returns are more accurate if trees is not None: if properties is None: properties = {} properties['inputFormat'] = 'serialized' if 'serializer' not in properties: properties['serializer'] = 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer' doc = Document() full_text = [] for tree_idx, tree in enumerate(trees): sentence = doc.sentence.add() sentence.sentenceIndex = tree_idx sentence.tokenOffsetBegin = len(full_text) leaves = tree.leaf_labels() full_text.extend(leaves) sentence.tokenOffsetEnd = len(full_text) self.fill_tree_proto(tree, sentence.parseTree) for word in leaves: token = sentence.token.add() # the other side uses both value and word, weirdly enough token.value = word token.word = word # without the actual tokenization, at least we can # stop the words from running together token.after = " " doc.text = " ".join(full_text) with io.BytesIO() as stream: writeToDelimitedString(doc, stream) text = stream.getvalue() return self.__regex('/tregex', text, pattern, filter, annotators, properties) def __regex(self, path, text, pattern, filter, annotators=None, properties=None): """ Send a regex-related request to the CoreNLP server. :param (str | unicode) path: the path for the regex endpoint :param text: raw text for the CoreNLPServer to apply the regex :param (str | unicode) pattern: regex pattern :param (bool) filter: option to filter sentences that contain matches, if false returns matches :param properties: option to filter sentences that contain matches, if false returns matches :return: request result """ if self.start_server is not StartServer.DONT_START: self.ensure_alive() if properties is None: properties = {} properties.update({ 'inputFormat': 'text', 'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer' }) if annotators: properties['annotators'] = ",".join(annotators) if isinstance(annotators, list) else annotators # force output for regex requests to be json properties['outputFormat'] = 'json' # if the server is trying to send back character offsets, it # should send back codepoints counts as well in case the text # has extra wide characters properties['tokenize.codepoint'] = 'true' try: # Error occurs unless put properties in params input_format = properties.get("inputFormat", "text") if input_format == "text": ctype = "text/plain; charset=utf-8" elif input_format == "serialized": ctype = "application/x-protobuf" else: raise ValueError("Unrecognized inputFormat " + input_format) # change request method from `get` to `post` as required by CoreNLP r = requests.post( self.endpoint + path, params={ 'pattern': pattern, 'filter': filter, 'properties': str(properties) }, data=text.encode('utf-8') if isinstance(text, str) else text, headers={'content-type': ctype}, timeout=(self.timeout*2)/1000, ) r.raise_for_status() if r.encoding is None: r.encoding = "utf-8" return json.loads(r.text) except requests.HTTPError as e: if r.text.startswith("Timeout"): raise TimeoutException(r.text) else: raise AnnotationException(r.text) except json.JSONDecodeError: raise AnnotationException(r.text) def scenegraph(self, text, properties=None): """ Send a request to the server which processes the text using SceneGraph This will require a new CoreNLP release, 4.5.5 or later """ # since we're using requests ourself, # check if the server has started or not if self.start_server is not StartServer.DONT_START: self.ensure_alive() if properties is None: properties = {} # the only thing the scenegraph knows how to use is text properties['inputFormat'] = 'text' ctype = "text/plain; charset=utf-8" # the json output format is much more useful properties['outputFormat'] = 'json' try: r = requests.post( self.endpoint + "/scenegraph", params={ 'properties': str(properties) }, data=text.encode('utf-8') if isinstance(text, str) else text, headers={'content-type': ctype}, timeout=(self.timeout*2)/1000, ) r.raise_for_status() if r.encoding is None: r.encoding = "utf-8" return json.loads(r.text) except requests.HTTPError as e: if r.text.startswith("Timeout"): raise TimeoutException(r.text) else: raise AnnotationException(r.text) except json.JSONDecodeError: raise AnnotationException(r.text) def read_corenlp_props(props_path): """ Read a Stanford CoreNLP properties file into a dict """ props_dict = {} with open(props_path) as props_file: entry_lines = [entry_line for entry_line in props_file.read().split('\n') if entry_line.strip() and not entry_line.startswith('#')] for entry_line in entry_lines: k = entry_line.split('=')[0] k_len = len(k+"=") v = entry_line[k_len:] props_dict[k.strip()] = v return props_dict def write_corenlp_props(props_dict, file_path=None): """ Write a Stanford CoreNLP properties dict to a file """ if file_path is None: file_path = f"corenlp_server-{uuid.uuid4().hex[:16]}.props" # confirm tmp file path matches pattern assert SERVER_PROPS_TMP_FILE_PATTERN.match(file_path) with open(file_path, 'w') as props_file: for k, v in props_dict.items(): if isinstance(v, list): writeable_v = ",".join(v) else: writeable_v = v props_file.write(f'{k} = {writeable_v}\n\n') return file_path def regex_matches_to_indexed_words(matches): """ Transforms tokensregex and semgrex matches to indexed words. :param matches: unprocessed regex matches :return: flat array of indexed words """ words = [dict(v, **dict([('sentence', i)])) for i, s in enumerate(matches['sentences']) for k, v in s.items() if k != 'length'] return words __all__ = ["CoreNLPClient", "AnnotationException", "TimeoutException", "to_text"] ================================================ FILE: stanza/server/dependency_converter.py ================================================ """ A converter from constituency trees to dependency trees using CoreNLP's UniversalEnglish converter. ONLY works on English. """ import stanza from stanza.protobuf import DependencyConverterRequest, DependencyConverterResponse from stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext CONVERTER_JAVA = "edu.stanford.nlp.trees.ProcessDependencyConverterRequest" def send_converter_request(request, classpath=None): return send_request(request, DependencyConverterResponse, CONVERTER_JAVA, classpath=classpath) def build_request(doc): """ Request format is simple: one tree per sentence in the document """ request = DependencyConverterRequest() for sentence in doc.sentences: request.trees.append(build_tree(sentence.constituency, None)) return request def process_doc(doc, classpath=None): """ Convert the constituency trees in the document, then attach the resulting dependencies to the sentences """ request = build_request(doc) response = send_converter_request(request, classpath=classpath) attach_dependencies(doc, response) def attach_dependencies(doc, response): if len(doc.sentences) != len(response.conversions): raise ValueError("Sent %d sentences but got back %d conversions" % (len(doc.sentences), len(response.conversions))) for sent_idx, (sentence, conversion) in enumerate(zip(doc.sentences, response.conversions)): graph = conversion.graph # The deterministic conversion should have an equal number of words and one fewer edge # ... the root is represented by a word with no parent if len(sentence.words) != len(graph.node): raise ValueError("Sentence %d of the conversion should have %d words but got back %d nodes in the graph" % (sent_idx, len(sentence.words), len(graph.node))) if len(sentence.words) != len(graph.edge) + 1: raise ValueError("Sentence %d of the conversion should have %d edges (one per word, plus the root) but got back %d edges in the graph" % (sent_idx, len(sentence.words) - 1, len(graph.edge))) expected_nodes = set(range(1, len(sentence.words) + 1)) targets = set() for edge in graph.edge: if edge.target in targets: raise ValueError("Found two parents of %d in sentence %d" % (edge.target, sent_idx)) targets.add(edge.target) # -1 since the words are 0 indexed in the sentence, # but we count dependencies from 1 sentence.words[edge.target-1].head = edge.source sentence.words[edge.target-1].deprel = edge.dep roots = expected_nodes - targets assert len(roots) == 1 for root in roots: sentence.words[root-1].head = 0 sentence.words[root-1].deprel = "root" sentence.build_dependencies() class DependencyConverter(JavaProtobufContext): """ Context window for the dependency converter This is a context window which keeps a process open. Should allow for multiple requests without launching new java processes each time. """ def __init__(self, classpath=None): super(DependencyConverter, self).__init__(classpath, DependencyConverterResponse, CONVERTER_JAVA) def process(self, doc): """ Converts a constituency tree to dependency trees for each of the sentences in the document """ request = build_request(doc) response = self.process_request(request) attach_dependencies(doc, response) return doc def main(): nlp = stanza.Pipeline('en', processors='tokenize,pos,constituency') doc = nlp('I like blue antennae.') print("{:C}".format(doc)) process_doc(doc, classpath="$CLASSPATH") print("{:C}".format(doc)) doc = nlp('And I cannot lie.') print("{:C}".format(doc)) with DependencyConverter(classpath="$CLASSPATH") as converter: converter.process(doc) print("{:C}".format(doc)) if __name__ == '__main__': main() ================================================ FILE: stanza/server/java_protobuf_requests.py ================================================ from collections import deque import subprocess from stanza.models.common.utils import misc_to_space_after from stanza.models.constituency.parse_tree import Tree from stanza.protobuf import DependencyGraph, FlattenedParseTree from stanza.server.client import resolve_classpath def send_request(request, response_type, java_main, classpath=None): """ Use subprocess to run a Java protobuf processor on the given request Returns the protobuf response """ classpath = resolve_classpath(classpath) if classpath is None: raise ValueError("Classpath is None, Perhaps you need to set the $CLASSPATH or $CORENLP_HOME environment variable to point to a CoreNLP install.") pipe = subprocess.run(["java", "-cp", classpath, java_main], input=request.SerializeToString(), stdout=subprocess.PIPE, check=True) response = response_type() response.ParseFromString(pipe.stdout) return response def add_tree_nodes(proto_tree, tree, score): # add an open node node = proto_tree.nodes.add() node.openNode = True if score is not None: node.score = score # add the content of this node node = proto_tree.nodes.add() node.value = tree.label # add all children... # leaves get just one node # branches are called recursively for child in tree.children: if child.is_leaf(): node = proto_tree.nodes.add() node.value = child.label else: add_tree_nodes(proto_tree, child, None) node = proto_tree.nodes.add() node.closeNode = True def build_tree(tree, score): """ Builds a FlattenedParseTree from CoreNLP.proto Populates the value field from tree.label and iterates through the children via tree.children. Should work on any tree structure which follows that layout The score will be added to the top node (if it is not None) Operates by recursively calling add_tree_nodes """ proto_tree = FlattenedParseTree() add_tree_nodes(proto_tree, tree, score) return proto_tree def from_tree(proto_tree): """ Convert a FlattenedParseTree back into a Tree returns Tree, score (score might be None if it is missing) """ score = None stack = deque() for node in proto_tree.nodes: if node.HasField("score") and score is None: score = node.score if node.openNode: if len(stack) > 0 and isinstance(stack[-1], FlattenedParseTree.Node) and stack[-1].openNode: raise ValueError("Got a proto with no label on a node: {}".format(proto_tree)) stack.append(node) continue if not node.closeNode: child = Tree(label=node.value) # TODO: do something with the score stack.append(child) continue # must be a close operation... if len(stack) <= 1: raise ValueError("Got a proto with too many close operations: {}".format(proto_tree)) # on a close operation, pop until we hit the open # then turn everything in that span into a new node children = [] nextNode = stack.pop() while not isinstance(nextNode, FlattenedParseTree.Node): children.append(nextNode) nextNode = stack.pop() if len(children) == 0: raise ValueError("Got a proto with an open immediately followed by a close: {}".format(proto_tree)) children.reverse() label = children[0] children = children[1:] subtree = Tree(label=label.label, children=children) stack.append(subtree) if len(stack) > 1: raise ValueError("Got a proto which does not close all of the nodes: {}".format(proto_tree)) tree = stack.pop() if not isinstance(tree, Tree): raise ValueError("Got a proto which was just one Open operation: {}".format(proto_tree)) return tree, score def add_token(token_list, word, token): """ Add a token to a proto request. CoreNLP tokens have components of both word and token from stanza. We pass along "after" but not "before" """ if token is None and isinstance(word.id, int): raise AssertionError("Only expected word w/o token for 'extra' words") query_token = token_list.add() query_token.word = word.text query_token.value = word.text if word.lemma is not None: query_token.lemma = word.lemma if word.xpos is not None: query_token.pos = word.xpos if word.upos is not None: query_token.coarseTag = word.upos if word.feats and word.feats != "_": for feature in word.feats.split("|"): key, value = feature.split("=", maxsplit=1) query_token.conllUFeatures.key.append(key) query_token.conllUFeatures.value.append(value) if token is not None: if token.ner is not None: query_token.ner = token.ner if token is not None and len(token.id) > 1: query_token.mwtText = token.text query_token.isMWT = True query_token.isFirstMWT = token.id[0] == word.id if token.id[-1] != word.id: # if we are not the last word of an MWT token # we are absolutely not followed by space pass else: query_token.after = token.spaces_after query_token.index = word.id else: # presumably empty words won't really be written this way, # but we can still keep track of it query_token.after = misc_to_space_after(word.misc) query_token.index = word.id[0] query_token.emptyIndex = word.id[1] if word.misc and word.misc != "_": query_token.conllUMisc = word.misc if token is not None and token.misc and token.misc != "_": query_token.mwtMisc = token.misc def add_sentence(request_sentences, sentence, num_tokens): """ Add the tokens for this stanza sentence to a list of protobuf sentences """ request_sentence = request_sentences.add() request_sentence.tokenOffsetBegin = num_tokens request_sentence.tokenOffsetEnd = num_tokens + sum(len(token.words) for token in sentence.tokens) for token in sentence.tokens: for word in token.words: add_token(request_sentence.token, word, token) return request_sentence def add_word_to_graph(graph, word, sent_idx): """ Add a node and possibly an edge for a word in a basic dependency graph. """ node = graph.node.add() node.sentenceIndex = sent_idx+1 if isinstance(word.id, int): node.index = word.id else: node.index = word.id[0] node.emptyIndex = word.id[1] if word.head != 0 and word.head is not None: edge = graph.edge.add() edge.source = word.head if isinstance(word.id, int): edge.target = word.id else: edge.target = word.id[0] edge.targetEmpty = word.id[1] if word.deprel is not None: edge.dep = word.deprel else: # the receiving side doesn't like null as a dependency edge.dep = "_" def convert_networkx_graph(graph_proto, sentence, sent_idx): """ Turns a networkx graph into a DependencyGraph from the proto file """ for token in sentence.tokens: for word in token.words: add_token(graph_proto.token, word, token) for word in sentence.empty_words: add_token(graph_proto.token, word, None) dependencies = sentence._enhanced_dependencies for target in dependencies: if target == 0: # don't need to send the explicit root continue for source in dependencies.predecessors(target): if source == 0: # unlike with basic, we need to send over the roots, # as the enhanced can have loops graph_proto.rootNode.append(len(graph_proto.node)) continue for deprel in dependencies.get_edge_data(source, target): edge = graph_proto.edge.add() if isinstance(source, int): edge.source = source else: edge.source = source[0] if source[1] != 0: edge.sourceEmpty = source[1] if isinstance(target, int): edge.target = target else: edge.target = target[0] if target[1] != 0: edge.targetEmpty = target[1] edge.dep = deprel node = graph_proto.node.add() node.sentenceIndex = sent_idx + 1 # the nodes in the networkx graph are indexed from 1, not counting the root if isinstance(target, int): node.index = target else: node.index = target[0] if target[1] != 0: node.emptyIndex = target[1] return graph_proto def features_to_string(features): if not features: return None if len(features.key) == 0: return None return "|".join("%s=%s" % (key, value) for key, value in zip(features.key, features.value)) def misc_space_pieces(misc): """ Return only the space-related misc pieces """ if misc is None or misc == "" or misc == "_": return misc pieces = misc.split("|") pieces = [x for x in pieces if x.split("=", maxsplit=1)[0] in ("SpaceAfter", "SpacesAfter", "SpacesBefore")] if len(pieces) > 0: return "|".join(pieces) return None def remove_space_misc(misc): """ Remove any pieces from misc which are space-related """ if misc is None or misc == "" or misc == "_": return misc pieces = misc.split("|") pieces = [x for x in pieces if x.split("=", maxsplit=1)[0] not in ("SpaceAfter", "SpacesAfter", "SpacesBefore")] if len(pieces) > 0: return "|".join(pieces) return None def substitute_space_misc(misc, space_misc): space_misc_pieces = space_misc.split("|") if space_misc else [] space_misc_after = None space_misc_before = None for piece in space_misc_pieces: if piece.startswith("SpaceBefore"): space_misc_before = piece elif piece.startswith("SpaceAfter") or piece.startswith("SpacesAfter"): space_misc_after = piece else: raise AssertionError("An unknown piece wound up in the misc space fields: %s" % piece) pieces = misc.split("|") new_pieces = [] for piece in pieces: if piece.startswith("SpaceBefore"): if space_misc_before: new_pieces.append(space_misc_before) space_misc_before = None elif piece.startswith("SpaceAfter") or piece.startswith("SpacesAfter"): if space_misc_after: new_pieces.append(space_misc_after) space_misc_after = None else: new_pieces.append(piece) if space_misc_after: new_pieces.append(space_misc_after) if space_misc_before: new_pieces.append(space_misc_before) if len(new_pieces) == 0: return None return "|".join(new_pieces) class JavaProtobufContext(object): """ A generic context for sending requests to a java program using protobufs in a subprocess """ def __init__(self, classpath, build_response, java_main, extra_args=None): self.classpath = resolve_classpath(classpath) self.build_response = build_response self.java_main = java_main if extra_args is None: extra_args = [] self.extra_args = extra_args self.pipe = None def open_pipe(self): self.pipe = subprocess.Popen(["java", "-cp", self.classpath, self.java_main, "-multiple"] + self.extra_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) def close_pipe(self): if self.pipe.poll() is None: self.pipe.stdin.write((0).to_bytes(4, 'big')) self.pipe.stdin.flush() self.pipe = None def __enter__(self): self.open_pipe() return self def __exit__(self, type, value, traceback): self.close_pipe() def process_request(self, request): if self.pipe is None: raise RuntimeError("Pipe to java process is not open or was closed") text = request.SerializeToString() self.pipe.stdin.write(len(text).to_bytes(4, 'big')) self.pipe.stdin.write(text) self.pipe.stdin.flush() response_length = self.pipe.stdout.read(4) if len(response_length) < 4: raise BrokenPipeError("Could not communicate with java process!") response_length = int.from_bytes(response_length, "big") response_text = self.pipe.stdout.read(response_length) response = self.build_response() response.ParseFromString(response_text) return response ================================================ FILE: stanza/server/main.py ================================================ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Simple shell program to pipe in """ import corenlp import json import re import csv import sys from collections import namedtuple, OrderedDict FLOAT_RE = re.compile(r"\d*\.\d+") INT_RE = re.compile(r"\d+") def dictstr(arg): """ Parse a key=value string as a tuple (key, value) that can be provided as an argument to dict() """ key, value = arg.split("=") if value.lower() == "true" or value.lower() == "false": value = bool(value) elif INT_RE.match(value): value = int(value) elif FLOAT_RE.match(value): value = float(value) return (key, value) def do_annotate(args): args.props = dict(args.props) if args.props else {} if args.sentence_mode: args.props["ssplit.isOneSentence"] = True with corenlp.CoreNLPClient(annotators=args.annotators, properties=args.props, be_quiet=not args.verbose_server) as client: for line in args.input: if line.startswith("#"): continue ann = client.annotate(line.strip(), output_format=args.format) if args.format == "json": if args.sentence_mode: ann = ann["sentences"][0] args.output.write(json.dumps(ann)) args.output.write("\n") def main(): import argparse parser = argparse.ArgumentParser(description='Annotate data') parser.add_argument('-i', '--input', type=argparse.FileType('r'), default=sys.stdin, help="Input file to process; each line contains one document (default: stdin)") parser.add_argument('-o', '--output', type=argparse.FileType('w'), default=sys.stdout, help="File to write annotations to (default: stdout)") parser.add_argument('-f', '--format', choices=["json",], default="json", help="Output format") parser.add_argument('-a', '--annotators', nargs="+", type=str, default=["tokenize ssplit lemma pos"], help="A list of annotators") parser.add_argument('-s', '--sentence-mode', action="store_true",help="Assume each line of input is a sentence.") parser.add_argument('-v', '--verbose-server', action="store_true",help="Server is made verbose") parser.add_argument('-m', '--memory', type=str, default="4G", help="Memory to use for the server") parser.add_argument('-p', '--props', nargs="+", type=dictstr, help="Properties as a list of key=value pairs") parser.set_defaults(func=do_annotate) ARGS = parser.parse_args() if ARGS.func is None: parser.print_help() sys.exit(1) else: ARGS.func(ARGS) if __name__ == "__main__": main() ================================================ FILE: stanza/server/morphology.py ================================================ """ Direct pipe connection to the Java CoreNLP Morphology class Only effective for English. Must be supplied with PTB scheme xpos, not upos """ from stanza.protobuf import MorphologyRequest, MorphologyResponse from stanza.server.java_protobuf_requests import send_request, JavaProtobufContext MORPHOLOGY_JAVA = "edu.stanford.nlp.process.ProcessMorphologyRequest" def send_morphology_request(request): return send_request(request, MorphologyResponse, MORPHOLOGY_JAVA) def build_request(words, xpos_tags): """ Turn a list of words and a list of tags into a request tags must be xpos, not upos """ request = MorphologyRequest() for word, tag in zip(words, xpos_tags): tagged_word = request.words.add() tagged_word.word = word tagged_word.xpos = tag return request def process_text(words, xpos_tags): """ Get the lemmata for each word/tag pair Currently the return is a MorphologyResponse from CoreNLP.proto tags must be xpos, not upos """ request = build_request(words, xpos_tags) return send_morphology_request(request) class Morphology(JavaProtobufContext): """ Morphology context window This is a context window which keeps a process open. Should allow for multiple requests without launching new java processes each time. (much faster than calling process_text over and over) """ def __init__(self, classpath=None): super(Morphology, self).__init__(classpath, MorphologyResponse, MORPHOLOGY_JAVA) def process(self, words, xpos_tags): """ Get the lemmata for each word/tag pair """ request = build_request(words, xpos_tags) return self.process_request(request) def main(): # TODO: turn this into a unit test, once a new CoreNLP is released words = ["Jennifer", "has", "the", "prettiest", "antennae"] tags = ["NNP", "VBZ", "DT", "JJS", "NNS"] expected = ["Jennifer", "have", "the", "pretty", "antenna"] result = process_text(words, tags) lemma = [x.lemma for x in result.words] print(lemma) assert lemma == expected with Morphology() as morph: result = morph.process(words, tags) lemma = [x.lemma for x in result.words] assert lemma == expected if __name__ == '__main__': main() ================================================ FILE: stanza/server/parser_eval.py ================================================ """ This class runs a Java process to evaluate a treebank prediction using CoreNLP """ from collections import namedtuple import sys import stanza from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse from stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext from stanza.models.constituency.tree_reader import read_treebank EVALUATE_JAVA = "edu.stanford.nlp.parser.metrics.EvaluateExternalParser" ParseResult = namedtuple("ParseResult", ['gold', 'predictions', 'state', 'constituents']) ScoredTree = namedtuple("ScoredTree", ['tree', 'score']) def build_request(treebank): """ treebank should be a list of pairs: [gold, predictions] each predictions is a list of tuples (prediction, score, state) state is ignored and can be None Note that for now, only one tree is measured, but this may be extensible in the future Trees should be in the form of a Tree from parse_tree.py """ request = EvaluateParserRequest() for raw_result in treebank: gold = raw_result.gold predictions = raw_result.predictions parse_result = request.treebank.add() parse_result.gold.CopyFrom(build_tree(gold, None)) for pred in predictions: if isinstance(pred, tuple): prediction, score = pred else: prediction = pred score = None try: parse_result.predicted.append(build_tree(prediction, score)) except Exception as e: raise RuntimeError("Unable to build parser request from tree {}".format(pred)) from e return request def collate(gold_treebank, predictions_treebank): """ Turns a list of gold and prediction into a evaluation object """ treebank = [] for gold, prediction in zip(gold_treebank, predictions_treebank): result = ParseResult(gold, [prediction], None, None) treebank.append(result) return treebank class EvaluateParser(JavaProtobufContext): """ Parser evaluation context window This is a context window which keeps a process open. Should allow for multiple requests without launching new java processes each time. """ def __init__(self, classpath=None, kbest=None, silent=False): if kbest is not None: extra_args = ["-evalPCFGkBest", "{}".format(kbest), "-evals", "pcfgTopK"] else: extra_args = [] if silent: extra_args.extend(["-evals", "summary=False"]) super(EvaluateParser, self).__init__(classpath, EvaluateParserResponse, EVALUATE_JAVA, extra_args=extra_args) def process(self, treebank): request = build_request(treebank) return self.process_request(request) def main(): gold = read_treebank(sys.argv[1]) predictions = read_treebank(sys.argv[2]) treebank = collate(gold, predictions) with EvaluateParser() as ep: ep.process(treebank) if __name__ == '__main__': main() ================================================ FILE: stanza/server/semgrex.py ================================================ """Invokes the Java semgrex on a document The server client has a method "semgrex" which sends text to Java CoreNLP for processing with a semgrex (SEMantic GRaph regEX) query: https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html However, this operates on text using the CoreNLP tools, which means the dependency graphs may not align with stanza's depparse module, and this also limits the languages for which it can be used. This module allows for running semgrex commands on the graphs produced by depparse. To use, first process text into a doc using stanza.Pipeline Next, pass the processed doc and a list of semgrex patterns to process_doc in this module. It will run the java semgrex module as a subprocess and return the result in the form of a SemgrexResponse, whose description is in the proto file included with stanza. A minimal example is the main method of this module. Note that launching the subprocess is potentially quite expensive relative to the search if used many times on small documents. Ideally larger texts would be processed, and all of the desired semgrex patterns would be run at once. The worst thing to do would be to call this multiple times on a large document, one invocation per semgrex pattern, as that would serialize the document each time. Included here is a context manager which allows for keeping the same java process open for multiple requests. This saves on the subprocess launching time. It is still important not to wastefully serialize the same document over and over, though. """ import argparse from collections import namedtuple import copy import os import re import stanza from stanza.protobuf import SemgrexRequest, SemgrexResponse from stanza.server.java_protobuf_requests import send_request, add_token, add_word_to_graph, JavaProtobufContext, convert_networkx_graph from stanza.utils.conll import CoNLL SEMGREX_JAVA = "edu.stanford.nlp.semgraph.semgrex.ProcessSemgrexRequest" SemgrexQuery = namedtuple("SemgrexQuery", "pattern comments") def send_semgrex_request(request): return send_request(request, SemgrexResponse, SEMGREX_JAVA) def build_request(doc, semgrex_patterns, enhanced=False): request = SemgrexRequest() if isinstance(semgrex_patterns, str): semgrex_patterns = [semgrex_patterns] semgrex_patterns = [x if isinstance(x, SemgrexQuery) else SemgrexQuery(x, []) for x in semgrex_patterns] for semgrex in semgrex_patterns: request.semgrex.append(semgrex.pattern) for sent_idx, sentence in enumerate(doc.sentences): query = request.query.add() if enhanced: # tokens will be added on to the graph object convert_networkx_graph(query.graph, sentence, sent_idx) else: word_idx = 0 for token in sentence.tokens: for word in token.words: add_token(query.token, word, token) add_word_to_graph(query.graph, word, sent_idx) word_idx = word_idx + 1 return request def process_doc(doc, *semgrex_patterns, enhanced=False): """ Returns the result of processing the given semgrex expression on the stanza doc. Currently the return is a SemgrexResponse from CoreNLP.proto """ request = build_request(doc, semgrex_patterns, enhanced=enhanced) return send_semgrex_request(request) class Semgrex(JavaProtobufContext): """ Semgrex context window This is a context window which keeps a process open. Should allow for multiple requests without launching new java processes each time. """ def __init__(self, classpath=None): super(Semgrex, self).__init__(classpath, SemgrexResponse, SEMGREX_JAVA) def process(self, doc, *semgrex_patterns): """ Apply each of the semgrex patterns to each of the dependency trees in doc """ request = build_request(doc, semgrex_patterns) return self.process_request(request) def annotate_doc(doc, semgrex_result, semgrex_patterns, matches_only, exclude_matches): """ Put comments on the sentences which describe the matching semgrex patterns """ doc = copy.deepcopy(doc) if isinstance(semgrex_patterns, str): semgrex_patterns = [semgrex_patterns] semgrex_patterns = [x if isinstance(x, SemgrexQuery) else SemgrexQuery(x, []) for x in semgrex_patterns] matched_ids = set() for sentence_result in semgrex_result.result: for pattern_result in sentence_result.result: for match in pattern_result.match: matched_ids.add(match.sentenceIndex) pattern_texts = [semgrex_pattern.pattern.replace("\n", " ") for semgrex_pattern in semgrex_patterns] matching_sentences = [] for sentence_result in semgrex_result.result: sentence_matched = False matched_semgrex_ids = set() for pattern_result in sentence_result.result: if len(pattern_result.match) == 0: continue highlight_tokens = [] highlight_edges = [] for match in pattern_result.match: sentence_matched = True sentence = doc.sentences[match.sentenceIndex] semgrex_pattern = semgrex_patterns[match.semgrexIndex] pattern_text = pattern_texts[match.semgrexIndex] matched_semgrex_ids.add(match.semgrexIndex) match_word = "%d:%s" % (match.matchIndex, sentence.words[match.matchIndex-1].text) if len(match.node) == 0: node_matches = "" else: node_matches = ["%s=%d:%s" % (node.name, node.matchIndex, sentence.words[node.matchIndex-1].text) for node in match.node] node_matches = " " + " ".join(node_matches) if len(match.varstring) == 0: var_values = "" else: var_values = ["%s=%s" % (v.name, v.value) for v in match.varstring] var_values = " " + " ".join(var_values) sentence.add_comment("# semgrex pattern |%s| matched at %s%s%s" % (pattern_text, match_word, node_matches, var_values)) for comment in semgrex_pattern.comments: sentence.add_comment("# semgrex comment: %s" % comment) highlight_tokens.append(match.matchIndex) for edge in match.edge: highlight_edges.append(edge.target) if len(highlight_tokens) > 0: sentence.add_comment("# highlight tokens = %s" % (" ".join("%d" % x for x in highlight_tokens))) if len(highlight_edges) > 0: sentence.add_comment("# highlight deprels = %s" % (" ".join("%d" % x for x in highlight_edges))) if sentence_matched and not matches_only: for semgrex_idx, pattern_text in enumerate(pattern_texts): if semgrex_idx not in matched_semgrex_ids: sentence.add_comment("# semgrex pattern |%s| did not match!" % pattern_text) if sentence_matched: matching_sentences.append(sentence) nonmatching_sentences = [sentence for sentence_idx, sentence in enumerate(doc.sentences) if sentence_idx not in matched_ids] for sentence in nonmatching_sentences: for semgrex_idx, pattern_text in enumerate(pattern_texts): sentence.add_comment("# semgrex pattern |%s| did not match!" % pattern_text) if matches_only: doc.sentences = matching_sentences elif exclude_matches: doc.sentences = nonmatching_sentences return doc def main(): """ Runs a toy example, or can run a given semgrex expression on the given input file. For example: python3 -m stanza.server.semgrex --input_file demo/semgrex_sample.conllu --matches_only to only print sentences that match the semgrex pattern --no_print_input to not print the input """ parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, default=None, help='Process this file or directory') parser.add_argument('--input_filter', type=str, default=None, help='Only process files that match this regex') parser.add_argument('semgrex', type=str, nargs="*", default=["{}=source >obj=zzz {}=target"], help="Semgrex to apply to the text. The default looks for sentences with objects") parser.add_argument('--semgrex_file', type=str, default=None, help="File to read semgrex patterns from - relevant in case the pattern you want to use doesn't work well on the command line, for example") parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help="Print the input alongside the output - gets kind of noisy") parser.add_argument('--no_print_input', dest='print_input', action='store_false', help="Don't print the input alongside the output - gets kind of noisy") parser.add_argument('--matches_only', action='store_true', default=True, help="Only print the matching sentences") parser.add_argument('--no_matches_only', dest='matches_only', action='store_false', help="Only print the matching sentences") parser.add_argument('--exclude_matches', action='store_true', default=False, help="Only print the NON-matching sentences") parser.add_argument('--enhanced', action='store_true', default=False, help='Use the enhanced dependencies instead of the basic') parser.add_argument('--no_combined_doc', dest='combined_doc', action='store_false', default=True, help='By default, combine all the input docs into one big document. Allows for easier secondary processing like sorting') args = parser.parse_args() if args.semgrex_file: with open(args.semgrex_file) as fin: args.semgrex = [x.strip() for x in fin.readlines()] semgrex_patterns = [] current_comments = [] for line in args.semgrex: if not line: current_comments = [] elif line.startswith("#"): current_comments.append(line[1:].strip()) else: semgrex_patterns.append(SemgrexQuery(line, current_comments)) current_comments = [] args.semgrex = semgrex_patterns if args.input: if os.path.isfile(args.input): docs = [CoNLL.conll2doc(input_file=args.input, ignore_gapping=False)] else: filenames = sorted(os.listdir(args.input)) if args.input_filter: input_filter = re.compile(args.input_filter) filenames = [x for x in filenames if input_filter.match(x)] filenames = [os.path.join(args.input, filename) for filename in filenames] filenames = [filename for filename in filenames if os.path.isfile(filename)] docs = [CoNLL.conll2doc(input_file=filename, ignore_gapping=False) for filename in filenames] else: nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma,depparse') docs = [nlp('Uro ruined modern. Fortunately, Wotc banned him.')] if args.combined_doc: sentences = [sent for doc in docs for sent in doc.sentences] docs = [docs[0]] docs[0].sentences = sentences for doc in docs: if args.print_input: print("{:C}".format(doc)) print() print("-" * 75) print() semgrex_result = process_doc(doc, *args.semgrex, enhanced=args.enhanced) doc = annotate_doc(doc, semgrex_result, args.semgrex, args.matches_only, args.exclude_matches) if len(doc.sentences) > 0: print("{:C}\n".format(doc)) if __name__ == '__main__': main() ================================================ FILE: stanza/server/ssurgeon.py ================================================ """Invokes the Java ssurgeon on a document "ssurgeon" sends text to Java CoreNLP for processing with a ssurgeon (Semantic graph SURGEON) query The main program in this file gives a very short intro to how to use it. """ import argparse from collections import namedtuple import copy import os import re import sys from stanza.models.common.utils import misc_to_space_after, space_after_to_misc from stanza.protobuf import SsurgeonRequest, SsurgeonResponse from stanza.server import java_protobuf_requests from stanza.utils.conll import CoNLL from stanza.models.common.doc import ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, NER, Word, Token, Sentence SSURGEON_JAVA = "edu.stanford.nlp.semgraph.semgrex.ssurgeon.ProcessSsurgeonRequest" SsurgeonEdit = namedtuple("SsurgeonEdit", "semgrex_pattern ssurgeon_edits ssurgeon_id notes language", defaults=[None, None, "UniversalEnglish"]) def parse_ssurgeon_edits(ssurgeon_text): ssurgeon_text = ssurgeon_text.strip() ssurgeon_blocks = re.split("\n\n+", ssurgeon_text) ssurgeon_edits = [] for idx, block in enumerate(ssurgeon_blocks): lines = block.split("\n") comments = [line[1:].strip() for line in lines if line.startswith("#")] notes = " ".join(comments) lines = [x.strip() for x in lines if x.strip() and not x.startswith("#")] if len(lines) == 0: # was a block of entirely comments continue semgrex = lines[0] ssurgeon = lines[1:] ssurgeon_edits.append(SsurgeonEdit(semgrex, ssurgeon, "%d" % (idx + 1), notes)) return ssurgeon_edits def read_ssurgeon_edits(edit_file): with open(edit_file, encoding="utf-8") as fin: return parse_ssurgeon_edits(fin.read()) def send_ssurgeon_request(request): return java_protobuf_requests.send_request(request, SsurgeonResponse, SSURGEON_JAVA) def build_request(doc, ssurgeon_edits): request = SsurgeonRequest() for ssurgeon in ssurgeon_edits: ssurgeon_proto = request.ssurgeon.add() ssurgeon_proto.semgrex = ssurgeon.semgrex_pattern for operation in ssurgeon.ssurgeon_edits: ssurgeon_proto.operation.append(operation) if ssurgeon.ssurgeon_id is not None: ssurgeon_proto.id = ssurgeon.ssurgeon_id if ssurgeon.notes is not None: ssurgeon_proto.notes = ssurgeon.notes if ssurgeon.language is not None: ssurgeon_proto.language = ssurgeon.language try: for sent_idx, sentence in enumerate(doc.sentences): graph = request.graph.add() word_idx = 0 for token in sentence.tokens: for word in token.words: java_protobuf_requests.add_token(graph.token, word, token) java_protobuf_requests.add_word_to_graph(graph, word, sent_idx) word_idx = word_idx + 1 except Exception as e: raise RuntimeError("Failed to process sentence {}:\n{:C}".format(sent_idx, sentence)) from e return request def build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None): ssurgeon_edit = SsurgeonEdit(semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes) return build_request(doc, [ssurgeon_edit]) def process_doc(doc, ssurgeon_edits): """ Returns the result of processing the given semgrex expression and ssurgeon edits on the stanza doc. Currently the return is a SsurgeonResponse from CoreNLP.proto """ request = build_request(doc, ssurgeon_edits) return send_ssurgeon_request(request) def process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None): request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes) return send_ssurgeon_request(request) def build_word_entry(word_index, graph_word): word_entry = { ID: word_index, TEXT: graph_word.word if graph_word.word else None, LEMMA: graph_word.lemma if graph_word.lemma else None, UPOS: graph_word.coarseTag if graph_word.coarseTag else None, XPOS: graph_word.pos if graph_word.pos else None, FEATS: java_protobuf_requests.features_to_string(graph_word.conllUFeatures), DEPS: None, NER: graph_word.ner if graph_word.ner else None, MISC: None, START_CHAR: None, # TODO: fix this? one problem is the text positions END_CHAR: None, # might change across all of the sentences # presumably python will complain if this conflicts # with one of the constants above "is_mwt": graph_word.isMWT, "is_first_mwt": graph_word.isFirstMWT, "mwt_text": graph_word.mwtText, "mwt_misc": graph_word.mwtMisc, } # TODO: do "before" as well word_entry[MISC] = space_after_to_misc(graph_word.after) if graph_word.conllUMisc: word_entry[MISC] = java_protobuf_requests.substitute_space_misc(graph_word.conllUMisc, word_entry[MISC]) return word_entry def convert_response_to_doc(doc, semgrex_response, add_missing_text): doc = copy.deepcopy(doc) try: for sent_idx, (sentence, ssurgeon_result) in enumerate(zip(doc.sentences, semgrex_response.result)): # EditNode is currently bugged... :/ # TODO: change this after next CoreNLP release (after 4.5.3) #if not ssurgeon_result.changed: # continue ssurgeon_graph = ssurgeon_result.graph tokens = [] token_id_to_idx = {} for graph_node, graph_word in zip(ssurgeon_graph.node, ssurgeon_graph.token): word_entry = build_word_entry(graph_node.index, graph_word) token_id_to_idx[graph_node.index] = len(tokens) tokens.append(word_entry) for root in ssurgeon_graph.root: tokens[token_id_to_idx[root]][HEAD] = 0 tokens[token_id_to_idx[root]][DEPREL] = "root" for edge in ssurgeon_graph.edge: # can't do anything about the extra dependencies for now # TODO: put them all in .deps if edge.isExtra: continue tokens[token_id_to_idx[edge.target]][HEAD] = edge.source tokens[token_id_to_idx[edge.target]][DEPREL] = edge.dep tokens.sort(key=lambda x: x[ID]) # for any MWT, produce a token_entry which represents the word range mwt_tokens = [] for word_start_idx, word in enumerate(tokens): if not word["is_first_mwt"]: if word["is_mwt"]: word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC]) mwt_tokens.append(word) continue word_end_idx = word_start_idx + 1 while word_end_idx < len(tokens) and tokens[word_end_idx]["is_mwt"] and not tokens[word_end_idx]["is_first_mwt"]: word_end_idx += 1 mwt_token_entry = { # the tokens don't fencepost the way lists do ID: (tokens[word_start_idx][ID], tokens[word_end_idx-1][ID]), TEXT: word["mwt_text"], NER: word[NER], # use the SpaceAfter=No (or not) from the last word in the token MISC: None, } mwt_token_entry[MISC] = java_protobuf_requests.misc_space_pieces(tokens[word_end_idx-1][MISC]) if tokens[word_end_idx-1]["mwt_misc"]: mwt_token_entry[MISC] = java_protobuf_requests.substitute_space_misc(tokens[word_end_idx-1]["mwt_misc"], mwt_token_entry[MISC]) word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC]) mwt_tokens.append(mwt_token_entry) mwt_tokens.append(word) old_comments = list(sentence.comments) sentence = Sentence(mwt_tokens, doc) token_text = [] for token_idx, token in enumerate(sentence.tokens): token_text.append(token.text) if token_idx == len(sentence.tokens) - 1: break token_text.append(token.spaces_after) sentence_text = "".join(token_text) found_text = False for comment in old_comments: if comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text="): sentence.add_comment("# text = " + sentence_text) found_text = True else: sentence.add_comment(comment) if not found_text and add_missing_text: sentence.add_comment("# text = " + sentence_text) doc.sentences[sent_idx] = sentence sentence.rebuild_dependencies() except Exception as e: raise RuntimeError("Ssurgeon could not process sentence {}\nSsurgeon result:\n{}\nOriginal sentence:\n{:C}".format(sent_idx, ssurgeon_result, sentence)) from e return doc class Ssurgeon(java_protobuf_requests.JavaProtobufContext): """ Ssurgeon context window This is a context window which keeps a process open. Should allow for multiple requests without launching new java processes each time. """ def __init__(self, classpath=None): super(Ssurgeon, self).__init__(classpath, SsurgeonResponse, SSURGEON_JAVA) def process(self, doc, ssurgeon_edits): """ Apply each of the ssurgeon patterns to each of the dependency trees in doc """ request = build_request(doc, ssurgeon_edits) return self.process_request(request) def process_one_operation(self, doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None): """ Convenience method - build one operation, then apply it """ request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes) return self.process_request(request) SAMPLE_DOC = """ # sent_id = 271 # text = Hers is easy to clean. # previous = What did the dealer like about Alex's car? # comment = extraction/raising via "tough extraction" and clausal subject 1 Hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nsubj _ _ 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _ 3 easy easy ADJ JJ Degree=Pos 0 root _ _ 4 to to PART TO _ 5 mark _ _ 5 clean clean VERB VB VerbForm=Inf 3 csubj _ SpaceAfter=No 6 . . PUNCT . _ 5 punct _ _ """ def main(): # for Windows, so that we aren't randomly printing garbage (or just failing to print) try: sys.stdout.reconfigure(encoding='utf-8') except AttributeError: # TODO: deprecate 3.6 support after the next release pass # The default semgrex detects sentences in the UD_English-Pronouns dataset which have both nsubj and csubj on the same word. # The default ssurgeon transforms the unwanted csubj to advcl # See https://github.com/UniversalDependencies/docs/issues/923 parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, default=None, help="Input file / directory to process (otherwise will process a sample text)") parser.add_argument('--output', type=str, default=None, help="Output location (otherwise will write back to the input directory)") parser.add_argument('--stdout', action='store_true', default=False, help='Output to stdout') parser.add_argument('--input_filter', type=str, default=".*[.]conllu", help="If processing a directory, only process files from --input that match this filter - regex, not shell filter. Default: %(default)s") parser.add_argument('--no_input_filter', action='store_const', const=None, dest="input_filter", help="Remove the default input filename filter") parser.add_argument('--edit_file', type=str, default=None, help="File to get semgrex and ssurgeon rules from") parser.add_argument('--semgrex', type=str, default="{}=source >nsubj {} >csubj=bad {}", help="Semgrex to apply to the text. A default detects words which have both an nsubj and a csubj. Default: %(default)s") parser.add_argument('ssurgeon', type=str, default=["relabelNamedEdge -edge bad -reln advcl"], nargs="*", help="Ssurgeon edits to apply based on the Semgrex. Can have multiple edits in a row. A default exists to transform csubj into advcl. Default: %(default)s") parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help="Print the input alongside the output - gets kind of noisy. Default: %(default)s") parser.add_argument('--no_print_input', dest='print_input', action='store_false', help="Don't print the input alongside the output - gets kind of noisy") parser.add_argument('--no_add_missing_text', dest='add_missing_text', action='store_false', help="By default, the tool will add a #text comment if one does not exist. This leaves that blank") args = parser.parse_args() if args.edit_file: ssurgeon_edits = read_ssurgeon_edits(args.edit_file) else: ssurgeon_edits = [SsurgeonEdit(args.semgrex, args.ssurgeon)] if args.input: if os.path.isfile(args.input): docs = [CoNLL.conll2doc(input_file=args.input)] if args.output is None: outputs = [args.input] else: # TODO: could check if --output is a directory outputs = [args.output] input_output = zip(docs, outputs) else: if not args.output: args.output = args.input if not os.path.exists(args.output): os.makedirs(args.output) def read_docs(): for doc_filename in os.listdir(args.input): if args.input_filter: if not re.match(args.input_filter, doc_filename): continue doc_path = os.path.join(args.input, doc_filename) if not os.path.isfile(doc_path): continue output_path = os.path.join(args.output, doc_filename) print("Processing %s to %s" % (doc_path, output_path)) yield CoNLL.conll2doc(input_file=doc_path), output_path input_output = read_docs() else: docs = [CoNLL.conll2doc(input_str=SAMPLE_DOC)] outputs = [None] input_output = zip(docs, outputs) args.stdout = True for doc, output in input_output: if args.print_input: print("{:C}".format(doc)) ssurgeon_request = build_request(doc, ssurgeon_edits) ssurgeon_response = send_ssurgeon_request(ssurgeon_request) updated_doc = convert_response_to_doc(doc, ssurgeon_response, args.add_missing_text) if output is not None: with open(output, "w", encoding="utf-8") as fout: fout.write("{:C}\n\n".format(updated_doc)) if args.stdout: print("{:C}\n".format(updated_doc)) if __name__ == '__main__': main() ================================================ FILE: stanza/server/tokensregex.py ================================================ """Invokes the Java tokensregex on a document This operates tokensregex on docs processed with stanza models. https://nlp.stanford.edu/software/tokensregex.html A minimal example is the main method of this module. """ import stanza from stanza.protobuf import TokensRegexRequest, TokensRegexResponse from stanza.server.java_protobuf_requests import send_request, add_sentence def send_tokensregex_request(request): return send_request(request, TokensRegexResponse, "edu.stanford.nlp.ling.tokensregex.ProcessTokensRegexRequest") def process_doc(doc, *patterns): request = TokensRegexRequest() for pattern in patterns: request.pattern.append(pattern) request_doc = request.doc request_doc.text = doc.text num_tokens = 0 for sentence in doc.sentences: add_sentence(request_doc.sentence, sentence, num_tokens) num_tokens = num_tokens + sum(len(token.words) for token in sentence.tokens) return send_tokensregex_request(request) def main(): #nlp = stanza.Pipeline('en', # processors='tokenize,pos,lemma,ner') nlp = stanza.Pipeline('en', processors='tokenize') doc = nlp('Uro ruined modern. Fortunately, Wotc banned him') print(process_doc(doc, "him", "ruined")) if __name__ == '__main__': main() ================================================ FILE: stanza/server/tsurgeon.py ================================================ """Invokes the Java tsurgeon on a list of trees Included with CoreNLP is a mechanism for modifying trees based on existing patterns within a tree. The patterns are found using tregex: https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/trees/tregex/TregexPattern.html The modifications are then performed using tsurgeon: https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/trees/tregex/tsurgeon/Tsurgeon.html This module accepts Tree objects as produced by the conparser and returns the modified trees that result from one or more tsurgeon operations. """ from stanza.models.constituency import tree_reader from stanza.models.constituency.parse_tree import Tree from stanza.protobuf import TsurgeonRequest, TsurgeonResponse from stanza.server.java_protobuf_requests import send_request, build_tree, from_tree, JavaProtobufContext TSURGEON_JAVA = "edu.stanford.nlp.trees.tregex.tsurgeon.ProcessTsurgeonRequest" def send_tsurgeon_request(request): return send_request(request, TsurgeonResponse, TSURGEON_JAVA) def build_request(trees, operations): """ Build the TsurgeonRequest object trees: a list of trees operations: a list of (tregex, tsurgeon, tsurgeon, ...) """ if isinstance(trees, Tree): trees = (trees,) request = TsurgeonRequest() for tree in trees: request.trees.append(build_tree(tree, 0.0)) if all(isinstance(x, str) for x in operations): operations = (operations,) for operation in operations: if len(operation) == 1: raise ValueError("Expected [tregex, tsurgeon, ...] but just got a tregex") operation_request = request.operations.add() operation_request.tregex = operation[0] for tsurgeon in operation[1:]: operation_request.tsurgeon.append(tsurgeon) return request def process_trees(trees, *operations): """ Returns the result of processing the given tsurgeon operations on the given trees Returns a list of modified trees, eg, the result is already processed """ request = build_request(trees, operations) result = send_tsurgeon_request(request) return [from_tree(t)[0] for t in result.trees] class Tsurgeon(JavaProtobufContext): """ Tsurgeon context window This is a context window which keeps a process open. Should allow for multiple requests without launching new java processes each time. """ def __init__(self, classpath=None): super(Tsurgeon, self).__init__(classpath, TsurgeonResponse, TSURGEON_JAVA) def process(self, trees, *operations): request = build_request(trees, operations) result = self.process_request(request) return [from_tree(t)[0] for t in result.trees] def main(): """ A small demonstration of a tsurgeon operation """ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) tregex = "WP=wp" tsurgeon = "relabel wp WWWPPP" result = process_trees(trees, (tregex, tsurgeon)) print(result) if __name__ == '__main__': main() ================================================ FILE: stanza/server/ud_enhancer.py ================================================ import stanza from stanza.protobuf import DependencyEnhancerRequest, Document, Language from stanza.server.java_protobuf_requests import send_request, add_sentence, JavaProtobufContext ENHANCER_JAVA = "edu.stanford.nlp.trees.ud.ProcessUniversalEnhancerRequest" def build_enhancer_request(doc, language, pronouns_pattern): if bool(language) == bool(pronouns_pattern): raise ValueError("Should set exactly one of language and pronouns_pattern") request = DependencyEnhancerRequest() if pronouns_pattern: request.setRelativePronouns(pronouns_pattern) elif language.lower() in ("en", "english"): request.language = Language.UniversalEnglish elif language.lower() in ("zh", "zh-hans", "chinese"): request.language = Language.UniversalChinese else: raise ValueError("Sorry, but language " + language + " is not supported yet. Either set a pronouns pattern or file an issue at https://stanfordnlp.github.io/stanza suggesting a mechanism for converting this language") request_doc = request.document request_doc.text = doc.text num_tokens = 0 for sent_idx, sentence in enumerate(doc.sentences): request_sentence = add_sentence(request_doc.sentence, sentence, num_tokens) num_tokens = num_tokens + sum(len(token.words) for token in sentence.tokens) graph = request_sentence.basicDependencies nodes = [] word_index = 0 for token in sentence.tokens: for word in token.words: # TODO: refactor with the bit in java_protobuf_requests word_index = word_index + 1 node = graph.node.add() node.sentenceIndex = sent_idx node.index = word_index if word.head != 0: edge = graph.edge.add() edge.source = word.head edge.target = word_index edge.dep = word.deprel return request def process_doc(doc, language=None, pronouns_pattern=None): request = build_enhancer_request(doc, language, pronouns_pattern) return send_request(request, Document, ENHANCER_JAVA) class UniversalEnhancer(JavaProtobufContext): """ UniversalEnhancer context window This is a context window which keeps a process open. Should allow for multiple requests without launching new java processes each time. """ def __init__(self, language=None, pronouns_pattern=None, classpath=None): super(UniversalEnhancer, self).__init__(classpath, Document, ENHANCER_JAVA) if bool(language) == bool(pronouns_pattern): raise ValueError("Should set exactly one of language and pronouns_pattern") self.language = language self.pronouns_pattern = pronouns_pattern def process(self, doc): request = build_enhancer_request(doc, self.language, self.pronouns_pattern) return self.process_request(request) def main(): nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma,depparse') with UniversalEnhancer(language="en") as enhancer: doc = nlp("This is the car that I bought") result = enhancer.process(doc) print(result.sentence[0].enhancedDependencies) if __name__ == '__main__': main() ================================================ FILE: stanza/tests/__init__.py ================================================ """ Utilities for testing """ import os import re from platformdirs import user_cache_dir from stanza import __resources_version__ # Environment Variables # set this to specify working directory of tests TEST_HOME_VAR = 'STANZA_TEST_HOME' # Global Variables TEST_DIR_BASE_NAME = 'stanza_test' TEST_WORKING_DIR = os.getenv(TEST_HOME_VAR, None) if not TEST_WORKING_DIR: TEST_WORKING_DIR = user_cache_dir(TEST_DIR_BASE_NAME, 'StanfordNLP', __resources_version__) TEST_MODELS_DIR = f'{TEST_WORKING_DIR}/models' TEST_CORENLP_DIR = f'{TEST_WORKING_DIR}/corenlp_dir' # server resources SERVER_TEST_PROPS = f'{TEST_WORKING_DIR}/scripts/external_server.properties' # language resources LANGUAGE_RESOURCES = {} TOKENIZE_MODEL = 'tokenizer.pt' MWT_MODEL = 'mwt_expander.pt' POS_MODEL = 'tagger.pt' POS_PRETRAIN = 'pretrain.pt' LEMMA_MODEL = 'lemmatizer.pt' DEPPARSE_MODEL = 'parser.pt' DEPPARSE_PRETRAIN = 'pretrain.pt' MODEL_FILES = [TOKENIZE_MODEL, MWT_MODEL, POS_MODEL, POS_PRETRAIN, LEMMA_MODEL, DEPPARSE_MODEL, DEPPARSE_PRETRAIN] # English resources EN_KEY = 'en' EN_SHORTHAND = 'en_ewt' # models EN_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{EN_SHORTHAND}_models' EN_MODEL_FILES = [f'{EN_MODELS_DIR}/{EN_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES] # French resources FR_KEY = 'fr' FR_SHORTHAND = 'fr_gsd' # regression file paths FR_TEST_IN = f'{TEST_WORKING_DIR}/in/fr_gsd.test.txt' FR_TEST_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out' FR_TEST_GOLD_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out.gold' # models FR_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{FR_SHORTHAND}_models' FR_MODEL_FILES = [f'{FR_MODELS_DIR}/{FR_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES] # Other language resources AR_SHORTHAND = 'ar_padt' DE_SHORTHAND = 'de_gsd' KK_SHORTHAND = 'kk_ktb' KO_SHORTHAND = 'ko_gsd' # utils for clean up # only allow removal of dirs/files in this approved list REMOVABLE_PATHS = ['en_ewt_models', 'en_ewt_tokenizer.pt', 'en_ewt_mwt_expander.pt', 'en_ewt_tagger.pt', 'en_ewt.pretrain.pt', 'en_ewt_lemmatizer.pt', 'en_ewt_parser.pt', 'fr_gsd_models', 'fr_gsd_tokenizer.pt', 'fr_gsd_mwt_expander.pt', 'fr_gsd_tagger.pt', 'fr_gsd.pretrain.pt', 'fr_gsd_lemmatizer.pt', 'fr_gsd_parser.pt', 'ar_padt_models', 'ar_padt_tokenizer.pt', 'ar_padt_mwt_expander.pt', 'ar_padt_tagger.pt', 'ar_padt.pretrain.pt', 'ar_padt_lemmatizer.pt', 'ar_padt_parser.pt', 'de_gsd_models', 'de_gsd_tokenizer.pt', 'de_gsd_mwt_expander.pt', 'de_gsd_tagger.pt', 'de_gsd.pretrain.pt', 'de_gsd_lemmatizer.pt', 'de_gsd_parser.pt', 'kk_ktb_models', 'kk_ktb_tokenizer.pt', 'kk_ktb_mwt_expander.pt', 'kk_ktb_tagger.pt', 'kk_ktb.pretrain.pt', 'kk_ktb_lemmatizer.pt', 'kk_ktb_parser.pt', 'ko_gsd_models', 'ko_gsd_tokenizer.pt', 'ko_gsd_mwt_expander.pt', 'ko_gsd_tagger.pt', 'ko_gsd.pretrain.pt', 'ko_gsd_lemmatizer.pt', 'ko_gsd_parser.pt'] def safe_rm(path_to_rm): """ Safely remove a directory of files or a file 1.) check path exists, files are files, dirs are dirs 2.) only remove things on approved list REMOVABLE_PATHS 3.) assert no longer exists """ # just return if path doesn't exist if not os.path.exists(path_to_rm): return # handle directory if os.path.isdir(path_to_rm): files_to_rm = [f'{path_to_rm}/{fname}' for fname in os.listdir(path_to_rm)] dir_to_rm = path_to_rm else: files_to_rm = [path_to_rm] dir_to_rm = None # clear out files for file_to_rm in files_to_rm: if os.path.isfile(file_to_rm) and os.path.basename(file_to_rm) in REMOVABLE_PATHS: os.remove(file_to_rm) assert not os.path.exists(file_to_rm), f'Error removing: {file_to_rm}' # clear out directory if dir_to_rm is not None and os.path.isdir(dir_to_rm): os.rmdir(dir_to_rm) assert not os.path.exists(dir_to_rm), f'Error removing: {dir_to_rm}' def compare_ignoring_whitespace(predicted, expected): predicted = re.sub('[ \t]+', ' ', predicted.strip()) predicted = re.sub('\r\n', '\n', predicted) expected = re.sub('[ \t]+', ' ', expected.strip()) expected = re.sub('\r\n', '\n', expected) assert predicted == expected ================================================ FILE: stanza/tests/classifiers/__init__.py ================================================ ================================================ FILE: stanza/tests/classifiers/test_classifier.py ================================================ import glob import os import pytest import numpy as np import torch import stanza import stanza.models.classifier as classifier import stanza.models.classifiers.data as data from stanza.models.classifiers.trainer import Trainer from stanza.models.common import pretrain from stanza.models.common import utils from stanza.tests import TEST_MODELS_DIR from stanza.tests.classifiers.test_data import train_file, dev_file, test_file, DATASET, SENTENCES pytestmark = [pytest.mark.pipeline, pytest.mark.travis] EMB_DIM = 5 @pytest.fixture(scope="module") def fake_embeddings(tmp_path_factory): """ will return a path to a fake embeddings file with the words in SENTENCES """ # could set np random seed here words = sorted(set([x.lower() for y in SENTENCES for x in y])) words = words[:-1] embedding_dir = tmp_path_factory.mktemp("data") embedding_txt = embedding_dir / "embedding.txt" embedding_pt = embedding_dir / "embedding.pt" embedding = np.random.random((len(words), EMB_DIM)) with open(embedding_txt, "w", encoding="utf-8") as fout: for word, emb in zip(words, embedding): fout.write(word) fout.write("\t") fout.write("\t".join(str(x) for x in emb)) fout.write("\n") pt = pretrain.Pretrain(str(embedding_pt), str(embedding_txt)) pt.load() assert os.path.exists(embedding_pt) return embedding_pt class TestClassifier: def build_model(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None): """ Build a model to be used by one of the later tests """ save_dir = str(tmp_path / "classifier") save_name = "model.pt" args = ["--save_dir", save_dir, "--save_name", save_name, "--wordvec_pretrain_file", str(fake_embeddings), "--filter_channels", "20", "--fc_shapes", "20,10", "--train_file", str(train_file), "--dev_file", str(dev_file), "--max_epochs", "2", "--batch_size", "60"] if extra_args is not None: args = args + extra_args args = classifier.parse_args(args) train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len) if checkpoint_file: trainer = Trainer.load(checkpoint_file, args, load_optimizer=True) else: trainer = Trainer.build_new_model(args, train_set) return trainer, train_set, args def run_training(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None): """ Iterate a couple times over a model """ trainer, train_set, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args, checkpoint_file) dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len) labels = data.dataset_labels(train_set) save_filename = os.path.join(args.save_dir, args.save_name) if checkpoint_file is None: checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name) classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels) return trainer, save_filename, checkpoint_file def test_build_model(self, tmp_path, fake_embeddings, train_file, dev_file): """ Test that building a basic model works """ self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"]) def test_save_load(self, tmp_path, fake_embeddings, train_file, dev_file): """ Test that a basic model can save & load """ trainer, _, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"]) save_filename = os.path.join(args.save_dir, args.save_name) trainer.save(save_filename) args.load_name = args.save_name trainer = Trainer.load(args.load_name, args) args.load_name = save_filename trainer = Trainer.load(args.load_name, args) def test_train_basic(self, tmp_path, fake_embeddings, train_file, dev_file): self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"]) def test_train_bilstm(self, tmp_path, fake_embeddings, train_file, dev_file): """ Test w/ and w/o bilstm variations of the classifier """ args = ["--bilstm", "--bilstm_hidden_dim", "20"] self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) args = ["--no_bilstm"] self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) def test_train_maxpool_width(self, tmp_path, fake_embeddings, train_file, dev_file): """ Test various maxpool widths Also sets --filter_channels to a multiple of 2 but not of 3 for the test to make sure the math is done correctly on a non-divisible width """ args = ["--maxpool_width", "1", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) args = ["--maxpool_width", "2", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) args = ["--maxpool_width", "3", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) def test_train_conv_2d(self, tmp_path, fake_embeddings, train_file, dev_file): args = ["--filter_sizes", "(3,4,5)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) args = ["--filter_sizes", "((3,2),)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) def test_train_filter_channels(self, tmp_path, fake_embeddings, train_file, dev_file): args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--no_bilstm"] trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) assert trainer.model.fc_input_size == 40 args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "15,20", "--no_bilstm"] trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) # 50 = 2x15 for the 2d conv (over 5 dim embeddings) + 20 assert trainer.model.fc_input_size == 50 def test_train_bert(self, tmp_path, fake_embeddings, train_file, dev_file): """ Test on a tiny Bert WITHOUT finetuning, which hopefully does not take up too much disk space or memory """ bert_model = "hf-internal-testing/tiny-bert" trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model]) assert os.path.exists(save_filename) saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True) # check that the bert model wasn't saved as part of the classifier assert not saved_model['params']['config']['force_bert_saved'] assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys()) def test_finetune_bert(self, tmp_path, fake_embeddings, train_file, dev_file): """ Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory """ bert_model = "hf-internal-testing/tiny-bert" trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune"]) assert os.path.exists(save_filename) saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True) # after finetuning the bert model, make sure that the save file DOES contain parts of the transformer assert saved_model['params']['config']['force_bert_saved'] assert any(x.startswith("bert_model") for x in saved_model['params']['model'].keys()) def test_finetune_bert_layers(self, tmp_path, fake_embeddings, train_file, dev_file): """Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory, using 2 layers As an added bonus (or eager test), load the finished model and continue training from there. Then check that the initial model and the middle model are different, then that the middle model and final model are different """ bert_model = "hf-internal-testing/tiny-bert" trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models"]) assert os.path.exists(save_filename) save_path = os.path.split(save_filename)[0] initial_model = glob.glob(os.path.join(save_path, "*E0000*")) assert len(initial_model) == 1 initial_model = initial_model[0] initial_model = torch.load(initial_model, lambda storage, loc: storage, weights_only=True) second_model_file = glob.glob(os.path.join(save_path, "*E0002*")) assert len(second_model_file) == 1 second_model_file = second_model_file[0] second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True) for layer_idx in range(2): bert_names = [x for x in second_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x] assert len(bert_names) > 0 assert all(x in initial_model['params']['model'] and x in second_model['params']['model'] for x in bert_names) assert not all(torch.allclose(initial_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names) # put some random marker in the file to look for later, # check the continued training didn't clobber the expected file assert "asdf" not in second_model second_model["asdf"] = 1234 torch.save(second_model, second_model_file) trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file) second_model_file_redo = glob.glob(os.path.join(save_path, "*E0002*")) assert len(second_model_file_redo) == 1 assert second_model_file == second_model_file_redo[0] second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True) assert "asdf" in second_model fifth_model_file = glob.glob(os.path.join(save_path, "*E0005*")) assert len(fifth_model_file) == 1 final_model = torch.load(fifth_model_file[0], lambda storage, loc: storage, weights_only=True) for layer_idx in range(2): bert_names = [x for x in final_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x] assert len(bert_names) > 0 assert all(x in final_model['params']['model'] and x in second_model['params']['model'] for x in bert_names) assert not all(torch.allclose(final_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names) def test_finetune_peft(self, tmp_path, fake_embeddings, train_file, dev_file): """ Test on a tiny Bert with PEFT finetuning """ bert_model = "hf-internal-testing/tiny-bert" trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler"]) assert os.path.exists(save_filename) saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True) # after finetuning the bert model, make sure that the save file DOES contain parts of the transformer, but only in peft form assert saved_model['params']['config']['bert_model'] == bert_model assert saved_model['params']['config']['force_bert_saved'] assert saved_model['params']['config']['use_peft'] assert not saved_model['params']['config']['has_charlm_forward'] assert not saved_model['params']['config']['has_charlm_backward'] assert len(saved_model['params']['bert_lora']) > 0 assert any(x.find(".pooler.") >= 0 for x in saved_model['params']['bert_lora']) assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora']) assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys()) # The Pipeline should load and run a PEFT trained model, # although obviously we don't expect the results to do # anything correct pipeline = stanza.Pipeline("en", download_method=None, model_dir=TEST_MODELS_DIR, processors="tokenize,sentiment", sentiment_model_path=save_filename, sentiment_pretrain_path=str(fake_embeddings)) doc = pipeline("This is a test") def test_finetune_peft_restart(self, tmp_path, fake_embeddings, train_file, dev_file): """ Test that if we restart training on a peft model, the peft weights change """ bert_model = "hf-internal-testing/tiny-bert" trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models"]) assert os.path.exists(save_file) saved_model = torch.load(save_file, lambda storage, loc: storage, weights_only=True) assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora']) trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file) save_path = os.path.split(save_file)[0] initial_model_file = glob.glob(os.path.join(save_path, "*E0000*")) assert len(initial_model_file) == 1 initial_model_file = initial_model_file[0] initial_model = torch.load(initial_model_file, lambda storage, loc: storage, weights_only=True) second_model_file = glob.glob(os.path.join(save_path, "*E0002*")) assert len(second_model_file) == 1 second_model_file = second_model_file[0] second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True) final_model_file = glob.glob(os.path.join(save_path, "*E0005*")) assert len(final_model_file) == 1 final_model_file = final_model_file[0] final_model = torch.load(final_model_file, lambda storage, loc: storage, weights_only=True) # params in initial_model & second_model start with "base_model.model." # whereas params in final_model start directly with "encoder" or "pooler" initial_lora = initial_model['params']['bert_lora'] second_lora = second_model['params']['bert_lora'] final_lora = final_model['params']['bert_lora'] for side in ("_A.", "_B."): for layer in (".0.", ".1."): initial_params = sorted([x for x in initial_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0]) second_params = sorted([x for x in second_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0]) final_params = sorted([x for x in final_lora if x.startswith("encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0]) assert len(initial_params) > 0 assert len(initial_params) == len(second_params) assert len(initial_params) == len(final_params) for x, y in zip(second_params, final_params): assert x.endswith(y) if side != "_A.": # the A tensors don't move very much, if at all assert not torch.allclose(initial_lora.get(x), second_lora.get(x)) assert not torch.allclose(second_lora.get(x), final_lora.get(y)) ================================================ FILE: stanza/tests/classifiers/test_constituency_classifier.py ================================================ import os import pytest import stanza import stanza.models.classifier as classifier import stanza.models.classifiers.data as data from stanza.models.classifiers.trainer import Trainer from stanza.tests import TEST_MODELS_DIR from stanza.tests.classifiers.test_classifier import fake_embeddings from stanza.tests.classifiers.test_data import train_file_with_trees, dev_file_with_trees from stanza.models.common import utils from stanza.tests.constituency.test_trainer import build_trainer, TREEBANK pytestmark = [pytest.mark.pipeline, pytest.mark.travis] class TestConstituencyClassifier: @pytest.fixture(scope="class") def constituency_model(self, fake_embeddings, tmp_path_factory): args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'] trainer = build_trainer(str(fake_embeddings), *args, treebank=TREEBANK) trainer_pt = str(tmp_path_factory.mktemp("constituency") / "constituency.pt") trainer.save(trainer_pt, save_optimizer=False) return trainer_pt def build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None): """ Build a Constituency Classifier model to be used by one of the later tests """ save_dir = str(tmp_path / "classifier") save_name = "model.pt" args = ["--save_dir", save_dir, "--save_name", save_name, "--model_type", "constituency", "--constituency_model", constituency_model, "--wordvec_pretrain_file", str(fake_embeddings), "--fc_shapes", "20,10", "--train_file", str(train_file_with_trees), "--dev_file", str(dev_file_with_trees), "--max_epochs", "2", "--batch_size", "60"] if extra_args is not None: args = args + extra_args args = classifier.parse_args(args) train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len) trainer = Trainer.build_new_model(args, train_set) return trainer, train_set, args def run_training(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None): """ Iterate a couple times over a model """ trainer, train_set, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args) dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len) labels = data.dataset_labels(train_set) save_filename = os.path.join(args.save_dir, args.save_name) checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name) classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels) return trainer, train_set, args def test_build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees): """ Test that building a basic constituency-based model works """ self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees) def test_save_load(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees): """ Test that a constituency model can save & load """ trainer, _, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees) save_filename = os.path.join(args.save_dir, args.save_name) trainer.save(save_filename) args.load_name = args.save_name trainer = Trainer.load(args.load_name, args) def test_train_basic(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees): self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees) def test_train_pipeline(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees): """ Test that writing out a temp model, then loading it in the pipeline is a thing that works """ trainer, _, args = self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees) save_filename = os.path.join(args.save_dir, args.save_name) assert os.path.exists(save_filename) assert os.path.exists(args.constituency_model) pipeline_args = {"lang": "en", "download_method": None, "model_dir": TEST_MODELS_DIR, "processors": "tokenize,pos,constituency,sentiment", "tokenize_pretokenized": True, "constituency_model_path": args.constituency_model, "constituency_pretrain_path": args.wordvec_pretrain_file, "constituency_backward_charlm_path": None, "constituency_forward_charlm_path": None, "sentiment_model_path": save_filename, "sentiment_pretrain_path": args.wordvec_pretrain_file, "sentiment_backward_charlm_path": None, "sentiment_forward_charlm_path": None} pipeline = stanza.Pipeline(**pipeline_args) doc = pipeline("This is a test") # since the model is random, we have no expectations for what the result actually is assert doc.sentences[0].sentiment is not None def test_train_all_words(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees): self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_all_words']) self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_all_words']) def test_train_top_layer(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees): self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_top_layer']) self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_top_layer']) def test_train_attn(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees): self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--no_constituency_all_words']) self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--constituency_all_words']) self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_node_attn']) ================================================ FILE: stanza/tests/classifiers/test_data.py ================================================ import json import pytest import stanza.models.classifiers.data as data from stanza.models.classifiers.utils import WVType from stanza.models.common.vocab import PAD, UNK from stanza.models.constituency.parse_tree import Tree SENTENCES = [ ["I", "hate", "the", "Opal", "banning"], ["Tell", "my", "wife", "hello"], # obviously this is the neutral result ["I", "like", "Sh'reyan", "'s", "antennae"], ] DATASET = [ {"sentiment": "0", "text": SENTENCES[0]}, {"sentiment": "1", "text": SENTENCES[1]}, {"sentiment": "2", "text": SENTENCES[2]}, ] TREES = [ "(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))", "(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))", "(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))", ] DATASET_WITH_TREES = [ {"sentiment": "0", "text": SENTENCES[0], "constituency": TREES[0]}, {"sentiment": "1", "text": SENTENCES[1], "constituency": TREES[1]}, {"sentiment": "2", "text": SENTENCES[2], "constituency": TREES[2]}, ] @pytest.fixture(scope="module") def train_file(tmp_path_factory): train_set = DATASET * 20 train_filename = tmp_path_factory.mktemp("data") / "train.json" with open(train_filename, "w", encoding="utf-8") as fout: json.dump(train_set, fout, ensure_ascii=False) return train_filename @pytest.fixture(scope="module") def dev_file(tmp_path_factory): dev_set = DATASET * 2 dev_filename = tmp_path_factory.mktemp("data") / "dev.json" with open(dev_filename, "w", encoding="utf-8") as fout: json.dump(dev_set, fout, ensure_ascii=False) return dev_filename @pytest.fixture(scope="module") def test_file(tmp_path_factory): test_set = DATASET test_filename = tmp_path_factory.mktemp("data") / "test.json" with open(test_filename, "w", encoding="utf-8") as fout: json.dump(test_set, fout, ensure_ascii=False) return test_filename @pytest.fixture(scope="module") def train_file_with_trees(tmp_path_factory): train_set = DATASET_WITH_TREES * 20 train_filename = tmp_path_factory.mktemp("data") / "train_trees.json" with open(train_filename, "w", encoding="utf-8") as fout: json.dump(train_set, fout, ensure_ascii=False) return train_filename @pytest.fixture(scope="module") def dev_file_with_trees(tmp_path_factory): dev_set = DATASET_WITH_TREES * 2 dev_filename = tmp_path_factory.mktemp("data") / "dev_trees.json" with open(dev_filename, "w", encoding="utf-8") as fout: json.dump(dev_set, fout, ensure_ascii=False) return dev_filename class TestClassifierData: def test_read_data(self, train_file): """ Test reading of the json format """ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) assert len(train_set) == 60 def test_read_data_with_trees(self, train_file, train_file_with_trees): """ Test reading of the json format """ train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1) assert len(train_trees_set) == 60 for idx, x in enumerate(train_trees_set): assert isinstance(x.constituency, Tree) assert str(x.constituency) == TREES[idx % len(TREES)] train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) def test_dataset_vocab(self, train_file): """ Converting a dataset to vocab should have a specific set of words along with PAD and UNK """ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) vocab = data.dataset_vocab(train_set) expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y]) assert set(vocab) == expected def test_dataset_labels(self, train_file): """ Test the extraction of labels from a dataset """ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) labels = data.dataset_labels(train_set) assert labels == ["0", "1", "2"] def test_sort_by_length(self, train_file): """ There are two unique lengths in the toy dataset """ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) sorted_dataset = data.sort_dataset_by_len(train_set) assert list(sorted_dataset.keys()) == [4, 5] assert len(sorted_dataset[4]) == len(train_set) // 3 assert len(sorted_dataset[5]) == 2 * len(train_set) // 3 def test_check_labels(self, train_file): """ Check that an exception is thrown for an unknown label """ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) labels = sorted(set([x["sentiment"] for x in DATASET])) assert len(labels) > 1 data.check_labels(labels, train_set) with pytest.raises(RuntimeError): data.check_labels(labels[:1], train_set) ================================================ FILE: stanza/tests/classifiers/test_process_utils.py ================================================ """ A few tests of the utils module for the sentiment datasets """ import os import pytest import stanza from stanza.models.classifiers import data from stanza.models.classifiers.data import SentimentDatum from stanza.models.classifiers.utils import WVType from stanza.utils.datasets.sentiment import process_utils from stanza.tests import TEST_MODELS_DIR from stanza.tests.classifiers.test_data import train_file, dev_file, test_file def test_write_list(tmp_path, train_file): """ Test that writing a single list of items to an output file works """ train_set = data.read_dataset(train_file, WVType.OTHER, 1) dataset_file = tmp_path / "foo.json" process_utils.write_list(dataset_file, train_set) train_copy = data.read_dataset(dataset_file, WVType.OTHER, 1) assert train_copy == train_set def test_write_dataset(tmp_path, train_file, dev_file, test_file): """ Test that writing all three parts of a dataset works """ dataset = [data.read_dataset(filename, WVType.OTHER, 1) for filename in (train_file, dev_file, test_file)] process_utils.write_dataset(dataset, tmp_path, "en_test") expected_files = ['en_test.train.json', 'en_test.dev.json', 'en_test.test.json'] dataset_files = os.listdir(tmp_path) assert sorted(dataset_files) == sorted(expected_files) for filename, expected in zip(expected_files, dataset): written = data.read_dataset(tmp_path / filename, WVType.OTHER, 1) assert written == expected def test_read_snippets(tmp_path): """ Test the basic operation of the read_snippets function """ filename = tmp_path / "foo.csv" with open(filename, "w", encoding="utf-8") as fout: fout.write("FOO\tThis is a test\thappy\n") fout.write("FOO\tThis is a second sentence\tsad\n") nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None) mapping = {"happy": 0, "sad": 1} snippets = process_utils.read_snippets(filename, 2, 1, "en", mapping, nlp=nlp) assert len(snippets) == 2 assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']), SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence'])] def test_read_snippets_two_columns(tmp_path): """ Test what happens when multiple columns are combined for the sentiment value """ filename = tmp_path / "foo.csv" with open(filename, "w", encoding="utf-8") as fout: fout.write("FOO\tThis is a test\thappy\tfoo\n") fout.write("FOO\tThis is a second sentence\tsad\tbar\n") fout.write("FOO\tThis is a third sentence\tsad\tfoo\n") nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None) mapping = {("happy", "foo"): 0, ("sad", "bar"): 1, ("sad", "foo"): 2} snippets = process_utils.read_snippets(filename, (2,3), 1, "en", mapping, nlp=nlp) assert len(snippets) == 3 assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']), SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence']), SentimentDatum(sentiment=2, text=['This', 'is', 'a', 'third', 'sentence'])] ================================================ FILE: stanza/tests/common/__init__.py ================================================ ================================================ FILE: stanza/tests/common/test_bert_embedding.py ================================================ import pytest import torch from stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings pytestmark = [pytest.mark.travis, pytest.mark.pipeline] BERT_MODEL = "hf-internal-testing/tiny-bert" @pytest.fixture(scope="module") def tiny_bert(): m, t = load_bert(BERT_MODEL) return m, t def test_load_bert(tiny_bert): """ Empty method that just tests loading the bert """ m, t = tiny_bert def test_run_bert(tiny_bert): m, t = tiny_bert device = next(m.parameters()).device extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "a", "test"]], device, True) def test_run_bert_empty_word(tiny_bert): m, t = tiny_bert device = next(m.parameters()).device foo = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "-", "a", "test"]], device, True) bar = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "", "a", "test"]], device, True) assert len(foo) == 1 assert torch.allclose(foo[0], bar[0]) ================================================ FILE: stanza/tests/common/test_char_model.py ================================================ """ Currently tests a few configurations of files for creating a charlm vocab Also has a skeleton test of loading & saving a charlm """ from collections import Counter import glob import lzma import os import tempfile import pytest from stanza.models import charlm from stanza.models.common import char_model from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.travis, pytest.mark.pipeline] fake_text_1 = """ Unban mox opal! I hate watching Peppa Pig """ fake_text_2 = """ This is plastic cheese """ class TestCharModel: def test_single_file_vocab(self): with tempfile.TemporaryDirectory() as tempdir: sample_file = os.path.join(tempdir, "text.txt") with open(sample_file, "w", encoding="utf-8") as fout: fout.write(fake_text_1) vocab = char_model.build_charlm_vocab(sample_file) for i in fake_text_1: assert i in vocab assert "Q" not in vocab def test_single_file_xz_vocab(self): with tempfile.TemporaryDirectory() as tempdir: sample_file = os.path.join(tempdir, "text.txt.xz") with lzma.open(sample_file, "wt", encoding="utf-8") as fout: fout.write(fake_text_1) vocab = char_model.build_charlm_vocab(sample_file) for i in fake_text_1: assert i in vocab assert "Q" not in vocab def test_single_file_dir_vocab(self): with tempfile.TemporaryDirectory() as tempdir: sample_file = os.path.join(tempdir, "text.txt") with open(sample_file, "w", encoding="utf-8") as fout: fout.write(fake_text_1) vocab = char_model.build_charlm_vocab(tempdir) for i in fake_text_1: assert i in vocab assert "Q" not in vocab def test_multiple_files_vocab(self): with tempfile.TemporaryDirectory() as tempdir: sample_file = os.path.join(tempdir, "t1.txt") with open(sample_file, "w", encoding="utf-8") as fout: fout.write(fake_text_1) sample_file = os.path.join(tempdir, "t2.txt.xz") with lzma.open(sample_file, "wt", encoding="utf-8") as fout: fout.write(fake_text_2) vocab = char_model.build_charlm_vocab(tempdir) for i in fake_text_1: assert i in vocab for i in fake_text_2: assert i in vocab assert "Q" not in vocab def test_cutoff_vocab(self): with tempfile.TemporaryDirectory() as tempdir: sample_file = os.path.join(tempdir, "t1.txt") with open(sample_file, "w", encoding="utf-8") as fout: fout.write(fake_text_1) sample_file = os.path.join(tempdir, "t2.txt.xz") with lzma.open(sample_file, "wt", encoding="utf-8") as fout: fout.write(fake_text_2) vocab = char_model.build_charlm_vocab(tempdir, cutoff=2) counts = Counter(fake_text_1) + Counter(fake_text_2) for letter, count in counts.most_common(): if count < 2: assert letter not in vocab else: assert letter in vocab def test_build_model(self): """ Test the whole thing on a small dataset for an iteration or two """ with tempfile.TemporaryDirectory() as tempdir: eval_file = os.path.join(tempdir, "en_test.dev.txt") with open(eval_file, "w", encoding="utf-8") as fout: fout.write(fake_text_1) train_file = os.path.join(tempdir, "en_test.train.txt") with open(train_file, "w", encoding="utf-8") as fout: for i in range(1000): fout.write(fake_text_1) fout.write("\n") fout.write(fake_text_2) fout.write("\n") save_name = 'en_test.forward.pt' vocab_save_name = 'en_text.vocab.pt' checkpoint_save_name = 'en_text.checkpoint.pt' args = ['--train_file', train_file, '--eval_file', eval_file, '--eval_steps', '0', # eval once per opoch '--epochs', '2', '--cutoff', '1', '--batch_size', '%d' % len(fake_text_1), '--shorthand', 'en_test', '--save_dir', tempdir, '--save_name', save_name, '--vocab_save_name', vocab_save_name, '--checkpoint_save_name', checkpoint_save_name] args = charlm.parse_args(args) charlm.train(args) assert os.path.exists(os.path.join(tempdir, vocab_save_name)) # test that saving & loading of the model worked assert os.path.exists(os.path.join(tempdir, save_name)) model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, save_name)) # test that saving & loading of the checkpoint worked assert os.path.exists(os.path.join(tempdir, checkpoint_save_name)) model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, checkpoint_save_name)) trainer = char_model.CharacterLanguageModelTrainer.load(args, os.path.join(tempdir, checkpoint_save_name)) assert trainer.global_step > 0 assert trainer.epoch == 2 # quick test to verify this method works with a trained model charlm.get_current_lr(trainer, args) # test loading a vocab built by the training method... vocab = charlm.load_char_vocab(os.path.join(tempdir, vocab_save_name)) trainer = char_model.CharacterLanguageModelTrainer.from_new_model(args, vocab) # ... and test the get_current_lr for an untrained model as well # this test is super "eager" assert charlm.get_current_lr(trainer, args) == args['lr0'] @pytest.fixture(scope="class") def english_forward(self): # eg, stanza_test/models/en/forward_charlm/1billion.pt models_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "*") models = glob.glob(models_path) # we expect at least one English model downloaded for the tests assert len(models) >= 1 model_file = models[0] return char_model.CharacterLanguageModel.load(model_file) @pytest.fixture(scope="class") def english_backward(self): # eg, stanza_test/models/en/forward_charlm/1billion.pt models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*") models = glob.glob(models_path) # we expect at least one English model downloaded for the tests assert len(models) >= 1 model_file = models[0] return char_model.CharacterLanguageModel.load(model_file) def test_load_model(self, english_forward, english_backward): """ Check that basic loading functions work """ assert english_forward.is_forward_lm assert not english_backward.is_forward_lm def test_save_load_model(self, english_forward, english_backward): """ Load, save, and load again """ with tempfile.TemporaryDirectory() as tempdir: for model in (english_forward, english_backward): save_file = os.path.join(tempdir, "resaved", "charlm.pt") model.save(save_file) reloaded = char_model.CharacterLanguageModel.load(save_file) assert model.is_forward_lm == reloaded.is_forward_lm ================================================ FILE: stanza/tests/common/test_chuliu_edmonds.py ================================================ """ Test some use cases of the chuliu_edmonds algorithm (currently just the tarjan implementation) """ import numpy as np import pytest from stanza.models.common.chuliu_edmonds import tarjan pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_tarjan_basic(): simple = np.array([0, 4, 4, 4, 0]) result = tarjan(simple) assert result == [] simple = np.array([0, 2, 0, 4, 2, 2]) result = tarjan(simple) assert result == [] def test_tarjan_cycle(): cycle_graph = np.array([0, 3, 1, 2]) result = tarjan(cycle_graph) expected = np.array([False, True, True, True]) assert len(result) == 1 np.testing.assert_array_equal(result[0], expected) cycle_graph = np.array([0, 3, 1, 2, 5, 6, 4]) result = tarjan(cycle_graph) assert len(result) == 2 expected = [np.array([False, True, True, True, False, False, False]), np.array([False, False, False, False, True, True, True])] for r, e in zip(result, expected): np.testing.assert_array_equal(r, e) ================================================ FILE: stanza/tests/common/test_common_data.py ================================================ import pytest import stanza from stanza.tests import * from stanza.models.common.data import get_augment_ratio, augment_punct pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_augment_ratio(): data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] should_augment = lambda x: x >= 3 can_augment = lambda x: x >= 4 # check that zero is returned if no augmentation is needed # which will be the case since 2 are already satisfactory assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.1) == 0.0 # this should throw an error with pytest.raises(AssertionError): get_augment_ratio(data, can_augment, should_augment) # with a desired ratio of 0.4, # there are already 2 that don't need augmenting # and 7 that are eligible to be augmented # so 2/7 will need to be augmented assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.4) == pytest.approx(2/7) def test_augment_punct(): data = [["Simple", "test", "."]] should_augment = lambda x: x[-1] == "." can_augment = should_augment new_data = augment_punct(data, 1.0, should_augment, can_augment) assert new_data == [["Simple", "test"]] ================================================ FILE: stanza/tests/common/test_confusion.py ================================================ """ Test a couple simple confusion matrices and output formats """ from collections import defaultdict import pytest from stanza.utils.confusion import format_confusion, confusion_to_f1, confusion_to_macro_f1, confusion_to_weighted_f1 pytestmark = [pytest.mark.travis, pytest.mark.pipeline] @pytest.fixture def simple_confusion(): confusion = defaultdict(lambda: defaultdict(int)) confusion["B-ORG"]["B-ORG"] = 1 confusion["B-ORG"]["B-PER"] = 1 confusion["E-ORG"]["E-ORG"] = 1 confusion["E-ORG"]["E-PER"] = 1 confusion["O"]["O"] = 4 return confusion @pytest.fixture def short_confusion(): """ Same thing, but with a short name. This should not be sorted by entity type """ confusion = defaultdict(lambda: defaultdict(int)) confusion["A"]["B-ORG"] = 1 confusion["B-ORG"]["B-PER"] = 1 confusion["E-ORG"]["E-ORG"] = 1 confusion["E-ORG"]["E-PER"] = 1 confusion["O"]["O"] = 4 return confusion EXPECTED_SIMPLE_OUTPUT = """ t\\p O B-ORG E-ORG B-PER E-PER O 4 0 0 0 0 B-ORG 0 1 0 1 0 E-ORG 0 0 1 0 1 B-PER 0 0 0 0 0 E-PER 0 0 0 0 0 """[1:-1] # don't want to strip EXPECTED_SHORT_OUTPUT = """ t\\p O A B-ORG B-PER E-ORG E-PER O 4 0 0 0 0 0 A 0 0 1 0 0 0 B-ORG 0 0 0 1 0 0 B-PER 0 0 0 0 0 0 E-ORG 0 0 0 0 1 1 E-PER 0 0 0 0 0 0 """[1:-1] EXPECTED_HIDE_BLANK_SHORT_OUTPUT = """ t\\p O B-ORG E-ORG B-PER E-PER O 4 0 0 0 0 A 0 1 0 0 0 B-ORG 0 0 0 1 0 E-ORG 0 0 1 0 1 """[1:-1] def test_simple_output(simple_confusion): assert EXPECTED_SIMPLE_OUTPUT == format_confusion(simple_confusion) def test_short_output(short_confusion): assert EXPECTED_SHORT_OUTPUT == format_confusion(short_confusion) def test_hide_blank_short_output(short_confusion): assert EXPECTED_HIDE_BLANK_SHORT_OUTPUT == format_confusion(short_confusion, hide_blank=True) def test_macro_f1(simple_confusion, short_confusion): assert confusion_to_macro_f1(simple_confusion) == pytest.approx(0.466666666666) assert confusion_to_macro_f1(short_confusion) == pytest.approx(0.277777777777) def test_weighted_f1(simple_confusion, short_confusion): assert confusion_to_weighted_f1(simple_confusion) == pytest.approx(0.83333333) assert confusion_to_weighted_f1(short_confusion) == pytest.approx(0.66666666) assert confusion_to_weighted_f1(simple_confusion, exclude=["O"]) == pytest.approx(0.66666666) assert confusion_to_weighted_f1(short_confusion, exclude=["O"]) == pytest.approx(0.33333333) ================================================ FILE: stanza/tests/common/test_constant.py ================================================ """ Test the conversion to lcodes and splitting of dataset names """ import tempfile import pytest import stanza from stanza.models.common.constant import treebank_to_short_name, lang_to_langcode, is_right_to_left, two_to_three_letters, langlower2lcode from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_treebank(): """ Test the entire treebank name conversion """ # conversion of a UD_ name assert "hi_hdtb" == treebank_to_short_name("UD_Hindi-HDTB") # conversion of names without UD assert "hi_fire2013" == treebank_to_short_name("Hindi-fire2013") assert "hi_fire2013" == treebank_to_short_name("Hindi-Fire2013") assert "hi_fire2013" == treebank_to_short_name("Hindi-FIRE2013") # already short names are generally preserved assert "hi_fire2013" == treebank_to_short_name("hi-fire2013") assert "hi_fire2013" == treebank_to_short_name("hi_fire2013") # a special case assert "zh-hant_pud" == treebank_to_short_name("UD_Chinese-PUD") # a special case already converted once assert "zh-hant_pud" == treebank_to_short_name("zh-hant_pud") assert "zh-hant_pud" == treebank_to_short_name("zh-hant-pud") assert "zh-hans_gsdsimp" == treebank_to_short_name("zh-hans_gsdsimp") assert "wo_masakhane" == treebank_to_short_name("wo_masakhane") assert "wo_masakhane" == treebank_to_short_name("wol_masakhane") assert "wo_masakhane" == treebank_to_short_name("Wol_masakhane") assert "wo_masakhane" == treebank_to_short_name("wolof_masakhane") assert "wo_masakhane" == treebank_to_short_name("Wolof_masakhane") def test_lang_to_langcode(): assert "hi" == lang_to_langcode("Hindi") assert "hi" == lang_to_langcode("HINDI") assert "hi" == lang_to_langcode("hindi") assert "hi" == lang_to_langcode("HI") assert "hi" == lang_to_langcode("hi") def test_right_to_left(): assert is_right_to_left("ar") assert is_right_to_left("Arabic") assert not is_right_to_left("en") assert not is_right_to_left("English") def test_two_to_three(): assert lang_to_langcode("Wolof") == "wo" assert lang_to_langcode("wol") == "wo" assert "wo" in two_to_three_letters assert two_to_three_letters["wo"] == "wol" def test_langlower(): assert lang_to_langcode("WOLOF") == "wo" assert lang_to_langcode("nOrWeGiAn") == "nb" assert "soj" == langlower2lcode["soi"] assert "soj" == langlower2lcode["sohi"] ================================================ FILE: stanza/tests/common/test_data_conversion.py ================================================ """ Basic tests of the data conversion """ import io import pytest import tempfile from zipfile import ZipFile import stanza from stanza.utils.conll import CoNLL from stanza.models.common.doc import Document from stanza.tests import * pytestmark = pytest.mark.pipeline # data for testing CONLL = [[['1', 'Nous', 'il', 'PRON', '_', 'Number=Plur|Person=1|PronType=Prs', '3', 'nsubj', '_', 'start_char=0|end_char=4'], ['2', 'avons', 'avoir', 'AUX', '_', 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', '3', 'aux:tense', '_', 'start_char=5|end_char=10'], ['3', 'atteint', 'atteindre', 'VERB', '_', 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', '0', 'root', '_', 'start_char=11|end_char=18'], ['4', 'la', 'le', 'DET', '_', 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', '5', 'det', '_', 'start_char=19|end_char=21'], ['5', 'fin', 'fin', 'NOUN', '_', 'Gender=Fem|Number=Sing', '3', 'obj', '_', 'start_char=22|end_char=25'], ['6-7', 'du', '_', '_', '_', '_', '_', '_', '_', 'start_char=26|end_char=28'], ['6', 'de', 'de', 'ADP', '_', '_', '8', 'case', '_', '_'], ['7', 'le', 'le', 'DET', '_', 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', '8', 'det', '_', '_'], ['8', 'sentier', 'sentier', 'NOUN', '_', 'Gender=Masc|Number=Sing', '5', 'nmod', '_', 'start_char=29|end_char=36'], ['9', '.', '.', 'PUNCT', '_', '_', '3', 'punct', '_', 'start_char=36|end_char=37']]] DICT = [[{'id': (1,), 'text': 'Nous', 'lemma': 'il', 'upos': 'PRON', 'feats': 'Number=Plur|Person=1|PronType=Prs', 'head': 3, 'deprel': 'nsubj', 'misc': 'start_char=0|end_char=4'}, {'id': (2,), 'text': 'avons', 'lemma': 'avoir', 'upos': 'AUX', 'feats': 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', 'head': 3, 'deprel': 'aux:tense', 'misc': 'start_char=5|end_char=10'}, {'id': (3,), 'text': 'atteint', 'lemma': 'atteindre', 'upos': 'VERB', 'feats': 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', 'head': 0, 'deprel': 'root', 'misc': 'start_char=11|end_char=18'}, {'id': (4,), 'text': 'la', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', 'head': 5, 'deprel': 'det', 'misc': 'start_char=19|end_char=21'}, {'id': (5,), 'text': 'fin', 'lemma': 'fin', 'upos': 'NOUN', 'feats': 'Gender=Fem|Number=Sing', 'head': 3, 'deprel': 'obj', 'misc': 'start_char=22|end_char=25'}, {'id': (6, 7), 'text': 'du', 'misc': 'start_char=26|end_char=28'}, {'id': (6,), 'text': 'de', 'lemma': 'de', 'upos': 'ADP', 'head': 8, 'deprel': 'case'}, {'id': (7,), 'text': 'le', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', 'head': 8, 'deprel': 'det'}, {'id': (8,), 'text': 'sentier', 'lemma': 'sentier', 'upos': 'NOUN', 'feats': 'Gender=Masc|Number=Sing', 'head': 5, 'deprel': 'nmod', 'misc': 'start_char=29|end_char=36'}, {'id': (9,), 'text': '.', 'lemma': '.', 'upos': 'PUNCT', 'head': 3, 'deprel': 'punct', 'misc': 'start_char=36|end_char=37'}]] def test_conll_to_dict(): dicts, empty = CoNLL.convert_conll(CONLL) assert dicts == DICT assert len(dicts) == len(empty) assert all(len(x) == 0 for x in empty) def test_dict_to_conll(): document = Document(DICT) # :c = no comments conll = [[sentence.split("\t") for sentence in doc.split("\n")] for doc in "{:c}".format(document).split("\n\n")] assert conll == CONLL def test_dict_to_doc_and_doc_to_dict(): """ Test the conversion from raw dict to Document and back This code path will first turn start_char|end_char into start_char & end_char fields in the Document That version to a dict will have separate fields for each of those Finally, the conversion from that dict to a list of conll entries should convert that back to misc """ document = Document(DICT) dicts = document.to_dict() document = Document(dicts) conll = [[sentence.split("\t") for sentence in doc.split("\n")] for doc in "{:c}".format(document).split("\n\n")] assert conll == CONLL # sample is two sentences long so that the tests check multiple sentences RUSSIAN_SAMPLE=""" # sent_id = yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253 # genre = review # text = Как- то слишком мало цветов получают актёры после спектакля. 1 Как как-то ADV _ Degree=Pos|PronType=Ind 7 advmod _ SpaceAfter=No 2 - - PUNCT _ _ 3 punct _ _ 3 то то PART _ _ 1 list _ deprel=list:goeswith 4 слишком слишком ADV _ Degree=Pos 5 advmod _ _ 5 мало мало ADV _ Degree=Pos 6 advmod _ _ 6 цветов цветок NOUN _ Animacy=Inan|Case=Gen|Gender=Masc|Number=Plur 7 obj _ _ 7 получают получать VERB _ Aspect=Imp|Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ _ 8 актёры актер NOUN _ Animacy=Anim|Case=Nom|Gender=Masc|Number=Plur 7 nsubj _ _ 9 после после ADP _ _ 10 case _ _ 10 спектакля спектакль NOUN _ Animacy=Inan|Case=Gen|Gender=Masc|Number=Sing 7 obl _ SpaceAfter=No 11 . . PUNCT _ _ 7 punct _ _ # sent_id = 4 # genre = social # text = В женщине важна верность, а не красота. 1 В в ADP _ _ 2 case _ _ 2 женщине женщина NOUN _ Animacy=Anim|Case=Loc|Gender=Fem|Number=Sing 3 obl _ _ 3 важна важный ADJ _ Degree=Pos|Gender=Fem|Number=Sing|Variant=Short 0 root _ _ 4 верность верность NOUN _ Animacy=Inan|Case=Nom|Gender=Fem|Number=Sing 3 nsubj _ SpaceAfter=No 5 , , PUNCT _ _ 8 punct _ _ 6 а а CCONJ _ _ 8 cc _ _ 7 не не PART _ Polarity=Neg 8 advmod _ _ 8 красота красота NOUN _ Animacy=Inan|Case=Nom|Gender=Fem|Number=Sing 4 conj _ SpaceAfter=No 9 . . PUNCT _ _ 3 punct _ _ """.strip() RUSSIAN_TEXT = ["Как- то слишком мало цветов получают актёры после спектакля.", "В женщине важна верность, а не красота."] RUSSIAN_IDS = ["yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253", "4"] def check_russian_doc(doc): """ Refactored the test for the Russian doc so we can use it to test various file methods """ lines = RUSSIAN_SAMPLE.split("\n") assert len(doc.sentences) == 2 assert lines[0] == doc.sentences[0].comments[0] assert lines[1] == doc.sentences[0].comments[1] assert lines[2] == doc.sentences[0].comments[2] for sent_idx, (expected_text, expected_id, sentence) in enumerate(zip(RUSSIAN_TEXT, RUSSIAN_IDS, doc.sentences)): assert expected_text == sentence.text assert expected_id == sentence.sent_id assert sent_idx == sentence.index assert len(sentence.comments) == 3 assert not sentence.has_enhanced_dependencies() sentences = "{:C}".format(doc) sentences = sentences.split("\n\n") assert len(sentences) == 2 sentence = sentences[0].split("\n") assert len(sentence) == 14 assert lines[0] == sentence[0] assert lines[1] == sentence[1] assert lines[2] == sentence[2] # assert that the weird deprel=list:goeswith was properly handled assert doc.sentences[0].words[2].head == 1 assert doc.sentences[0].words[2].deprel == "list:goeswith" def test_write_russian_doc(tmp_path): """ Specifically test the write_doc2conll method """ filename = tmp_path / "russian.conll" doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE) check_russian_doc(doc) CoNLL.write_doc2conll(doc, filename) with open(filename, encoding="utf-8") as fin: text = fin.read() # the conll docs have to end with \n\n assert text.endswith("\n\n") # but to compare against the original, strip off the whitespace text = text.strip() # we skip the first sentence because the "deprel=list:goeswith" is weird # note that the deprel itself is checked in check_russian_doc text = text[text.find("# sent_id = 4"):] sample = RUSSIAN_SAMPLE[RUSSIAN_SAMPLE.find("# sent_id = 4"):] assert text == sample doc2 = CoNLL.conll2doc(filename) check_russian_doc(doc2) # random sentence from EN_Pronouns ENGLISH_SAMPLE = """ # newdoc # sent_id = 1 # text = It is hers. # previous = Which person owns this? # comment = copular subject 1 It it PRON PRP Number=Sing|Person=3|PronType=Prs 3 nsubj _ _ 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _ 3 hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No 4 . . PUNCT . _ 3 punct _ _ """.strip() def test_write_to_io(): doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE) output = io.StringIO() CoNLL.write_doc2conll(doc, output) output_value = output.getvalue() assert output_value.endswith("\n\n") assert output_value.strip() == ENGLISH_SAMPLE def test_write_doc2conll_append(tmp_path): doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE) filename = tmp_path / "english.conll" CoNLL.write_doc2conll(doc, filename) CoNLL.write_doc2conll(doc, filename, mode="a") with open(filename) as fin: text = fin.read() expected = ENGLISH_SAMPLE + "\n\n" + ENGLISH_SAMPLE + "\n\n" assert text == expected def test_doc_with_comments(): """ Test that a doc with comments gets converted back with comments """ doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE) check_russian_doc(doc) def test_unusual_misc(): """ The above RUSSIAN_SAMPLE resulted in a blank misc field in one particular implementation of the conll code (the below test would fail) """ doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE) sentences = "{:C}".format(doc).split("\n\n") assert len(sentences) == 2 sentence = sentences[0].split("\n") assert len(sentence) == 14 for word in sentence: pieces = word.split("\t") assert len(pieces) == 1 or len(pieces) == 10 if len(pieces) == 10: assert all(piece for piece in pieces) def test_file(): """ Test loading a doc from a file """ with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "russian.conll") with open(filename, "w", encoding="utf-8") as fout: fout.write(RUSSIAN_SAMPLE) doc = CoNLL.conll2doc(input_file=filename) check_russian_doc(doc) def test_zip_file(): """ Test loading a doc from a zip file """ with tempfile.TemporaryDirectory() as tempdir: zip_file = os.path.join(tempdir, "russian.zip") filename = "russian.conll" with ZipFile(zip_file, "w") as zout: with zout.open(filename, "w") as fout: fout.write(RUSSIAN_SAMPLE.encode()) doc = CoNLL.conll2doc(input_file=filename, zip_file=zip_file) check_russian_doc(doc) SIMPLE_NER = """ # text = Teferi's best friend is Karn # sent_id = 0 1 Teferi _ _ _ _ 0 _ _ start_char=0|end_char=6|ner=S-PERSON 2 's _ _ _ _ 1 _ _ start_char=6|end_char=8|ner=O 3 best _ _ _ _ 2 _ _ start_char=9|end_char=13|ner=O 4 friend _ _ _ _ 3 _ _ start_char=14|end_char=20|ner=O 5 is _ _ _ _ 4 _ _ start_char=21|end_char=23|ner=O 6 Karn _ _ _ _ 5 _ _ start_char=24|end_char=28|ner=S-PERSON """.strip() def test_simple_ner_conversion(): """ Test that tokens get properly created with NER tags """ doc = CoNLL.conll2doc(input_str=SIMPLE_NER) assert len(doc.sentences) == 1 sentence = doc.sentences[0] assert len(sentence.tokens) == 6 EXPECTED_NER = ["S-PERSON", "O", "O", "O", "O", "S-PERSON"] for token, ner in zip(sentence.tokens, EXPECTED_NER): assert token.ner == ner # check that the ner, start_char, end_char fields were not put on the token's misc # those should all be set as specific fields on the token assert not token.misc assert len(token.words) == 1 # they should also not reach the word's misc field assert not token.words[0].misc conll = "{:C}".format(doc) assert conll == SIMPLE_NER MWT_NER = """ # text = This makes John's headache worse # sent_id = 0 1 This _ _ _ _ 0 _ _ start_char=0|end_char=4|ner=O 2 makes _ _ _ _ 1 _ _ start_char=5|end_char=10|ner=O 3-4 John's _ _ _ _ _ _ _ start_char=11|end_char=17|ner=S-PERSON 3 John _ _ _ _ 2 _ _ _ 4 's _ _ _ _ 3 _ _ _ 5 headache _ _ _ _ 4 _ _ start_char=18|end_char=26|ner=O 6 worse _ _ _ _ 5 _ _ start_char=27|end_char=32|ner=O """.strip() def test_mwt_ner_conversion(): """ Test that tokens including MWT get properly created with NER tags Note that this kind of thing happens with the EWT tokenizer for English, for example """ doc = CoNLL.conll2doc(input_str=MWT_NER) assert len(doc.sentences) == 1 sentence = doc.sentences[0] assert len(sentence.tokens) == 5 assert not sentence.has_enhanced_dependencies() EXPECTED_NER = ["O", "O", "S-PERSON", "O", "O"] EXPECTED_WORDS = [1, 1, 2, 1, 1] for token, ner, expected_words in zip(sentence.tokens, EXPECTED_NER, EXPECTED_WORDS): assert token.ner == ner # check that the ner, start_char, end_char fields were not put on the token's misc # those should all be set as specific fields on the token assert not token.misc assert len(token.words) == expected_words # they should also not reach the word's misc field assert not token.words[0].misc conll = "{:C}".format(doc) assert conll == MWT_NER ALL_OFFSETS_CONLLU = """ # text = This makes John's headache worse # sent_id = 0 1 This _ _ _ _ 0 _ _ start_char=0|end_char=4 2 makes _ _ _ _ 1 _ _ start_char=5|end_char=10 3-4 John's _ _ _ _ _ _ _ start_char=11|end_char=17 3 John _ _ _ _ 2 _ _ start_char=11|end_char=15 4 's _ _ _ _ 3 _ _ start_char=15|end_char=17 5 headache _ _ _ _ 4 _ _ start_char=18|end_char=26 6 worse _ _ _ _ 5 _ _ SpaceAfter=No|start_char=27|end_char=32 """.strip() NO_OFFSETS_CONLLU = """ # text = This makes John's headache worse # sent_id = 0 1 This _ _ _ _ 0 _ _ _ 2 makes _ _ _ _ 1 _ _ _ 3-4 John's _ _ _ _ _ _ _ _ 3 John _ _ _ _ 2 _ _ _ 4 's _ _ _ _ 3 _ _ _ 5 headache _ _ _ _ 4 _ _ _ 6 worse _ _ _ _ 5 _ _ SpaceAfter=No """.strip() NO_COMMENTS_NO_OFFSETS_CONLLU = """ 1 This _ _ _ _ 0 _ _ _ 2 makes _ _ _ _ 1 _ _ _ 3-4 John's _ _ _ _ _ _ _ _ 3 John _ _ _ _ 2 _ _ _ 4 's _ _ _ _ 3 _ _ _ 5 headache _ _ _ _ 4 _ _ _ 6 worse _ _ _ _ 5 _ _ SpaceAfter=No """.strip() def test_no_offsets_output(): doc = CoNLL.conll2doc(input_str=ALL_OFFSETS_CONLLU) assert len(doc.sentences) == 1 sentence = doc.sentences[0] assert len(sentence.tokens) == 5 conll = "{:C}".format(doc) assert conll == ALL_OFFSETS_CONLLU conll = "{:C-o}".format(doc) assert conll == NO_OFFSETS_CONLLU conll = "{:c-o}".format(doc) assert conll == NO_COMMENTS_NO_OFFSETS_CONLLU # A random sentence from et_ewt-ud-train.conllu # which we use to test the deps conversion for multiple deps ESTONIAN_DEPS = """ # newpar # sent_id = aia_foorum_37 # text = Sestpeale ei mõistagi neid, kes koduaias sortidega tegelevad. 1 Sestpeale sest_peale ADV D _ 3 advmod 3:advmod _ 2 ei ei AUX V Polarity=Neg 3 aux 3:aux _ 3 mõistagi mõistma VERB V Connegative=Yes|Mood=Ind|Tense=Pres|VerbForm=Fin|Voice=Act 0 root 0:root _ 4 neid tema PRON P Case=Par|Number=Plur|Person=3|PronType=Prs 3 obj 3:obj|9:nsubj SpaceAfter=No 5 , , PUNCT Z _ 9 punct 9:punct _ 6 kes kes PRON P Case=Nom|Number=Plur|PronType=Int,Rel 9 nsubj 4:ref _ 7 koduaias kodu_aed NOUN S Case=Ine|Number=Sing 9 obl 9:obl _ 8 sortidega sort NOUN S Case=Com|Number=Plur 9 obl 9:obl _ 9 tegelevad tegelema VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 4 acl:relcl 4:acl SpaceAfter=No 10 . . PUNCT Z _ 3 punct 3:punct _ """.strip() def test_deps_conversion(): doc = CoNLL.conll2doc(input_str=ESTONIAN_DEPS) assert len(doc.sentences) == 1 sentence = doc.sentences[0] assert len(sentence.tokens) == 10 assert sentence.has_enhanced_dependencies() word = doc.sentences[0].words[3] assert word.deps == "3:obj|9:nsubj" conll = "{:C}".format(doc) assert conll == ESTONIAN_DEPS ESTONIAN_EMPTY_DEPS = """ # sent_id = ewtb2_000035_15 # text = Ja paari aasta pärast rôômalt maasikatele ... 1 Ja ja CCONJ J _ 3 cc 5.1:cc _ 2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _ 3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _ 4 pärast pärast ADP K AdpType=Post 3 case 3:case _ 5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt 5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1 6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes 7 ... ... PUNCT Z _ 3 punct 5.1:punct _ """.strip() ESTONIAN_EMPTY_END_DEPS = """ # sent_id = ewtb2_000035_15 # text = Ja paari aasta pärast rôômalt maasikatele ... 1 Ja ja CCONJ J _ 3 cc 5.1:cc _ 2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _ 3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _ 4 pärast pärast ADP K AdpType=Post 3 case 3:case _ 5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt 5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1 """.strip() def test_empty_deps_conversion(): """ Check that we can read and then output a sentence with empty dependencies """ check_empty_deps_conversion(ESTONIAN_EMPTY_DEPS, 7) def test_empty_deps_at_end_conversion(): """ The empty deps conversion should also work if the empty dep is at the end """ check_empty_deps_conversion(ESTONIAN_EMPTY_END_DEPS, 5) def check_empty_deps_conversion(input_str, expected_words): doc = CoNLL.conll2doc(input_str=input_str, ignore_gapping=False) assert len(doc.sentences) == 1 assert len(doc.sentences[0].tokens) == expected_words assert len(doc.sentences[0].words) == expected_words assert len(doc.sentences[0].empty_words) == 1 sentence = doc.sentences[0] conll = "{:C}".format(doc) assert conll == input_str sentence_dict = doc.sentences[0].to_dict() assert len(sentence_dict) == expected_words + 1 # currently this is true for both of the examples we run assert sentence_dict[5]['id'] == (5, 1) # redo the above checks to make sure # there are no weird bugs in the accessors assert len(doc.sentences) == 1 assert len(doc.sentences[0].tokens) == expected_words assert len(doc.sentences[0].words) == expected_words assert len(doc.sentences[0].empty_words) == 1 ESTONIAN_DOC_ID = """ # doc_id = this_is_a_doc # sent_id = ewtb2_000035_15 # text = Ja paari aasta pärast rôômalt maasikatele ... 1 Ja ja CCONJ J _ 3 cc 5.1:cc _ 2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _ 3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _ 4 pärast pärast ADP K AdpType=Post 3 case 3:case _ 5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt 5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1 6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes 7 ... ... PUNCT Z _ 3 punct 5.1:punct _ """.strip() def test_read_doc_id(): doc = CoNLL.conll2doc(input_str=ESTONIAN_DOC_ID, ignore_gapping=False) assert "{:C}".format(doc) == ESTONIAN_DOC_ID assert doc.sentences[0].doc_id == 'this_is_a_doc' SIMPLE_DEPENDENCY_INDEX_ERROR = """ # text = Teferi's best friend is Karn # sent_id = 0 # notes = this sentence has a dependency index outside the sentence. it should throw an IndexError 1 Teferi _ _ _ _ 0 root _ start_char=0|end_char=6|ner=S-PERSON 2 's _ _ _ _ 1 dep _ start_char=6|end_char=8|ner=O 3 best _ _ _ _ 2 dep _ start_char=9|end_char=13|ner=O 4 friend _ _ _ _ 3 dep _ start_char=14|end_char=20|ner=O 5 is _ _ _ _ 4 dep _ start_char=21|end_char=23|ner=O 6 Karn _ _ _ _ 8 dep _ start_char=24|end_char=28|ner=S-PERSON """.strip() def test_read_dependency_errors(): with pytest.raises(IndexError): doc = CoNLL.conll2doc(input_str=SIMPLE_DEPENDENCY_INDEX_ERROR) MULTIPLE_DOC_IDS = """ # doc_id = doc_1 # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0020 # text = His mother was also killed in the attack. 1 His his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 2 nmod:poss 2:nmod:poss _ 2 mother mother NOUN NN Number=Sing 5 nsubj:pass 5:nsubj:pass _ 3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 5 aux:pass 5:aux:pass _ 4 also also ADV RB _ 5 advmod 5:advmod _ 5 killed kill VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ 6 in in ADP IN _ 8 case 8:case _ 7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _ 8 attack attack NOUN NN Number=Sing 5 obl 5:obl:in SpaceAfter=No 9 . . PUNCT . _ 5 punct 5:punct _ # doc_id = doc_1 # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0028 # text = This item is a small one and easily missed. 1 This this DET DT Number=Sing|PronType=Dem 2 det 2:det _ 2 item item NOUN NN Number=Sing 6 nsubj 6:nsubj|9:nsubj:pass _ 3 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _ 4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ 5 small small ADJ JJ Degree=Pos 6 amod 6:amod _ 6 one one NOUN NN Number=Sing 0 root 0:root _ 7 and and CCONJ CC _ 9 cc 9:cc _ 8 easily easily ADV RB _ 9 advmod 9:advmod _ 9 missed miss VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 6 conj 6:conj:and SpaceAfter=No 10 . . PUNCT . _ 6 punct 6:punct _ # doc_id = doc_2 # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0029 # text = But in my view it is highly significant. 1 But but CCONJ CC _ 8 cc 8:cc _ 2 in in ADP IN _ 4 case 4:case _ 3 my my PRON PRP$ Case=Gen|Number=Sing|Person=1|Poss=Yes|PronType=Prs 4 nmod:poss 4:nmod:poss _ 4 view view NOUN NN Number=Sing 8 obl 8:obl:in _ 5 it it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 8 nsubj 8:nsubj _ 6 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 8 cop 8:cop _ 7 highly highly ADV RB _ 8 advmod 8:advmod _ 8 significant significant ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No 9 . . PUNCT . _ 8 punct 8:punct _ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0040 # text = The trial begins again Nov.28. 1 The the DET DT Definite=Def|PronType=Art 2 det 2:det _ 2 trial trial NOUN NN Number=Sing 3 nsubj 3:nsubj _ 3 begins begin VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _ 4 again again ADV RB _ 3 advmod 3:advmod _ 5 Nov. November PROPN NNP Abbr=Yes|Number=Sing 3 obl:tmod 3:obl:tmod SpaceAfter=No 6 28 28 NUM CD NumForm=Digit|NumType=Card 5 nummod 5:nummod SpaceAfter=No 7 . . PUNCT . _ 3 punct 3:punct _ """.lstrip() def test_read_multiple_doc_ids(): docs = CoNLL.conll2multi_docs(input_str=MULTIPLE_DOC_IDS) assert len(docs) == 2 assert len(docs[0].sentences) == 2 assert len(docs[1].sentences) == 2 # remove the first doc_id comment text = "\n".join(MULTIPLE_DOC_IDS.split("\n")[1:]) docs = CoNLL.conll2multi_docs(input_str=text) assert len(docs) == 3 assert len(docs[0].sentences) == 1 assert len(docs[1].sentences) == 1 assert len(docs[2].sentences) == 2 ENGLISH_TEST_SENTENCE = """ # text = This is a test # sent_id = 0 1 This this PRON DT Number=Sing|PronType=Dem 4 nsubj _ start_char=0|end_char=4 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ start_char=5|end_char=7 3 a a DET DT Definite=Ind|PronType=Art 4 det _ start_char=8|end_char=9 4 test test NOUN NN Number=Sing 0 root _ SpaceAfter=No|start_char=10|end_char=14 """.lstrip() def test_convert_dict(): doc = CoNLL.conll2doc(input_str=ENGLISH_TEST_SENTENCE) converted = CoNLL.convert_dict(doc.to_dict()) expected = [[['1', 'This', 'this', 'PRON', 'DT', 'Number=Sing|PronType=Dem', '4', 'nsubj', '_', 'start_char=0|end_char=4'], ['2', 'is', 'be', 'AUX', 'VBZ', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', '4', 'cop', '_', 'start_char=5|end_char=7'], ['3', 'a', 'a', 'DET', 'DT', 'Definite=Ind|PronType=Art', '4', 'det', '_', 'start_char=8|end_char=9'], ['4', 'test', 'test', 'NOUN', 'NN', 'Number=Sing', '0', 'root', '_', 'SpaceAfter=No|start_char=10|end_char=14']]] assert converted == expected def test_line_numbers(): doc = CoNLL.conll2doc(input_str=ENGLISH_TEST_SENTENCE, keep_line_numbers=True) # currently the line numbers are not output in the conllu format doc_conllu = "{:C}\n".format(doc) assert doc_conllu == ENGLISH_TEST_SENTENCE # currently the line numbers are not output in the dict format converted = CoNLL.convert_dict(doc.to_dict()) expected = [[['1', 'This', 'this', 'PRON', 'DT', 'Number=Sing|PronType=Dem', '4', 'nsubj', '_', 'start_char=0|end_char=4'], ['2', 'is', 'be', 'AUX', 'VBZ', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', '4', 'cop', '_', 'start_char=5|end_char=7'], ['3', 'a', 'a', 'DET', 'DT', 'Definite=Ind|PronType=Art', '4', 'det', '_', 'start_char=8|end_char=9'], ['4', 'test', 'test', 'NOUN', 'NN', 'Number=Sing', '0', 'root', '_', 'SpaceAfter=No|start_char=10|end_char=14']]] assert converted == expected for word_idx, word in enumerate(doc.sentences[0].words): # the test sentence has two comments in it assert word.line_number == word_idx + 2 SPEAKER_EXAMPLE = """ # sent_id = GUM_fiction_pag-57 # speaker = Siri # addressee = Pag # text = "Sorry." 1 " " PUNCT `` _ 2 punct 2:punct Discourse=joint-sequence_m:130->128:1:_|SpaceAfter=No 2 Sorry sorry ADJ JJ Degree=Pos 0 root 0:root MSeg=Sorr-y|SpaceAfter=No 3 . . PUNCT . _ 2 punct 2:punct SpaceAfter=No 4 " " PUNCT '' _ 2 punct 2:punct _ """.lstrip() def test_speaker(): doc = CoNLL.conll2doc(input_str=SPEAKER_EXAMPLE) assert len(doc.sentences) == 1 assert doc.sentences[0].speaker == 'Siri' assert "# speaker = Siri" in doc.sentences[0].comments doc.sentences[0].speaker = "foo" assert doc.sentences[0].speaker == 'foo' assert any(comment.startswith("# speaker") for comment in doc.sentences[0].comments) assert "# speaker = foo" in doc.sentences[0].comments doc.sentences[0].speaker = None assert not any(comment.startswith("# speaker") for comment in doc.sentences[0].comments) assert doc.sentences[0].speaker is None doc.sentences[0].speaker = "Siri" assert doc.sentences[0].speaker == 'Siri' assert "# speaker = Siri" in doc.sentences[0].comments ================================================ FILE: stanza/tests/common/test_data_objects.py ================================================ """ Basic tests of the stanza data objects, especially the setter/getter routines """ import pytest import stanza from stanza.models.common.doc import Document, Sentence, Word from stanza.tests import * pytestmark = pytest.mark.pipeline # data for testing EN_DOC = "This is a test document. Pretty cool!" EN_DOC_UPOS_XPOS = (('PRON_DT', 'AUX_VBZ', 'DET_DT', 'NOUN_NN', 'NOUN_NN', 'PUNCT_.'), ('ADV_RB', 'ADJ_JJ', 'PUNCT_.')) EN_DOC2 = "Chris Manning wrote a sentence. Then another." @pytest.fixture(scope="module") def nlp_pipeline(): nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en') return nlp def test_readonly(nlp_pipeline): Document.add_property('some_property', 123) doc = nlp_pipeline(EN_DOC) assert doc.some_property == 123 with pytest.raises(ValueError): doc.some_property = 456 def test_getter(nlp_pipeline): Word.add_property('upos_xpos', getter=lambda self: f"{self.upos}_{self.xpos}") doc = nlp_pipeline(EN_DOC) assert EN_DOC_UPOS_XPOS == tuple(tuple(word.upos_xpos for word in sentence.words) for sentence in doc.sentences) def test_setter_getter(nlp_pipeline): int2str = {0: 'ok', 1: 'good', 2: 'bad'} str2int = {'ok': 0, 'good': 1, 'bad': 2} def setter(self, value): self._classname = str2int[value] Sentence.add_property('classname', getter=lambda self: int2str[self._classname] if self._classname is not None else None, setter=setter) doc = nlp_pipeline(EN_DOC) sentence = doc.sentences[0] sentence.classname = 'good' assert sentence._classname == 1 # don't try this at home sentence._classname = 2 assert sentence.classname == 'bad' def test_backpointer(nlp_pipeline): doc = nlp_pipeline(EN_DOC2) ent = doc.ents[0] assert ent.sent is doc.sentences[0] assert list(doc.iter_words())[0].sent is doc.sentences[0] assert list(doc.iter_tokens())[-1].sent is doc.sentences[-1] ================================================ FILE: stanza/tests/common/test_doc.py ================================================ import pytest import stanza from stanza.tests import * from stanza.models.common.doc import Document, ID, TEXT, NER, CONSTITUENCY, SENTIMENT pytestmark = [pytest.mark.travis, pytest.mark.pipeline] @pytest.fixture def sentences_dict(): return [[{ID: 1, TEXT: "unban"}, {ID: 2, TEXT: "mox"}, {ID: 3, TEXT: "opal"}], [{ID: 4, TEXT: "ban"}, {ID: 5, TEXT: "Lurrus"}]] @pytest.fixture def doc(sentences_dict): doc = Document(sentences_dict) return doc def test_basic_values(doc, sentences_dict): """ Test that sentences & token text are properly set when constructing a doc """ assert len(doc.sentences) == len(sentences_dict) for sentence, raw_sentence in zip(doc.sentences, sentences_dict): assert sentence.doc == doc assert len(sentence.tokens) == len(raw_sentence) for token, raw_token in zip(sentence.tokens, raw_sentence): assert token.text == raw_token[TEXT] def test_set_sentence(doc): """ Test setting a field on the sentences themselves """ doc.set(fields="sentiment", contents=["4", "0"], to_sentence=True) assert doc.sentences[0].sentiment == "4" assert doc.sentences[1].sentiment == "0" def test_set_tokens(doc): """ Test setting values on tokens """ ner_contents = ["O", "ARTIFACT", "ARTIFACT", "O", "CAT"] doc.set(fields=NER, contents=ner_contents, to_token=True) result = doc.get(NER, from_token=True) assert result == ner_contents def test_constituency_comment(doc): """ Test that setting the constituency tree on a doc sets the constituency comment """ for sentence in doc.sentences: assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 0 # currently nothing is checking that the items are actually trees trees = ["asdf", "zzzz"] doc.set(fields=CONSTITUENCY, contents=trees, to_sentence=True) for sentence, expected in zip(doc.sentences, trees): constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")] assert len(constituency_comments) == 1 assert constituency_comments[0].endswith(expected) # Test that if we replace the trees with an updated tree, the comment is also replaced trees = ["zzzz", "asdf"] doc.set(fields=CONSTITUENCY, contents=trees, to_sentence=True) for sentence, expected in zip(doc.sentences, trees): constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")] assert len(constituency_comments) == 1 assert constituency_comments[0].endswith(expected) def test_sentiment_comment(doc): """ Test that setting the sentiment on a doc sets the sentiment comment """ for sentence in doc.sentences: assert len([x for x in sentence.comments if x.startswith("# sentiment")]) == 0 # currently nothing is checking that the items are actually trees sentiments = ["1", "2"] doc.set(fields=SENTIMENT, contents=sentiments, to_sentence=True) for sentence, expected in zip(doc.sentences, sentiments): sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")] assert len(sentiment_comments) == 1 assert sentiment_comments[0].endswith(expected) # Test that if we replace the trees with an updated tree, the comment is also replaced sentiments = ["3", "4"] doc.set(fields=SENTIMENT, contents=sentiments, to_sentence=True) for sentence, expected in zip(doc.sentences, sentiments): sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")] assert len(sentiment_comments) == 1 assert sentiment_comments[0].endswith(expected) def test_sent_id_comment(doc): """ Test that setting the sent_id on a sentence sets the sentiment comment """ for sent_idx, sentence in enumerate(doc.sentences): assert len([x for x in sentence.comments if x.startswith("# sent_id")]) == 1 assert sentence.sent_id == "%d" % sent_idx doc.sentences[0].sent_id = "foo" assert doc.sentences[0].sent_id == "foo" assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1 assert "# sent_id = foo" in doc.sentences[0].comments doc.reindex_sentences(10) for sent_idx, sentence in enumerate(doc.sentences): assert sentence.sent_id == "%d" % (sent_idx + 10) assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1 assert "# sent_id = %d" % (sent_idx + 10) in sentence.comments doc.sentences[0].add_comment("# sent_id = bar") assert doc.sentences[0].sent_id == "bar" assert "# sent_id = bar" in doc.sentences[0].comments assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1 def test_doc_id_comment(doc): """ Test that setting the doc_id on a sentence sets the document comment """ assert doc.sentences[0].doc_id is None assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 0 doc.sentences[0].doc_id = "foo" assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1 assert "# doc_id = foo" in doc.sentences[0].comments assert doc.sentences[0].doc_id == "foo" doc.sentences[0].add_comment("# doc_id = bar") assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1 assert doc.sentences[0].doc_id == "bar" @pytest.fixture(scope="module") def pipeline(): return stanza.Pipeline(dir=TEST_MODELS_DIR) def test_serialized(pipeline): """ Brief test of the serialized format Checks that NER entities are correctly set. Also checks that constituency & sentiment are set on the sentences. """ text = "John Bauer works at Stanford" doc = pipeline(text) assert len(doc.ents) == 2 serialized = doc.to_serialized() doc2 = Document.from_serialized(serialized) assert len(doc2.sentences) == 1 assert len(doc2.ents) == 2 assert doc.sentences[0].constituency == doc2.sentences[0].constituency assert doc.sentences[0].sentiment == doc2.sentences[0].sentiment ================================================ FILE: stanza/tests/common/test_dropout.py ================================================ import pytest import torch import stanza from stanza.models.common.dropout import WordDropout pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_word_dropout(): """ Test that word_dropout is randomly dropping out the entire final dimension of a tensor Doing 600 small rows should be super fast, but it leaves us with something like a 1 in 10^180 chance of the test failing. Not very common, in other words """ wd = WordDropout(0.5) batch = torch.randn(600, 4) dropped = wd(batch) # the one time any of this happens, it's going to be really confusing assert not torch.allclose(batch, dropped) num_zeros = 0 for i in range(batch.shape[0]): assert torch.allclose(dropped[i], batch[i]) or torch.sum(dropped[i]) == 0.0 if torch.sum(dropped[i]) == 0.0: num_zeros += 1 assert num_zeros > 0 and num_zeros < batch.shape[0] ================================================ FILE: stanza/tests/common/test_foundation_cache.py ================================================ import glob import os import shutil import tempfile import pytest import stanza from stanza.models.common.foundation_cache import FoundationCache, load_charlm from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_charlm_cache(): models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*") models = glob.glob(models_path) # we expect at least one English model downloaded for the tests assert len(models) >= 1 model_file = models[0] cache = FoundationCache() with tempfile.TemporaryDirectory(dir=".") as test_dir: temp_file = os.path.join(test_dir, "charlm.pt") shutil.copy2(model_file, temp_file) # this will work model = load_charlm(temp_file) # this will save the model model = cache.load_charlm(temp_file) # this should no longer work with pytest.raises(FileNotFoundError): model = load_charlm(temp_file) # it should remember the cached version model = cache.load_charlm(temp_file) ================================================ FILE: stanza/tests/common/test_pretrain.py ================================================ import os import tempfile import pytest import numpy as np import torch from stanza.models.common import pretrain from stanza.models.common.vocab import UNK_ID from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def check_vocab(vocab): # 4 base vectors, plus the 3 vectors actually present in the file assert len(vocab) == 7 assert 'unban' in vocab assert 'mox' in vocab assert 'opal' in vocab def check_embedding(emb, unk=False): expected = np.array([[ 0., 0., 0., 0.,], [ 0., 0., 0., 0.,], [ 0., 0., 0., 0.,], [ 0., 0., 0., 0.,], [ 1., 2., 3., 4.,], [ 5., 6., 7., 8.,], [ 9., 10., 11., 12.,]]) if unk: expected[UNK_ID] = -1 np.testing.assert_allclose(emb, expected) def check_pretrain(pt): check_vocab(pt.vocab) check_embedding(pt.emb) def test_text_pretrain(): pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.txt', save_to_file=False) check_pretrain(pt) def test_xz_pretrain(): pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False) check_pretrain(pt) def test_gz_pretrain(): pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.gz', save_to_file=False) check_pretrain(pt) def test_zip_pretrain(): pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.zip', save_to_file=False) check_pretrain(pt) def test_csv_pretrain(): pt = pretrain.Pretrain(csv_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.csv', save_to_file=False) check_pretrain(pt) def test_resave_pretrain(): """ Test saving a pretrain and then loading from the existing file """ test_pt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".pt", delete=False) try: test_pt_file.close() # note that this tests the ability to save a pretrain and the # ability to fall back when the existing pretrain isn't working pt = pretrain.Pretrain(filename=test_pt_file.name, vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz') check_pretrain(pt) pt2 = pretrain.Pretrain(filename=test_pt_file.name, vec_filename=f'unban_mox_opal') check_pretrain(pt2) pt3 = torch.load(test_pt_file.name, weights_only=True) check_embedding(pt3['emb']) finally: os.unlink(test_pt_file.name) SPACE_PRETRAIN=""" 3 4 unban mox 1 2 3 4 opal 5 6 7 8 foo 9 10 11 12 """.strip() def test_whitespace(): """ Test reading a pretrain with an ascii space in it The vocab word with a space in it should have the correct number of dimensions read, with the space converted to nbsp """ test_txt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".txt", delete=False) try: test_txt_file.write(SPACE_PRETRAIN.encode()) test_txt_file.close() pt = pretrain.Pretrain(vec_filename=test_txt_file.name, save_to_file=False) check_embedding(pt.emb) assert "unban\xa0mox" in pt.vocab # this one also works because of the normalize_unit in vocab.py assert "unban mox" in pt.vocab finally: os.unlink(test_txt_file.name) NO_HEADER_PRETRAIN=""" unban 1 2 3 4 mox 5 6 7 8 opal 9 10 11 12 """.strip() def test_no_header(): """ Check loading a pretrain with no rows,cols header """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir: filename = os.path.join(tmpdir, "tiny.txt") with open(filename, "w", encoding="utf-8") as fout: fout.write(NO_HEADER_PRETRAIN) pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False) check_embedding(pt.emb) UNK_PRETRAIN=""" unban 1 2 3 4 mox 5 6 7 8 opal 9 10 11 12 -1 -1 -1 -1 """.strip() def test_no_header(): """ Check loading a pretrain with at the end, like GloVe does """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir: filename = os.path.join(tmpdir, "tiny.txt") with open(filename, "w", encoding="utf-8") as fout: fout.write(UNK_PRETRAIN) pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False) check_embedding(pt.emb, unk=True) ================================================ FILE: stanza/tests/common/test_relative_attn.py ================================================ import pytest import torch from stanza.models.common.relative_attn import RelativeAttention pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_attn(): foo = RelativeAttention(d_model=100, num_heads=2, window=8, dropout=0.0) bar = torch.randn(10, 13, 100) result = foo(bar) assert result.shape == bar.shape value = foo.value(bar) if not torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06): raise ValueError(result[:, -1, :] - value[:, -1, :]) assert torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06) assert not torch.allclose(result[:, 0, :], value[:, 0, :]) def test_shorter_sequence(): # originally this was failing because the batch was smaller than the window foo = RelativeAttention(d_model=20, num_heads=2, window=5, dropout=0.0) bar = torch.randn(10, 3, 20) result = foo(bar) assert result.shape == bar.shape value = foo.value(bar) if not torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06): raise ValueError(result[:, -1, :] - value[:, -1, :]) assert torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06) assert not torch.allclose(result[:, 0, :], value[:, 0, :]) def test_reverse(): foo = RelativeAttention(d_model=100, num_heads=2, window=8, reverse=True, dropout=0.0) bar = torch.randn(10, 13, 100) result = foo(bar) assert result.shape == bar.shape value = foo.value(bar) if not torch.allclose(result[:, 0, :], value[:, 0, :], atol=1e-06): raise ValueError(result[:, 0, :] - value[:, 0, :]) assert torch.allclose(result[:, 0, :], value[:, 0, :], atol=1e-06) assert not torch.allclose(result[:, -1, :], value[:, -1, :]) ================================================ FILE: stanza/tests/common/test_short_name_to_treebank.py ================================================ import pytest import stanza from stanza.models.common import short_name_to_treebank pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_short_name(): assert short_name_to_treebank.short_name_to_treebank("en_ewt") == "UD_English-EWT" def test_canonical_name(): assert short_name_to_treebank.canonical_treebank_name("UD_URDU-UDTB") == "UD_Urdu-UDTB" assert short_name_to_treebank.canonical_treebank_name("ur_udtb") == "UD_Urdu-UDTB" assert short_name_to_treebank.canonical_treebank_name("Unban_Mox_Opal") == "Unban_Mox_Opal" ================================================ FILE: stanza/tests/common/test_utils.py ================================================ import lzma import os import tempfile import pytest import stanza import stanza.models.common.utils as utils from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_wordvec_not_found(): """ get_wordvec_file should fail if neither word2vec nor fasttext exists """ with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir: with pytest.raises(FileNotFoundError): utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo') def test_word2vec_xz(): """ Test searching for word2vec and xz files """ with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir: # make a fake directory for English word vectors word2vec_dir = os.path.join(temp_dir, 'word2vec', 'English') os.makedirs(word2vec_dir) # make a fake English word vector file fake_file = os.path.join(word2vec_dir, 'en.vectors.xz') fout = open(fake_file, 'w') fout.close() # get_wordvec_file should now find this fake file filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo') assert filename == fake_file def test_fasttext_txt(): """ Test searching for fasttext and txt files """ with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir: # make a fake directory for English word vectors fasttext_dir = os.path.join(temp_dir, 'fasttext', 'English') os.makedirs(fasttext_dir) # make a fake English word vector file fake_file = os.path.join(fasttext_dir, 'en.vectors.txt') fout = open(fake_file, 'w') fout.close() # get_wordvec_file should now find this fake file filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo') assert filename == fake_file def test_wordvec_type(): """ If we supply our own wordvec type, get_wordvec_file should find that """ with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir: # make a fake directory for English word vectors google_dir = os.path.join(temp_dir, 'google', 'English') os.makedirs(google_dir) # make a fake English word vector file fake_file = os.path.join(google_dir, 'en.vectors.txt') fout = open(fake_file, 'w') fout.close() # get_wordvec_file should now find this fake file filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo', wordvec_type='google') assert filename == fake_file # this file won't be found using the normal defaults with pytest.raises(FileNotFoundError): utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo') def test_sort_with_indices(): data = [[1, 2, 3], [4, 5], [6]] ordered, orig_idx = utils.sort_with_indices(data, key=len) assert ordered == ([6], [4, 5], [1, 2, 3]) assert orig_idx == (2, 1, 0) unsorted = utils.unsort(ordered, orig_idx) assert data == unsorted def test_empty_sort_with_indices(): ordered, orig_idx = utils.sort_with_indices([]) assert len(ordered) == 0 assert len(orig_idx) == 0 unsorted = utils.unsort(ordered, orig_idx) assert [] == unsorted def test_split_into_batches(): data = [] for i in range(5): data.append(["Unban", "mox", "opal", str(i)]) data.append(["Do", "n't", "ban", "Urza", "'s", "Saga", "that", "card", "is", "great"]) data.append(["Ban", "Ragavan"]) # small batches will put one element in each interval batches = utils.split_into_batches(data, 5) assert batches == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)] # this one has a batch interrupted in the middle by a large element batches = utils.split_into_batches(data, 8) assert batches == [(0, 2), (2, 4), (4, 5), (5, 6), (6, 7)] # this one has the large element at the start of its own batch batches = utils.split_into_batches(data[1:], 8) assert batches == [(0, 2), (2, 4), (4, 5), (5, 6)] # overloading the test! assert that the key & reverse is working ordered, orig_idx = utils.sort_with_indices(data, key=len, reverse=True) assert [len(x) for x in ordered] == [10, 4, 4, 4, 4, 4, 2] # this has the large element at the start batches = utils.split_into_batches(ordered, 8) assert batches == [(0, 1), (1, 3), (3, 5), (5, 7)] # double check that unsort is working as expected assert data == utils.unsort(ordered, orig_idx) def test_find_missing_tags(): assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC"]) == [] assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC", "ORG"]) == ['ORG'] assert utils.find_missing_tags([["O", "PER"], ["O", "LOC"]], [["O", "PER"], ["LOC", "ORG"]]) == ['ORG'] def test_open_read_text(): """ test that we can read either .xz or regular txt """ TEXT = "this is a test" with tempfile.TemporaryDirectory() as tempdir: # test text file filename = os.path.join(tempdir, "foo.txt") with open(filename, "w") as fout: fout.write(TEXT) with utils.open_read_text(filename) as fin: in_text = fin.read() assert TEXT == in_text assert fin.closed # the context should close the file when we throw an exception! try: with utils.open_read_text(filename) as finex: assert not finex.closed raise ValueError("unban mox opal!") except ValueError: pass assert finex.closed # test xz file filename = os.path.join(tempdir, "foo.txt.xz") with lzma.open(filename, "wt") as fout: fout.write(TEXT) with utils.open_read_text(filename) as finxz: in_text = finxz.read() assert TEXT == in_text assert finxz.closed # the context should close the file when we throw an exception! try: with utils.open_read_text(filename) as finexxz: assert not finexxz.closed raise ValueError("unban mox opal!") except ValueError: pass assert finexxz.closed def test_checkpoint_name(): """ Test some expected results for the checkpoint names """ # use os.path.split so that the test is agnostic of file separator on Linux or Windows checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm.pt", None) assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint.pt") checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", None) assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint") checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", "othername.pt") assert os.path.split(checkpoint) == ("saved_models", "othername.pt") def test_punct_simplification(): """ Test a punctuation simplification that should make it so unexpected question/exclamation marks types are processed into ? and ! """ test = [[["!!!!"], ["‼‼‼‼"], ["????"], ["?!?!"], ["??︖"], ["?foo"], ["bar!"]]] test = utils.simplify_punct(test) expected = [[['!'], ['!'], ['?'], ['?'], ['?'], ['?foo'], ['bar!']]] assert test == expected ================================================ FILE: stanza/tests/constituency/__init__.py ================================================ ================================================ FILE: stanza/tests/constituency/test_convert_arboretum.py ================================================ """ Test a couple different classes of trees to check the output of the Arboretum conversion Note that the text has been removed """ import os import tempfile import pytest from stanza.server import tsurgeon from stanza.tests import TEST_WORKING_DIR from stanza.utils.datasets.constituency import convert_arboretum pytestmark = [pytest.mark.pipeline, pytest.mark.travis] PROJ_EXAMPLE=""" """ NOT_FIX_NONPROJ_EXAMPLE=""" """ NONPROJ_EXAMPLE=""" """ def test_projective_example(): """ Test reading a basic tree, along with some further manipulations from the conversion program """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir: test_name = os.path.join(tempdir, "proj.xml") with open(test_name, "w", encoding="utf-8") as fout: fout.write(PROJ_EXAMPLE) sentences = convert_arboretum.read_xml_file(test_name) assert len(sentences) == 1 tree, words = convert_arboretum.process_tree(sentences[0]) expected_tree = "(s (fcl (prop s2_1) (v-fin s2_2) (pron-pers s2_3) (adjp (adj s2_4) (pp (prp s2_5) (np (art s2_6) (adj s2_7) (n s2_8)))) (pu s2_9)))" assert str(tree) == expected_tree assert [w.word for w in words.values()] == ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', '.'] assert not convert_arboretum.word_sequence_missing_words(tree) with tsurgeon.Tsurgeon() as tsurgeon_processor: assert tree == convert_arboretum.check_words(tree, tsurgeon_processor) # check that the words can be replaced as expected replaced_tree = convert_arboretum.replace_words(tree, words) expected_tree = "(s (fcl (prop A) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))" assert str(replaced_tree) == expected_tree assert convert_arboretum.split_underscores(replaced_tree) == replaced_tree # fake a word which should be split words['s2_1'] = words['s2_1']._replace(word='foo_bar') replaced_tree = convert_arboretum.replace_words(tree, words) expected_tree = "(s (fcl (prop foo_bar) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))" assert str(replaced_tree) == expected_tree expected_tree = "(s (fcl (np (prop foo) (prop bar)) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))" assert str(convert_arboretum.split_underscores(replaced_tree)) == expected_tree def test_not_fix_example(): """ Test that a non-projective tree which we don't have a heuristic for quietly fails """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir: test_name = os.path.join(tempdir, "nofix.xml") with open(test_name, "w", encoding="utf-8") as fout: fout.write(NOT_FIX_NONPROJ_EXAMPLE) sentences = convert_arboretum.read_xml_file(test_name) assert len(sentences) == 1 tree, words = convert_arboretum.process_tree(sentences[0]) assert not convert_arboretum.word_sequence_missing_words(tree) with tsurgeon.Tsurgeon() as tsurgeon_processor: assert convert_arboretum.check_words(tree, tsurgeon_processor) is None def test_fix_proj_example(): """ Test that a non-projective tree can be rearranged as expected Note that there are several other classes of non-proj tree we could test as well... """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir: test_name = os.path.join(tempdir, "fix.xml") with open(test_name, "w", encoding="utf-8") as fout: fout.write(NONPROJ_EXAMPLE) sentences = convert_arboretum.read_xml_file(test_name) assert len(sentences) == 1 tree, words = convert_arboretum.process_tree(sentences[0]) assert not convert_arboretum.word_sequence_missing_words(tree) # the 4 and 5 are moved inside the 3-6 node expected_orig = "(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (v-pcp2 s9_6)) (prop s9_4) (adv s9_5) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))" expected_proj = "(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (prop s9_4) (adv s9_5) (v-pcp2 s9_6)) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))" assert str(tree) == expected_orig with tsurgeon.Tsurgeon() as tsurgeon_processor: assert str(convert_arboretum.check_words(tree, tsurgeon_processor)) == expected_proj ================================================ FILE: stanza/tests/constituency/test_convert_it_vit.py ================================================ """ Test a couple different classes of trees to check the output of the VIT conversion A couple representative trees are included, but hopefully not enough to be a problem in terms of our license. One of the tests is currently disabled as it relies on tregex & tsurgeon features not yet released """ import io import os import tempfile import pytest from stanza.server import tsurgeon from stanza.utils.conll import CoNLL from stanza.utils.datasets.constituency import convert_it_vit pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # just a sample! don't sue us please CON_SAMPLE = """ #ID=sent_00002 cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]] #ID=sent_00318 dirsp-[fc-[congf-tuttavia, f-[sn-[sq-[ind-qualche], n-problema], ir_infl-[vsupir-potrebbe, vcl-esserci], compc-[clit-ci, sp-[p-per, sn-[art-la, n-commissione, sa-[ag-esteri], f2-[sp-[part-alla, relob-cui, sn-[n-presidenza]], f-[ibar-[vc-è], compc-[sn-[n-candidato], sn-[art-l, n-esponente, spd-[pd-di, sn-[mw-Alleanza, npro-Nazionale]], sn-[mw-Mirko, nh-Tremaglia]]]]]]]]]], dirs-':', f3-[sn-[art-una, n-candidatura, sc-[q-più, sa-[ppas-subìta], sc-[ccong-che, sa-[ppas-gradita]], compt-[spda-[partda-dalla, sn-[mw-Lega, npro-Nord, punt-',', f2-[rel-che, fc-[congf-tuttavia, f-[ir_infl-[vsupir-dovrebbe, vit-rispettare], compt-[sn-[art-gli, n-accordi]]]]]]]]]], punto-.]] #ID=sent_00589 f-[sn-[art-l, n-ottimismo, spd-[pd-di, sn-[nh-Kantor]]], ir_infl-[vsupir-potrebbe, congf-però, vcl-rivelarsi], compc-[sn-[in-ancora, art-una, nt-volta], sa-[ag-prematuro]], punto-.] """ UD_SAMPLE = """ # sent_id = VIT-2 # text = Negli ultimi anni la dinamica dei polo di attrazione è stata sempre più caratterizzata dall'emergere di una crescente concorrenza che si è progressivamente spostata dalle singole imprese ai sistemi economici e territoriali, determinando l'esigenza di una riconsiderazione dei rapporti esistenti tra soggetti produttivi e ambiente in cui questi operano. 1-2 Negli _ _ _ _ _ _ _ _ 1 In in ADP E _ 4 case _ _ 2 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 4 det _ _ 3 ultimi ultimo ADJ A Gender=Masc|Number=Plur 4 amod _ _ 4 anni anno NOUN S Gender=Masc|Number=Plur 16 obl _ _ 5 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 6 det _ _ 6 dinamica dinamica NOUN S Gender=Fem|Number=Sing 16 nsubj:pass _ _ 7-8 dei _ _ _ _ _ _ _ _ 7 di di ADP E _ 9 case _ _ 8 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 9 det _ _ 9 polo polo NOUN S Gender=Masc|Number=Sing 6 nmod _ _ 10 di di ADP E _ 11 case _ _ 11 attrazione attrazione NOUN S Gender=Fem|Number=Sing 9 nmod _ _ 12 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux _ _ 13 stata essere AUX VA Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 16 aux:pass _ _ 14 sempre sempre ADV B _ 15 advmod _ _ 15 più più ADV B _ 16 advmod _ _ 16 caratterizzata caratterizzare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 0 root _ _ 17-18 dall' _ _ _ _ _ _ _ SpaceAfter=No 17 da da ADP E _ 19 case _ _ 18 l' il DET RD Definite=Def|Number=Sing|PronType=Art 19 det _ _ 19 emergere emergere NOUN S Gender=Masc|Number=Sing 16 obl _ _ 20 di di ADP E _ 23 case _ _ 21 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 23 det _ _ 22 crescente crescente ADJ A Number=Sing 23 amod _ _ 23 concorrenza concorrenza NOUN S Gender=Fem|Number=Sing 19 nmod _ _ 24 che che PRON PR PronType=Rel 28 nsubj _ _ 25 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 28 expl _ _ 26 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 28 aux _ _ 27 progressivamente progressivamente ADV B _ 28 advmod _ _ 28 spostata spostare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 23 acl:relcl _ _ 29-30 dalle _ _ _ _ _ _ _ _ 29 da da ADP E _ 32 case _ _ 30 le il DET RD Definite=Def|Gender=Fem|Number=Plur|PronType=Art 32 det _ _ 31 singole singolo ADJ A Gender=Fem|Number=Plur 32 amod _ _ 32 imprese impresa NOUN S Gender=Fem|Number=Plur 28 obl _ _ 33-34 ai _ _ _ _ _ _ _ _ 33 a a ADP E _ 35 case _ _ 34 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 35 det _ _ 35 sistemi sistema NOUN S Gender=Masc|Number=Plur 28 obl _ _ 36 economici economico ADJ A Gender=Masc|Number=Plur 35 amod _ _ 37 e e CCONJ CC _ 38 cc _ _ 38 territoriali territoriale ADJ A Number=Plur 36 conj _ SpaceAfter=No 39 , , PUNCT FF _ 28 punct _ _ 40 determinando determinare VERB V VerbForm=Ger 28 advcl _ _ 41 l' il DET RD Definite=Def|Number=Sing|PronType=Art 42 det _ SpaceAfter=No 42 esigenza esigenza NOUN S Gender=Fem|Number=Sing 40 obj _ _ 43 di di ADP E _ 45 case _ _ 44 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 45 det _ _ 45 riconsiderazione riconsiderazione NOUN S Gender=Fem|Number=Sing 42 nmod _ _ 46-47 dei _ _ _ _ _ _ _ _ 46 di di ADP E _ 48 case _ _ 47 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 48 det _ _ 48 rapporti rapporto NOUN S Gender=Masc|Number=Plur 45 nmod _ _ 49 esistenti esistente VERB V Number=Plur 48 acl _ _ 50 tra tra ADP E _ 51 case _ _ 51 soggetti soggetto NOUN S Gender=Masc|Number=Plur 49 obl _ _ 52 produttivi produttivo ADJ A Gender=Masc|Number=Plur 51 amod _ _ 53 e e CCONJ CC _ 54 cc _ _ 54 ambiente ambiente NOUN S Gender=Masc|Number=Sing 51 conj _ _ 55 in in ADP E _ 56 case _ _ 56 cui cui PRON PR PronType=Rel 58 obl _ _ 57 questi questo PRON PD Gender=Masc|Number=Plur|PronType=Dem 58 nsubj _ _ 58 operano operare VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 54 acl:relcl _ SpaceAfter=No 59 . . PUNCT FS _ 16 punct _ _ # sent_id = VIT-318 # text = Tuttavia qualche problema potrebbe esserci per la commissione esteri alla cui presidenza è candidato l'esponente di Alleanza Nazionale Mirko Tremaglia: una candidatura più subìta che gradita dalla Lega Nord, che tuttavia dovrebbe rispettare gli accordi. 1 Tuttavia tuttavia CCONJ CC _ 5 cc _ _ 2 qualche qualche DET DI Number=Sing|PronType=Ind 3 det _ _ 3 problema problema NOUN S Gender=Masc|Number=Sing 5 nsubj _ _ 4 potrebbe potere AUX VA Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux _ _ 5-6 esserci _ _ _ _ _ _ _ _ 5 esser essere VERB V VerbForm=Inf 0 root _ _ 6 ci ci PRON PC Clitic=Yes|Number=Plur|Person=1|PronType=Prs 5 expl _ _ 7 per per ADP E _ 9 case _ _ 8 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 9 det _ _ 9 commissione commissione NOUN S Gender=Fem|Number=Sing 5 obl _ _ 10 esteri estero ADJ A Gender=Masc|Number=Plur 9 amod _ _ 11-12 alla _ _ _ _ _ _ _ _ 11 a a ADP E _ 14 case _ _ 12 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 14 det _ _ 13 cui cui DET DR PronType=Rel 14 det:poss _ _ 14 presidenza presidenza NOUN S Gender=Fem|Number=Sing 16 obl _ _ 15 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux:pass _ _ 16 candidato candidare VERB V Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 9 acl:relcl _ _ 17 l' il DET RD Definite=Def|Number=Sing|PronType=Art 18 det _ SpaceAfter=No 18 esponente esponente NOUN S Number=Sing 16 nsubj:pass _ _ 19 di di ADP E _ 20 case _ _ 20 Alleanza Alleanza PROPN SP _ 18 nmod _ _ 21 Nazionale Nazionale PROPN SP _ 20 flat:name _ _ 22 Mirko Mirko PROPN SP _ 18 nmod _ _ 23 Tremaglia Tremaglia PROPN SP _ 22 flat:name _ SpaceAfter=No 24 : : PUNCT FC _ 22 punct _ _ 25 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 26 det _ _ 26 candidatura candidatura NOUN S Gender=Fem|Number=Sing 22 appos _ _ 27 più più ADV B _ 28 advmod _ _ 28 subìta subire VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 26 advcl _ _ 29 che che CCONJ CC _ 30 cc _ _ 30 gradita gradito ADJ A Gender=Fem|Number=Sing 28 amod _ _ 31-32 dalla _ _ _ _ _ _ _ _ 31 da da ADP E _ 33 case _ _ 32 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 33 det _ _ 33 Lega Lega PROPN SP _ 28 obl:agent _ _ 34 Nord Nord PROPN SP _ 33 flat:name _ SpaceAfter=No 35 , , PUNCT FC _ 33 punct _ _ 36 che che PRON PR PronType=Rel 39 nsubj _ _ 37 tuttavia tuttavia CCONJ CC _ 39 cc _ _ 38 dovrebbe dovere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 39 aux _ _ 39 rispettare rispettare VERB V VerbForm=Inf 33 acl:relcl _ _ 40 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 41 det _ _ 41 accordi accordio NOUN S Gender=Masc|Number=Plur 39 obj _ SpaceAfter=No 42 . . PUNCT FS _ 5 punct _ _ # sent_id = VIT-591 # text = L'ottimismo di Kantor potrebbe però rivelarsi ancora una volta prematuro. 1 L' il DET RD Definite=Def|Number=Sing|PronType=Art 2 det _ SpaceAfter=No 2 ottimismo ottimismo NOUN S Gender=Masc|Number=Sing 7 nsubj _ _ 3 di di ADP E _ 4 case _ _ 4 Kantor Kantor PROPN SP _ 2 nmod _ _ 5 potrebbe potere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux _ _ 6 però però ADV B _ 7 advmod _ _ 7-8 rivelarsi _ _ _ _ _ _ _ _ 7 rivelar rivelare VERB V VerbForm=Inf 0 root _ _ 8 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 7 expl _ _ 9 ancora ancora ADV B _ 7 advmod _ _ 10 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 11 det _ _ 11 volta volta NOUN S Gender=Fem|Number=Sing 7 obl _ _ 12 prematuro prematuro ADJ A Gender=Masc|Number=Sing 7 xcomp _ SpaceAfter=No 13 . . PUNCT FS _ 7 punct _ _ """ def test_process_mwts(): # dei appears multiple times # the verb/pron esserci will be ignored expected_mwts = {'Negli': ('In', 'gli'), 'dei': ('di', 'i'), "dall'": ('da', "l'"), 'dalle': ('da', 'le'), 'ai': ('a', 'i'), 'alla': ('a', 'la'), 'dalla': ('da', 'la')} ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE) mwts = convert_it_vit.get_mwt(ud_train_data) assert expected_mwts == mwts def test_raw_tree(): con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE)) expected_ids = ["#ID=sent_00002", "#ID=sent_00318", "#ID=sent_00589"] expected_trees = ["(ROOT (cp (sp (part negli) (sn (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd dei) (sn (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda dall) (sn (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda dalle) (sn (sa (ag singole)) (n imprese))) (sp (part ai) (sn (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd dei) (sn (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))", "(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part alla) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda dalla) (sn (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))", "(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"] assert len(con_sentences) == 3 for sentence, expected_id, expected_tree in zip(con_sentences, expected_ids, expected_trees): assert sentence[0] == expected_id tree = convert_it_vit.raw_tree(sentence[1]) assert str(tree) == expected_tree def test_update_mwts(): con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE)) ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE) mwt_map = convert_it_vit.get_mwt(ud_train_data) expected_trees=["(ROOT (cp (sp (part In) (sn (art gli) (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd di) (sn (art i) (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda da) (sn (art l') (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda da) (sn (art le) (sa (ag singole)) (n imprese))) (sp (part a) (sn (art i) (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd di) (sn (art i) (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))", "(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part a) (art la) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda da) (sn (art la) (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))", "(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (clit si) (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"] with tsurgeon.Tsurgeon() as tsurgeon_processor: for con_sentence, ud_sentence, expected_tree in zip(con_sentences, ud_train_data.sentences, expected_trees): con_tree = convert_it_vit.raw_tree(con_sentence[1]) updated_tree, _ = convert_it_vit.update_mwts_and_special_cases(con_tree, ud_sentence, mwt_map, tsurgeon_processor) assert str(updated_tree) == expected_tree CON_PERCENT_SAMPLE = """ ID#sent_00020 f-[sn-[art-il, n-tesoro], ibar-[vt-mette], compt-[sp-[part-sul, sn-[n-mercato]], sn-[art-il, num-51%, sp-[p-a, sn-[num-2, n-lire]], sp-[p-per, sn-[n-azione]]]], punto-.] ID#sent_00022 dirsp-[f3-[sn-[art-le, n-novità]], dirs-':', f3-[coord-[sn-[n-voto, spd-[pd-di, sn-[n-lista]]], cong-e, sn-[n-tetto, sp-[part-agli, sn-[n-acquisti]], sv3-[vppt-limitato, comppas-[sp-[part-allo, sn-[num-0/5%]]]]]], punto-.]] ID#sent_00517 dirsp-[fc-[f-[sn-[art-l, n-aumento, sa-[ag-mensile], spd-[pd-di, sn-[nt-aprile]]], ibar-[ause-è, vppc-stato], compc-[sq-[q-dell_, sn-[num-1/3%]], sp-[p-contro, sn-[art-lo, num-0/7/0/8%, spd-[partd-degli, sn-[sa-[ag-ultimi], num-due, sn-[nt-mesi]]]]]]]]] ID#sent_01117 fc-[f-[sn-[art-La, sa-[ag-crescente], n-ripresa, spd-[partd-dei, sn-[n-beni, spd-[pd-di, sn-[n-consumo]]]]], ibar-[vin-deriva], savv-[avv-esclusivamente], compin-[spda-[partda-dal, sn-[n-miglioramento, f2-[spd-[pd-di, sn-[relob-cui]], f-[ibar-[ausa-hanno, vppin-beneficiato], compin-[sn-[n-beni, coord-[sa-[ag-durevoli, fp-[par-'(', sn-[num-plus4/5%], par-')']], cong-e, sa-[ag-semidurevoli, fp-[par-'(', sn-[num-plus1/5%], par-')']]]]]]]]]]], punt-',', fs-[cosu-mentre, f-[sn-[art-i, n-beni, sa-[neg-non, ag-durevoli], fp-[par-'(', sn-[num-min1%], par-')']], ibar-[vt-accusano], cong-ancora, compt-[sn-[art-un, sa-[ag-evidente], n-ritardo]]]], punto-.] """ CON_PERCENT_LEAVES = [ ['il', 'tesoro', 'mette', 'sul', 'mercato', 'il', '51', '%%', 'a', '2', 'lire', 'per', 'azione', '.'], ['le', 'novità', ':', 'voto', 'di', 'lista', 'e', 'tetto', 'agli', 'acquisti', 'limitato', 'allo', '0,5', '%%', '.'], ['l', 'aumento', 'mensile', 'di', 'aprile', 'è', 'stato', "dell'", '1,3', '%%', 'contro', 'lo', '0/7,0/8', '%%', 'degli', 'ultimi', 'due', 'mesi'], # the plus and min look bad, but they get cleaned up when merging with the UD version of the dataset ['La', 'crescente', 'ripresa', 'dei', 'beni', 'di', 'consumo', 'deriva', 'esclusivamente', 'dal', 'miglioramento', 'di', 'cui', 'hanno', 'beneficiato', 'beni', 'durevoli', '(', 'plus4,5', '%%', ')', 'e', 'semidurevoli', '(', 'plus1,5', '%%', ')', ',', 'mentre', 'i', 'beni', 'non', 'durevoli', '(', 'min1', '%%', ')', 'accusano', 'ancora', 'un', 'evidente', 'ritardo', '.'] ] def test_read_percent(): con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_PERCENT_SAMPLE)) assert len(con_sentences) == len(CON_PERCENT_LEAVES) for (_, raw_tree), expected_leaves in zip(con_sentences, CON_PERCENT_LEAVES): tree = convert_it_vit.raw_tree(raw_tree) words = tree.leaf_labels() if expected_leaves is None: print(words) else: assert words == expected_leaves ================================================ FILE: stanza/tests/constituency/test_convert_starlang.py ================================================ """ Test a couple different classes of trees to check the output of the Starlang conversion """ import os import tempfile import pytest from stanza.utils.datasets.constituency import convert_starlang pytestmark = [pytest.mark.pipeline, pytest.mark.travis] TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )" def test_read_tree(): """ Test a basic tree read """ tree = convert_starlang.read_tree(TREE) assert "(ROOT (S (NP (NP Bayan) (NP Haag)) (VP (NP Elianti) (VP çalar)) (. .)))" == str(tree) def test_missing_word(): """ Test that an error is thrown if the word is missing """ tree_text = TREE.replace("turkish=", "foo=") with pytest.raises(ValueError): tree = convert_starlang.read_tree(tree_text) def test_bad_label(): """ Test that an unexpected label results in an error """ tree_text = TREE.replace("(S", "(s") with pytest.raises(ValueError): tree = convert_starlang.read_tree(tree_text) ================================================ FILE: stanza/tests/constituency/test_ensemble.py ================================================ """ Add a simple test of the Ensemble's inference path This just reuses one model several times - that should still check the main loop, at least """ import pytest from stanza import Pipeline from stanza.models.constituency import text_processing from stanza.models.constituency import tree_reader from stanza.models.constituency.ensemble import Ensemble, EnsembleTrainer from stanza.models.constituency.text_processing import parse_tokenized_sentences from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] @pytest.fixture(scope="module") def pipeline(): return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True) @pytest.fixture(scope="module") def saved_ensemble(tmp_path_factory, pipeline): tmp_path = tmp_path_factory.mktemp("ensemble") # test the ensemble by reusing the same parser multiple times con_processor = pipeline.processors["constituency"] model = con_processor._model args = dict(model.args) foundation_cache = pipeline.foundation_cache model_path = con_processor._config['model_path'] # reuse the same model 3 times just to make sure the code paths are working filenames = [model_path, model_path, model_path] ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache) save_path = tmp_path / "ensemble.pt" ensemble.save(save_path) return ensemble, save_path, args, foundation_cache def check_basic_predictions(trees): predictions = [x.predictions for x in trees] assert len(predictions) == 2 assert all(len(x) == 1 for x in predictions) trees = [x[0].tree for x in predictions] result = ["{}".format(tree) for tree in trees] expected = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))", "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"] assert result == expected def test_ensemble_inference(pipeline): # test the ensemble by reusing the same parser multiple times con_processor = pipeline.processors["constituency"] model = con_processor._model args = dict(model.args) foundation_cache = pipeline.foundation_cache model_path = con_processor._config['model_path'] # reuse the same model 3 times just to make sure the code paths are working filenames = [model_path, model_path, model_path] ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache) ensemble = ensemble.model sentences = [["This", "is", "a", "test"], ["This", "is", "another", "test"]] trees = parse_tokenized_sentences(args, ensemble, [pipeline], sentences) check_basic_predictions(trees) def test_ensemble_save(saved_ensemble): """ Depending on the saved_ensemble fixture should be enough to ensure that the ensemble was correctly saved (loading is tested separately) """ def test_ensemble_save_load(pipeline, saved_ensemble): _, save_path, args, foundation_cache = saved_ensemble ensemble = EnsembleTrainer.load(save_path, args, foundation_cache=foundation_cache) sentences = [["This", "is", "a", "test"], ["This", "is", "another", "test"]] trees = parse_tokenized_sentences(args, ensemble.model, [pipeline], sentences) check_basic_predictions(trees) def test_parse_text(tmp_path, pipeline, saved_ensemble): _, model_path, args, foundation_cache = saved_ensemble raw_file = str(tmp_path / "test_input.txt") with open(raw_file, "w") as fout: fout.write("This is a test\nThis is another test\n") output_file = str(tmp_path / "test_output.txt") args = dict(args) args['tokenized_file'] = raw_file args['predict_file'] = output_file text_processing.load_model_parse_text(args, model_path, [pipeline]) trees = tree_reader.read_treebank(output_file) trees = ["{}".format(x) for x in trees] expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))", "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"] assert trees == expected_trees def test_pipeline(saved_ensemble): _, model_path, _, foundation_cache = saved_ensemble nlp = Pipeline("en", processors="tokenize,pos,constituency", constituency_model_path=str(model_path), foundation_cache=foundation_cache, download_method=None) doc = nlp("This is a test") tree = "{}".format(doc.sentences[0].constituency) assert tree == "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))" ================================================ FILE: stanza/tests/constituency/test_in_order_compound_oracle.py ================================================ import pytest from stanza.models.constituency import in_order_compound_oracle from stanza.models.constituency import tree_reader from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme from stanza.models.constituency.transition_sequence import build_treebank from stanza.tests.constituency.test_transition_sequence import reconstruct_tree pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # A sample tree from PTB with a triple unary transition (at a location other than root) # Here we test the incorrect closing of various brackets TRIPLE_UNARY_START_TREE = """ ( (S (PRN (S (NP-SBJ (-NONE- *) ) (VP (VB See) ))) (, ,) (NP-SBJ (NP (DT the) (JJ other) (NN rule) ) (PP (IN of) (NP (NN thumb) )) (PP (IN about) (NP (NN ballooning) ))))) """ TREES = [TRIPLE_UNARY_START_TREE] TREEBANK = "\n".join(TREES) ROOT_LABELS = ["ROOT"] @pytest.fixture(scope="module") def trees(): trees = tree_reader.read_trees(TREEBANK) trees = [t.prune_none().simplify_labels() for t in trees] assert len(trees) == len(TREES) return trees @pytest.fixture(scope="module") def gold_sequences(trees): gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND) return gold_sequences def get_repairs(gold_sequence, wrong_transition, repair_fn): """ Use the repair function and the wrong transition to iterate over the gold sequence Returns a list of possible repairs, one for each position in the sequence Repairs are tuples, (idx, seq) """ repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None)) for idx, gold_transition in enumerate(gold_sequence)] repairs = [x for x in repairs if x[1] is not None] return repairs def test_fix_shift_close(): trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE) trees = [t.prune_none().simplify_labels() for t in trees] assert len(trees) == 1 tree = trees[0] gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND) # there are three places in this tree where a long bracket (more than 2 subtrees) # could theoretically be closed and then reopened repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_shift_close_error) assert len(repairs) == 3 expected_trees = ["(ROOT (S (S (PRN (S (VP (VB See)))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))", "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other)) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))", "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)))) (PP (IN about) (NP (NN ballooning))))))"] for repair, expected in zip(repairs, expected_trees): repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND) assert str(repaired_tree) == expected def test_fix_open_close(): trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE) trees = [t.prune_none().simplify_labels() for t in trees] assert len(trees) == 1 tree = trees[0] gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND) repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_open_close_error) print("------------------") for repair in repairs: print(repair) repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND) print("{:P}".format(repaired_tree)) ================================================ FILE: stanza/tests/constituency/test_in_order_oracle.py ================================================ import itertools import pytest from stanza.models.constituency import parse_transitions from stanza.models.constituency import tree_reader from stanza.models.constituency.base_model import SimpleModel from stanza.models.constituency.in_order_oracle import * from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme from stanza.models.constituency.transition_sequence import build_treebank from stanza.tests import * from stanza.tests.constituency.test_transition_sequence import reconstruct_tree pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # A sample tree from PTB with a single unary transition (at a location other than root) SINGLE_UNARY_TREE = """ ( (S (NP-SBJ-1 (DT A) (NN record) (NN date) ) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set) (NP (-NONE- *-1) )))) (. .) )) """ # [Shift, OpenConstituent(('NP-SBJ-1',)), Shift, Shift, CloseConstituent, OpenConstituent(('S',)), Shift, OpenConstituent(('VP',)), Shift, Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)), CloseConstituent, CloseConstituent, CloseConstituent, CloseConstituent, Shift, CloseConstituent, OpenConstituent(('ROOT',)), CloseConstituent] # A sample tree from PTB with a double unary transition (at a location other than root) DOUBLE_UNARY_TREE = """ ( (S (NP-SBJ (NP (RB Not) (PDT all) (DT those) ) (SBAR (WHNP-3 (WP who) ) (S (NP-SBJ (-NONE- *T*-3) ) (VP (VBD wrote) )))) (VP (VBP oppose) (NP (DT the) (NNS changes) )) (. .) )) """ # A sample tree from PTB with a triple unary transition (at a location other than root) # The triple unary is at the START of the next bracket, which affects how the # dynamic oracle repairs the transition sequence TRIPLE_UNARY_START_TREE = """ ( (S (PRN (S (NP-SBJ (-NONE- *) ) (VP (VB See) ))) (, ,) (NP-SBJ (NP (DT the) (JJ other) (NN rule) ) (PP (IN of) (NP (NN thumb) )) (PP (IN about) (NP (NN ballooning) ))))) """ # A sample tree from PTB with a triple unary transition (at a location other than root) # The triple unary is at the END of the next bracket, which affects how the # dynamic oracle repairs the transition sequence TRIPLE_UNARY_END_TREE = """ ( (S (NP (NNS optimists) ) (VP (VBP expect) (S (NP-SBJ-4 (NNP Hong) (NNP Kong) ) (VP (TO to) (VP (VB hum) (ADVP-CLR (RB along) ) (SBAR-MNR (RB as) (S (NP-SBJ (-NONE- *-4) ) (VP (-NONE- *?*) (ADVP-TMP (IN before) )))))))))) """ TREES = [SINGLE_UNARY_TREE, DOUBLE_UNARY_TREE, TRIPLE_UNARY_START_TREE, TRIPLE_UNARY_END_TREE] TREEBANK = "\n".join(TREES) NOUN_PHRASE_TREE = """ ( (NP (NP (NNP Chicago) (POS 's)) (NNP Goodman) (NNP Theatre))) """ WIDE_NP_TREE = """ ( (S (NP-SBJ (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP-SBJ (NNS mice)) (VP (VBP are) (NP-PRD (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful)) (JJ experimental) (NN system)) (SBAR (WHADVP-2 (-NONE- *0*)) (S (NP-SBJ (-NONE- *PRO*)) (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics))))))))))))) """ WIDE_TREES = [NOUN_PHRASE_TREE, WIDE_NP_TREE] WIDE_TREEBANK = "\n".join(WIDE_TREES) ROOT_LABELS = ["ROOT"] def get_repairs(gold_sequence, wrong_transition, repair_fn): """ Use the repair function and the wrong transition to iterate over the gold sequence Returns a list of possible repairs, one for each position in the sequence Repairs are tuples, (idx, seq) """ repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None)) for idx, gold_transition in enumerate(gold_sequence)] repairs = [x for x in repairs if x[1] is not None] return repairs @pytest.fixture(scope="module") def unary_trees(): trees = tree_reader.read_trees(TREEBANK) trees = [t.prune_none().simplify_labels() for t in trees] assert len(trees) == len(TREES) return trees @pytest.fixture(scope="module") def gold_sequences(unary_trees): gold_sequences = build_treebank(unary_trees, TransitionScheme.IN_ORDER) return gold_sequences @pytest.fixture(scope="module") def wide_trees(): trees = tree_reader.read_trees(WIDE_TREEBANK) trees = [t.prune_none().simplify_labels() for t in trees] assert len(trees) == len(WIDE_TREES) return trees def test_wrong_open_root(gold_sequences): """ Test the results of the dynamic oracle on a few trees if the ROOT is mishandled. """ wrong_transition = OpenConstituent("S") gold_transition = OpenConstituent("ROOT") close_transition = CloseConstituent() for gold_sequence in gold_sequences: # each of the sequences should be ended with ROOT, Close assert gold_sequence[-2] == gold_transition repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_root_error) # there is only spot in the sequence with a ROOT, so there should # be exactly one location which affords a S/ROOT replacement assert len(repairs) == 1 repair = repairs[0] # the repair should occur at the -2 position, which is where ROOT is assert repair[0] == len(gold_sequence) - 2 # and the resulting list should have the wrong transition followed by a Close # to give the model another chance to close the tree expected = gold_sequence[:-2] + [wrong_transition, close_transition] + gold_sequence[-2:] assert repair[1] == expected def test_missed_unary(gold_sequences): """ Test the repairs of an open/open error if it is effectively a skipped unary transition """ wrong_transition = OpenConstituent("S") repairs = get_repairs(gold_sequences[0], wrong_transition, fix_wrong_open_unary_chain) assert len(repairs) == 0 # here we are simulating picking NT-S instead of NT-VP # the DOUBLE_UNARY tree has one location where this is relevant, index 11 repairs = get_repairs(gold_sequences[1], wrong_transition, fix_wrong_open_unary_chain) assert len(repairs) == 1 assert repairs[0][0] == 11 assert repairs[0][1] == gold_sequences[1][:11] + gold_sequences[1][13:] # the TRIPLE_UNARY_START tree has two locations where this is relevant # at index 1, the pattern goes (S (VP ...)) # so choosing S instead of VP means you can skip the VP and only miss that one bracket # at index 5, the pattern goes (S (PRN (S (VP ...))) (...)) # note that this is capturing a unary transition into a larger constituent # skipping the PRN is satisfactory repairs = get_repairs(gold_sequences[2], wrong_transition, fix_wrong_open_unary_chain) assert len(repairs) == 2 assert repairs[0][0] == 1 assert repairs[0][1] == gold_sequences[2][:1] + gold_sequences[2][3:] assert repairs[1][0] == 5 assert repairs[1][1] == gold_sequences[2][:5] + gold_sequences[2][7:] # The TRIPLE_UNARY_END tree has 2 sections of tree for a total of 3 locations # where the repair might happen # Surprisingly the unary transition at the very start can only be # repaired by skipping it and using the outer S transition instead # The second repair overall (first repair in the second location) # should have a double skip to reach the S node repairs = get_repairs(gold_sequences[3], wrong_transition, fix_wrong_open_unary_chain) assert len(repairs) == 3 assert repairs[0][0] == 1 assert repairs[0][1] == gold_sequences[3][:1] + gold_sequences[3][3:] assert repairs[1][0] == 21 assert repairs[1][1] == gold_sequences[3][:21] + gold_sequences[3][25:] assert repairs[2][0] == 23 assert repairs[2][1] == gold_sequences[3][:23] + gold_sequences[3][25:] def test_open_with_stuff(unary_trees, gold_sequences): wrong_transition = OpenConstituent("S") expected_trees = [ "(ROOT (S (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))", "(ROOT (S (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))", None, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))" ] for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees): repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_stuff_unary) if expected is None: assert len(repairs) == 0 else: assert len(repairs) == 1 result = reconstruct_tree(tree, repairs[0][1]) assert str(result) == expected def test_general_open(gold_sequences): wrong_transition = OpenConstituent("SBARQ") for sequence in gold_sequences: repairs = get_repairs(sequence, wrong_transition, fix_wrong_open_general) assert len(repairs) == sum(isinstance(x, OpenConstituent) for x in sequence) - 1 for repair in repairs: assert len(repair[1]) == len(sequence) assert repair[1][repair[0]] == wrong_transition assert repair[1][:repair[0]] == sequence[:repair[0]] assert repair[1][repair[0]+1:] == sequence[repair[0]+1:] def test_missed_unary(unary_trees, gold_sequences): shift_transition = Shift() close_transition = CloseConstituent() expected_close_results = [ [(12, 2)], [(11, 4), (13, 2)], # (NP NN thumb) and (NP NN ballooning) are both candidates for this repair [(18, 2), (24, 2)], [(21, 6), (23, 4), (25, 2)], ] expected_shift_results = [ (), (), (), # (ADVP-CLR (RB along)) is followed by a shift [(16, 2)], ] for tree, sequence, expected_close, expected_shift in zip(unary_trees, gold_sequences, expected_close_results, expected_shift_results): repairs = get_repairs(sequence, close_transition, fix_missed_unary) assert len(repairs) == len(expected_close) for repair, (expected_idx, expected_len) in zip(repairs, expected_close): assert repair[0] == expected_idx assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:] repairs = get_repairs(sequence, shift_transition, fix_missed_unary) assert len(repairs) == len(expected_shift) for repair, (expected_idx, expected_len) in zip(repairs, expected_shift): assert repair[0] == expected_idx assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:] def test_open_shift(unary_trees, gold_sequences): shift_transition = Shift() expected_repairs = [ [(7, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"), (10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VBN been) (VP (VBN set))) (. .)))")], [(7, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WP who) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"), (9, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"), (19, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose) (NP (DT the) (NNS changes)) (. .)))"), (21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (DT the) (NNS changes)) (. .)))")], [(14, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"), (16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"), (22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about) (NP (NN ballooning)))))")], [(5, "(ROOT (S (NP (NNS optimists)) (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), (10, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), (12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), (14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), (19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (RB as) (S (VP (ADVP (IN before))))))))))")] ] for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs): repairs = get_repairs(sequence, shift_transition, fix_open_shift) assert len(repairs) == len(expected) for repair, (idx, expected_tree) in zip(repairs, expected): assert repair[0] == idx result_tree = reconstruct_tree(tree, repair[1]) assert str(result_tree) == expected_tree def test_open_close(unary_trees, gold_sequences): close_transition = CloseConstituent() expected_repairs = [ [(7, "(ROOT (S (S (NP (DT A) (NN record) (NN date)) (VBZ has)) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"), (10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VP (VBZ has) (RB n't) (VBN been)) (VP (VBN set))) (. .)))")], # missed the WHNP. The surrounding SBAR cannot be created, either [(7, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"), # missed the SBAR (9, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who))) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"), # missed the VP around "oppose the changes" (19, "(ROOT (S (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose)) (NP (DT the) (NNS changes)) (. .)))"), # missed the NP in "the changes", looks pretty bad tbh (21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VP (VBP oppose) (DT the)) (NNS changes)) (. .)))")], [(14, "(ROOT (S (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule))) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"), (16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (IN of)) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"), (22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about)) (NP (NN ballooning)))))")], [(5, "(ROOT (S (S (NP (NNS optimists)) (VBP expect)) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), (10, "(ROOT (S (NP (NNS optimists)) (VP (VP (VBP expect) (NP (NNP Hong) (NNP Kong))) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), (12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (S (NP (NNP Hong) (NNP Kong)) (TO to)) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), (14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (VP (TO to) (VB hum)) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), (19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (ADVP (RB along)) (RB as)) (S (VP (ADVP (IN before))))))))))")] ] for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs): repairs = get_repairs(sequence, close_transition, fix_open_close) assert len(repairs) == len(expected) for repair, (idx, expected_tree) in zip(repairs, expected): assert repair[0] == idx result_tree = reconstruct_tree(tree, repair[1]) assert str(result_tree) == expected_tree def test_shift_close(unary_trees, gold_sequences): """ Test the fix for a shift -> close These errors can occur pretty much everywhere, and the fix is quite simple, so we only test a few cases. """ close_transition = CloseConstituent() expected_tree = "(ROOT (S (NP (NP (DT A)) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))" repairs = get_repairs(gold_sequences[0], close_transition, fix_shift_close) assert len(repairs) == 7 result_tree = reconstruct_tree(unary_trees[0], repairs[0][1]) assert str(result_tree) == expected_tree repairs = get_repairs(gold_sequences[1], close_transition, fix_shift_close) assert len(repairs) == 8 repairs = get_repairs(gold_sequences[2], close_transition, fix_shift_close) assert len(repairs) == 8 repairs = get_repairs(gold_sequences[3], close_transition, fix_shift_close) assert len(repairs) == 9 for rep in repairs: if rep[0] == 16: # This one is special because it occurs as part of a unary # in other words, it should go unary, shift # and instead we are making it close where the unary should be # ... the unary would create "(ADVP (RB along))" expected_tree = "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))" result_tree = reconstruct_tree(unary_trees[3], rep[1]) assert str(result_tree) == expected_tree break else: raise AssertionError("Did not find an expected repair location") def test_close_open_shift_nested(unary_trees, gold_sequences): shift_transition = Shift() expected_trees = [{}, {4: "(ROOT (S (NP (RB Not) (PDT all) (DT those) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"}, {4: "(ROOT (S (VP (VB See)) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))", 13: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"}, {}] for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees): repairs = get_repairs(gold_sequence, shift_transition, fix_close_open_shift_nested) assert len(repairs) == len(expected) if len(expected) >= 1: for repair in repairs: assert repair[0] in expected.keys() result_tree = reconstruct_tree(tree, repair[1]) assert str(result_tree) == expected[repair[0]] def check_repairs(trees, gold_sequences, expected_trees, transition, repair_fn): for tree_idx, (gold_tree, gold_sequence, expected) in enumerate(zip(trees, gold_sequences, expected_trees)): repairs = get_repairs(gold_sequence, transition, repair_fn) if expected is not None: assert len(repairs) == len(expected) for repair in repairs: assert repair[0] in expected result_tree = reconstruct_tree(gold_tree, repair[1]) assert str(result_tree) == expected[repair[0]] else: print("---------------------") print("{:P}".format(gold_tree)) print(gold_sequence) #print(repairs) for repair in repairs: print("---------------------") print(gold_sequence) print(repair[1]) result_tree = reconstruct_tree(gold_tree, repair[1]) print("{:P}".format(gold_tree)) print("{:P}".format(result_tree)) print(tree_idx) print(repair[0]) print(result_tree) def test_close_open_shift_unambiguous(unary_trees, gold_sequences): shift_transition = Shift() expected_trees = [{}, {8: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who) (S (VP (VBD wrote)))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"}, {}, {2: "(ROOT (S (NP (NNS optimists) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))", 9: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"}] check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_unambiguous_bracket) def test_close_open_shift_ambiguous_early(unary_trees, gold_sequences): shift_transition = Shift() expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))))) (. .)))"}, {16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes)))) (. .)))"}, {2: "(ROOT (S (PRN (S (VP (VB See) (, ,)))) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))", 6: "(ROOT (S (PRN (S (VP (VB See))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"}, {}] check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_early) def test_close_open_shift_ambiguous_late(unary_trees, gold_sequences): shift_transition = Shift() expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))))"}, {16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .))))"}, {2: "(ROOT (S (PRN (S (VP (VB See) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))))", 6: "(ROOT (S (PRN (S (VP (VB See))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))"}, {}] check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_late) def test_close_shift_shift(unary_trees, wide_trees): """ Test that close -> shift works when there is a single block shifted after Includes a test specifically that there is no oracle action when there are two blocks after the missed close """ shift_transition = Shift() expected_trees = [{15: "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .))))"}, {24: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes)) (. .))))"}, {20: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning)))))))"}, {17: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"}, {}, {}] test_trees = unary_trees + wide_trees gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER) check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_unambiguous) def test_close_shift_shift_early(unary_trees, wide_trees): """ Test that close -> shift works when there are multiple blocks shifted after Also checks that the single block case is skipped, so as to keep them separate when testing A tree with the expected property was specifically added for this test """ shift_transition = Shift() test_trees = unary_trees + wide_trees gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER) expected_trees = [{}, {}, {}, {}, {}, {21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental)) (NN system)) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}] check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_early) def test_close_shift_shift_late(unary_trees, wide_trees): """ Test that close -> shift works when there are multiple blocks shifted after Also checks that the single block case is skipped, so as to keep them separate when testing A tree with the expected property was specifically added for this test """ shift_transition = Shift() test_trees = unary_trees + wide_trees gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER) expected_trees = [{}, {}, {}, {}, {}, {21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental) (NN system))) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}] check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_late) ================================================ FILE: stanza/tests/constituency/test_lstm_model.py ================================================ import os import pytest import torch from stanza.models.common import pretrain from stanza.models.common.utils import set_random_seed from stanza.models.constituency import parse_transitions from stanza.tests import * from stanza.tests.constituency import test_parse_transitions from stanza.tests.constituency.test_trainer import build_trainer pytestmark = [pytest.mark.pipeline, pytest.mark.travis] @pytest.fixture(scope="module") def pretrain_file(): return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' def build_model(pretrain_file, *args): # By default, we turn off multistage, since that can turn off various other structures in the initial training args = ['--no_multistage', '--pattn_num_layers', '4', '--pattn_d_model', '256', '--hidden_size', '128', '--use_lattn'] + list(args) trainer = build_trainer(pretrain_file, *args) return trainer.model @pytest.fixture(scope="module") def unary_model(pretrain_file): return build_model(pretrain_file, "--transition_scheme", "TOP_DOWN_UNARY") def test_initial_state(unary_model): test_parse_transitions.test_initial_state(unary_model) def test_shift(pretrain_file): # TODO: might be good to include some tests specifically for shift # in the context of a model with unaries model = build_model(pretrain_file) test_parse_transitions.test_shift(model) def test_unary(unary_model): test_parse_transitions.test_unary(unary_model) def test_unary_requires_root(unary_model): test_parse_transitions.test_unary_requires_root(unary_model) def test_open(unary_model): test_parse_transitions.test_open(unary_model) def test_compound_open(pretrain_file): model = build_model(pretrain_file, '--transition_scheme', "TOP_DOWN_COMPOUND") test_parse_transitions.test_compound_open(model) def test_in_order_open(pretrain_file): model = build_model(pretrain_file, '--transition_scheme', "IN_ORDER") test_parse_transitions.test_in_order_open(model) def test_close(unary_model): test_parse_transitions.test_close(unary_model) def run_forward_checks(model, num_states=1): """ Run a couple small transitions and a forward pass on the given model Results are not checked in any way. This function allows for testing that building models with various options results in a functional model. """ states = test_parse_transitions.build_initial_state(model, num_states) model(states) shift = parse_transitions.Shift() shifts = [shift for _ in range(num_states)] states = model.bulk_apply(states, shifts) model(states) open_transition = parse_transitions.OpenConstituent("NP") open_transitions = [open_transition for _ in range(num_states)] assert open_transition.is_legal(states[0], model) states = model.bulk_apply(states, open_transitions) assert states[0].num_opens == 1 model(states) states = model.bulk_apply(states, shifts) model(states) states = model.bulk_apply(states, shifts) model(states) assert states[0].num_opens == 1 # now should have "mox", "opal" on the constituents close_transition = parse_transitions.CloseConstituent() close_transitions = [close_transition for _ in range(num_states)] assert close_transition.is_legal(states[0], model) states = model.bulk_apply(states, close_transitions) assert states[0].num_opens == 0 model(states) def test_unary_forward(unary_model): """ Checks that the forward pass doesn't crash when run after various operations Doesn't check the forward pass for making reasonable answers """ run_forward_checks(unary_model) def test_lstm_forward(pretrain_file): model = build_model(pretrain_file) run_forward_checks(model, num_states=1) run_forward_checks(model, num_states=2) def test_lstm_layers(pretrain_file): model = build_model(pretrain_file, '--num_lstm_layers', '1') run_forward_checks(model) model = build_model(pretrain_file, '--num_lstm_layers', '2') run_forward_checks(model) model = build_model(pretrain_file, '--num_lstm_layers', '3') run_forward_checks(model) def test_multiple_output_forward(pretrain_file): """ Test a couple different sizes of output layers """ model = build_model(pretrain_file, '--num_output_layers', '1', '--num_lstm_layers', '2') run_forward_checks(model) model = build_model(pretrain_file, '--num_output_layers', '2', '--num_lstm_layers', '2') run_forward_checks(model) model = build_model(pretrain_file, '--num_output_layers', '3', '--num_lstm_layers', '2') run_forward_checks(model) def test_no_tag_embedding_forward(pretrain_file): """ Test that the model continues to work if the tag embedding is turned on or off """ model = build_model(pretrain_file, '--tag_embedding_dim', '20') run_forward_checks(model) model = build_model(pretrain_file, '--tag_embedding_dim', '0') run_forward_checks(model) def test_forward_combined_dummy(pretrain_file): """ Tests combined dummy and open node embeddings """ model = build_model(pretrain_file, '--combined_dummy_embedding') run_forward_checks(model) model = build_model(pretrain_file, '--no_combined_dummy_embedding') run_forward_checks(model) def test_nonlinearity_init(pretrain_file): """ Tests that different initialization methods of the nonlinearities result in valid tensors """ model = build_model(pretrain_file, '--nonlinearity', 'relu') run_forward_checks(model) model = build_model(pretrain_file, '--nonlinearity', 'tanh') run_forward_checks(model) model = build_model(pretrain_file, '--nonlinearity', 'silu') run_forward_checks(model) def test_forward_charlm(pretrain_file): """ Tests loading and running a charlm Note that this doesn't test the results of the charlm itself, just that the model is shaped correctly """ forward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "1billion.pt") backward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "1billion.pt") assert os.path.exists(forward_charlm_path), "Need to download en test models (or update path to the forward charlm)" assert os.path.exists(backward_charlm_path), "Need to download en test models (or update path to the backward charlm)" model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'none') run_forward_checks(model) model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'words') run_forward_checks(model) def test_forward_bert(pretrain_file): """ Test on a tiny Bert, which hopefully does not take up too much disk space or memory """ bert_model = "hf-internal-testing/tiny-bert" model = build_model(pretrain_file, '--bert_model', bert_model) run_forward_checks(model) def test_forward_xlnet(pretrain_file): """ Test on a tiny xlnet, which hopefully does not take up too much disk space or memory """ bert_model = "hf-internal-testing/tiny-random-xlnet" model = build_model(pretrain_file, '--bert_model', bert_model) run_forward_checks(model) def test_forward_sentence_boundaries(pretrain_file): """ Test start & stop boundary vectors """ model = build_model(pretrain_file, '--sentence_boundary_vectors', 'everything') run_forward_checks(model) model = build_model(pretrain_file, '--sentence_boundary_vectors', 'words') run_forward_checks(model) model = build_model(pretrain_file, '--sentence_boundary_vectors', 'none') run_forward_checks(model) def test_forward_constituency_composition(pretrain_file): """ Test different constituency composition functions """ model = build_model(pretrain_file, '--constituency_composition', 'bilstm') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'max') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'key') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'untied_key') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'untied_max') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'bilstm_max') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm_cx') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'bigram') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'attn') run_forward_checks(model, num_states=2) def test_forward_key_position(pretrain_file): """ Test KEY and UNTIED_KEY either with or without reduce_position """ model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '0') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '32') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '0') run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '32') run_forward_checks(model, num_states=2) def test_forward_attn_hidden_size(pretrain_file): """ Test that when attn is used with hidden sizes not evenly divisible by reduce_heads, the model reconfigures the hidden_size """ model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129') assert model.hidden_size >= 129 assert model.hidden_size % model.reduce_heads == 0 run_forward_checks(model, num_states=2) model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129', '--reduce_heads', '10') assert model.hidden_size == 130 assert model.reduce_heads == 10 def test_forward_partitioned_attention(pretrain_file): """ Test with & without partitioned attention layers """ model = build_model(pretrain_file, '--pattn_num_heads', '8', '--pattn_num_layers', '8') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_heads', '0', '--pattn_num_layers', '0') run_forward_checks(model) def test_forward_labeled_attention(pretrain_file): """ Test with & without labeled attention layers """ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16') run_forward_checks(model) model = build_model(pretrain_file, '--lattn_d_proj', '0', '--lattn_d_l', '0') run_forward_checks(model) model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_combined_input') run_forward_checks(model) def test_lattn_partitioned(pretrain_file): model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_partitioned') run_forward_checks(model) model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned') run_forward_checks(model) def test_lattn_projection(pretrain_file): """ Test with & without labeled attention layers """ with pytest.raises(ValueError): # this is too small model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '256', '--lattn_partitioned') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned', '--lattn_d_input_proj', '256') run_forward_checks(model) model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '768') run_forward_checks(model) # check that it works if we turn off the projection, # in case having it on beccomes the default model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '0') run_forward_checks(model) def test_forward_timing_choices(pretrain_file): """ Test different timing / position encodings """ model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'sin') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'learned') run_forward_checks(model) def test_transition_stack(pretrain_file): """ Test different transition stack types: lstm & attention """ model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_stack', 'attn', '--transition_heads', '1') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_stack', 'attn', '--transition_heads', '4') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_stack', 'lstm') run_forward_checks(model) def test_constituent_stack(pretrain_file): """ Test different constituent stack types: lstm & attention """ model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--constituent_stack', 'attn', '--constituent_heads', '1') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--constituent_stack', 'attn', '--constituent_heads', '4') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--constituent_stack', 'lstm') run_forward_checks(model) def test_different_transition_sizes(pretrain_file): """ If the transition hidden size and embedding size are different, the model should still work """ model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_embedding_dim', '10', '--transition_hidden_size', '10', '--sentence_boundary_vectors', 'everything') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_embedding_dim', '20', '--transition_hidden_size', '10', '--sentence_boundary_vectors', 'everything') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_embedding_dim', '10', '--transition_hidden_size', '20', '--sentence_boundary_vectors', 'everything') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_embedding_dim', '10', '--transition_hidden_size', '10', '--sentence_boundary_vectors', 'none') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_embedding_dim', '20', '--transition_hidden_size', '10', '--sentence_boundary_vectors', 'none') run_forward_checks(model) model = build_model(pretrain_file, '--pattn_num_layers', '0', '--lattn_d_proj', '0', '--transition_embedding_dim', '10', '--transition_hidden_size', '20', '--sentence_boundary_vectors', 'none') run_forward_checks(model) def test_relative_attention(pretrain_file): model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat') run_forward_checks(model) def test_relative_attention_cat(pretrain_file): model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat') run_forward_checks(model) cat_size = model.word_input_size model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat') run_forward_checks(model) no_cat_size = model.word_input_size assert cat_size > no_cat_size def test_relative_attention_directional(pretrain_file): model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_forward', '--no_rattn_cat') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_reverse', '--no_rattn_cat') run_forward_checks(model) def test_relative_attention_sinks(pretrain_file): model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_window', '2', '--rattn_sinks', '1') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_sinks', '1') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_window', '2', '--rattn_sinks', '2') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_sinks', '2') run_forward_checks(model) def test_relative_attention_cat_sinks(pretrain_file): model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_window', '2', '--rattn_sinks', '1') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_sinks', '1') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_window', '2', '--rattn_sinks', '2') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_sinks', '2') run_forward_checks(model) def test_relative_attention_endpoint_sinks(pretrain_file): model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '1') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '1') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '2') run_forward_checks(model) model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '2') run_forward_checks(model) def test_lstm_tree_forward(pretrain_file): """ Test the LSTM_TREE forward pass """ model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm') run_forward_checks(model) model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm') run_forward_checks(model) model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm') run_forward_checks(model) def test_lstm_tree_cx_forward(pretrain_file): """ Test the LSTM_TREE_CX forward pass """ model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm_cx') run_forward_checks(model) model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm_cx') run_forward_checks(model) model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm_cx') run_forward_checks(model) def test_maxout(pretrain_file): """ Test with and without maxout layers for output """ model = build_model(pretrain_file, '--maxout_k', '0') run_forward_checks(model) # check the output size & implicitly check the type # to check for a particularly silly bug assert model.output_layers[-1].weight.shape[0] == len(model.transitions) model = build_model(pretrain_file, '--maxout_k', '2') run_forward_checks(model) assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 2 model = build_model(pretrain_file, '--maxout_k', '3') run_forward_checks(model) assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 3 def check_structure_test(pretrain_file, args1, args2): """ Test that the "copy" method copies the parameters from one model to another Also check that the copied models produce the same results """ set_random_seed(1000) other = build_model(pretrain_file, *args1) other.eval() set_random_seed(1001) model = build_model(pretrain_file, *args2) model.eval() assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight) assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight) model.copy_with_new_structure(other) assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight) assert torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight) # the norms will be the same, as the non-zero values are all the same assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0)) # now, check that applying one transition to an initial state # results in the same values in the output states for both models # as the pattn layer inputs are 0, the output values should be equal shift = [parse_transitions.Shift()] model_states = test_parse_transitions.build_initial_state(model, 1) model_states = model.bulk_apply(model_states, shift) other_states = test_parse_transitions.build_initial_state(other, 1) other_states = other.bulk_apply(other_states, shift) for i, j in zip(other_states[0].word_queue, model_states[0].word_queue): assert torch.allclose(i.hx, j.hx, atol=1e-07) for i, j in zip(other_states[0].transitions, model_states[0].transitions): assert torch.allclose(i.lstm_hx, j.lstm_hx) assert torch.allclose(i.lstm_cx, j.lstm_cx) for i, j in zip(other_states[0].constituents, model_states[0].constituents): assert (i.value is None) == (j.value is None) if i.value is not None: assert torch.allclose(i.value.tree_hx, j.value.tree_hx, atol=1e-07) assert torch.allclose(i.lstm_hx, j.lstm_hx) assert torch.allclose(i.lstm_cx, j.lstm_cx) def test_copy_with_new_structure_same(pretrain_file): """ Test that copying the structure with no changes works as expected """ check_structure_test(pretrain_file, ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'], ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']) def test_copy_with_new_structure_untied(pretrain_file): """ Test that copying the structure with no changes works as expected """ check_structure_test(pretrain_file, ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'MAX'], ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'UNTIED_MAX']) def test_copy_with_new_structure_pattn(pretrain_file): check_structure_test(pretrain_file, ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'], ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2']) def test_copy_with_new_structure_both(pretrain_file): check_structure_test(pretrain_file, ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'], ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2']) def test_copy_with_new_structure_lattn(pretrain_file): check_structure_test(pretrain_file, ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'], ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2']) def test_parse_tagged_words(pretrain_file): """ Small test which doesn't check results, just execution """ model = build_model(pretrain_file) sentence = [("I", "PRP"), ("am", "VBZ"), ("Luffa", "NNP")] # we don't expect a useful tree out of a random model # so we don't check the result # just check that it works without crashing result = model.parse_tagged_words([sentence], 10) assert len(result) == 1 pts = [x for x in result[0].yield_preterminals()] for word, pt in zip(sentence, pts): assert pt.children[0].label == word[0] assert pt.label == word[1] ================================================ FILE: stanza/tests/constituency/test_parse_transitions.py ================================================ import pytest from stanza.models.constituency import parse_transitions from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT from stanza.models.constituency.parse_transitions import TransitionScheme, Shift, CloseConstituent, OpenConstituent from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def build_initial_state(model, num_states=1): words = ["Unban", "Mox", "Opal"] tags = ["VB", "NNP", "NNP"] sentences = [list(zip(words, tags)) for _ in range(num_states)] states = model.initial_state_from_words(sentences) assert len(states) == num_states assert all(state.num_transitions == 0 for state in states) return states def test_initial_state(model=None): if model is None: model = SimpleModel() states = build_initial_state(model) assert len(states) == 1 state = states[0] assert state.sentence_length == 3 assert state.num_opens == 0 # each stack has a sentinel value at the end assert len(state.word_queue) == 5 assert len(state.constituents) == 1 assert len(state.transitions) == 1 assert state.word_position == 0 def test_shift(model=None): if model is None: model = SimpleModel() state = build_initial_state(model)[0] open_transition = parse_transitions.OpenConstituent("ROOT") state = open_transition.apply(state, model) open_transition = parse_transitions.OpenConstituent("S") state = open_transition.apply(state, model) shift = parse_transitions.Shift() assert shift.is_legal(state, model) assert len(state.word_queue) == 5 assert state.word_position == 0 state = shift.apply(state, model) assert len(state.word_queue) == 5 # 4 because of the dummy created by the opens assert len(state.constituents) == 4 assert len(state.transitions) == 4 assert shift.is_legal(state, model) assert state.word_position == 1 assert not state.empty_word_queue() state = shift.apply(state, model) assert len(state.word_queue) == 5 assert len(state.constituents) == 5 assert len(state.transitions) == 5 assert shift.is_legal(state, model) assert state.word_position == 2 assert not state.empty_word_queue() state = shift.apply(state, model) assert len(state.word_queue) == 5 assert len(state.constituents) == 6 assert len(state.transitions) == 6 assert not shift.is_legal(state, model) assert state.word_position == 3 assert state.empty_word_queue() constituents = state.constituents assert model.get_top_constituent(constituents).children[0].label == 'Opal' constituents = constituents.pop() assert model.get_top_constituent(constituents).children[0].label == 'Mox' constituents = constituents.pop() assert model.get_top_constituent(constituents).children[0].label == 'Unban' def test_initial_unary(model=None): # it doesn't make sense to start with a CompoundUnary if model is None: model = SimpleModel() state = build_initial_state(model)[0] unary = parse_transitions.CompoundUnary('ROOT', 'VP') assert unary.label == ('ROOT', 'VP',) assert not unary.is_legal(state, model) unary = parse_transitions.CompoundUnary('VP') assert unary.label == ('VP',) assert not unary.is_legal(state, model) def test_unary(model=None): if model is None: model = SimpleModel() state = build_initial_state(model)[0] shift = parse_transitions.Shift() state = shift.apply(state, model) # this is technically the wrong parse but we're being lazy unary = parse_transitions.CompoundUnary('S', 'VP') assert unary.is_legal(state, model) state = unary.apply(state, model) assert not unary.is_legal(state, model) tree = model.get_top_constituent(state.constituents) assert tree.label == 'S' assert len(tree.children) == 1 tree = tree.children[0] assert tree.label == 'VP' assert len(tree.children) == 1 tree = tree.children[0] assert tree.label == 'VB' assert tree.is_preterminal() def test_unary_requires_root(model=None): if model is None: model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY) state = build_initial_state(model)[0] open_transition = parse_transitions.OpenConstituent("S") assert open_transition.is_legal(state, model) state = open_transition.apply(state, model) shift = parse_transitions.Shift() assert shift.is_legal(state, model) state = shift.apply(state, model) assert shift.is_legal(state, model) state = shift.apply(state, model) assert shift.is_legal(state, model) state = shift.apply(state, model) assert not shift.is_legal(state, model) close_transition = parse_transitions.CloseConstituent() assert close_transition.is_legal(state, model) state = close_transition.apply(state, model) assert not open_transition.is_legal(state, model) assert not close_transition.is_legal(state, model) np_unary = parse_transitions.CompoundUnary("NP") assert not np_unary.is_legal(state, model) root_unary = parse_transitions.CompoundUnary("ROOT") assert root_unary.is_legal(state, model) assert not state.finished(model) state = root_unary.apply(state, model) assert not root_unary.is_legal(state, model) assert state.finished(model) def test_open(model=None): if model is None: model = SimpleModel() state = build_initial_state(model)[0] shift = parse_transitions.Shift() state = shift.apply(state, model) state = shift.apply(state, model) assert state.num_opens == 0 open_transition = parse_transitions.OpenConstituent("VP") assert open_transition.is_legal(state, model) state = open_transition.apply(state, model) assert open_transition.is_legal(state, model) assert state.num_opens == 1 # check that it is illegal if there are too many opens already for i in range(20): state = open_transition.apply(state, model) assert not open_transition.is_legal(state, model) assert state.num_opens == 21 # check that it is illegal if the state is out of words state = build_initial_state(model)[0] state = shift.apply(state, model) state = shift.apply(state, model) state = shift.apply(state, model) assert not open_transition.is_legal(state, model) def test_compound_open(model=None): if model is None: model = SimpleModel() state = build_initial_state(model)[0] open_transition = parse_transitions.OpenConstituent("ROOT", "S") assert open_transition.is_legal(state, model) shift = parse_transitions.Shift() close_transition = parse_transitions.CloseConstituent() state = open_transition.apply(state, model) state = shift.apply(state, model) state = shift.apply(state, model) state = shift.apply(state, model) state = close_transition.apply(state, model) tree = model.get_top_constituent(state.constituents) assert tree.label == 'ROOT' assert len(tree.children) == 1 tree = tree.children[0] assert tree.label == 'S' assert len(tree.children) == 3 assert tree.children[0].children[0].label == 'Unban' assert tree.children[1].children[0].label == 'Mox' assert tree.children[2].children[0].label == 'Opal' def test_in_order_open(model=None): if model is None: model = SimpleModel(TransitionScheme.IN_ORDER) state = build_initial_state(model)[0] shift = parse_transitions.Shift() assert shift.is_legal(state, model) state = shift.apply(state, model) assert not shift.is_legal(state, model) open_vp = parse_transitions.OpenConstituent("VP") assert open_vp.is_legal(state, model) state = open_vp.apply(state, model) assert not open_vp.is_legal(state, model) close_trans = parse_transitions.CloseConstituent() assert close_trans.is_legal(state, model) state = close_trans.apply(state, model) open_s = parse_transitions.OpenConstituent("S") assert open_s.is_legal(state, model) state = open_s.apply(state, model) assert not open_vp.is_legal(state, model) # check that root transitions won't happen in the middle of a parse open_root = parse_transitions.OpenConstituent("ROOT") assert not open_root.is_legal(state, model) # build (NP (NNP Mox) (NNP Opal)) open_np = parse_transitions.OpenConstituent("NP") assert shift.is_legal(state, model) state = shift.apply(state, model) assert open_np.is_legal(state, model) # make sure root can't happen in places where an arbitrary open is legal assert not open_root.is_legal(state, model) state = open_np.apply(state, model) assert shift.is_legal(state, model) state = shift.apply(state, model) assert close_trans.is_legal(state, model) state = close_trans.apply(state, model) assert close_trans.is_legal(state, model) state = close_trans.apply(state, model) assert open_root.is_legal(state, model) state = open_root.apply(state, model) def test_too_many_unaries_close(): """ This tests rejecting Close at the start of a sequence after too many unary transitions The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence """ model = SimpleModel(TransitionScheme.IN_ORDER) state = build_initial_state(model)[0] shift = parse_transitions.Shift() assert shift.is_legal(state, model) state = shift.apply(state, model) open_np = parse_transitions.OpenConstituent("NP") close_trans = parse_transitions.CloseConstituent() for _ in range(UNARY_LIMIT): assert open_np.is_legal(state, model) state = open_np.apply(state, model) assert close_trans.is_legal(state, model) state = close_trans.apply(state, model) assert open_np.is_legal(state, model) state = open_np.apply(state, model) assert not close_trans.is_legal(state, model) def test_too_many_unaries_open(): """ This tests rejecting Open in the middle of a sequence after too many unary transitions The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence """ model = SimpleModel(TransitionScheme.IN_ORDER) state = build_initial_state(model)[0] shift = parse_transitions.Shift() assert shift.is_legal(state, model) state = shift.apply(state, model) open_np = parse_transitions.OpenConstituent("NP") close_trans = parse_transitions.CloseConstituent() assert open_np.is_legal(state, model) state = open_np.apply(state, model) assert not open_np.is_legal(state, model) assert shift.is_legal(state, model) state = shift.apply(state, model) for _ in range(UNARY_LIMIT): assert open_np.is_legal(state, model) state = open_np.apply(state, model) assert close_trans.is_legal(state, model) state = close_trans.apply(state, model) assert not open_np.is_legal(state, model) def test_close(model=None): if model is None: model = SimpleModel() # this one actually tests an entire subtree building state = build_initial_state(model)[0] open_transition_vp = parse_transitions.OpenConstituent("VP") assert open_transition_vp.is_legal(state, model) state = open_transition_vp.apply(state, model) assert state.num_opens == 1 shift = parse_transitions.Shift() assert shift.is_legal(state, model) state = shift.apply(state, model) open_transition_np = parse_transitions.OpenConstituent("NP") assert open_transition_np.is_legal(state, model) state = open_transition_np.apply(state, model) assert state.num_opens == 2 assert shift.is_legal(state, model) state = shift.apply(state, model) assert shift.is_legal(state, model) state = shift.apply(state, model) assert not shift.is_legal(state, model) assert state.num_opens == 2 # now should have "mox", "opal" on the constituents close_transition = parse_transitions.CloseConstituent() assert close_transition.is_legal(state, model) state = close_transition.apply(state, model) assert state.num_opens == 1 assert close_transition.is_legal(state, model) state = close_transition.apply(state, model) assert state.num_opens == 0 assert not close_transition.is_legal(state, model) tree = model.get_top_constituent(state.constituents) assert tree.label == 'VP' assert len(tree.children) == 2 tree = tree.children[1] assert tree.label == 'NP' assert len(tree.children) == 2 assert tree.children[0].is_preterminal() assert tree.children[1].is_preterminal() assert tree.children[0].children[0].label == 'Mox' assert tree.children[1].children[0].label == 'Opal' # extra one for None at the start of the TreeStack assert len(state.constituents) == 2 assert state.all_transitions(model) == [open_transition_vp, shift, open_transition_np, shift, shift, close_transition, close_transition] def test_in_order_compound_finalize(model=None): """ Test the Finalize transition is only legal at the end of a sequence """ if model is None: model = SimpleModel(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND) state = build_initial_state(model)[0] finalize = parse_transitions.Finalize("ROOT") shift = parse_transitions.Shift() assert shift.is_legal(state, model) assert not finalize.is_legal(state, model) state = shift.apply(state, model) open_transition = parse_transitions.OpenConstituent("NP") assert open_transition.is_legal(state, model) assert not finalize.is_legal(state, model) state = open_transition.apply(state, model) assert state.num_opens == 1 assert shift.is_legal(state, model) assert not finalize.is_legal(state, model) state = shift.apply(state, model) assert shift.is_legal(state, model) assert not finalize.is_legal(state, model) state = shift.apply(state, model) close_transition = parse_transitions.CloseConstituent() assert close_transition.is_legal(state, model) state = close_transition.apply(state, model) assert state.num_opens == 0 assert not close_transition.is_legal(state, model) assert finalize.is_legal(state, model) state = finalize.apply(state, model) assert not finalize.is_legal(state, model) tree = model.get_top_constituent(state.constituents) assert tree.label == 'ROOT' def test_hashes(): transitions = set() shift = parse_transitions.Shift() assert shift not in transitions transitions.add(shift) assert shift in transitions shift = parse_transitions.Shift() assert shift in transitions for i in range(5): transitions.add(shift) assert len(transitions) == 1 unary = parse_transitions.CompoundUnary("asdf") assert unary not in transitions transitions.add(unary) assert unary in transitions unary = parse_transitions.CompoundUnary("asdf", "zzzz") assert unary not in transitions transitions.add(unary) transitions.add(unary) transitions.add(unary) unary = parse_transitions.CompoundUnary("asdf", "zzzz") assert unary in transitions oc = parse_transitions.OpenConstituent("asdf") assert oc not in transitions transitions.add(oc) assert oc in transitions transitions.add(oc) transitions.add(oc) assert len(transitions) == 4 assert parse_transitions.OpenConstituent("asdf") in transitions cc = parse_transitions.CloseConstituent() assert cc not in transitions transitions.add(cc) transitions.add(cc) transitions.add(cc) assert cc in transitions cc = parse_transitions.CloseConstituent() assert cc in transitions assert len(transitions) == 5 def test_sort(): expected = [] expected.append(parse_transitions.Shift()) expected.append(parse_transitions.CloseConstituent()) expected.append(parse_transitions.CompoundUnary("NP")) expected.append(parse_transitions.CompoundUnary("NP", "VP")) expected.append(parse_transitions.OpenConstituent("mox")) expected.append(parse_transitions.OpenConstituent("opal")) expected.append(parse_transitions.OpenConstituent("unban")) transitions = set(expected) transitions = sorted(transitions) assert transitions == expected def test_check_transitions(): """ Test that check_transitions passes or fails a couple simple, small test cases """ transitions = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")} other = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")} parse_transitions.check_transitions(transitions, other, "test") # This will get a pass because it is a unary made out of existing unaries other = {Shift(), CloseConstituent(), OpenConstituent("NP", "VP")} parse_transitions.check_transitions(transitions, other, "test") # This should fail with pytest.raises(RuntimeError): other = {Shift(), CloseConstituent(), OpenConstituent("NP", "ZP")} parse_transitions.check_transitions(transitions, other, "test") ================================================ FILE: stanza/tests/constituency/test_parse_tree.py ================================================ import pytest from stanza.models.constituency.parse_tree import Tree from stanza.models.constituency import tree_reader from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def test_leaf_preterminal(): foo = Tree(label="foo") assert foo.is_leaf() assert not foo.is_preterminal() assert len(foo.children) == 0 assert str(foo) == 'foo' bar = Tree(label="bar", children=foo) assert not bar.is_leaf() assert bar.is_preterminal() assert len(bar.children) == 1 assert str(bar) == "(bar foo)" baz = Tree(label="baz", children=[bar]) assert not baz.is_leaf() assert not baz.is_preterminal() assert len(baz.children) == 1 assert str(baz) == "(baz (bar foo))" def test_yield_preterminals(): text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" trees = tree_reader.read_trees(text) preterminals = list(trees[0].yield_preterminals()) assert len(preterminals) == 3 assert str(preterminals) == "[(VB Unban), (NNP Mox), (NNP Opal)]" def test_depth(): text = "(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" trees = tree_reader.read_trees(text) assert trees[0].depth() == 0 assert trees[1].depth() == 4 def test_unique_labels(): """ Test getting the unique labels from a tree Assumes tree_reader works, which should be fine since it is tested elsewhere """ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) labels = Tree.get_unique_constituent_labels(trees) expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP'] assert labels == expected def test_unique_tags(): """ Test getting the unique tags from a tree """ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) tags = Tree.get_unique_tags(trees) expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP'] assert tags == expected def test_unique_words(): """ Test getting the unique words from a tree """ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) words = Tree.get_unique_words(trees) expected = ['?', 'Who', 'in', 'seat', 'sits', 'this'] assert words == expected def test_rare_words(): """ Test getting the unique words from a tree """ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))" trees = tree_reader.read_trees(text) words = Tree.get_rare_words(trees, 0.5) expected = ['Who', 'in', 'sits'] assert words == expected def test_common_words(): """ Test getting the unique words from a tree """ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))" trees = tree_reader.read_trees(text) words = Tree.get_common_words(trees, 3) expected = ['?', 'seat', 'this'] assert words == expected def test_root_labels(): text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) assert ["ROOT"] == Tree.get_root_labels(trees) text=("( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))") trees = tree_reader.read_trees(text) assert ["ROOT"] == Tree.get_root_labels(trees) text="(FOO) (BAR)" trees = tree_reader.read_trees(text) assert ["BAR", "FOO"] == Tree.get_root_labels(trees) def test_prune_none(): text=["((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (-NONE- in) (NP (DT this) (NN seat))))) (. ?)))", # test one dead node "((SBARQ (WHNP (-NONE- Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))", # test recursive dead nodes "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (-NONE- this) (-NONE- seat))))) (. ?)))"] # test all children dead expected=["(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (NP (DT this) (NN seat))))) (. ?)))", "(ROOT (SBARQ (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))", "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"] for t, e in zip(text, expected): trees = tree_reader.read_trees(t) assert len(trees) == 1 tree = trees[0].prune_none() assert e == str(tree) def test_simplify_labels(): text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))" expected = "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))" trees = tree_reader.read_trees(text) trees = [t.simplify_labels() for t in trees] assert len(trees) == 1 assert expected == str(trees[0]) def test_remap_constituent_labels(): text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" expected="(ROOT (FOO (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" label_map = { "SBARQ": "FOO" } trees = tree_reader.read_trees(text) trees = [t.remap_constituent_labels(label_map) for t in trees] assert len(trees) == 1 assert expected == str(trees[0]) def test_remap_constituent_words(): text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))" word_map = { "Who": "unban", "sits": "mox", "in": "opal" } trees = tree_reader.read_trees(text) trees = [t.remap_words(word_map) for t in trees] assert len(trees) == 1 assert expected == str(trees[0]) def test_replace_words(): text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))" new_words = ["unban", "mox", "opal", "?"] trees = tree_reader.read_trees(text) assert len(trees) == 1 tree = trees[0] new_tree = tree.replace_words(new_words) assert expected == str(new_tree) def test_compound_constituents(): # TODO: add skinny trees like this to the various transition tests text="((VP (VB Unban)))" trees = tree_reader.read_trees(text) assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')] text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" trees = tree_reader.read_trees(text) assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('SQ', 'VP'), ('WHNP',)] text="((VP (VB Unban))) (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" trees = tree_reader.read_trees(text) assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('ROOT', 'VP'), ('SQ', 'VP'), ('WHNP',)] def test_equals(): """ Check one tree from the actual dataset for == when built with compound Open, this didn't work because of a silly bug """ text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))" trees = tree_reader.read_trees(text) assert len(trees) == 1 tree = trees[0] assert tree == tree trees2 = tree_reader.read_trees(text) tree2 = trees2[0] assert tree is not tree2 assert tree == tree2 # This tree was causing the model to barf on CTB7, # although it turns out the problem was just the # depth of the unary, not the list CHINESE_LONG_LIST_TREE = """ (ROOT (IP (NP (NNP 证券法)) (VP (PP (IN 对) (NP (DNP (NP (NP (NNP 中国)) (NP (NN 证券) (NN 市场))) (DEC 的)) (NP (NN 运作)))) (, ,) (PP (PP (IN 从) (NP (NP (NN 股票)) (NP (VV 发行) (EC 、) (VV 交易)))) (, ,) (PP (VV 到) (NP (NP (NN 上市) (NN 公司) (NN 收购)) (EC 、) (NP (NN 证券) (NN 交易所)) (EC 、) (NP (NN 证券) (NN 公司)) (EC 、) (NP (NN 登记) (NN 结算) (NN 机构)) (EC 、) (NP (NN 交易) (NN 服务) (NN 机构)) (EC 、) (NP (NN 证券业) (NN 协会)) (EC 、) (NP (NN 证券) (NN 监督) (NN 管理) (NN 机构)) (CC 和) (NP (DNP (NP (CP (CP (IP (VP (VV 违法)))))) (DEC 的)) (NP (NN 法律) (NN 责任)))))) (ADVP (RB 都)) (VP (VV 作) (AS 了) (NP (ADJP (JJ 详细)) (NP (NN 规定))))) (. 。))) """ WEIRD_UNARY = """ (DNP (NP (CP (CP (IP (VP (ASDF (NP (NN 上市) (NN 公司) (NN 收购)) (EC 、) (NP (NN 证券) (NN 交易所)) (EC 、) (NP (NN 证券) (NN 公司)) (EC 、) (NP (NN 登记) (NN 结算) (NN 机构)) (EC 、) (NP (NN 交易) (NN 服务) (NN 机构)) (EC 、) (NP (NN 证券业) (NN 协会)) (EC 、) (NP (NN 证券) (NN 监督) (NN 管理) (NN 机构)))))))) (DEC 的)) """ def test_count_unaries(): trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE) assert len(trees) == 1 assert trees[0].count_unary_depth() == 5 trees = tree_reader.read_trees(WEIRD_UNARY) assert len(trees) == 1 assert trees[0].count_unary_depth() == 5 def test_str_bracket_labels(): text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" expected = "(_ROOT (_S (_VP (_VB Unban )_VB )_VP (_NP (_NNP Mox )_NNP (_NNP Opal )_NNP )_NP )_S )_ROOT" trees = tree_reader.read_trees(text) assert len(trees) == 1 assert "{:L}".format(trees[0]) == expected def test_all_leaves_are_preterminals(): text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" trees = tree_reader.read_trees(text) assert len(trees) == 1 assert trees[0].all_leaves_are_preterminals() text = "((S (VP (VB Unban)) (NP (Mox) (NNP Opal))))" trees = tree_reader.read_trees(text) assert len(trees) == 1 assert not trees[0].all_leaves_are_preterminals() def test_latex(): """ Test the latex format for trees """ expected = "\\Tree [.S [.NP Jennifer ] [.VP has [.NP nice antennae ] ] ]" tree = "(ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ nice) (NNS antennae)))))" tree = tree_reader.read_trees(tree)[0] text = "{:T}".format(tree) assert text == expected def test_pretty_print(): """ Pretty print a couple trees - newlines & indentation """ text = "(ROOT (S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal)))) (ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric)))))))" trees = tree_reader.read_trees(text) assert len(trees) == 2 expected = """(ROOT (S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal)))) """ assert "{:P}".format(trees[0]) == expected expected = """(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric))))))) """ assert "{:P}".format(trees[1]) == expected assert text == "{:O} {:O}".format(*trees) def test_reverse(): text = "(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB lick) (NP (NP (NNP Jennifer) (POS 's)) (NNS antennae))))))))" trees = tree_reader.read_trees(text) assert len(trees) == 1 reversed_tree = trees[0].reverse() assert str(reversed_tree) == "(ROOT (S (VP (S (VP (VP (NP (NNS antennae) (NP (POS 's) (NNP Jennifer))) (VB lick)) (TO to))) (VBP want)) (NP (PRP I))))" ================================================ FILE: stanza/tests/constituency/test_positional_encoding.py ================================================ import pytest import torch from stanza import Pipeline from stanza.models.constituency.positional_encoding import SinusoidalEncoding, AddSinusoidalEncoding from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def test_positional_encoding(): encoding = SinusoidalEncoding(model_dim=10, max_len=6) foo = encoding(torch.tensor([5])) assert foo.shape == (1, 10) # TODO: check the values def test_resize(): encoding = SinusoidalEncoding(model_dim=10, max_len=3) foo = encoding(torch.tensor([5])) assert foo.shape == (1, 10) def test_arange(): encoding = SinusoidalEncoding(model_dim=10, max_len=2) foo = encoding(torch.arange(4)) assert foo.shape == (4, 10) assert encoding.max_len() == 4 def test_add(): encoding = AddSinusoidalEncoding(d_model=10, max_len=4) x = torch.zeros(1, 4, 10) y = encoding(x) r = torch.randn(1, 4, 10) r2 = encoding(r) assert torch.allclose(r2 - r, y, atol=1e-07) r = torch.randn(2, 4, 10) r2 = encoding(r) assert torch.allclose(r2[0] - r[0], y, atol=1e-07) assert torch.allclose(r2[1] - r[1], y, atol=1e-07) ================================================ FILE: stanza/tests/constituency/test_selftrain_vi_quad.py ================================================ """ Test some of the methods in the vi_quad dataset Uses a small section of the dataset as a test """ import pytest from stanza.utils.datasets.constituency import selftrain_vi_quad from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] SAMPLE_TEXT = """ {"version": "1.1", "data": [{"title": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng", "paragraphs": [{"qas": [{"question": "T\u00ean g\u1ecdi n\u00e0o \u0111\u01b0\u1ee3c Ph\u1ea1m V\u0103n \u0110\u1ed3ng s\u1eed d\u1ee5ng khi l\u00e0m Ph\u00f3 ch\u1ee7 nhi\u1ec7m c\u01a1 quan Bi\u1ec7n s\u1ef1 x\u1ee9 t\u1ea1i Qu\u1ebf L\u00e2m?", "answers": [{"answer_start": 507, "text": "L\u00e2m B\u00e1 Ki\u1ec7t"}], "id": "uit_01__05272_0_1"}, {"question": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng gi\u1eef ch\u1ee9c v\u1ee5 g\u00ec trong b\u1ed9 m\u00e1y Nh\u00e0 n\u01b0\u1edbc C\u1ed9ng h\u00f2a X\u00e3 h\u1ed9i ch\u1ee7 ngh\u0129a Vi\u1ec7t Nam?", "answers": [{"answer_start": 60, "text": "Th\u1ee7 t\u01b0\u1edbng"}], "id": "uit_01__05272_0_2"}, {"question": "Giai \u0111o\u1ea1n n\u0103m 1955-1976, Ph\u1ea1m V\u0103n \u0110\u1ed3ng n\u1eafm gi\u1eef ch\u1ee9c v\u1ee5 g\u00ec?", "answers": [{"answer_start": 245, "text": "Th\u1ee7 t\u01b0\u1edbng Ch\u00ednh ph\u1ee7 Vi\u1ec7t Nam D\u00e2n ch\u1ee7 C\u1ed9ng h\u00f2a"}], "id": "uit_01__05272_0_3"}], "context": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng (1 th\u00e1ng 3 n\u0103m 1906 \u2013 29 th\u00e1ng 4 n\u0103m 2000) l\u00e0 Th\u1ee7 t\u01b0\u1edbng \u0111\u1ea7u ti\u00ean c\u1ee7a n\u01b0\u1edbc C\u1ed9ng h\u00f2a X\u00e3 h\u1ed9i ch\u1ee7 ngh\u0129a Vi\u1ec7t Nam t\u1eeb n\u0103m 1976 (t\u1eeb n\u0103m 1981 g\u1ecdi l\u00e0 Ch\u1ee7 t\u1ecbch H\u1ed9i \u0111\u1ed3ng B\u1ed9 tr\u01b0\u1edfng) cho \u0111\u1ebfn khi ngh\u1ec9 h\u01b0u n\u0103m 1987. Tr\u01b0\u1edbc \u0111\u00f3 \u00f4ng t\u1eebng gi\u1eef ch\u1ee9c v\u1ee5 Th\u1ee7 t\u01b0\u1edbng Ch\u00ednh ph\u1ee7 Vi\u1ec7t Nam D\u00e2n ch\u1ee7 C\u1ed9ng h\u00f2a t\u1eeb n\u0103m 1955 \u0111\u1ebfn n\u0103m 1976. \u00d4ng l\u00e0 v\u1ecb Th\u1ee7 t\u01b0\u1edbng Vi\u1ec7t Nam t\u1ea1i v\u1ecb l\u00e2u nh\u1ea5t (1955\u20131987). \u00d4ng l\u00e0 h\u1ecdc tr\u00f2, c\u1ed9ng s\u1ef1 c\u1ee7a Ch\u1ee7 t\u1ecbch H\u1ed3 Ch\u00ed Minh. \u00d4ng c\u00f3 t\u00ean g\u1ecdi th\u00e2n m\u1eadt l\u00e0 T\u00f4, \u0111\u00e2y t\u1eebng l\u00e0 b\u00ed danh c\u1ee7a \u00f4ng. \u00d4ng c\u00f2n c\u00f3 t\u00ean g\u1ecdi l\u00e0 L\u00e2m B\u00e1 Ki\u1ec7t khi l\u00e0m Ph\u00f3 ch\u1ee7 nhi\u1ec7m c\u01a1 quan Bi\u1ec7n s\u1ef1 x\u1ee9 t\u1ea1i Qu\u1ebf L\u00e2m (Ch\u1ee7 nhi\u1ec7m l\u00e0 H\u1ed3 H\u1ecdc L\u00e3m)."}, {"qas": [{"question": "S\u1ef1 ki\u1ec7n quan tr\u1ecdng n\u00e0o \u0111\u00e3 di\u1ec5n ra v\u00e0o ng\u00e0y 20/7/1954?", "answers": [{"answer_start": 364, "text": "b\u1ea3n Hi\u1ec7p \u0111\u1ecbnh \u0111\u00ecnh ch\u1ec9 chi\u1ebfn s\u1ef1 \u1edf Vi\u1ec7t Nam, Campuchia v\u00e0 L\u00e0o \u0111\u00e3 \u0111\u01b0\u1ee3c k\u00fd k\u1ebft th\u1eeba nh\u1eadn t\u00f4n tr\u1ecdng \u0111\u1ed9c l\u1eadp, ch\u1ee7 quy\u1ec1n, c\u1ee7a n\u01b0\u1edbc Vi\u1ec7t Nam, L\u00e0o v\u00e0 Campuchia"}], "id": "uit_01__05272_1_1"}, {"question": "Ch\u1ee9c v\u1ee5 m\u00e0 Ph\u1ea1m V\u0103n \u0110\u1ed3ng \u0111\u1ea3m nhi\u1ec7m t\u1ea1i H\u1ed9i ngh\u1ecb Gen\u00e8ve v\u1ec1 \u0110\u00f4ng D\u01b0\u01a1ng?", "answers": [{"answer_start": 33, "text": "Tr\u01b0\u1edfng ph\u00e1i \u0111o\u00e0n Ch\u00ednh ph\u1ee7"}], "id": "uit_01__05272_1_2"}, {"question": "H\u1ed9i ngh\u1ecb Gen\u00e8ve v\u1ec1 \u0110\u00f4ng D\u01b0\u01a1ng c\u00f3 t\u00ednh ch\u1ea5t nh\u01b0 th\u1ebf n\u00e0o?", "answers": [{"answer_start": 262, "text": "r\u1ea5t c\u0103ng th\u1eb3ng v\u00e0 ph\u1ee9c t\u1ea1p"}], "id": "uit_01__05272_1_3"}]}]}]} """ EXPECTED = ['Tên gọi nào được Phạm Văn Đồng sử dụng khi làm Phó chủ nhiệm cơ quan Biện sự xứ tại Quế Lâm?', 'Phạm Văn Đồng giữ chức vụ gì trong bộ máy Nhà nước Cộng hòa Xã hội chủ nghĩa Việt Nam?', 'Giai đoạn năm 1955-1976, Phạm Văn Đồng nắm giữ chức vụ gì?', 'Sự kiện quan trọng nào đã diễn ra vào ngày 20/7/1954?', 'Chức vụ mà Phạm Văn Đồng đảm nhiệm tại Hội nghị Genève về Đông Dương?', 'Hội nghị Genève về Đông Dương có tính chất như thế nào?'] def test_read_file(): results = selftrain_vi_quad.parse_quad(SAMPLE_TEXT) assert results == EXPECTED ================================================ FILE: stanza/tests/constituency/test_text_processing.py ================================================ """ Run through the various text processing methods for using the parser on text files / directories Uses a simple tree where the parser should always get it right, but things could potentially go wrong """ import glob import os import pytest from stanza import Pipeline from stanza.models.constituency import text_processing from stanza.models.constituency import tree_reader from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] @pytest.fixture(scope="module") def pipeline(): return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True) def test_read_tokenized_file(tmp_path): filename = str(tmp_path / "test_input.txt") with open(filename, "w") as fout: # test that the underscore token comes back with spaces fout.write("This is a_small test\nLine two\n") text, ids = text_processing.read_tokenized_file(filename) assert text == [['This', 'is', 'a small', 'test'], ['Line', 'two']] assert ids == [None, None] def test_parse_tokenized_sentences(pipeline): con_processor = pipeline.processors["constituency"] model = con_processor._model args = model.args sentences = [["This", "is", "a", "test"]] trees = text_processing.parse_tokenized_sentences(args, model, [pipeline], sentences) predictions = [x.predictions for x in trees] assert len(predictions) == 1 scored_trees = predictions[0] assert len(scored_trees) == 1 result = "{}".format(scored_trees[0].tree) expected = "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))" assert result == expected def test_parse_text(tmp_path, pipeline): con_processor = pipeline.processors["constituency"] model = con_processor._model args = model.args raw_file = str(tmp_path / "test_input.txt") with open(raw_file, "w") as fout: fout.write("This is a test\nThis is another test\n") output_file = str(tmp_path / "test_output.txt") text_processing.parse_text(args, model, [pipeline], tokenized_file=raw_file, predict_file=output_file) trees = tree_reader.read_treebank(output_file) trees = ["{}".format(x) for x in trees] expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))", "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"] assert trees == expected_trees def test_parse_dir(tmp_path, pipeline): con_processor = pipeline.processors["constituency"] model = con_processor._model args = model.args raw_dir = str(tmp_path / "input") os.makedirs(raw_dir) raw_f1 = str(tmp_path / "input" / "f1.txt") raw_f2 = str(tmp_path / "input" / "f2.txt") output_dir = str(tmp_path / "output") with open(raw_f1, "w") as fout: fout.write("This is a test") with open(raw_f2, "w") as fout: fout.write("This is another test") text_processing.parse_dir(args, model, [pipeline], raw_dir, output_dir) output_files = sorted(glob.glob(os.path.join(output_dir, "*"))) expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))", "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"] for output_file, expected_tree in zip(output_files, expected_trees): trees = tree_reader.read_treebank(output_file) assert len(trees) == 1 assert "{}".format(trees[0]) == expected_tree def test_parse_text(tmp_path, pipeline): con_processor = pipeline.processors["constituency"] model = con_processor._model args = dict(model.args) model_path = con_processor._config['model_path'] raw_file = str(tmp_path / "test_input.txt") with open(raw_file, "w") as fout: fout.write("This is a test\nThis is another test\n") output_file = str(tmp_path / "test_output.txt") args['tokenized_file'] = raw_file args['predict_file'] = output_file text_processing.load_model_parse_text(args, model_path, [pipeline]) trees = tree_reader.read_treebank(output_file) trees = ["{}".format(x) for x in trees] expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))", "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"] assert trees == expected_trees ================================================ FILE: stanza/tests/constituency/test_top_down_oracle.py ================================================ import pytest from stanza.models.constituency.base_model import SimpleModel from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, TransitionScheme from stanza.models.constituency.top_down_oracle import * from stanza.models.constituency.transition_sequence import build_sequence from stanza.models.constituency.tree_reader import read_trees from stanza.tests.constituency.test_transition_sequence import reconstruct_tree pytestmark = [pytest.mark.pipeline, pytest.mark.travis] OPEN_SHIFT_EXAMPLE_TREE = """ ( (S (NP (NNP Jennifer) (NNP Sh\'reyan)) (VP (VBZ has) (NP (RB nice) (NNS antennae))))) """ OPEN_SHIFT_PROBLEM_TREE = """ (ROOT (S (NP (NP (NP (DT The) (`` ``) (JJ Thin) (NNP Man) ('' '') (NN series)) (PP (IN of) (NP (NNS movies)))) (, ,) (CONJP (RB as) (RB well) (IN as)) (NP (JJ many) (NNS others)) (, ,)) (VP (VBD based) (NP (PRP$ their) (JJ entire) (JJ comedic) (NN appeal)) (PP (IN on) (NP (NP (DT the) (NN star) (NNS detectives) (POS ')) (JJ witty) (NNS quips) (CC and) (NNS puns))) (SBAR (IN as) (S (NP (NP (JJ other) (NNS characters)) (PP (IN in) (NP (DT the) (NNS movies)))) (VP (VBD were) (VP (VBN murdered)))))) (. .))) """ ROOT_LABELS = ["ROOT"] def get_single_repair(gold_sequence, wrong_transition, repair_fn, idx, *args, **kwargs): return repair_fn(gold_sequence[idx], wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None, *args, **kwargs) def build_state(model, tree, num_transitions): transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) states = model.initial_state_from_gold_trees([tree], [transitions]) for idx, t in enumerate(transitions[:num_transitions]): assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence) states = model.bulk_apply(states, [t]) state = states[0] return state def test_fix_open_shift(): trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] EXPECTED_FIX_EARLY = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] EXPECTED_FIX_LATE = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] assert transitions == EXPECTED_ORIG new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2) assert new_transitions == EXPECTED_FIX_EARLY new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 8) assert new_transitions == EXPECTED_FIX_LATE def test_fix_open_shift_observed_error(): """ Ran into an error on this tree, need to fix it The problem is the multiple Open in a row all need to be removed when a Shift happens """ trees = read_trees(OPEN_SHIFT_PROBLEM_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2) assert new_transitions is None new_transitions = get_single_repair(transitions, Shift(), fix_multiple_open_shift, 2) # Can break the expected transitions down like this: # [OpenConstituent(('ROOT',)), OpenConstituent(('S',)), # all gone: OpenConstituent(('NP',)), OpenConstituent(('NP',)), OpenConstituent(('NP',)), # Shift, Shift, Shift, Shift, Shift, Shift, # gone: CloseConstituent, # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), Shift, CloseConstituent, CloseConstituent, # gone: CloseConstituent, # Shift, OpenConstituent(('CONJP',)), Shift, Shift, Shift, CloseConstituent, OpenConstituent(('NP',)), Shift, Shift, CloseConstituent, Shift, # gone: CloseConstituent, # and then the rest: # OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)), # Shift, Shift, Shift, Shift, CloseConstituent, # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), # OpenConstituent(('NP',)), Shift, Shift, Shift, Shift, # CloseConstituent, Shift, Shift, Shift, Shift, CloseConstituent, # CloseConstituent, OpenConstituent(('SBAR',)), Shift, # OpenConstituent(('S',)), OpenConstituent(('NP',)), # OpenConstituent(('NP',)), Shift, Shift, CloseConstituent, # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), # Shift, Shift, CloseConstituent, CloseConstituent, # CloseConstituent, OpenConstituent(('VP',)), Shift, # OpenConstituent(('VP',)), Shift, CloseConstituent, # CloseConstituent, CloseConstituent, CloseConstituent, # CloseConstituent, Shift, CloseConstituent, CloseConstituent] expected_transitions = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), Shift(), Shift(), Shift(), Shift(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), Shift(), OpenConstituent('CONJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('SBAR'), Shift(), OpenConstituent('S'), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('VP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] assert new_transitions == expected_transitions def test_open_open_ambiguous_unary_fix(): trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] assert transitions == EXPECTED_ORIG new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_unary, 2) assert new_transitions == EXPECTED_FIX def test_open_open_ambiguous_later_fix(): trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] assert transitions == EXPECTED_ORIG new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_later, 2) assert new_transitions == EXPECTED_FIX CLOSE_SHIFT_EXAMPLE_TREE = """ ( (NP (DT a) (ADJP (NN stock) (HYPH -) (VBG picking)) (NN tool))) """ # not intended to be a correct tree CLOSE_SHIFT_DEEP_EXAMPLE_TREE = """ ( (NP (DT a) (VP (ADJP (NN stock) (HYPH -) (VBG picking))) (NN tool))) """ # not intended to be a correct tree CLOSE_SHIFT_OPEN_EXAMPLE_TREE = """ ( (NP (DT a) (ADJP (NN stock) (HYPH -) (VBG picking)) (NP (NN tool)))) """ CLOSE_SHIFT_AMBIGUOUS_TREE = """ ( (NP (DT a) (ADJP (NN stock) (HYPH -) (VBG picking)) (NN tool) (NN foo))) """ def test_fix_close_shift_ambiguous_immediate(): """ Test the result when a close/shift error occurs and we want to close the new, incorrect constituent immediately """ trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_later, 7) expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] assert transitions == expected_original assert new_sequence == expected_update def test_fix_close_shift_ambiguous_later(): # test that the one with two shifts, which is ambiguous, gets rejected trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_immediate, 7) expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] assert transitions == expected_original assert new_sequence == expected_update def test_oracle_with_optional_level(): tree = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)[0] gold_sequence = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) assert transitions == gold_sequence oracle = TopDownOracle(ROOT_LABELS, 1, "", "") model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY, root_labels=ROOT_LABELS) state = build_state(model, tree, 7) fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8], model=model, state=state) assert fix is RepairType.OTHER_CLOSE_SHIFT assert new_sequence is None oracle = TopDownOracle(ROOT_LABELS, 1, "CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR", "") fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8], model=model, state=state) assert fix is RepairType.CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR assert new_sequence == expected_update def test_fix_close_shift(): """ Test a tree of the kind we expect the close/shift to be able to get right """ trees = read_trees(CLOSE_SHIFT_EXAMPLE_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7) expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] assert transitions == expected_original assert new_sequence == expected_update # test that the one with two shifts, which is ambiguous, gets rejected trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7) assert new_sequence is None def test_fix_close_shift_deeper_tree(): """ Test a tree of the kind we expect the close/shift to be able to get right """ trees = read_trees(CLOSE_SHIFT_DEEP_EXAMPLE_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) for count_opens in [True, False]: new_sequence = get_single_repair(transitions, transitions[10], fix_close_shift, 8, count_opens=count_opens) expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] assert transitions == expected_original assert new_sequence == expected_update def test_fix_close_shift_open_tree(): """ We would like the close/shift to get this case right as well """ trees = read_trees(CLOSE_SHIFT_OPEN_EXAMPLE_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift, 7, count_opens=False) assert new_sequence is None new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift_with_opens, 7) expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] assert transitions == expected_original assert new_sequence == expected_update CLOSE_OPEN_EXAMPLE_TREE = """ ( (VP (VBZ eat) (NP (NN spaghetti)) (PP (IN with) (DT a) (NN fork)))) """ CLOSE_OPEN_DIFFERENT_LABEL_TREE = """ ( (VP (VBZ eat) (NP (NN spaghetti)) (NP (DT a) (NN fork)))) """ CLOSE_OPEN_TWO_LABELS_TREE = """ ( (VP (VBZ eat) (NP (NN spaghetti)) (PP (IN with) (DT a) (NN fork)) (PP (IN in) (DT a) (NN restaurant)))) """ def test_fix_close_open(): trees = read_trees(CLOSE_OPEN_EXAMPLE_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) assert isinstance(transitions[5], CloseConstituent) assert transitions[6] == OpenConstituent("PP") new_transitions = get_single_repair(transitions, transitions[6], fix_close_open_correct_open, 5) expected_original = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] expected_update = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] assert transitions == expected_original assert new_transitions == expected_update def test_fix_close_open_invalid(): for TREE in (CLOSE_OPEN_DIFFERENT_LABEL_TREE, CLOSE_OPEN_TWO_LABELS_TREE): trees = read_trees(TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) assert isinstance(transitions[5], CloseConstituent) assert isinstance(transitions[6], OpenConstituent) new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5) assert new_transitions is None def test_fix_close_open_ambiguous_immediate(): """ Test that a fix for an ambiguous close/open works as expected """ trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) assert isinstance(transitions[5], CloseConstituent) assert isinstance(transitions[6], OpenConstituent) reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN) assert tree == reconstructed new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5, check_close=False) reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN) expected = """ ( (VP (VBZ eat) (NP (NN spaghetti) (PP (IN with) (DT a) (NN fork))) (PP (IN in) (DT a) (NN restaurant)))) """ expected = read_trees(expected)[0] assert reconstructed == expected def test_fix_close_open_ambiguous_later(): """ Test that a fix for an ambiguous close/open works as expected """ trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) assert isinstance(transitions[5], CloseConstituent) assert isinstance(transitions[6], OpenConstituent) reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN) assert tree == reconstructed new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open_ambiguous_later, 5, check_close=False) reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN) expected = """ ( (VP (VBZ eat) (NP (NN spaghetti) (PP (IN with) (DT a) (NN fork)) (PP (IN in) (DT a) (NN restaurant))))) """ expected = read_trees(expected)[0] assert reconstructed == expected SHIFT_CLOSE_EXAMPLES = [ ("((S (NP (DT an) (NML (NNP Oct) (CD 19)) (NN review))))", "((S (NP (DT an) (NML (NNP Oct) (CD 19))) (NN review)))", 8), ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", "((S (NP (` `) (NP (DT The)) (NN Misanthrope) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", 6), ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", "((S (NP (` `) (NP (DT The) (NN Misanthrope))) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre)))))", 8), ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", "((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman)) (NNP Theatre)))))", 13), ] def test_shift_close(): for idx, (orig_tree, expected_tree, shift_position) in enumerate(SHIFT_CLOSE_EXAMPLES): trees = read_trees(orig_tree) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) if shift_position is None: print(transitions) continue assert isinstance(transitions[shift_position], Shift) new_transitions = get_single_repair(transitions, CloseConstituent(), fix_shift_close, shift_position) reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN) if expected_tree is None: print(transitions) print(new_transitions) print("{:P}".format(reconstructed)) else: expected_tree = read_trees(expected_tree) assert len(expected_tree) == 1 expected_tree = expected_tree[0] assert reconstructed == expected_tree def test_shift_open_ambiguous_unary(): """ Test what happens if a Shift is turned into an Open in an ambiguous manner """ trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] assert transitions == expected_original new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_unary, 4) expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] assert new_sequence == expected_updated def test_shift_open_ambiguous_later(): """ Test what happens if a Shift is turned into an Open in an ambiguous manner """ trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) assert len(trees) == 1 tree = trees[0] transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] assert transitions == expected_original new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_later, 4) expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] assert new_sequence == expected_updated ================================================ FILE: stanza/tests/constituency/test_trainer.py ================================================ from collections import defaultdict import logging import pathlib import tempfile import pytest import torch from torch import nn from torch import optim from stanza import Pipeline from stanza.models import constituency_parser from stanza.models.common import pretrain from stanza.models.common.bert_embedding import load_bert, load_tokenizer from stanza.models.common.foundation_cache import FoundationCache from stanza.models.common.utils import set_random_seed from stanza.models.constituency import lstm_model from stanza.models.constituency.parse_transitions import Transition from stanza.models.constituency import parser_training from stanza.models.constituency import trainer from stanza.models.constituency import tree_reader from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] logger = logging.getLogger('stanza.constituency.trainer') logger.setLevel(logging.WARNING) TREEBANK = """ ( (S (VP (VBG Enjoying) (NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition))) (. .))) ( (NP (VP (VBG Sitting) (PP (IN in) (NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station))) (VP (VBG waiting) (PP (IN for) (NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train))))) (. .))) ( (S (NP (PRP I)) (VP (ADVP (RB really)) (VBP hate) (NP (DT the) (NNP @MBTA))))) ( (S (S (VP (VB Seek))) (CC and) (S (NP (PRP ye)) (VP (MD shall) (VP (VB find)))) (. .))) """ def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK): # TODO: build a fake embedding some other way? train_trees = tree_reader.read_trees(treebank) dev_trees = train_trees[-1:] silver_trees = [] args = ['--wordvec_pretrain_file', wordvec_pretrain_file] + list(args) args = constituency_parser.parse_args(args) foundation_cache = FoundationCache() # might be None, unless we're testing loading an existing model model_load_name = args['load_name'] model, _, _, _ = parser_training.build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_name) assert isinstance(model.model, lstm_model.LSTMModel) return model class TestTrainer: @pytest.fixture(scope="class") def wordvec_pretrain_file(self): return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' @pytest.fixture(scope="class") def tiny_random_xlnet(self, tmp_path_factory): """ Download the tiny-random-xlnet model and make a concrete copy of it The issue here is that the "random" nature of the original makes it difficult or impossible to test that the values in the transformer don't change during certain operations. Saving a concrete instantiation of those random numbers makes it so we can test there is no difference when training only a subset of the layers, for example """ xlnet_name = 'hf-internal-testing/tiny-random-xlnet' xlnet_model, xlnet_tokenizer = load_bert(xlnet_name) path = str(tmp_path_factory.mktemp('tiny-random-xlnet')) xlnet_model.save_pretrained(path) xlnet_tokenizer.save_pretrained(path) return path @pytest.fixture(scope="class") def tiny_random_bart(self, tmp_path_factory): """ Download the tiny-random-bart model and make a concrete copy of it Issue is the same as with tiny_random_xlnet """ bart_name = 'hf-internal-testing/tiny-random-bart' bart_model, bart_tokenizer = load_bert(bart_name) path = str(tmp_path_factory.mktemp('tiny-random-bart')) bart_model.save_pretrained(path) bart_tokenizer.save_pretrained(path) return path def test_initial_model(self, wordvec_pretrain_file): """ does nothing, just tests that the construction went okay """ args = ['wordvec_pretrain_file', wordvec_pretrain_file] build_trainer(wordvec_pretrain_file) def test_save_load_model(self, wordvec_pretrain_file): """ Just tests that saving and loading works without crashs. Currently no test of the values themselves (checks some fields to make sure they are regenerated correctly) """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: tr = build_trainer(wordvec_pretrain_file) transitions = tr.model.transitions # attempt saving filename = os.path.join(tmpdirname, "parser.pt") tr.save(filename) assert os.path.exists(filename) # load it back in tr2 = tr.load(filename) trans2 = tr2.model.transitions assert(transitions == trans2) assert all(isinstance(x, Transition) for x in trans2) def test_relearn_structure(self, wordvec_pretrain_file): """ Test that starting a trainer with --relearn_structure copies the old model """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: set_random_seed(1000) args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'] tr = build_trainer(wordvec_pretrain_file, *args) # attempt saving filename = os.path.join(tmpdirname, "parser.pt") tr.save(filename) set_random_seed(1001) args = ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--relearn_structure', '--load_name', filename] tr2 = build_trainer(wordvec_pretrain_file, *args) assert torch.allclose(tr.model.delta_embedding.weight, tr2.model.delta_embedding.weight) assert torch.allclose(tr.model.output_layers[0].weight, tr2.model.output_layers[0].weight) # the norms will be the same, as the non-zero values are all the same assert torch.allclose(torch.linalg.norm(tr.model.word_lstm.weight_ih_l0), torch.linalg.norm(tr2.model.word_lstm.weight_ih_l0)) def write_treebanks(self, tmpdirname): train_treebank_file = os.path.join(tmpdirname, "train.mrg") with open(train_treebank_file, 'w', encoding='utf-8') as fout: fout.write(TREEBANK) fout.write(TREEBANK) eval_treebank_file = os.path.join(tmpdirname, "eval.mrg") with open(eval_treebank_file, 'w', encoding='utf-8') as fout: fout.write(TREEBANK) return train_treebank_file, eval_treebank_file def training_args(self, wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *additional_args): # let's not make the model huge... args = ['--pattn_num_layers', '0', '--pattn_d_model', '128', '--lattn_d_proj', '0', '--use_lattn', '--hidden_size', '20', '--delta_embedding_dim', '10', '--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname, '--save_dir', tmpdirname, '--save_name', 'test.pt', '--save_each_start', '0', '--save_each_name', os.path.join(tmpdirname, 'each_%02d.pt'), '--train_file', train_treebank_file, '--eval_file', eval_treebank_file, '--epoch_size', '6', '--train_batch_size', '3', '--shorthand', 'en_test'] args = args + list(additional_args) args = constituency_parser.parse_args(args) # just in case we change the defaults in the future args['wandb'] = None return args def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False, exists_ok=False, foundation_cache=None): """ Runs a test of the trainer for a few iterations. Checks some basic properties of the saved model, but doesn't check for the accuracy of the results """ if extra_args is None: extra_args = [] extra_args += ['--epochs', '%d' % num_epochs] train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) if use_silver: extra_args += ['--silver_file', str(eval_treebank_file)] args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args) each_name = args['save_each_name'] if not exists_ok: assert not os.path.exists(args['save_name']) retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True, dir=TEST_MODELS_DIR, foundation_cache=foundation_cache, download_method=None) trained_model = parser_training.train(args, None, [retag_pipeline]) # check that hooks are in the model if expected for p in trained_model.model.parameters(): if p.requires_grad: if args['grad_clipping'] is not None: assert len(p._backward_hooks) == 1 else: assert p._backward_hooks is None # check that the model can be loaded back assert os.path.exists(args['save_name']) peft_name = trained_model.model.peft_name tr = trainer.Trainer.load(args['save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name) assert tr.optimizer is not None assert tr.scheduler is not None assert tr.epochs_trained >= 1 for p in tr.model.parameters(): if p.requires_grad: assert p._backward_hooks is None tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name) assert tr.optimizer is not None assert tr.scheduler is not None assert tr.epochs_trained == num_epochs for i in range(1, num_epochs+1): model_name = each_name % i assert os.path.exists(model_name) tr = trainer.Trainer.load(model_name, load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name) assert tr.epochs_trained == i assert tr.batches_trained == (4 * i if use_silver else 2 * i) return args, trained_model def test_train(self, wordvec_pretrain_file): """ Test the whole thing for a few iterations on the fake data """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_train_test(wordvec_pretrain_file, tmpdirname) def test_early_dropout(self, wordvec_pretrain_file): """ Test the whole thing for a few iterations on the fake data """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: args = ['--early_dropout', '3'] _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) model = model.model dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)] assert len(dropouts) > 0, "Didn't find any dropouts in the model!" for name, module in dropouts: assert module.p == 0.0, "Dropout module %s was not set to 0 with early_dropout" with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: # test that when turned off, early_dropout doesn't happen args = ['--early_dropout', '-1'] _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) model = model.model dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)] assert len(dropouts) > 0, "Didn't find any dropouts in the model!" if all(module.p == 0.0 for _, module in dropouts): raise AssertionError("All dropouts were 0 after training even though early_dropout was set to -1") def test_train_silver(self, wordvec_pretrain_file): """ Test the whole thing for a few iterations on the fake data This tests that it works if you give it a silver file The check for the use of the silver data is that the number of batches trained should go up """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True) def test_train_checkpoint(self, wordvec_pretrain_file): """ Test the whole thing for a few iterations, then restart This tests that the 5th iteration save file is not rewritten and that the iterations continue to 10 TODO: could make it more robust by verifying that only 5 more epochs are trained. Perhaps a "most recent epochs" could be saved in the trainer """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=False) save_5 = args['save_each_name'] % 5 save_10 = args['save_each_name'] % 10 assert os.path.exists(save_5) assert not os.path.exists(save_10) save_5_stat = pathlib.Path(save_5).stat() self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=10, use_silver=False, exists_ok=True) assert os.path.exists(save_5) assert os.path.exists(save_10) assert pathlib.Path(save_5).stat().st_mtime == save_5_stat.st_mtime def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None): train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) args = ['--multistage', '--pattn_num_layers', '1'] if use_lattn: args += ['--lattn_d_proj', '16'] if extra_args: args += extra_args args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args) each_name = os.path.join(args['save_dir'], 'each_%02d.pt') word_input_sizes = defaultdict(list) for i in range(1, 9): model_name = each_name % i assert os.path.exists(model_name) tr = trainer.Trainer.load(model_name, load_optimizer=True) assert tr.epochs_trained == i word_input_sizes[tr.model.word_input_size].append(i) if use_lattn: # there should be three stages: no attn, pattn, pattn+lattn assert len(word_input_sizes) == 3 word_input_keys = sorted(word_input_sizes.keys()) assert word_input_sizes[word_input_keys[0]] == [1, 2, 3] assert word_input_sizes[word_input_keys[1]] == [4, 5] assert word_input_sizes[word_input_keys[2]] == [6, 7, 8] else: # with no lattn, there are two stages: no attn, pattn assert len(word_input_sizes) == 2 word_input_keys = sorted(word_input_sizes.keys()) assert word_input_sizes[word_input_keys[0]] == [1, 2, 3] assert word_input_sizes[word_input_keys[1]] == [4, 5, 6, 7, 8] def test_multistage_lattn(self, wordvec_pretrain_file): """ Test a multistage training for a few iterations on the fake data This should start with no pattn or lattn, have pattn in the middle, then lattn at the end """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True) def test_multistage_no_lattn(self, wordvec_pretrain_file): """ Test a multistage training for a few iterations on the fake data This should start with no pattn or lattn, have pattn in the middle, then lattn at the end """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False) def test_multistage_optimizer(self, wordvec_pretrain_file): """ Test that the correct optimizers are built for a multistage training process """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: extra_args = ['--optim', 'adamw'] self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args) # check that the optimizers which get rebuilt when loading # the models are adadelta for the first half of the # multistage, then adamw each_name = os.path.join(tmpdirname, 'each_%02d.pt') for i in range(1, 3): model_name = each_name % i tr = trainer.Trainer.load(model_name, load_optimizer=True) assert tr.epochs_trained == i assert isinstance(tr.optimizer, optim.Adadelta) # double check that this is actually a valid test assert not isinstance(tr.optimizer, optim.AdamW) for i in range(4, 8): model_name = each_name % i tr = trainer.Trainer.load(model_name, load_optimizer=True) assert tr.epochs_trained == i assert isinstance(tr.optimizer, optim.AdamW) def test_grad_clip_hooks(self, wordvec_pretrain_file): """ Verify that grad clipping is not saved with the model, but is attached at training time """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: args = ['--grad_clipping', '25'] self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args) def test_analyze_trees(self, wordvec_pretrain_file): test_str = "(ROOT (S (NP (PRP I)) (VP (VBP wan) (S (VP (TO na) (VP (VB lick) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))))) (ROOT (S (NP (DT This) (NN interface)) (VP (VBZ sucks))))" test_tree = tree_reader.read_trees(test_str) assert len(test_tree) == 2 args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'] tr = build_trainer(wordvec_pretrain_file, *args) results = tr.model.analyze_trees(test_tree) assert len(results) == 2 assert len(results[0].predictions) == 1 assert results[0].predictions[0].tree == test_tree[0] assert results[0].state is not None assert isinstance(results[0].state.score, torch.Tensor) assert results[0].state.score.shape == torch.Size([]) assert len(results[0].constituents) == 9 assert results[0].constituents[-1].value == test_tree[0] # the way the results are built, the next-to-last entry # should be the thing just below the root assert results[0].constituents[-2].value == test_tree[0].children[0] assert len(results[1].predictions) == 1 assert results[1].predictions[0].tree == test_tree[1] assert results[1].state is not None assert isinstance(results[1].state.score, torch.Tensor) assert results[1].state.score.shape == torch.Size([]) assert len(results[1].constituents) == 4 assert results[1].constituents[-1].value == test_tree[1] assert results[1].constituents[-2].value == test_tree[1].children[0] def bert_weights_allclose(self, bert_model, parser_model): """ Return True if all bert weights are close, False otherwise """ for name, parameter in bert_model.named_parameters(): other_name = "bert_model." + name other_parameter = parser_model.model.get_parameter(other_name) if not torch.allclose(parameter.cpu(), other_parameter.cpu()): return False return True def frozen_transformer_test(self, wordvec_pretrain_file, transformer_name): with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: foundation_cache = FoundationCache() args = ['--bert_model', transformer_name] args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args, foundation_cache=foundation_cache) bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name) assert self.bert_weights_allclose(bert_model, trained_model) checkpoint = torch.load(args['save_name'], lambda storage, loc: storage, weights_only=True) params = checkpoint['params'] # check that the bert model wasn't saved in the model assert all(not x.startswith("bert_model.") for x in params['model'].keys()) # make sure we're looking at the right thing assert any(x.startswith("output_layers.") for x in params['model'].keys()) # check that the cached model is used as expected when loading a bert model trained_model = trainer.Trainer.load(args['save_name'], foundation_cache=foundation_cache) assert trained_model.model.bert_model is bert_model def test_bert_frozen(self, wordvec_pretrain_file): """ Check that the parameters of the bert model don't change when training a basic model """ self.frozen_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert') def test_xlnet_frozen(self, wordvec_pretrain_file, tiny_random_xlnet): """ Check that the parameters of an xlnet model don't change when training a basic model """ self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_xlnet) def test_bart_frozen(self, wordvec_pretrain_file, tiny_random_bart): """ Check that the parameters of an xlnet model don't change when training a basic model """ self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_bart) def test_bert_finetune_one_epoch(self, wordvec_pretrain_file): """ Check that the parameters the bert model DO change over a single training step """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: transformer_name = 'hf-internal-testing/tiny-bert' args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adadelta'] args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=1, extra_args=args) # check that the weights are different foundation_cache = FoundationCache() bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name) assert not self.bert_weights_allclose(bert_model, trained_model) # double check that a new bert is created instead of using the FoundationCache when the bert has been trained model_name = args['save_name'] assert os.path.exists(model_name) no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name) tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache) assert tr.model.bert_model is not bert_model assert not self.bert_weights_allclose(bert_model, tr) assert self.bert_weights_allclose(trained_model.model.bert_model, tr) new_save_name = os.path.join(tmpdirname, "test_resave_bert.pt") assert not os.path.exists(new_save_name) tr.save(new_save_name, save_optimizer=False) tr2 = trainer.Trainer.load(new_save_name, args=no_finetune_args, foundation_cache=foundation_cache) # check that the resaved model included its finetuned bert weights assert tr2.model.bert_model is not bert_model # the finetuned bert weights should also be scheduled for saving the next time as well assert not tr2.model.is_unsaved_module("bert_model") def finetune_transformer_test(self, wordvec_pretrain_file, transformer_name): """ Check that the parameters of the transformer DO change when using bert_finetune """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw'] args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args) # check that the weights are different foundation_cache = FoundationCache() bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name) assert not self.bert_weights_allclose(bert_model, trained_model) # double check that a new bert is created instead of using the FoundationCache when the bert has been trained no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name) trained_model = trainer.Trainer.load(args['save_name'], args=no_finetune_args, foundation_cache=foundation_cache) assert not trained_model.model.args['bert_finetune'] assert not trained_model.model.args['stage1_bert_finetune'] assert trained_model.model.bert_model is not bert_model def test_bert_finetune(self, wordvec_pretrain_file): """ Check that the parameters of a bert model DO change when using bert_finetune """ self.finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert') def test_xlnet_finetune(self, wordvec_pretrain_file, tiny_random_xlnet): """ Check that the parameters of an xlnet model DO change when using bert_finetune """ self.finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet) def test_stage1_bert_finetune(self, wordvec_pretrain_file): """ Check that the parameters the bert model DO change when using stage1_bert_finetune, but only for the first couple steps """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: bert_model_name = 'hf-internal-testing/tiny-bert' args = ['--bert_model', bert_model_name, '--stage1_bert_finetune', '--optim', 'adamw'] # need to use num_epochs==6 so that epochs 1 and 2 are saved to be different # a test of 5 or less means that sometimes it will reload the params # at step 2 to get ready for the following iterations with adamw args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) # check that the weights are different foundation_cache = FoundationCache() bert_model, bert_tokenizer = foundation_cache.load_bert(bert_model_name) assert not self.bert_weights_allclose(bert_model, trained_model) # double check that a new bert is created instead of using the FoundationCache when the bert has been trained no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', bert_model_name, '--optim', 'adamw') num_epochs = trained_model.model.args['epochs'] each_name = os.path.join(tmpdirname, 'each_%02d.pt') for i in range(1, num_epochs+1): model_name = each_name % i assert os.path.exists(model_name) tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache) assert tr.model.bert_model is not bert_model assert not self.bert_weights_allclose(bert_model, tr) if i >= num_epochs // 2: assert self.bert_weights_allclose(trained_model.model.bert_model, tr) # verify that models 1 and 2 are saved to be different model_name_1 = each_name % 1 model_name_2 = each_name % 2 tr_1 = trainer.Trainer.load(model_name_1, args=no_finetune_args, foundation_cache=foundation_cache) tr_2 = trainer.Trainer.load(model_name_2, args=no_finetune_args, foundation_cache=foundation_cache) assert not self.bert_weights_allclose(tr_1.model.bert_model, tr_2) def one_layer_finetune_transformer_test(self, wordvec_pretrain_file, transformer_name): """ Check that the parameters the bert model DO change when using bert_finetune """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: args = ['--bert_model', transformer_name, '--bert_finetune', '--bert_finetune_layers', '1', '--optim', 'adamw', '--bert_finetune_layers', '1'] args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args) # check that the weights of the last layer are different, # but the weights of the earlier layers and # non-transformer-layers are the same foundation_cache = FoundationCache() bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name) assert bert_model.config.num_hidden_layers > 1 layer_name = "layer.%d." % (bert_model.config.num_hidden_layers - 1) for name, parameter in bert_model.named_parameters(): other_name = "bert_model." + name other_parameter = trained_model.model.get_parameter(other_name) if layer_name in name: if 'rel_attn.seg_embed' in name or 'rel_attn.r_s_bias' in name: # not sure why this happens for xlnet, just roll with it continue assert not torch.allclose(parameter.cpu(), other_parameter.cpu()) else: assert torch.allclose(parameter.cpu(), other_parameter.cpu()) def test_bert_finetune_one_layer(self, wordvec_pretrain_file): self.one_layer_finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert') def test_xlnet_finetune_one_layer(self, wordvec_pretrain_file, tiny_random_xlnet): self.one_layer_finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet) def test_peft_finetune(self, tmp_path, wordvec_pretrain_file): transformer_name = 'hf-internal-testing/tiny-bert' args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw', '--use_peft'] args, trained_model = self.run_train_test(wordvec_pretrain_file, str(tmp_path), extra_args=args) def test_peft_twostage_finetune(self, wordvec_pretrain_file): with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: num_epochs = 6 transformer_name = 'hf-internal-testing/tiny-bert' args = ['--bert_model', transformer_name, '--stage1_bert_finetune', '--optim', 'adamw', '--use_peft'] args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=num_epochs, extra_args=args) for epoch in range(num_epochs): filename_prev = args['save_each_name'] % epoch filename_next = args['save_each_name'] % (epoch+1) trainer_prev = trainer.Trainer.load(filename_prev, args=args, load_optimizer=False) trainer_next = trainer.Trainer.load(filename_next, args=args, load_optimizer=False) lora_names = [name for name, _ in trainer_prev.model.bert_model.named_parameters() if name.find("lora") >= 0] if epoch < 2: assert not any(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(), trainer_next.model.bert_model.get_parameter(name).cpu()) for name in lora_names) elif epoch > 2: assert all(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(), trainer_next.model.bert_model.get_parameter(name).cpu()) for name in lora_names) ================================================ FILE: stanza/tests/constituency/test_transformer_tree_stack.py ================================================ import pytest import torch from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def test_initial_state(): """ Test that the initial state has the expected shapes """ ts = TransformerTreeStack(3, 5, 0.0) initial = ts.initial_state() assert len(initial) == 1 assert initial.value.output.shape == torch.Size([5]) assert initial.value.key_stack.shape == torch.Size([1, 5]) assert initial.value.value_stack.shape == torch.Size([1, 5]) def test_output(): """ Test that you can get an expected output shape from the TTS """ ts = TransformerTreeStack(3, 5, 0.0) initial = ts.initial_state() out = ts.output(initial) assert out.shape == torch.Size([5]) assert torch.allclose(initial.value.output, out) def test_push_state_single(): """ Test that stacks are being updated correctly when using a single stack Values of the attention are not verified, though """ ts = TransformerTreeStack(3, 5, 0.0) initial = ts.initial_state() rand_input = torch.randn(1, 3) stacks = ts.push_states([initial], ["A"], rand_input) stacks = ts.push_states(stacks, ["B"], rand_input) assert len(stacks) == 1 assert len(stacks[0]) == 3 assert stacks[0].value.value == "B" assert stacks[0].pop().value.value == "A" assert stacks[0].pop().pop().value.value is None def test_push_state_same_length(): """ Test that stacks are being updated correctly when using 3 stacks of the same length Values of the attention are not verified, though """ ts = TransformerTreeStack(3, 5, 0.0) initial = ts.initial_state() rand_input = torch.randn(3, 3) stacks = ts.push_states([initial, initial, initial], ["A", "A", "A"], rand_input) stacks = ts.push_states(stacks, ["B", "B", "B"], rand_input) stacks = ts.push_states(stacks, ["C", "C", "C"], rand_input) assert len(stacks) == 3 for s in stacks: assert len(s) == 4 assert s.value.key_stack.shape == torch.Size([4, 5]) assert s.value.value_stack.shape == torch.Size([4, 5]) assert s.value.value == "C" assert s.pop().value.value == "B" assert s.pop().pop().value.value == "A" assert s.pop().pop().pop().value.value is None def test_push_state_different_length(): """ Test what happens if stacks of different lengths are passed in """ ts = TransformerTreeStack(3, 5, 0.0) initial = ts.initial_state() rand_input = torch.randn(2, 3) one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0] stacks = [one_step, initial] stacks = ts.push_states(stacks, ["B", "C"], rand_input) assert len(stacks) == 2 assert len(stacks[0]) == 3 assert len(stacks[1]) == 2 assert stacks[0].pop().value.value == 'A' assert stacks[0].value.value == 'B' assert stacks[1].value.value == 'C' assert stacks[0].value.key_stack.shape == torch.Size([3, 5]) assert stacks[1].value.key_stack.shape == torch.Size([2, 5]) def test_mask(): """ Test that a mask prevents the softmax from picking up unwanted values """ ts = TransformerTreeStack(3, 5, 0.0) random_v = torch.tensor([[[0.1, 0.2, 0.3, 0.4, 0.5]]]) double_v = random_v * 2 value = torch.cat([random_v, double_v], axis=1) random_k = torch.randn(1, 1, 5) key = torch.cat([random_k, random_k], axis=1) query = torch.randn(1, 5) output = ts.attention(key, query, value) # when the two keys are equal, we expect the attention to be 50/50 expected_output = (random_v + double_v) / 2 assert torch.allclose(output, expected_output) # If the first entry is masked out, the second one should be the # only one represented mask = torch.zeros(1, 2, dtype=torch.bool) mask[0][0] = True output = ts.attention(key, query, value, mask) assert torch.allclose(output, double_v) # If the second entry is masked out, the first one should be the # only one represented mask = torch.zeros(1, 2, dtype=torch.bool) mask[0][1] = True output = ts.attention(key, query, value, mask) assert torch.allclose(output, random_v) def test_position(): """ Test that nothing goes horribly wrong when position encodings are used Does not actually test the results of the encodings """ ts = TransformerTreeStack(4, 5, 0.0, use_position=True) initial = ts.initial_state() assert len(initial) == 1 assert initial.value.output.shape == torch.Size([5]) assert initial.value.key_stack.shape == torch.Size([1, 5]) assert initial.value.value_stack.shape == torch.Size([1, 5]) rand_input = torch.randn(2, 4) one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0] stacks = [one_step, initial] stacks = ts.push_states(stacks, ["B", "C"], rand_input) def test_length_limit(): """ Test that the length limit drops nodes as the length limit is exceeded """ ts = TransformerTreeStack(4, 5, 0.0, length_limit = 2) initial = ts.initial_state() assert len(initial) == 1 assert initial.value.output.shape == torch.Size([5]) assert initial.value.key_stack.shape == torch.Size([1, 5]) assert initial.value.value_stack.shape == torch.Size([1, 5]) data = torch.tensor([[0.1, 0.2, 0.3, 0.4]]) stacks = ts.push_states([initial], ["A"], data) stacks = ts.push_states(stacks, ["B"], data) assert len(stacks) == 1 assert len(stacks[0]) == 3 assert stacks[0].value.key_stack.shape[0] == 3 assert stacks[0].value.value_stack.shape[0] == 3 stacks = ts.push_states(stacks, ["C"], data) assert len(stacks) == 1 assert len(stacks[0]) == 4 assert stacks[0].value.key_stack.shape[0] == 3 assert stacks[0].value.value_stack.shape[0] == 3 stacks = ts.push_states(stacks, ["D"], data) assert len(stacks) == 1 assert len(stacks[0]) == 5 assert stacks[0].value.key_stack.shape[0] == 3 assert stacks[0].value.value_stack.shape[0] == 3 def test_two_heads(): """ Test that the length limit drops nodes as the length limit is exceeded """ ts = TransformerTreeStack(4, 6, 0.0, num_heads = 2) initial = ts.initial_state() assert len(initial) == 1 assert initial.value.output.shape == torch.Size([6]) assert initial.value.key_stack.shape == torch.Size([1, 6]) assert initial.value.value_stack.shape == torch.Size([1, 6]) rand_input = torch.randn(2, 4) one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0] stacks = [one_step, initial] stacks = ts.push_states(stacks, ["B", "C"], rand_input) assert len(stacks) == 2 assert len(stacks[0]) == 3 assert len(stacks[1]) == 2 assert stacks[0].pop().value.value == 'A' assert stacks[0].value.value == 'B' assert stacks[1].value.value == 'C' assert stacks[0].value.key_stack.shape == torch.Size([3, 6]) assert stacks[1].value.key_stack.shape == torch.Size([2, 6]) ================================================ FILE: stanza/tests/constituency/test_transition_sequence.py ================================================ import pytest from stanza.models.constituency import parse_transitions from stanza.models.constituency import transition_sequence from stanza.models.constituency import tree_reader from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT from stanza.models.constituency.parse_transitions import * from stanza.tests import * from stanza.tests.constituency.test_parse_tree import CHINESE_LONG_LIST_TREE pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def reconstruct_tree(tree, sequence, transition_scheme=TransitionScheme.IN_ORDER, unary_limit=UNARY_LIMIT, reverse=False): """ Starting from a tree and a list of transitions, build the tree caused by the transitions """ model = SimpleModel(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse) states = model.initial_state_from_gold_trees([tree]) assert(len(states)) == 1 assert states[0].num_transitions == 0 # TODO: could fold this into parse_sentences (similar to verify_transitions in trainer.py) for idx, t in enumerate(sequence): assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence) states = model.bulk_apply(states, [t]) result_tree = states[0].constituents.value if reverse: result_tree = result_tree.reverse() return result_tree def check_reproduce_tree(transition_scheme): text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) model = SimpleModel(transition_scheme) transitions = transition_sequence.build_sequence(trees[0], transition_scheme) states = model.initial_state_from_gold_trees(trees) assert(len(states)) == 1 state = states[0] assert state.num_transitions == 0 for t in transitions: assert t.is_legal(state, model) state = t.apply(state, model) # one item for the final tree # one item for the sentinel at the end assert len(state.constituents) == 2 # the transition sequence should put all of the words # from the buffer onto the tree # one spot left for the sentinel value assert len(state.word_queue) == 8 assert state.sentence_length == 6 assert state.word_position == state.sentence_length assert len(state.transitions) == len(transitions) + 1 result_tree = state.constituents.value assert result_tree == trees[0] def test_top_down_unary(): check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN_UNARY) def test_top_down_no_unary(): check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN) def test_in_order(): check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER) def test_in_order_compound(): check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND) def test_in_order_unary(): check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_UNARY) def test_all_transitions(): text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) model = SimpleModel() transitions = transition_sequence.build_treebank(trees) expected = [Shift(), CloseConstituent(), CompoundUnary("ROOT"), CompoundUnary("SQ"), CompoundUnary("WHNP"), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("SBARQ"), OpenConstituent("VP")] assert transition_sequence.all_transitions(transitions) == expected def test_all_transitions_no_unary(): text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) model = SimpleModel() transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN) expected = [Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("ROOT"), OpenConstituent("SBARQ"), OpenConstituent("SQ"), OpenConstituent("VP"), OpenConstituent("WHNP")] assert transition_sequence.all_transitions(transitions) == expected def test_top_down_compound_unary(): text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))" trees = tree_reader.read_trees(text) assert len(trees) == 1 model = SimpleModel() transitions = transition_sequence.build_sequence(trees[0], transition_scheme=TransitionScheme.TOP_DOWN_COMPOUND) states = model.initial_state_from_gold_trees(trees) assert len(states) == 1 state = states[0] for t in transitions: assert t.is_legal(state, model) state = t.apply(state, model) result = model.get_top_constituent(state.constituents) assert trees[0] == result def test_chinese_tree(): trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE) transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN) redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN) assert redone == trees[0] transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER) with pytest.raises(AssertionError): redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER) redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6) assert redone == trees[0] def test_chinese_tree_reversed(): """ test that the reversed transitions also work """ trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE) transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN, reverse=True) redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN, reverse=True) assert redone == trees[0] with pytest.raises(AssertionError): # turn off reverse - it should fail to rebuild the tree redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN) assert redone == trees[0] transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER, reverse=True) with pytest.raises(AssertionError): redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, reverse=True) redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6, reverse=True) assert redone == trees[0] with pytest.raises(AssertionError): # turn off reverse - it should fail to rebuild the tree redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6) assert redone == trees[0] ================================================ FILE: stanza/tests/constituency/test_tree_reader.py ================================================ import pytest from stanza.models.constituency import tree_reader from stanza.models.constituency.tree_reader import MixedTreeError, UnclosedTreeError, UnlabeledTreeError from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def test_simple(): """ Tests reading two simple trees from the same text """ text = "(VB Unban) (NNP Opal)" trees = tree_reader.read_trees(text) assert len(trees) == 2 assert trees[0].is_preterminal() assert trees[0].label == 'VB' assert trees[0].children[0].label == 'Unban' assert trees[1].is_preterminal() assert trees[1].label == 'NNP' assert trees[1].children[0].label == 'Opal' def test_newlines(): """ The same test should work if there are newlines """ text = "(VB Unban)\n\n(NNP Opal)" trees = tree_reader.read_trees(text) assert len(trees) == 2 def test_parens(): """ Parens should be escaped in the tree files and escaped when written """ text = "(-LRB- -LRB-) (-RRB- -RRB-)" trees = tree_reader.read_trees(text) assert len(trees) == 2 assert trees[0].label == '-LRB-' assert trees[0].children[0].label == '(' assert "{}".format(trees[0]) == '(-LRB- -LRB-)' assert trees[1].label == '-RRB-' assert trees[1].children[0].label == ')' assert "{}".format(trees[1]) == '(-RRB- -RRB-)' def test_complicated(): """ A more complicated tree that should successfully read """ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) assert len(trees) == 1 tree = trees[0] assert not tree.is_leaf() assert not tree.is_preterminal() assert tree.label == 'ROOT' assert len(tree.children) == 1 assert tree.children[0].label == 'SBARQ' assert len(tree.children[0].children) == 3 assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.'] # etc etc def test_one_word(): """ Check that one node trees are correctly read probably not super relevant for the parsing use case """ text="(FOO) (BAR)" trees = tree_reader.read_trees(text) assert len(trees) == 2 assert trees[0].is_leaf() assert trees[0].label == 'FOO' assert trees[1].is_leaf() assert trees[1].label == 'BAR' def test_missing_close_parens(): """ Test the unclosed error condition """ text = "(Foo) \n (Bar \n zzz" try: trees = tree_reader.read_trees(text) raise AssertionError("Expected an exception") except UnclosedTreeError as e: assert e.line_num == 1 def test_mixed_tree(): """ Test the mixed error condition """ text = "(Foo) \n (Bar) \n (Unban (Mox) Opal)" try: trees = tree_reader.read_trees(text) raise AssertionError("Expected an exception") except MixedTreeError as e: assert e.line_num == 2 trees = tree_reader.read_trees(text, broken_ok=True) assert len(trees) == 3 def test_unlabeled_tree(): """ Test the unlabeled error condition """ text = "(ROOT ((Foo) (Bar)))" try: trees = tree_reader.read_trees(text) raise AssertionError("Expected an exception") except UnlabeledTreeError as e: assert e.line_num == 0 trees = tree_reader.read_trees(text, broken_ok=True) assert len(trees) == 1 ================================================ FILE: stanza/tests/constituency/test_tree_stack.py ================================================ import pytest from stanza.models.constituency.tree_stack import TreeStack from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def test_simple(): stack = TreeStack(value=5, parent=None, length=1) stack = stack.push(3) stack = stack.push(1) expected_values = [1, 3, 5] for value in expected_values: assert stack.value == value stack = stack.pop() assert stack is None def test_iter(): stack = TreeStack(value=5, parent=None, length=1) stack = stack.push(3) stack = stack.push(1) stack_list = list(stack) assert list(stack) == [1, 3, 5] def test_str(): stack = TreeStack(value=5, parent=None, length=1) stack = stack.push(3) stack = stack.push(1) assert str(stack) == "TreeStack(1, 3, 5)" def test_len(): stack = TreeStack(value=5, parent=None, length=1) assert len(stack) == 1 stack = stack.push(3) stack = stack.push(1) assert len(stack) == 3 def test_long_len(): """ Original stack had a bug where this took exponential time... """ stack = TreeStack(value=0, parent=None, length=1) for i in range(1, 40): stack = stack.push(i) assert len(stack) == 40 ================================================ FILE: stanza/tests/constituency/test_utils.py ================================================ import pytest from stanza import Pipeline from stanza.models.constituency import tree_reader from stanza.models.constituency import utils from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] @pytest.fixture(scope="module") def pipeline(): return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos", tokenize_pretokenized=True) def test_xpos_retag(pipeline): """ Test using the English tagger that trees will be correctly retagged by read_trees using xpos """ text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))" expected = "((S (VP (VB Find)) (NP (NNP Mox) (NNP Opal)))) ((S (NP (NNP Ragavan)) (VP (VBZ steals) (NP (JJ important) (NNS cards)))))" trees = tree_reader.read_trees(text) new_trees = utils.retag_trees(trees, [pipeline], xpos=True) assert new_trees == tree_reader.read_trees(expected) def test_upos_retag(pipeline): """ Test using the English tagger that trees will be correctly retagged by read_trees using upos """ text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))" expected = "((S (VP (VERB Find)) (NP (PROPN Mox) (PROPN Opal)))) ((S (NP (PROPN Ragavan)) (VP (VERB steals) (NP (ADJ important) (NOUN cards)))))" trees = tree_reader.read_trees(text) new_trees = utils.retag_trees(trees, [pipeline], xpos=False) assert new_trees == tree_reader.read_trees(expected) def test_replace_tags(): """ Test the underlying replace_tags method Also tests that the method throws exceptions when it is supposed to """ text = "((S (VP (X Find)) (NP (X Mox) (X Opal))))" expected = "((S (VP (A Find)) (NP (B Mox) (C Opal))))" trees = tree_reader.read_trees(text) new_tags = ["A", "B", "C"] new_tree = trees[0].replace_tags(new_tags) assert new_tree == tree_reader.read_trees(expected)[0] with pytest.raises(ValueError): new_tags = ["A", "B"] new_tree = trees[0].replace_tags(new_tags) with pytest.raises(ValueError): new_tags = ["A", "B", "C", "D"] new_tree = trees[0].replace_tags(new_tags) ================================================ FILE: stanza/tests/constituency/test_vietnamese.py ================================================ """ A few tests for Vietnamese parsing, which has some difficulties related to spaces in words Technically some other languages can have this, too, like that one French token """ import os import tempfile import pytest from stanza.models.common import pretrain from stanza.models.constituency import tree_reader from stanza.tests.constituency.test_trainer import build_trainer pytestmark = [pytest.mark.pipeline, pytest.mark.travis] VI_TREEBANK = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài Loan) (" ") (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))' VI_TREEBANK_UNDERSCORE = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài_Loan) (" ") (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .)))' VI_TREEBANK_SIMPLE = '(ROOT (S (NP (" ") (N Đảo) (Np Đài Loan) (" ") (PP (E ở) (NP (N đồng bằng) (NP (N sông) (Np Cửu Long))))) (. .)))' VI_TREEBANK_PAREN = '(ROOT (S-TTL (NP (PUNCT -LRB-) (N-H Đảo) (Np Đài Loan) (PUNCT -RRB-) (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))' VI_TREEBANK_VLSP = '\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n' VI_TREEBANK_VLSP_50 = '\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n' VI_TREEBANK_VLSP_100 = '\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n' EXPECTED_LABELED_BRACKETS = '(_ROOT (_S (_NP (_" " )_" (_N Đảo )_N (_Np Đài_Loan )_Np (_" " )_" (_PP (_E ở )_E (_NP (_N đồng_bằng )_N (_NP (_N sông )_N (_Np Cửu_Long )_Np )_NP )_NP )_PP )_NP (_. . )_. )_S )_ROOT' def test_read_vi_tree(): """ Test that an individual tree with spaces in the leaves is being processed as we expect """ text = VI_TREEBANK.split("\n")[0] trees = tree_reader.read_trees(text) assert len(trees) == 1 assert str(trees[0]) == text # this is the first NP # the third node of that NP, eg (Np Đài Loan) node = trees[0].children[0].children[0].children[2] assert node.is_preterminal() assert node.children[0].label == "Đài Loan" VI_EMBEDDING = """ 4 4 Đảo 0.11 0.21 0.31 0.41 Đài Loan 0.12 0.22 0.32 0.42 đồng bằng 0.13 0.23 0.33 0.43 sông 0.14 0.24 0.34 0.44 """.strip() def test_vi_embedding(): """ Test that a VI embedding's words are correctly found when processing trees """ text = VI_TREEBANK.split("\n")[0] trees = tree_reader.read_trees(text) words = set(trees[0].leaf_labels()) with tempfile.TemporaryDirectory() as tempdir: emb_filename = os.path.join(tempdir, "emb.txt") pt_filename = os.path.join(tempdir, "emb.pt") with open(emb_filename, "w", encoding="utf-8") as fout: fout.write(VI_EMBEDDING) pt = pretrain.Pretrain(filename=pt_filename, vec_filename=emb_filename, save_to_file=True) pt.load() trainer = build_trainer(pt_filename) model = trainer.model assert model.num_words_known(words) == 4 def test_space_formatting(): """ By default, spaces are left as spaces, but there is a format option to change spaces """ text = VI_TREEBANK.split("\n")[0] trees = tree_reader.read_trees(text) assert len(trees) == 1 assert str(trees[0]) == text assert "{}".format(trees[0]) == VI_TREEBANK assert "{:_O}".format(trees[0]) == VI_TREEBANK_UNDERSCORE def test_vlsp_formatting(): text = VI_TREEBANK_PAREN.split("\n")[0] trees = tree_reader.read_trees(text) assert len(trees) == 1 assert str(trees[0]) == text assert "{:_V}".format(trees[0]) == VI_TREEBANK_VLSP trees[0].tree_id = 50 assert "{:_Vi}".format(trees[0]) == VI_TREEBANK_VLSP_50 trees[0].tree_id = 100 assert "{:_Vi}".format(trees[0]) == VI_TREEBANK_VLSP_100 empty = tree_reader.read_trees("(ROOT)")[0] with pytest.raises(ValueError): "{:V}".format(empty) branches = tree_reader.read_trees("(ROOT (1) (2) (3))")[0] with pytest.raises(ValueError): "{:V}".format(branches) def test_language_formatting(): """ Test turning the parse tree into a 'language' for GPT """ text = VI_TREEBANK.split("\n")[0] trees = tree_reader.read_trees(text) trees = [t.prune_none().simplify_labels() for t in trees] assert len(trees) == 1 assert str(trees[0]) == VI_TREEBANK_SIMPLE text = "{:L}".format(trees[0]) assert text == EXPECTED_LABELED_BRACKETS ================================================ FILE: stanza/tests/datasets/__init__.py ================================================ ================================================ FILE: stanza/tests/datasets/coref/__init__.py ================================================ ================================================ FILE: stanza/tests/datasets/coref/test_hebrew_iahlt.py ================================================ import pytest from stanza import Pipeline from stanza.tests import TEST_MODELS_DIR from stanza.utils.datasets.coref.convert_hebrew_iahlt import extract_doc pytestmark = [pytest.mark.travis, pytest.mark.pipeline] @pytest.fixture(scope="module") def tokenizer(): pipe = Pipeline(lang="he", processors="tokenize", dir=TEST_MODELS_DIR, download_method=None) return pipe TEXT = """ מבולבלים​? גם אנחנו​: ל​מסעדנים ו​ה​מלצרים יש עוד סימני שאלה על ה​טיפים​ ה​פער בין פסיקת בית ה​דין ל​עבודה לבין פסיקה קודמת של בג"ץ​, משאיר את ה​ענף ב​חוסר וודאות​, ו​ה -​1 ב​ינואר כבר מעבר ל​פינה . "​מ​בחינת​י , הייתי מוסיף ל​תפריט תוספת שירות של 17​% "​, אמר בעלים של מסעדה ב​שדרות​ ב​רשות ה​מיסים מסתפקים ב​מסר עמום באשר ל​כוונותי​הם לאור פסק דין ה​טיפים ש​צפוי להיכנס ל​תוקפ​ו ב​-​1 ב​ינואר . על פי פרשנות​ם ה​מקצועית , הבהירו​, יש מקום לחייב את כספי ה​טיפים ב​מע"מ , "​עם זאת​, ה​רשות עדין בוחנת את ה​סוגיה ו​טרם התקבלה החלטה אופרטיבית ב​עניין "​. ו​איך אמורים ה​מסעדנים להיערך בינתיים ל​יישום ה​פסיקה ו​ל​מחזור ה​שנה ה​באה ? ב​יום חמישי יפגשו אנשי ארגון '​מסעדנים חזקים ביחד​' עם מנהל רשות ה​מיסים ערן יעקב​, ו​ידרשו תשובות ברורות​.​ "​אני עדיין לא מדבר עם ה​עובדים של​י , ו​אני גם לא יודע איך להיערך החל מ​עוד שבועיים​"​, אמר ל​'​דבר ראשון​' ניר שוחט​, ה​בעלים של מסעדת סושי מוטו ב​שדרות ו​מוסיף כי יהיה קשה להתאים את ה​פסיקה ל​מציאות ב​שטח . "​אף אחד לא יודע​. יש המון סתירות – עורך ה​דין אומר דבר אחד ו​רואה ה​חשבון דבר אחר​. עדיין לא הצליחו להבין את ה​חוק ל​אשור​ו "​.​ "​מ​בחינת​י , הייתי מוסיף ל​תפריט תוספת שירות של 17​% . זה יגלם גם את ה​מע"מ ו​ה​טיפים ו​מ​זה אני אשלם ל​מלצרים . די כבר עם ה​טיפים ה​אלה , מספיק​.​"​ """ CLUSTER = {'metadata': {'name': 'המסעדנים', 'entity': 'person'}, 'mentions': [[28, 35, {}], [572, 581, {}]]} def test_extract_doc(tokenizer): doc = {'text': TEXT, 'clusters': [CLUSTER], 'metadata': { 'doc_id': 'test' } } extracted = extract_doc(tokenizer, [doc]) assert len(extracted) == 1 assert len(extracted[0].coref_spans) == 2 assert extracted[0].coref_spans[1] == [(0, 4, 4)] assert extracted[0].coref_spans[6] == [(0, 3, 4)] ================================================ FILE: stanza/tests/datasets/ner/__init__.py ================================================ ================================================ FILE: stanza/tests/datasets/ner/test_prepare_ner_file.py ================================================ """ Test some simple conversions of NER bio files """ import pytest import json from stanza.models.common.doc import Document from stanza.utils.datasets.ner.prepare_ner_file import process_dataset BIO_1 = """ Jennifer B-PERSON Sh'reyan I-PERSON has O lovely O antennae O """.strip() BIO_2 = """ but O I O don't O like O the O way O Jennifer B-PERSON treated O Beckett B-PERSON on O the O Cerritos B-LOCATION """.strip() def check_json_file(doc, raw_text, expected_sentences, expected_tokens): raw_sentences = raw_text.strip().split("\n\n") assert len(raw_sentences) == expected_sentences if isinstance(expected_tokens, int): expected_tokens = [expected_tokens] for raw_sentence, expected_len in zip(raw_sentences, expected_tokens): assert len(raw_sentence.strip().split("\n")) == expected_len assert len(doc.sentences) == expected_sentences for sentence, expected_len in zip(doc.sentences, expected_tokens): assert len(sentence.tokens) == expected_len for sentence, raw_sentence in zip(doc.sentences, raw_sentences): for token, line in zip(sentence.tokens, raw_sentence.strip().split("\n")): word, tag = line.strip().split() assert token.text == word assert token.ner == tag def write_and_convert(tmp_path, raw_text): bio_file = tmp_path / "test.bio" with open(bio_file, "w", encoding="utf-8") as fout: fout.write(raw_text) json_file = tmp_path / "json.bio" process_dataset(bio_file, json_file) with open(json_file) as fin: doc = Document(json.load(fin)) return doc def run_test(tmp_path, raw_text, expected_sentences, expected_tokens): doc = write_and_convert(tmp_path, raw_text) check_json_file(doc, raw_text, expected_sentences, expected_tokens) def test_simple(tmp_path): run_test(tmp_path, BIO_1, 1, 5) def test_ner_at_end(tmp_path): run_test(tmp_path, BIO_2, 1, 12) def test_two_sentences(tmp_path): raw_text = BIO_1 + "\n\n" + BIO_2 run_test(tmp_path, raw_text, 2, [5, 12]) ================================================ FILE: stanza/tests/datasets/ner/test_utils.py ================================================ """ Test the utils file of the NER dataset processing """ import pytest from stanza.utils.datasets.ner.utils import list_doc_entities from stanza.tests.datasets.ner.test_prepare_ner_file import BIO_1, BIO_2, write_and_convert def test_list_doc_entities(tmp_path): """ Test the function which lists all of the entities in a doc """ doc = write_and_convert(tmp_path, BIO_1) entities = list_doc_entities(doc) expected = [(('Jennifer', "Sh'reyan"), 'PERSON')] assert expected == entities doc = write_and_convert(tmp_path, BIO_2) entities = list_doc_entities(doc) expected = [(('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')] assert expected == entities doc = write_and_convert(tmp_path, "\n\n".join([BIO_1, BIO_2])) entities = list_doc_entities(doc) expected = [(('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')] assert expected == entities doc = write_and_convert(tmp_path, "\n\n".join([BIO_1, BIO_1, BIO_2])) entities = list_doc_entities(doc) expected = [(('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')] assert expected == entities ================================================ FILE: stanza/tests/datasets/test_common.py ================================================ """ Test conllu manipulating routines in stanza/utils/dataset/common.py """ import pytest from stanza.utils.datasets.common import maybe_add_fake_dependencies # from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.pipeline] DEPS_EXAMPLE=""" # text = Sh'reyan's antennae are hella thicc 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 3 nmod:poss 3:nmod:poss SpaceAfter=No 2 's 's PART POS _ 1 case 1:case _ 3 antennae antenna NOUN NNS Number=Plur 6 nsubj 6:nsubj _ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 6 cop 6:cop _ 5 hella hella ADV RB _ 6 advmod 6:advmod _ 6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _ """.strip().split("\n") ONLY_ROOT_EXAMPLE=""" # text = Sh'reyan's antennae are hella thicc 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing _ _ _ SpaceAfter=No 2 's 's PART POS _ _ _ _ _ 3 antennae antenna NOUN NNS Number=Plur _ _ _ _ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin _ _ _ _ 5 hella hella ADV RB _ _ _ _ _ 6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _ """.strip().split("\n") ONLY_ROOT_EXPECTED=""" # text = Sh'reyan's antennae are hella thicc 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 6 dep _ SpaceAfter=No 2 's 's PART POS _ 1 dep _ _ 3 antennae antenna NOUN NNS Number=Plur 1 dep _ _ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 1 dep _ _ 5 hella hella ADV RB _ 1 dep _ _ 6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _ """.strip().split("\n") NO_DEPS_EXAMPLE=""" # text = Sh'reyan's antennae are hella thicc 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing _ _ _ SpaceAfter=No 2 's 's PART POS _ _ _ _ _ 3 antennae antenna NOUN NNS Number=Plur _ _ _ _ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin _ _ _ _ 5 hella hella ADV RB _ _ _ _ _ 6 thicc thicc ADJ JJ Degree=Pos _ _ _ _ """.strip().split("\n") NO_DEPS_EXPECTED=""" # text = Sh'reyan's antennae are hella thicc 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 0 root _ SpaceAfter=No 2 's 's PART POS _ 1 dep _ _ 3 antennae antenna NOUN NNS Number=Plur 1 dep _ _ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 1 dep _ _ 5 hella hella ADV RB _ 1 dep _ _ 6 thicc thicc ADJ JJ Degree=Pos 1 dep _ _ """.strip().split("\n") def test_fake_deps_no_change(): result = maybe_add_fake_dependencies(DEPS_EXAMPLE) assert result == DEPS_EXAMPLE def test_fake_deps_all_tokens(): result = maybe_add_fake_dependencies(NO_DEPS_EXAMPLE) assert result == NO_DEPS_EXPECTED def test_fake_deps_only_root(): result = maybe_add_fake_dependencies(ONLY_ROOT_EXAMPLE) assert result == ONLY_ROOT_EXPECTED ================================================ FILE: stanza/tests/datasets/test_vietnamese_renormalization.py ================================================ import pytest import os from stanza.utils.datasets.vietnamese import renormalize pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_replace_all(): text = "SỌAmple tụy test file" expected = "SOẠmple tuỵ test file" assert renormalize.replace_all(text) == expected def test_replace_file(tmp_path): text = "SỌAmple tụy test file" expected = "SOẠmple tuỵ test file" orig = tmp_path / "orig.txt" converted = tmp_path / "converted.txt" with open(orig, "w", encoding="utf-8") as fout: for i in range(10): fout.write(text) fout.write("\n") renormalize.convert_file(orig, converted) assert os.path.exists(converted) with open(converted, encoding="utf-8") as fin: lines = fin.readlines() assert len(lines) == 10 for i in lines: assert i.strip() == expected ================================================ FILE: stanza/tests/depparse/__init__.py ================================================ ================================================ FILE: stanza/tests/depparse/test_depparse_data.py ================================================ """ Test some pieces of the depparse dataloader """ import pytest from stanza.models import parser from stanza.models.depparse.data import data_to_batches, DataLoader from stanza.utils.conll import CoNLL pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def make_fake_data(*lengths): data = [] for i, length in enumerate(lengths): word = chr(ord('A') + i) chunk = [[word] * length] data.append(chunk) return data def check_batches(batched_data, expected_sizes, expected_order): for chunk, size in zip(batched_data, expected_sizes): assert sum(len(x[0]) for x in chunk) == size word_order = [] for chunk in batched_data: for sentence in chunk: word_order.append(sentence[0][0]) assert word_order == expected_order def test_data_to_batches_eval_mode(): """ Tests the chunking of batches in eval_mode A few options are tested, such as whether or not to sort and the maximum sentence size """ data = make_fake_data(1, 2, 3) batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None) check_batches(batched_data[0], [5, 1], ['C', 'B', 'A']) data = make_fake_data(1, 2, 6) batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None) check_batches(batched_data[0], [6, 3], ['C', 'B', 'A']) data = make_fake_data(3, 2, 1) batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None) check_batches(batched_data[0], [5, 1], ['A', 'B', 'C']) data = make_fake_data(3, 5, 2) batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None) check_batches(batched_data[0], [5, 5], ['B', 'A', 'C']) data = make_fake_data(3, 5, 2) batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3) check_batches(batched_data[0], [3, 5, 2], ['A', 'B', 'C']) data = make_fake_data(4, 1, 1) batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3) check_batches(batched_data[0], [4, 2], ['A', 'B', 'C']) data = make_fake_data(1, 4, 1) batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3) check_batches(batched_data[0], [1, 4, 1], ['A', 'B', 'C']) EWT_PUNCT_SAMPLE = """ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0048 # text = Bush asked for permission to go to Alabama to work on a Senate campaign. 1 Bush Bush PROPN NNP Number=Sing 2 nsubj 2:nsubj _ 2 asked ask VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 3 for for ADP IN _ 4 case 4:case _ 4 permission permission NOUN NN Number=Sing 2 obl 2:obl:for _ 5 to to PART TO _ 6 mark 6:mark _ 6 go go VERB VB VerbForm=Inf 4 acl 4:acl:to _ 7 to to ADP IN _ 8 case 8:case _ 8 Alabama Alabama PROPN NNP Number=Sing 6 obl 6:obl:to _ 9 to to PART TO _ 10 mark 10:mark _ 10 work work VERB VB VerbForm=Inf 6 advcl 6:advcl:to _ 11 on on ADP IN _ 14 case 14:case _ 12 a a DET DT Definite=Ind|PronType=Art 14 det 14:det _ 13 Senate Senate PROPN NNP Number=Sing 14 compound 14:compound _ 14 campaign campaign NOUN NN Number=Sing 10 obl 10:obl:on SpaceAfter=No 15 !!!!! ! PUNCT . _ 2 punct 2:punct _ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0049 # text = His superior officers said OK. 1 His his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nmod:poss 3:nmod:poss _ 2 superior superior ADJ JJ Degree=Pos 3 amod 3:amod _ 3 officers officer NOUN NNS Number=Plur 4 nsubj 4:nsubj _ 4 said say VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 5 OK ok INTJ UH _ 4 obj 4:obj SpaceAfter=No 6 ????? ? PUNCT . _ 4 punct 4:punct _ """ def test_punct_simplification(): """ Test a punctuation simplification that should make it so unexpected question/exclamation marks types are processed into ? and ! """ sample = CoNLL.conll2doc(input_str=EWT_PUNCT_SAMPLE) args = parser.parse_args(args=["--batch_size", "1000", "--shorthand", "en_test"]) data = DataLoader(sample, 5000, args, None) batches = [batch for batch in data] assert batches[0][-1] == [['Bush', 'asked', 'for', 'permission', 'to', 'go', 'to', 'Alabama', 'to', 'work', 'on', 'a', 'Senate', 'campaign', '!'], ['His', 'superior', 'officers', 'said', 'OK', '?']] if __name__ == '__main__': test_data_to_batches() ================================================ FILE: stanza/tests/depparse/test_parser.py ================================================ """ Run the tagger for a couple iterations on some fake data Uses a couple sentences of UD_English-EWT as training/dev data """ import os import pytest import zipfile import torch from stanza.models import parser from stanza.models.common import pretrain from stanza.models.depparse.trainer import Trainer from stanza.tests import TEST_WORKING_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] TRAIN_DATA = """ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003 # text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad. 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No 2 : : PUNCT : _ 1 punct 1:punct _ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _ 6 that that SCONJ IN _ 9 mark 9:mark _ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _ 8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _ 11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _ 15 in in ADP IN _ 16 case 16:case _ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No 17 . . PUNCT . _ 1 punct 1:punct _ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004 # text = Two of them were being run by 2 officials of the Ministry of the Interior! 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _ 2 of of ADP IN _ 3 case 3:case _ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ 7 by by ADP IN _ 9 case 9:case _ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _ 10 of of ADP IN _ 12 case 12:case _ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _ 13 of of ADP IN _ 15 case 15:case _ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No 16 ! ! PUNCT . _ 6 punct 6:punct _ """.lstrip() DEV_DATA = """ 1 From from ADP IN _ 3 case 3:case _ 2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _ 3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _ 4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _ 5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _ 6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _ 7 : : PUNCT : _ 4 punct 4:punct _ """.lstrip() class TestParser: @pytest.fixture(scope="class") def wordvec_pretrain_file(self): return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None, zip_train_data=False): """ Run the training for a few iterations, load & return the model """ train_file = str(tmp_path / "train.zip") if zip_train_data else str(tmp_path / "train.conllu") dev_file = str(tmp_path / "dev.conllu") pred_file = str(tmp_path / "pred.conllu") save_name = "test_parser.pt" save_file = str(tmp_path / save_name) if zip_train_data: with zipfile.ZipFile(train_file, "w") as zout: with zout.open('train.conllu', 'w') as fout: fout.write(train_text.encode()) else: with open(train_file, "w", encoding="utf-8") as fout: fout.write(train_text) with open(dev_file, "w", encoding="utf-8") as fout: fout.write(dev_text) args = ["--wordvec_pretrain_file", wordvec_pretrain_file, "--train_file", train_file, "--eval_file", dev_file, "--output_file", pred_file, "--log_step", "10", "--eval_interval", "20", "--max_steps", "100", "--shorthand", "en_test", "--save_dir", str(tmp_path), "--save_name", save_name, # in case we are doing a bert test "--bert_start_finetuning", "10", "--bert_warmup_steps", "10", "--lang", "en"] if not augment_nopunct: args.extend(["--augment_nopunct", "0.0"]) if extra_args is not None: args = args + extra_args trainer, _ = parser.main(args) assert os.path.exists(save_file) pt = pretrain.Pretrain(wordvec_pretrain_file) # test loading the saved model saved_model = Trainer(pretrain=pt, model_file=save_file) return trainer def test_train(self, tmp_path, wordvec_pretrain_file): """ Simple test of a few 'epochs' of tagger training """ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA) def test_arc_embedding(self, tmp_path, wordvec_pretrain_file): """ Simple test w/ and w/o arc embedding """ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--use_arc_embedding']) def test_no_arc_embedding(self, tmp_path, wordvec_pretrain_file): """ Simple test w/ and w/o arc embedding """ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--no_use_arc_embedding']) def test_zipfile_train(self, tmp_path, wordvec_pretrain_file): """ Simple test of a few 'epochs' of tagger training with a zipfile """ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, zip_train_data=True) def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file): self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2']) def test_with_bert_finetuning(self, tmp_path, wordvec_pretrain_file): trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2']) assert 'bert_optimizer' in trainer.optimizer.keys() assert 'bert_scheduler' in trainer.scheduler.keys() def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file): """ Check that if we save, then load, then save a model with a finetuned bert, that bert isn't lost """ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2']) assert 'bert_optimizer' in trainer.optimizer.keys() assert 'bert_scheduler' in trainer.scheduler.keys() save_name = trainer.args['save_name'] filename = tmp_path / save_name assert os.path.exists(filename) checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) assert any(x.startswith("bert_model") for x in checkpoint['model'].keys()) # Test loading the saved model, saving it, and still having bert in it # even if we have set bert_finetune to False for this incarnation pt = pretrain.Pretrain(wordvec_pretrain_file) args = {"bert_finetune": False} saved_model = Trainer(pretrain=pt, model_file=filename, args=args) saved_model.save(filename) # This is the part that would fail if the force_bert_saved option did not exist checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) assert any(x.startswith("bert_model") for x in checkpoint['model'].keys()) def test_with_peft(self, tmp_path, wordvec_pretrain_file): trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2', '--use_peft']) assert 'bert_optimizer' in trainer.optimizer.keys() assert 'bert_scheduler' in trainer.scheduler.keys() def test_single_optimizer_checkpoint(self, tmp_path, wordvec_pretrain_file): trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam']) save_dir = trainer.args['save_dir'] save_name = trainer.args['save_name'] checkpoint_name = trainer.args["checkpoint_save_name"] assert os.path.exists(os.path.join(save_dir, save_name)) assert checkpoint_name is not None assert os.path.exists(checkpoint_name) assert len(trainer.optimizer) == 1 for opt in trainer.optimizer.values(): assert isinstance(opt, torch.optim.Adam) pt = pretrain.Pretrain(wordvec_pretrain_file) checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name) assert checkpoint.optimizer is not None assert len(checkpoint.optimizer) == 1 for opt in checkpoint.optimizer.values(): assert isinstance(opt, torch.optim.Adam) def test_two_optimizers_checkpoint(self, tmp_path, wordvec_pretrain_file): trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam', '--second_optim', 'sgd', '--second_optim_start_step', '40']) save_dir = trainer.args['save_dir'] save_name = trainer.args['save_name'] checkpoint_name = trainer.args["checkpoint_save_name"] assert os.path.exists(os.path.join(save_dir, save_name)) assert checkpoint_name is not None assert os.path.exists(checkpoint_name) assert len(trainer.optimizer) == 1 for opt in trainer.optimizer.values(): assert isinstance(opt, torch.optim.SGD) pt = pretrain.Pretrain(wordvec_pretrain_file) checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name) assert checkpoint.optimizer is not None assert len(checkpoint.optimizer) == 1 for opt in trainer.optimizer.values(): assert isinstance(opt, torch.optim.SGD) ================================================ FILE: stanza/tests/langid/__init__.py ================================================ ================================================ FILE: stanza/tests/langid/test_langid.py ================================================ """ Basic tests of langid module """ import pytest from stanza.models.common.doc import Document from stanza.pipeline.core import Pipeline from stanza.pipeline.langid_processor import LangIDProcessor from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] #pytestmark = pytest.mark.skip @pytest.fixture(scope="module") def basic_multilingual(): return Pipeline(dir=TEST_MODELS_DIR, lang='multilingual', processors="langid") @pytest.fixture(scope="module") def enfr_multilingual(): return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en", "fr"]) @pytest.fixture(scope="module") def en_multilingual(): return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"]) @pytest.fixture(scope="module") def clean_multilingual(): return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_clean_text=True) def test_langid(basic_multilingual): """ Basic test of language identification """ english_text = "This is an English sentence." french_text = "C'est une phrase française." docs = [english_text, french_text] docs = [Document([], text=text) for text in docs] basic_multilingual(docs) predictions = [doc.lang for doc in docs] assert predictions == ["en", "fr"] def test_langid_benchmark(basic_multilingual): """ Run lang id model on 500 examples, confirm reasonable accuracy. """ examples = [ {"text": "contingentiam in naturalibus causis.", "label": "la"}, {"text": "I jak opowiadał nieżyjący już pan Czesław", "label": "pl"}, {"text": "Sonera gilt seit längerem als Übernahmekandidat", "label": "de"}, {"text": "与银类似,汞也可以与空气中的硫化氢反应。", "label": "zh-hans"}, {"text": "contradictionem implicat.", "label": "la"}, {"text": "Bis zu Prozent gingen die Offerten etwa im", "label": "de"}, {"text": "inneren Sicherheit vorgeschlagene Ausweitung der", "label": "de"}, {"text": "Multimedia-PDA mit Mini-Tastatur", "label": "de"}, {"text": "Ponášalo sa to na rovnicu o dvoch neznámych.", "label": "sk"}, {"text": "이처럼 앞으로 심판의 그 날에 다시 올 메시아가 예수 그리스도이며 , 그는 모든 인류의", "label": "ko"}, {"text": "Die Arbeitsgruppe bedauert , dass der weit über", "label": "de"}, {"text": "И только раз довелось поговорить с ним не вполне", "label": "ru"}, {"text": "de a-l lovi cu piciorul și conștiința că era", "label": "ro"}, {"text": "relación coas pretensións do demandante e que, nos", "label": "gl"}, {"text": "med petdeset in sedemdeset", "label": "sl"}, {"text": "Catalunya; el Consell Comarcal del Vallès Oriental", "label": "ca"}, {"text": "kunnen worden.", "label": "nl"}, {"text": "Witkin je ve většině ohledů zcela jiný.", "label": "cs"}, {"text": "lernen, so zu agieren, dass sie positive oder auch", "label": "de"}, {"text": "olurmuş...", "label": "tr"}, {"text": "sarcasmo de Altman, desde as «peruas» que discutem", "label": "pt"}, {"text": "خلاف فوجداری مقدمہ درج کرے۔", "label": "ur"}, {"text": "Norddal kommune :", "label": "no"}, {"text": "dem Windows-.-Zeitalter , soll in diesem Jahr", "label": "de"}, {"text": "przeklętych ucieleśniają mit poety-cygana,", "label": "pl"}, {"text": "We do not believe the suspect has ties to this", "label": "en"}, {"text": "groziņu pīšanu.", "label": "lv"}, {"text": "Senior Vice-President David M. Thomas möchte", "label": "de"}, {"text": "neomylně vybral nějakou knihu a začetl se.", "label": "cs"}, {"text": "Statt dessen darf beispielsweise der Browser des", "label": "de"}, {"text": "outubro, alcançando R $ bilhões em .", "label": "pt"}, {"text": "(Porte, ), as it does other disciplines", "label": "en"}, {"text": "uskupení se mylně domnívaly, že podporu", "label": "cs"}, {"text": "Übernahme von Next Ende an dem System herum , das", "label": "de"}, {"text": "No podemos decir a la Hacienda que los alemanes", "label": "es"}, {"text": "и рѣста еи братья", "label": "orv"}, {"text": "الذي اتخذ قرارا بتجميد اعلان الدولة الفلسطينية", "label": "ar"}, {"text": "uurides Rootsi sõjaarhiivist toodud . sajandi", "label": "et"}, {"text": "selskapets penger til å pusse opp sin enebolig på", "label": "no"}, {"text": "средней полосе и севернее в Ярославской,", "label": "ru"}, {"text": "il-massa żejda fil-ġemgħat u superġemgħat ta'", "label": "mt"}, {"text": "The Global Beauties on internetilehekülg, mida", "label": "et"}, {"text": "이스라엘 인들은 하나님이 그 큰 팔을 펴 이집트 인들을 치는 것을 보고 하나님을 두려워하며", "label": "ko"}, {"text": "Snad ještě dodejme jeden ekonomický argument.", "label": "cs"}, {"text": "Spalio d. vykusiame pirmajame rinkimų ture", "label": "lt"}, {"text": "und schlechter Journalismus ein gutes Geschäft .", "label": "de"}, {"text": "Du sodiečiai sėdi ant potvynio apsemtų namų stogo.", "label": "lt"}, {"text": "цей є автентичним.", "label": "uk"}, {"text": "Și îndegrabă fu cu îngerul mulțime de șireaguri", "label": "ro"}, {"text": "sobra personal cualificado.", "label": "es"}, {"text": "Tako se u Njemačkoj dvije trećine liječnika služe", "label": "hr"}, {"text": "Dual-Athlon-Chipsatz noch in diesem Jahr", "label": "de"}, {"text": "यहां तक कि चीन के चीफ ऑफ जनरल स्टाफ भी भारत का", "label": "hi"}, {"text": "Li forestier du mont avale", "label": "fro"}, {"text": "Netzwerken für Privatanwender zu bewundern .", "label": "de"}, {"text": "만해는 승적을 가진 중이 결혼할 수 없다는 불교의 계율을 시대에 맞지 않는 것으로 보았다", "label": "ko"}, {"text": "balance and weight distribution but not really for", "label": "en"}, {"text": "og så e # tente vi opp den om morgonen å sfyrte", "label": "nn"}, {"text": "변화는 의심의 여지가 없는 것이지만 반면에 진화는 논쟁의 씨앗이다 .", "label": "ko"}, {"text": "puteare fac aceastea.", "label": "ro"}, {"text": "Waitt seine Führungsmannschaft nicht dem", "label": "de"}, {"text": "juhtimisega, tulid sealt.", "label": "et"}, {"text": "Veränderungen .", "label": "de"}, {"text": "banda en el Bayer Leverkusen de la Bundesliga de", "label": "es"}, {"text": "В туже зиму посла всеволодъ сн҃а своѥго ст҃ослава", "label": "orv"}, {"text": "пославъ приведе я мастеры ѿ грекъ", "label": "orv"}, {"text": "En un nou escenari difícil d'imaginar fa poques", "label": "ca"}, {"text": "καὶ γὰρ τινὲς αὐτοὺς εὐεργεσίαι εἶχον ἐκ Κροίσου", "label": "grc"}, {"text": "직접적인 관련이 있다 .", "label": "ko"}, {"text": "가까운 듯하면서도 멀다 .", "label": "ko"}, {"text": "Er bietet ein ähnliches Leistungsniveau und", "label": "de"}, {"text": "民都洛水牛是獨居的,並不會以群族聚居。", "label": "zh-hant"}, {"text": "την τρομοκρατία.", "label": "el"}, {"text": "hurbiltzen diren neurrian.", "label": "eu"}, {"text": "Ah dimenticavo, ma tutta sta caciara per fare un", "label": "it"}, {"text": "На первом этапе (-) прошла так называемая", "label": "ru"}, {"text": "of games are on the market.", "label": "en"}, {"text": "находится Мост дружбы, соединяющий узбекский и", "label": "ru"}, {"text": "lessié je voldroie que li saint fussent aporté", "label": "fro"}, {"text": "Дошла очередь и до Гималаев.", "label": "ru"}, {"text": "vzácným suknem táhly pouští, si jednou chtěl do", "label": "cs"}, {"text": "E no terceiro tipo sitúa a familias (%), nos que a", "label": "gl"}, {"text": "وجابت دوريات امريكية وعراقية شوارع المدينة، فيما", "label": "ar"}, {"text": "Jeg har bodd her i år .", "label": "no"}, {"text": "Pohrozil, že odbory zostří postoj, pokud se", "label": "cs"}, {"text": "tinham conseguido.", "label": "pt"}, {"text": "Nicht-Erkrankten einen Anfangsverdacht für einen", "label": "de"}, {"text": "permanece em aberto.", "label": "pt"}, {"text": "questi possono promettere rendimenti fino a un", "label": "it"}, {"text": "Tema juurutatud kahevedurisüsteemita oleksid", "label": "et"}, {"text": "Поведение внешне простой игрушки оказалось", "label": "ru"}, {"text": "Bundesländern war vom Börsenverein des Deutschen", "label": "de"}, {"text": "acció, 'a mesura que avanci l'estiu, amb l'augment", "label": "ca"}, {"text": "Dove trovare queste risorse? Jay Naidoo, ministro", "label": "it"}, {"text": "essas gordurinhas.", "label": "pt"}, {"text": "Im zweiten Schritt sollen im übernächsten Jahr", "label": "de"}, {"text": "allveelaeva pole enam vaja, kuna külm sõda on läbi", "label": "et"}, {"text": "उपद्रवी दुकानों को लूटने के साथ ही उनमें आग लगा", "label": "hi"}, {"text": "@user nella sfortuna sei fortunata ..", "label": "it"}, {"text": "математических школ в виде грозовых туч.", "label": "ru"}, {"text": "No cambiaremos nunca nuestra forma de jugar por un", "label": "es"}, {"text": "dla tej klasy ani wymogów minimalnych, z wyjątkiem", "label": "pl"}, {"text": "en todo el mundo, mientras que en España consiguió", "label": "es"}, {"text": "политики считать надежное обеспечение военной", "label": "ru"}, {"text": "gogoratzen du, genio alemana delakoaren", "label": "eu"}, {"text": "Бычий глаз.", "label": "ru"}, {"text": "Opeření se v pravidelných obdobích obnovuje", "label": "cs"}, {"text": "I no és només la seva, es tracta d'una resposta", "label": "ca"}, {"text": "오경을 가르쳤다 .", "label": "ko"}, {"text": "Nach der so genannten Start-up-Periode vergibt die", "label": "de"}, {"text": "Saulista huomasi jo lapsena , että hänellä on", "label": "fi"}, {"text": "Министерство культуры сочло нецелесообразным, и", "label": "ru"}, {"text": "znepřátelené tábory v Tádžikistánu předseda", "label": "cs"}, {"text": "καὶ ἦν ὁ λαὸς προσδοκῶν τὸν Ζαχαρίαν καὶ ἐθαύμαζον", "label": "grc"}, {"text": "Вечером, в продукте, этот же человек говорил о", "label": "ru"}, {"text": "lugar á formación de xuizos máis complexos.", "label": "gl"}, {"text": "cheaper, in the end?", "label": "en"}, {"text": "الوزارة في شأن صفقات بيع الشركات العامة التي تم", "label": "ar"}, {"text": "tärkeintä elämässäni .", "label": "fi"}, {"text": "Виконання Мінських угод було заблоковано Росією та", "label": "uk"}, {"text": "Aby szybko rozpoznać żołnierzy desantu, należy", "label": "pl"}, {"text": "Bankengeschäfte liegen vorn , sagte Strothmann .", "label": "de"}, {"text": "продолжение работы.", "label": "ru"}, {"text": "Metro AG plant Online-Offensive", "label": "de"}, {"text": "nu vor veni, și să vor osîndi, aceia nu pot porni", "label": "ro"}, {"text": "Ich denke , es geht in Wirklichkeit darum , NT bei", "label": "de"}, {"text": "de turism care încasează contravaloarea", "label": "ro"}, {"text": "Aurkaria itotzea da helburua, baloia lapurtu eta", "label": "eu"}, {"text": "com a centre de formació en Tecnologies de la", "label": "ca"}, {"text": "oportet igitur quod omne agens in agendo intendat", "label": "la"}, {"text": "Jerzego Andrzejewskiego, oparty na chińskich", "label": "pl"}, {"text": "sau một vài câu chuyện xã giao không dính dáng tới", "label": "vi"}, {"text": "что экономическому прорыву жесткий авторитарный", "label": "ru"}, {"text": "DRAM-Preisen scheinen DSPs ein", "label": "de"}, {"text": "Jos dajan nubbái: Mana!", "label": "sme"}, {"text": "toți carii ascultară de el să răsipiră.", "label": "ro"}, {"text": "odpowiedzialności, które w systemie własności", "label": "pl"}, {"text": "Dvomesečno potovanje do Mollenda v Peruju je", "label": "sl"}, {"text": "d'entre les agències internacionals.", "label": "ca"}, {"text": "Fahrzeugzugangssysteme gefertigt und an viele", "label": "de"}, {"text": "in an answer to the sharers' petition in Cuthbert", "label": "en"}, {"text": "Europa-Domain per Verordnung zu regeln .", "label": "de"}, {"text": "#Balotelli. Su ebay prezzi stracciati per Silvio", "label": "it"}, {"text": "Ne na košickém trávníku, ale už včera v letadle se", "label": "cs"}, {"text": "zaměstnanosti a investičních strategií.", "label": "cs"}, {"text": "Tatínku, udělej den", "label": "cs"}, {"text": "frecuencia con Mary.", "label": "es"}, {"text": "Свеаборге.", "label": "ru"}, {"text": "opatření slovenské strany o certifikaci nejvíce", "label": "cs"}, {"text": "En todas me decían: 'Espera que hagamos un estudio", "label": "es"}, {"text": "Die Demonstration sollte nach Darstellung der", "label": "de"}, {"text": "Ci vorrà un assoluto rigore se dietro i disavanzi", "label": "it"}, {"text": "Tatínku, víš, že Honzovi odešla maminka?", "label": "cs"}, {"text": "Die Anzahl der Rechner wuchs um % auf und die", "label": "de"}, {"text": "האמריקאית על אדמת סעודיה עלולה לסבך את ישראל, אין", "label": "he"}, {"text": "Volán Egyesülés, a Közlekedési Főfelügyelet is.", "label": "hu"}, {"text": "Schejbala, který stejnou hru s velkým úspěchem", "label": "cs"}, {"text": "depends on the data type of the field.", "label": "en"}, {"text": "Umsatzwarnung zu Wochenbeginn zeitweise auf ein", "label": "de"}, {"text": "niin heti nukun .", "label": "fi"}, {"text": "Mobilfunkunternehmen gegen die Anwendung der so", "label": "de"}, {"text": "sapessi le intenzioni del governo Monti e dell'UE", "label": "it"}, {"text": "Di chi è figlia Martine Aubry?", "label": "it"}, {"text": "avec le reste du monde.", "label": "fr"}, {"text": "Այդ մաքոքը ինքնին նոր չէ, աշխարհը արդեն մի քանի", "label": "hy"}, {"text": "și în cazul destrămării cenaclului.", "label": "ro"}, {"text": "befriedigen kann , und ohne die auftretenden", "label": "de"}, {"text": "Κύκνον τ̓ ἐξεναρεῖν καὶ ἀπὸ κλυτὰ τεύχεα δῦσαι.", "label": "grc"}, {"text": "færdiguddannede.", "label": "da"}, {"text": "Schmidt war Sohn eines Rittergutsbesitzers.", "label": "de"}, {"text": "и вдаша попадь ѡпрати", "label": "orv"}, {"text": "cine nu știe învățătură”.", "label": "ro"}, {"text": "détacha et cette dernière tenta de tuer le jeune", "label": "fr"}, {"text": "Der har saka også ei lengre forhistorie.", "label": "nn"}, {"text": "Pieprz roztłuc w moździerzu, dodać do pasty,", "label": "pl"}, {"text": "Лежа за гребнем оврага, как за бруствером, Ушаков", "label": "ru"}, {"text": "gesucht habe, vielen Dank nochmals!", "label": "de"}, {"text": "инструментальных сталей, повышения", "label": "ru"}, {"text": "im Halbfinale Patrick Smith und im Finale dann", "label": "de"}, {"text": "البنوك التريث في منح تسهيلات جديدة لمنتجي حديد", "label": "ar"}, {"text": "una bolsa ventral, la cual se encuentra debajo de", "label": "es"}, {"text": "za SETimes.", "label": "sr"}, {"text": "de Irak, a un piloto italiano que había violado el", "label": "es"}, {"text": "Er könne sich nicht erklären , wie die Zeitung auf", "label": "de"}, {"text": "Прохорова.", "label": "ru"}, {"text": "la democrazia perde sulla tecnocrazia? #", "label": "it"}, {"text": "entre ambas instituciones, confirmó al medio que", "label": "es"}, {"text": "Austlandet, vart det funne om lag førti", "label": "nn"}, {"text": "уровнями власти.", "label": "ru"}, {"text": "Dá tedy primáři úplatek, a často ne malý.", "label": "cs"}, {"text": "brillantes del acto, al llevar a cabo en el", "label": "es"}, {"text": "eee druga zadeva je majhen priročen gre kamorkoli", "label": "sl"}, {"text": "Das ATX-Board paßt in herkömmliche PC-ATX-Gehäuse", "label": "de"}, {"text": "Za vodné bylo v prvním pololetí zaplaceno v ČR", "label": "cs"}, {"text": "Даже на полсантиметра.", "label": "ru"}, {"text": "com la del primer tinent d'alcalde en funcions,", "label": "ca"}, {"text": "кількох оповідань в цілості — щось на зразок того", "label": "uk"}, {"text": "sed ad divitias congregandas, vel superfluum", "label": "la"}, {"text": "Norma Talmadge, spela mot Valentino i en version", "label": "sv"}, {"text": "Dlatego chciał się jej oświadczyć w niezwykłym", "label": "pl"}, {"text": "будут выступать на одинаковых снарядах.", "label": "ru"}, {"text": "Orang-orang terbunuh di sana.", "label": "id"}, {"text": "لدى رايت شقيق اسمه أوسكار, وهو يعمل كرسام للكتب", "label": "ar"}, {"text": "Wirklichkeit verlagerten und kaum noch", "label": "de"}, {"text": "как перемешивают костяшки перед игрой в домино, и", "label": "ru"}, {"text": "В средине дня, когда солнце светило в нашу", "label": "ru"}, {"text": "d'aventure aux rôles de jeune romantique avec une", "label": "fr"}, {"text": "My teď hledáme organizace, jež by s námi chtěly", "label": "cs"}, {"text": "Urteilsfähigkeit einbüßen , wenn ich eigene", "label": "de"}, {"text": "sua appartenenza anche a voci diverse da quella in", "label": "it"}, {"text": "Aufträge dieses Jahr verdoppeln werden .", "label": "de"}, {"text": "M.E.: Miała szanse mnie odnaleźć, gdyby naprawdę", "label": "pl"}, {"text": "secundum contactum virtutis, cum careat dimensiva", "label": "la"}, {"text": "ezinbestekoa dela esan zuen.", "label": "eu"}, {"text": "Anek hurbiltzeko eskatzen zion besaulkitik, eta", "label": "eu"}, {"text": "perfectius alio videat, quamvis uterque videat", "label": "la"}, {"text": "Die Strecke war anspruchsvoll und führte unter", "label": "de"}, {"text": "саморазоблачительным уроком, западные СМИ не", "label": "ru"}, {"text": "han representerer radikal islamisme .", "label": "no"}, {"text": "Què s'hi respira pel que fa a la reforma del", "label": "ca"}, {"text": "previsto para também ser desconstruido.", "label": "pt"}, {"text": "Ὠκεανοῦ βαθυκόλποις ἄνθεά τ̓ αἰνυμένην, ῥόδα καὶ", "label": "grc"}, {"text": "para jovens de a anos nos Cieps.", "label": "pt"}, {"text": "संघर्ष को अंजाम तक पहुंचाने का ऐलान किया है ।", "label": "hi"}, {"text": "objeví i u nás.", "label": "cs"}, {"text": "kvitteringer.", "label": "da"}, {"text": "This report is no exception.", "label": "en"}, {"text": "Разлепват доносниците до избирателните списъци", "label": "bg"}, {"text": "anderem ihre Bewegungsfreiheit in den USA", "label": "de"}, {"text": "Ñu tegoon ca kaw gor ña ay njotti bopp yu kenn", "label": "wo"}, {"text": "Struktur kann beispielsweise der Schwerpunkt mehr", "label": "de"}, {"text": "% la velocidad permitida, la sanción es muy grave.", "label": "es"}, {"text": "Teles-Einstieg in ADSL-Markt", "label": "de"}, {"text": "ettekäändeks liiga suure osamaksu.", "label": "et"}, {"text": "als Indiz für die geänderte Marktpolitik des", "label": "de"}, {"text": "quod quidem aperte consequitur ponentes", "label": "la"}, {"text": "de negociación para el próximo de junio.", "label": "es"}, {"text": "Tyto důmyslné dekorace doznaly v poslední době", "label": "cs"}, {"text": "največjega uspeha doslej.", "label": "sl"}, {"text": "Paul Allen je jedan od suosnivača Interval", "label": "hr"}, {"text": "Federal (Seac / DF) eo Sindicato das Empresas de", "label": "pt"}, {"text": "Quartal mit . Mark gegenüber dem gleichen Quartal", "label": "de"}, {"text": "otros clubes y del Barça B saldrán varios", "label": "es"}, {"text": "Jaskula (Pol.) -", "label": "cs"}, {"text": "umožnily říci, že je možné přejít k mnohem", "label": "cs"}, {"text": "اعلن الجنرال تومي فرانكس قائد القوات الامريكية", "label": "ar"}, {"text": "Telekom-Chef Ron Sommer und der Vorstandssprecher", "label": "de"}, {"text": "My, jako průmyslový a finanční holding, můžeme", "label": "cs"}, {"text": "voorlichting onder andere betrekking kan hebben:", "label": "nl"}, {"text": "Hinrichtung geistig Behinderter applaudiert oder", "label": "de"}, {"text": "wie beispielsweise Anzahl erzielte Klicks ,", "label": "de"}, {"text": "Intel-PC-SDRAM-Spezifikation in der Version . (", "label": "de"}, {"text": "plângere în termen de zile de la comunicarea", "label": "ro"}, {"text": "и Испания ще изгубят втория си комисар в ЕК.", "label": "bg"}, {"text": "इसके चलते इस आदिवासी जनजाति का क्षरण हो रहा है ।", "label": "hi"}, {"text": "aunque se mostró contrario a establecer un", "label": "es"}, {"text": "des letzten Jahres von auf Millionen Euro .", "label": "de"}, {"text": "Ankara se također poziva da u cijelosti ratificira", "label": "hr"}, {"text": "herunterlädt .", "label": "de"}, {"text": "стрессовую ситуацию для организма, каковой", "label": "ru"}, {"text": "Státního shromáždění (parlamentu).", "label": "cs"}, {"text": "diskutieren , ob und wie dieser Dienst weiterhin", "label": "de"}, {"text": "Verbindungen zu FPÖ-nahen Polizisten gepflegt und", "label": "de"}, {"text": "Pražského volebního lídra ovšem nevybírá Miloš", "label": "cs"}, {"text": "Nach einem Bericht der Washington Post bleibt das", "label": "de"}, {"text": "للوضع آنذاك، لكني في قرارة نفسي كنت سعيداً لما", "label": "ar"}, {"text": "не желаят запазването на статуквото.", "label": "bg"}, {"text": "Offenburg gewesen .", "label": "de"}, {"text": "ἐὰν ὑμῖν εἴπω οὐ μὴ πιστεύσητε", "label": "grc"}, {"text": "all'odiato compagno di squadra Prost, il quale", "label": "it"}, {"text": "historischen Gänselieselbrunnens.", "label": "de"}, {"text": "למידע מלווייני הריגול האמריקאיים העוקבים אחר", "label": "he"}, {"text": "οὐδὲν ἄρα διαφέρεις Ἀμάσιος τοῦ Ἠλείου, ὃν", "label": "grc"}, {"text": "movementos migratorios.", "label": "gl"}, {"text": "Handy und ein Spracherkennungsprogramm sämtliche", "label": "de"}, {"text": "Kümne aasta jooksul on Eestisse ohjeldamatult", "label": "et"}, {"text": "H.G. Bücknera.", "label": "pl"}, {"text": "protiv krijumčarenja, ili pak traženju ukidanja", "label": "hr"}, {"text": "Topware-Anteile mehrere Millionen Mark gefordert", "label": "de"}, {"text": "Maar de mensen die nu over Van Dijk bij FC Twente", "label": "nl"}, {"text": "poidan experimentar as percepcións do interesado,", "label": "gl"}, {"text": "Miał przecież w kieszeni nóż.", "label": "pl"}, {"text": "Avšak žádná z nich nepronikla za hranice přímé", "label": "cs"}, {"text": "esim. helpottamalla luottoja muiden", "label": "fi"}, {"text": "Podle předběžných výsledků zvítězila v", "label": "cs"}, {"text": "Nicht nur das Web-Frontend , auch die", "label": "de"}, {"text": "Regierungsinstitutionen oder Universitäten bei", "label": "de"}, {"text": "Խուլեն Լոպետեգիին, պատճառաբանելով, որ վերջինս", "label": "hy"}, {"text": "Афганистана, где в последние дни идут ожесточенные", "label": "ru"}, {"text": "лѧхове же не идоша", "label": "orv"}, {"text": "Mit Hilfe von IBMs Chip-Management-Systemen sollen", "label": "de"}, {"text": ", als Manager zu Telefonica zu wechseln .", "label": "de"}, {"text": "którym zajmuje się człowiek, zmienia go i pozwala", "label": "pl"}, {"text": "činí kyperských liber, to je asi USD.", "label": "cs"}, {"text": "Studienplätze getauscht werden .", "label": "de"}, {"text": "учёных, орнитологов признают вид.", "label": "ru"}, {"text": "acordare a concediilor prevăzute de legislațiile", "label": "ro"}, {"text": "at større innsats for fornybar, berekraftig energi", "label": "nn"}, {"text": "Politiet veit ikkje kor mange personar som deltok", "label": "nn"}, {"text": "offentligheten av unge , sinte menn som har", "label": "no"}, {"text": "însuși în jurul lapunei, care încet DISPARE în", "label": "ro"}, {"text": "O motivo da decisão é evitar uma sobrecarga ainda", "label": "pt"}, {"text": "El Apostolado de la prensa contribuye en modo", "label": "es"}, {"text": "Teltow ( Kreis Teltow-Fläming ) ist Schmitt einer", "label": "de"}, {"text": "grozījumus un iesniegt tos Apvienoto Nāciju", "label": "lv"}, {"text": "Gestalt einer deutschen Nationalmannschaft als", "label": "de"}, {"text": "D überholt zu haben , konterte am heutigen Montag", "label": "de"}, {"text": "Softwarehersteller Oracle hat im dritten Quartal", "label": "de"}, {"text": "Během nich se ekonomické podmínky mohou radikálně", "label": "cs"}, {"text": "Dziki kot w górach zeskakuje z kamienia.", "label": "pl"}, {"text": "Ačkoliv ligový nováček prohrál, opět potvrdil, že", "label": "cs"}, {"text": "des Tages , Portraits internationaler Stars sowie", "label": "de"}, {"text": "Communicator bekannt wurde .", "label": "de"}, {"text": "τῷ δ’ ἄρα καὶ αὐτῷ ἡ γυνή ἐπίτεξ ἐοῦσα πᾶσαν", "label": "grc"}, {"text": "Triadú tenia, mentre redactava 'Dies de memòria',", "label": "ca"}, {"text": "دسته‌جمعی در درخشندگی ماه سیم‌گون زمزمه ستاینده و", "label": "fa"}, {"text": "Книгу, наполненную мелочной заботой об одежде,", "label": "ru"}, {"text": "putares canem leporem persequi.", "label": "la"}, {"text": "В дальнейшем эта яркость слегка померкла, но в", "label": "ru"}, {"text": "offizielles Verfahren gegen die Telekom", "label": "de"}, {"text": "podrían haber sido habitantes de la Península", "label": "es"}, {"text": "Grundlage für dieses Verfahren sind spezielle", "label": "de"}, {"text": "Rechtsausschuß vorgelegten Entwurf der Richtlinie", "label": "de"}, {"text": "Im so genannten Portalgeschäft sei das Unternehmen", "label": "de"}, {"text": "ⲏ ⲉⲓϣⲁⲛϥⲓ ⲛⲉⲓⲇⲱⲗⲟⲛ ⲉⲧϩⲙⲡⲉⲕⲏⲓ ⲙⲏ ⲉⲓⲛⲁϣϩⲱⲡ ⲟⲛ ⲙⲡⲣⲏ", "label": "cop"}, {"text": "juego podían matar a cualquier herbívoro, pero", "label": "es"}, {"text": "Nach Angaben von Axent nutzen Unternehmen aus der", "label": "de"}, {"text": "hrdiny Havlovy Zahradní slavnosti (premiéra ) se", "label": "cs"}, {"text": "Een zin van heb ik jou daar", "label": "nl"}, {"text": "hat sein Hirn an der CeBIT-Kasse vergessen .", "label": "de"}, {"text": "καὶ τοὺς ἐκπλαγέντας οὐκ ἔχειν ἔτι ἐλεγχομένους", "label": "grc"}, {"text": "nachgewiesenen langfristigen Kosten , sowie den im", "label": "de"}, {"text": "jučer nakon četiri dana putovanja u Helsinki.", "label": "hr"}, {"text": "pašto paslaugos teikėjas gali susitarti su", "label": "lt"}, {"text": "В результате, эти золотые кадры переходят из одной", "label": "ru"}, {"text": "द फाइव-ईयर एंगेजमेंट में अभिनय किया जिसमें जैसन", "label": "hi"}, {"text": "výpis o počtu akcií.", "label": "cs"}, {"text": "Enfin, elles arrivent à un pavillon chinois", "label": "fr"}, {"text": "Tentu saja, tren yang berhubungandengan", "label": "id"}, {"text": "Arbeidarpartiet og SV har sikra seg fleirtal mot", "label": "nn"}, {"text": "eles: 'Tudo isso está errado' , disse um", "label": "pt"}, {"text": "The islands are in their own time zone, minutes", "label": "en"}, {"text": "Auswahl debütierte er am .", "label": "de"}, {"text": "Bu komisyonlar, arazilerini satın almak için", "label": "tr"}, {"text": "Geschütze gegen Redmond aufgefahren .", "label": "de"}, {"text": "Time scything the hours, but at the top, over the", "label": "en"}, {"text": "Di musim semi , berharap mengadaptasi Tintin untuk", "label": "id"}, {"text": "крупнейшей геополитической катастрофой XX века.", "label": "ru"}, {"text": "Rajojen avaaminen ei suju ongelmitta .", "label": "fi"}, {"text": "непроницаемым, как для СССР.", "label": "ru"}, {"text": "Ma non mancano le polemiche.", "label": "it"}, {"text": "Internet als Ort politischer Diskussion und auch", "label": "de"}, {"text": "incomplets.", "label": "ca"}, {"text": "Su padre luchó al lado de Luis Moya, primer Jefe", "label": "es"}, {"text": "informazione.", "label": "it"}, {"text": "Primacom bietet für Telekom-Kabelnetz", "label": "de"}, {"text": "Oświadczenie prezydencji w imieniu Unii", "label": "pl"}, {"text": "foran rattet i familiens gamle Baleno hvis døra på", "label": "no"}, {"text": "[speaker:laughter]", "label": "sl"}, {"text": "Dog med langt mindre utstyr med seg.", "label": "nn"}, {"text": "dass es nicht schon mit der anfänglichen", "label": "de"}, {"text": "इस पर दोनों पक्षों में नोकझोंक शुरू हो गई ।", "label": "hi"}, {"text": "کے ترجمان منیش تیواری اور دگ وجئے سنگھ نے بھی یہ", "label": "ur"}, {"text": "dell'Assemblea Costituente che posseggono i", "label": "it"}, {"text": "и аште вьси съблазнѧтъ сѧ нъ не азъ", "label": "cu"}, {"text": "In Irvine hat auch das Logistikunternehmen Atlas", "label": "de"}, {"text": "законодательных норм, принимаемых существующей", "label": "ru"}, {"text": "Κροίσῳ προτείνων τὰς χεῖρας ἐπικατασφάξαι μιν", "label": "grc"}, {"text": "МИНУСЫ: ИНФЛЯЦИЯ И КРИЗИС В ЖИВОТНОВОДСТВЕ.", "label": "ru"}, {"text": "unterschiedlicher Meinung .", "label": "de"}, {"text": "Jospa joku ystävällinen sielu auttaisi kassieni", "label": "fi"}, {"text": "Añadió que, en el futuro se harán otros", "label": "es"}, {"text": "Sessiz tonlama hem Fince, hem de Kuzey Sami", "label": "tr"}, {"text": "nicht ihnen gehört und sie nicht alles , was sie", "label": "de"}, {"text": "Etelästä Kuivajärveen laskee Tammelan Liesjärvestä", "label": "fi"}, {"text": "ICANNs Vorsitzender Vint Cerf warb mit dem Hinweis", "label": "de"}, {"text": "Norsk politikk frå til kan dermed, i", "label": "nn"}, {"text": "Głosowało posłów.", "label": "pl"}, {"text": "Danny Jones -- smithjones@ev.net", "label": "en"}, {"text": "sebeuvědomění moderní civilizace sehrála lučavka", "label": "cs"}, {"text": "относительно спокойный сон: тому гарантия", "label": "ru"}, {"text": "A halte voiz prist li pedra a crïer", "label": "fro"}, {"text": "آن‌ها امیدوارند این واکسن به‌زودی در دسترس بیماران", "label": "fa"}, {"text": "vlastní důstojnou vousatou tváří.", "label": "cs"}, {"text": "ora aprire la strada a nuove cause e alimentare il", "label": "it"}, {"text": "Die Zahl der Vielleser nahm von auf Prozent zu ,", "label": "de"}, {"text": "Finanzvorstand von Hotline-Dienstleister InfoGenie", "label": "de"}, {"text": "entwickeln .", "label": "de"}, {"text": "incolumità pubblica.", "label": "it"}, {"text": "lehtija televisiomainonta", "label": "fi"}, {"text": "joistakin kohdista eri mieltä.", "label": "fi"}, {"text": "Hlavně anglická nezávislá scéna, Dead Can Dance,", "label": "cs"}, {"text": "pásmech od do bodů bodové stupnice.", "label": "cs"}, {"text": "Zu Beginn des Ersten Weltkrieges zählte das", "label": "de"}, {"text": "Així van sorgir, damunt els antics cementiris,", "label": "ca"}, {"text": "In manchem Gedicht der spätern Alten, wie zum", "label": "de"}, {"text": "gaweihaida jah insandida in þana fairƕu jus qiþiþ", "label": "got"}, {"text": "Beides sollte gelöscht werden!", "label": "de"}, {"text": "modifiqués la seva petició inicial de anys de", "label": "ca"}, {"text": "В день открытия симпозиума состоялась закладка", "label": "ru"}, {"text": "tõestatud.", "label": "et"}, {"text": "ἵππῳ πίπτει αὐτοῦ ταύτῃ", "label": "grc"}, {"text": "bisher nie enttäuscht!", "label": "de"}, {"text": "De bohte ollu tuollárat ja suttolaččat ja", "label": "sme"}, {"text": "Klarsignal från röstlängdsläsaren, tre tryck i", "label": "sv"}, {"text": "Tvůrcem nového termínu je Joseph Fisher.", "label": "cs"}, {"text": "Nie miałem czasu na reakcję twierdzi Norbert,", "label": "pl"}, {"text": "potentia Schöpfer.", "label": "de"}, {"text": "Un poquito caro, pero vale mucho la pena;", "label": "es"}, {"text": "οὔ τε γὰρ ἴφθιμοι Λύκιοι Δαναῶν ἐδύναντο τεῖχος", "label": "grc"}, {"text": "vajec, sladového výtažku a některých vitamínových", "label": "cs"}, {"text": "Настоящие герои, те, чьи истории потом", "label": "ru"}, {"text": "praesumptio:", "label": "la"}, {"text": "Olin justkui nende vastutusel.", "label": "et"}, {"text": "Jokainen keinahdus tuo lähemmäksi hetkeä jolloin", "label": "fi"}, {"text": "ekonomicky výhodných způsobů odvodnění těžkých,", "label": "cs"}, {"text": "Poprvé ve své historii dokázala v kvalifikaci pro", "label": "cs"}, {"text": "zpracovatelského a spotřebního průmyslu bude nutné", "label": "cs"}, {"text": "Windows CE zu integrieren .", "label": "de"}, {"text": "Armangué, a través d'un decret, ordenés l'aturada", "label": "ca"}, {"text": "to, co nás Evropany spojuje, než to, co nás od", "label": "cs"}, {"text": "ergänzt durch einen gesetzlich verankertes", "label": "de"}, {"text": "Насчитал, что с начала года всего три дня были", "label": "ru"}, {"text": "Borisovu tražeći od njega da prihvati njenu", "label": "sr"}, {"text": "la presenza di ben veleni diversi: . chili di", "label": "it"}, {"text": "καὶ τῶν ἐκλεκτῶν ἀγγέλων ἵνα ταῦτα φυλάξῃς χωρὶς", "label": "grc"}, {"text": "pretraživale obližnju bolnicu i stambene zgrade u", "label": "hr"}, {"text": "An rund Katzen habe Wolf seine Spiele getestet ,", "label": "de"}, {"text": "investigating since March.", "label": "en"}, {"text": "Tonböden (Mullböden).", "label": "de"}, {"text": "Stálý dopisovatel LN v SRN Bedřich Utitz", "label": "cs"}, {"text": "červnu předložené smlouvy.", "label": "cs"}, {"text": "πνεύματι ᾧ ἐλάλει", "label": "grc"}, {"text": ".%의 신장세를 보였다.", "label": "ko"}, {"text": "Foae verde, foi de nuc, Prin pădure, prin colnic,", "label": "ro"}, {"text": "διαπέμψας ἄλλους ἄλλῃ τοὺς μὲν ἐς Δελφοὺς ἰέναι", "label": "grc"}, {"text": "المسلمين أو أي تيار سياسي طالما عمل ذلك التيار في", "label": "ar"}, {"text": "As informações são da Dow Jones.", "label": "pt"}, {"text": "Milliarde DM ausgestattet sein .", "label": "de"}, {"text": "De utgår fortfarande från att kvinnans jämlikhet", "label": "sv"}, {"text": "Sneeuw maakte in Davos bij de voorbereiding een", "label": "nl"}, {"text": "De ahí que en este mercado puedan negociarse", "label": "es"}, {"text": "intenzívnějšímu sbírání a studiu.", "label": "cs"}, {"text": "और औसकर ४.० पैकेज का प्रयोग किया गया है ।", "label": "hi"}, {"text": "Adipati Kuningan karena Kuningan menjadi bagian", "label": "id"}, {"text": "Svako je bar jednom poželeo da mašine prosto umeju", "label": "sr"}, {"text": "Im vergangenen Jahr haben die Regierungen einen", "label": "de"}, {"text": "durat motus, aliquid fit et non est;", "label": "la"}, {"text": "Dominować będą piosenki do tekstów Edwarda", "label": "pl"}, {"text": "beantwortet .", "label": "de"}, {"text": "О гуманитариях было кому рассказывать, а вот за", "label": "ru"}, {"text": "Helsingin kaupunki riitautti vuokrasopimuksen", "label": "fi"}, {"text": "chợt tan biến.", "label": "vi"}, {"text": "avtomobil ločuje od drugih.", "label": "sl"}, {"text": "Congress has proven itself ineffective as a body.", "label": "en"}, {"text": "मैक्सिको ने इस तरह का शो इस समय आयोजित करने का", "label": "hi"}, {"text": "No minimum order amount.", "label": "en"}, {"text": "Convertassa .", "label": "fi"}, {"text": "Как это можно сделать?", "label": "ru"}, {"text": "tha mi creidsinn gu robh iad ceart cho saor shuas", "label": "gd"}, {"text": "실제 일제는 이런 만해의 논리를 묵살하고 한반도를 침략한 다음 , 이어 만주를 침략하고", "label": "ko"}, {"text": "Da un semplice richiamo all'ordine fino a grandi", "label": "it"}, {"text": "pozoruhodný nejen po umělecké stránce, jež", "label": "cs"}, {"text": "La comida y el servicio aprueban.", "label": "es"}, {"text": "again, connected not with each other but to the", "label": "en"}, {"text": "Protokol výslovně stanoví, že nikdo nemůže být", "label": "cs"}, {"text": "ఒక విషయం అడగాలని ఉంది .", "label": "te"}, {"text": "Безгранично почитая дирекцию, ловя на лету каждое", "label": "ru"}, {"text": "rovnoběžných růstových vrstev, zůstávají krychlové", "label": "cs"}, {"text": "प्रवेश और पूर्व प्रधानमंत्री लाल बहादुर शास्त्री", "label": "hi"}, {"text": "Bronzen medaille in de Europese marathon.", "label": "nl"}, {"text": "- gadu vecumā viņi to nesaprot.", "label": "lv"}, {"text": "Realizó sus estudios primarios en la Escuela Julia", "label": "es"}, {"text": "cuartos de final, su clasificación para la final a", "label": "es"}, {"text": "Sem si pro něho přiletí americký raketoplán, na", "label": "cs"}, {"text": "Way to go!", "label": "en"}, {"text": "gehört der neuen SPD-Führung unter Parteichef", "label": "de"}, {"text": "Somit simuliert der Player mit einer GByte-Platte", "label": "de"}, {"text": "Berufung auf kommissionsnahe Kreise , die bereits", "label": "de"}, {"text": "Dist Clarïen", "label": "fro"}, {"text": "Schon nach den Gerüchten , die Telekom wolle den", "label": "de"}, {"text": "Software von NetObjects ist nach Angaben des", "label": "de"}, {"text": "si enim per legem iustitia ergo Christus gratis", "label": "la"}, {"text": "ducerent in ipsam magis quam in corpus christi,", "label": "la"}, {"text": "Neustar-Melbourne-IT-Partnerschaft NeuLevel .", "label": "de"}, {"text": "forderte dagegen seine drastische Verschärfung.", "label": "de"}, {"text": "pemmican på hundrede forskellige måder.", "label": "da"}, {"text": "Lehån, själv matematiklärare, visar hur den nya", "label": "sv"}, {"text": "I highly recommend his shop.", "label": "en"}, {"text": "verità, giovani fedeli prostratevi #amen", "label": "it"}, {"text": "उत्तर प्रदेश के अध्यक्ष पद से हटाए गए विनय कटियार", "label": "hi"}, {"text": "() روزی مےں کشادگی ہوتی ہے۔", "label": "ur"}, {"text": "Prozessorgeschäft profitieren kann , stellen", "label": "de"}, {"text": "školy začalo počítat pytle s moukou a zjistilo, že", "label": "cs"}, {"text": "प्रभावशाली पर गैर सरकारी लोगों के घरों में भी", "label": "hi"}, {"text": "geschichtslos , oder eine Farce , wie sich", "label": "de"}, {"text": "Ústrednými mocnosťami v marci však spôsobilo, že", "label": "sk"}, {"text": "التسليح بدون مبرر، واستمرار الأضرار الناجمة عن فرض", "label": "ar"}, {"text": "Například Pedagogická fakulta Univerzity Karlovy", "label": "cs"}, {"text": "nostris ut eriperet nos de praesenti saeculo", "label": "la"}] docs = [Document([], text=example["text"]) for example in examples] gold_labels = [example["label"] for example in examples] basic_multilingual(docs) accuracy = sum([(doc.lang == label) for doc,label in zip(docs,gold_labels)])/len(docs) assert accuracy >= 0.98 def test_text_cleaning(basic_multilingual, clean_multilingual): """ Basic test of cleaning text """ docs = ["Bonjour le monde! #thisisfrench #ilovefrance", "Bonjour le monde! https://t.co/U0Zjp3tusD"] docs = [Document([], text=text) for text in docs] basic_multilingual(docs) assert [doc.lang for doc in docs] == ["it", "it"] assert clean_multilingual.processors["langid"]._clean_text clean_multilingual(docs) assert [doc.lang for doc in docs] == ["fr", "fr"] def test_emoji_cleaning(): TEXT = ["Sh'reyan has nice antennae :thumbs_up:", "This is🐱 a cat"] EXPECTED = ["Sh'reyan has nice antennae", "This is a cat"] for text, expected in zip(TEXT, EXPECTED): assert LangIDProcessor.clean_text(text) == expected def test_lang_subset(basic_multilingual, enfr_multilingual, en_multilingual): """ Basic test of restricting output to subset of languages """ docs = ["Bonjour le monde! #thisisfrench #ilovefrance", "Bonjour le monde! https://t.co/U0Zjp3tusD"] docs = [Document([], text=text) for text in docs] basic_multilingual(docs) assert [doc.lang for doc in docs] == ["it", "it"] assert enfr_multilingual.processors["langid"]._model.lang_subset == ["en", "fr"] enfr_multilingual(docs) assert [doc.lang for doc in docs] == ["fr", "fr"] assert en_multilingual.processors["langid"]._model.lang_subset == ["en"] en_multilingual(docs) assert [doc.lang for doc in docs] == ["en", "en"] def test_lang_subset_unlikely_language(en_multilingual): """ Test that the language subset masking chooses a legal language, even if all legal languages are supa unlikely """ sentences = ["你好" * 200] docs = [Document([], text=text) for text in sentences] en_multilingual(docs) assert [doc.lang for doc in docs] == ["en"] processor = en_multilingual.processors['langid'] model = processor._model text_tensor = processor._text_to_tensor(sentences) en_idx = model.tag_to_idx['en'] predictions = model(text_tensor) assert predictions[0, en_idx] < 0, "If this test fails, then regardless of how unlikely it was, the model is predicting the input string is possibly English. Update the test by picking a different combination of languages & input" ================================================ FILE: stanza/tests/langid/test_multilingual.py ================================================ """ Tests specifically for the MultilingualPipeline """ from collections import defaultdict import pytest from stanza.pipeline.multilingual import MultilingualPipeline from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=True, **kwargs): english_text = "This is an English sentence." english_words = ["This", "is", "an", "English", "sentence", "."] english_deps_gold = "\n".join(( "('This', 5, 'nsubj')", "('is', 5, 'cop')", "('an', 5, 'det')", "('English', 5, 'amod')", "('sentence', 0, 'root')", "('.', 5, 'punct')" )) if not en_has_dependencies: english_deps_gold = "" french_text = "C'est une phrase française." french_words = ["C'", "est", "une", "phrase", "française", "."] french_deps_gold = "\n".join(( "(\"C'\", 4, 'nsubj')", "('est', 4, 'cop')", "('une', 4, 'det')", "('phrase', 0, 'root')", "('française', 4, 'amod')", "('.', 4, 'punct')" )) if not fr_has_dependencies: french_deps_gold = "" if 'lang_configs' in kwargs: nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, **kwargs) else: lang_configs = {"en": {"processors": "tokenize,pos,lemma,depparse"}, "fr": {"processors": "tokenize,pos,lemma,depparse"}} nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, lang_configs=lang_configs, **kwargs) docs = [english_text, french_text] docs = nlp(docs) assert docs[0].lang == "en" assert len(docs[0].sentences) == 1 assert [x.text for x in docs[0].sentences[0].words] == english_words assert docs[0].sentences[0].dependencies_string() == english_deps_gold assert len(docs[1].sentences) == 1 assert docs[1].lang == "fr" assert [x.text for x in docs[1].sentences[0].words] == french_words assert docs[1].sentences[0].dependencies_string() == french_deps_gold def test_multilingual_pipeline(): """ Basic test of multilingual pipeline """ run_multilingual_pipeline() def test_multilingual_pipeline_small_cache(): """ Test with the cache size 1 """ run_multilingual_pipeline(max_cache_size=1) def test_multilingual_config(): """ Test with only tokenize for the EN pipeline """ lang_configs = { "en": {"processors": "tokenize"} } run_multilingual_pipeline(en_has_dependencies=False, lang_configs=lang_configs) def test_multilingual_processors_limited(): """ Test loading an available subset of processors """ run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors="tokenize") run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs={"en": {"processors": "tokenize,pos,lemma,depparse"}}, processors="tokenize") # this should not fail, as it will drop the zzzzzzzzzz processor for the languages which don't have it run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors="tokenize,zzzzzzzzzz") def test_defaultdict_config(): """ Test that you can pass in a defaultdict for the lang_configs argument """ lang_configs = defaultdict(lambda: dict(processors="tokenize")) run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs=lang_configs) lang_configs = defaultdict(lambda: dict(processors="tokenize")) lang_configs["en"] = {"processors": "tokenize,pos,lemma,depparse"} run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs=lang_configs) ================================================ FILE: stanza/tests/lemma/__init__.py ================================================ ================================================ FILE: stanza/tests/lemma/test_data.py ================================================ """ Test a couple basic data functions, such as processing a doc for its lemmas """ import pytest from stanza.models.common.doc import Document from stanza.models.lemma.data import DataLoader from stanza.utils.conll import CoNLL TRAIN_DATA = """ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003 # text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad. 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No 2 : : PUNCT : _ 1 punct 1:punct _ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _ 6 that that SCONJ IN _ 9 mark 9:mark _ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _ 8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _ 11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _ 15 in in ADP IN _ 16 case 16:case _ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No 17 . . PUNCT . _ 1 punct 1:punct _ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004 # text = Two of them were being run by 2 officials of the Ministry of the Interior! 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _ 2 of of ADP IN _ 3 case 3:case _ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ 7 by by ADP IN _ 9 case 9:case _ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _ 10 of of ADP IN _ 12 case 12:case _ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _ 13 of of ADP IN _ 15 case 15:case _ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No 16 ! ! PUNCT . _ 6 punct 6:punct _ """.lstrip() GOESWITH_DATA = """ # sent_id = email-enronsent27_01-0041 # newpar id = email-enronsent27_01-p0005 # text = Ken Rice@ENRON COMMUNICATIONS 1 Ken kenrice@enroncommunications X GW Typo=Yes 0 root 0:root _ 2 Rice@ENRON _ X GW _ 1 goeswith 1:goeswith _ 3 COMMUNICATIONS _ X ADD _ 1 goeswith 1:goeswith _ """.lstrip() CORRECT_FORM_DATA = """ # sent_id = weblog-blogspot.com_healingiraq_20040409053012_ENG_20040409_053012-0019 # text = They are targetting ambulances 1 They they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 3 nsubj 3:nsubj _ 2 are be AUX VBP Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _ 3 targetting target VERB VBG Tense=Pres|Typo=Yes|VerbForm=Part 0 root 0:root CorrectForm=targeting 4 ambulances ambulance NOUN NNS Number=Plur 3 obj 3:obj SpaceAfter=No """ BLANKS_DATA = """ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0018 # text = Guerrillas killed an engineer, Asi Ali, from Tikrit. 1 Guerrillas _ NOUN NNS Number=Plur 2 nsubj 2:nsubj _ 2 killed _ VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 3 an a DET DT Definite=Ind|PronType=Art 4 det 4:det _ 4 engineer _ NOUN NN Number=Sing 2 obj 2:obj SpaceAfter=No """.lstrip() def test_load_document(): train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA) data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True) assert len(data) == 33 # meticulously counted by hand assert all(len(x) == 3 for x in data) data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False) assert len(data) == 33 assert all(len(x) == 3 for x in data) def test_load_goeswith(): raw_data = TRAIN_DATA + GOESWITH_DATA train_doc = CoNLL.conll2doc(input_str=raw_data) data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True) assert len(data) == 36 # will be the same as in test_load_document with three additional words assert all(len(x) == 3 for x in data) data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False) assert len(data) == 33 # will be the same as in test_load_document, but with the trailing 3 GOESWITH removed assert all(len(x) == 3 for x in data) def test_correct_form(): raw_data = TRAIN_DATA + CORRECT_FORM_DATA train_doc = CoNLL.conll2doc(input_str=raw_data) data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True) assert len(data) == 37 # the 'targeting' correction should not be applied if evaluation=True # when evaluation=False, then the CorrectForms will be applied assert not any(x[0] == 'targeting' for x in data) data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False) assert len(data) == 38 # the same, but with an extra row so the model learns both 'targetting' and 'targeting' assert any(x[0] == 'targeting' for x in data) assert any(x[0] == 'targetting' for x in data) def test_load_blank(): raw_data = TRAIN_DATA + BLANKS_DATA train_doc = CoNLL.conll2doc(input_str=raw_data) data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False) assert len(data) == 37 # will be the same as in test_load_document with FOUR additional words assert all(len(x) == 3 for x in data) data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=True, evaluation=False) assert len(data) == 34 # will be the same as in test_load_document, but one extra word is added. others were blank assert all(len(x) == 3 for x in data) ================================================ FILE: stanza/tests/lemma/test_lemma_trainer.py ================================================ """ Test a couple basic functions - load & save an existing model """ import pytest import glob import os import tempfile import torch from stanza.models import lemmatizer from stanza.models.lemma import trainer from stanza.tests import * from stanza.utils.training.common import choose_lemma_charlm, build_charlm_args pytestmark = [pytest.mark.pipeline, pytest.mark.travis] @pytest.fixture(scope="module") def english_model(): models_path = os.path.join(TEST_MODELS_DIR, "en", "lemma", "*") models = glob.glob(models_path) # we expect at least one English model downloaded for the tests assert len(models) >= 1, "No English lemma models downloaded during setup! Please make sure to run the setup script." for model_file in models: if "nocharlm" in model_file: return trainer.Trainer(model_file=model_file) raise FileNotFoundError("Should have downloaded the nocharlm English lemmatizer during setup. Please rerun the setup script.") def test_load_model(english_model): """ Does nothing, just tests that loading works """ def test_save_load_model(english_model): """ Load, save, and load again """ with tempfile.TemporaryDirectory() as tempdir: save_file = os.path.join(tempdir, "resaved", "lemma.pt") english_model.save(save_file) reloaded = trainer.Trainer(model_file=save_file) TRAIN_DATA = """ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003 # text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad. 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No 2 : : PUNCT : _ 1 punct 1:punct _ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _ 6 that that SCONJ IN _ 9 mark 9:mark _ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _ 8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _ 11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _ 15 in in ADP IN _ 16 case 16:case _ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No 17 . . PUNCT . _ 1 punct 1:punct _ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004 # text = Two of them were being run by 2 officials of the Ministry of the Interior! 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _ 2 of of ADP IN _ 3 case 3:case _ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ 7 by by ADP IN _ 9 case 9:case _ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _ 10 of of ADP IN _ 12 case 12:case _ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _ 13 of of ADP IN _ 15 case 15:case _ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No 16 ! ! PUNCT . _ 6 punct 6:punct _ """.lstrip() DEV_DATA = """ 1 From from ADP IN _ 3 case 3:case _ 2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _ 3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _ 4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _ 5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _ 6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _ 7 : : PUNCT : _ 4 punct 4:punct _ """.lstrip() class TestLemmatizer: @pytest.fixture(scope="class") def charlm_args(self): charlm = choose_lemma_charlm("en", "test", "default") charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR) return charlm_args def run_training(self, tmp_path, train_text, dev_text, extra_args=None): """ Run the training for a few iterations, load & return the model """ pred_file = str(tmp_path / "pred.conllu") save_name = "test_tagger.pt" save_file = str(tmp_path / save_name) train_file = str(tmp_path / "train.conllu") with open(train_file, "w", encoding="utf-8") as fout: fout.write(train_text) dev_file = str(tmp_path / "dev.conllu") with open(dev_file, "w", encoding="utf-8") as fout: fout.write(dev_text) args = ["--train_file", train_file, "--eval_file", dev_file, "--output_file", pred_file, "--num_epoch", "2", "--log_step", "10", "--save_dir", str(tmp_path), "--save_name", save_name, "--shorthand", "en_test"] if extra_args is not None: args = args + extra_args lemmatizer.main(args) assert os.path.exists(save_file) saved_model = trainer.Trainer(model_file=save_file) return saved_model def test_basic_train(self, tmp_path): """ Simple test of a few 'epochs' of lemmatizer training """ self.run_training(tmp_path, TRAIN_DATA, DEV_DATA) def test_charlm_train(self, tmp_path, charlm_args): """ Simple test of a few 'epochs' of lemmatizer training """ saved_model = self.run_training(tmp_path, TRAIN_DATA, DEV_DATA, extra_args=charlm_args) # check that the charlm wasn't saved in here args = saved_model.args save_name = os.path.join(args['save_dir'], args['save_name']) checkpoint = torch.load(save_name, lambda storage, loc: storage, weights_only=True) assert not any(x.startswith("contextual_embedding") for x in checkpoint['model'].keys()) ================================================ FILE: stanza/tests/lemma/test_lowercase.py ================================================ import pytest from stanza.models.lemmatizer import all_lowercase from stanza.utils.conll import CoNLL LATIN_CONLLU = """ # sent_id = train-s1 # text = unde et philosophus dicit felicitatem esse operationem perfectam. # reference = ittb-scg-s4203 1 unde unde ADV O4 AdvType=Loc|PronType=Rel 4 advmod:lmod _ _ 2 et et CCONJ O4 _ 3 advmod:emph _ _ 3 philosophus philosophus NOUN B1|grn1|casA|gen1 Case=Nom|Gender=Masc|InflClass=IndEurO|Number=Sing 4 nsubj _ _ 4 dicit dico VERB N3|modA|tem1|gen6 Aspect=Imp|InflClass=LatX|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens 5 felicitatem felicitas NOUN C1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 7 nsubj _ _ 6 esse sum AUX N3|modH|tem1 Aspect=Imp|Tense=Pres|VerbForm=Inf 7 cop _ _ 7 operationem operatio NOUN C1|grn1|casD|gen2|vgr1 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 4 ccomp _ _ 8 perfectam perfectus ADJ A1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurA|Number=Sing 7 amod _ SpaceAfter=No 9 . . PUNCT Punc _ 4 punct _ _ # sent_id = train-s2 # text = perfectio autem operationis dependet ex quatuor. # reference = ittb-scg-s4204 1 perfectio perfectio NOUN C1|grn1|casA|gen2 Case=Nom|Gender=Fem|InflClass=IndEurX|Number=Sing 4 nsubj _ _ 2 autem autem PART O4 _ 4 discourse _ _ 3 operationis operatio NOUN C1|grn1|casB|gen2|vgr1 Case=Gen|Gender=Fem|InflClass=IndEurX|Number=Sing 1 nmod _ _ 4 dependet dependeo VERB K3|modA|tem1|gen6 Aspect=Imp|InflClass=LatE|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens 5 ex ex ADP S4|vgr2 _ 6 case _ _ 6 quatuor quattuor NUM G1|gen3|vgr1 NumForm=Word|NumType=Card 4 obl:arg _ SpaceAfter=No 7 . . PUNCT Punc _ 4 punct _ _ """.lstrip() ENG_CONLLU = """ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0007 # text = You wonder if he was manipulating the market with his bombing targets. 1 You you PRON PRP Case=Nom|Person=2|PronType=Prs 2 nsubj 2:nsubj _ 2 wonder wonder VERB VBP Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin 0 root 0:root _ 3 if if SCONJ IN _ 6 mark 6:mark _ 4 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 6 nsubj 6:nsubj _ 5 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _ 6 manipulating manipulate VERB VBG Tense=Pres|VerbForm=Part 2 ccomp 2:ccomp _ 7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _ 8 market market NOUN NN Number=Sing 6 obj 6:obj _ 9 with with ADP IN _ 12 case 12:case _ 10 his his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 12 nmod:poss 12:nmod:poss _ 11 bombing bombing NOUN NN Number=Sing 12 compound 12:compound _ 12 targets target NOUN NNS Number=Plur 6 obl 6:obl:with SpaceAfter=No 13 . . PUNCT . _ 2 punct 2:punct _ """.lstrip() def test_all_lowercase(): doc = CoNLL.conll2doc(input_str=LATIN_CONLLU) assert all_lowercase(doc) def test_not_all_lowercase(): doc = CoNLL.conll2doc(input_str=ENG_CONLLU) assert not all_lowercase(doc) ================================================ FILE: stanza/tests/lemma_classifier/__init__.py ================================================ ================================================ FILE: stanza/tests/lemma_classifier/test_data_preparation.py ================================================ import os import pytest import stanza.models.lemma_classifier.utils as utils import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier pytestmark = [pytest.mark.pipeline, pytest.mark.travis] EWT_ONE_SENTENCE = """ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002 # newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002 # text = Here's a Miami Herald interview 1-2 Here's _ _ _ _ _ _ _ _ 1 Here here ADV RB PronType=Dem 0 root 0:root _ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _ 3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ 4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _ 5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _ 6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _ """.lstrip() EWT_TRAIN_SENTENCES = """ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002 # newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002 # text = Here's a Miami Herald interview 1-2 Here's _ _ _ _ _ _ _ _ 1 Here here ADV RB PronType=Dem 0 root 0:root _ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _ 3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ 4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _ 5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _ 6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0027 # text = But Posada's nearly 80 years old 1 But but CCONJ CC _ 7 cc 7:cc _ 2-3 Posada's _ _ _ _ _ _ _ _ 2 Posada Posada PROPN NNP Number=Sing 7 nsubj 7:nsubj _ 3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 cop 7:cop _ 4 nearly nearly ADV RB _ 5 advmod 5:advmod _ 5 80 80 NUM CD NumForm=Digit|NumType=Card 6 nummod 6:nummod _ 6 years year NOUN NNS Number=Plur 7 obl:npmod 7:obl:npmod _ 7 old old ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0067 # newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0011 # text = Now that's a post I can relate to. 1 Now now ADV RB _ 5 advmod 5:advmod _ 2-3 that's _ _ _ _ _ _ _ _ 2 that that PRON DT Number=Sing|PronType=Dem 5 nsubj 5:nsubj _ 3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _ 4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _ 5 post post NOUN NN Number=Sing 0 root 0:root _ 6 I I PRON PRP Case=Nom|Number=Sing|Person=1|PronType=Prs 8 nsubj 8:nsubj _ 7 can can AUX MD VerbForm=Fin 8 aux 8:aux _ 8 relate relate VERB VB VerbForm=Inf 5 acl:relcl 5:acl:relcl _ 9 to to ADP IN _ 8 obl 8:obl SpaceAfter=No 10 . . PUNCT . _ 5 punct 5:punct _ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0073 # newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0012 # text = hey that's a great blog 1 hey hey INTJ UH _ 6 discourse 6:discourse _ 2-3 that's _ _ _ _ _ _ _ _ 2 that that PRON DT Number=Sing|PronType=Dem 6 nsubj 6:nsubj _ 3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _ 4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ 5 great great ADJ JJ Degree=Pos 6 amod 6:amod _ 6 blog blog NOUN NN Number=Sing 0 root 0:root SpaceAfter=No # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0089 # text = And It's Not Hard To Do 1 And and CCONJ CC _ 5 cc 5:cc _ 2-3 It's _ _ _ _ _ _ _ _ 2 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 expl 5:expl _ 3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _ 4 Not not PART RB _ 5 advmod 5:advmod _ 5 Hard hard ADJ JJ Degree=Pos 0 root 0:root _ 6 To to PART TO _ 7 mark 7:mark _ 7 Do do VERB VB VerbForm=Inf 5 csubj 5:csubj SpaceAfter=No # sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0029 # text = Meanwhile, a decision's been reached 1 Meanwhile meanwhile ADV RB _ 7 advmod 7:advmod SpaceAfter=No 2 , , PUNCT , _ 1 punct 1:punct _ 3 a a DET DT Definite=Ind|PronType=Art 4 det 4:det _ 4-5 decision's _ _ _ _ _ _ _ _ 4 decision decision NOUN NN Number=Sing 7 nsubj:pass 7:nsubj:pass _ 5 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux 7:aux _ 6 been be AUX VBN Tense=Past|VerbForm=Part 7 aux:pass 7:aux:pass _ 7 reached reach VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ # sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0138 # text = It's become a guardian of morality 1-2 It's _ _ _ _ _ _ _ _ 1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|5:nsubj:xsubj _ 2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _ 3 become become VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _ 4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _ 5 guardian guardian NOUN NN Number=Sing 3 xcomp 3:xcomp _ 6 of of ADP IN _ 7 case 7:case _ 7 morality morality NOUN NN Number=Sing 5 nmod 5:nmod:of _ # sent_id = email-enronsent15_01-0018 # text = It's got its own bathroom and tv 1-2 It's _ _ _ _ _ _ _ _ 1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|13:nsubj _ 2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _ 3 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _ 4 its its PRON PRP$ Case=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 6 nmod:poss 6:nmod:poss _ 5 own own ADJ JJ Degree=Pos 6 amod 6:amod _ 6 bathroom bathroom NOUN NN Number=Sing 3 obj 3:obj _ 7 and and CCONJ CC _ 8 cc 8:cc _ 8 tv TV NOUN NN Number=Sing 6 conj 3:obj|6:conj:and SpaceAfter=No # sent_id = newsgroup-groups.google.com_alt.animals.cat_01ff709c4bf2c60c_ENG_20040418_040100-0022 # text = It's also got the website 1-2 It's _ _ _ _ _ _ _ _ 1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _ 2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _ 3 also also ADV RB _ 4 advmod 4:advmod _ 4 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _ 5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _ 6 website website NOUN NN Number=Sing 4 obj 4:obj|12:obl _ """.lstrip() # from the train set, actually EWT_DEV_SENTENCES = """ # sent_id = answers-20111108104724AAuBUR7_ans-0044 # text = He's only exhibited weight loss and some muscle atrophy 1-2 He's _ _ _ _ _ _ _ _ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _ 2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _ 3 only only ADV RB _ 4 advmod 4:advmod _ 4 exhibited exhibit VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _ 5 weight weight NOUN NN Number=Sing 6 compound 6:compound _ 6 loss loss NOUN NN Number=Sing 4 obj 4:obj _ 7 and and CCONJ CC _ 10 cc 10:cc _ 8 some some DET DT PronType=Ind 10 det 10:det _ 9 muscle muscle NOUN NN Number=Sing 10 compound 10:compound _ 10 atrophy atrophy NOUN NN Number=Sing 6 conj 4:obj|6:conj:and SpaceAfter=No # sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0097 # text = It's a good thing too. 1-2 It's _ _ _ _ _ _ _ _ 1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _ 3 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _ 4 good good ADJ JJ Degree=Pos 5 amod 5:amod _ 5 thing thing NOUN NN Number=Sing 0 root 0:root _ 6 too too ADV RB _ 5 advmod 5:advmod SpaceAfter=No 7 . . PUNCT . _ 5 punct 5:punct _ """.lstrip() # from the train set, actually EWT_TEST_SENTENCES = """ # sent_id = reviews-162422-0015 # text = He said he's had a long and bad day. 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 2 nsubj 2:nsubj _ 2 said say VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 3-4 he's _ _ _ _ _ _ _ _ 3 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _ 4 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux 5:aux _ 5 had have VERB VBN Tense=Past|VerbForm=Part 2 ccomp 2:ccomp _ 6 a a DET DT Definite=Ind|PronType=Art 10 det 10:det _ 7 long long ADJ JJ Degree=Pos 10 amod 10:amod _ 8 and and CCONJ CC _ 9 cc 9:cc _ 9 bad bad ADJ JJ Degree=Pos 7 conj 7:conj:and|10:amod _ 10 day day NOUN NN Number=Sing 5 obj 5:obj SpaceAfter=No 11 . . PUNCT . _ 2 punct 2:punct _ # sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0100 # text = What's a few dead soldiers 1-2 What's _ _ _ _ _ _ _ _ 1 What what PRON WP PronType=Int 6 nsubj 6:nsubj _ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _ 3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ 4 few few ADJ JJ Degree=Pos 6 amod 6:amod _ 5 dead dead ADJ JJ Degree=Pos 6 amod 6:amod _ 6 soldiers soldier NOUN NNS Number=Plur 0 root 0:root _ """ def write_test_dataset(tmp_path, texts, datasets): ud_path = tmp_path / "ud" input_path = ud_path / "UD_English-EWT" output_path = tmp_path / "data" / "lemma_classifier" os.makedirs(input_path, exist_ok=True) for text, dataset in zip(texts, datasets): sample_file = input_path / ("en_ewt-ud-%s.conllu" % dataset) with open(sample_file, "w", encoding="utf-8") as fout: fout.write(text) paths = {"UDBASE": ud_path, "LEMMA_CLASSIFIER_DATA_DIR": output_path} return paths def write_english_test_dataset(tmp_path): texts = (EWT_TRAIN_SENTENCES, EWT_DEV_SENTENCES, EWT_TEST_SENTENCES) datasets = prepare_lemma_classifier.SECTIONS return write_test_dataset(tmp_path, texts, datasets) def convert_english_dataset(tmp_path): paths = write_english_test_dataset(tmp_path) converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have") assert len(converted_files) == 3 return converted_files def test_convert_one_sentence(tmp_path): texts = [EWT_ONE_SENTENCE] datasets = ["train"] paths = write_test_dataset(tmp_path, texts, datasets) converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have", ["train"]) assert len(converted_files) == 1 dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False) assert len(dataset) == 1 assert dataset.label_decoder == {'be': 0} id_to_upos = {y: x for x, y in dataset.upos_to_id.items()} for text_batches, _, upos_batches, _ in dataset: assert text_batches == [['Here', "'s", 'a', 'Miami', 'Herald', 'interview']] upos = [id_to_upos[x] for x in upos_batches[0]] assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN'] def test_convert_dataset(tmp_path): converted_files = convert_english_dataset(tmp_path) dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False) assert len(dataset) == 1 label_decoder = dataset.label_decoder assert len(label_decoder) == 2 assert "be" in label_decoder assert "have" in label_decoder for text_batches, _, _, _ in dataset: assert len(text_batches) == 9 dataset = utils.Dataset(converted_files[1], get_counts=True, batch_size=10, shuffle=False) assert len(dataset) == 1 for text_batches, _, _, _ in dataset: assert len(text_batches) == 2 dataset = utils.Dataset(converted_files[2], get_counts=True, batch_size=10, shuffle=False) assert len(dataset) == 1 for text_batches, _, _, _ in dataset: assert len(text_batches) == 2 ================================================ FILE: stanza/tests/lemma_classifier/test_training.py ================================================ import glob import os import pytest pytestmark = [pytest.mark.pipeline, pytest.mark.travis] from stanza.models.lemma_classifier import train_lstm_model from stanza.models.lemma_classifier import train_transformer_model from stanza.models.lemma_classifier.base_model import LemmaClassifier from stanza.models.lemma_classifier.evaluate_models import evaluate_model from stanza.tests import TEST_WORKING_DIR from stanza.tests.lemma_classifier.test_data_preparation import convert_english_dataset @pytest.fixture(scope="module") def pretrain_file(): return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' def test_train_lstm(tmp_path, pretrain_file): converted_files = convert_english_dataset(tmp_path) save_name = str(tmp_path / 'lemma.pt') train_file = converted_files[0] eval_file = converted_files[1] train_args = ['--wordvec_pretrain_file', pretrain_file, '--save_name', save_name, '--train_file', train_file, '--eval_file', eval_file] trainer = train_lstm_model.main(train_args) evaluate_model(trainer.model, eval_file) # test that loading the model works model = LemmaClassifier.load(save_name, None) def test_train_transformer(tmp_path, pretrain_file): converted_files = convert_english_dataset(tmp_path) save_name = str(tmp_path / 'lemma.pt') train_file = converted_files[0] eval_file = converted_files[1] train_args = ['--bert_model', 'hf-internal-testing/tiny-bert', '--save_name', save_name, '--train_file', train_file, '--eval_file', eval_file] trainer = train_transformer_model.main(train_args) evaluate_model(trainer.model, eval_file) # test that loading the model works model = LemmaClassifier.load(save_name, None) ================================================ FILE: stanza/tests/morphseg/__init__.py ================================================ ================================================ FILE: stanza/tests/morphseg/conftest.py ================================================ """ Shared pytest fixtures and configuration """ import pytest from morphseg import MorphemeSegmenter pytestmark = [pytest.mark.travis, pytest.mark.pipeline] @pytest.fixture(scope="session") def english_segmenter(): """ Load English segmenter once for the entire test session """ return MorphemeSegmenter('en') @pytest.fixture(scope="session") def all_segmenters(): """ Load all supported language segmenters """ segmenters = {} for lang in MorphemeSegmenter.PRETRAINED_MODEL_LANGS: segmenters[lang] = MorphemeSegmenter(lang) return segmenters def pytest_configure(config): """ Custom pytest configuration """ config.addinivalue_line( "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" ) config.addinivalue_line( "markers", "multilingual: marks tests that test multiple languages" ) ================================================ FILE: stanza/tests/morphseg/test_integration.py ================================================ """ Integration tests for morphseg """ import pytest from morphseg import MorphemeSegmenter pytestmark = [pytest.mark.travis, pytest.mark.pipeline] class TestIntegration: def test_full_pipeline(self): """Test complete segmentation pipeline""" segmenter = MorphemeSegmenter('en') text = "According to all known laws of aviation, there is no way a bee should be able to fly." result = segmenter.segment(text, output_string=False) # Should segment multiple words assert len(result) > 10 # Each word should have at least one morpheme for word_morphemes in result: assert len(word_morphemes) >= 1 def test_consistency_across_modes(self): """Test that list and string output modes are consistent""" segmenter = MorphemeSegmenter('en') words = ['running', 'dogs', 'aviation'] for word in words: list_result = segmenter.segment(word, output_string=False) string_result = segmenter.segment(word, output_string=True, delimiter=' @@') # String result should be reconstructable from list result expected_string = ' @@'.join(list_result[0]) assert string_result == expected_string, \ f"List and string outputs don't match for '{word}'" def test_unicode_handling(self): """Test handling of unicode characters""" segmenter = MorphemeSegmenter('fr') text = "café résumé" result = segmenter.segment(text, output_string=False) assert isinstance(result, list) assert len(result) >= 1 def test_mixed_case(self): """Test handling of mixed case input""" segmenter = MorphemeSegmenter('en') # Should normalize to lowercase result1 = segmenter.segment('Running', output_string=False) result2 = segmenter.segment('RUNNING', output_string=False) result3 = segmenter.segment('running', output_string=False) # All should produce the same result assert result1 == result2 == result3 ================================================ FILE: stanza/tests/morphseg/test_morpheme_segmenter.py ================================================ """ Tests for MorphemeSegmenter class """ import pytest from morphseg import MorphemeSegmenter pytestmark = [pytest.mark.travis, pytest.mark.pipeline] class TestMorphemeSegmenter: @pytest.fixture(scope="class") def english_segmenter(self): """Load English model once for all tests""" return MorphemeSegmenter('en') def test_basic_segmentation(self, english_segmenter): """Test basic morpheme segmentation""" result = english_segmenter.segment('running', output_string=False) assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], list) assert len(result[0]) >= 1 def test_multiple_words(self, english_segmenter): """Test segmentation of multiple words""" result = english_segmenter.segment('running quickly', output_string=False) assert isinstance(result, list) assert len(result) == 2 for segmentation in result: assert isinstance(segmentation, list) assert len(segmentation) >= 1 def test_known_segmentations(self, english_segmenter): """Test known morpheme segmentations""" test_cases = { 'dogs': ['dog', 's'], 'aviation': ['aviate', 'ion'], 'known': ['know', 'n'], } for word, expected in test_cases.items(): result = english_segmenter.segment(word, output_string=False) assert result[0] == expected, f"Expected {expected}, got {result[0]} for '{word}'" def test_output_string_mode(self, english_segmenter): """Test string output mode""" result = english_segmenter.segment('running quickly', output_string=True) assert isinstance(result, str) assert ' @@' in result # Default delimiter def test_custom_delimiter(self, english_segmenter): """Test custom delimiter in output""" result = english_segmenter.segment('running', output_string=True, delimiter='-') assert isinstance(result, str) assert '-' in result or result == 'running' # May be unsegmented def test_empty_input(self, english_segmenter): """Test handling of empty input""" result = english_segmenter.segment('', output_string=False) assert result == [] result = english_segmenter.segment('', output_string=True) assert result == "" def test_single_character(self, english_segmenter): """Test single character input""" result = english_segmenter.segment('a', output_string=False) assert isinstance(result, list) assert len(result) == 1 assert result[0] == ['a'] def test_punctuation(self, english_segmenter): """Test handling of punctuation""" result = english_segmenter.segment('Hello, world!', output_string=False) assert isinstance(result, list) # Should segment only words, not punctuation assert len(result) > 0 class TestDeterminism: """ Tests to ensure predictions are deterministic """ def test_deterministic_predictions(self): """Test that same input produces same output consistently""" segmenter = MorphemeSegmenter('en') test_words = ['running', 'dogs', 'quickly', 'aviation'] for word in test_words: results = [] for _ in range(5): result = segmenter.segment(word, output_string=False) results.append(result) # All results should be identical for i in range(1, len(results)): assert results[i] == results[0], \ f"Non-deterministic results for '{word}': {results[0]} vs {results[i]}" def test_deterministic_batch(self): """Test determinism with batch processing""" segmenter = MorphemeSegmenter('en') text = "The dogs are running quickly through the fields." results = [] for _ in range(3): result = segmenter.segment(text, output_string=False) results.append(result) # All results should be identical for i in range(1, len(results)): assert results[i] == results[0], \ f"Non-deterministic batch results: {results[0]} vs {results[i]}" class TestMultilingual: @pytest.mark.parametrize("lang", ['cs', 'en', 'es', 'fr', 'hu', 'it', 'la', 'ru']) def test_language_loading(self, lang): """Test that all supported languages can be loaded""" segmenter = MorphemeSegmenter(lang) assert segmenter.lang == lang assert segmenter.sequence_labeller is not None @pytest.mark.parametrize("lang,word", [ ('en', 'running'), ('es', 'corriendo'), ('fr', 'rapidement'), ('ru', 'бегущий'), # Russian instead of German ]) def test_multilingual_segmentation(self, lang, word): """Test segmentation across languages""" if lang not in MorphemeSegmenter.PRETRAINED_MODEL_LANGS: pytest.skip(f"Language {lang} not supported") segmenter = MorphemeSegmenter(lang) result = segmenter.segment(word, output_string=False) assert isinstance(result, list) assert len(result) >= 1 class TestErrorHandling: def test_invalid_language(self): """Test handling of invalid language code""" with pytest.warns(UserWarning): segmenter = MorphemeSegmenter('invalid_lang') assert segmenter.sequence_labeller is None def test_invalid_input_type(self): """Test handling of invalid input types""" segmenter = MorphemeSegmenter('en') with pytest.raises(ValueError, match="Input sequence must be a string"): segmenter.segment(123) with pytest.raises(ValueError, match="Input sequence must be a string"): segmenter.segment(['not', 'a', 'string']) def test_invalid_output_string_type(self): """Test handling of invalid output_string parameter""" segmenter = MorphemeSegmenter('en') with pytest.raises(ValueError, match="output_string must be a boolean"): segmenter.segment('test', output_string='yes') def test_invalid_delimiter_type(self): """Test handling of invalid delimiter parameter""" segmenter = MorphemeSegmenter('en') with pytest.raises(ValueError, match="Delimiter must be a string"): segmenter.segment('test', delimiter=123) def test_model_not_trained(self): """Test error when using untrained model""" segmenter = MorphemeSegmenter('en') segmenter.sequence_labeller = None with pytest.raises(RuntimeError, match="Model not trained"): segmenter.segment('test') class TestModelState: """ Tests to ensure model is in correct state """ def test_model_in_eval_mode(self): """Test that loaded model is in eval mode""" segmenter = MorphemeSegmenter('en') # Check that model is in eval mode assert not segmenter.sequence_labeller.model.model.training, \ "Model should be in eval mode after loading" def test_model_stays_in_eval_mode(self): """Test that model stays in eval mode after predictions""" segmenter = MorphemeSegmenter('en') # Make several predictions for _ in range(3): segmenter.segment('running', output_string=False) # Model should still be in eval mode assert not segmenter.sequence_labeller.model.model.training, \ "Model should remain in eval mode after predictions" ================================================ FILE: stanza/tests/morphseg/test_stanza_integration.py ================================================ """ Integration tests for Stanza MorphSeg Processor Tests the morpheme segmentation processor within the Stanza pipeline """ import pytest import stanza from stanza.models.common.doc import Document from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.travis, pytest.mark.pipeline] class TestMorphSegProcessor: """Tests for the MorphSeg processor in Stanza pipeline""" @pytest.fixture(scope="class") def en_pipeline(self): """Create English pipeline with morphseg processor""" return stanza.Pipeline( lang='en', processors='tokenize,morphseg', model_dir=TEST_MODELS_DIR, download_method=None ) def test_processor_loads(self, en_pipeline): """Test that morphseg processor loads successfully""" assert 'morphseg' in en_pipeline.processors assert en_pipeline.processors['morphseg'] is not None def test_basic_segmentation(self, en_pipeline): """Test basic morpheme segmentation through pipeline""" doc = en_pipeline("running") assert len(doc.sentences) == 1 assert len(doc.sentences[0].words) == 1 word = doc.sentences[0].words[0] assert hasattr(word, 'morphemes') assert isinstance(word.morphemes, list) assert len(word.morphemes) >= 1 def test_known_segmentations(self, en_pipeline): """Test known morpheme segmentations""" # Note: These are actual segmentations from the en2 model # Some words may be unsegmented depending on the model test_cases = { 'dogs': ['dog', 's'], 'aviation': ['aviate', 'ion'], 'known': ['know', 'n'], } for word_text, expected in test_cases.items(): doc = en_pipeline(word_text) word = doc.sentences[0].words[0] assert word.morphemes == expected, \ f"Expected {expected}, got {word.morphemes} for '{word_text}'" def test_segmentation_consistency(self, en_pipeline): """Test that segmentation is consistent and produces valid output""" words = ['running', 'quickly', 'walked', 'playing'] for word_text in words: doc = en_pipeline(word_text) word = doc.sentences[0].words[0] # Should have morphemes attribute assert hasattr(word, 'morphemes') assert isinstance(word.morphemes, list) assert len(word.morphemes) >= 1 # All morphemes should be strings for morpheme in word.morphemes: assert isinstance(morpheme, str) assert len(morpheme) > 0 def test_multiple_words(self, en_pipeline): """Test segmentation of multiple words in a sentence""" doc = en_pipeline("The dogs are running quickly.") # Check that all words have morphemes attribute for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') assert isinstance(word.morphemes, list) assert len(word.morphemes) >= 1 def test_punctuation_handling(self, en_pipeline): """Test that punctuation is handled correctly""" doc = en_pipeline("Hello, world!") # All tokens should have morphemes, including punctuation for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') # Punctuation should be unsegmented if word.text in [',', '!', '.']: assert word.morphemes == [word.text] def test_long_text(self, en_pipeline): """Test processing of longer text""" text = "According to all known laws of aviation, there is no way a bee should be able to fly." doc = en_pipeline(text) # Should have multiple sentences or one long sentence assert len(doc.sentences) >= 1 # Count words with morpheme segmentation total_words = sum(len(sent.words) for sent in doc.sentences) assert total_words > 10 # All words should have morphemes for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') def test_empty_input(self, en_pipeline): """Test handling of empty input""" doc = en_pipeline("") assert len(doc.sentences) == 0 def test_single_character(self, en_pipeline): """Test single character input""" doc = en_pipeline("I") assert len(doc.sentences) == 1 word = doc.sentences[0].words[0] assert word.morphemes == ['i'] # Normalized to lowercase def test_morphemes_attribute_persistence(self, en_pipeline): """Test that morphemes attribute persists through pipeline""" doc = en_pipeline("running quickly") # Store morphemes morphemes_list = [] for sentence in doc.sentences: for word in sentence.words: morphemes_list.append(word.morphemes) # Access again to ensure persistence for i, sentence in enumerate(doc.sentences): for j, word in enumerate(sentence.words): assert hasattr(word, 'morphemes') assert word.morphemes is not None class TestMultilingualMorphSeg: """Test morpheme segmentation across different languages""" @pytest.mark.parametrize("lang,text,expected_word", [ ('en', 'running', 'running'), ('es', 'corriendo', 'corriendo'), ('fr', 'rapidement', 'rapidement'), ('cs', 'běžící', 'běžící'), ('it', 'correndo', 'correndo'), ]) def test_multilingual_support(self, lang, text, expected_word): """Test that different languages can be processed""" try: nlp = stanza.Pipeline( lang=lang, processors='tokenize,morphseg', download_method=None ) doc = nlp(text) assert len(doc.sentences) >= 1 assert len(doc.sentences[0].words) >= 1 word = doc.sentences[0].words[0] assert hasattr(word, 'morphemes') assert isinstance(word.morphemes, list) except Exception as e: pytest.skip(f"Language {lang} not available: {e}") class TestMorphSegWithOtherProcessors: """Test morphseg processor in combination with other processors""" def test_with_mwt(self): """Test morphseg with MWT processor""" nlp = stanza.Pipeline( lang='en', processors='tokenize,mwt,morphseg', download_method=None ) doc = nlp("The dogs are running.") for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') def test_with_pos(self): """Test morphseg with POS tagging""" try: nlp = stanza.Pipeline( lang='en', processors='tokenize,pos,morphseg', download_method=None ) doc = nlp("running quickly") for sentence in doc.sentences: for word in sentence.words: # Should have both POS and morphemes assert hasattr(word, 'morphemes') assert hasattr(word, 'upos') or hasattr(word, 'xpos') except Exception as e: pytest.skip(f"POS processor not available: {e}") def test_with_lemma(self): """Test morphseg with lemmatization""" try: nlp = stanza.Pipeline( lang='en', processors='tokenize,pos,lemma,morphseg', download_method=None ) doc = nlp("The dogs were running quickly.") for sentence in doc.sentences: for word in sentence.words: # Should have both lemma and morphemes assert hasattr(word, 'morphemes') assert hasattr(word, 'lemma') except Exception as e: pytest.skip(f"Lemma processor not available: {e}") class TestMorphSegDeterminism: """Test that morphseg processor produces deterministic results""" def test_deterministic_results(self): """Test that same input produces same output""" nlp = stanza.Pipeline( lang='en', processors='tokenize,morphseg', download_method=None ) text = "running dogs aviation" results = [] for _ in range(3): doc = nlp(text) morphemes = [word.morphemes for sent in doc.sentences for word in sent.words] results.append(morphemes) # All results should be identical for i in range(1, len(results)): assert results[i] == results[0], \ f"Non-deterministic results: {results[0]} vs {results[i]}" def test_batch_determinism(self): """Test determinism with batch processing""" nlp = stanza.Pipeline( lang='en', processors='tokenize,morphseg', download_method=None ) texts = [ "The dogs are running.", "Aviation is amazing.", "Known facts are helpful." ] # Process multiple times all_results = [] for _ in range(2): batch_results = [] for text in texts: doc = nlp(text) morphemes = [word.morphemes for sent in doc.sentences for word in sent.words] batch_results.append(morphemes) all_results.append(batch_results) # Results should be identical assert all_results[0] == all_results[1] class TestMorphSegEdgeCases: """Test edge cases and special inputs""" @pytest.fixture(scope="class") def en_pipeline(self): """Create English pipeline""" return stanza.Pipeline( lang='en', processors='tokenize,morphseg', download_method=None ) def test_numbers(self, en_pipeline): """Test handling of numbers""" doc = en_pipeline("123 456") for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') def test_mixed_case(self, en_pipeline): """Test mixed case handling""" # Should normalize to same result doc1 = en_pipeline("Running") doc2 = en_pipeline("RUNNING") doc3 = en_pipeline("running") morphemes1 = doc1.sentences[0].words[0].morphemes morphemes2 = doc2.sentences[0].words[0].morphemes morphemes3 = doc3.sentences[0].words[0].morphemes assert morphemes1 == morphemes2 == morphemes3 def test_unicode_characters(self, en_pipeline): """Test handling of unicode characters""" doc = en_pipeline("café résumé") for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') assert isinstance(word.morphemes, list) def test_special_characters(self, en_pipeline): """Test handling of special characters""" doc = en_pipeline("test@example.com $100 50%") for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') def test_very_long_word(self, en_pipeline): """Test handling of very long words""" long_word = "antidisestablishmentarianism" doc = en_pipeline(long_word) word = doc.sentences[0].words[0] assert hasattr(word, 'morphemes') assert len(word.morphemes) >= 1 def test_repeated_words(self, en_pipeline): """Test handling of repeated words""" doc = en_pipeline("running running running") # All instances should have same segmentation morphemes_list = [word.morphemes for word in doc.sentences[0].words] assert morphemes_list[0] == morphemes_list[1] == morphemes_list[2] def test_whitespace_handling(self, en_pipeline): """Test handling of various whitespace""" doc = en_pipeline("word1 word2\tword3\nword4") # Should properly segment all words despite whitespace word_count = sum(len(sent.words) for sent in doc.sentences) assert word_count >= 4 for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') class TestMorphSegConfiguration: """Test different configurations of morphseg processor""" def test_custom_model_path(self): """Test loading with custom model path configuration""" # Test that the configuration accepts model_path parameter # Using default behavior (no custom path) nlp = stanza.Pipeline( lang='en', processors='tokenize,morphseg', download_method=None ) doc = nlp("testing") assert len(doc.sentences) > 0 assert hasattr(doc.sentences[0].words[0], 'morphemes') def test_custom_model_path_with_file(self): """Test loading with an actual custom model file path""" # This test would require a custom model file to exist # Skip if no custom model is available pytest.skip("Custom model path test requires a specific model file") def test_processor_requirements(self): """Test that morphseg requires tokenize""" # MorphSeg requires TOKENIZE processor nlp = stanza.Pipeline( lang='en', processors='tokenize,morphseg', download_method=None ) # Verify tokenize is present assert 'tokenize' in nlp.processors or 'tokenize' in str(nlp.processors) class TestMorphSegOutputFormat: """Test output format of morpheme segmentations""" @pytest.fixture(scope="class") def en_pipeline(self): """Create English pipeline""" return stanza.Pipeline( lang='en', processors='tokenize,morphseg', download_method=None ) def test_morphemes_is_list(self, en_pipeline): """Test that morphemes attribute is always a list""" doc = en_pipeline("The dogs are running quickly.") for sentence in doc.sentences: for word in sentence.words: assert isinstance(word.morphemes, list) def test_morphemes_are_strings(self, en_pipeline): """Test that all morphemes are strings""" doc = en_pipeline("The dogs are running quickly.") for sentence in doc.sentences: for word in sentence.words: for morpheme in word.morphemes: assert isinstance(morpheme, str) def test_morphemes_non_empty(self, en_pipeline): """Test that morphemes list is never empty""" doc = en_pipeline("The dogs are running quickly.") for sentence in doc.sentences: for word in sentence.words: assert len(word.morphemes) >= 1 def test_unsegmented_words(self, en_pipeline): """Test that unsegmented words have single morpheme""" # Words like 'the', 'is', 'a' typically don't segment doc = en_pipeline("The dog is a pet.") for sentence in doc.sentences: for word in sentence.words: # Even if unsegmented, should have the word itself as morpheme if len(word.morphemes) == 1: # The single morpheme should match the normalized word assert isinstance(word.morphemes[0], str) class TestMorphSegRepeatedly: """Test repeated processing of multiple documents""" def test_sequential_document_processing(self): """Test processing multiple documents one after another""" nlp = stanza.Pipeline( lang='en', processors='tokenize,morphseg', download_method=None ) texts = [ "The dogs are running.", "Aviation is fascinating.", "Programming requires patience." ] for text in texts: doc = nlp(text) for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') assert isinstance(word.morphemes, list) def test_multi_sentence_document(self): """Test processing a document with multiple sentences (internal batching)""" nlp = stanza.Pipeline( lang='en', processors='tokenize,morphseg', download_method=None ) doc = nlp("The dogs are running. Aviation is fascinating. Programming requires patience.") assert len(doc.sentences) == 3 for sentence in doc.sentences: for word in sentence.words: assert hasattr(word, 'morphemes') assert isinstance(word.morphemes, list) ================================================ FILE: stanza/tests/mwt/__init__.py ================================================ ================================================ FILE: stanza/tests/mwt/test_character_classifier.py ================================================ import os import pytest from stanza.models import mwt_expander from stanza.models.mwt.character_classifier import CharacterClassifier from stanza.models.mwt.data import DataLoader from stanza.models.mwt.trainer import Trainer from stanza.utils.conll import CoNLL pytestmark = [pytest.mark.pipeline, pytest.mark.travis] ENG_TRAIN = """ # text = Elena's motorcycle tour 1-2 Elena's _ _ _ _ _ _ _ _ 1 Elena Elena PROPN NNP Number=Sing 4 nmod:poss 4:nmod:poss _ 2 's 's PART POS _ 1 case 1:case _ 3 motorcycle motorcycle NOUN NN Number=Sing 4 compound 4:compound _ 4 tour tour NOUN NN Number=Sing 0 root 0:root _ # text = women's reproductive health 1-2 women's _ _ _ _ _ _ _ _ 1 women woman NOUN NNS Number=Plur 4 nmod:poss 4:nmod:poss _ 2 's 's PART POS _ 1 case 1:case _ 3 reproductive reproductive ADJ JJ Degree=Pos 4 amod 4:amod _ 4 health health NOUN NN Number=Sing 0 root 0:root SpaceAfter=No # text = The Chernobyl Children's Project 1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _ 2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _ 3-4 Children's _ _ _ _ _ _ _ _ 3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _ 4 's 's PART POS _ 3 case 3:case _ 5 Project Project PROPN NNP Number=Sing 0 root 0:root _ """.lstrip() ENG_DEV = """ # text = The Chernobyl Children's Project 1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _ 2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _ 3-4 Children's _ _ _ _ _ _ _ _ 3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _ 4 's 's PART POS _ 3 case 3:case _ 5 Project Project PROPN NNP Number=Sing 0 root 0:root _ """.lstrip() def test_train(tmp_path): test_train = str(os.path.join(tmp_path, "en_test.train.conllu")) with open(test_train, "w") as fout: fout.write(ENG_TRAIN) test_dev = str(os.path.join(tmp_path, "en_test.dev.conllu")) with open(test_dev, "w") as fout: fout.write(ENG_DEV) test_output = str(os.path.join(tmp_path, "en_test.dev.pred.conllu")) model_name = "en_test_mwt.pt" args = [ "--data_dir", str(tmp_path), "--train_file", test_train, "--eval_file", test_dev, "--gold_file", test_dev, "--lang", "en", "--shorthand", "en_test", "--output_file", test_output, "--save_dir", str(tmp_path), "--save_name", model_name, "--num_epoch", "10", ] mwt_expander.main(args=args) model = Trainer(model_file=os.path.join(tmp_path, model_name)) assert model.model is not None assert isinstance(model.model, CharacterClassifier) doc = CoNLL.conll2doc(input_str=ENG_DEV) dataloader = DataLoader(doc, 10, model.args, vocab=model.vocab, evaluation=True, expand_unk_vocab=True) preds = [] for i, batch in enumerate(dataloader.to_loader()): assert i == 0 # there should only be one batch preds += model.predict(batch, never_decode_unk=True, vocab=dataloader.vocab) assert len(preds) == 1 # it is possible to make a version of the test where this happens almost every time # for example, running for 100 epochs makes the model succeed 30 times in a row # (never saw a failure) # but the one time that failure happened, it would be really annoying #assert preds[0] == "Children 's" ================================================ FILE: stanza/tests/mwt/test_english_corner_cases.py ================================================ """ Test a couple English MWT corner cases which might be more widely applicable to other MWT languages - unknown English character doesn't result in bizarre splits - Casing or CASING doesn't get lost in the dictionary lookup In the English UD datasets, the MWT are composed exactly of the subwords, so the MWT model should be chopping up the input text rather than generating new text. Furthermore, SHE'S and She's should be split "SHE 'S" and "She 's" respectively """ import pytest import stanza from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def test_mwt_unknown_char(): pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None) mwt_trainer = pipeline.processors['mwt']._trainer assert mwt_trainer.args['force_exact_pieces'] # find a letter 'i' which isn't in the training data # the MWT model should still recognize a possessive containing this letter assert "i" in mwt_trainer.vocab for letter in "ĩîíìī": if letter not in mwt_trainer.vocab: break else: raise AssertionError("Need to update the MWT test - all of the non-standard letters 'i' are now in the MWT vocab") word = "Jenn" + letter + "fer" possessive = word + "'s" text = "I wanna lick " + possessive + " antennae" doc = pipeline(text) assert doc.sentences[0].tokens[1].text == 'wanna' assert len(doc.sentences[0].tokens[1].words) == 2 assert "".join(x.text for x in doc.sentences[0].tokens[1].words) == 'wanna' assert doc.sentences[0].tokens[3].text == possessive assert len(doc.sentences[0].tokens[3].words) == 2 assert "".join(x.text for x in doc.sentences[0].tokens[3].words) == possessive def test_english_mwt_casing(): """ Test that for a word where the lowercase split is known, the correct casing is still used Once upon a time, the logic used in the MWT expander would split SHE'S -> she 's which is a very surprising tokenization to people expecting the original text in the output document """ pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None) mwt_trainer = pipeline.processors['mwt']._trainer for i in range(1, 20): # many test cases follow this pattern for some reason, # so we should proactively look for a test case which hasn't # made its way into the MWT dictionary unknown_name = "jennife" + "r" * i + "'s" if unknown_name not in mwt_trainer.expansion_dict and unknown_name.upper() not in mwt_trainer.expansion_dict: unknown_name = unknown_name.upper() break else: raise AssertionError("Need a new heuristic for the unknown word in the English MWT!") # this SHOULD show up in the expansion dict assert "she's" in mwt_trainer.expansion_dict, "Expected |she's| to be in the English MWT expansion dict... perhaps find a different test case" text = [x.text for x in pipeline("JENNIFER HAS NICE ANTENNAE").sentences[0].words] assert text == ['JENNIFER', 'HAS', 'NICE', 'ANTENNAE'] text = [x.text for x in pipeline(unknown_name + " GOT NICE ANTENNAE").sentences[0].words] assert text == [unknown_name[:-2], "'S", 'GOT', 'NICE', 'ANTENNAE'] text = [x.text for x in pipeline("SHE'S GOT NICE ANTENNAE").sentences[0].words] assert text == ['SHE', "'S", 'GOT', 'NICE', 'ANTENNAE'] text = [x.text for x in pipeline("She's GOT NICE ANTENNAE").sentences[0].words] assert text == ['She', "'s", 'GOT', 'NICE', 'ANTENNAE'] ================================================ FILE: stanza/tests/mwt/test_prepare_mwt.py ================================================ import pytest pytestmark = [pytest.mark.pipeline, pytest.mark.travis] from stanza.utils.datasets.prepare_mwt_treebank import check_mwt_composition SAMPLE_GOOD_TEXT = """ # sent_id = weblog-typepad.com_ripples_20040407125600_ENG_20040407_125600-0057 # text = The Chernobyl Children's Project (http://www.adiccp.org/home/default.asp) offers several ways to help the children of that region. 1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _ 2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _ 3-4 Children's _ _ _ _ _ _ _ _ 3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _ 4 's 's PART POS _ 3 case 3:case _ 5 Project Project PROPN NNP Number=Sing 9 nsubj 9:nsubj _ 6 ( ( PUNCT -LRB- _ 7 punct 7:punct SpaceAfter=No 7 http://www.adiccp.org/home/default.asp http://www.adiccp.org/home/default.asp X ADD _ 5 appos 5:appos SpaceAfter=No 8 ) ) PUNCT -RRB- _ 7 punct 7:punct _ 9 offers offer VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _ 10 several several ADJ JJ Degree=Pos 11 amod 11:amod _ 11 ways way NOUN NNS Number=Plur 9 obj 9:obj _ 12 to to PART TO _ 13 mark 13:mark _ 13 help help VERB VB VerbForm=Inf 11 acl 11:acl:to _ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _ 15 children child NOUN NNS Number=Plur 13 obj 13:obj _ 16 of of ADP IN _ 18 case 18:case _ 17 that that DET DT Number=Sing|PronType=Dem 18 det 18:det _ 18 region region NOUN NN Number=Sing 15 nmod 15:nmod:of SpaceAfter=No 19 . . PUNCT . _ 9 punct 9:punct _ """.lstrip() SAMPLE_BAD_TEXT = """ # sent_id = weblog-typepad.com_ripples_20040407125600_ENG_20040407_125600-0057 # text = The Chernobyl Children's Project (http://www.adiccp.org/home/default.asp) offers several ways to help the children of that region. 1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _ 2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _ 3-4 Children's _ _ _ _ _ _ _ _ 3 Childrez Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _ 4 's 's PART POS _ 3 case 3:case _ 5 Project Project PROPN NNP Number=Sing 9 nsubj 9:nsubj _ 6 ( ( PUNCT -LRB- _ 7 punct 7:punct SpaceAfter=No 7 http://www.adiccp.org/home/default.asp http://www.adiccp.org/home/default.asp X ADD _ 5 appos 5:appos SpaceAfter=No 8 ) ) PUNCT -RRB- _ 7 punct 7:punct _ 9 offers offer VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _ 10 several several ADJ JJ Degree=Pos 11 amod 11:amod _ 11 ways way NOUN NNS Number=Plur 9 obj 9:obj _ 12 to to PART TO _ 13 mark 13:mark _ 13 help help VERB VB VerbForm=Inf 11 acl 11:acl:to _ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _ 15 children child NOUN NNS Number=Plur 13 obj 13:obj _ 16 of of ADP IN _ 18 case 18:case _ 17 that that DET DT Number=Sing|PronType=Dem 18 det 18:det _ 18 region region NOUN NN Number=Sing 15 nmod 15:nmod:of SpaceAfter=No 19 . . PUNCT . _ 9 punct 9:punct _ """.lstrip() def test_check_mwt_composition(tmp_path): mwt_file = tmp_path / "good.mwt" with open(mwt_file, "w", encoding="utf-8") as fout: fout.write(SAMPLE_GOOD_TEXT) check_mwt_composition(mwt_file) mwt_file = tmp_path / "bad.mwt" with open(mwt_file, "w", encoding="utf-8") as fout: fout.write(SAMPLE_BAD_TEXT) with pytest.raises(ValueError): check_mwt_composition(mwt_file) ================================================ FILE: stanza/tests/mwt/test_utils.py ================================================ """ Test the MWT resplitting of preexisting tokens without word splits """ import pytest import stanza from stanza.models.mwt.utils import resplit_mwt from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] @pytest.fixture(scope="module") def pipeline(): """ A reusable pipeline with the NER module """ return stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,mwt", package="gum") def test_resplit_keep_tokens(pipeline): """ Test splitting with enforced token boundaries """ tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]] doc = resplit_mwt(tokens, pipeline) assert len(doc.sentences) == 2 assert len(doc.sentences[0].tokens) == 4 assert len(doc.sentences[0].tokens[1].words) == 2 assert doc.sentences[0].tokens[1].words[0].text == "ca" assert doc.sentences[0].tokens[1].words[1].text == "n't" assert len(doc.sentences[1].tokens) == 2 # updated GUM MWT splits "I can't" into three segments # the way we want, "I - ca - n't" # previously it would split "I - can - 't" assert len(doc.sentences[1].tokens[0].words) == 3 assert doc.sentences[1].tokens[0].words[0].text == "I" assert doc.sentences[1].tokens[0].words[1].text == "ca" assert doc.sentences[1].tokens[0].words[2].text == "n't" def test_resplit_no_keep_tokens(pipeline): """ Test splitting without enforced token boundaries """ tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]] doc = resplit_mwt(tokens, pipeline, keep_tokens=False) assert len(doc.sentences) == 2 assert len(doc.sentences[0].tokens) == 4 assert len(doc.sentences[0].tokens[1].words) == 2 assert doc.sentences[0].tokens[1].words[0].text == "ca" assert doc.sentences[0].tokens[1].words[1].text == "n't" assert len(doc.sentences[1].tokens) == 3 assert len(doc.sentences[1].tokens[1].words) == 2 assert doc.sentences[1].tokens[1].words[0].text == "ca" assert doc.sentences[1].tokens[1].words[1].text == "n't" ================================================ FILE: stanza/tests/ner/__init__.py ================================================ ================================================ FILE: stanza/tests/ner/test_bsf_2_beios.py ================================================ """ Tests the conversion code for the lang_uk NER dataset """ import unittest from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo import pytest pytestmark = [pytest.mark.travis, pytest.mark.pipeline] class TestBsf2Beios(unittest.TestCase): def test_empty_markup(self): res = convert_bsf('', '') self.assertEqual('', res) def test_1line_markup(self): data = 'тележурналіст Василь' bsf_markup = 'T1 PERS 14 20 Василь' expected = '''тележурналіст O Василь S-PERS''' self.assertEqual(expected, convert_bsf(data, bsf_markup)) def test_1line_follow_markup(self): data = 'тележурналіст Василь .' bsf_markup = 'T1 PERS 14 20 Василь' expected = '''тележурналіст O Василь S-PERS . O''' self.assertEqual(expected, convert_bsf(data, bsf_markup)) def test_1line_2tok_markup(self): data = 'тележурналіст Василь Нагірний .' bsf_markup = 'T1 PERS 14 29 Василь Нагірний' expected = '''тележурналіст O Василь B-PERS Нагірний E-PERS . O''' self.assertEqual(expected, convert_bsf(data, bsf_markup)) def test_1line_Long_tok_markup(self): data = 'А в музеї Гуцульщини і Покуття можна ' bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття' expected = '''А O в O музеї B-ORG Гуцульщини I-ORG і I-ORG Покуття E-ORG можна O''' self.assertEqual(expected, convert_bsf(data, bsf_markup)) def test_2line_2tok_markup(self): data = '''тележурналіст Василь Нагірний . В івано-франківському видавництві «Лілея НВ» вийшла друком''' bsf_markup = '''T1 PERS 14 29 Василь Нагірний T2 ORG 67 75 Лілея НВ''' expected = '''тележурналіст O Василь B-PERS Нагірний E-PERS . O В O івано-франківському O видавництві O « O Лілея B-ORG НВ E-ORG » O вийшла O друком O''' self.assertEqual(expected, convert_bsf(data, bsf_markup)) def test_real_markup(self): data = '''Через напіввоєнний стан в Україні та збільшення телефонних терористичних погроз українці купуватимуть sim-карти тільки за паспортами . Про це повідомив начальник управління зв'язків зі ЗМІ адміністрації Держспецзв'язку Віталій Кукса . Він зауважив , що днями відомство опублікує проект змін до правил надання телекомунікаційних послуг , де будуть прописані норми ідентифікації громадян . Абонентів , які на сьогодні вже мають sim-карту , за словами Віталія Кукси , реєструватимуть , коли ті звертатимуться в службу підтримки свого оператора мобільного зв'язку . Однак мобільні оператори побоюються , що таке нововведення помітно зменшить продаж стартових пакетів , адже спеціалізовані магазини є лише у містах . Відтак купити сімку в невеликих населених пунктах буде неможливо . Крім того , нова процедура ідентифікації абонентів вимагатиме від операторів мобільного зв'язку додаткових витрат . - Близько 90 % українських абонентів - це абоненти передоплати . Якщо мова буде йти навіть про поетапну їх ідентифікацію , зробити це буде складно , довго і дорого . Мобільним операторам доведеться йти на чималі витрати , пов'язані з укладанням і зберіганням договорів , веденням баз даних , - розповіла « Економічній правді » начальник відділу зв'язків з громадськістю « МТС-Україна » Вікторія Рубан . ''' bsf_markup = '''T1 LOC 26 33 Україні T2 ORG 203 218 Держспецзв'язку T3 PERS 219 232 Віталій Кукса T4 PERS 449 462 Віталія Кукси T5 ORG 1201 1219 Економічній правді T6 ORG 1267 1278 МТС-Україна T7 PERS 1281 1295 Вікторія Рубан ''' expected = '''Через O напіввоєнний O стан O в O Україні S-LOC та O збільшення O телефонних O терористичних O погроз O українці O купуватимуть O sim-карти O тільки O за O паспортами O . O Про O це O повідомив O начальник O управління O зв'язків O зі O ЗМІ O адміністрації O Держспецзв'язку S-ORG Віталій B-PERS Кукса E-PERS . O Він O зауважив O , O що O днями O відомство O опублікує O проект O змін O до O правил O надання O телекомунікаційних O послуг O , O де O будуть O прописані O норми O ідентифікації O громадян O . O Абонентів O , O які O на O сьогодні O вже O мають O sim-карту O , O за O словами O Віталія B-PERS Кукси E-PERS , O реєструватимуть O , O коли O ті O звертатимуться O в O службу O підтримки O свого O оператора O мобільного O зв'язку O . O Однак O мобільні O оператори O побоюються O , O що O таке O нововведення O помітно O зменшить O продаж O стартових O пакетів O , O адже O спеціалізовані O магазини O є O лише O у O містах O . O Відтак O купити O сімку O в O невеликих O населених O пунктах O буде O неможливо O . O Крім O того O , O нова O процедура O ідентифікації O абонентів O вимагатиме O від O операторів O мобільного O зв'язку O додаткових O витрат O . O - O Близько O 90 O % O українських O абонентів O - O це O абоненти O передоплати O . O Якщо O мова O буде O йти O навіть O про O поетапну O їх O ідентифікацію O , O зробити O це O буде O складно O , O довго O і O дорого O . O Мобільним O операторам O доведеться O йти O на O чималі O витрати O , O пов'язані O з O укладанням O і O зберіганням O договорів O , O веденням O баз O даних O , O - O розповіла O « O Економічній B-ORG правді E-ORG » O начальник O відділу O зв'язків O з O громадськістю O « O МТС-Україна S-ORG » O Вікторія B-PERS Рубан E-PERS . O''' self.assertEqual(expected, convert_bsf(data, bsf_markup)) class TestBsf(unittest.TestCase): def test_empty_bsf(self): self.assertEqual(parse_bsf(''), []) def test_empty2_bsf(self): self.assertEqual(parse_bsf(' \n \n'), []) def test_1line_bsf(self): bsf = 'T1 PERS 103 118 Василь Нагірний' res = parse_bsf(bsf) expected = BsfInfo('T1', 'PERS', 103, 118, 'Василь Нагірний') self.assertEqual(len(res), 1) self.assertEqual(res, [expected]) def test_2line_bsf(self): bsf = '''T9 PERS 778 783 Карла T10 MISC 814 819 міста''' res = parse_bsf(bsf) expected = [BsfInfo('T9', 'PERS', 778, 783, 'Карла'), BsfInfo('T10', 'MISC', 814, 819, 'міста')] self.assertEqual(len(res), 2) self.assertEqual(res, expected) def test_multiline_bsf(self): bsf = '''T3 PERS 220 235 Андрієм Кіщуком T4 MISC 251 285 А . Kubler . Світло і тіні маестро T5 PERS 363 369 Кіблер''' res = parse_bsf(bsf) expected = [BsfInfo('T3', 'PERS', 220, 235, 'Андрієм Кіщуком'), BsfInfo('T4', 'MISC', 251, 285, '''А . Kubler . Світло і тіні маестро'''), BsfInfo('T5', 'PERS', 363, 369, 'Кіблер')] self.assertEqual(len(res), len(expected)) self.assertEqual(res, expected) if __name__ == '__main__': unittest.main() ================================================ FILE: stanza/tests/ner/test_bsf_2_iob.py ================================================ """ Tests the conversion code for the lang_uk NER dataset """ import unittest from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo import pytest pytestmark = [pytest.mark.travis, pytest.mark.pipeline] class TestBsf2Iob(unittest.TestCase): def test_1line_follow_markup_iob(self): data = 'тележурналіст Василь .' bsf_markup = 'T1 PERS 14 20 Василь' expected = '''тележурналіст O Василь B-PERS . O''' self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob')) def test_1line_2tok_markup_iob(self): data = 'тележурналіст Василь Нагірний .' bsf_markup = 'T1 PERS 14 29 Василь Нагірний' expected = '''тележурналіст O Василь B-PERS Нагірний I-PERS . O''' self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob')) def test_1line_Long_tok_markup_iob(self): data = 'А в музеї Гуцульщини і Покуття можна ' bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття' expected = '''А O в O музеї B-ORG Гуцульщини I-ORG і I-ORG Покуття I-ORG можна O''' self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob')) def test_2line_2tok_markup_iob(self): data = '''тележурналіст Василь Нагірний . В івано-франківському видавництві «Лілея НВ» вийшла друком''' bsf_markup = '''T1 PERS 14 29 Василь Нагірний T2 ORG 67 75 Лілея НВ''' expected = '''тележурналіст O Василь B-PERS Нагірний I-PERS . O В O івано-франківському O видавництві O « O Лілея B-ORG НВ I-ORG » O вийшла O друком O''' self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob')) def test_all_multiline_iob(self): data = '''його книжечка «А . Kubler . Світло і тіні маестро» . Причому''' bsf_markup = '''T4 MISC 15 49 А . Kubler . Світло і тіні маестро ''' expected = '''його O книжечка O « O А B-MISC . I-MISC Kubler I-MISC . I-MISC Світло I-MISC і I-MISC тіні I-MISC маестро I-MISC » O . O Причому O''' self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob')) if __name__ == '__main__': unittest.main() ================================================ FILE: stanza/tests/ner/test_combine_ner_datasets.py ================================================ import json import os import pytest pytestmark = [pytest.mark.travis, pytest.mark.pipeline] from stanza.models.common.doc import Document from stanza.tests.ner.test_ner_training import write_temp_file, EN_TRAIN_BIO, EN_DEV_BIO from stanza.utils.datasets.ner import combine_ner_datasets def test_combine(tmp_path): """ Test that if we write two short datasets and combine them, we get back one slightly longer dataset To simplify matters, we just use the same input text with longer amounts of text for each shard. """ SHARDS = ("train", "dev", "test") for s_num, shard in enumerate(SHARDS): t1_json = tmp_path / ("en_t1.%s.json" % shard) # eg, 1x, 2x, 3x the test data from test_ner_training write_temp_file(t1_json, "\n\n".join([EN_TRAIN_BIO] * (s_num + 1))) t2_json = tmp_path / ("en_t2.%s.json" % shard) write_temp_file(t2_json, "\n\n".join([EN_DEV_BIO] * (s_num + 1))) args = ["--output_dataset", "en_c", "en_t1", "en_t2", "--input_dir", str(tmp_path), "--output_dir", str(tmp_path)] combine_ner_datasets.main(args) for s_num, shard in enumerate(SHARDS): filename = tmp_path / ("en_c.%s.json" % shard) assert os.path.exists(filename) with open(filename, encoding="utf-8") as fin: doc = Document(json.load(fin)) assert len(doc.sentences) == (s_num + 1) * 3 ================================================ FILE: stanza/tests/ner/test_convert_amt.py ================================================ """ Test some of the functions used for converting an AMT json to a Stanza json """ import os import pytest import stanza from stanza.utils.datasets.ner import convert_amt from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] TEXT = "Jennifer Sh'reyan has lovely antennae." def fake_label(label, start_char, end_char): return {'label': label, 'startOffset': start_char, 'endOffset': end_char} LABELS = [ fake_label('Person', 0, 8), fake_label('Person', 9, 17), fake_label('Person', 0, 17), fake_label('Andorian', 0, 8), fake_label('Appendage', 29, 37), fake_label('Person', 1, 8), fake_label('Person', 0, 7), fake_label('Person', 0, 9), fake_label('Appendage', 29, 38), ] def fake_labels(*indices): return [LABELS[x] for x in indices] def fake_docs(*indices): return [(TEXT, fake_labels(*indices))] def test_remove_nesting(): """ Test a few orders on nested items to make sure the desired results are coming back """ # this should be unchanged result = convert_amt.remove_nesting(fake_docs(0, 1)) assert result == fake_docs(0, 1) # this should be returned sorted result = convert_amt.remove_nesting(fake_docs(0, 4, 1)) assert result == fake_docs(0, 1, 4) # this should just have one copy result = convert_amt.remove_nesting(fake_docs(0, 0)) assert result == fake_docs(0) # outer one preferred result = convert_amt.remove_nesting(fake_docs(0, 2)) assert result == fake_docs(2) result = convert_amt.remove_nesting(fake_docs(1, 2)) assert result == fake_docs(2) result = convert_amt.remove_nesting(fake_docs(5, 2)) assert result == fake_docs(2) # order doesn't matter result = convert_amt.remove_nesting(fake_docs(0, 4, 2)) assert result == fake_docs(2, 4) result = convert_amt.remove_nesting(fake_docs(2, 4, 0)) assert result == fake_docs(2, 4) # first one preferred result = convert_amt.remove_nesting(fake_docs(0, 3)) assert result == fake_docs(0) result = convert_amt.remove_nesting(fake_docs(3, 0)) assert result == fake_docs(3) def test_process_doc(): nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize", download_method=None) def check_results(doc, *expected): ner = [x[1] for x in doc[0]] assert ner == list(expected) # test a standard case of all the values lining up doc = convert_amt.process_doc(TEXT, fake_labels(2, 4), nlp) check_results(doc, "B-Person", "I-Person", "O", "O", "B-Appendage", "O") # test a slightly wrong start index doc = convert_amt.process_doc(TEXT, fake_labels(5, 1, 4), nlp) check_results(doc, "B-Person", "B-Person", "O", "O", "B-Appendage", "O") # test a slightly wrong end index doc = convert_amt.process_doc(TEXT, fake_labels(6, 1, 4), nlp) check_results(doc, "B-Person", "B-Person", "O", "O", "B-Appendage", "O") # test a slightly wronger end index doc = convert_amt.process_doc(TEXT, fake_labels(7, 4), nlp) check_results(doc, "B-Person", "O", "O", "O", "B-Appendage", "O") # test a period at the end of a text - should not be captured doc = convert_amt.process_doc(TEXT, fake_labels(7, 8), nlp) check_results(doc, "B-Person", "O", "O", "O", "B-Appendage", "O") ================================================ FILE: stanza/tests/ner/test_convert_nkjp.py ================================================ import pytest import io import os import xml.etree.ElementTree as ET from stanza.utils.datasets.ner.convert_nkjp import MORPH_FILE, NER_FILE, extract_entities_from_subfolder, extract_entities_from_sentence, extract_unassigned_subfolder_entities pytestmark = [pytest.mark.pipeline, pytest.mark.travis] EXPECTED_ENTITIES = { '1-p': { '1.39-s': [{'ent_id': 'named_1.39-s_n1', 'index': 0, 'orth': 'Sił Zbrojnych', 'ner_type': 'orgName', 'ner_subtype': None, 'targets': ['1.37-seg', '1.38-seg']}], '1.56-s': [], '1.79-s': [] }, '2-p': { '2.30-s': [], '2.45-s': [] }, '3-p': { '3.70-s': [] } } @pytest.fixture(scope="module") def dataset(tmp_path_factory): dataset_path = tmp_path_factory.mktemp("nkjp_dataset") sample_path = dataset_path / "sample" os.mkdir(sample_path) ann_path = sample_path / NER_FILE with open(ann_path, "w", encoding="utf-8") as fout: fout.write(SAMPLE_ANN) morph_path = sample_path / MORPH_FILE with open(morph_path, "w", encoding="utf-8") as fout: fout.write(SAMPLE_MORPHO) return dataset_path EXPECTED_TOKENS = [ {'seg_id': '1.1-seg', 'i': 0, 'orth': '2', 'text': '2', 'tag': '_', 'ner': 'O', 'ner_subtype': None}, {'seg_id': '1.37-seg', 'i': 36, 'orth': 'Sił', 'text': 'Sił', 'tag': '_', 'ner': 'B-orgName', 'ner_subtype': None}, {'seg_id': '1.38-seg', 'i': 37, 'orth': 'Zbrojnych', 'text': 'Zbrojnych', 'tag': '_', 'ner': 'I-orgName', 'ner_subtype': None}, ] def test_extract_entities_from_subfolder(dataset): entities = extract_entities_from_subfolder("sample", dataset) assert len(entities) == 1 assert len(entities['1-p']) == 1 assert len(entities['1-p']['1.39-s']) == 39 assert entities['1-p']['1.39-s']['1.1-seg'] == EXPECTED_TOKENS[0] assert entities['1-p']['1.39-s']['1.37-seg'] == EXPECTED_TOKENS[1] assert entities['1-p']['1.39-s']['1.38-seg'] == EXPECTED_TOKENS[2] def test_extract_unassigned(dataset): entities = extract_unassigned_subfolder_entities("sample", dataset) assert entities == EXPECTED_ENTITIES SENTENCE_SAMPLE = """ Sił Zbrojnych Siły Zbrojne """.strip() EMPTY_SENTENCE = """""" def test_extract_entities_from_sentence(): rt = ET.fromstring(SENTENCE_SAMPLE) entities = extract_entities_from_sentence(rt) assert entities == EXPECTED_ENTITIES['1-p']['1.39-s'] rt = ET.fromstring(EMPTY_SENTENCE) entities = extract_entities_from_sentence(rt) assert entities == [] # picked completely at random, one sample file for testing: # 610-1-000248/ann_named.xml # only the first sentence is used in the morpho file SAMPLE_ANN = """

Sił Zbrojnych Siły Zbrojne

""".lstrip() SAMPLE_MORPHO = """

2 2 2:adj:sg:nom:n:pos . . .:interp Wezwanie wezwanie wezwać wezwanie:subst:sg:acc:n , , ,:interp o o o ojciec o:prep:loc którym który który:adj:sg:loc:n:pos mowa mowa mowa:subst:sg:nom:f w w wiek wielki wiersz wieś wyspa w:prep:loc:nwok ust usta ustęp ustęp:brev:pun . . .:interp 1 1 1:adj:sg:loc:m3:pos , , ,:interp doręcza doręczać doręcze doręczać:fin:sg:ter:imperf się się się:qub na na na na:prep:acc czternaście czternaście czternaście:num:pl:acc:m3:rec dni dni dzień dzień:subst:pl:gen:m3 przed przed przed:prep:inst:nwok terminem termin termin:subst:sg:inst:m3 wykonania wykonanie wykonać wykonać:ger:sg:gen:n:perf:aff świadczenia świadczenie świadczyć świadczenie:subst:sg:gen:n , , ,:interp z z z zeszyt z:prep:inst:nwok wyjątkiem wyjątek wyjątek:subst:sg:inst:m3 przypadków przypadek przypadek:subst:pl:gen:m3 , , ,:interp w w wiek wielki wiersz wieś wyspa w:prep:loc:nwok których który który:adj:pl:loc:m3:pos wykonanie wykonanie wykonać wykonać:ger:sg:nom:n:perf:aff świadczenia świadczenie świadczyć świadczenie:subst:sg:gen:n następuje następować następować:fin:sg:ter:imperf w w wiek wielki wiersz wieś wyspa w:prep:loc:nwok celu Cela cel cel:subst:sg:loc:m3 sprawdzenia sprawdzić sprawdzić:ger:sg:gen:n:perf:aff gotowości gotowość gotowość:subst:sg:gen:f mobilizacyjnej mobilizacyjny mobilizacyjny:adj:sg:gen:f:pos Sił siła siły siła:subst:pl:gen:f Zbrojnych zbrojny zbrojny:adj:pl:gen:f:pos . . .:interp

""".lstrip() ================================================ FILE: stanza/tests/ner/test_convert_starlang_ner.py ================================================ """ Test a couple different classes of trees to check the output of the Starlang conversion for NER """ import os import tempfile import pytest from stanza.utils.datasets.ner import convert_starlang_ner pytestmark = [pytest.mark.pipeline, pytest.mark.travis] TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )" def test_read_tree(): """ Test a basic tree read """ sentence = convert_starlang_ner.read_tree(TREE) expected = [('Bayan', 'PERSON'), ('Haag', 'PERSON'), ('Elianti', 'O'), ('çalar', 'O'), ('.', 'O')] assert sentence == expected ================================================ FILE: stanza/tests/ner/test_data.py ================================================ import json import pytest from stanza.models import ner_tagger from stanza.models.common.doc import Document from stanza.models.ner.data import DataLoader from stanza.tests import TEST_WORKING_DIR pytestmark = [pytest.mark.travis, pytest.mark.pipeline] ONE_SENTENCE = """ [ [ { "text": "EU", "ner": "B-ORG" }, { "text": "rejects", "ner": "O" }, { "text": "German", "ner": "B-MISC" }, { "text": "call", "ner": "O" }, { "text": "to", "ner": "O" }, { "text": "boycott", "ner": "O" }, { "text": "Mox", "ner": "B-MISC" }, { "text": "Opal", "ner": "I-MISC" }, { "text": ".", "ner": "O" } ] ] """ @pytest.fixture(scope="module") def pretrain_file(): return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' @pytest.fixture(scope="module") def one_sentence_json_path(tmpdir_factory): filename = tmpdir_factory.mktemp('data').join("sentence.json") with open(filename, 'w') as fout: fout.write(ONE_SENTENCE) return filename def test_build_vocab(pretrain_file, one_sentence_json_path, tmp_path): """ Test that when loading a data file, we get back """ args = ner_tagger.parse_args(["--wordvec_pretrain_file", pretrain_file]) pt = ner_tagger.load_pretrain(args) with open(one_sentence_json_path) as fin: train_doc = Document(json.load(fin)) train_batch = DataLoader(train_doc, args['batch_size'], args, pt, vocab=None, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words']) vocab = train_batch.vocab pt_words = list(vocab['word']) assert pt_words == ['', '', '', '', 'unban', 'mox', 'opal'] delta_words = list(vocab['delta']) assert delta_words == ['', '', '', '', 'eu', 'rejects', 'german', 'call', 'to', 'boycott', 'mox', 'opal', '.'] tags = list(vocab['tag']) assert tags == [[''], [''], [], [''], ['S-ORG'], ['O'], ['S-MISC'], ['B-MISC'], ['E-MISC']] def test_build_vocab_ignore_repeats(pretrain_file, one_sentence_json_path, tmp_path): """ Test that when loading a data file, we get back """ args = ner_tagger.parse_args(["--wordvec_pretrain_file", pretrain_file, "--emb_finetune_known_only"]) pt = ner_tagger.load_pretrain(args) with open(one_sentence_json_path) as fin: train_doc = Document(json.load(fin)) train_batch = DataLoader(train_doc, args['batch_size'], args, pt, vocab=None, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words']) vocab = train_batch.vocab pt_words = list(vocab['word']) assert pt_words == ['', '', '', '', 'unban', 'mox', 'opal'] delta_words = list(vocab['delta']) assert delta_words == ['', '', '', '', 'mox', 'opal'] tags = list(vocab['tag']) assert tags == [[''], [''], [], [''], ['S-ORG'], ['O'], ['S-MISC'], ['B-MISC'], ['E-MISC']] ================================================ FILE: stanza/tests/ner/test_from_conllu.py ================================================ import pytest from stanza import Pipeline from stanza.utils.conll import CoNLL from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def test_from_conllu(): """ If the doc does not have the entire text available, make sure it still safely processes the text Test case supplied from user - see issue #1428 """ pipe = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,ner", download_method=None) doc = pipe("In February, I traveled to Seattle. Dr. Pritchett gave me a new hip") ents = [x.text for x in doc.ents] # the default NER model ought to find these three assert ents == ['February', 'Seattle', 'Pritchett'] doc_conllu = "{:C}\n\n".format(doc) doc = CoNLL.conll2doc(input_str=doc_conllu) pipe = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,ner", tokenize_pretokenized=True, download_method=None) pipe(doc) ents = [x.text for x in doc.ents] # this should still work when processed from a CoNLLu document # the bug previously caused a crash because the text to construct # the entities was not available, since the Document wouldn't have # the entire document text available assert ents == ['February', 'Seattle', 'Pritchett'] ================================================ FILE: stanza/tests/ner/test_models_ner_scorer.py ================================================ """ Simple test of the scorer module for NER """ import pytest import stanza from stanza.tests import * from stanza.models.ner.scorer import score_by_token, score_by_entity pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_ner_scorer(): pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'], ['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']] gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'], ['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']] token_p, token_r, token_f, confusion = score_by_token(pred_sequences, gold_sequences) assert pytest.approx(token_p, abs=0.00001) == 0.625 assert pytest.approx(token_r, abs=0.00001) == 0.5 assert pytest.approx(token_f, abs=0.00001) == 0.55555 entity_p, entity_r, entity_f, entity_f1 = score_by_entity(pred_sequences, gold_sequences) assert pytest.approx(entity_p, abs=0.00001) == 0.4 assert pytest.approx(entity_r, abs=0.00001) == 0.33333 assert pytest.approx(entity_f, abs=0.00001) == 0.36363 assert entity_f1 == {'LOC': 0.0, 'MISC': 1.0, 'ORG': 0.0, 'PER': 0.5} ================================================ FILE: stanza/tests/ner/test_ner_tagger.py ================================================ """ Basic testing of the NER tagger. """ import os import pytest import stanza from stanza.tests import * from stanza.models import ner_tagger from stanza.utils.confusion import confusion_to_macro_f1 import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file from stanza.utils.training.run_ner import build_pretrain_args pytestmark = [pytest.mark.travis, pytest.mark.pipeline] EN_DOC = "Chris Manning is a good man. He works in Stanford University." EN_DOC_GOLD = """ """.strip() EN_BIO = """ Chris B-PERSON Manning E-PERSON is O a O good O man O . O He O works O in O Stanford B-ORG University E-ORG . O """.strip().replace(" ", "\t") EN_EXPECTED_OUTPUT = """ Chris B-PERSON B-PERSON Manning E-PERSON E-PERSON is O O a O O good O O man O O . O O He O O works O O in O O Stanford B-ORG B-ORG University E-ORG E-ORG . O O """.strip().replace(" ", "\t") def test_ner(): nlp = stanza.Pipeline(**{'processors': 'tokenize,ner', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'logging_level': 'error'}) doc = nlp(EN_DOC) assert EN_DOC_GOLD == '\n'.join([ent.pretty_print() for ent in doc.ents]) def test_evaluate(tmp_path): """ This simple example should have a 1.0 f1 for the ontonote model """ package = "ontonotes-ww-multi_charlm" model_path = os.path.join(TEST_MODELS_DIR, "en", "ner", package + ".pt") assert os.path.exists(model_path), "The {} model should be downloaded as part of setup.py".format(package) os.makedirs(tmp_path, exist_ok=True) test_bio_filename = tmp_path / "test.bio" test_json_filename = tmp_path / "test.json" test_output_filename = tmp_path / "output.bio" with open(test_bio_filename, "w", encoding="utf-8") as fout: fout.write(EN_BIO) prepare_ner_file.process_dataset(test_bio_filename, test_json_filename) args = ["--save_name", str(model_path), "--eval_file", str(test_json_filename), "--eval_output_file", str(test_output_filename), "--mode", "predict"] args = args + build_pretrain_args("en", package, model_dir=TEST_MODELS_DIR, extra_args=[]) args = ner_tagger.parse_args(args=args) confusion = ner_tagger.evaluate(args) assert confusion_to_macro_f1(confusion) == pytest.approx(1.0) with open(test_output_filename, encoding="utf-8") as fin: results = fin.read().strip() assert results == EN_EXPECTED_OUTPUT ================================================ FILE: stanza/tests/ner/test_ner_trainer.py ================================================ import pytest from stanza.tests import * from stanza.models.ner import trainer pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_fix_singleton_tags(): TESTS = [ (["O"], ["O"]), (["B-PER"], ["S-PER"]), (["B-PER", "I-PER"], ["B-PER", "E-PER"]), (["B-PER", "O", "B-PER"], ["S-PER", "O", "S-PER"]), (["B-PER", "B-PER", "I-PER"], ["S-PER", "B-PER", "E-PER"]), (["B-PER", "I-PER", "O", "B-PER"], ["B-PER", "E-PER", "O", "S-PER"]), (["B-PER", "B-PER", "I-PER", "B-PER"], ["S-PER", "B-PER", "E-PER", "S-PER"]), (["B-PER", "I-ORG", "O", "B-PER"], ["S-PER", "S-ORG", "O", "S-PER"]), (["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]), (["S-PER", "B-PER", "E-PER"], ["S-PER", "B-PER", "E-PER"]), (["E-PER"], ["S-PER"]), (["E-PER", "O", "E-PER"], ["S-PER", "O", "S-PER"]), (["B-PER", "E-ORG", "O", "B-PER"], ["S-PER", "S-ORG", "O", "S-PER"]), (["I-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]), (["B-PER", "I-PER", "I-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]), (["B-PER", "I-PER", "E-PER", "O", "I-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]), (["B-PER", "I-PER", "E-PER", "O", "B-PER", "I-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]), (["I-PER", "I-PER", "I-PER", "O", "I-PER", "I-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]), ] for unfixed, expected in TESTS: assert trainer.fix_singleton_tags(unfixed) == expected, "Error converting {} to {}".format(unfixed, expected) ================================================ FILE: stanza/tests/ner/test_ner_training.py ================================================ import json import logging import os import warnings import pytest import torch pytestmark = [pytest.mark.travis, pytest.mark.pipeline] from stanza.models import ner_tagger from stanza.models.ner.trainer import Trainer from stanza.tests import TEST_WORKING_DIR from stanza.utils.datasets.ner.prepare_ner_file import process_dataset logger = logging.getLogger('stanza') EN_TRAIN_BIO = """ Chris B-PERSON Manning E-PERSON is O a O good O man O . O He O works O in O Stanford B-ORG University E-ORG . O """.lstrip().replace(" ", "\t") EN_DEV_BIO = """ Chris B-PERSON Manning E-PERSON is O part O of O Computer B-ORG Science E-ORG """.lstrip().replace(" ", "\t") EN_TRAIN_2TAG = """ Chris B-PERSON B-PER Manning E-PERSON E-PER is O O a O O good O O man O O . O O He O O works O O in O O Stanford B-ORG B-ORG University E-ORG B-ORG . O O """.strip().replace(" ", "\t") EN_TRAIN_2TAG_EMPTY2 = """ Chris B-PERSON - Manning E-PERSON - is O - a O - good O - man O - . O - He O - works O - in O - Stanford B-ORG - University E-ORG - . O - """.strip().replace(" ", "\t") EN_DEV_2TAG = """ Chris B-PERSON B-PER Manning E-PERSON E-PER is O O part O O of O O Computer B-ORG B-ORG Science E-ORG E-ORG """.strip().replace(" ", "\t") @pytest.fixture(scope="module") def pretrain_file(): return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' def write_temp_file(filename, bio_data): bio_filename = os.path.splitext(filename)[0] + ".bio" with open(bio_filename, "w", encoding="utf-8") as fout: fout.write(bio_data) process_dataset(bio_filename, filename) def write_temp_2tag(filename, bio_data): doc = [] sentences = bio_data.split("\n\n") for sentence in sentences: doc.append([]) for word in sentence.split("\n"): text, tags = word.split("\t", maxsplit=1) doc[-1].append({ "text": text, "multi_ner": tags.split() }) with open(filename, "w", encoding="utf-8") as fout: json.dump(doc, fout) def get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args): save_dir = tmp_path / "models" args = ["--data_dir", str(tmp_path), "--wordvec_pretrain_file", pretrain_file, "--train_file", str(train_json), "--eval_file", str(dev_json), "--shorthand", "en_test", "--max_steps", "100", "--eval_interval", "40", "--save_dir", str(save_dir)] args = args + list(extra_args) return args def run_two_tag_training(pretrain_file, tmp_path, *extra_args, train_data=EN_TRAIN_2TAG): train_json = tmp_path / "en_test.train.json" write_temp_2tag(train_json, train_data) dev_json = tmp_path / "en_test.dev.json" write_temp_2tag(dev_json, EN_DEV_2TAG) args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args) return ner_tagger.main(args) def test_basic_two_tag_training(pretrain_file, tmp_path): trainer = run_two_tag_training(pretrain_file, tmp_path) assert len(trainer.model.tag_clfs) == 2 assert len(trainer.model.crits) == 2 assert len(trainer.vocab['tag'].lens()) == 2 def test_two_tag_training_backprop(pretrain_file, tmp_path): """ Test that the training is backproping both tags We can do this by using the "finetune" mechanism and verifying that the output tensors are different """ trainer = run_two_tag_training(pretrain_file, tmp_path) # first, need to save the final model before restarting # (alternatively, could reload the final checkpoint) trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name'])) new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune") assert len(trainer.model.tag_clfs) == 2 assert len(new_trainer.model.tag_clfs) == 2 for old_clf, new_clf in zip(trainer.model.tag_clfs, new_trainer.model.tag_clfs): assert not torch.allclose(old_clf.weight, new_clf.weight) def test_two_tag_training_c2_backprop(pretrain_file, tmp_path): """ Test that the training is backproping only one tag if one column is blank We can do this by using the "finetune" mechanism and verifying that the output tensors are different in just the first column """ trainer = run_two_tag_training(pretrain_file, tmp_path) # first, need to save the final model before restarting # (alternatively, could reload the final checkpoint) trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name'])) new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune", train_data=EN_TRAIN_2TAG_EMPTY2) assert len(trainer.model.tag_clfs) == 2 assert len(new_trainer.model.tag_clfs) == 2 assert not torch.allclose(trainer.model.tag_clfs[0].weight, new_trainer.model.tag_clfs[0].weight) assert torch.allclose(trainer.model.tag_clfs[1].weight, new_trainer.model.tag_clfs[1].weight) def test_connected_two_tag_training(pretrain_file, tmp_path): trainer = run_two_tag_training(pretrain_file, tmp_path, "--connect_output_layers") assert len(trainer.model.tag_clfs) == 2 assert len(trainer.model.crits) == 2 assert len(trainer.vocab['tag'].lens()) == 2 # this checks that with the connected output layers, # the second output layer has its size increased # by the number of tags known to the first output layer assert trainer.model.tag_clfs[1].weight.shape[1] == trainer.vocab['tag'].lens()[0] + trainer.model.tag_clfs[0].weight.shape[1] def run_training(pretrain_file, tmp_path, *extra_args): train_json = tmp_path / "en_test.train.json" write_temp_file(train_json, EN_TRAIN_BIO) dev_json = tmp_path / "en_test.dev.json" write_temp_file(dev_json, EN_DEV_BIO) args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args) return ner_tagger.main(args) def test_train_model_gpu(pretrain_file, tmp_path): """ Briefly train an NER model (no expectation of correctness) and check that it is on the GPU """ trainer = run_training(pretrain_file, tmp_path) if not torch.cuda.is_available(): warnings.warn("Cannot check that the NER model is on the GPU, since GPU is not available") return model = trainer.model device = next(model.parameters()).device assert str(device).startswith("cuda") def test_train_model_cpu(pretrain_file, tmp_path): """ Briefly train an NER model (no expectation of correctness) and check that it is on the GPU """ trainer = run_training(pretrain_file, tmp_path, "--cpu") model = trainer.model device = next(model.parameters()).device assert str(device).startswith("cpu") def model_file_has_bert(filename): checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) return any(x.startswith("bert_model.") for x in checkpoint['model'].keys()) def test_with_bert(pretrain_file, tmp_path): trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert') model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name']) assert not model_file_has_bert(model_file) def test_with_bert_finetune(pretrain_file, tmp_path): trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune') model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name']) assert model_file_has_bert(model_file) foo_save_filename = os.path.join(tmp_path, "foo_" + trainer.args['save_name']) bar_save_filename = os.path.join(tmp_path, "bar_" + trainer.args['save_name']) trainer.save(foo_save_filename) assert model_file_has_bert(foo_save_filename) # TODO: technically this should still work if we turn off bert finetuning when reloading reloaded_trainer = Trainer(args=trainer.args, model_file=foo_save_filename) reloaded_trainer.save(bar_save_filename) assert model_file_has_bert(bar_save_filename) def test_with_peft_finetune(pretrain_file, tmp_path): # TODO: check that the peft tensors are moving when training? trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft') model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name']) checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True) assert 'bert_lora' in checkpoint assert not any(x.startswith("bert_model.") for x in checkpoint['model'].keys()) # test loading reloaded_trainer = Trainer(args=trainer.args, model_file=model_file) ================================================ FILE: stanza/tests/ner/test_ner_utils.py ================================================ import pytest from stanza.tests import * from stanza.models.common.vocab import EMPTY from stanza.models.ner import utils pytestmark = [pytest.mark.travis, pytest.mark.pipeline] WORDS = [["Unban", "Mox", "Opal"], ["Ragavan", "is", "red"], ["Urza", "Lord", "High", "Artificer", "goes", "infinite", "with", "Thopter", "Sword"]] BIO_TAGS = [["O", "B-ART", "I-ART"], ["B-MONKEY", "O", "B-COLOR"], ["B-PER", "I-PER", "I-PER", "I-PER", "O", "O", "O", "B-WEAPON", "B-WEAPON"]] BIO_U_TAGS = [["O", "B_ART", "I_ART"], ["B_MONKEY", "O", "B_COLOR"], ["B_PER", "I_PER", "I_PER", "I_PER", "O", "O", "O", "B_WEAPON", "B_WEAPON"]] BIOES_TAGS = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "S-WEAPON", "S-WEAPON"]] # note the problem with not using BIO tags - the consecutive tags for thopter/sword get treated as one item BASIC_TAGS = [["O", "ART", "ART"], ["MONKEY", "O", "COLOR"], [ "PER", "PER", "PER", "PER", "O", "O", "O", "WEAPON", "WEAPON"]] BASIC_BIOES = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "B-WEAPON", "E-WEAPON"]] ALT_BIO = [["O", "B-MANA", "I-MANA"], ["B-CRE", "O", "O"], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]] ALT_BIOES = [["O", "B-MANA", "E-MANA"], ["S-CRE", "O", "O"], ["B-CRE", "I-CRE", "I-CRE", "E-CRE", "O", "O", "O", "S-ART", "S-ART"]] NONE_BIO = [["O", "B-MANA", "I-MANA"], [None, None, None], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]] NONE_BIOES = [["O", "B-MANA", "E-MANA"], [None, None, None], ["B-CRE", "I-CRE", "I-CRE", "E-CRE", "O", "O", "O", "S-ART", "S-ART"]] EMPTY_BIO = [["O", "B-MANA", "I-MANA"], [EMPTY, EMPTY, EMPTY], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]] def test_normalize_empty_tags(): sentences = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, NONE_BIO)] new_sentences = utils.normalize_empty_tags(sentences) expected = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, EMPTY_BIO)] assert new_sentences == expected def check_reprocessed_tags(words, input_tags, expected_tags): sentences = [list(zip(x, y)) for x, y in zip(words, input_tags)] retagged = utils.process_tags(sentences=sentences, scheme="bioes") # process_tags selectively returns tuples or strings based on the input # so we don't need to fiddle with the expected output format here expected_retagged = [list(zip(x, y)) for x, y in zip(words, expected_tags)] assert retagged == expected_retagged def test_process_tags_bio(): check_reprocessed_tags(WORDS, BIO_TAGS, BIOES_TAGS) # check that the alternate version is correct as well # that way we can independently check the two layer version check_reprocessed_tags(WORDS, ALT_BIO, ALT_BIOES) def test_process_tags_with_none(): # if there is a block of tags with None in them, the Nones should be skipped over check_reprocessed_tags(WORDS, NONE_BIO, NONE_BIOES) def merge_tags(*tags): merged_tags = [[tuple(x) for x in zip(*sentences)] # combine tags such as ("O", "O"), ("B-ART", "B-MANA"), ... for sentences in zip(*tags)] # ... for each set of sentences return merged_tags def test_combined_tags_bio(): bio_tags = merge_tags(BIO_TAGS, ALT_BIO) expected = merge_tags(BIOES_TAGS, ALT_BIOES) check_reprocessed_tags(WORDS, bio_tags, expected) def test_combined_tags_mixed(): bio_tags = merge_tags(BIO_TAGS, ALT_BIOES) expected = merge_tags(BIOES_TAGS, ALT_BIOES) check_reprocessed_tags(WORDS, bio_tags, expected) def test_process_tags_basic(): check_reprocessed_tags(WORDS, BASIC_TAGS, BASIC_BIOES) def test_process_tags_bioes(): """ This one should not change, naturally """ check_reprocessed_tags(WORDS, BIOES_TAGS, BIOES_TAGS) check_reprocessed_tags(WORDS, BASIC_BIOES, BASIC_BIOES) def run_flattened(fn, tags): return fn([x for x in y for y in tags]) def test_check_bio(): assert utils.is_bio_scheme([x for y in BIO_TAGS for x in y]) assert not utils.is_bio_scheme([x for y in BIOES_TAGS for x in y]) assert not utils.is_bio_scheme([x for y in BASIC_TAGS for x in y]) assert not utils.is_bio_scheme([x for y in BASIC_BIOES for x in y]) def test_check_basic(): assert not utils.is_basic_scheme([x for y in BIO_TAGS for x in y]) assert not utils.is_basic_scheme([x for y in BIOES_TAGS for x in y]) assert utils.is_basic_scheme([x for y in BASIC_TAGS for x in y]) assert not utils.is_basic_scheme([x for y in BASIC_BIOES for x in y]) def test_underscores(): """ Check that the methods work if the inputs are underscores instead of dashes """ assert not utils.is_basic_scheme([x for y in BIO_U_TAGS for x in y]) check_reprocessed_tags(WORDS, BIO_U_TAGS, BIOES_TAGS) def test_merge_tags(): """ Check a few versions of the tag sequence merging """ seq1 = [ "O", "O", "O", "B-FOO", "E-FOO", "O"] seq2 = [ "S-FOO", "O", "B-FOO", "E-FOO", "O", "O"] seq3 = [ "B-FOO", "E-FOO", "B-FOO", "E-FOO", "O", "O"] seq_err = [ "O", "B-FOO", "O", "B-FOO", "E-FOO", "O"] seq_err2 = [ "O", "B-FOO", "O", "B-FOO", "B-FOO", "O"] seq_err3 = [ "O", "B-FOO", "O", "B-FOO", "I-FOO", "O"] seq_err4 = [ "O", "B-FOO", "O", "B-FOO", "I-FOO", "I-FOO"] result = utils.merge_tags(seq1, seq2) expected = [ "S-FOO", "O", "O", "B-FOO", "E-FOO", "O"] assert result == expected result = utils.merge_tags(seq2, seq1) expected = [ "S-FOO", "O", "B-FOO", "E-FOO", "O", "O"] assert result == expected result = utils.merge_tags(seq1, seq3) expected = [ "B-FOO", "E-FOO", "O", "B-FOO", "E-FOO", "O"] assert result == expected with pytest.raises(ValueError): result = utils.merge_tags(seq1, seq_err) with pytest.raises(ValueError): result = utils.merge_tags(seq1, seq_err2) with pytest.raises(ValueError): result = utils.merge_tags(seq1, seq_err3) with pytest.raises(ValueError): result = utils.merge_tags(seq1, seq_err4) ================================================ FILE: stanza/tests/ner/test_pay_amt_annotators.py ================================================ """ Simple test for tracking AMT annotator work """ import os import zipfile import pytest from stanza.tests import TEST_WORKING_DIR from stanza.utils.ner import paying_annotators DATA_SOURCE = os.path.join(TEST_WORKING_DIR, "in", "aws_annotations.zip") @pytest.fixture(scope="module") def completed_amt_job_metadata(tmp_path_factory): assert os.path.exists(DATA_SOURCE) unzip_path = tmp_path_factory.mktemp("amt_test") input_path = unzip_path / "ner" / "aws_labeling_copy" with zipfile.ZipFile(DATA_SOURCE, 'r') as zin: zin.extractall(unzip_path) return input_path def test_amt_annotator_track(completed_amt_job_metadata): workers = { "7efc17ac-3397-4472-afe5-89184ad145d0": "Worker1", "afce8c28-969c-4e73-a20f-622ef122f585": "Worker2", "91f6236e-63c6-4a84-8fd6-1efbab6dedab": "Worker3", "6f202e93-e6b6-4e1d-8f07-0484b9a9093a": "Worker4", "2b674d33-f656-44b0-8f90-d70a1ab71ec2": "Worker5" } # map AMT annotator subs to relevant identifier tracked_work = paying_annotators.track_tasks(completed_amt_job_metadata, workers) assert tracked_work == {'Worker4': 20, 'Worker5': 20, 'Worker2': 3, 'Worker3': 16} def test_amt_annotator_track_no_map(completed_amt_job_metadata): sub_to_count = paying_annotators.track_tasks(completed_amt_job_metadata) assert sub_to_count == {'6f202e93-e6b6-4e1d-8f07-0484b9a9093a': 20, '2b674d33-f656-44b0-8f90-d70a1ab71ec2': 20, 'afce8c28-969c-4e73-a20f-622ef122f585': 3, '91f6236e-63c6-4a84-8fd6-1efbab6dedab': 16} def main(): test_amt_annotator_track() test_amt_annotator_track_no_map() if __name__ == "__main__": main() print("TESTS COMPLETED!") ================================================ FILE: stanza/tests/ner/test_split_wikiner.py ================================================ """ Runs a few tests on the split_wikiner file """ import os import tempfile import pytest from stanza.utils.datasets.ner import split_wikiner from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # two sentences from the Italian dataset, split into many pieces # to test the splitting functionality FBK_SAMPLE = """ Il O Papa O si O aggrava O Le O condizioni O di O Papa O Giovanni PER Paolo PER II PER si O sono O aggravate O in O il O corso O di O la O giornata O di O giovedì O . O Il O portavoce O Navarro PER Valls PER ha O dichiarato O che O il O Santo O Padre O in O la O giornata O di O oggi O è O stato O colpito O da O una O affezione O altamente O febbrile O provocata O da O una O infezione O documentata O di O le O vie O urinarie O . O A O il O momento O non O è O previsto O il O ricovero O a O il O Policlinico LOC Gemelli LOC , O come O ha O precisato O il O responsabile O di O il O dipartimento O di O emergenza O professor O Rodolfo PER Proietti PER . O """ def test_read_sentences(): with tempfile.TemporaryDirectory() as tempdir: raw_filename = os.path.join(tempdir, "raw.tsv") with open(raw_filename, "w") as fout: fout.write(FBK_SAMPLE) sentences = split_wikiner.read_sentences(raw_filename, "utf-8") assert len(sentences) == 20 text = [["\t".join(word) for word in sent] for sent in sentences] text = ["\n".join(sent) for sent in text] text = "\n\n".join(text) assert FBK_SAMPLE.strip() == text def test_write_sentences(): with tempfile.TemporaryDirectory() as tempdir: raw_filename = os.path.join(tempdir, "raw.tsv") with open(raw_filename, "w") as fout: fout.write(FBK_SAMPLE) sentences = split_wikiner.read_sentences(raw_filename, "utf-8") copy_filename = os.path.join(tempdir, "copy.tsv") split_wikiner.write_sentences_to_file(sentences, copy_filename) sent2 = split_wikiner.read_sentences(raw_filename, "utf-8") assert sent2 == sentences def run_split_wikiner(expected_train=14, expected_dev=3, expected_test=3, **kwargs): """ Runs a test using various parameters to check the results of the splitting process """ with tempfile.TemporaryDirectory() as indir: raw_filename = os.path.join(indir, "raw.tsv") with open(raw_filename, "w") as fout: fout.write(FBK_SAMPLE) with tempfile.TemporaryDirectory() as outdir: split_wikiner.split_wikiner(outdir, raw_filename, **kwargs) train_file = os.path.join(outdir, "it_fbk.train.bio") dev_file = os.path.join(outdir, "it_fbk.dev.bio") test_file = os.path.join(outdir, "it_fbk.test.bio") assert os.path.exists(train_file) assert os.path.exists(dev_file) if kwargs["test_section"]: assert os.path.exists(test_file) else: assert not os.path.exists(test_file) train_sent = split_wikiner.read_sentences(train_file, "utf-8") dev_sent = split_wikiner.read_sentences(dev_file, "utf-8") assert len(train_sent) == expected_train assert len(dev_sent) == expected_dev if kwargs["test_section"]: test_sent = split_wikiner.read_sentences(test_file, "utf-8") assert len(test_sent) == expected_test else: test_sent = [] if kwargs["shuffle"]: orig_sents = sorted(split_wikiner.read_sentences(raw_filename, "utf-8")) split_sents = sorted(train_sent + dev_sent + test_sent) else: orig_sents = split_wikiner.read_sentences(raw_filename, "utf-8") split_sents = train_sent + dev_sent + test_sent assert orig_sents == split_sents def test_no_shuffle_split(): run_split_wikiner(prefix="it_fbk", shuffle=False, test_section=True) def test_shuffle_split(): run_split_wikiner(prefix="it_fbk", shuffle=True, test_section=True) def test_resize(): run_split_wikiner(expected_train=12, expected_dev=2, expected_test=6, train_fraction=0.6, dev_fraction=0.1, prefix="it_fbk", shuffle=True, test_section=True) def test_no_test_split(): run_split_wikiner(expected_train=17, train_fraction=0.85, prefix="it_fbk", shuffle=False, test_section=False) ================================================ FILE: stanza/tests/ner/test_suc3.py ================================================ """ Tests the conversion code for the SUC3 NER dataset """ import os import tempfile from zipfile import ZipFile import pytest pytestmark = [pytest.mark.travis, pytest.mark.pipeline] import stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob TEST_CONLL = """ 1 Den den PN PN UTR|SIN|DEF|SUB/OBJ _ _ _ _ O _ ac01b-030:2328 2 Gud Gud PM PM NOM _ _ _ _ B myth ac01b-030:2329 3 giver giva VB VB PRS|AKT _ _ _ _ O _ ac01b-030:2330 4 ämbetet ämbete NN NN NEU|SIN|DEF|NOM _ _ _ _ O _ ac01b-030:2331 5 får få VB VB PRS|AKT _ _ _ _ O _ ac01b-030:2332 6 också också AB AB _ _ _ _ O _ ac01b-030:2333 7 förståndet förstånd NN NN NEU|SIN|DEF|NOM _ _ _ _ O _ ac01b-030:2334 8 . . MAD MAD _ _ _ _ O _ ac01b-030:2335 1 Han han PN PN UTR|SIN|DEF|SUB _ _ _ _ O _ aa01a-017:227 2 berättar berätta VB VB PRS|AKT _ _ _ _ O _ aa01a-017:228 3 anekdoten anekdot NN NN UTR|SIN|DEF|NOM _ _ _ _ O _ aa01a-017:229 4 som som HP HP -|-|- _ _ _ _ O _ aa01a-017:230 5 FN-medlaren FN-medlare NN NN UTR|SIN|DEF|NOM _ _ _ _ O _ aa01a-017:231 6 Brian Brian PM PM NOM _ _ _ _ B person aa01a-017:232 7 Urquhart Urquhart PM PM NOM _ _ _ _ I person aa01a-017:233 8 myntat mynta VB VB SUP|AKT _ _ _ _ O _ aa01a-017:234 9 : : MAD MAD _ _ _ _ O _ aa01a-017:235 """ EXPECTED_IOB = """ Den O Gud B-myth giver O ämbetet O får O också O förståndet O . O Han O berättar O anekdoten O som O FN-medlaren O Brian B-person Urquhart I-person myntat O : O """ def test_read_zip(): """ Test creating a fake zip file, then converting it to an .iob file """ with tempfile.TemporaryDirectory() as tempdir: zip_name = os.path.join(tempdir, "test.zip") in_filename = "conll" with ZipFile(zip_name, "w") as zout: with zout.open(in_filename, "w") as fout: fout.write(TEST_CONLL.encode()) out_filename = os.path.join(tempdir, "iob") num = suc_conll_to_iob.extract_from_zip(zip_name, in_filename, out_filename) assert num == 2 with open(out_filename) as fin: result = fin.read() assert EXPECTED_IOB.strip() == result.strip() def test_read_raw(): """ Test a direct text file conversion w/o the zip file """ with tempfile.TemporaryDirectory() as tempdir: in_filename = os.path.join(tempdir, "test.txt") with open(in_filename, "w", encoding="utf-8") as fout: fout.write(TEST_CONLL) out_filename = os.path.join(tempdir, "iob") with open(in_filename, encoding="utf-8") as fin, open(out_filename, "w", encoding="utf-8") as fout: num = suc_conll_to_iob.extract(fin, fout) assert num == 2 with open(out_filename) as fin: result = fin.read() assert EXPECTED_IOB.strip() == result.strip() ================================================ FILE: stanza/tests/pipeline/__init__.py ================================================ ================================================ FILE: stanza/tests/pipeline/pipeline_device_tests.py ================================================ """ Utility methods to check that all processors are on the expected device Refactored since it can be used for multiple pipelines """ import warnings import torch def check_on_gpu(pipeline): """ Check that the processors are all on the GPU and that basic execution works """ if not torch.cuda.is_available(): warnings.warn("Unable to run the test that checks the pipeline is on the GPU, as there is no GPU available!") return for name, proc in pipeline.processors.items(): if proc.trainer is not None: device = next(proc.trainer.model.parameters()).device else: device = next(proc._model.parameters()).device assert str(device).startswith("cuda"), "Processor %s was not on the GPU" % name # just check that there are no cpu/cuda tensor conflicts # when running on the GPU pipeline("This is a small test") def check_on_cpu(pipeline): """ Check that the processors are all on the CPU and that basic execution works """ for name, proc in pipeline.processors.items(): if proc.trainer is not None: device = next(proc.trainer.model.parameters()).device else: device = next(proc._model.parameters()).device assert str(device).startswith("cpu"), "Processor %s was not on the CPU" % name # just check that there are no cpu/cuda tensor conflicts # when running on the CPU pipeline("This is a small test") ================================================ FILE: stanza/tests/pipeline/test_arabic_pipeline.py ================================================ """ Small test of loading the Arabic pipeline The main goal is to check that nothing goes wrong with RtL languages, but incidentally this would have caught a bug where the xpos tags were split into individual pieces instead of reassembled as expected """ import pytest import stanza from stanza.tests import TEST_MODELS_DIR pytestmark = pytest.mark.pipeline def test_arabic_pos_pipeline(): pipe = stanza.Pipeline(**{'processors': 'tokenize,pos', 'dir': TEST_MODELS_DIR, 'download_method': None, 'lang': 'ar'}) text = "ولم يتم اعتقال احد بحسب المتحدث باسم الشرطة." doc = pipe(text) # the first token translates to "and not", seems common enough # that we should be able to rely on it having a stable MWT and tag assert len(doc.sentences) == 1 assert doc.sentences[0].tokens[0].text == "ولم" assert doc.sentences[0].words[0].xpos == "C---------" assert doc.sentences[0].words[1].xpos == "F---------" ================================================ FILE: stanza/tests/pipeline/test_core.py ================================================ import pytest import shutil import tempfile import stanza from stanza.tests import * from stanza.pipeline import core from stanza.resources.common import get_md5, load_resources_json pytestmark = pytest.mark.pipeline def test_pretagged(): """ Test that the pipeline does or doesn't build if pos is left out and pretagged is specified """ nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,pos,lemma,depparse") with pytest.raises(core.PipelineRequirementsException): nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse") nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", depparse_pretagged=True) nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", pretagged=True) # test that the module specific flag overrides the general flag nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", depparse_pretagged=True, pretagged=False) def test_download_missing_ner_model(): """ Test that the pipeline will automatically download missing models """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined", verbose=False) pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize,ner", package={"ner": ("ontonotes_charlm")}) assert sorted(os.listdir(test_dir)) == ['en', 'resources.json'] en_dir = os.path.join(test_dir, 'en') en_dir_listing = sorted(os.listdir(en_dir)) assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize'] assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt'] def test_download_missing_resources(): """ Test that the pipeline will automatically download missing models """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize,ner", package={"tokenize": "combined", "ner": "ontonotes_charlm"}) assert sorted(os.listdir(test_dir)) == ['en', 'resources.json'] en_dir = os.path.join(test_dir, 'en') en_dir_listing = sorted(os.listdir(en_dir)) assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize'] assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt'] def test_download_resources_overwrites(): """ Test that the DOWNLOAD_RESOURCES method overwrites an existing resources.json """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}) assert sorted(os.listdir(test_dir)) == ['en', 'resources.json'] resources_path = os.path.join(test_dir, 'resources.json') mod_time = os.path.getmtime(resources_path) pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}) new_mod_time = os.path.getmtime(resources_path) assert mod_time != new_mod_time def test_reuse_resources_overwrites(): """ Test that the REUSE_RESOURCES method does *not* overwrite an existing resources.json """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: pipe = stanza.Pipeline("en", download_method=core.DownloadMethod.REUSE_RESOURCES, model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}) assert sorted(os.listdir(test_dir)) == ['en', 'resources.json'] resources_path = os.path.join(test_dir, 'resources.json') mod_time = os.path.getmtime(resources_path) pipe = stanza.Pipeline("en", download_method=core.DownloadMethod.REUSE_RESOURCES, model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}) new_mod_time = os.path.getmtime(resources_path) assert mod_time == new_mod_time def test_download_not_repeated(): """ Test that a model is only downloaded once if it already matches the expected model from the resources file """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined") assert sorted(os.listdir(test_dir)) == ['en', 'resources.json'] en_dir = os.path.join(test_dir, 'en') en_dir_listing = sorted(os.listdir(en_dir)) assert en_dir_listing == ['mwt', 'tokenize'] tokenize_path = os.path.join(en_dir, "tokenize", "combined.pt") mod_time = os.path.getmtime(tokenize_path) pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}) assert os.path.getmtime(tokenize_path) == mod_time def test_download_none(): with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download("it", model_dir=test_dir, processors="tokenize", package="combined") stanza.download("it", model_dir=test_dir, processors="tokenize", package="vit") it_dir = os.path.join(test_dir, 'it') it_dir_listing = sorted(os.listdir(it_dir)) assert sorted(it_dir_listing) == ['mwt', 'tokenize'] combined_path = os.path.join(it_dir, "tokenize", "combined.pt") vit_path = os.path.join(it_dir, "tokenize", "vit.pt") assert os.path.exists(combined_path) assert os.path.exists(vit_path) combined_md5 = get_md5(combined_path) vit_md5 = get_md5(vit_path) # check that the models are different # otherwise the test is not testing anything assert combined_md5 != vit_md5 shutil.copyfile(vit_path, combined_path) assert get_md5(combined_path) == vit_md5 pipe = stanza.Pipeline("it", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}, download_method=None) assert get_md5(combined_path) == vit_md5 pipe = stanza.Pipeline("it", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}) assert get_md5(combined_path) != vit_md5 def check_download_method_updates(download_method): """ Run a single test of creating a pipeline with a given download_method, checking that the model is updated """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined") assert sorted(os.listdir(test_dir)) == ['en', 'resources.json'] en_dir = os.path.join(test_dir, 'en') en_dir_listing = sorted(os.listdir(en_dir)) assert en_dir_listing == ['mwt', 'tokenize'] tokenize_path = os.path.join(en_dir, "tokenize", "combined.pt") with open(tokenize_path, "w") as fout: fout.write("Unban mox opal!") mod_time = os.path.getmtime(tokenize_path) pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}, download_method=download_method) assert os.path.getmtime(tokenize_path) != mod_time def test_download_fixed(): """ Test that a model is fixed if the existing model doesn't match the md5sum """ for download_method in (core.DownloadMethod.REUSE_RESOURCES, core.DownloadMethod.DOWNLOAD_RESOURCES): check_download_method_updates(download_method) def test_download_strings(): """ Same as the test of the download_method, but tests that the pipeline works for string download_method """ for download_method in ("reuse_resources", "download_resources"): check_download_method_updates(download_method) def test_limited_pipeline(): """ Test loading a pipeline, but then only using a couple processors """ pipe = stanza.Pipeline(processors="tokenize,pos,lemma,depparse,ner", dir=TEST_MODELS_DIR) doc = pipe("John Bauer works at Stanford") assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words) assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens) doc = pipe("John Bauer works at Stanford", processors=["tokenize","pos"]) assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words) assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens) doc = pipe("John Bauer works at Stanford", processors="tokenize") assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words) assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens) doc = pipe("John Bauer works at Stanford", processors="tokenize,ner") assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words) assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens) with pytest.raises(ValueError): # this should fail doc = pipe("John Bauer works at Stanford", processors="tokenize,depparse") @pytest.fixture(scope="module") def unknown_language_name(): resources = load_resources_json(model_dir=TEST_MODELS_DIR) name = "en" while name in resources: name = name + "z" assert name != "en" return name def test_empty_unknown_language(unknown_language_name): """ Check that there is an error for trying to load an unknown language """ with pytest.raises(ValueError): pipe = stanza.Pipeline(unknown_language_name, model_dir=TEST_MODELS_DIR, download_method=None) def test_unknown_language_tokenizer(unknown_language_name): """ Test that loading tokenize works for an unknown language """ base_pipe = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None) # even if we one day add MWT to English, the tokenizer by itself should still work tokenize_processor = base_pipe.processors["tokenize"] pipe=stanza.Pipeline(unknown_language_name, processors="tokenize", allow_unknown_language=True, tokenize_model_path=tokenize_processor.config['model_path'], download_method=None) doc = pipe("This is a test") words = [x.text for x in doc.sentences[0].words] assert words == ['This', 'is', 'a', 'test'] def test_unknown_language_mwt(unknown_language_name): """ Test that loading tokenize & mwt works for an unknown language """ base_pipe = stanza.Pipeline("fr", dir=TEST_MODELS_DIR, processors="tokenize,mwt", download_method=None) assert len(base_pipe.processors) == 2 tokenize_processor = base_pipe.processors["tokenize"] mwt_processor = base_pipe.processors["mwt"] pipe=stanza.Pipeline(unknown_language_name, model_dir=TEST_MODELS_DIR, processors="tokenize,mwt", allow_unknown_language=True, tokenize_model_path=tokenize_processor.config['model_path'], mwt_model_path=mwt_processor.config['model_path'], download_method=None) ================================================ FILE: stanza/tests/pipeline/test_decorators.py ================================================ """ Basic tests of the depparse processor boolean flags """ import pytest import stanza from stanza.models.common.doc import Document from stanza.pipeline.core import PipelineRequirementsException from stanza.pipeline.processor import Processor, ProcessorVariant, register_processor, register_processor_variant, ProcessorRegisterException from stanza.utils.conll import CoNLL from stanza.tests import * pytestmark = pytest.mark.pipeline # data for testing EN_DOC = "This is a test sentence. This is another!" EN_DOC_LOWERCASE_TOKENS = ''']> ]> ]> ]> ]> ]> ]> ]> ]> ]>''' EN_DOC_LOL_TOKENS = ''']> ]> ]> ]> ]> ]> ]> ]>''' EN_DOC_COOL_LEMMAS = ''']> ]> ]> ]> ]> ]> ]> ]> ]> ]>''' @register_processor("lowercase") class LowercaseProcessor(Processor): ''' Processor that lowercases all text ''' _requires = set(['tokenize']) _provides = set(['lowercase']) def __init__(self, config, pipeline, device): pass def _set_up_model(self, *args): pass def process(self, doc): doc.text = doc.text.lower() for sent in doc.sentences: for tok in sent.tokens: tok.text = tok.text.lower() for word in sent.words: word.text = word.text.lower() return doc def test_register_processor(): nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,lowercase', download_method=None) doc = nlp(EN_DOC) assert EN_DOC_LOWERCASE_TOKENS == '\n\n'.join(sent.tokens_string() for sent in doc.sentences) def test_register_nonprocessor(): with pytest.raises(ProcessorRegisterException): @register_processor("nonprocessor") class NonProcessor: pass @register_processor_variant("tokenize", "lol") class LOLTokenizer(ProcessorVariant): ''' An alternative tokenizer that splits text by space and replaces all tokens with LOL ''' def __init__(self, lang): pass def process(self, text): sentence = [{'id': (i+1, ), 'text': 'LOL'} for i, tok in enumerate(text.split())] return Document([sentence], text) def test_register_processor_variant(): nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={"tokenize": "lol"}, package=None, download_method=None) doc = nlp(EN_DOC) assert EN_DOC_LOL_TOKENS == '\n\n'.join(sent.tokens_string() for sent in doc.sentences) @register_processor_variant("lemma", "cool") class CoolLemmatizer(ProcessorVariant): ''' An alternative lemmatizer that lemmatizes every word to "cool". ''' OVERRIDE = True def __init__(self, lang): pass def process(self, document): for sentence in document.sentences: for word in sentence.words: word.lemma = "cool" return document def test_register_processor_variant_with_override(): nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={"tokenize": "combined", "pos": "combined", "lemma": "cool"}, package=None, download_method=None) doc = nlp(EN_DOC) result = '\n\n'.join(sent.tokens_string() for sent in doc.sentences) assert EN_DOC_COOL_LEMMAS == result def test_register_nonprocessor_variant(): with pytest.raises(ProcessorRegisterException): @register_processor_variant("tokenize", "nonvariant") class NonVariant: pass ================================================ FILE: stanza/tests/pipeline/test_depparse.py ================================================ """ Basic tests of the depparse processor boolean flags """ import gc import pytest import stanza from stanza.pipeline.core import PipelineRequirementsException from stanza.utils.conll import CoNLL from stanza.tests import * pytestmark = pytest.mark.pipeline # data for testing EN_DOC = "Barack Obama was born in Hawaii. He was elected president in 2008. Obama attended Harvard." EN_DOC_CONLLU_PRETAGGED = """ 1 Barack Barack PROPN NNP Number=Sing 0 _ _ _ 2 Obama Obama PROPN NNP Number=Sing 1 _ _ _ 3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 2 _ _ _ 4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 3 _ _ _ 5 in in ADP IN _ 4 _ _ _ 6 Hawaii Hawaii PROPN NNP Number=Sing 5 _ _ _ 7 . . PUNCT . _ 6 _ _ _ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 0 _ _ _ 2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 _ _ _ 3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 2 _ _ _ 4 president president PROPN NNP Number=Sing 3 _ _ _ 5 in in ADP IN _ 4 _ _ _ 6 2008 2008 NUM CD NumType=Card 5 _ _ _ 7 . . PUNCT . _ 6 _ _ _ 1 Obama Obama PROPN NNP Number=Sing 0 _ _ _ 2 attended attend VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 1 _ _ _ 3 Harvard Harvard PROPN NNP Number=Sing 2 _ _ _ 4 . . PUNCT . _ 3 _ _ _ """.lstrip() EN_DOC_DEPENDENCY_PARSES_GOLD = """ ('Barack', 4, 'nsubj:pass') ('Obama', 1, 'flat') ('was', 4, 'aux:pass') ('born', 0, 'root') ('in', 6, 'case') ('Hawaii', 4, 'obl') ('.', 4, 'punct') ('He', 3, 'nsubj:pass') ('was', 3, 'aux:pass') ('elected', 0, 'root') ('president', 3, 'xcomp') ('in', 6, 'case') ('2008', 3, 'obl') ('.', 3, 'punct') ('Obama', 2, 'nsubj') ('attended', 0, 'root') ('Harvard', 2, 'obj') ('.', 2, 'punct') """.strip() @pytest.fixture(scope="module") def en_depparse_pipeline(): nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,pos,lemma,depparse') gc.collect() return nlp def test_depparse(en_depparse_pipeline): doc = en_depparse_pipeline(EN_DOC) assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\n\n'.join([sent.dependencies_string() for sent in doc.sentences]) def test_depparse_with_pretagged_doc(): nlp = stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'depparse_pretagged': True}) doc = CoNLL.conll2doc(input_str=EN_DOC_CONLLU_PRETAGGED) processed_doc = nlp(doc) assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\n\n'.join( [sent.dependencies_string() for sent in processed_doc.sentences]) def test_raises_requirements_exception_if_pretagged_not_passed(): with pytest.raises(PipelineRequirementsException): stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'}) ================================================ FILE: stanza/tests/pipeline/test_english_pipeline.py ================================================ """ Basic testing of the English pipeline """ import pytest import stanza from stanza.utils.conll import CoNLL from stanza.models.common.doc import Document from stanza.tests import * from stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # data for testing EN_DOC = "Barack Obama was born in Hawaii. He was elected president in 2008. Obama attended Harvard." EN_DOCS = ["Barack Obama was born in Hawaii.", "He was elected president in 2008.", "Obama attended Harvard."] EN_DOC_TOKENS_GOLD = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() EN_DOC_WORDS_GOLD = """ """.strip() EN_DOC_DEPENDENCY_PARSES_GOLD = """ ('Barack', 4, 'nsubj:pass') ('Obama', 1, 'flat') ('was', 4, 'aux:pass') ('born', 0, 'root') ('in', 6, 'case') ('Hawaii', 4, 'obl') ('.', 4, 'punct') ('He', 3, 'nsubj:pass') ('was', 3, 'aux:pass') ('elected', 0, 'root') ('president', 3, 'xcomp') ('in', 6, 'case') ('2008', 3, 'obl') ('.', 3, 'punct') ('Obama', 2, 'nsubj') ('attended', 0, 'root') ('Harvard', 2, 'obj') ('.', 2, 'punct') """.strip() EN_DOC_CONLLU_GOLD = """ # text = Barack Obama was born in Hawaii. # sent_id = 0 # constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .))) # sentiment = 1 1 Barack Barack PROPN NNP Number=Sing 4 nsubj:pass _ start_char=0|end_char=6|ner=B-PERSON 2 Obama Obama PROPN NNP Number=Sing 1 flat _ start_char=7|end_char=12|ner=E-PERSON 3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 4 aux:pass _ start_char=13|end_char=16|ner=O 4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=17|end_char=21|ner=O 5 in in ADP IN _ 6 case _ start_char=22|end_char=24|ner=O 6 Hawaii Hawaii PROPN NNP Number=Sing 4 obl _ SpaceAfter=No|start_char=25|end_char=31|ner=S-GPE 7 . . PUNCT . _ 4 punct _ SpacesAfter=\\s\\s|start_char=31|end_char=32|ner=O # text = He was elected president in 2008. # sent_id = 1 # constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .))) # sentiment = 1 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=34|end_char=36|ner=O 2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=37|end_char=40|ner=O 3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=41|end_char=48|ner=O 4 president president NOUN NN Number=Sing 3 xcomp _ start_char=49|end_char=58|ner=O 5 in in ADP IN _ 6 case _ start_char=59|end_char=61|ner=O 6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ SpaceAfter=No|start_char=62|end_char=66|ner=S-DATE 7 . . PUNCT . _ 3 punct _ SpacesAfter=\\s\\s|start_char=66|end_char=67|ner=O # text = Obama attended Harvard. # sent_id = 2 # constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .))) # sentiment = 1 1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=69|end_char=74|ner=S-PERSON 2 attended attend VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root _ start_char=75|end_char=83|ner=O 3 Harvard Harvard PROPN NNP Number=Sing 2 obj _ SpaceAfter=No|start_char=84|end_char=91|ner=S-ORG 4 . . PUNCT . _ 2 punct _ SpaceAfter=No|start_char=91|end_char=92|ner=O """.strip() EN_DOC_CONLLU_GOLD_MULTIDOC = """ # text = Barack Obama was born in Hawaii. # sent_id = 0 # constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .))) # sentiment = 1 1 Barack Barack PROPN NNP Number=Sing 4 nsubj:pass _ start_char=0|end_char=6|ner=B-PERSON 2 Obama Obama PROPN NNP Number=Sing 1 flat _ start_char=7|end_char=12|ner=E-PERSON 3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 4 aux:pass _ start_char=13|end_char=16|ner=O 4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=17|end_char=21|ner=O 5 in in ADP IN _ 6 case _ start_char=22|end_char=24|ner=O 6 Hawaii Hawaii PROPN NNP Number=Sing 4 obl _ SpaceAfter=No|start_char=25|end_char=31|ner=S-GPE 7 . . PUNCT . _ 4 punct _ SpaceAfter=No|start_char=31|end_char=32|ner=O # text = He was elected president in 2008. # sent_id = 1 # constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .))) # sentiment = 1 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=0|end_char=2|ner=O 2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=3|end_char=6|ner=O 3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=7|end_char=14|ner=O 4 president president NOUN NN Number=Sing 3 xcomp _ start_char=15|end_char=24|ner=O 5 in in ADP IN _ 6 case _ start_char=25|end_char=27|ner=O 6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ SpaceAfter=No|start_char=28|end_char=32|ner=S-DATE 7 . . PUNCT . _ 3 punct _ SpaceAfter=No|start_char=32|end_char=33|ner=O # text = Obama attended Harvard. # sent_id = 2 # constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .))) # sentiment = 1 1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=0|end_char=5|ner=S-PERSON 2 attended attend VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root _ start_char=6|end_char=14|ner=O 3 Harvard Harvard PROPN NNP Number=Sing 2 obj _ SpaceAfter=No|start_char=15|end_char=22|ner=S-ORG 4 . . PUNCT . _ 2 punct _ SpaceAfter=No|start_char=22|end_char=23|ner=O """.strip() PRETOKENIZED_TEXT = "Jennifer has lovely blue antennae ." PRETOKENIZED_PIECES = [PRETOKENIZED_TEXT.split()] EXPECTED_TOKENIZED_ONLY_CONLLU = """ # text = Jennifer has lovely blue antennae . # sent_id = 0 1 Jennifer _ _ _ _ 0 _ _ start_char=0|end_char=8 2 has _ _ _ _ 1 _ _ start_char=9|end_char=12 3 lovely _ _ _ _ 2 _ _ start_char=13|end_char=19 4 blue _ _ _ _ 3 _ _ start_char=20|end_char=24 5 antennae _ _ _ _ 4 _ _ start_char=25|end_char=33 6 . _ _ _ _ 5 _ _ SpaceAfter=No|start_char=34|end_char=35 """.strip() EXPECTED_PRETOKENIZED_CONLLU = """ # text = Jennifer has lovely blue antennae . # sent_id = 0 # constituency = (ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ lovely) (JJ blue) (NNS antennae))) (. .))) # sentiment = 2 1 Jennifer Jennifer PROPN NNP Number=Sing 2 nsubj _ start_char=0|end_char=8|ner=S-PERSON 2 has have VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root _ start_char=9|end_char=12|ner=O 3 lovely lovely ADJ JJ Degree=Pos 5 amod _ start_char=13|end_char=19|ner=O 4 blue blue ADJ JJ Degree=Pos 5 amod _ start_char=20|end_char=24|ner=O 5 antennae antenna NOUN NNS Number=Plur 2 obj _ start_char=25|end_char=33|ner=O 6 . . PUNCT . _ 2 punct _ SpaceAfter=No|start_char=34|end_char=35|ner=O """.strip() class TestEnglishPipeline: @pytest.fixture(scope="class") def pipeline(self): return stanza.Pipeline(dir=TEST_MODELS_DIR, download_method=None) @pytest.fixture(scope="class") def pretokenized_pipeline(self): return stanza.Pipeline(dir=TEST_MODELS_DIR, tokenize_pretokenized=True, download_method=None) @pytest.fixture(scope="class") def tokenizer_pipeline(self): return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize", download_method=None) @pytest.fixture(scope="class") def processed_doc(self, pipeline): """ Document created by running full English pipeline on a few sentences """ return pipeline(EN_DOC) def test_text(self, processed_doc): assert processed_doc.text == EN_DOC def test_conllu(self, processed_doc): assert "{:C}".format(processed_doc) == EN_DOC_CONLLU_GOLD def test_process_conllu(self, pipeline): """ Process a conllu text directly This can use the pipeline which still uses tokenization, as process_conllu skips the tokenize and mwt processors """ doc = pipeline.process_conllu(EN_DOC_CONLLU_GOLD) result = "{:C}".format(doc) assert result == EN_DOC_CONLLU_GOLD def test_tokens(self, processed_doc): assert "\n\n".join([sent.tokens_string() for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD def test_words(self, processed_doc): assert "\n\n".join([sent.words_string() for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD def test_dependency_parse(self, processed_doc): assert "\n\n".join([sent.dependencies_string() for sent in processed_doc.sentences]) == \ EN_DOC_DEPENDENCY_PARSES_GOLD def test_empty(self, pipeline): # make sure that various models handle the degenerate empty case pipeline("") pipeline("--") def test_bulk_process(self, pipeline): """ Double check that the bulk_process method in Pipeline converts documents as expected """ # it should process strings processed = pipeline.bulk_process(EN_DOCS) assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC # it should pass Documents through successfully docs = [Document([], text=t) for t in EN_DOCS] processed = pipeline.bulk_process(docs) assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC def test_empty_bulk_process(self, pipeline): """ Previously we had a bug where an empty document list would cause a crash """ processed = pipeline.bulk_process([]) assert processed == [] def test_pretokenized(self, pretokenized_pipeline, tokenizer_pipeline): doc = pretokenized_pipeline(PRETOKENIZED_PIECES) conllu = "{:C}".format(doc).strip() assert conllu == EXPECTED_PRETOKENIZED_CONLLU doc = tokenizer_pipeline(PRETOKENIZED_TEXT) conllu = "{:C}".format(doc).strip() assert conllu == EXPECTED_TOKENIZED_ONLY_CONLLU # putting a doc with tokens into the pipeline should also work reparsed = pretokenized_pipeline(doc) conllu = "{:C}".format(reparsed).strip() assert conllu == EXPECTED_PRETOKENIZED_CONLLU def test_bulk_pretokenized(self, pretokenized_pipeline, tokenizer_pipeline): doc = tokenizer_pipeline(PRETOKENIZED_TEXT) conllu = "{:C}".format(doc).strip() assert conllu == EXPECTED_TOKENIZED_ONLY_CONLLU docs = pretokenized_pipeline([doc, doc]) assert len(docs) == 2 for doc in docs: conllu = "{:C}".format(doc).strip() assert conllu == EXPECTED_PRETOKENIZED_CONLLU def test_conll2doc_pretokenized(self, pretokenized_pipeline): doc = CoNLL.conll2doc(input_str=EXPECTED_TOKENIZED_ONLY_CONLLU) # this was bug from version 1.10.1 sent to us from a user # the pretokenized tokenize_processor would try to whitespace tokenize a document # even if the document already had sentences & words & stuff # not only would that be wrong if the text wouldn't whitespace tokenize into the words # (such as with punctuation and SpaceAfter=No), # it wouldn't even work in the case of conll2doc, since the document.text wasn't set docs = pretokenized_pipeline([doc, doc]) assert len(docs) == 2 for doc in docs: conllu = "{:C}".format(doc).strip() assert conllu == EXPECTED_PRETOKENIZED_CONLLU def test_stream(self, pipeline): """ Test the streaming interface to the Pipeline """ # Test all of the documents in one batch # (the default batch size is significantly more than |EN_DOCS|) processed = [doc for doc in pipeline.stream(EN_DOCS)] assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC # It should also work on an iterator rather than an iterable processed = [doc for doc in pipeline.stream(iter(EN_DOCS))] assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC # Stream one at a time processed = [doc for doc in pipeline.stream(EN_DOCS, batch_size=1)] processed = ["{:C}".format(doc) for doc in processed] assert "\n\n".join(processed) == EN_DOC_CONLLU_GOLD_MULTIDOC @pytest.fixture(scope="class") def processed_multidoc(self, pipeline): """ Document created by running full English pipeline on a few sentences """ docs = [Document([], text=t) for t in EN_DOCS] return pipeline(docs) def test_conllu_multidoc(self, processed_multidoc): assert "\n\n".join(["{:C}".format(doc) for doc in processed_multidoc]) == EN_DOC_CONLLU_GOLD_MULTIDOC def test_tokens_multidoc(self, processed_multidoc): assert "\n\n".join([sent.tokens_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD def test_words_multidoc(self, processed_multidoc): assert "\n\n".join([sent.words_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD def test_sentence_indices_multidoc(self, processed_multidoc): sentences = [sent for doc in processed_multidoc for sent in doc.sentences] for sent_idx, sentence in enumerate(sentences): assert sent_idx == sentence.index def test_dependency_parse_multidoc(self, processed_multidoc): assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == \ EN_DOC_DEPENDENCY_PARSES_GOLD @pytest.fixture(scope="class") def processed_multidoc_variant(self): """ Document created by running full English pipeline on a few sentences """ docs = [Document([], text=t) for t in EN_DOCS] nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors={'tokenize': 'spacy'}) return nlp(docs) def test_dependency_parse_multidoc_variant(self, processed_multidoc_variant): assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc_variant for sent in processed_doc.sentences]) == \ EN_DOC_DEPENDENCY_PARSES_GOLD def test_constituency_parser(self): nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency") doc = nlp("This is a test") assert str(doc.sentences[0].constituency) == '(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))' def test_on_gpu(self, pipeline): """ The default pipeline should have all the models on the GPU """ check_on_gpu(pipeline) def test_on_cpu(self): """ Create a pipeline on the CPU, check that all the models on CPU """ pipeline = stanza.Pipeline("en", dir=TEST_MODELS_DIR, use_gpu=False) check_on_cpu(pipeline) ================================================ FILE: stanza/tests/pipeline/test_french_pipeline.py ================================================ """ Basic testing of French pipeline The benefit of this test is to verify that the bulk processing works for languages with MWT in them """ import pytest import stanza from stanza.models.common.doc import Document from stanza.tests import * from stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu pytestmark = pytest.mark.pipeline FR_MWT_SENTENCE = "Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de " \ "l'Industrie et du Numérique." EXPECTED_RESULT = """ [ [ { "id": 1, "text": "Alors", "lemma": "alors", "upos": "ADV", "head": 3, "deprel": "advmod", "start_char": 0, "end_char": 5 }, { "id": 2, "text": "encore", "lemma": "encore", "upos": "ADV", "head": 3, "deprel": "advmod", "start_char": 6, "end_char": 12 }, { "id": 3, "text": "inconnu", "lemma": "inconnu", "upos": "ADJ", "feats": "Gender=Masc|Number=Sing", "head": 11, "deprel": "advcl", "start_char": 13, "end_char": 20 }, { "id": [ 4, 5 ], "text": "du", "start_char": 21, "end_char": 23 }, { "id": 4, "text": "de", "lemma": "de", "upos": "ADP", "head": 7, "deprel": "case" }, { "id": 5, "text": "le", "lemma": "le", "upos": "DET", "feats": "Definite=Def|Gender=Masc|Number=Sing|PronType=Art", "head": 7, "deprel": "det" }, { "id": 6, "text": "grand", "lemma": "grand", "upos": "ADJ", "feats": "Gender=Masc|Number=Sing", "head": 7, "deprel": "amod", "start_char": 24, "end_char": 29 }, { "id": 7, "text": "public", "lemma": "public", "upos": "NOUN", "feats": "Gender=Masc|Number=Sing", "head": 3, "deprel": "obl:arg", "start_char": 30, "end_char": 36, "misc": "SpaceAfter=No" }, { "id": 8, "text": ",", "lemma": ",", "upos": "PUNCT", "head": 3, "deprel": "punct", "start_char": 36, "end_char": 37 }, { "id": 9, "text": "Emmanuel", "lemma": "Emmanuel", "upos": "PROPN", "head": 11, "deprel": "nsubj", "start_char": 38, "end_char": 46 }, { "id": 10, "text": "Macron", "lemma": "Macron", "upos": "PROPN", "head": 9, "deprel": "flat:name", "start_char": 47, "end_char": 53 }, { "id": 11, "text": "devient", "lemma": "devenir", "upos": "VERB", "feats": "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin", "head": 0, "deprel": "root", "start_char": 54, "end_char": 61 }, { "id": 12, "text": "en", "lemma": "en", "upos": "ADP", "head": 13, "deprel": "case", "start_char": 62, "end_char": 64 }, { "id": 13, "text": "2014", "lemma": "2014", "upos": "NUM", "feats": "Number=Plur", "head": 11, "deprel": "obl:mod", "start_char": 65, "end_char": 69 }, { "id": 14, "text": "ministre", "lemma": "ministre", "upos": "NOUN", "feats": "Gender=Masc|Number=Sing", "head": 11, "deprel": "xcomp", "start_char": 70, "end_char": 78 }, { "id": 15, "text": "de", "lemma": "de", "upos": "ADP", "head": 17, "deprel": "case", "start_char": 79, "end_char": 81 }, { "id": 16, "text": "l'", "lemma": "le", "upos": "DET", "feats": "Definite=Def|Number=Sing|PronType=Art", "head": 17, "deprel": "det", "start_char": 82, "end_char": 84, "misc": "SpaceAfter=No" }, { "id": 17, "text": "Économie", "lemma": "économie", "upos": "NOUN", "feats": "Gender=Fem|Number=Sing", "head": 14, "deprel": "nmod", "start_char": 84, "end_char": 92, "misc": "SpaceAfter=No" }, { "id": 18, "text": ",", "lemma": ",", "upos": "PUNCT", "head": 21, "deprel": "punct", "start_char": 92, "end_char": 93 }, { "id": 19, "text": "de", "lemma": "de", "upos": "ADP", "head": 21, "deprel": "case", "start_char": 94, "end_char": 96 }, { "id": 20, "text": "l'", "lemma": "le", "upos": "DET", "feats": "Definite=Def|Number=Sing|PronType=Art", "head": 21, "deprel": "det", "start_char": 97, "end_char": 99, "misc": "SpaceAfter=No" }, { "id": 21, "text": "Industrie", "lemma": "industrie", "upos": "NOUN", "feats": "Gender=Fem|Number=Sing", "head": 17, "deprel": "conj", "start_char": 99, "end_char": 108 }, { "id": 22, "text": "et", "lemma": "et", "upos": "CCONJ", "head": 25, "deprel": "cc", "start_char": 109, "end_char": 111 }, { "id": [ 23, 24 ], "text": "du", "start_char": 112, "end_char": 114 }, { "id": 23, "text": "de", "lemma": "de", "upos": "ADP", "head": 25, "deprel": "case" }, { "id": 24, "text": "le", "lemma": "le", "upos": "DET", "feats": "Definite=Def|Gender=Masc|Number=Sing|PronType=Art", "head": 25, "deprel": "det" }, { "id": 25, "text": "Numérique", "lemma": "numérique", "upos": "NOUN", "feats": "Gender=Masc|Number=Sing", "head": 17, "deprel": "conj", "start_char": 115, "end_char": 124, "misc": "SpaceAfter=No" }, { "id": 26, "text": ".", "lemma": ".", "upos": "PUNCT", "head": 11, "deprel": "punct", "start_char": 124, "end_char": 125, "misc": "SpaceAfter=No" } ] ] """ class TestFrenchPipeline: @pytest.fixture(scope="class") def pipeline(self): """ Create a pipeline with French models """ pipeline = stanza.Pipeline(processors='tokenize,mwt,pos,lemma,depparse', dir=TEST_MODELS_DIR, lang='fr') return pipeline def test_single(self, pipeline): doc = pipeline(FR_MWT_SENTENCE) compare_ignoring_whitespace(str(doc), EXPECTED_RESULT) def test_bulk(self, pipeline): NUM_DOCS = 10 raw_text = [FR_MWT_SENTENCE] * NUM_DOCS raw_doc = [Document([], text=doccontent) for doccontent in raw_text] result = pipeline(raw_doc) assert len(result) == NUM_DOCS for doc in result: compare_ignoring_whitespace(str(doc), EXPECTED_RESULT) assert len(doc.sentences) == 1 assert doc.num_words == 26 assert doc.num_tokens == 24 def test_on_gpu(self, pipeline): """ The default pipeline should have all the models on the GPU """ check_on_gpu(pipeline) def test_on_cpu(self): """ Create a pipeline on the CPU, check that all the models on CPU """ pipeline = stanza.Pipeline("fr", dir=TEST_MODELS_DIR, use_gpu=False) check_on_cpu(pipeline) ================================================ FILE: stanza/tests/pipeline/test_lemmatizer.py ================================================ """ Basic testing of lemmatization """ import pytest import stanza from stanza.tests import * from stanza.models.common.doc import TEXT, UPOS, LEMMA pytestmark = pytest.mark.pipeline EN_DOC = "Joe Smith was born in California." EN_DOC_IDENTITY_GOLD = """ Joe Joe Smith Smith was was born born in in California California . . """.strip() EN_DOC_LEMMATIZER_MODEL_GOLD = """ Joe Joe Smith Smith was be born bear in in California California . . """.strip() def test_identity_lemmatizer(): nlp = stanza.Pipeline(**{'processors': 'tokenize,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_use_identity': True}, download_method=None) doc = nlp(EN_DOC) word_lemma_pairs = [] for w in doc.iter_words(): word_lemma_pairs += [f"{w.text} {w.lemma}"] assert EN_DOC_IDENTITY_GOLD == "\n".join(word_lemma_pairs) def test_full_lemmatizer(): nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, download_method=None) doc = nlp(EN_DOC) word_lemma_pairs = [] for w in doc.iter_words(): word_lemma_pairs += [f"{w.text} {w.lemma}"] assert EN_DOC_LEMMATIZER_MODEL_GOLD == "\n".join(word_lemma_pairs) def find_unknown_word(lemmatizer, base): for i in range(10): base = base + "z" if base not in lemmatizer.word_dict and all(x[0] != base for x in lemmatizer.composite_dict.keys()): return base raise RuntimeError("wtf?") def test_store_results(): nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, lemma_store_results=True, download_method=None) lemmatizer = nlp.processors["lemma"]._trainer az = find_unknown_word(lemmatizer, "a") bz = find_unknown_word(lemmatizer, "b") cz = find_unknown_word(lemmatizer, "c") # try sentences with the order long, short doc = nlp("I found an " + az + " in my " + bz + ". It was a " + cz) stuff = doc.get([TEXT, UPOS, LEMMA]) assert len(stuff) == 12 assert stuff[3][0] == az assert stuff[6][0] == bz assert stuff[11][0] == cz assert lemmatizer.composite_dict[(az, stuff[3][1])] == stuff[3][2] assert lemmatizer.composite_dict[(bz, stuff[6][1])] == stuff[6][2] assert lemmatizer.composite_dict[(cz, stuff[11][1])] == stuff[11][2] doc2 = nlp("I found an " + az + " in my " + bz + ". It was a " + cz) stuff2 = doc2.get([TEXT, UPOS, LEMMA]) assert stuff == stuff2 dz = find_unknown_word(lemmatizer, "d") ez = find_unknown_word(lemmatizer, "e") fz = find_unknown_word(lemmatizer, "f") # try sentences with the order long, short doc = nlp("It was a " + dz + ". I found an " + ez + " in my " + fz) stuff = doc.get([TEXT, UPOS, LEMMA]) assert len(stuff) == 12 assert stuff[3][0] == dz assert stuff[8][0] == ez assert stuff[11][0] == fz assert lemmatizer.composite_dict[(dz, stuff[3][1])] == stuff[3][2] assert lemmatizer.composite_dict[(ez, stuff[8][1])] == stuff[8][2] assert lemmatizer.composite_dict[(fz, stuff[11][1])] == stuff[11][2] doc2 = nlp("It was a " + dz + ". I found an " + ez + " in my " + fz) stuff2 = doc2.get([TEXT, UPOS, LEMMA]) assert stuff == stuff2 assert az not in lemmatizer.word_dict def test_caseless_lemmatizer(): """ Test that setting the lemmatizer as caseless at Pipeline time lowercases the text """ nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None) # the capital letter here should throw off the lemmatizer & it won't remove the plural # although weirdly the current English model *does* lowercase the A doc = nlp("Here is an Excerpt") assert doc.sentences[0].words[-1].lemma == 'excerpt' nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None, lemma_caseless=True) # with the model set to lowercasing, the word will be treated as if it were 'antennae' doc = nlp("Here is an Excerpt") assert doc.sentences[0].words[-1].lemma == 'Excerpt' def test_latin_caseless_lemmatizer(): """ Test the Latin caseless lemmatizer """ nlp = stanza.Pipeline('la', package='ittb', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None) lemmatizer = nlp.processors['lemma'] assert lemmatizer.config['caseless'] doc = nlp("Quod Erat Demonstrandum") expected_lemmas = "qui sum demonstro".split() assert len(doc.sentences) == 1 assert len(doc.sentences[0].words) == 3 for word, expected in zip(doc.sentences[0].words, expected_lemmas): assert word.lemma == expected def test_contextual_lemmatizer(): nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, package={"lemma": "default_accurate"}, download_method="reuse_resources") lemmatizer = nlp.processors['lemma']._trainer # the accurate model should have a 's classifier assert len(lemmatizer.contextual_lemmatizers) > 0 # ideally the doc would have 'have' as the lemma for the second # word, but maybe it's not always accurate. actually, it works # fine at the time of this test doc = nlp("He's added a contextual lemmatizer") ================================================ FILE: stanza/tests/pipeline/test_pipeline_constituency_processor.py ================================================ import gc import pytest import stanza from stanza.models.common.foundation_cache import FoundationCache from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # data for testing TEST_TEXT = "This is a test. Another sentence. Are these sorted?" TEST_TOKENS = [["This", "is", "a", "test", "."], ["Another", "sentence", "."], ["Are", "these", "sorted", "?"]] @pytest.fixture(scope="module") def foundation_cache(): # the test suite sometimes winds up holding on to GPU memory for too long, # resulting in an OOM error # occasionally calling gc.collect() will help gc.collect() return FoundationCache() def check_results(doc): assert len(doc.sentences) == len(TEST_TOKENS) for sentence, expected in zip(doc.sentences, TEST_TOKENS): assert sentence.constituency.leaf_labels() == expected def test_sorted_big_batch(foundation_cache): pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None) doc = pipe(TEST_TEXT) check_results(doc) def test_comments(foundation_cache): """ Test that the pipeline is creating constituency comments """ pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None) doc = pipe(TEST_TEXT) check_results(doc) for sentence in doc.sentences: assert any(x.startswith("# constituency = ") for x in sentence.comments) doc.sentences[0].constituency = "asdf" assert "# constituency = asdf" in doc.sentences[0].comments for sentence in doc.sentences: assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 1 def test_illegal_batch_size(foundation_cache): stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos", constituency_batch_size="zzz", foundation_cache=foundation_cache, download_method=None) with pytest.raises(ValueError): stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size="zzz", foundation_cache=foundation_cache, download_method=None) def test_sorted_one_batch(foundation_cache): pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=1, foundation_cache=foundation_cache, download_method=None) doc = pipe(TEST_TEXT) check_results(doc) def test_sorted_two_batch(foundation_cache): pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=2, foundation_cache=foundation_cache, download_method=None) doc = pipe(TEST_TEXT) check_results(doc) def test_get_constituents(foundation_cache): pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None) assert "SBAR" in pipe.processors["constituency"].get_constituents() ================================================ FILE: stanza/tests/pipeline/test_pipeline_depparse_processor.py ================================================ """ Basic testing of part of speech tagging """ import pytest import stanza from stanza.models.common.vocab import VOCAB_PREFIX from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] class TestClassifier: @pytest.fixture(scope="class") def english_depparse(self): """ Get a depparse_processor for English """ nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'}) assert 'depparse' in nlp.processors return nlp.processors['depparse'] def test_get_known_relations(self, english_depparse): """ Test getting the known relations from a processor. Doesn't test that all the relations exist, since who knows what will change in the future """ relations = english_depparse.get_known_relations() assert len(relations) > 5 assert 'case' in relations for i in VOCAB_PREFIX: assert i not in relations ================================================ FILE: stanza/tests/pipeline/test_pipeline_mwt_expander.py ================================================ """ Basic testing of multi-word-token expansion """ import pytest import stanza from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # mwt data for testing FR_MWT_SENTENCE = "Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de " \ "l'Industrie et du Numérique." FR_MWT_TOKEN_TO_WORDS_GOLD = """ token: Alors words: [] token: encore words: [] token: inconnu words: [] token: du words: [, ] token: grand words: [] token: public words: [] token: , words: [] token: Emmanuel words: [] token: Macron words: [] token: devient words: [] token: en words: [] token: 2014 words: [] token: ministre words: [] token: de words: [] token: l' words: [] token: Économie words: [] token: , words: [] token: de words: [] token: l' words: [] token: Industrie words: [] token: et words: [] token: du words: [, ] token: Numérique words: [] token: . words: [] """.strip() FR_MWT_WORD_TO_TOKEN_GOLD = """ word: Alors token parent:1-Alors word: encore token parent:2-encore word: inconnu token parent:3-inconnu word: de token parent:4-5-du word: le token parent:4-5-du word: grand token parent:6-grand word: public token parent:7-public word: , token parent:8-, word: Emmanuel token parent:9-Emmanuel word: Macron token parent:10-Macron word: devient token parent:11-devient word: en token parent:12-en word: 2014 token parent:13-2014 word: ministre token parent:14-ministre word: de token parent:15-de word: l' token parent:16-l' word: Économie token parent:17-Économie word: , token parent:18-, word: de token parent:19-de word: l' token parent:20-l' word: Industrie token parent:21-Industrie word: et token parent:22-et word: de token parent:23-24-du word: le token parent:23-24-du word: Numérique token parent:25-Numérique word: . token parent:26-. """.strip() def test_mwt(): pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='fr', download_method=None) doc = pipeline(FR_MWT_SENTENCE) token_to_words = "\n".join( [f'token: {token.text.ljust(9)}\t\twords: [{", ".join([word.pretty_print() for word in token.words])}]' for sent in doc.sentences for token in sent.tokens] ).strip() word_to_token = "\n".join( [f'word: {word.text.ljust(9)}\t\ttoken parent:{"-".join([str(x) for x in word.parent.id])}-{word.parent.text}' for sent in doc.sentences for word in sent.words]).strip() assert token_to_words == FR_MWT_TOKEN_TO_WORDS_GOLD assert word_to_token == FR_MWT_WORD_TO_TOKEN_GOLD def test_unknown_character(): """ The MWT processor has a mechanism to temporarily add unknown characters to the vocab Here we check that it is properly adding the characters from a test case a user sent us """ pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None) text = "Björkängshallen's" mwt_processor = pipeline.processors["mwt"] trainer = mwt_processor.trainer # verify that the test case is still valid # (perhaps an updated MWT model will have all of these characters in the future) assert not all(x in trainer.vocab._unit2id for x in text) doc = pipeline(text) batch = mwt_processor.build_batch(doc) # the vocab used in this batch should have the missing characters assert all(x in batch.vocab._unit2id for x in text) def test_unknown_word(): """ Test a word which wasn't in the MWT training data The seq2seq model for MWT was randomly hallucinating, but with the CharacterClassifier, it should be able to process unusual MWT without hallucinations """ pipe = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None) doc = pipe("I read the newspaper's report.") assert len(doc.sentences) == 1 assert len(doc.sentences[0].tokens) == 6 assert len(doc.sentences[0].tokens[3].words) == 2 assert doc.sentences[0].tokens[3].words[0].text == 'newspaper' # double check that this is something unknown to the model mwt_processor = pipe.processors["mwt"] trainer = mwt_processor.trainer expansion = trainer.dict_expansion("newspaper's") assert expansion is None ================================================ FILE: stanza/tests/pipeline/test_pipeline_ner_processor.py ================================================ import pytest import stanza from stanza.utils.conll import CoNLL from stanza.models.common.doc import Document from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # data for testing EN_DOCS = ["Barack Obama was born in Hawaii.", "He was elected president in 2008.", "Obama attended Harvard."] EXPECTED_ENTS = [[{ "text": "Barack Obama", "type": "PERSON", "start_char": 0, "end_char": 12 }, { "text": "Hawaii", "type": "GPE", "start_char": 25, "end_char": 31 }], [{ "text": "2008", "type": "DATE", "start_char": 28, "end_char": 32 }], [{ "text": "Obama", "type": "PERSON", "start_char": 0, "end_char": 5 }, { "text": "Harvard", "type": "ORG", "start_char": 15, "end_char": 22 }]] def check_entities_equal(doc, expected): """ Checks that the entities of a doc are equal to the given list of maps """ assert len(doc.ents) == len(expected) for doc_entity, expected_entity in zip(doc.ents, expected): for k in expected_entity: assert getattr(doc_entity, k) == expected_entity[k] class TestNERProcessor: @pytest.fixture(scope="class") def pipeline(self): """ A reusable pipeline with the NER module """ return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,ner") @pytest.fixture(scope="class") def processed_doc(self, pipeline): """ Document created by running full English pipeline on a few sentences """ return [pipeline(text) for text in EN_DOCS] @pytest.fixture(scope="class") def processed_bulk(self, pipeline): """ Document created by running full English pipeline on a few sentences """ docs = [Document([], text=t) for t in EN_DOCS] return pipeline(docs) def test_bulk_ents(self, processed_bulk): assert len(processed_bulk) == len(EXPECTED_ENTS) for doc, expected in zip(processed_bulk, EXPECTED_ENTS): check_entities_equal(doc, expected) def test_ents(self, processed_doc): assert len(processed_doc) == len(EXPECTED_ENTS) for doc, expected in zip(processed_doc, EXPECTED_ENTS): check_entities_equal(doc, expected) EXPECTED_MULTI_ENTS = [{ "text": "John Bauer", "type": "PERSON", "start_char": 0, "end_char": 10 }, { "text": "Stanford", "type": "ORG", "start_char": 20, "end_char": 28 }, { "text": "hip arthritis", "type": "DISEASE", "start_char": 37, "end_char": 50 }, { "text": "Chris Manning", "type": "PERSON", "start_char": 66, "end_char": 79 }] EXPECTED_MULTI_NER = [ [('O', 'B-PERSON'), ('O', 'E-PERSON'), ('O', 'O'), ('O', 'O'), ('O', 'S-ORG'), ('O', 'O'), ('O', 'O'), ('B-DISEASE', 'O'), ('E-DISEASE', 'O'), ('O', 'O')], [('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'B-PERSON'), ('O', 'E-PERSON'),]] class TestMultiNERProcessor: @pytest.fixture(scope="class") def pipeline(self): """ A reusable pipeline with TWO ner models """ return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,ner", package={"ner": ["ncbi_disease", "ontonotes_charlm"]}) def test_multi_example(self, pipeline): doc = pipeline("John Bauer works at Stanford and has hip arthritis. He works for Chris Manning") check_entities_equal(doc, EXPECTED_MULTI_ENTS) def test_multi_ner(self, pipeline): """ Test that multiple NER labels are correctly assigned in tuples """ doc = pipeline("John Bauer works at Stanford and has hip arthritis. He works for Chris Manning") multi_ner = [[token.multi_ner for token in sentence.tokens] for sentence in doc.sentences] assert multi_ner == EXPECTED_MULTI_NER def test_known_tags(self, pipeline): assert pipeline.processors["ner"].get_known_tags() == ["DISEASE"] assert len(pipeline.processors["ner"].get_known_tags(1)) == 18 ================================================ FILE: stanza/tests/pipeline/test_pipeline_pos_processor.py ================================================ """ Basic testing of part of speech tagging """ import pytest import stanza from stanza.tests import * pytestmark = pytest.mark.pipeline EN_DOC = "Joe Smith was born in California." EN_DOC_GOLD = """ ]> ]> ]> ]> ]> ]> ]> """.strip() @pytest.fixture(scope="module") def pos_pipeline(): return stanza.Pipeline(**{'processors': 'tokenize,pos', 'dir': TEST_MODELS_DIR, 'download_method': None, 'lang': 'en'}) def test_part_of_speech(pos_pipeline): doc = pos_pipeline(EN_DOC) assert EN_DOC_GOLD == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) def test_get_known_xpos(pos_pipeline): tags = pos_pipeline.processors['pos'].get_known_xpos() # make sure we have xpos... assert 'DT' in tags # ... and not upos assert 'DET' not in tags def test_get_known_upos(pos_pipeline): tags = pos_pipeline.processors['pos'].get_known_upos() # make sure we have upos... assert 'DET' in tags # ... and not xpos assert 'DT' not in tags def test_get_known_feats(pos_pipeline): feats = pos_pipeline.processors['pos'].get_known_feats() # I appreciate how self-referential the Abbr feat is assert 'Abbr' in feats assert 'Yes' in feats['Abbr'] ================================================ FILE: stanza/tests/pipeline/test_pipeline_sentiment_processor.py ================================================ import gc import pytest import stanza from stanza.utils.conll import CoNLL from stanza.models.common.doc import Document from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] # data for testing EN_DOCS = ["Ragavan is terrible and should go away.", "Today is okay.", "Urza's Saga is great."] EN_DOC = " ".join(EN_DOCS) EXPECTED = [0, 1, 2] class TestSentimentPipeline: @pytest.fixture(scope="class") def pipeline(self): """ A reusable pipeline with the NER module """ gc.collect() return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,sentiment") def test_simple(self, pipeline): results = [] for text in EN_DOCS: doc = pipeline(text) assert len(doc.sentences) == 1 results.append(doc.sentences[0].sentiment) assert EXPECTED == results def test_multiple_sentences(self, pipeline): doc = pipeline(EN_DOC) assert len(doc.sentences) == 3 results = [sentence.sentiment for sentence in doc.sentences] assert EXPECTED == results def test_empty_text(self, pipeline): """ Test empty text and a text which might get reduced to empty text by removing dashes """ doc = pipeline("") assert len(doc.sentences) == 0 doc = pipeline("--") assert len(doc.sentences) == 1 ================================================ FILE: stanza/tests/pipeline/test_requirements.py ================================================ """ Test the requirements functionality for processors """ import pytest import stanza from stanza.pipeline.core import PipelineRequirementsException from stanza.pipeline.processor import ProcessorRequirementsException from stanza.tests import * pytestmark = pytest.mark.pipeline def check_exception_vals(req_exception, req_exception_vals): """ Check the values of a ProcessorRequirementsException against a dict of expected values. :param req_exception: the ProcessorRequirementsException to evaluate :param req_exception_vals: expected values for the ProcessorRequirementsException :return: None """ assert isinstance(req_exception, ProcessorRequirementsException) assert req_exception.processor_type == req_exception_vals['processor_type'] assert req_exception.processors_list == req_exception_vals['processors_list'] assert req_exception.err_processor.requires == req_exception_vals['requires'] def test_missing_requirements(): """ Try to build several pipelines with bad configs and check thrown exceptions against gold exceptions. :return: None """ # list of (bad configs, list of gold ProcessorRequirementsExceptions that should be thrown) pairs bad_config_lists = [ # missing tokenize ( # input config {'processors': 'pos,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, # 2 expected exceptions [ {'processor_type': 'POSProcessor', 'processors_list': ['pos', 'depparse'], 'provided_reqs': set([]), 'requires': set(['tokenize'])}, {'processor_type': 'DepparseProcessor', 'processors_list': ['pos', 'depparse'], 'provided_reqs': set([]), 'requires': set(['tokenize','pos', 'lemma'])} ] ), # no pos when lemma_pos set to True; for english mwt should not be included in the loaded processor list ( # input config {'processors': 'tokenize,mwt,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_pos': True}, # 1 expected exception [ {'processor_type': 'LemmaProcessor', 'processors_list': ['tokenize', 'mwt', 'lemma'], 'provided_reqs': set(['tokenize', 'mwt']), 'requires': set(['tokenize', 'pos'])} ] ) ] # try to build each bad config, catch exceptions, check against gold pipeline_fails = 0 for bad_config, gold_exceptions in bad_config_lists: try: stanza.Pipeline(**bad_config) except PipelineRequirementsException as e: pipeline_fails += 1 assert isinstance(e, PipelineRequirementsException) assert len(e.processor_req_fails) == len(gold_exceptions) for processor_req_e, gold_exception in zip(e.processor_req_fails,gold_exceptions): # compare the thrown ProcessorRequirementsExceptions against gold check_exception_vals(processor_req_e, gold_exception) # check pipeline building failed twice assert pipeline_fails == 2 ================================================ FILE: stanza/tests/pipeline/test_tokenizer.py ================================================ """ Basic testing of tokenization """ import pytest import stanza from stanza.tests import * pytestmark = pytest.mark.pipeline EN_DOC = "Joe Smith lives in California. Joe's favorite food is pizza. He enjoys going to the beach." EN_DOC_WITH_EXTRA_WHITESPACE = "Joe Smith \n lives in\n California. Joe's favorite food \tis pizza. \t\t\tHe enjoys \t\tgoing to the beach." EN_DOC_GOLD_TOKENS = """ ]> ]> ]> ]> ]> ]> , ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() # spaCy doesn't have MWT EN_DOC_SPACY_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() EN_DOC_POSTPROCESSOR_TOKENS_LIST = [['Joe', 'Smith', 'lives', 'in', 'California', '.'], [("Joe's", True), 'favorite', 'food', 'is', 'pizza', '.'], ['He', 'enjoys', 'going', 'to', 'the', 'beach', '.']] EN_DOC_POSTPROCESSOR_COMBINED_LIST = [['Joe', 'Smith', 'lives', 'in', 'California', '.'], ['Joe', "'s", 'favorite', 'food', 'is', 'pizza', '.'], ['He', 'enjoys', 'going', "to the beach", '.']] EN_DOC_POSTPROCESSOR_COMBINED_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """ # ensure that the entry above has spaces somewhere to test that spaces work in between tokens EN_DOC_GOLD_NOSSPLIT_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() EN_DOC_PRETOKENIZED = \ "Joe Smith lives in California .\nJoe's favorite food is pizza .\n\nHe enjoys going to the beach.\n" EN_DOC_PRETOKENIZED_GOLD_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() EN_DOC_PRETOKENIZED_LIST = [['Joe', 'Smith', 'lives', 'in', 'California', '.'], ['He', 'loves', 'pizza', '.']] EN_DOC_PRETOKENIZED_LIST_GOLD_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() EN_DOC_NO_SSPLIT = ["This is a sentence. This is another.", "This is a third."] EN_DOC_NO_SSPLIT_SENTENCES = [['This', 'is', 'a', 'sentence', '.', 'This', 'is', 'another', '.'], ['This', 'is', 'a', 'third', '.']] FR_DOC = "Le prince va manger du poulet aux les magasins aujourd'hui." FR_DOC_POSTPROCESSOR_TOKENS_LIST = [['Le', 'prince', 'va', 'manger', ('du', True), 'poulet', ('aux', True), 'les', 'magasins', "aujourd'hui", '.']] FR_DOC_POSTPROCESSOR_COMBINED_MWT_LIST = [['Le', 'prince', 'va', 'manger', ('du', True), 'poulet', ('aux', True), 'les', 'magasins', ("aujourd'hui", ["aujourd'", "hui"]), '.']] FR_DOC_PRETOKENIZED_LIST_GOLD_TOKENS = """ ]> ]> ]> ]> , ]> ]> , ]> ]> ]> , ]> ]> """ JA_DOC = "北京は中国の首都です。 北京の人口は2152万人です。\n" # add some random whitespaces that need to be skipped JA_DOC_GOLD_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() JA_DOC_GOLD_NOSSPLIT_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() ZH_DOC = "北京是中国的首都。 北京有2100万人口,是一个直辖市。\n" ZH_DOC1 = "北\n京是中\n国的首\n都。 北京有2100万人口,是一个直辖市。\n" ZH_DOC2 = "北\n京是中\n国的首\n都。\n\n 北京有2100万人口,是一个直辖市。\n" ZH_DOC_GOLD_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() ZH_DOC1_GOLD_TOKENS=""" ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() ZH_DOC_GOLD_NOSSPLIT_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() ZH_PARENS_DOC = "我们一起学(猫叫)" TH_DOC = "ข้าราชการได้รับการหมุนเวียนเป็นระยะ และเขาได้รับมอบหมายให้ประจำในระดับภูมิภาค" TH_DOC_GOLD_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() TH_DOC_GOLD_NOSSPLIT_TOKENS = """ ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> ]> """.strip() @pytest.fixture(scope="module") def basic_pipeline(): """ Create a pipeline with a basic English tokenizer """ nlp = stanza.Pipeline(processors='tokenize', dir=TEST_MODELS_DIR, lang='en', download_method=None) return nlp @pytest.fixture(scope="module") def pretokenized_pipeline(): """ Create a pipeline with a basic English pretokenized tokenizer """ nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'tokenize_pretokenized': True, 'download_method': None}) return nlp @pytest.fixture(scope="module") def zh_pipeline(): """ Create a pipeline with a basic Chinese tokenizer """ nlp = stanza.Pipeline(lang='zh', processors='tokenize', dir=TEST_MODELS_DIR, download_method=None) return nlp def test_tokenize(basic_pipeline): doc = basic_pipeline(EN_DOC) assert EN_DOC_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_tokenize_ssplit_robustness(basic_pipeline): doc = basic_pipeline(EN_DOC_WITH_EXTRA_WHITESPACE) assert EN_DOC_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_pretokenized(pretokenized_pipeline): doc = pretokenized_pipeline(EN_DOC_PRETOKENIZED) assert EN_DOC_PRETOKENIZED_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) doc = pretokenized_pipeline(EN_DOC_PRETOKENIZED_LIST) assert EN_DOC_PRETOKENIZED_LIST_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_pretokenized_multidoc(pretokenized_pipeline): doc = pretokenized_pipeline(EN_DOC_PRETOKENIZED) assert EN_DOC_PRETOKENIZED_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) doc = pretokenized_pipeline([stanza.Document([], text=EN_DOC_PRETOKENIZED_LIST)])[0] assert EN_DOC_PRETOKENIZED_LIST_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_postprocessor(): def dummy_postprocessor(in_doc): # Importantly, EN_DOC_POSTPROCESSOR_COMBINED_LIST returns a few tokens joinde # with space. As some languages (such as VN) contains tokens with space in between # its important to have joined space tested as one of the tokens assert in_doc == EN_DOC_POSTPROCESSOR_TOKENS_LIST return EN_DOC_POSTPROCESSOR_COMBINED_LIST nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'download_method': None, 'tokenize_postprocessor': dummy_postprocessor}) doc = nlp(EN_DOC) assert EN_DOC_POSTPROCESSOR_COMBINED_TOKENS.strip() == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]).strip() def test_postprocessor_mwt(): def dummy_postprocessor(input): # Importantly, EN_DOC_POSTPROCESSOR_COMBINED_LIST returns a few tokens joinde # with space. As some languages (such as VN) contains tokens with space in between # its important to have joined space tested as one of the tokens assert input == FR_DOC_POSTPROCESSOR_TOKENS_LIST return FR_DOC_POSTPROCESSOR_COMBINED_MWT_LIST nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'fr', 'download_method': None, 'tokenize_postprocessor': dummy_postprocessor}) doc = nlp(FR_DOC) assert FR_DOC_PRETOKENIZED_LIST_GOLD_TOKENS.strip() == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]).strip() def test_postprocessor_typeerror(): with pytest.raises(ValueError): nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'download_method': None, 'tokenize_postprocessor': "iamachicken"}) def test_no_ssplit(): nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'download_method': None, 'tokenize_no_ssplit': True}) doc = nlp(EN_DOC_NO_SSPLIT) assert EN_DOC_NO_SSPLIT_SENTENCES == [[w.text for w in s.words] for s in doc.sentences] assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_zh_tokenizer_skip_newline(zh_pipeline): doc = zh_pipeline(ZH_DOC1) assert ZH_DOC1_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char].replace('\n', '') == token.text for sent in doc.sentences for token in sent.tokens]) def test_zh_tokenizer_skip_newline_offsets(zh_pipeline): doc = zh_pipeline(ZH_DOC2) assert ZH_DOC1_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char].replace('\n', '') == token.text for sent in doc.sentences for token in sent.tokens]) def test_zh_tokenizer_parens(zh_pipeline): """ The original fix for newlines in Chinese text broke () in Chinese text """ doc = zh_pipeline(ZH_PARENS_DOC) # ... the results are kind of bad for this expression, so no testing of the results yet #assert ZH_PARENS_DOC_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) def test_spacy(): nlp = stanza.Pipeline(processors='tokenize', dir=TEST_MODELS_DIR, lang='en', tokenize_with_spacy=True, download_method=None) doc = nlp(EN_DOC) # make sure the loaded tokenizer is actually spacy assert "SpacyTokenizer" == nlp.processors['tokenize']._variant.__class__.__name__ assert EN_DOC_SPACY_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_spacy_no_ssplit(): nlp = stanza.Pipeline(processors='tokenize', dir=TEST_MODELS_DIR, lang='en', tokenize_with_spacy=True, tokenize_no_ssplit=True, download_method=None) doc = nlp(EN_DOC) # make sure the loaded tokenizer is actually spacy assert "SpacyTokenizer" == nlp.processors['tokenize']._variant.__class__.__name__ assert EN_DOC_GOLD_NOSSPLIT_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_sudachipy(): nlp = stanza.Pipeline(lang='ja', dir=TEST_MODELS_DIR, processors={'tokenize': 'sudachipy'}, package=None, download_method=None) doc = nlp(JA_DOC) assert "SudachiPyTokenizer" == nlp.processors['tokenize']._variant.__class__.__name__ assert JA_DOC_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_sudachipy_no_ssplit(): nlp = stanza.Pipeline(lang='ja', dir=TEST_MODELS_DIR, processors={'tokenize': 'sudachipy'}, tokenize_no_ssplit=True, package=None, download_method=None) doc = nlp(JA_DOC) assert "SudachiPyTokenizer" == nlp.processors['tokenize']._variant.__class__.__name__ assert JA_DOC_GOLD_NOSSPLIT_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_jieba(): nlp = stanza.Pipeline(lang='zh', dir=TEST_MODELS_DIR, processors={'tokenize': 'jieba'}, package=None, download_method=None) doc = nlp(ZH_DOC) assert "JiebaTokenizer" == nlp.processors['tokenize']._variant.__class__.__name__ assert ZH_DOC_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_jieba_no_ssplit(): nlp = stanza.Pipeline(lang='zh', dir=TEST_MODELS_DIR, processors={'tokenize': 'jieba'}, tokenize_no_ssplit=True, package=None, download_method=None) doc = nlp(ZH_DOC) assert "JiebaTokenizer" == nlp.processors['tokenize']._variant.__class__.__name__ assert ZH_DOC_GOLD_NOSSPLIT_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_pythainlp(): nlp = stanza.Pipeline(lang='th', dir=TEST_MODELS_DIR, processors={'tokenize': 'pythainlp'}, package=None, download_method=None) doc = nlp(TH_DOC) assert "PyThaiNLPTokenizer" == nlp.processors['tokenize']._variant.__class__.__name__ assert TH_DOC_GOLD_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) def test_pythainlp_no_ssplit(): nlp = stanza.Pipeline(lang='th', dir=TEST_MODELS_DIR, processors={'tokenize': 'pythainlp'}, tokenize_no_ssplit=True, package=None, download_method=None) doc = nlp(TH_DOC) assert "PyThaiNLPTokenizer" == nlp.processors['tokenize']._variant.__class__.__name__ assert TH_DOC_GOLD_NOSSPLIT_TOKENS == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]) assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens]) ================================================ FILE: stanza/tests/pos/__init__.py ================================================ ================================================ FILE: stanza/tests/pos/test_data.py ================================================ """ A few tests of specific operations from the Dataset """ import os import pytest from stanza.models.common.doc import * from stanza.models import tagger from stanza.models.pos.data import Dataset, ShuffledDataset from stanza.utils.conll import CoNLL from stanza.tests.pos.test_tagger import TRAIN_DATA, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_FEATS def test_basic_reading(): """ Test that a dataset with no xpos is detected by the Dataset """ # empty args for building the data object args = tagger.parse_args(args=[]) train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA) data = Dataset(train_doc, args, None) assert data.has_upos assert data.has_xpos assert data.has_feats def test_no_xpos(): """ Test that a dataset with no xpos is detected by the Dataset """ # empty args for building the data object args = tagger.parse_args(args=[]) train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA_NO_XPOS) data = Dataset(train_doc, args, None) assert data.has_upos assert not data.has_xpos assert data.has_feats def test_no_upos(): """ Test that a dataset with no upos is detected by the Dataset """ # empty args for building the data object args = tagger.parse_args(args=[]) train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA_NO_UPOS) data = Dataset(train_doc, args, None) assert not data.has_upos assert data.has_xpos assert data.has_feats def test_no_feats(): """ Test that a dataset with no feats is detected by the Dataset """ # empty args for building the data object args = tagger.parse_args(args=[]) train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA_NO_FEATS) data = Dataset(train_doc, args, None) assert data.has_upos assert data.has_xpos assert not data.has_feats def test_no_augment(): """ Test that with no punct removing augmentation, the doc always has punct at the end """ args = tagger.parse_args(args=["--shorthand", "en_test", "--augment_nopunct", "0.0"]) train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA) data = Dataset(train_doc, args, None) data = data.to_loader(batch_size=2) for i in range(50): for batch in data: for text in batch.text: assert text[-1] in (".", "!") def test_augment(): """ Test that with 100% punct removing augmentation, the doc never has punct at the end """ args = tagger.parse_args(args=["--shorthand", "en_test", "--augment_nopunct", "1.0"]) train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA) data = Dataset(train_doc, args, None) data = data.to_loader(batch_size=2) for i in range(50): for batch in data: for text in batch.text: assert text[-1] not in (".", "!") def test_sometimes_augment(): """ Test 50% punct removing augmentation With this frequency, we should get a reasonable number of docs with a punct at the end and a reasonable without. """ args = tagger.parse_args(args=["--shorthand", "en_test", "--augment_nopunct", "0.5"]) train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA) data = Dataset(train_doc, args, None) data = data.to_loader(batch_size=2) count_with = 0 count_without = 0 for i in range(50): for batch in data: for text in batch.text: if text[-1] in (".", "!"): count_with += 1 else: count_without += 1 # this should never happen # literally less than 1 in 10^20th odds assert count_with > 5 assert count_without > 5 NO_XPOS_TEMPLATE = """ # text = Noxpos {indexp} # sent_id = {index} 1 Noxpos noxpos NOUN _ Number=Sing 0 root _ start_char=0|end_char=8|ner=O 2 {indexp} {indexp} NUM _ NumForm=Digit|NumType=Card 1 dep _ start_char=9|end_char=10|ner=S-CARDINAL """.strip() YES_XPOS_TEMPLATE = """ # text = Yesxpos {indexp} # sent_id = {index} 1 Yesxpos yesxpos NOUN NN Number=Sing 0 root _ start_char=0|end_char=8|ner=O 2 {indexp} {indexp} NUM CD NumForm=Digit|NumType=Card 1 dep _ start_char=9|end_char=10|ner=S-CARDINAL """.strip() def test_shuffle(tmp_path): args = tagger.parse_args(args=["--batch_size", "10", "--shorthand", "en_test", "--augment_nopunct", "0.0"]) # 100 looked nice but was actually a 1/1000000 chance of the test failing # so let's crank it up to 1000 and make it 1/10^58 no_xpos = [NO_XPOS_TEMPLATE.format(index=idx, indexp=idx+1) for idx in range(1000)] no_doc = CoNLL.conll2doc(input_str="\n\n".join(no_xpos)) no_data = Dataset(no_doc, args, None) yes_xpos = [YES_XPOS_TEMPLATE.format(index=idx, indexp=idx+101) for idx in range(1000)] yes_doc = CoNLL.conll2doc(input_str="\n\n".join(yes_xpos)) yes_data = Dataset(yes_doc, args, None) shuffled = ShuffledDataset([no_data, yes_data], 10) assert sum(1 for _ in shuffled) == 200 num_with = 0 num_without = 0 for batch in shuffled: if batch.xpos is not None: num_with += 1 else: num_without += 1 # at the halfway point of the iteration, there should be at # least one in each category # for example, if we had forgotten to shuffle, this assertion would fail if num_with + num_without == 100: assert num_with > 1 assert num_without > 1 assert num_with == 100 assert num_without == 100 EWT_SAMPLE = """ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0048 # text = Bush asked for permission to go to Alabama to work on a Senate campaign. 1 Bush Bush PROPN NNP Number=Sing 2 nsubj 2:nsubj _ 2 asked ask VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 3 for for ADP IN _ 4 case 4:case _ 4 permission permission NOUN NN Number=Sing 2 obl 2:obl:for _ 5 to to PART TO _ 6 mark 6:mark _ 6 go go VERB VB VerbForm=Inf 4 acl 4:acl:to _ 7 to to ADP IN _ 8 case 8:case _ 8 Alabama Alabama PROPN NNP Number=Sing 6 obl 6:obl:to _ 9 to to PART TO _ 10 mark 10:mark _ 10 work work VERB VB VerbForm=Inf 6 advcl 6:advcl:to _ 11 on on ADP IN _ 14 case 14:case _ 12 a a DET DT Definite=Ind|PronType=Art 14 det 14:det _ 13 Senate Senate PROPN NNP Number=Sing 14 compound 14:compound _ 14 campaign campaign NOUN NN Number=Sing 10 obl 10:obl:on SpaceAfter=No 15 . . PUNCT . _ 2 punct 2:punct _ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0049 # text = His superior officers said OK. 1 His his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nmod:poss 3:nmod:poss _ 2 superior superior ADJ JJ Degree=Pos 3 amod 3:amod _ 3 officers officer NOUN NNS Number=Plur 4 nsubj 4:nsubj _ 4 said say VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 5 OK ok INTJ UH _ 4 obj 4:obj SpaceAfter=No 6 . . PUNCT . _ 4 punct 4:punct _ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0053 # text = In ’72 or ’73, if you were a pilot, active or Guard, and you had an obligation and wanted to get out, no problem. 1 In in ADP IN _ 2 case 2:case _ 2 ’72 '72 NUM CD NumForm=Digit|NumType=Card 10 obl 10:obl:in _ 3 or or CCONJ CC _ 4 cc 4:cc _ 4 ’73 '73 NUM CD NumForm=Digit|NumType=Card 2 conj 2:conj:or|10:obl:in SpaceAfter=No 5 , , PUNCT , _ 2 punct 2:punct _ 6 if if SCONJ IN _ 10 mark 10:mark _ 7 you you PRON PRP Case=Nom|Person=2|PronType=Prs 10 nsubj 10:nsubj _ 8 were be AUX VBD Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin 10 cop 10:cop _ 9 a a DET DT Definite=Ind|PronType=Art 10 det 10:det _ 10 pilot pilot NOUN NN Number=Sing 28 advcl 28:advcl:if SpaceAfter=No 11 , , PUNCT , _ 12 punct 12:punct _ 12 active active ADJ JJ Degree=Pos 10 amod 10:amod _ 13 or or CCONJ CC _ 14 cc 14:cc _ 14 Guard Guard PROPN NNP Number=Sing 12 conj 10:amod|12:conj:or SpaceAfter=No 15 , , PUNCT , _ 18 punct 18:punct _ 16 and and CCONJ CC _ 18 cc 18:cc _ 17 you you PRON PRP Case=Nom|Person=2|PronType=Prs 18 nsubj 18:nsubj|22:nsubj|24:nsubj:xsubj _ 18 had have VERB VBD Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin 10 conj 10:conj:and|28:advcl:if _ 19 an a DET DT Definite=Ind|PronType=Art 20 det 20:det _ 20 obligation obligation NOUN NN Number=Sing 18 obj 18:obj _ 21 and and CCONJ CC _ 22 cc 22:cc _ 22 wanted want VERB VBD Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin 18 conj 18:conj:and _ 23 to to PART TO _ 24 mark 24:mark _ 24 get get VERB VB VerbForm=Inf 22 xcomp 22:xcomp _ 25 out out ADV RB _ 24 advmod 24:advmod SpaceAfter=No 26 , , PUNCT , _ 10 punct 10:punct _ 27 no no DET DT PronType=Neg 28 det 28:det _ 28 problem problem NOUN NN Number=Sing 0 root 0:root SpaceAfter=No 29 . . PUNCT . _ 28 punct 28:punct _ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0054 # text = In fact, you were helping them solve their problem.” 1 In in ADP IN _ 2 case 2:case _ 2 fact fact NOUN NN Number=Sing 6 obl 6:obl:in SpaceAfter=No 3 , , PUNCT , _ 2 punct 2:punct _ 4 you you PRON PRP Case=Nom|Person=2|PronType=Prs 6 nsubj 6:nsubj _ 5 were be AUX VBD Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin 6 aux 6:aux _ 6 helping help VERB VBG Tense=Pres|VerbForm=Part 0 root 0:root _ 7 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 6 obj 6:obj|8:nsubj:xsubj _ 8 solve solve VERB VB VerbForm=Inf 6 xcomp 6:xcomp _ 9 their their PRON PRP$ Case=Gen|Number=Plur|Person=3|Poss=Yes|PronType=Prs 10 nmod:poss 10:nmod:poss _ 10 problem problem NOUN NN Number=Sing 8 obj 8:obj SpaceAfter=No 11 . . PUNCT . _ 6 punct 6:punct SpaceAfter=No 12 ” " PUNCT '' _ 6 punct 6:punct _ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0055 # text = So Bush stopped flying. 1 So so ADV RB _ 3 advmod 3:advmod _ 2 Bush Bush PROPN NNP Number=Sing 3 nsubj 3:nsubj|4:nsubj:xsubj _ 3 stopped stop VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 4 flying fly VERB VBG VerbForm=Ger 3 xcomp 3:xcomp SpaceAfter=No 5 . . PUNCT . _ 3 punct 3:punct _ """.lstrip() def test_length_limited_dataloader(): sample = CoNLL.conll2doc(input_str=EWT_SAMPLE) args = tagger.parse_args(args=["--batch_size", "10", "--shorthand", "en_test", "--augment_nopunct", "0.0"]) data = Dataset(sample, args, None) # this should read the whole dataset dl = data.to_length_limited_loader(5, 1000) batches = [batch.idx for batch in dl] assert batches == [(0, 1, 2, 3, 4)] dl = data.to_length_limited_loader(4, 1000) batches = [batch.idx for batch in dl] assert batches == [(0, 1, 2, 3), (4,)] dl = data.to_length_limited_loader(2, 1000) batches = [batch.idx for batch in dl] assert batches == [(0, 1), (2, 3), (4,)] # the first three sentences should reach this limit dl = data.to_length_limited_loader(5, 55) batches = [batch.idx for batch in dl] assert batches == [(0, 1, 2), (3, 4)] # the third sentence (2) is already past this limit by itself dl = data.to_length_limited_loader(5, 25) batches = [batch.idx for batch in dl] assert batches == [(0, 1), (2,), (3, 4)] EWT_PUNCT_SAMPLE = """ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0048 # text = Bush asked for permission to go to Alabama to work on a Senate campaign. 1 Bush Bush PROPN NNP Number=Sing 2 nsubj 2:nsubj _ 2 asked ask VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 3 for for ADP IN _ 4 case 4:case _ 4 permission permission NOUN NN Number=Sing 2 obl 2:obl:for _ 5 to to PART TO _ 6 mark 6:mark _ 6 go go VERB VB VerbForm=Inf 4 acl 4:acl:to _ 7 to to ADP IN _ 8 case 8:case _ 8 Alabama Alabama PROPN NNP Number=Sing 6 obl 6:obl:to _ 9 to to PART TO _ 10 mark 10:mark _ 10 work work VERB VB VerbForm=Inf 6 advcl 6:advcl:to _ 11 on on ADP IN _ 14 case 14:case _ 12 a a DET DT Definite=Ind|PronType=Art 14 det 14:det _ 13 Senate Senate PROPN NNP Number=Sing 14 compound 14:compound _ 14 campaign campaign NOUN NN Number=Sing 10 obl 10:obl:on SpaceAfter=No 15 !!!!! ! PUNCT . _ 2 punct 2:punct _ # sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0049 # text = His superior officers said OK. 1 His his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nmod:poss 3:nmod:poss _ 2 superior superior ADJ JJ Degree=Pos 3 amod 3:amod _ 3 officers officer NOUN NNS Number=Plur 4 nsubj 4:nsubj _ 4 said say VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 5 OK ok INTJ UH _ 4 obj 4:obj SpaceAfter=No 6 ????? ? PUNCT . _ 4 punct 4:punct _ """ def test_punct_simplification(): """ Test a punctuation simplification that should make it so unexpected question/exclamation marks types are processed into ? and ! """ sample = CoNLL.conll2doc(input_str=EWT_PUNCT_SAMPLE) args = tagger.parse_args(args=["--batch_size", "10", "--shorthand", "en_test", "--augment_nopunct", "0.0"]) data = Dataset(sample, args, None) dl = data.to_length_limited_loader(2, 1000) batches = [batch for batch in dl] batch_idx = [batch.idx for batch in batches] assert batch_idx == [(0, 1)] assert batches[0].text[0][-1] == '!' assert batches[0].text[1][-1] == '?' assert batches[0].text[0] == ['Bush', 'asked', 'for', 'permission', 'to', 'go', 'to', 'Alabama', 'to', 'work', 'on', 'a', 'Senate', 'campaign', '!'] assert batches[0].text[1] == ['His', 'superior', 'officers', 'said', 'OK', '?'] ================================================ FILE: stanza/tests/pos/test_tagger.py ================================================ """ Run the tagger for a couple iterations on some fake data Uses a couple sentences of UD_English-EWT as training/dev data """ import os import pytest import torch import stanza from stanza.models import tagger from stanza.models.common import pretrain from stanza.models.pos.trainer import Trainer from stanza.tests import TEST_WORKING_DIR, TEST_MODELS_DIR from stanza.utils.training.common import choose_pos_charlm, build_charlm_args pytestmark = [pytest.mark.pipeline, pytest.mark.travis] TRAIN_DATA = """ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003 # text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad. 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No 2 : : PUNCT : _ 1 punct 1:punct _ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _ 6 that that SCONJ IN _ 9 mark 9:mark _ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _ 8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _ 11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _ 15 in in ADP IN _ 16 case 16:case _ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No 17 . . PUNCT . _ 1 punct 1:punct _ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004 # text = Two of them were being run by 2 officials of the Ministry of the Interior! 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _ 2 of of ADP IN _ 3 case 3:case _ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ 7 by by ADP IN _ 9 case 9:case _ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _ 10 of of ADP IN _ 12 case 12:case _ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _ 13 of of ADP IN _ 15 case 15:case _ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No 16 ! ! PUNCT . _ 6 punct 6:punct _ """.lstrip() TRAIN_DATA_2 = """ # sent_id = 11 # text = It's all hers! # previous = Which person owns this? # comment = predeterminer modifier 1 It it PRON PRP Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _ 3 all all DET DT Case=Nom 4 det:predet _ _ 4 hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No 5 ! ! PUNCT . _ 4 punct _ _ """.lstrip() TRAIN_DATA_NO_UPOS = """ # sent_id = 11 # text = It's all hers! # previous = Which person owns this? # comment = predeterminer modifier 1 It it _ PRP Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No 2 's be _ VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _ 3 all all _ DT Case=Nom 4 det:predet _ _ 4 hers hers _ PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No 5 ! ! _ . _ 4 punct _ _ """.lstrip() TRAIN_DATA_NO_XPOS = """ # sent_id = 11 # text = It's all hers! # previous = Which person owns this? # comment = predeterminer modifier 1 It it PRON _ Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No 2 's be AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _ 3 all all DET _ Case=Nom 4 det:predet _ _ 4 hers hers PRON _ Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No 5 ! ! PUNCT _ _ 4 punct _ _ """.lstrip() TRAIN_DATA_NO_FEATS = """ # sent_id = 11 # text = It's all hers! # previous = Which person owns this? # comment = predeterminer modifier 1 It it PRON PRP _ 4 nsubj _ SpaceAfter=No 2 's be AUX VBZ _ 4 cop _ _ 3 all all DET DT _ 4 det:predet _ _ 4 hers hers PRON PRP _ 0 root _ SpaceAfter=No 5 ! ! PUNCT . _ 4 punct _ _ """.lstrip() DEV_DATA = """ 1 From from ADP IN _ 3 case 3:case _ 2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _ 3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _ 4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _ 5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _ 6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _ 7 : : PUNCT : _ 4 punct 4:punct _ """.lstrip() class TestTagger: @pytest.fixture(scope="class") def wordvec_pretrain_file(self): return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' @pytest.fixture(scope="class") def charlm_args(self): charlm = choose_pos_charlm("en", "test", "default") charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR) return charlm_args def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None): """ Run the training for a few iterations, load & return the model """ dev_file = str(tmp_path / "dev.conllu") pred_file = str(tmp_path / "pred.conllu") save_name = "test_tagger.pt" save_file = str(tmp_path / save_name) if isinstance(train_text, str): train_text = [train_text] train_files = [] for idx, train_blob in enumerate(train_text): train_file = str(tmp_path / ("train_%d.conllu" % idx)) with open(train_file, "w", encoding="utf-8") as fout: fout.write(train_blob) train_files.append(train_file) train_file = ";".join(train_files) with open(dev_file, "w", encoding="utf-8") as fout: fout.write(dev_text) args = ["--wordvec_pretrain_file", wordvec_pretrain_file, "--train_file", train_file, "--eval_file", dev_file, "--output_file", pred_file, "--log_step", "10", "--eval_interval", "20", "--max_steps", "100", "--shorthand", "en_test", "--save_dir", str(tmp_path), "--save_name", save_name, "--lang", "en"] if not augment_nopunct: args.extend(["--augment_nopunct", "0.0"]) if extra_args is not None: args = args + extra_args tagger.main(args) assert os.path.exists(save_file) pt = pretrain.Pretrain(wordvec_pretrain_file) saved_model = Trainer(pretrain=pt, model_file=save_file) return saved_model def test_train(self, tmp_path, wordvec_pretrain_file, augment_nopunct=True): """ Simple test of a few 'epochs' of tagger training """ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA) def test_vocab_cutoff(self, tmp_path, wordvec_pretrain_file): """ Test that the vocab cutoff leaves words we expect in the vocab, but not rare words """ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=["--word_cutoff", "3"]) word_vocab = trainer.vocab['word'] assert 'of' in word_vocab assert 'officials' in TRAIN_DATA assert 'officials' not in word_vocab def test_multiple_files(self, tmp_path, wordvec_pretrain_file): """ Test that multiple train files works Checks for evidence of it working by looking for words from the second file in the vocab """ trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA, TRAIN_DATA_2 * 3], DEV_DATA, extra_args=["--word_cutoff", "3"]) word_vocab = trainer.vocab['word'] assert 'of' in word_vocab assert 'officials' in TRAIN_DATA assert 'officials' not in word_vocab assert ' hers ' not in TRAIN_DATA assert ' hers ' in TRAIN_DATA_2 assert 'hers' in word_vocab def test_train_zero_augment(self, tmp_path, wordvec_pretrain_file): """ Train with the punct augmentation set to zero Distinguishs cases where training works w/ or w/o augmentation """ extra_args = ['--augment_nopunct', '0.0'] trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args) def test_train_100_augment(self, tmp_path, wordvec_pretrain_file): """ Train with the punct augmentation set to 1.0 Distinguishs cases where training works w/ or w/o augmentation """ extra_args = ['--augment_nopunct', '1.0'] trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args) def test_train_charlm(self, tmp_path, wordvec_pretrain_file, charlm_args): trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=charlm_args) def test_train_charlm_projection(self, tmp_path, wordvec_pretrain_file, charlm_args): extra_args = charlm_args + ['--charlm_transform_dim', '100'] trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args) def test_missing_column(self, tmp_path, wordvec_pretrain_file): """ Test that using train files with missing columns works In this test, we create three separate files, each with a single training entry. We then train on an amalgam of those three files with a batch size of 1, saving after each batch. This will ensure that only one item is used for each training loop and we can inspect the models which were saved. Since each of the three files have exactly one column missing from the training data, we expect to see the output maps for each column stay unchanged in one iteration and change in the other two. """ # use SGD because some old versions of pytorch with Adam keep # learning a value even if the loss is 0 in subsequent steps # (perhaps it had a momentum by default?) extra_args = ['--save_each', '--eval_interval', '1', '--max_steps', '3', '--batch_size', '1', '--optim', 'sgd'] trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_FEATS], DEV_DATA, extra_args=extra_args) save_each_name = tagger.save_each_file_name(trainer.args) model_files = [save_each_name % i for i in range(4)] assert all(os.path.exists(x) for x in model_files) pt = pretrain.Pretrain(wordvec_pretrain_file) saved_trainers = [Trainer(pretrain=pt, model_file=model_file) for model_file in model_files] upos_unchanged = 0 xpos_unchanged = 0 ufeats_unchanged = 0 for t1, t2 in zip(saved_trainers[:-1], saved_trainers[1:]): upos_unchanged += torch.allclose(t1.model.upos_clf.weight, t2.model.upos_clf.weight) xpos_unchanged += torch.allclose(t1.model.xpos_clf.W_bilin.weight, t2.model.xpos_clf.W_bilin.weight) ufeats_unchanged += all(torch.allclose(f1.W_bilin.weight, f2.W_bilin.weight) for f1, f2 in zip(t1.model.ufeats_clf, t2.model.ufeats_clf)) upos_norms = [torch.linalg.norm(t.model.upos_clf.weight) for t in saved_trainers] assert upos_unchanged == 1, "Unchanged: {} {} {} {}".format(upos_unchanged, xpos_unchanged, ufeats_unchanged, upos_norms) assert xpos_unchanged == 1, "Unchanged: %d %d %d" % (upos_unchanged, xpos_unchanged, ufeats_unchanged) assert ufeats_unchanged == 1, "Unchanged: %d %d %d" % (upos_unchanged, xpos_unchanged, ufeats_unchanged) def test_save_each(self, tmp_path, wordvec_pretrain_file): extra_args = ['--save_each'] trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args) save_each_name = tagger.save_each_file_name(trainer.args) expected_models = sorted(set([save_each_name % i for i in range(0, trainer.args['max_steps']+1, trainer.args['eval_interval'])])) assert len(expected_models) == 6 for model_name in expected_models: assert os.path.exists(model_name) def test_with_bert(self, tmp_path, wordvec_pretrain_file): self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert']) def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file): self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2']) def test_with_bert_finetune(self, tmp_path, wordvec_pretrain_file): self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_learning_rate', '0.01', '--bert_hidden_layers', '2']) def test_bert_pipeline(self, tmp_path, wordvec_pretrain_file): """ Test training the tagger, then using it in a pipeline The pipeline use of the tagger also tests the longer-than-maxlen workaround for the transformer """ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert']) save_name = trainer.args['save_name'] save_file = str(tmp_path / save_name) assert os.path.exists(save_file) pipe = stanza.Pipeline("en", processors="tokenize,pos", models_dir=TEST_MODELS_DIR, pos_model_path=save_file, pos_pretrain_path=wordvec_pretrain_file) trainer = pipe.processors['pos'].trainer assert trainer.args['save_name'] == save_name # these should be one chunk only doc = pipe("foo " * 100) doc = pipe("foo " * 500) # this is two chunks of bert embedding doc = pipe("foo " * 1000) # this is multiple chunks doc = pipe("foo " * 2000) ================================================ FILE: stanza/tests/pos/test_xpos_vocab_factory.py ================================================ """ Test some pieces of the depparse dataloader """ import pytest import logging import os import tempfile from stanza.models import tagger from stanza.models.common import pretrain from stanza.models.pos.data import Dataset from stanza.models.pos.trainer import Trainer from stanza.models.pos.vocab import WordVocab, XPOSVocab from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory from stanza.utils.conll import CoNLL from stanza.tests import TEST_WORKING_DIR pytestmark = [pytest.mark.travis, pytest.mark.pipeline] logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory') EN_EXAMPLE=""" 1 Sh'reyan Sh'reyan PROPN NNP%(tag)s Number=Sing 3 nmod:poss 3:nmod:poss _ 2 's 's PART POS%(tag)s _ 1 case 1:case _ 3 antennae antenna NOUN%(tag)s NNS Number=Plur 6 nsubj 6:nsubj _ 4 are be VERB VBP%(tag)s Mood=Ind|Tense=Pres|VerbForm=Fin 6 cop 6:cop _ 5 hella hella ADV RB%(tag)s _ 6 advmod 6:advmod _ 6 thicc thicc ADJ JJ%(tag)s Degree=Pos 0 root 0:root _ """ EMPTY_TAG = lambda x: "" DASH_TAGS = lambda x: "-%d" % x def build_doc(iterations, suffix): """ build N copies of the english text above, with a lambda function applied for the tag suffices for example: lambda x: "" means the suffices are all blank (NNP, POS, NNS, etc) for each iteration lambda x: "-%d" % x means they go (NNP-0, NNP-1, NNP-2, etc) for the first word's tag """ texts = [EN_EXAMPLE % {"tag": suffix(i)} for i in range(iterations)] text = "\n\n".join(texts) doc = CoNLL.conll2doc(input_str=text) return doc def build_data(iterations, suffix): """ Same thing, but passes the Doc through a POS Tagger DataLoader """ doc = build_doc(iterations, suffix) data = Dataset.load_doc(doc) return data class ErrorFatalHandler(logging.Handler): """ This handler turns any error logs into a fatal error Theoretically you could change the level to make other things fatal as well """ def __init__(self): super().__init__() self.setLevel(logging.ERROR) def emit(self, record): raise AssertionError("Oh no, we printed an error") class TestXPOSVocabFactory: @classmethod def setup_class(cls): """ Add a logger to the xpos factory logger so that it will throw an assertion instead of logging an error We don't actually want assertions, since that would be a huge pain in the event one of the models actually changes, so instead we just logger.error in the factory. Using this handler is a simple way to check that the error is correctly logged when something changes """ logger.info("About to start xpos_vocab_factory tests - logger.error in that module will now cause AssertionError") handler = ErrorFatalHandler() logger.addHandler(handler) @classmethod def teardown_class(cls): """ Remove the handler we installed earlier """ handlers = [x for x in logger.handlers if isinstance(x, ErrorFatalHandler)] for handler in handlers: logger.removeHandler(handler) logger.error("Done with xpos_vocab_factory tests - this should not throw an error") def test_basic_en_ewt(self): """ en_ewt is currently the basic vocab note that this may change if the dataset is drastically relabeled in the future """ data = build_data(1, EMPTY_TAG) vocab = xpos_vocab_factory(data, "en_ewt") assert isinstance(vocab, WordVocab) def test_basic_en_unknown(self): """ With only 6 tags, it should use a basic vocab for an unknown dataset """ data = build_data(10, EMPTY_TAG) vocab = xpos_vocab_factory(data, "en_unknown") assert isinstance(vocab, WordVocab) def test_dash_en_unknown(self): """ With this many different tags, it should choose to reduce it to the base xpos removing the - """ data = build_data(10, DASH_TAGS) vocab = xpos_vocab_factory(data, "en_unknown") assert isinstance(vocab, XPOSVocab) assert vocab.sep == "-" def test_dash_en_ewt_wrong(self): """ The dataset looks like XPOS(-), which is wrong for en_ewt """ with pytest.raises(AssertionError): data = build_data(10, DASH_TAGS) vocab = xpos_vocab_factory(data, "en_ewt") assert isinstance(vocab, XPOSVocab) assert vocab.sep == "-" def check_reload(self, pt, shorthand, iterations, suffix, expected_vocab): """ Build a Trainer (no actual training), save it, and load it back in to check the type of Vocab restored TODO: This test may be a bit "eager" in that there are no other tests which check building, saving, & loading a pos trainer. Could add tests to test_trainer.py, for example """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: args = tagger.parse_args(["--batch_size", "1", "--shorthand", shorthand]) train_doc = build_doc(iterations, suffix) train_batch = Dataset(train_doc, args, pt, evaluation=False) vocab = train_batch.vocab assert isinstance(vocab['xpos'], expected_vocab) trainer = Trainer(args=args, vocab=vocab, pretrain=pt, device="cpu") model_file = os.path.join(tmpdirname, "foo.pt") trainer.save(model_file) new_trainer = Trainer(model_file=model_file, pretrain=pt) assert isinstance(new_trainer.vocab['xpos'], expected_vocab) @pytest.fixture(scope="class") def pt(self): pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False) return pt def test_reload_word_vocab(self, pt): """ Test that building a model with a known word vocab shorthand, saving it, and loading it gets back a word vocab """ self.check_reload(pt, "en_ewt", 10, EMPTY_TAG, WordVocab) def test_reload_unknown_word_vocab(self, pt): """ Test that building a model with an unknown word vocab, saving it, and loading it gets back a word vocab """ self.check_reload(pt, "en_unknown", 10, EMPTY_TAG, WordVocab) def test_reload_unknown_xpos_vocab(self, pt): """ Test that building a model with an unknown xpos vocab, saving it, and loading it gets back an xpos vocab """ self.check_reload(pt, "en_unknown", 10, DASH_TAGS, XPOSVocab) ================================================ FILE: stanza/tests/pytest.ini ================================================ [pytest] markers = travis: all tests that will be run in travis CI client: all tests that are related to the CoreNLP client interface pipeline: all tests that are related to the Stanza neural pipeline morphseg: all tests that are related to morpheme segmentation ================================================ FILE: stanza/tests/resources/__init__.py ================================================ ================================================ FILE: stanza/tests/resources/test_charlm_depparse.py ================================================ import pytest from stanza.resources.default_packages import default_charlms, depparse_charlms from stanza.resources.print_charlm_depparse import list_depparse def test_list_depparse(): models = list_depparse() # check that it's picking up the models which don't have specific charlms # first, make sure the default assumption of the test is still true... # if this test fails, find a different language which isn't in depparse_charlms assert "af" not in depparse_charlms assert "af" in default_charlms assert "af_afribooms_charlm" in models assert "af_afribooms_nocharlm" in models # assert that it's picking up the models which do have specific charlms that aren't None # again, first make sure the default assumptions are true # if one of these next few tests fail, just update the test assert "en" in depparse_charlms assert "en" in default_charlms assert "ewt" not in depparse_charlms["en"] assert "craft" in depparse_charlms["en"] assert "mimic" in depparse_charlms["en"] # now, check the results assert "en_ewt_charlm" in models assert "en_ewt_nocharlm" in models assert "en_mimic_charlm" in models # haven't yet trained w/ and w/o for the bio models assert "en_mimic_nocharlm" not in models assert "en_craft_charlm" not in models assert "en_craft_nocharlm" in models ================================================ FILE: stanza/tests/resources/test_common.py ================================================ """ Test various resource downloading functions from resources/common.py """ import os import pytest import tempfile import stanza from stanza.resources import common from stanza.tests import TEST_MODELS_DIR, TEST_WORKING_DIR pytestmark = [pytest.mark.travis, pytest.mark.client] def test_assert_file_exists(): with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: filename = os.path.join(test_dir, "test.txt") with pytest.raises(FileNotFoundError): common.assert_file_exists(filename) with open(filename, "w", encoding="utf-8") as fout: fout.write("Unban mox opal!") # MD5 of the fake model file, not any real model files in the system EXPECTED_MD5 = "44dbf21b4e89cea5184615a72a825a36" common.assert_file_exists(filename) common.assert_file_exists(filename, md5=EXPECTED_MD5) with pytest.raises(ValueError): common.assert_file_exists(filename, md5="12345") with pytest.raises(ValueError): common.assert_file_exists(filename, md5="12345", alternate_md5="12345") common.assert_file_exists(filename, md5="12345", alternate_md5=EXPECTED_MD5) def test_download_tokenize_mwt(): with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False) pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt") assert isinstance(pipeline, stanza.Pipeline) # mwt should be added to the list assert len(pipeline.loaded_processors) == 2 def test_download_non_default(): """ Test the download path for a single file rather than the default zip The expectation is that an NER model will also download two charlm models. If that layout changes on purpose, this test will fail and will need to be updated """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download("en", model_dir=test_dir, processors="ner", package="ontonotes_charlm", verbose=False) assert sorted(os.listdir(test_dir)) == ['en', 'resources.json'] en_dir = os.path.join(test_dir, 'en') en_dir_listing = sorted(os.listdir(en_dir)) assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'ner', 'pretrain'] assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt'] for i in en_dir_listing: assert len(os.listdir(os.path.join(en_dir, i))) == 1 def test_download_two_models(): """ Test the download path for two NER models The package system should now allow for multiple NER models to be specified, and a consequence of that is it should be possible to download two models at once The expectation is that the two different NER models both download a different forward & backward charlm. If that changes, the test will fail. Best way to update it will be two different models which download two different charlms """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download("en", model_dir=test_dir, processors="ner", package={"ner": ["ontonotes_charlm", "anatem"]}, verbose=False) assert sorted(os.listdir(test_dir)) == ['en', 'resources.json'] en_dir = os.path.join(test_dir, 'en') en_dir_listing = sorted(os.listdir(en_dir)) assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'ner', 'pretrain'] assert sorted(os.listdir(os.path.join(en_dir, 'ner'))) == ['anatem.pt', 'ontonotes_charlm.pt'] for i in en_dir_listing: assert len(os.listdir(os.path.join(en_dir, i))) == 2 def test_process_pipeline_parameters(): """ Test a few options for specifying which processors to load """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: lang, model_dir, package, processors = common.process_pipeline_parameters("en", test_dir, None, "tokenize,pos") assert processors == {"tokenize": "default", "pos": "default"} assert package == None lang, model_dir, package, processors = common.process_pipeline_parameters("en", test_dir, {"tokenize": "spacy"}, "tokenize,pos") assert processors == {"tokenize": "spacy", "pos": "default"} assert package == None lang, model_dir, package, processors = common.process_pipeline_parameters("en", test_dir, {"pos": "ewt"}, "tokenize,pos") assert processors == {"tokenize": "default", "pos": "ewt"} assert package == None lang, model_dir, package, processors = common.process_pipeline_parameters("en", test_dir, "ewt", "tokenize,pos") assert processors == {"tokenize": "ewt", "pos": "ewt"} assert package == None def test_language_resources(): resources = common.load_resources_json(TEST_MODELS_DIR) # check that an unknown language comes back as None bad_lang = 'z' while bad_lang in resources and len(bad_lang) < 100: bad_lang = bad_lang + 'z' assert bad_lang not in resources assert common.get_language_resources(resources, bad_lang) == None # check the parameters of the test make sense # there should be 'zh' which is an alias of 'zh-hans' assert "zh" in resources assert "alias" in resources["zh"] assert resources["zh"]["alias"] == "zh-hans" # check that getting the resources for either 'zh' or 'zh-hans' # return the simplified Chinese resources zh_resources = common.get_language_resources(resources, "zh") assert "tokenize" in zh_resources assert "alias" not in zh_resources assert "Chinese" in zh_resources["lang_name"] zh_hans_resources = common.get_language_resources(resources, "zh-hans") assert zh_resources == zh_hans_resources ================================================ FILE: stanza/tests/resources/test_default_packages.py ================================================ import pytest import stanza from stanza.resources import default_packages def test_default_pretrains(): """ Test that all languages with a default treebank have a default pretrain or are specifically marked as not having a pretrain """ for lang in default_packages.default_treebanks.keys(): assert lang in default_packages.no_pretrain_languages or lang in default_packages.default_pretrains, "Lang %s does not have a default pretrain marked!" % lang def test_no_pretrain_languages(): """ Test that no languages have no_default_pretrain marked despite having a pretrain """ for lang in default_packages.no_pretrain_languages: assert lang not in default_packages.default_pretrains, "Lang %s is marked as no_pretrain but has a default pretrain!" % lang ================================================ FILE: stanza/tests/resources/test_installation.py ================================================ """ Test installation functions. """ import os import pytest import shutil import tempfile import stanza from stanza.tests import TEST_WORKING_DIR pytestmark = [pytest.mark.travis, pytest.mark.client] def test_install_corenlp(): # we do not reset the CORENLP_HOME variable since this may impact the # client tests with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: # the download method doesn't install over existing directories shutil.rmtree(test_dir) stanza.install_corenlp(dir=test_dir) assert os.path.isdir(test_dir), "Installation destination directory not found." jar_files = [f for f in os.listdir(test_dir) \ if f.endswith('.jar') and f.startswith('stanford-corenlp')] assert len(jar_files) > 0, \ "Cannot find stanford-corenlp jar files in the installation directory." assert not os.path.exists(os.path.join(test_dir, 'corenlp.zip')), \ "Downloaded zip file was not removed." def test_download_corenlp_models(): model_name = "arabic" version = "4.2.2" with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download_corenlp_models(model=model_name, version=version, dir=test_dir) dest_file = os.path.join(test_dir, f"stanford-corenlp-{version}-models-{model_name}.jar") assert os.path.isfile(dest_file), "Downloaded model file not found." def test_download_tokenize_mwt(): with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False) pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt") assert isinstance(pipeline, stanza.Pipeline) # mwt should be added to the list assert len(pipeline.loaded_processors) == 2 ================================================ FILE: stanza/tests/resources/test_prepare_resources.py ================================================ import pytest import stanza import stanza.resources.prepare_resources as prepare_resources from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_split_model_name(): # Basic test lang, package, processor = prepare_resources.split_model_name('ro_nonstandard_tagger.pt') assert lang == 'ro' assert package == 'nonstandard' assert processor == 'pos' # Check that nertagger is found even though it also ends with tagger # Check that ncbi_disease is correctly partitioned despite the extra _ lang, package, processor = prepare_resources.split_model_name('en_ncbi_disease_nertagger.pt') assert lang == 'en' assert package == 'ncbi_disease' assert processor == 'ner' # assert that processors with _ in them are also okay lang, package, processor = prepare_resources.split_model_name('en_pubmed_forward_charlm.pt') assert lang == 'en' assert package == 'pubmed' assert processor == 'forward_charlm' ================================================ FILE: stanza/tests/server/__init__.py ================================================ ================================================ FILE: stanza/tests/server/test_client.py ================================================ """ Tests that call a running CoreNLPClient. """ from http.server import BaseHTTPRequestHandler, HTTPServer import multiprocessing import pytest import requests import stanza.server as corenlp import stanza.server.client as client import shlex import subprocess import time from stanza.models.constituency import tree_reader from stanza.tests import * # set the marker for this module pytestmark = [pytest.mark.travis, pytest.mark.client] TEXT = "Chris wrote a simple sentence that he parsed with Stanford CoreNLP.\n" MAX_REQUEST_ATTEMPTS = 5 EN_GOLD = """ Sentence #1 (12 tokens): Chris wrote a simple sentence that he parsed with Stanford CoreNLP. Tokens: [Text=Chris CharacterOffsetBegin=0 CharacterOffsetEnd=5 PartOfSpeech=NNP] [Text=wrote CharacterOffsetBegin=6 CharacterOffsetEnd=11 PartOfSpeech=VBD] [Text=a CharacterOffsetBegin=12 CharacterOffsetEnd=13 PartOfSpeech=DT] [Text=simple CharacterOffsetBegin=14 CharacterOffsetEnd=20 PartOfSpeech=JJ] [Text=sentence CharacterOffsetBegin=21 CharacterOffsetEnd=29 PartOfSpeech=NN] [Text=that CharacterOffsetBegin=30 CharacterOffsetEnd=34 PartOfSpeech=WDT] [Text=he CharacterOffsetBegin=35 CharacterOffsetEnd=37 PartOfSpeech=PRP] [Text=parsed CharacterOffsetBegin=38 CharacterOffsetEnd=44 PartOfSpeech=VBD] [Text=with CharacterOffsetBegin=45 CharacterOffsetEnd=49 PartOfSpeech=IN] [Text=Stanford CharacterOffsetBegin=50 CharacterOffsetEnd=58 PartOfSpeech=NNP] [Text=CoreNLP CharacterOffsetBegin=59 CharacterOffsetEnd=66 PartOfSpeech=NNP] [Text=. CharacterOffsetBegin=66 CharacterOffsetEnd=67 PartOfSpeech=.] """.strip() def run_webserver(port, timeout_secs): class HTTPTimeoutHandler(BaseHTTPRequestHandler): def do_POST(self): time.sleep(timeout_secs) self.send_response(200) self.send_header('Content-type', 'text/plain; charset=utf-8') self.end_headers() self.wfile.write("HTTPMockServerTimeout") HTTPServer(('127.0.0.1', port), HTTPTimeoutHandler).serve_forever() class HTTPMockServerTimeoutContext: """ For launching an HTTP server on certain port with an specified delay at responses """ def __init__(self, port, timeout_secs): self.port = port self.timeout_secs = timeout_secs def __enter__(self): self.p = multiprocessing.Process(target=run_webserver, args=(self.port, self.timeout_secs)) self.p.daemon = True self.p.start() def __exit__(self, exc_type, exc_value, exc_traceback): self.p.terminate() class TestCoreNLPClient: @pytest.fixture(scope="class") def corenlp_client(self): """ Client to run tests on """ client = corenlp.CoreNLPClient(annotators='tokenize,ssplit,pos,lemma,ner,depparse', server_id='stanza_main_test_server') yield client client.stop() def test_connect(self, corenlp_client): corenlp_client.ensure_alive() assert corenlp_client.is_active assert corenlp_client.is_alive() def test_context_manager(self): with corenlp.CoreNLPClient(annotators="tokenize,ssplit", endpoint="http://localhost:9001") as context_client: ann = context_client.annotate(TEXT) assert corenlp.to_text(ann.sentence[0]) == TEXT[:-1] def test_no_duplicate_servers(self): """We expect a second server on the same port to fail""" with pytest.raises(corenlp.PermanentlyFailedException): with corenlp.CoreNLPClient(annotators="tokenize,ssplit") as duplicate_server: raise RuntimeError("This should have failed") def test_annotate(self, corenlp_client): ann = corenlp_client.annotate(TEXT) assert corenlp.to_text(ann.sentence[0]) == TEXT[:-1] def test_update(self, corenlp_client): ann = corenlp_client.annotate(TEXT) ann = corenlp_client.update(ann) assert corenlp.to_text(ann.sentence[0]) == TEXT[:-1] def test_tokensregex(self, corenlp_client): pattern = '([ner: PERSON]+) /wrote/ /an?/ []{0,3} /sentence|article/' matches = corenlp_client.tokensregex(TEXT, pattern) assert len(matches["sentences"]) == 1 assert matches["sentences"][0]["length"] == 1 assert matches == { "sentences": [{ "0": { "text": "Chris wrote a simple sentence", "begin": 0, "end": 5, "1": { "text": "Chris", "begin": 0, "end": 1 }}, "length": 1 },]} def test_semgrex(self, corenlp_client): pattern = '{word:wrote} >nsubj {}=subject >obj {}=object' matches = corenlp_client.semgrex(TEXT, pattern, to_words=True) assert matches == [ { "text": "wrote", "begin": 1, "end": 2, "$subject": { "text": "Chris", "begin": 0, "end": 1 }, "$object": { "text": "sentence", "begin": 4, "end": 5 }, "sentence": 0,}] def test_tregex(self, corenlp_client): # the PP should be easy to parse pattern = 'PP < NP' matches = corenlp_client.tregex(TEXT, pattern) print(matches) assert matches == { 'sentences': [ {'0': {'sentIndex': 0, 'characterOffsetBegin': 45, 'codepointOffsetBegin': 45, 'characterOffsetEnd': 66, 'codepointOffsetEnd': 66, 'match': '(PP (IN with)\n (NP (NNP Stanford) (NNP CoreNLP)))\n', 'spanString': 'with Stanford CoreNLP', 'namedNodes': []}} ] } def test_tregex_trees(self, corenlp_client): """ Test the results of tregex run on trees w/o parsing """ trees = tree_reader.read_trees("(ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ blue) (NN skin))))) (ROOT (S (NP (PRP I)) (VP (VBP like) (NP (PRP$ her) (NNS antennae)))))") pattern = "VP < NP" matches = corenlp_client.tregex(pattern=pattern, trees=trees) assert matches == { 'sentences': [ {'0': {'sentIndex': 0, 'match': '(VP (VBZ has)\n (NP (JJ blue) (NN skin)))\n', 'spanString': 'has blue skin', 'namedNodes': []}}, {'0': {'sentIndex': 1, 'match': '(VP (VBP like)\n (NP (PRP$ her) (NNS antennae)))\n', 'spanString': 'like her antennae', 'namedNodes': []}} ] } @pytest.fixture def external_server_9001(self): corenlp_home = client.resolve_classpath(None) start_cmd = f'java -Xmx5g -cp "{corenlp_home}" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9001 ' \ f'-timeout 60000 -server_id stanza_external_server -serverProperties {SERVER_TEST_PROPS}' start_cmd = start_cmd and shlex.split(start_cmd) external_server_process = subprocess.Popen(start_cmd) yield external_server_process assert external_server_process external_server_process.terminate() external_server_process.wait(5) def test_external_server_legacy_start_server(self, external_server_9001): """ Test starting up an external server and accessing with a client with start_server=False """ with corenlp.CoreNLPClient(start_server=False, endpoint="http://localhost:9001") as external_server_client: ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text') assert ann.strip() == EN_GOLD def test_external_server_available(self, external_server_9001): """ Test starting up an external available server and accessing with a client with start_server=StartServer.DONT_START """ time.sleep(5) # wait and make sure the external CoreNLP server is up and running with corenlp.CoreNLPClient(start_server=corenlp.StartServer.DONT_START, endpoint="http://localhost:9001") as external_server_client: ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text') assert ann.strip() == EN_GOLD def test_external_server_unavailable(self): """ Test accessing with a client with start_server=StartServer.DONT_START to an external unavailable server """ with pytest.raises(corenlp.AnnotationException): with corenlp.CoreNLPClient(start_server=corenlp.StartServer.DONT_START, endpoint="http://localhost:9001") as external_server_client: ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text') def test_external_server_timeout(self): """ Test starting up an external server with long response time (20 seconds) and accessing with a client with start_server=StartServer.DONT_START and timeout=5000""" with HTTPMockServerTimeoutContext(9001, 20): time.sleep(5) # wait and make sure the external HTTPMockServer server is up and running with pytest.raises(corenlp.TimeoutException): with corenlp.CoreNLPClient(start_server=corenlp.StartServer.DONT_START, endpoint="http://localhost:9001", timeout=5000) as external_server_client: ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text') def test_external_server_try_start_with_external(self, external_server_9001): """ Test starting up an external server and accessing with a client with start_server=StartServer.TRY_START """ time.sleep(5) # wait and make sure the external CoreNLP server is up and running with corenlp.CoreNLPClient(start_server=corenlp.StartServer.TRY_START, annotators='tokenize,ssplit,pos', endpoint="http://localhost:9001") as external_server_client: ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text') assert external_server_client.server is None, "If this is not None, that indicates the client started a server instead of reusing an existing one" assert ann.strip() == EN_GOLD def test_external_server_try_start(self): """ Test starting up a server with a client with start_server=StartServer.TRY_START """ with corenlp.CoreNLPClient(start_server=corenlp.StartServer.TRY_START, annotators='tokenize,ssplit,pos', endpoint="http://localhost:9001") as external_server_client: ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text') assert ann.strip() == EN_GOLD def test_external_server_force_start(self, external_server_9001): """ Test starting up an external server and accessing with a client with start_server=StartServer.FORCE_START """ time.sleep(5) # wait and make sure the external CoreNLP server is up and running with pytest.raises(corenlp.PermanentlyFailedException): with corenlp.CoreNLPClient(start_server=corenlp.StartServer.FORCE_START, endpoint="http://localhost:9001") as external_server_client: ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text') ================================================ FILE: stanza/tests/server/test_java_protobuf_requests.py ================================================ import tempfile import pytest from stanza.models.common.utils import misc_to_space_after, space_after_to_misc from stanza.models.constituency import tree_reader from stanza.server import java_protobuf_requests from stanza.tests import * from stanza.utils.conll import CoNLL from stanza.protobuf import DependencyGraph pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def check_tree(proto_tree, py_tree, py_score): tree, tree_score = java_protobuf_requests.from_tree(proto_tree) assert tree_score == py_score assert tree == py_tree def test_build_tree(): text="((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\n( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) assert len(trees) == 2 for tree in trees: proto_tree = java_protobuf_requests.build_tree(trees[0], 1.0) check_tree(proto_tree, trees[0], 1.0) ESTONIAN_EMPTY_DEPS = """ # sent_id = ewtb2_000035_15 # text = Ja paari aasta pärast rôômalt maasikatele ... 1 Ja ja CCONJ J _ 3 cc 5.1:cc _ 2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _ 3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _ 4 pärast pärast ADP K AdpType=Post 3 case 3:case _ 5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt 5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1 6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes 7 ... ... PUNCT Z _ 3 punct 5.1:punct _ """.strip() def test_convert_networkx_graph(): doc = CoNLL.conll2doc(input_str=ESTONIAN_EMPTY_DEPS, ignore_gapping=False) deps = doc.sentences[0]._enhanced_dependencies graph = DependencyGraph() java_protobuf_requests.convert_networkx_graph(graph, doc.sentences[0], 0) assert len(graph.rootNode) == 1 assert graph.rootNode[0] == 0 nodes = sorted([(x.index, x.emptyIndex) for x in graph.node]) expected_nodes = [(1,0), (2,0), (3,0), (4,0), (5,0), (5,1), (6,0), (7,0)] assert nodes == expected_nodes edges = [(x.target, x.dep) for x in graph.edge if x.source == 5 and x.sourceEmpty == 1] edges = sorted(edges) expected_edges = [(1, 'cc'), (3, 'obl'), (5, 'advmod'), (6, 'obl'), (7, 'punct')] assert edges == expected_edges ENGLISH_NBSP_SAMPLE=""" # sent_id = newsgroup-groups.google.com_n3td3v_e874a1e5eb995654_ENG_20060120_052200-0011 # text = Please note that neither the e-mail address nor name of the sender have been verified. 1 Please please INTJ UH _ 2 discourse _ _ 2 note note VERB VB Mood=Imp|VerbForm=Fin 0 root _ _ 3 that that SCONJ IN _ 15 mark _ _ 4 neither neither CCONJ CC _ 7 cc:preconj _ _ 5 the the DET DT Definite=Def|PronType=Art 7 det _ _ 6 e-mail e-mail NOUN NN Number=Sing 7 compound _ _ 7 address address NOUN NN Number=Sing 15 nsubj:pass _ _ 8 nor nor CCONJ CC _ 9 cc _ _ 9 name name NOUN NN Number=Sing 7 conj _ _ 10 of of ADP IN _ 12 case _ _ 11 the the DET DT Definite=Def|PronType=Art 12 det _ _ 12 sender sender NOUN NN Number=Sing 7 nmod _ _ 13 have have AUX VBP Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 15 aux _ SpacesAfter=\\u00A0 14 been be AUX VBN Tense=Past|VerbForm=Part 15 aux:pass _ _ 15 verified verify VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 2 ccomp _ SpaceAfter=No 16 . . PUNCT . _ 2 punct _ _ """.strip() def test_nbsp_doc(): """ Test that the space conversion methods will convert to and from NBSP """ doc = CoNLL.conll2doc(input_str=ENGLISH_NBSP_SAMPLE) assert doc.sentences[0].text == "Please note that neither the e-mail address nor name of the sender have been verified." assert doc.sentences[0].tokens[12].spaces_after == " " assert misc_to_space_after("SpacesAfter=\\u00A0") == ' ' assert space_after_to_misc(' ') == "SpacesAfter=\\u00A0" conllu = "{:C}".format(doc) assert conllu == ENGLISH_NBSP_SAMPLE ================================================ FILE: stanza/tests/server/test_morphology.py ================================================ """ Test the most basic functionality of the morphology script """ import pytest from stanza.server.morphology import Morphology, process_text words = ["Jennifer", "has", "the", "prettiest", "antennae"] tags = ["NNP", "VBZ", "DT", "JJS", "NNS"] expected = ["Jennifer", "have", "the", "pretty", "antenna"] def test_process_text(): result = process_text(words, tags) lemma = [x.lemma for x in result.words] print(lemma) assert lemma == expected def test_basic_morphology(): with Morphology() as morph: result = morph.process(words, tags) lemma = [x.lemma for x in result.words] assert lemma == expected ================================================ FILE: stanza/tests/server/test_parser_eval.py ================================================ """ Test the parser eval interface """ import pytest import stanza from stanza.models.constituency import tree_reader from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse from stanza.server.parser_eval import build_request, collate, EvaluateParser, ParseResult from stanza.tests.server.test_java_protobuf_requests import check_tree from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.client] def build_one_tree_treebank(fake_scores=True): text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" trees = tree_reader.read_trees(text) assert len(trees) == 1 gold = trees[0] if fake_scores: prediction = (gold, 1.0) treebank = [ParseResult(gold, [prediction], None, None)] return treebank else: prediction = gold return collate([gold], [prediction]) def check_build(fake_scores=True): treebank = build_one_tree_treebank(fake_scores) request = build_request(treebank) assert len(request.treebank) == 1 check_tree(request.treebank[0].gold, treebank[0][0], None) assert len(request.treebank[0].predicted) == 1 if fake_scores: check_tree(request.treebank[0].predicted[0], treebank[0][1][0][0], treebank[0][1][0][1]) else: check_tree(request.treebank[0].predicted[0], treebank[0][1][0], None) def test_build_tuple_request(): check_build(True) def test_build_notuple_request(): check_build(False) def test_score_one_tree_tuples(): treebank = build_one_tree_treebank(True) with EvaluateParser() as ep: response = ep.process(treebank) assert response.f1 == pytest.approx(1.0) def test_score_one_tree_notuples(): treebank = build_one_tree_treebank(False) with EvaluateParser() as ep: response = ep.process(treebank) assert response.f1 == pytest.approx(1.0) ================================================ FILE: stanza/tests/server/test_protobuf.py ================================================ """ Tests to read a stored protobuf. Also serves as an example of how to parse sentences, tokens, pos, lemma, ner, dependencies and mentions. The test corresponds to annotations for the following sentence: Chris wrote a simple sentence that he parsed with Stanford CoreNLP. """ import os from pathlib import Path import pytest from pytest import fixture from stanza.protobuf import Document, Sentence, Token, DependencyGraph,\ CorefChain from stanza.protobuf import parseFromDelimitedString, writeToDelimitedString, to_text # set the marker for this module pytestmark = [pytest.mark.travis, pytest.mark.client] # Text that was annotated TEXT = "Chris wrote a simple sentence that he parsed with Stanford CoreNLP.\n" @fixture def doc_pb(): test_dir = os.path.dirname(os.path.abspath(__file__)) test_dir = Path(test_dir).parent test_data = os.path.join(test_dir, 'data', 'test.dat') with open(test_data, 'rb') as f: buf = f.read() doc = Document() parseFromDelimitedString(doc, buf) return doc def test_parse_protobuf(doc_pb): assert doc_pb.ByteSize() == 4709 def test_write_protobuf(doc_pb): stream = writeToDelimitedString(doc_pb) buf = stream.getvalue() stream.close() doc_pb_ = Document() parseFromDelimitedString(doc_pb_, buf) assert doc_pb == doc_pb_ def test_document_text(doc_pb): assert doc_pb.text == TEXT def test_sentences(doc_pb): assert len(doc_pb.sentence) == 1 sentence = doc_pb.sentence[0] assert isinstance(sentence, Sentence) # check sentence length assert sentence.characterOffsetEnd - sentence.characterOffsetBegin == 67 # Note that the sentence text should actually be recovered from the tokens. assert sentence.text == '' assert to_text(sentence) == TEXT[:-1] def test_tokens(doc_pb): sentence = doc_pb.sentence[0] tokens = sentence.token assert len(tokens) == 12 assert isinstance(tokens[0], Token) # Word words = "Chris wrote a simple sentence that he parsed with Stanford CoreNLP .".split() words_ = [t.word for t in tokens] assert words_ == words # Lemma lemmas = "Chris write a simple sentence that he parse with Stanford CoreNLP .".split() lemmas_ = [t.lemma for t in tokens] assert lemmas_ == lemmas # POS pos = "NNP VBD DT JJ NN IN PRP VBD IN NNP NNP .".split() pos_ = [t.pos for t in tokens] assert pos_ == pos # NER ner = "PERSON O O O O O O O O ORGANIZATION O O".split() ner_ = [t.ner for t in tokens] assert ner_ == ner # character offsets begin = [int(i) for i in "0 6 12 14 21 30 35 38 45 50 59 66".split()] end = [int(i) for i in "5 11 13 20 29 34 37 44 49 58 66 67".split()] begin_ = [t.beginChar for t in tokens] end_ = [t.endChar for t in tokens] assert begin_ == begin assert end_ == end def test_dependency_parse(doc_pb): """ Extract the dependency parse from the annotation. """ sentence = doc_pb.sentence[0] # You can choose from the following types of dependencies. # In general, you'll want enhancedPlusPlus assert sentence.basicDependencies.ByteSize() > 0 assert sentence.enhancedDependencies.ByteSize() > 0 assert sentence.enhancedPlusPlusDependencies.ByteSize() > 0 tree = sentence.enhancedPlusPlusDependencies isinstance(tree, DependencyGraph) # Indices are 1-indexd with 0 being the "pseudo root" assert tree.root # 'wrote' is the root. == [2] # There are as many nodes as there are tokens. assert len(tree.node) == len(sentence.token) # Enhanced++ dependencies often contain additional edges and are # not trees -- here, 'parsed' would also have an edge to # 'sentence' assert len(tree.edge) == 12 # This edge goes from "wrote" to "Chirs" edge = tree.edge[0] assert edge.source == 2 assert edge.target == 1 assert edge.dep == "nsubj" def test_coref_chain(doc_pb): """ Extract the corefence chains from the annotation. """ # Coreference chains span sentences and are stored in the # document. chains = doc_pb.corefChain # In this document there is 1 chain with Chris and he. assert len(chains) == 1 chain = chains[0] assert isinstance(chain, CorefChain) assert chain.mention[0].beginIndex == 0 # 'Chris' assert chain.mention[0].endIndex == 1 assert chain.mention[0].gender == "MALE" assert chain.mention[1].beginIndex == 6 # 'he' assert chain.mention[1].endIndex == 7 assert chain.mention[1].gender == "MALE" assert chain.representative == 0 # Head of the chain is 'Chris' ================================================ FILE: stanza/tests/server/test_semgrex.py ================================================ """ Test the semgrex interface """ import pytest import stanza import stanza.server.semgrex as semgrex from stanza.models.common.doc import Document from stanza.protobuf import SemgrexRequest from stanza.utils.conll import CoNLL from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.client] TEST_ONE_SENTENCE = [[ { "id": 1, "text": "Unban", "lemma": "unban", "upos": "VERB", "xpos": "VB", "feats": "Mood=Imp|VerbForm=Fin", "head": 0, "deprel": "root", "misc": "start_char=0|end_char=5" }, { "id": 2, "text": "Mox", "lemma": "Mox", "upos": "PROPN", "xpos": "NNP", "feats": "Number=Sing", "head": 3, "deprel": "compound", "misc": "start_char=6|end_char=9" }, { "id": 3, "text": "Opal", "lemma": "Opal", "upos": "PROPN", "xpos": "NNP", "feats": "Number=Sing", "head": 1, "deprel": "obj", "misc": "start_char=10|end_char=14", "ner": "GEM" }, { "id": 4, "text": "!", "lemma": "!", "upos": "PUNCT", "xpos": ".", "head": 1, "deprel": "punct", "misc": "start_char=14|end_char=15" }]] TEST_TWO_SENTENCES = [[ { "id": 1, "text": "Unban", "lemma": "unban", "upos": "VERB", "xpos": "VB", "feats": "Mood=Imp|VerbForm=Fin", "head": 0, "deprel": "root", "misc": "start_char=0|end_char=5" }, { "id": 2, "text": "Mox", "lemma": "Mox", "upos": "PROPN", "xpos": "NNP", "feats": "Number=Sing", "head": 3, "deprel": "compound", "misc": "start_char=6|end_char=9" }, { "id": 3, "text": "Opal", "lemma": "Opal", "upos": "PROPN", "xpos": "NNP", "feats": "Number=Sing", "head": 1, "deprel": "obj", "misc": "start_char=10|end_char=14" }, { "id": 4, "text": "!", "lemma": "!", "upos": "PUNCT", "xpos": ".", "head": 1, "deprel": "punct", "misc": "start_char=14|end_char=15" }], [{ "id": 1, "text": "Unban", "lemma": "unban", "upos": "VERB", "xpos": "VB", "feats": "Mood=Imp|VerbForm=Fin", "head": 0, "deprel": "root", "misc": "start_char=16|end_char=21" }, { "id": 2, "text": "Mox", "lemma": "Mox", "upos": "PROPN", "xpos": "NNP", "feats": "Number=Sing", "head": 3, "deprel": "compound", "misc": "start_char=22|end_char=25" }, { "id": 3, "text": "Opal", "lemma": "Opal", "upos": "PROPN", "xpos": "NNP", "feats": "Number=Sing", "head": 1, "deprel": "obj", "misc": "start_char=26|end_char=30" }, { "id": 4, "text": "!", "lemma": "!", "upos": "PUNCT", "xpos": ".", "head": 1, "deprel": "punct", "misc": "start_char=30|end_char=31" }]] ONE_SENTENCE_DOC = Document(TEST_ONE_SENTENCE, "Unban Mox Opal!") TWO_SENTENCE_DOC = Document(TEST_TWO_SENTENCES, "Unban Mox Opal! Unban Mox Opal!") def check_response(response, response_len=1, semgrex_len=1, source_index=1, target_index=3, reln='obj'): assert len(response.result) == response_len for sentence_idx, sentence_result in enumerate(response.result): for semgrex_result in sentence_result.result: for match in semgrex_result.match: assert sentence_idx == match.sentenceIndex assert len(response.result[0].result) == semgrex_len for semgrex_result in response.result[0].result: assert len(semgrex_result.match) == 1 assert semgrex_result.match[0].matchIndex == source_index for match in semgrex_result.match: assert len(match.node) == 2 assert match.node[0].name == 'source' assert match.node[0].matchIndex == source_index assert match.node[1].name == 'target' assert match.node[1].matchIndex == target_index assert len(match.reln) == 1 assert match.reln[0].name == 'zzz' assert match.reln[0].reln == reln def test_multi(): with semgrex.Semgrex() as sem: response = sem.process(ONE_SENTENCE_DOC, "{}=source >obj=zzz {}=target") check_response(response) response = sem.process(ONE_SENTENCE_DOC, "{}=source >obj=zzz {}=target") check_response(response) response = sem.process(TWO_SENTENCE_DOC, "{}=source >obj=zzz {}=target") check_response(response, response_len=2) def test_single_sentence(): response = semgrex.process_doc(ONE_SENTENCE_DOC, "{}=source >obj=zzz {}=target") check_response(response) def test_two_semgrex(): response = semgrex.process_doc(ONE_SENTENCE_DOC, "{}=source >obj=zzz {}=target", "{}=source >obj=zzz {}=target") check_response(response, semgrex_len=2) def test_two_sentences(): response = semgrex.process_doc(TWO_SENTENCE_DOC, "{}=source >obj=zzz {}=target") check_response(response, response_len=2) def test_word_attribute(): response = semgrex.process_doc(ONE_SENTENCE_DOC, "{word:Mox}=source <=zzz {word:Opal}=target") check_response(response, response_len=1, source_index=2, reln='compound') def test_lemma_attribute(): response = semgrex.process_doc(ONE_SENTENCE_DOC, "{lemma:Mox}=source <=zzz {lemma:Opal}=target") check_response(response, response_len=1, source_index=2, reln='compound') def test_xpos_attribute(): response = semgrex.process_doc(ONE_SENTENCE_DOC, "{tag:NNP}=source <=zzz {word:Opal}=target") check_response(response, response_len=1, source_index=2, reln='compound') response = semgrex.process_doc(ONE_SENTENCE_DOC, "{pos:NNP}=source <=zzz {word:Opal}=target") check_response(response, response_len=1, source_index=2, reln='compound') def test_upos_attribute(): response = semgrex.process_doc(ONE_SENTENCE_DOC, "{cpos:PROPN}=source <=zzz {word:Opal}=target") check_response(response, response_len=1, source_index=2, reln='compound') def test_ner_attribute(): response = semgrex.process_doc(ONE_SENTENCE_DOC, "{cpos:PROPN}=source <=zzz {ner:GEM}=target") check_response(response, response_len=1, source_index=2, reln='compound') def test_hand_built_request(): """ Essentially a test program: the result should be a response with one match, two named nodes, one named relation """ request = SemgrexRequest() request.semgrex.append("{}=source >obj=zzz {}=target") query = request.query.add() for idx, word in enumerate(['Unban', 'Mox', 'Opal']): token = query.token.add() token.word = word token.value = word node = query.graph.node.add() node.sentenceIndex = 1 node.index = idx+1 edge = query.graph.edge.add() edge.source = 1 edge.target = 3 edge.dep = 'obj' edge = query.graph.edge.add() edge.source = 3 edge.target = 2 edge.dep = 'compound' response = semgrex.send_semgrex_request(request) check_response(response) BLANK_DEPENDENCY_SENTENCE = """ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0007 # text = You wonder if he was manipulating the market with his bombing targets. 1 You you PRON PRP Case=Nom|Person=2|PronType=Prs 2 nsubj _ _ 2 wonder wonder VERB VBP Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin 1 _ _ _ 3 if if SCONJ IN _ 6 mark _ _ 4 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 6 nsubj _ _ 5 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 6 aux _ _ 6 manipulating manipulate VERB VBG Tense=Pres|VerbForm=Part 2 ccomp _ _ 7 the the DET DT Definite=Def|PronType=Art 8 det _ _ 8 market market NOUN NN Number=Sing 6 obj _ _ 9 with with ADP IN _ 12 case _ _ 10 his his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 12 nmod:poss _ _ 11 bombing bombing NOUN NN Number=Sing 12 compound _ _ 12 targets target NOUN NNS Number=Plur 6 obl _ SpaceAfter=No 13 . . PUNCT . _ 2 punct _ _ """.lstrip() def test_blank_dependency(): """ A user / contributor sent a dependency file with blank dependency labels and twisted up roots """ blank_dep_doc = CoNLL.conll2doc(input_str=BLANK_DEPENDENCY_SENTENCE) blank_dep_request = semgrex.build_request(blank_dep_doc, "{}=root <_=edge {}") response = semgrex.send_semgrex_request(blank_dep_request) assert len(response.result) == 1 assert len(response.result[0].result) == 1 assert len(response.result[0].result[0].match) == 1 # there should be a named node... assert len(response.result[0].result[0].match[0].node) == 1 assert response.result[0].result[0].match[0].node[0].name == 'root' assert response.result[0].result[0].match[0].node[0].matchIndex == 2 # ... and a named edge assert len(response.result[0].result[0].match[0].edge) == 1 assert response.result[0].result[0].match[0].edge[0].source == 1 assert response.result[0].result[0].match[0].edge[0].target == 2 assert response.result[0].result[0].match[0].edge[0].reln == "_" EXPECTED_ONE_SENTENCE_MATCH = """ # text = Unban Mox Opal! # sent_id = 0 # semgrex pattern |{cpos:PROPN}=source <=zzz {ner:GEM}=target| matched at 2:Mox source=2:Mox target=3:Opal # highlight tokens = 2 # highlight deprels = 2 1 Unban unban VERB VB Mood=Imp|VerbForm=Fin 0 root _ start_char=0|end_char=5 2 Mox Mox PROPN NNP Number=Sing 3 compound _ start_char=6|end_char=9 3 Opal Opal PROPN NNP Number=Sing 1 obj _ SpaceAfter=No|start_char=10|end_char=14|ner=GEM 4 ! ! PUNCT . _ 1 punct _ SpaceAfter=No|start_char=14|end_char=15 """.strip() def test_ner_annotated(): semgrex_pattern = "{cpos:PROPN}=source <=zzz {ner:GEM}=target" # not using the existing ONE_SENTENCE_DOC as the Document may be mutated doc = Document(TEST_ONE_SENTENCE, "Unban Mox Opal!") response = semgrex.process_doc(doc, semgrex_pattern) doc = semgrex.annotate_doc(doc, response, semgrex_pattern, True, False) formatted = "{:C}".format(doc).strip() assert formatted == EXPECTED_ONE_SENTENCE_MATCH EXPECTED_ONE_SENTENCE_NO_MATCH = """ # text = Unban Mox Opal! # sent_id = 0 # semgrex pattern |{cpos:ZZZZ}| did not match! 1 Unban unban VERB VB Mood=Imp|VerbForm=Fin 0 root _ start_char=0|end_char=5 2 Mox Mox PROPN NNP Number=Sing 3 compound _ start_char=6|end_char=9 3 Opal Opal PROPN NNP Number=Sing 1 obj _ SpaceAfter=No|start_char=10|end_char=14|ner=GEM 4 ! ! PUNCT . _ 1 punct _ SpaceAfter=No|start_char=14|end_char=15 """.strip() def test_not_annotated(): semgrex_pattern = "{cpos:ZZZZ}" # not using the existing ONE_SENTENCE_DOC as the Document may be mutated doc = Document(TEST_ONE_SENTENCE, "Unban Mox Opal!") response = semgrex.process_doc(doc, semgrex_pattern) doc = semgrex.annotate_doc(doc, response, semgrex_pattern, False, False) formatted = "{:C}".format(doc).strip() assert formatted == EXPECTED_ONE_SENTENCE_NO_MATCH def test_empty_not_annotated(): """ If there are no responses and match_only is set, the returned doc should be empty """ semgrex_pattern = "{cpos:ZZZZ}" # not using the existing ONE_SENTENCE_DOC as the Document may be mutated doc = Document(TEST_ONE_SENTENCE, "Unban Mox Opal!") response = semgrex.process_doc(doc, semgrex_pattern) doc = semgrex.annotate_doc(doc, response, semgrex_pattern, True, False) formatted = "{:C}".format(doc).strip() assert formatted == "" def test_only_not_annotated(): semgrex_pattern = "{cpos:ZZZZ}" # not using the existing ONE_SENTENCE_DOC as the Document may be mutated doc = Document(TEST_ONE_SENTENCE, "Unban Mox Opal!") response = semgrex.process_doc(doc, semgrex_pattern) doc = semgrex.annotate_doc(doc, response, semgrex_pattern, False, True) formatted = "{:C}".format(doc).strip() assert formatted == EXPECTED_ONE_SENTENCE_NO_MATCH ================================================ FILE: stanza/tests/server/test_server_misc.py ================================================ """ Misc tests for the server """ import pytest import re import stanza.server as corenlp from stanza.tests import compare_ignoring_whitespace pytestmark = pytest.mark.client EN_DOC = "Joe Smith lives in California." EN_DOC_GOLD = """ Sentence #1 (6 tokens): Joe Smith lives in California. Tokens: [Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP Lemma=Joe NamedEntityTag=PERSON] [Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP Lemma=Smith NamedEntityTag=PERSON] [Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ Lemma=live NamedEntityTag=O] [Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN Lemma=in NamedEntityTag=O] [Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP Lemma=California NamedEntityTag=STATE_OR_PROVINCE] [Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=. Lemma=. NamedEntityTag=O] Dependency Parse (enhanced plus plus dependencies): root(ROOT-0, lives-3) compound(Smith-2, Joe-1) nsubj(lives-3, Smith-2) case(California-5, in-4) obl:in(lives-3, California-5) punct(lives-3, .-6) Extracted the following NER entity mentions: Joe Smith PERSON PERSON:0.9972202681743931 California STATE_OR_PROVINCE LOCATION:0.9990868267559281 Extracted the following KBP triples: 1.0 Joe Smith per:statesorprovinces_of_residence California """ EN_DOC_POS_ONLY_GOLD = """ Sentence #1 (6 tokens): Joe Smith lives in California. Tokens: [Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP] [Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP] [Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ] [Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN] [Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP] [Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.] """ def test_english_request(): """ Test case of starting server with Spanish defaults, and then requesting default English properties """ with corenlp.CoreNLPClient(properties='spanish', server_id='test_spanish_english_request') as client: ann = client.annotate(EN_DOC, properties='english', output_format='text') compare_ignoring_whitespace(ann, EN_DOC_GOLD) # Rerun the test with a server created in English mode to verify # that the expected output is what the defaults actually give us with corenlp.CoreNLPClient(properties='english', server_id='test_english_request') as client: ann = client.annotate(EN_DOC, output_format='text') compare_ignoring_whitespace(ann, EN_DOC_GOLD) def test_default_annotators(): """ Test case of creating a client with start_server=False and a set of annotators The annotators should be used instead of the server's default annotators """ with corenlp.CoreNLPClient(server_id='test_default_annotators', output_format='text', annotators=['tokenize','ssplit','pos','lemma','ner','depparse']) as client: with corenlp.CoreNLPClient(start_server=False, output_format='text', annotators=['tokenize','ssplit','pos']) as client2: ann = client2.annotate(EN_DOC) expected_codepoints = ((0, 1), (2, 4), (5, 8), (9, 15), (16, 20)) expected_characters = ((0, 1), (2, 4), (5, 10), (11, 17), (18, 22)) codepoint_doc = "I am 𝒚̂𝒊 random text" def test_codepoints(): """ Test case of asking for codepoints from the English tokenizer """ with corenlp.CoreNLPClient(annotators=['tokenize','ssplit'], # 'depparse','coref'], properties={'tokenize.codepoint': 'true'}) as client: ann = client.annotate(codepoint_doc) for i, (codepoints, characters) in enumerate(zip(expected_codepoints, expected_characters)): token = ann.sentence[0].token[i] assert token.codepointOffsetBegin == codepoints[0] assert token.codepointOffsetEnd == codepoints[1] assert token.beginChar == characters[0] assert token.endChar == characters[1] def test_codepoint_text(): """ Test case of extracting the correct sentence text using codepoints """ text = 'Unban mox opal 🐱. This is a second sentence.' with corenlp.CoreNLPClient(annotators=["tokenize","ssplit"], properties={'tokenize.codepoint': 'true'}) as client: ann = client.annotate(text) text_start = ann.sentence[0].token[0].codepointOffsetBegin text_end = ann.sentence[0].token[-1].codepointOffsetEnd sentence_text = text[text_start:text_end] assert sentence_text == 'Unban mox opal 🐱.' text_start = ann.sentence[1].token[0].codepointOffsetBegin text_end = ann.sentence[1].token[-1].codepointOffsetEnd sentence_text = text[text_start:text_end] assert sentence_text == 'This is a second sentence.' ================================================ FILE: stanza/tests/server/test_server_pretokenized.py ================================================ """ Misc tests for the server """ import pytest import re from stanza.server import CoreNLPClient pytestmark = pytest.mark.client tokens = {} tags = {} # Italian examples tokens["italian"] = [ "È vero , tutti possiamo essere sostituiti .\n Alcune chiamate partirono da il Quirinale ." ] tags["italian"] = [ [ ["AUX", "ADJ", "PUNCT", "PRON", "AUX", "AUX", "VERB", "PUNCT"], ["DET", "NOUN", "VERB", "ADP", "DET", "PROPN", "PUNCT"], ], ] # French examples tokens["french"] = [ ( "Les études durent six ans mais leur contenu diffère donc selon les Facultés .\n" "Il est fêté le 22 mai ." ) ] tags["french"] = [ [ ["DET", "NOUN", "VERB", "NUM", "NOUN", "CCONJ", "DET", "NOUN", "VERB", "ADV", "ADP", "DET", "PROPN", "PUNCT"], ["PRON", "AUX", "VERB", "DET", "NUM", "NOUN", "PUNCT"] ], ] # English examples tokens["english"] = ["This shouldn't be split .\n I hope it's not ."] tags["english"] = [ [ ["DT", "NN", "VB", "VBN", "."], ["PRP", "VBP", "PRP$", "RB", "."], ], ] def pretokenized_test(lang): """Test submitting pretokenized French text.""" with CoreNLPClient( properties=lang, annotators="pos", pretokenized=True, be_quiet=True, ) as client: for input_text, gold_tags in zip(tokens[lang], tags[lang]): ann = client.annotate(input_text) for sentence_tags, sentence in zip(gold_tags, ann.sentence): result_tags = [tok.pos for tok in sentence.token] assert sentence_tags == result_tags def test_english_pretokenized(): pretokenized_test("english") def test_italian_pretokenized(): pretokenized_test("italian") def test_french_pretokenized(): pretokenized_test("french") ================================================ FILE: stanza/tests/server/test_server_request.py ================================================ """ Tests for setting request properties of servers """ import json import pytest import stanza.server as corenlp from stanza.protobuf import Document from stanza.tests import TEST_WORKING_DIR, compare_ignoring_whitespace pytestmark = pytest.mark.client EN_DOC = "Joe Smith lives in California." # results with an example properties file EN_DOC_GOLD = """ Sentence #1 (6 tokens): Joe Smith lives in California. Tokens: [Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP] [Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP] [Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ] [Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN] [Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP] [Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.] """ GERMAN_DOC = "Angela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland." GERMAN_DOC_GOLD = """ Sentence #1 (10 tokens): Angela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland. Tokens: [Text=Angela CharacterOffsetBegin=0 CharacterOffsetEnd=6 PartOfSpeech=PROPN] [Text=Merkel CharacterOffsetBegin=7 CharacterOffsetEnd=13 PartOfSpeech=PROPN] [Text=ist CharacterOffsetBegin=14 CharacterOffsetEnd=17 PartOfSpeech=AUX] [Text=seit CharacterOffsetBegin=18 CharacterOffsetEnd=22 PartOfSpeech=ADP] [Text=2005 CharacterOffsetBegin=23 CharacterOffsetEnd=27 PartOfSpeech=NUM] [Text=Bundeskanzlerin CharacterOffsetBegin=28 CharacterOffsetEnd=43 PartOfSpeech=NOUN] [Text=der CharacterOffsetBegin=44 CharacterOffsetEnd=47 PartOfSpeech=DET] [Text=Bundesrepublik CharacterOffsetBegin=48 CharacterOffsetEnd=62 PartOfSpeech=PROPN] [Text=Deutschland CharacterOffsetBegin=63 CharacterOffsetEnd=74 PartOfSpeech=PROPN] [Text=. CharacterOffsetBegin=74 CharacterOffsetEnd=75 PartOfSpeech=PUNCT] """ FRENCH_CUSTOM_PROPS = {'annotators': 'tokenize,ssplit,mwt,pos,parse', 'tokenize.language': 'fr', 'pos.model': 'edu/stanford/nlp/models/pos-tagger/french-ud.tagger', 'parse.model': 'edu/stanford/nlp/models/srparser/frenchSR.ser.gz', 'mwt.mappingFile': 'edu/stanford/nlp/models/mwt/french/french-mwt.tsv', 'mwt.pos.model': 'edu/stanford/nlp/models/mwt/french/french-mwt.tagger', 'mwt.statisticalMappingFile': 'edu/stanford/nlp/models/mwt/french/french-mwt-statistical.tsv', 'mwt.preserveCasing': 'false', 'outputFormat': 'text'} FRENCH_EXTRA_PROPS = {'annotators': 'tokenize,ssplit,mwt,pos,depparse', 'tokenize.language': 'fr', 'pos.model': 'edu/stanford/nlp/models/pos-tagger/french-ud.tagger', 'mwt.mappingFile': 'edu/stanford/nlp/models/mwt/french/french-mwt.tsv', 'mwt.pos.model': 'edu/stanford/nlp/models/mwt/french/french-mwt.tagger', 'mwt.statisticalMappingFile': 'edu/stanford/nlp/models/mwt/french/french-mwt-statistical.tsv', 'mwt.preserveCasing': 'false', 'depparse.model': 'edu/stanford/nlp/models/parser/nndep/UD_French.gz'} FRENCH_DOC = "Cette enquête préliminaire fait suite aux révélations de l’hebdomadaire quelques jours plus tôt." FRENCH_CUSTOM_GOLD = """ Sentence #1 (16 tokens): Cette enquête préliminaire fait suite aux révélations de l’hebdomadaire quelques jours plus tôt. Tokens: [Text=Cette CharacterOffsetBegin=0 CharacterOffsetEnd=5 PartOfSpeech=DET] [Text=enquête CharacterOffsetBegin=6 CharacterOffsetEnd=13 PartOfSpeech=NOUN] [Text=préliminaire CharacterOffsetBegin=14 CharacterOffsetEnd=26 PartOfSpeech=ADJ] [Text=fait CharacterOffsetBegin=27 CharacterOffsetEnd=31 PartOfSpeech=VERB] [Text=suite CharacterOffsetBegin=32 CharacterOffsetEnd=37 PartOfSpeech=NOUN] [Text=à CharacterOffsetBegin=38 CharacterOffsetEnd=41 PartOfSpeech=ADP] [Text=les CharacterOffsetBegin=38 CharacterOffsetEnd=41 PartOfSpeech=DET] [Text=révélations CharacterOffsetBegin=42 CharacterOffsetEnd=53 PartOfSpeech=NOUN] [Text=de CharacterOffsetBegin=54 CharacterOffsetEnd=56 PartOfSpeech=ADP] [Text=l’ CharacterOffsetBegin=57 CharacterOffsetEnd=59 PartOfSpeech=NOUN] [Text=hebdomadaire CharacterOffsetBegin=59 CharacterOffsetEnd=71 PartOfSpeech=ADJ] [Text=quelques CharacterOffsetBegin=72 CharacterOffsetEnd=80 PartOfSpeech=DET] [Text=jours CharacterOffsetBegin=81 CharacterOffsetEnd=86 PartOfSpeech=NOUN] [Text=plus CharacterOffsetBegin=87 CharacterOffsetEnd=91 PartOfSpeech=ADV] [Text=tôt CharacterOffsetBegin=92 CharacterOffsetEnd=95 PartOfSpeech=ADV] [Text=. CharacterOffsetBegin=95 CharacterOffsetEnd=96 PartOfSpeech=PUNCT] Constituency parse: (ROOT (SENT (NP (DET Cette) (MWN (NOUN enquête) (ADJ préliminaire))) (VN (MWV (VERB fait) (NOUN suite))) (PP (ADP à) (NP (DET les) (NOUN révélations) (PP (ADP de) (NP (NOUN l’) (AP (ADJ hebdomadaire)))))) (NP (DET quelques) (NOUN jours)) (AdP (ADV plus) (ADV tôt)) (PUNCT .))) """ FRENCH_EXTRA_GOLD = """ Sentence #1 (16 tokens): Cette enquête préliminaire fait suite aux révélations de l’hebdomadaire quelques jours plus tôt. Tokens: [Text=Cette CharacterOffsetBegin=0 CharacterOffsetEnd=5 PartOfSpeech=DET] [Text=enquête CharacterOffsetBegin=6 CharacterOffsetEnd=13 PartOfSpeech=NOUN] [Text=préliminaire CharacterOffsetBegin=14 CharacterOffsetEnd=26 PartOfSpeech=ADJ] [Text=fait CharacterOffsetBegin=27 CharacterOffsetEnd=31 PartOfSpeech=VERB] [Text=suite CharacterOffsetBegin=32 CharacterOffsetEnd=37 PartOfSpeech=NOUN] [Text=à CharacterOffsetBegin=38 CharacterOffsetEnd=41 PartOfSpeech=ADP] [Text=les CharacterOffsetBegin=38 CharacterOffsetEnd=41 PartOfSpeech=DET] [Text=révélations CharacterOffsetBegin=42 CharacterOffsetEnd=53 PartOfSpeech=NOUN] [Text=de CharacterOffsetBegin=54 CharacterOffsetEnd=56 PartOfSpeech=ADP] [Text=l’ CharacterOffsetBegin=57 CharacterOffsetEnd=59 PartOfSpeech=NOUN] [Text=hebdomadaire CharacterOffsetBegin=59 CharacterOffsetEnd=71 PartOfSpeech=ADJ] [Text=quelques CharacterOffsetBegin=72 CharacterOffsetEnd=80 PartOfSpeech=DET] [Text=jours CharacterOffsetBegin=81 CharacterOffsetEnd=86 PartOfSpeech=NOUN] [Text=plus CharacterOffsetBegin=87 CharacterOffsetEnd=91 PartOfSpeech=ADV] [Text=tôt CharacterOffsetBegin=92 CharacterOffsetEnd=95 PartOfSpeech=ADV] [Text=. CharacterOffsetBegin=95 CharacterOffsetEnd=96 PartOfSpeech=PUNCT] Dependency Parse (enhanced plus plus dependencies): root(ROOT-0, fait-4) det(enquête-2, Cette-1) nsubj(fait-4, enquête-2) amod(enquête-2, préliminaire-3) obj(fait-4, suite-5) case(révélations-8, à-6) det(révélations-8, les-7) obl:à(fait-4, révélations-8) case(l’-10, de-9) nmod:de(révélations-8, l’-10) amod(révélations-8, hebdomadaire-11) det(jours-13, quelques-12) obl(fait-4, jours-13) advmod(tôt-15, plus-14) advmod(jours-13, tôt-15) punct(fait-4, .-16) """ FRENCH_JSON_GOLD = json.loads(open(f'{TEST_WORKING_DIR}/out/example_french.json', encoding="utf-8").read()) ES_DOC = 'Andrés Manuel López Obrador es el presidente de México.' ES_PROPS = {'annotators': 'tokenize,ssplit,mwt,pos,depparse', 'tokenize.language': 'es', 'pos.model': 'edu/stanford/nlp/models/pos-tagger/spanish-ud.tagger', 'mwt.mappingFile': 'edu/stanford/nlp/models/mwt/spanish/spanish-mwt.tsv', 'depparse.model': 'edu/stanford/nlp/models/parser/nndep/UD_Spanish.gz'} ES_PROPS_GOLD = """ Sentence #1 (10 tokens): Andrés Manuel López Obrador es el presidente de México. Tokens: [Text=Andrés CharacterOffsetBegin=0 CharacterOffsetEnd=6 PartOfSpeech=PROPN] [Text=Manuel CharacterOffsetBegin=7 CharacterOffsetEnd=13 PartOfSpeech=PROPN] [Text=López CharacterOffsetBegin=14 CharacterOffsetEnd=19 PartOfSpeech=PROPN] [Text=Obrador CharacterOffsetBegin=20 CharacterOffsetEnd=27 PartOfSpeech=PROPN] [Text=es CharacterOffsetBegin=28 CharacterOffsetEnd=30 PartOfSpeech=AUX] [Text=el CharacterOffsetBegin=31 CharacterOffsetEnd=33 PartOfSpeech=DET] [Text=presidente CharacterOffsetBegin=34 CharacterOffsetEnd=44 PartOfSpeech=NOUN] [Text=de CharacterOffsetBegin=45 CharacterOffsetEnd=47 PartOfSpeech=ADP] [Text=México CharacterOffsetBegin=48 CharacterOffsetEnd=54 PartOfSpeech=PROPN] [Text=. CharacterOffsetBegin=54 CharacterOffsetEnd=55 PartOfSpeech=PUNCT] Dependency Parse (enhanced plus plus dependencies): root(ROOT-0, presidente-7) nsubj(presidente-7, Andrés-1) flat(Andrés-1, Manuel-2) flat(Andrés-1, López-3) flat(Andrés-1, Obrador-4) cop(presidente-7, es-5) det(presidente-7, el-6) case(México-9, de-8) nmod:de(presidente-7, México-9) punct(presidente-7, .-10) """ class TestServerRequest: @pytest.fixture(scope="class") def corenlp_client(self): """ Client to run tests on """ client = corenlp.CoreNLPClient(annotators='tokenize,ssplit,pos', server_id='stanza_request_tests_server') yield client client.stop() def test_basic(self, corenlp_client): """ Basic test of making a request, test default output format is a Document """ ann = corenlp_client.annotate(EN_DOC, output_format="text") compare_ignoring_whitespace(ann, EN_DOC_GOLD) ann = corenlp_client.annotate(EN_DOC) assert isinstance(ann, Document) def test_python_dict(self, corenlp_client): """ Test using a Python dictionary to specify all request properties """ ann = corenlp_client.annotate(ES_DOC, properties=ES_PROPS, output_format="text") compare_ignoring_whitespace(ann, ES_PROPS_GOLD) ann = corenlp_client.annotate(FRENCH_DOC, properties=FRENCH_CUSTOM_PROPS) compare_ignoring_whitespace(ann, FRENCH_CUSTOM_GOLD) def test_lang_setting(self, corenlp_client): """ Test using a Stanford CoreNLP supported languages as a properties key """ ann = corenlp_client.annotate(GERMAN_DOC, properties="german", output_format="text") compare_ignoring_whitespace(ann, GERMAN_DOC_GOLD) def test_annotators_and_output_format(self, corenlp_client): """ Test setting the annotators and output_format """ ann = corenlp_client.annotate(FRENCH_DOC, properties=FRENCH_EXTRA_PROPS, annotators="tokenize,ssplit,mwt,pos", output_format="json") assert ann == FRENCH_JSON_GOLD ================================================ FILE: stanza/tests/server/test_server_start.py ================================================ """ Tests for starting a server in Python code """ import pytest import stanza.server as corenlp from stanza.server.client import AnnotationException import time from stanza.tests import * pytestmark = pytest.mark.client EN_DOC = "Joe Smith lives in California." # results on EN_DOC with standard StanfordCoreNLP defaults EN_PRELOAD_GOLD = """ Sentence #1 (6 tokens): Joe Smith lives in California. Tokens: [Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP Lemma=Joe NamedEntityTag=PERSON] [Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP Lemma=Smith NamedEntityTag=PERSON] [Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ Lemma=live NamedEntityTag=O] [Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN Lemma=in NamedEntityTag=O] [Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP Lemma=California NamedEntityTag=STATE_OR_PROVINCE] [Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=. Lemma=. NamedEntityTag=O] Dependency Parse (enhanced plus plus dependencies): root(ROOT-0, lives-3) compound(Smith-2, Joe-1) nsubj(lives-3, Smith-2) case(California-5, in-4) obl:in(lives-3, California-5) punct(lives-3, .-6) Extracted the following NER entity mentions: Joe Smith PERSON PERSON:0.9972202681743931 California STATE_OR_PROVINCE LOCATION:0.9990868267559281 Extracted the following KBP triples: 1.0 Joe Smith per:statesorprovinces_of_residence California """ # results with an example properties file EN_PROPS_FILE_GOLD = """ Sentence #1 (6 tokens): Joe Smith lives in California. Tokens: [Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP] [Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP] [Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ] [Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN] [Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP] [Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.] """ GERMAN_DOC = "Angela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland." # results with standard German properties GERMAN_FULL_PROPS_GOLD = """ Sentence #1 (10 tokens): Angela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland. Tokens: [Text=Angela CharacterOffsetBegin=0 CharacterOffsetEnd=6 PartOfSpeech=PROPN Lemma=angela NamedEntityTag=PERSON] [Text=Merkel CharacterOffsetBegin=7 CharacterOffsetEnd=13 PartOfSpeech=PROPN Lemma=merkel NamedEntityTag=PERSON] [Text=ist CharacterOffsetBegin=14 CharacterOffsetEnd=17 PartOfSpeech=AUX Lemma=ist NamedEntityTag=O] [Text=seit CharacterOffsetBegin=18 CharacterOffsetEnd=22 PartOfSpeech=ADP Lemma=seit NamedEntityTag=O] [Text=2005 CharacterOffsetBegin=23 CharacterOffsetEnd=27 PartOfSpeech=NUM Lemma=2005 NamedEntityTag=O] [Text=Bundeskanzlerin CharacterOffsetBegin=28 CharacterOffsetEnd=43 PartOfSpeech=NOUN Lemma=bundeskanzlerin NamedEntityTag=O] [Text=der CharacterOffsetBegin=44 CharacterOffsetEnd=47 PartOfSpeech=DET Lemma=der NamedEntityTag=O] [Text=Bundesrepublik CharacterOffsetBegin=48 CharacterOffsetEnd=62 PartOfSpeech=PROPN Lemma=bundesrepublik NamedEntityTag=LOCATION] [Text=Deutschland CharacterOffsetBegin=63 CharacterOffsetEnd=74 PartOfSpeech=PROPN Lemma=deutschland NamedEntityTag=LOCATION] [Text=. CharacterOffsetBegin=74 CharacterOffsetEnd=75 PartOfSpeech=PUNCT Lemma=. NamedEntityTag=O] Dependency Parse (enhanced plus plus dependencies): root(ROOT-0, Bundeskanzlerin-6) nsubj(Bundeskanzlerin-6, Angela-1) flat(Angela-1, Merkel-2) cop(Bundeskanzlerin-6, ist-3) case(2005-5, seit-4) nmod:seit(Bundeskanzlerin-6, 2005-5) det(Bundesrepublik-8, der-7) nmod(Bundeskanzlerin-6, Bundesrepublik-8) appos(Bundesrepublik-8, Deutschland-9) punct(Bundeskanzlerin-6, .-10) Extracted the following NER entity mentions: Angela Merkel PERSON PERSON:0.9999981583351504 Bundesrepublik Deutschland LOCATION LOCATION:0.9682902289749544 """ GERMAN_SMALL_PROPS = {'annotators': 'tokenize,ssplit,pos', 'tokenize.language': 'de', 'pos.model': 'edu/stanford/nlp/models/pos-tagger/german-ud.tagger'} # results with custom Python dictionary set properties GERMAN_SMALL_PROPS_GOLD = """ Sentence #1 (10 tokens): Angela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland. Tokens: [Text=Angela CharacterOffsetBegin=0 CharacterOffsetEnd=6 PartOfSpeech=PROPN] [Text=Merkel CharacterOffsetBegin=7 CharacterOffsetEnd=13 PartOfSpeech=PROPN] [Text=ist CharacterOffsetBegin=14 CharacterOffsetEnd=17 PartOfSpeech=AUX] [Text=seit CharacterOffsetBegin=18 CharacterOffsetEnd=22 PartOfSpeech=ADP] [Text=2005 CharacterOffsetBegin=23 CharacterOffsetEnd=27 PartOfSpeech=NUM] [Text=Bundeskanzlerin CharacterOffsetBegin=28 CharacterOffsetEnd=43 PartOfSpeech=NOUN] [Text=der CharacterOffsetBegin=44 CharacterOffsetEnd=47 PartOfSpeech=DET] [Text=Bundesrepublik CharacterOffsetBegin=48 CharacterOffsetEnd=62 PartOfSpeech=PROPN] [Text=Deutschland CharacterOffsetBegin=63 CharacterOffsetEnd=74 PartOfSpeech=PROPN] [Text=. CharacterOffsetBegin=74 CharacterOffsetEnd=75 PartOfSpeech=PUNCT] """ # results with custom Python dictionary set properties and annotators=tokenize,ssplit GERMAN_SMALL_PROPS_W_ANNOTATORS_GOLD = """ Sentence #1 (10 tokens): Angela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland. Tokens: [Text=Angela CharacterOffsetBegin=0 CharacterOffsetEnd=6] [Text=Merkel CharacterOffsetBegin=7 CharacterOffsetEnd=13] [Text=ist CharacterOffsetBegin=14 CharacterOffsetEnd=17] [Text=seit CharacterOffsetBegin=18 CharacterOffsetEnd=22] [Text=2005 CharacterOffsetBegin=23 CharacterOffsetEnd=27] [Text=Bundeskanzlerin CharacterOffsetBegin=28 CharacterOffsetEnd=43] [Text=der CharacterOffsetBegin=44 CharacterOffsetEnd=47] [Text=Bundesrepublik CharacterOffsetBegin=48 CharacterOffsetEnd=62] [Text=Deutschland CharacterOffsetBegin=63 CharacterOffsetEnd=74] [Text=. CharacterOffsetBegin=74 CharacterOffsetEnd=75] """ # properties for username/password example USERNAME_PASS_PROPS = {'annotators': 'tokenize,ssplit,pos'} USERNAME_PASS_GOLD = """ Sentence #1 (6 tokens): Joe Smith lives in California. Tokens: [Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP] [Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP] [Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ] [Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN] [Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP] [Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.] """ def annotate_and_time(client, text, properties={}): """ Submit an annotation request and return how long it took """ start = time.time() ann = client.annotate(text, properties=properties, output_format="text") end = time.time() return {'annotation': ann, 'start_time': start, 'end_time': end} def test_preload(): """ Test that the default annotators load fully immediately upon server start """ with corenlp.CoreNLPClient(server_id='test_server_start_preload') as client: # wait for annotators to load time.sleep(140) results = annotate_and_time(client, EN_DOC) compare_ignoring_whitespace(results['annotation'], EN_PRELOAD_GOLD) assert results['end_time'] - results['start_time'] < 3 def test_props_file(): """ Test starting the server with a props file """ with corenlp.CoreNLPClient(properties=SERVER_TEST_PROPS, server_id='test_server_start_props_file') as client: ann = client.annotate(EN_DOC, output_format="text") assert ann.strip() == EN_PROPS_FILE_GOLD.strip() def test_lang_start(): """ Test starting the server with a Stanford CoreNLP language name """ with corenlp.CoreNLPClient(properties='german', server_id='test_server_start_lang_name') as client: ann = client.annotate(GERMAN_DOC, output_format='text') compare_ignoring_whitespace(ann, GERMAN_FULL_PROPS_GOLD) def test_python_dict(): """ Test starting the server with a Python dictionary as default properties """ with corenlp.CoreNLPClient(properties=GERMAN_SMALL_PROPS, server_id='test_server_start_python_dict') as client: ann = client.annotate(GERMAN_DOC, output_format='text') assert ann.strip() == GERMAN_SMALL_PROPS_GOLD.strip() def test_python_dict_w_annotators(): """ Test starting the server with a Python dictionary as default properties, override annotators """ with corenlp.CoreNLPClient(properties=GERMAN_SMALL_PROPS, annotators="tokenize,ssplit", server_id='test_server_start_python_dict_w_annotators') as client: ann = client.annotate(GERMAN_DOC, output_format='text') assert ann.strip() == GERMAN_SMALL_PROPS_W_ANNOTATORS_GOLD.strip() def test_username_password(): """ Test starting a server with a username and password """ with corenlp.CoreNLPClient(properties=USERNAME_PASS_PROPS, username='user-1234', password='1234', server_id="test_server_username_pass") as client: # check with correct password ann = client.annotate(EN_DOC, output_format='text', username='user-1234', password='1234') assert ann.strip() == USERNAME_PASS_GOLD.strip() # check with incorrect password, should throw AnnotationException try: ann = client.annotate(EN_DOC, output_format='text', username='user-1234', password='12345') assert False except AnnotationException as ae: pass except Exception as e: assert False ================================================ FILE: stanza/tests/server/test_ssurgeon.py ================================================ import pytest from stanza.tests import compare_ignoring_whitespace pytestmark = [pytest.mark.travis, pytest.mark.client] from stanza.utils.conll import CoNLL import stanza.server.ssurgeon as ssurgeon SAMPLE_DOC_INPUT = """ # sent_id = 271 # text = Hers is easy to clean. # previous = What did the dealer like about Alex's car? # comment = extraction/raising via "tough extraction" and clausal subject 1 Hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nsubj _ _ 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _ 3 easy easy ADJ JJ Degree=Pos 0 root _ _ 4 to to PART TO _ 5 mark _ _ 5 clean clean VERB VB VerbForm=Inf 3 csubj _ SpaceAfter=No 6 . . PUNCT . _ 5 punct _ _ """ SAMPLE_DOC_EXPECTED = """ # sent_id = 271 # text = Hers is easy to clean. # previous = What did the dealer like about Alex's car? # comment = extraction/raising via "tough extraction" and clausal subject 1 Hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nsubj _ _ 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _ 3 easy easy ADJ JJ Degree=Pos 0 root _ _ 4 to to PART TO _ 5 mark _ _ 5 clean clean VERB VB VerbForm=Inf 3 advcl _ SpaceAfter=No 6 . . PUNCT . _ 5 punct _ _ """ def test_ssurgeon_same_length(): semgrex_pattern = "{}=source >nsubj {} >csubj=bad {}" ssurgeon_edits = ["relabelNamedEdge -edge bad -reln advcl"] doc = CoNLL.conll2doc(input_str=SAMPLE_DOC_INPUT) ssurgeon_response = ssurgeon.process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits) updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False) result = "{:C}".format(updated_doc) #print(result) #print(SAMPLE_DOC_EXPECTED) compare_ignoring_whitespace(result, SAMPLE_DOC_EXPECTED) ADD_WORD_DOC_INPUT = """ # text = Jennifer has lovely antennae. # sent_id = 12 # comment = if you're in to that kind of thing 1 Jennifer Jennifer PROPN NNP Number=Sing 2 nsubj _ start_char=0|end_char=8|ner=S-PERSON 2 has have VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root _ start_char=9|end_char=12|ner=O 3 lovely lovely ADJ JJ Degree=Pos 4 amod _ start_char=13|end_char=19|ner=O 4 antennae antenna NOUN NNS Number=Plur 2 obj _ start_char=20|end_char=28|ner=O|SpaceAfter=No 5 . . PUNCT . _ 2 punct _ start_char=28|end_char=29|ner=O """ ADD_WORD_DOC_EXPECTED = """ # text = Jennifer has lovely blue antennae. # sent_id = 12 # comment = if you're in to that kind of thing 1 Jennifer Jennifer PROPN NNP Number=Sing 2 nsubj _ ner=S-PERSON 2 has have VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root _ ner=O 3 lovely lovely ADJ JJ Degree=Pos 5 amod _ ner=O 4 blue blue ADJ JJ _ 5 amod _ ner=O 5 antennae antenna NOUN NNS Number=Plur 2 obj _ SpaceAfter=No|ner=O 6 . . PUNCT . _ 2 punct _ ner=O """ def test_ssurgeon_different_length(): semgrex_pattern = "{word:antennae}=antennae !> {word:blue}" ssurgeon_edits = ["addDep -gov antennae -reln amod -word blue -lemma blue -cpos ADJ -pos JJ -ner O -position -antennae -after \" \""] doc = CoNLL.conll2doc(input_str=ADD_WORD_DOC_INPUT) #print() #print("{:C}".format(doc)) ssurgeon_response = ssurgeon.process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits) updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False) result = "{:C}".format(updated_doc) #print(result) #print(ADD_WORD_DOC_EXPECTED) compare_ignoring_whitespace(result, ADD_WORD_DOC_EXPECTED) BECOME_MWT_DOC_INPUT = """ # sent_id = 25 # text = It's not yours! # comment = negation 1 It it PRON PRP Number=Sing|Person=2|PronType=Prs 4 nsubj _ SpaceAfter=No 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _ 3 not not PART RB Polarity=Neg 4 advmod _ _ 4 yours yours PRON PRP Gender=Neut|Number=Sing|Person=2|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No 5 ! ! PUNCT . _ 4 punct _ _ """ BECOME_MWT_DOC_EXPECTED = """ # sent_id = 25 # text = It's not yours! # comment = negation 1-2 It's _ _ _ _ _ _ _ _ 1 It it PRON PRP Number=Sing|Person=2|PronType=Prs 4 nsubj _ _ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _ 3 not not PART RB Polarity=Neg 4 advmod _ _ 4 yours yours PRON PRP Gender=Neut|Number=Sing|Person=2|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No 5 ! ! PUNCT . _ 4 punct _ _ """ def test_ssurgeon_become_mwt(): """ Test that converting a document, adding a new MWT, works as expected """ semgrex_pattern = "{word:It}=it . {word:/'s/}=s" ssurgeon_edits = ["EditNode -node it -is_mwt true -is_first_mwt true -mwt_text It's", "EditNode -node s -is_mwt true -is_first_mwt false -mwt_text It's"] doc = CoNLL.conll2doc(input_str=BECOME_MWT_DOC_INPUT) ssurgeon_response = ssurgeon.process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits) updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False) result = "{:C}".format(updated_doc) compare_ignoring_whitespace(result, BECOME_MWT_DOC_EXPECTED) EXISTING_MWT_DOC_INPUT = """ # sent_id = newsgroup-groups.google.com_GayMarriage_0ccbb50b41a5830b_ENG_20050321_181500-0005 # text = One of “NCRC4ME’s” 1 One one NUM CD NumType=Card 0 root 0:root _ 2 of of ADP IN _ 4 case 4:case _ 3 “ " PUNCT `` _ 4 punct 4:punct SpaceAfter=No 4-5 NCRC4ME’s _ _ _ _ _ _ _ SpaceAfter=No 4 NCRC4ME NCRC4ME PROPN NNP Number=Sing 1 compound 1:compound _ 5 ’s 's PART POS _ 4 case 4:case _ 6 ” " PUNCT '' _ 4 punct 4:punct _ """ # TODO: also, we shouldn't lose the enhanced dependencies... EXISTING_MWT_DOC_EXPECTED = """ # sent_id = newsgroup-groups.google.com_GayMarriage_0ccbb50b41a5830b_ENG_20050321_181500-0005 # text = One of “NCRC4ME’s” 1 One one NUM CD NumType=Card 0 root _ _ 2 of of ADP IN _ 4 case _ _ 3 “ " PUNCT `` _ 4 punct _ SpaceAfter=No 4-5 NCRC4ME’s _ _ _ _ _ _ _ SpaceAfter=No 4 NCRC4ME NCRC4ME PROPN NNP Number=Sing 1 compound _ _ 5 ’s 's PART POS _ 4 case _ _ 6 ” " PUNCT '' _ 4 punct _ _ """ def test_ssurgeon_existing_mwt_no_change(): """ Test that converting a document with an MWT works as expected Note regarding this test: Currently it works because ssurgeon.py doesn't look at the "changed" flag because of a bug in EditNode in CoreNLP 4.5.3 If that is fixed, but the enhanced dependencies aren't fixed, this test will fail because the enhanced dependencies *aren't* removed. Fixing the enhanced dependencies as well will fix that, though. """ semgrex_pattern = "{word:It}=it . {word:/'s/}=s" ssurgeon_edits = ["EditNode -node it -is_mwt true -is_first_mwt true -mwt_text It's", "EditNode -node s -is_mwt true -is_first_mwt false -mwt_text It's"] doc = CoNLL.conll2doc(input_str=EXISTING_MWT_DOC_INPUT) ssurgeon_response = ssurgeon.process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits) updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False) result = "{:C}".format(updated_doc) compare_ignoring_whitespace(result, EXISTING_MWT_DOC_EXPECTED) def check_empty_test(input_text, expected=None, echo=False): if expected is None: expected = input_text doc = CoNLL.conll2doc(input_str=input_text) # we don't want to edit this, just test the to/from conversion ssurgeon_response = ssurgeon.process_doc(doc, []) updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False) result = "{:C}".format(updated_doc) if echo: print("INPUT") print(input_text) print("EXPECTED") print(expected) print("RESULT") print(result) compare_ignoring_whitespace(result, expected) ITALIAN_MWT_INPUT = """ # sent_id = train_78 # text = @user dovrebbe fare pace col cervello # twittiro = IMPLICIT ANALOGY 1 @user @user SYM SYM _ 3 nsubj _ _ 2 dovrebbe dovere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux _ _ 3 fare fare VERB V VerbForm=Inf 0 root _ _ 4 pace pace NOUN S Gender=Fem|Number=Sing 3 obj _ _ 5-6 col _ _ _ _ _ _ _ _ 5 con con ADP E _ 7 case _ _ 6 il il DET RD Definite=Def|Gender=Masc|Number=Sing|PronType=Art 7 det _ _ 7 cervello cervello NOUN S Gender=Masc|Number=Sing 3 obl _ _ """ def test_ssurgeon_mwt_text(): """ Test that an MWT which is split into pieces which don't make up the original token results in a correct #text annotation For example, in Italian, "col" splits into "con il", and we want the #text to contain "col" """ check_empty_test(ITALIAN_MWT_INPUT) ITALIAN_SPACES_AFTER_INPUT=""" # sent_id = train_1114 # text = ““““ buona scuola ““““ # twittiro = EXPLICIT OTHER 1 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 2 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 3 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 4 “ “ PUNCT FB _ 6 punct _ _ 5 buona buono ADJ A Gender=Fem|Number=Sing 6 amod _ _ 6 scuola scuola NOUN S Gender=Fem|Number=Sing 0 root _ _ 7 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 8 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 9 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 10 “ “ PUNCT FB _ 6 punct _ SpacesAfter=\\n """ ITALIAN_SPACES_AFTER_YES_INPUT=""" # sent_id = train_1114 # text = ““““ buona scuola ““““ # twittiro = EXPLICIT OTHER 1 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 2 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 3 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 4 “ “ PUNCT FB _ 6 punct _ SpaceAfter=Yes 5 buona buono ADJ A Gender=Fem|Number=Sing 6 amod _ _ 6 scuola scuola NOUN S Gender=Fem|Number=Sing 0 root _ _ 7 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 8 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 9 “ “ PUNCT FB _ 6 punct _ SpaceAfter=No 10 “ “ PUNCT FB _ 6 punct _ SpacesAfter=\\n """ def test_ssurgeon_spaces_after_text(): """ Test that SpacesAfter goes and comes back the same way Tested using some random example from the UD_Italian-TWITTIRO dataset """ check_empty_test(ITALIAN_SPACES_AFTER_INPUT) def test_ssurgeon_spaces_after_yes(): """ Test that an unnecessary SpaceAfter=Yes is eliminated """ check_empty_test(ITALIAN_SPACES_AFTER_YES_INPUT, ITALIAN_SPACES_AFTER_INPUT) EMPTY_VALUES_INPUT = """ # text = Jennifer has lovely antennae. # sent_id = 12 # comment = if you're in to that kind of thing 1 Jennifer _ _ _ Number=Sing 2 nsubj _ ner=S-PERSON 2 has _ _ _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root _ ner=O 3 lovely _ _ _ Degree=Pos 4 amod _ ner=O 4 antennae _ _ _ Number=Plur 2 obj _ SpaceAfter=No|ner=O 5 . _ _ _ _ 2 punct _ ner=O """ def test_ssurgeon_blank_values(): """ Check that various None fields such as lemma & xpos are not turned into blanks Tests, like regulations, are often written in blood """ check_empty_test(EMPTY_VALUES_INPUT) # first couple sentences of UD_Cantonese-HK # we change the order of the misc column in word 3 to make sure the # pieces don't get unnecessarily reordered by ssurgeon CANTONESE_MISC_WORDS_INPUT = """ # sent_id = 1 # text = 你喺度搵乜嘢呀? 1 你 你 PRON _ _ 3 nsubj _ Translit=nei5|Gloss=2SG|SpaceAfter=No 2 喺度 喺度 ADV _ _ 3 advmod _ Translit=hai2dou6|Gloss=PROG|SpaceAfter=No 3 搵 搵 VERB _ _ 0 root _ Translit=wan2|Gloss=find|SpaceAfter=No 4 乜嘢 乜嘢 PRON _ _ 3 obj _ Translit=mat1je5|Gloss=what|SpaceAfter=No 5 呀 呀 PART _ _ 3 discourse:sp _ Translit=aa3|Gloss=SFP|SpaceAfter=No 6 ? ? PUNCT _ _ 3 punct _ SpaceAfter=No # sent_id = 2 # text = 咪執返啲嘢去阿哥個新屋度囖。 1 咪 咪 ADV _ _ 2 advmod _ SpaceAfter=No 2 執 執 VERB _ _ 0 root _ SpaceAfter=No 3 返 返 VERB _ _ 2 compound:dir _ SpaceAfter=No 4 啲 啲 NOUN _ NounType=Clf 5 clf:det _ SpaceAfter=No 5 嘢 嘢 NOUN _ _ 3 obj _ SpaceAfter=No 6 去 去 VERB _ _ 2 conj _ SpaceAfter=No 7 阿哥 阿哥 NOUN _ _ 10 nmod _ SpaceAfter=No 8 個 個 NOUN _ NounType=Clf 10 clf:det _ SpaceAfter=No 9 新 新 ADJ _ _ 10 amod _ SpaceAfter=No 10 屋 屋 NOUN _ _ 6 obj _ SpaceAfter=No 11 度 度 ADP _ _ 10 case:loc _ SpaceAfter=No 12 囖 囖 PART _ _ 2 discourse:sp _ SpaceAfter=No 13 。 。 PUNCT _ _ 2 punct _ SpaceAfter=No """ def test_ssurgeon_misc_words(): """ Check that various None fields such as lemma & xpos are not turned into blanks Tests, like regulations, are often written in blood """ check_empty_test(CANTONESE_MISC_WORDS_INPUT) ITALIAN_MWT_SPACE_AFTER_INPUT = """ # sent_id = train_78 # text = @user dovrebbe fare pace colcervello # twittiro = IMPLICIT ANALOGY 1 @user @user SYM SYM _ 3 nsubj _ _ 2 dovrebbe dovere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux _ _ 3 fare fare VERB V VerbForm=Inf 0 root _ _ 4 pace pace NOUN S Gender=Fem|Number=Sing 3 obj _ _ 5-6 col _ _ _ _ _ _ _ SpaceAfter=No 5 con con ADP E _ 7 case _ _ 6 il il DET RD Definite=Def|Gender=Masc|Number=Sing|PronType=Art 7 det _ _ 7 cervello cervello NOUN S Gender=Masc|Number=Sing 3 obl _ RandomFeature=foo """ def test_ssurgeon_mwt_space_after(): """ Check the SpaceAfter=No on an MWT (rather than a word) the RandomFeature=foo is on account of a silly bug in the initial version of passing in MWT misc features """ check_empty_test(ITALIAN_MWT_SPACE_AFTER_INPUT) ITALIAN_MWT_MISC_INPUT = """ # sent_id = train_78 # text = @user dovrebbe farepacecolcervello # twittiro = IMPLICIT ANALOGY 1 @user @user SYM SYM _ 3 nsubj _ _ 2 dovrebbe dovere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux _ _ 3-4 farepace _ _ _ _ _ _ _ Players=GonnaPlay|SpaceAfter=No 3 fare fare VERB V VerbForm=Inf 0 root _ _ 4 pace pace NOUN S Gender=Fem|Number=Sing 3 obj _ _ 5-6 col _ _ _ _ _ _ _ Haters=GonnaHate|SpaceAfter=No 5 con con ADP E _ 7 case _ _ 6 il il DET RD Definite=Def|Gender=Masc|Number=Sing|PronType=Art 7 det _ _ 7 cervello cervello NOUN S Gender=Masc|Number=Sing 3 obl _ RandomFeature=foo """ def test_ssurgeon_mwt_misc(): """ Check the SpaceAfter=No on an MWT (rather than a word) the RandomFeature=foo is on account of a silly bug in the initial version of passing in MWT misc features """ check_empty_test(ITALIAN_MWT_MISC_INPUT) SINDHI_ROOT_EXAMPLE = """ # sent_id = 1 # text = غلام رهڻ سان ماڻهو منافق ٿئي ٿو . 1 غلام غلام NOUN NN__اسم Case=Acc|Gender=Masc|Number=Sing|Person=3 2 compound _ _ 2 رهڻ ره VERB VB__فعل Number=Sing 6 advcl _ _ 3 سان سان ADP IN__حرفِ_جر Number=Sing 2 mark _ _ 4 ماڻهو ماڻهو NOUN NN__اسم Case=Nom|Gender=Masc|Number=Sing|Person=3 6 nsubj _ _ 5 منافق منافق ADJ JJ__صفت Case=Acc|Number=Sing|Person=3 6 xcomp _ _ 6 ٿئي ٿي VERB VB__فعل Number=Sing _ _ _ _ 7 ٿو ٿو AUX VB__فعل Number=Sing 6 aux _ _ 8 . . PUNCT -__پورو_دم _ 6 punct _ _ """.lstrip() SINDHI_ROOT_EXPECTED = """ # sent_id = 1 # text = غلام رهڻ سان ماڻهو منافق ٿئي ٿو . 1 غلام غلام NOUN NN__اسم Case=Acc|Gender=Masc|Number=Sing|Person=3 2 compound _ _ 2 رهڻ ره VERB VB__فعل Number=Sing 6 advcl _ _ 3 سان سان ADP IN__حرفِ_جر Number=Sing 2 mark _ _ 4 ماڻهو ماڻهو NOUN NN__اسم Case=Nom|Gender=Masc|Number=Sing|Person=3 6 nsubj _ _ 5 منافق منافق ADJ JJ__صفت Case=Acc|Number=Sing|Person=3 6 xcomp _ _ 6 ٿئي ٿي VERB VB__فعل Number=Sing 0 root _ _ 7 ٿو ٿو AUX VB__فعل Number=Sing 6 aux _ _ 8 . . PUNCT -__پورو_دم _ 6 punct _ _ """.strip() SINDHI_EDIT = """ {}=root !< {} setRoots root """ def test_ssurgeon_rewrite_sindhi_roots(): """ A user / contributor sent a dependency file with blank roots """ edits = ssurgeon.parse_ssurgeon_edits(SINDHI_EDIT) expected_edits = [ssurgeon.SsurgeonEdit(semgrex_pattern='{}=root !< {}', ssurgeon_edits=['setRoots root'], ssurgeon_id='1', notes='', language='UniversalEnglish')] assert edits == expected_edits blank_dep_doc = CoNLL.conll2doc(input_str=SINDHI_ROOT_EXAMPLE) # test that the conversion will work w/o crashing, such as because of a missing root edge request = ssurgeon.build_request(blank_dep_doc, edits) response = ssurgeon.process_doc(blank_dep_doc, edits) updated_doc = ssurgeon.convert_response_to_doc(blank_dep_doc, response, add_missing_text=False) result = "{:C}".format(updated_doc) assert result == SINDHI_ROOT_EXPECTED ================================================ FILE: stanza/tests/server/test_tokensregex.py ================================================ import pytest from stanza.tests import * from stanza.models.common.doc import Document import stanza.server.tokensregex as tokensregex pytestmark = [pytest.mark.travis, pytest.mark.client] from stanza.tests.server.test_semgrex import ONE_SENTENCE_DOC, TWO_SENTENCE_DOC def test_single_sentence(): #expected: #match { # sentence: 0 # match { # text: "Opal" # begin: 2 # end: 3 # } #} response = tokensregex.process_doc(ONE_SENTENCE_DOC, "Opal") assert len(response.match) == 1 assert len(response.match[0].match) == 1 assert response.match[0].match[0].sentence == 0 assert response.match[0].match[0].match.text == "Opal" assert response.match[0].match[0].match.begin == 2 assert response.match[0].match[0].match.end == 3 def test_ner_sentence(): #expected: #match { # sentence: 0 # match { # text: "Opal" # begin: 2 # end: 3 # } #} response = tokensregex.process_doc(ONE_SENTENCE_DOC, "[ner: GEM]") assert len(response.match) == 1 assert len(response.match[0].match) == 1 assert response.match[0].match[0].sentence == 0 assert response.match[0].match[0].match.text == "Opal" assert response.match[0].match[0].match.begin == 2 assert response.match[0].match[0].match.end == 3 ================================================ FILE: stanza/tests/server/test_tsurgeon.py ================================================ """ Test the semgrex interface """ import pytest import stanza from stanza.models.constituency import tree_reader from stanza.server.tsurgeon import process_trees, Tsurgeon from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.client] def test_simple(): text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) tregex = "WP=wp" tsurgeon = "relabel wp WWWPPP" result = process_trees(trees, (tregex, tsurgeon)) assert len(result) == 1 assert str(result[0]) == "(ROOT (SBARQ (WHNP (WWWPPP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" def test_context(): """ Processing the same thing twice should work twice... """ with Tsurgeon() as processor: text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = tree_reader.read_trees(text) tregex = "WP=wp" tsurgeon = "relabel wp WWWPPP" result = processor.process(trees, (tregex, tsurgeon)) assert len(result) == 1 assert str(result[0]) == "(ROOT (SBARQ (WHNP (WWWPPP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" result = processor.process(trees, (tregex, tsurgeon)) assert len(result) == 1 assert str(result[0]) == "(ROOT (SBARQ (WHNP (WWWPPP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" def test_arboretum(): """ Test a couple expressions used when processing the Arboretum treebank That particular treebank was the original inspiration for adding the Tsurgeon interface """ with Tsurgeon() as processor: text = "(s (par (fcl (n s1_1) (vp (v-fin s1_2) (v-pcp2 s1_4)) (adv s1_3) (np (pron-poss s1_5) (n s1_6) (pp (prp s1_7) (n s1_8)))) (pu s1_9) (conj-c s1_10) (fcl (adv s1_11) (v-fin s1_12) (np (prop s1_13) (pp (prp s1_14) (prop s1_15))) (np (art s1_16) (adjp (adv s1_17) (adj s1_18)) (n s1_19) (pp (prp s1_20) (np (pron-poss s1_21) (adj s1_22) (n s1_23) (prop s1_24))))) (pu s1_25)))" expected = "(s (par (fcl (n s1_1) (vp (v-fin s1_2) (adv s1_3) (v-pcp2 s1_4)) (np (pron-poss s1_5) (n s1_6) (pp (prp s1_7) (n s1_8)))) (pu s1_9) (conj-c s1_10) (fcl (adv s1_11) (v-fin s1_12) (np (prop s1_13) (pp (prp s1_14) (prop s1_15))) (np (art s1_16) (adjp (adv s1_17) (adj s1_18)) (n s1_19) (pp (prp s1_20) (np (pron-poss s1_21) (adj s1_22) (n s1_23) (prop s1_24))))) (pu s1_25)))" trees = tree_reader.read_trees(text) tregex = "s1_4 > (__=home > (__=parent > __=grandparent)) . (s1_3 > (__=move > =grandparent))" tsurgeon = "move move $+ home" result = processor.process(trees, (tregex, tsurgeon)) assert len(result) == 1 assert str(result[0]) == expected text = "(s (par (fcl (n s1_1) (vp (v-fin s1_2) (v-pcp2 s1_4)) (adv s1_3) (np (pron-poss s1_5) (n s1_6) (pp (prp s1_7) (n s1_8)))) (pu s1_9) (conj-c s1_10) (fcl (adv s1_11) (v-fin s1_12) (np (prop s1_13) (pp (prp s1_14) (prop s1_15))) (np (art s1_16) (adjp (adv s1_17) (adj s1_18)) (n s1_19) (pp (prp s1_20) (np (pron-poss s1_21) (adj s1_22) (n s1_23) (prop s1_24))))) (pu s1_25)))" expected = "(s (par (fcl (n s1_1) (vp (v-fin s1_2) (adv s1_3) (v-pcp2 s1_4)) (np (pron-poss s1_5) (n s1_6) (pp (prp s1_7) (n s1_8)))) (pu s1_9) (conj-c s1_10) (fcl (adv s1_11) (v-fin s1_12) (np (prop s1_13) (pp (prp s1_14) (prop s1_15))) (np (art s1_16) (adjp (adv s1_17) (adj s1_18)) (n s1_19) (pp (prp s1_20) (np (pron-poss s1_21) (adj s1_22) (n s1_23) (prop s1_24))))) (pu s1_25)))" trees = tree_reader.read_trees(text) tregex = "s1_4 > (__=home > (__=parent $+ (__=move <<, s1_3 <<- s1_3)))" tsurgeon = "move move $+ home" result = processor.process(trees, (tregex, tsurgeon)) assert len(result) == 1 assert str(result[0]) == expected ================================================ FILE: stanza/tests/server/test_ud_enhancer.py ================================================ import pytest import stanza from stanza.tests import * from stanza.models.common.doc import Document import stanza.server.ud_enhancer as ud_enhancer pytestmark = [pytest.mark.pipeline] def check_edges(graph, source, target, num, isExtra=None): edges = [edge for edge in graph.edge if edge.source == source and edge.target == target] assert len(edges) == num if num == 1: assert edges[0].isExtra == isExtra def test_one_sentence(): nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,pos,lemma,depparse") doc = nlp("This is the car that I bought") result = ud_enhancer.process_doc(doc, language="en", pronouns_pattern=None) assert len(result.sentence) == 1 sentence = result.sentence[0] basic = sentence.basicDependencies assert len(basic.node) == 7 assert len(basic.edge) == 6 check_edges(basic, 4, 7, 1, False) check_edges(basic, 7, 4, 0) enhanced = sentence.enhancedDependencies assert len(enhanced.node) == 7 assert len(enhanced.edge) == 7 check_edges(enhanced, 4, 7, 1, False) # this is the new edge check_edges(enhanced, 7, 4, 1, True) ================================================ FILE: stanza/tests/setup.py ================================================ import glob import logging import os import shutil import stanza from stanza.resources import installation from stanza.tests import TEST_HOME_VAR, TEST_WORKING_DIR logger = logging.getLogger('stanza') test_dir = os.getenv(TEST_HOME_VAR, None) if not test_dir: test_dir = TEST_WORKING_DIR logger.info("STANZA_TEST_HOME not set. Will assume %s", test_dir) logger.info("To use a different directory, export or set STANZA_TEST_HOME=...") in_dir = os.path.join(test_dir, "in") out_dir = os.path.join(test_dir, "out") scripts_dir = os.path.join(test_dir, "scripts") models_dir=os.path.join(test_dir, "models") corenlp_dir=os.path.join(test_dir, "corenlp_dir") os.makedirs(test_dir, exist_ok=True) os.makedirs(in_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True) os.makedirs(scripts_dir, exist_ok=True) os.makedirs(models_dir, exist_ok=True) os.makedirs(corenlp_dir, exist_ok=True) logger.info("COPYING FILES") shutil.copy("stanza/tests/data/external_server.properties", scripts_dir) shutil.copy("stanza/tests/data/example_french.json", out_dir) shutil.copy("stanza/tests/data/aws_annotations.zip", in_dir) for emb_file in glob.glob("stanza/tests/data/tiny_emb.*"): shutil.copy(emb_file, in_dir) logger.info("DOWNLOADING MODELS") stanza.download(lang='en', model_dir=models_dir, logging_level='info') stanza.download(lang="en", model_dir=models_dir, package=None, processors={"ner":"ncbi_disease"}) stanza.download(lang='fr', model_dir=models_dir, logging_level='info') # Latin ITTB has no case information for the lemmatizer stanza.download(lang='he', model_dir=models_dir, processors='tokenize', logging_level='info') stanza.download(lang='la', model_dir=models_dir, package='ittb', logging_level='info') stanza.download(lang='zh', model_dir=models_dir, logging_level='info') # useful not just for verifying RtL, but because the default Arabic has a unique style of xpos tags stanza.download(lang='ar', model_dir=models_dir, logging_level='info') stanza.download(lang='multilingual', model_dir=models_dir, logging_level='info') logger.info("DOWNLOADING STANZA TOKENIZERS FOR MORPHSEG TESTS") morphseg_langs = ['en', 'es', 'ru', 'fr', 'it', 'cs', 'hu', 'la'] for lang in morphseg_langs: stanza.download(lang=lang, model_dir=models_dir, processors='tokenize', logging_level='info') logger.info(f"Downloaded {lang} tokenizer for morphseg tests") logger.info("DOWNLOADING CORENLP") installation.install_corenlp(dir=corenlp_dir) installation.download_corenlp_models(model="french", version="main", dir=corenlp_dir) installation.download_corenlp_models(model="german", version="main", dir=corenlp_dir) installation.download_corenlp_models(model="italian", version="main", dir=corenlp_dir) installation.download_corenlp_models(model="spanish", version="main", dir=corenlp_dir) logger.info("Test setup completed.") ================================================ FILE: stanza/tests/tokenization/__init__.py ================================================ ================================================ FILE: stanza/tests/tokenization/test_prepare_tokenizer_treebank.py ================================================ import pytest import stanza from stanza.tests import * from stanza.utils.datasets import prepare_tokenizer_treebank pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_has_space_after_no(): assert prepare_tokenizer_treebank.has_space_after_no("SpaceAfter=No") assert prepare_tokenizer_treebank.has_space_after_no("UnbanMoxOpal=Yes|SpaceAfter=No") assert prepare_tokenizer_treebank.has_space_after_no("SpaceAfter=No|UnbanMoxOpal=Yes") assert not prepare_tokenizer_treebank.has_space_after_no("SpaceAfter=Yes") assert not prepare_tokenizer_treebank.has_space_after_no("CorrectSpaceAfter=No") assert not prepare_tokenizer_treebank.has_space_after_no("_") def test_add_space_after_no(): assert prepare_tokenizer_treebank.add_space_after_no("_") == "SpaceAfter=No" assert prepare_tokenizer_treebank.add_space_after_no("MoxOpal=Unban") == "MoxOpal=Unban|SpaceAfter=No" with pytest.raises(ValueError): prepare_tokenizer_treebank.add_space_after_no("SpaceAfter=No") def test_remove_space_after_no(): assert prepare_tokenizer_treebank.remove_space_after_no("SpaceAfter=No") == "_" assert prepare_tokenizer_treebank.remove_space_after_no("SpaceAfter=No|MoxOpal=Unban") == "MoxOpal=Unban" assert prepare_tokenizer_treebank.remove_space_after_no("MoxOpal=Unban|SpaceAfter=No") == "MoxOpal=Unban" with pytest.raises(ValueError): prepare_tokenizer_treebank.remove_space_after_no("_") def read_test_doc(doc): sentences = [x.strip().split("\n") for x in doc.split("\n\n")] return sentences SPANISH_QM_TEST_CASE = """ # sent_id = train-s7914 # text = ¿Cómo explicarles entonces que el mar tiene varios dueños y que a partir de la frontera de aquella ola el pescado ya no es tuyo?. # orig_file_sentence 080#14 # this sentence will have the intiial ¿ removed. an MWT should be preserved 1 ¿ ¿ PUNCT _ PunctSide=Ini|PunctType=Qest 3 punct _ SpaceAfter=No 2 Cómo cómo PRON _ PronType=Ind 3 obl _ _ 3-4 explicarles _ _ _ _ _ _ _ _ 3 explicar explicar VERB _ VerbForm=Inf 0 root _ _ 4 les él PRON _ Case=Dat|Number=Plur|Person=3|PronType=Prs 3 obj _ _ 5 entonces entonces ADV _ _ 3 advmod _ _ 6 que que SCONJ _ _ 9 mark _ _ 7 el el DET _ Definite=Def|Gender=Masc|Number=Sing|PronType=Art 8 det _ _ 8 mar mar NOUN _ Number=Sing 9 nsubj _ _ 9 tiene tener VERB _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 ccomp _ _ 10 varios varios DET _ Gender=Masc|Number=Plur|PronType=Ind 11 det _ _ 11 dueños dueño NOUN _ Gender=Masc|Number=Plur 9 obj _ _ 12 y y CCONJ _ _ 27 cc _ _ 13 que que SCONJ _ _ 27 mark _ _ 14 a a ADP _ _ 18 case _ MWE=a_partir_de|MWEPOS=ADP 15 partir partir NOUN _ _ 14 fixed _ _ 16 de de ADP _ _ 14 fixed _ _ 17 la el DET _ Definite=Def|Gender=Fem|Number=Sing|PronType=Art 18 det _ _ 18 frontera frontera NOUN _ Gender=Fem|Number=Sing 27 obl _ _ 19 de de ADP _ _ 21 case _ _ 20 aquella aquel DET _ Gender=Fem|Number=Sing|PronType=Dem 21 det _ _ 21 ola ola NOUN _ Gender=Fem|Number=Sing 18 nmod _ _ 22 el el DET _ Definite=Def|Gender=Masc|Number=Sing|PronType=Art 23 det _ _ 23 pescado pescado NOUN _ Gender=Masc|Number=Sing 27 nsubj _ _ 24 ya ya ADV _ _ 27 advmod _ _ 25 no no ADV _ Polarity=Neg 27 advmod _ _ 26 es ser AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 27 cop _ _ 27 tuyo tuyo PRON _ Gender=Masc|Number=Sing|Number[psor]=Sing|Person=2|Poss=Yes|PronType=Ind 9 conj _ SpaceAfter=No 28 ? ? PUNCT _ PunctSide=Fin|PunctType=Qest 3 punct _ SpaceAfter=No 29 . . PUNCT _ PunctType=Peri 3 punct _ _ # sent_id = train-s8516 # text = ¿ Pero es divertido en la vida real? - -. # orig_file_sentence 086#16 # this sentence will have the ¿ removed even with no SpaceAfter=No 1 ¿ ¿ PUNCT _ PunctSide=Ini|PunctType=Qest 4 punct _ _ 2 Pero pero CCONJ _ _ 4 advmod _ _ 3 es ser AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _ 4 divertido divertido ADJ _ Gender=Masc|Number=Sing|VerbForm=Part 0 root _ _ 5 en en ADP _ _ 7 case _ _ 6 la el DET _ Definite=Def|Gender=Fem|Number=Sing|PronType=Art 7 det _ _ 7 vida vida NOUN _ Gender=Fem|Number=Sing 4 obl _ _ 8 real real ADJ _ Number=Sing 7 amod _ SpaceAfter=No 9 ? ? PUNCT _ PunctSide=Fin|PunctType=Qest 4 punct _ _ 10 - - PUNCT _ PunctType=Dash 4 punct _ _ 11 - - PUNCT _ PunctType=Dash 4 punct _ SpaceAfter=No 12 . . PUNCT _ PunctType=Peri 4 punct _ _ # sent_id = train-s2337 # text = Es imposible. # orig_file_sentence 024#37 # Also included is a sentence which should be skipped (note that it does not show up in the expected result) 1 Es ser AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 2 cop _ _ 2 imposible imposible ADJ _ Number=Sing 0 root _ SpaceAfter=No 3 . . PUNCT _ PunctType=Peri 2 punct _ _ # sent_id = 3LB-CAST-a1-2-s6 # text = ¿Para qué seguir? # orig_file_sentence 006#22 # The treebank now includes basic dependencies in the additional dependencies column 1 ¿ ¿ PUNCT fia PunctSide=Ini|PunctType=Qest 4 punct 4:punct SpaceAfter=No 2 Para para ADP sps00 _ 3 case 3:case _ 3 qué qué PRON pt0cs000 Number=Sing|PronType=Int,Rel 4 obl 4:obl _ 4 seguir seguir VERB vmn0000 VerbForm=Inf 0 root 0:root SpaceAfter=No 5 ? ? PUNCT fit PunctSide=Fin|PunctType=Qest 4 punct 4:punct _ # sent_id = CESS-CAST-P-19990901-16-s19 # text = ¿Estará fingiendo?. # orig_file_sentence 097#24 # also it includes some copy nodes 1 ¿ ¿ PUNCT fia PunctSide=Ini|PunctType=Qest 3 punct 3:punct SpaceAfter=No 2 Estará estar AUX vmif3s0 Mood=Ind|Number=Sing|Person=3|Tense=Fut|VerbForm=Fin 3 aux 3:aux _ 3 fingiendo fingir VERB vmg0000 VerbForm=Ger 0 root 0:root SpaceAfter=No 3.1 _ _ PRON p _ _ _ 3:nsubj Entity=(CESSCASTP1999090116c2-person-1-CorefType:ident,gstype:spec) 4 ? ? PUNCT fit PunctSide=Fin|PunctType=Qest 3 punct 3:punct SpaceAfter=No 5 . . PUNCT fp PunctType=Peri 3 punct 3:punct _ # sent_id = CESS-CAST-P-20000401-126-s31 # text = ¿Qué pensó cuando se quedó # orig_file_sentence 087#37 # this one has a colon in the dependency name 1 ¿ ¿ PUNCT fia PunctSide=Ini|PunctType=Qest 3 punct 3:punct SpaceAfter=No|Entity=(CESSCASTP20000401126c27--3 2 Qué qué PRON pt0cs000 Number=Sing|PronType=Int,Rel 3 obj 3:obj _ 3 pensó pensar VERB vmis3s0 Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 3.1 _ _ PRON p _ _ _ 3:nsubj Entity=(CESSCASTP20000401126c1-person-1-CorefType:ident,gstype:spec) 4 cuando cuando SCONJ cs _ 6 mark 6:mark _ 4.1 _ _ PRON p _ _ _ 6:nsubj Entity=(CESSCASTP20000401126c1-person-1-CorefType:ident,gstype:spec) 5 se él PRON p0300000 Case=Acc|Person=3|PrepCase=Npr|PronType=Prs|Reflex=Yes 6 expl:pv 6:expl:pv _ 6 quedó quedar VERB vmis3s0 Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 advcl 3:advcl _ """ SPANISH_QM_RESULT = """ # sent_id = train-s7914 # text = Cómo explicarles entonces que el mar tiene varios dueños y que a partir de la frontera de aquella ola el pescado ya no es tuyo?. # orig_file_sentence 080#14 # this sentence will have the intiial ¿ removed. an MWT should be preserved 1 Cómo cómo PRON _ PronType=Ind 2 obl _ _ 2-3 explicarles _ _ _ _ _ _ _ _ 2 explicar explicar VERB _ VerbForm=Inf 0 root _ _ 3 les él PRON _ Case=Dat|Number=Plur|Person=3|PronType=Prs 2 obj _ _ 4 entonces entonces ADV _ _ 2 advmod _ _ 5 que que SCONJ _ _ 8 mark _ _ 6 el el DET _ Definite=Def|Gender=Masc|Number=Sing|PronType=Art 7 det _ _ 7 mar mar NOUN _ Number=Sing 8 nsubj _ _ 8 tiene tener VERB _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 2 ccomp _ _ 9 varios varios DET _ Gender=Masc|Number=Plur|PronType=Ind 10 det _ _ 10 dueños dueño NOUN _ Gender=Masc|Number=Plur 8 obj _ _ 11 y y CCONJ _ _ 26 cc _ _ 12 que que SCONJ _ _ 26 mark _ _ 13 a a ADP _ _ 17 case _ MWE=a_partir_de|MWEPOS=ADP 14 partir partir NOUN _ _ 13 fixed _ _ 15 de de ADP _ _ 13 fixed _ _ 16 la el DET _ Definite=Def|Gender=Fem|Number=Sing|PronType=Art 17 det _ _ 17 frontera frontera NOUN _ Gender=Fem|Number=Sing 26 obl _ _ 18 de de ADP _ _ 20 case _ _ 19 aquella aquel DET _ Gender=Fem|Number=Sing|PronType=Dem 20 det _ _ 20 ola ola NOUN _ Gender=Fem|Number=Sing 17 nmod _ _ 21 el el DET _ Definite=Def|Gender=Masc|Number=Sing|PronType=Art 22 det _ _ 22 pescado pescado NOUN _ Gender=Masc|Number=Sing 26 nsubj _ _ 23 ya ya ADV _ _ 26 advmod _ _ 24 no no ADV _ Polarity=Neg 26 advmod _ _ 25 es ser AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 26 cop _ _ 26 tuyo tuyo PRON _ Gender=Masc|Number=Sing|Number[psor]=Sing|Person=2|Poss=Yes|PronType=Ind 8 conj _ SpaceAfter=No 27 ? ? PUNCT _ PunctSide=Fin|PunctType=Qest 2 punct _ SpaceAfter=No 28 . . PUNCT _ PunctType=Peri 2 punct _ _ # sent_id = train-s8516 # text = Pero es divertido en la vida real? - -. # orig_file_sentence 086#16 # this sentence will have the ¿ removed even with no SpaceAfter=No 1 Pero pero CCONJ _ _ 3 advmod _ _ 2 es ser AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _ 3 divertido divertido ADJ _ Gender=Masc|Number=Sing|VerbForm=Part 0 root _ _ 4 en en ADP _ _ 6 case _ _ 5 la el DET _ Definite=Def|Gender=Fem|Number=Sing|PronType=Art 6 det _ _ 6 vida vida NOUN _ Gender=Fem|Number=Sing 3 obl _ _ 7 real real ADJ _ Number=Sing 6 amod _ SpaceAfter=No 8 ? ? PUNCT _ PunctSide=Fin|PunctType=Qest 3 punct _ _ 9 - - PUNCT _ PunctType=Dash 3 punct _ _ 10 - - PUNCT _ PunctType=Dash 3 punct _ SpaceAfter=No 11 . . PUNCT _ PunctType=Peri 3 punct _ _ # sent_id = 3LB-CAST-a1-2-s6 # text = Para qué seguir? # orig_file_sentence 006#22 # The treebank now includes basic dependencies in the additional dependencies column 1 Para para ADP sps00 _ 2 case 2:case _ 2 qué qué PRON pt0cs000 Number=Sing|PronType=Int,Rel 3 obl 3:obl _ 3 seguir seguir VERB vmn0000 VerbForm=Inf 0 root 0:root SpaceAfter=No 4 ? ? PUNCT fit PunctSide=Fin|PunctType=Qest 3 punct 3:punct _ # sent_id = CESS-CAST-P-19990901-16-s19 # text = Estará fingiendo?. # orig_file_sentence 097#24 # also it includes some copy nodes 1 Estará estar AUX vmif3s0 Mood=Ind|Number=Sing|Person=3|Tense=Fut|VerbForm=Fin 2 aux 2:aux _ 2 fingiendo fingir VERB vmg0000 VerbForm=Ger 0 root 0:root SpaceAfter=No 2.1 _ _ PRON p _ _ _ 2:nsubj Entity=(CESSCASTP1999090116c2-person-1-CorefType:ident,gstype:spec) 3 ? ? PUNCT fit PunctSide=Fin|PunctType=Qest 2 punct 2:punct SpaceAfter=No 4 . . PUNCT fp PunctType=Peri 2 punct 2:punct _ # sent_id = CESS-CAST-P-20000401-126-s31 # text = Qué pensó cuando se quedó # orig_file_sentence 087#37 # this one has a colon in the dependency name 1 Qué qué PRON pt0cs000 Number=Sing|PronType=Int,Rel 2 obj 2:obj _ 2 pensó pensar VERB vmis3s0 Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ 2.1 _ _ PRON p _ _ _ 2:nsubj Entity=(CESSCASTP20000401126c1-person-1-CorefType:ident,gstype:spec) 3 cuando cuando SCONJ cs _ 5 mark 5:mark _ 3.1 _ _ PRON p _ _ _ 5:nsubj Entity=(CESSCASTP20000401126c1-person-1-CorefType:ident,gstype:spec) 4 se él PRON p0300000 Case=Acc|Person=3|PrepCase=Npr|PronType=Prs|Reflex=Yes 5 expl:pv 5:expl:pv _ 5 quedó quedar VERB vmis3s0 Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 2 advcl 2:advcl _ """ def test_augment_initial_punct(): doc = read_test_doc(SPANISH_QM_TEST_CASE) doc2 = prepare_tokenizer_treebank.augment_initial_punct(doc, ratio=1.0) expected = doc + read_test_doc(SPANISH_QM_RESULT) assert doc2 == expected SPANISH_SHOULD_THROW = """ # sent_id = 3LB-CAST-a1-2-s6 # text = ¿Para qué seguir? # orig_file_sentence 006#22 # multiple heads are not handled yet in the augmented dependencies column 1 ¿ ¿ PUNCT fia PunctSide=Ini|PunctType=Qest 4 punct 4:punct SpaceAfter=No 2 Para para ADP sps00 _ 3 case 3:case _ 3 qué qué PRON pt0cs000 Number=Sing|PronType=Int,Rel 4 obl 4:obl,3:foo _ 4 seguir seguir VERB vmn0000 VerbForm=Inf 0 root 0:root SpaceAfter=No 5 ? ? PUNCT fit PunctSide=Fin|PunctType=Qest 4 punct 4:punct _ """ def test_augment_initial_punct_error(): """ The augment script should protect against the single dependency assumption changing in the future """ doc = read_test_doc(SPANISH_SHOULD_THROW) with pytest.raises(NotImplementedError): doc2 = prepare_tokenizer_treebank.augment_initial_punct(doc, ratio=1.0) # first sentence should have the space added # second sentence should be unchanged ARABIC_SPACE_AFTER_TEST_CASE = """ # newpar id = afp.20000815.0079:p6 # sent_id = afp.20000815.0079:p6u1 # text = وتتميز امسية الاربعاء الدولية باقامة 16 مباراة ودية. # orig_file_sentence AFP_ARB_20000815.0079#6 1-2 وتتميز _ _ _ _ _ _ _ _ 1 و وَ CCONJ C--------- _ 0 root 0:root Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa 2 تتميز تَمَيَّز VERB VIIA-3FS-- Aspect=Imp|Gender=Fem|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act 1 parataxis 1:parataxis Vform=تَتَمَيَّزُ|Gloss=be_distinguished,stand_out,discern,distinguish|Root=m_y_z|Translit=tatamayyazu|LTranslit=tamayyaz 3 امسية أُمسِيَّة NOUN N------S1R Case=Nom|Definite=Cons|Number=Sing 2 nsubj 2:nsubj Vform=أُمسِيَّةُ|Gloss=evening,soiree|Root=m_s_w|Translit=ʾumsīyatu|LTranslit=ʾumsīyat 4 الاربعاء أَربِعَاء NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 3 nmod 3:nmod:gen Vform=اَلأَربِعَاءِ|Gloss=Wednesday|Root=r_b_`|Translit=al-ʾarbiʿāʾi|LTranslit=ʾarbiʿāʾ 5 الدولية دُوَلِيّ ADJ A-----FS1D Case=Nom|Definite=Def|Gender=Fem|Number=Sing 3 amod 3:amod Vform=اَلدُّوَلِيَّةُ|Gloss=international,world|Root=d_w_l|Translit=ad-duwalīyatu|LTranslit=duwalīy 6-7 باقامة _ _ _ _ _ _ _ _ 6 ب بِ ADP P--------- AdpType=Prep 7 case 7:case Vform=بِ|Gloss=by,with|Root=bi|Translit=bi|LTranslit=bi 7 إقامة إِقَامَة NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 2 obl 2:obl:بِ:gen Vform=إِقَامَةِ|Gloss=residency,setting_up|Root=q_w_m|Translit=ʾiqāmati|LTranslit=ʾiqāmat 8 16 16 NUM Q--------- NumForm=Digit 7 nummod 7:nummod Vform=١٦|Translit=16 9 مباراة مُبَارَاة NOUN N------S4I Case=Acc|Definite=Ind|Number=Sing 8 nmod 8:nmod:acc Vform=مُبَارَاةً|Gloss=match,game,competition|Root=b_r_y|Translit=mubārātan|LTranslit=mubārāt 10 ودية وُدِّيّ ADJ A-----FS4I Case=Acc|Definite=Ind|Gender=Fem|Number=Sing 9 amod 9:amod SpaceAfter=No|Vform=وُدِّيَّةً|Gloss=friendly,amicable|Root=w_d_d|Translit=wuddīyatan|LTranslit=wuddīy 11 . . PUNCT G--------- _ 1 punct 1:punct Vform=.|Translit=. # newdoc id = afp.20000715.0075 # newpar id = afp.20000715.0075:p1 # sent_id = afp.20000715.0075:p1u1 # text = برلين ترفض حصول شركة اميركية على رخصة تصنيع دبابة "ليوبارد" الالمانية # orig_file_sentence AFP_ARB_20000715.0075#1 1 برلين بَرلِين X X--------- Foreign=Yes 2 nsubj 2:nsubj Vform=بَرلِين|Gloss=Berlin|Root=barlIn|Translit=barlīn|LTranslit=barlīn 2 ترفض رَفَض VERB VIIA-3FS-- Aspect=Imp|Gender=Fem|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act 0 root 0:root Vform=تَرفُضُ|Gloss=reject,refuse|Root=r_f_.d|Translit=tarfuḍu|LTranslit=rafaḍ 3 حصول حُصُول NOUN N------S4R Case=Acc|Definite=Cons|Number=Sing 2 obj 2:obj Vform=حُصُولَ|Gloss=acquisition,obtaining,occurrence,happening|Root=.h_.s_l|Translit=ḥuṣūla|LTranslit=ḥuṣūl 4 شركة شَرِكَة NOUN N------S2I Case=Gen|Definite=Ind|Number=Sing 3 nmod 3:nmod:gen Vform=شَرِكَةٍ|Gloss=company,corporation|Root=^s_r_k|Translit=šarikatin|LTranslit=šarikat 5 اميركية أَمِيرِكِيّ ADJ A-----FS2I Case=Gen|Definite=Ind|Gender=Fem|Number=Sing 4 amod 4:amod Vform=أَمِيرِكِيَّةٍ|Gloss=American|Root='amIrik|Translit=ʾamīrikīyatin|LTranslit=ʾamīrikīy 6 على عَلَى ADP P--------- AdpType=Prep 7 case 7:case Vform=عَلَى|Gloss=on,above|Root=`_l_w|Translit=ʿalā|LTranslit=ʿalā 7 رخصة رُخصَة NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 3 obl:arg 3:obl:arg:عَلَى:gen Vform=رُخصَةِ|Gloss=license,permit|Root=r__h_.s|Translit=ruḫṣati|LTranslit=ruḫṣat 8 تصنيع تَصنِيع NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 7 nmod 7:nmod:gen Vform=تَصنِيعِ|Gloss=fabrication,industrialization,processing|Root=.s_n_`|Translit=taṣnīʿi|LTranslit=taṣnīʿ 9 دبابة دَبَّابَة NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 8 nmod 8:nmod:gen Vform=دَبَّابَةِ|Gloss=tank|Root=d_b_b|Translit=dabbābati|LTranslit=dabbābat 10 " " PUNCT G--------- _ 11 punct 11:punct SpaceAfter=No|Vform="|Translit=" 11 ليوبارد لِيُوبَارد X X--------- Foreign=Yes 9 nmod 9:nmod SpaceAfter=No|Vform=لِيُوبَارد|Gloss=Leopard|Root=liyUbArd|Translit=liyūbārd|LTranslit=liyūbārd 12 " " PUNCT G--------- _ 11 punct 11:punct Vform="|Translit=" 13 الالمانية أَلمَانِيّ ADJ A-----FS2D Case=Gen|Definite=Def|Gender=Fem|Number=Sing 9 amod 9:amod Vform=اَلأَلمَانِيَّةِ|Gloss=German|Root='almAn|Translit=al-ʾalmānīyati|LTranslit=ʾalmānīy """ ARABIC_SPACE_AFTER_RESULT = """ # newpar id = afp.20000815.0079:p6 # sent_id = afp.20000815.0079:p6u1 # text = وتتميز امسية الاربعاء الدولية باقامة 16 مباراة ودية . # orig_file_sentence AFP_ARB_20000815.0079#6 1-2 وتتميز _ _ _ _ _ _ _ _ 1 و وَ CCONJ C--------- _ 0 root 0:root Vform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa 2 تتميز تَمَيَّز VERB VIIA-3FS-- Aspect=Imp|Gender=Fem|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act 1 parataxis 1:parataxis Vform=تَتَمَيَّزُ|Gloss=be_distinguished,stand_out,discern,distinguish|Root=m_y_z|Translit=tatamayyazu|LTranslit=tamayyaz 3 امسية أُمسِيَّة NOUN N------S1R Case=Nom|Definite=Cons|Number=Sing 2 nsubj 2:nsubj Vform=أُمسِيَّةُ|Gloss=evening,soiree|Root=m_s_w|Translit=ʾumsīyatu|LTranslit=ʾumsīyat 4 الاربعاء أَربِعَاء NOUN N------S2D Case=Gen|Definite=Def|Number=Sing 3 nmod 3:nmod:gen Vform=اَلأَربِعَاءِ|Gloss=Wednesday|Root=r_b_`|Translit=al-ʾarbiʿāʾi|LTranslit=ʾarbiʿāʾ 5 الدولية دُوَلِيّ ADJ A-----FS1D Case=Nom|Definite=Def|Gender=Fem|Number=Sing 3 amod 3:amod Vform=اَلدُّوَلِيَّةُ|Gloss=international,world|Root=d_w_l|Translit=ad-duwalīyatu|LTranslit=duwalīy 6-7 باقامة _ _ _ _ _ _ _ _ 6 ب بِ ADP P--------- AdpType=Prep 7 case 7:case Vform=بِ|Gloss=by,with|Root=bi|Translit=bi|LTranslit=bi 7 إقامة إِقَامَة NOUN N------S2R Case=Gen|Definite=Cons|Number=Sing 2 obl 2:obl:بِ:gen Vform=إِقَامَةِ|Gloss=residency,setting_up|Root=q_w_m|Translit=ʾiqāmati|LTranslit=ʾiqāmat 8 16 16 NUM Q--------- NumForm=Digit 7 nummod 7:nummod Vform=١٦|Translit=16 9 مباراة مُبَارَاة NOUN N------S4I Case=Acc|Definite=Ind|Number=Sing 8 nmod 8:nmod:acc Vform=مُبَارَاةً|Gloss=match,game,competition|Root=b_r_y|Translit=mubārātan|LTranslit=mubārāt 10 ودية وُدِّيّ ADJ A-----FS4I Case=Acc|Definite=Ind|Gender=Fem|Number=Sing 9 amod 9:amod Vform=وُدِّيَّةً|Gloss=friendly,amicable|Root=w_d_d|Translit=wuddīyatan|LTranslit=wuddīy 11 . . PUNCT G--------- _ 1 punct 1:punct Vform=.|Translit=. """ def test_augment_space_final_punct(): doc = read_test_doc(ARABIC_SPACE_AFTER_TEST_CASE) doc2 = prepare_tokenizer_treebank.augment_arabic_padt(doc, ratio=1.0) expected = doc + read_test_doc(ARABIC_SPACE_AFTER_RESULT) assert doc2 == expected ENGLISH_COMMA_SWAP_TEST_CASE=""" # sent_id = reviews-086839-0004 # text = Approx 4 months later, the compressor went out. 1 Approx approx ADV RB _ 3 advmod 3:advmod _ 2 4 4 NUM CD NumType=Card 3 nummod 3:nummod _ 3 months month NOUN NNS Number=Plur 4 obl:npmod 4:obl:npmod _ 4 later late ADV RBR Degree=Cmp 8 advmod 8:advmod SpaceAfter=No 5 , , PUNCT , _ 8 punct 8:punct _ 6 the the DET DT Definite=Def|PronType=Art 7 det 7:det _ 7 compressor compressor NOUN NN Number=Sing 8 nsubj 8:nsubj _ 8 went go VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 0 root 0:root _ 9 out out ADP RP _ 8 compound:prt 8:compound:prt SpaceAfter=No 10 . . PUNCT . _ 8 punct 8:punct _ # sent_id = reviews-086839-0004b # text = Approx 4 months later , the compressor went out. 1 Approx approx ADV RB _ 3 advmod 3:advmod _ 2 4 4 NUM CD NumType=Card 3 nummod 3:nummod _ 3 months month NOUN NNS Number=Plur 4 obl:npmod 4:obl:npmod _ 4 later late ADV RBR Degree=Cmp 8 advmod 8:advmod _ 5 , , PUNCT , _ 8 punct 8:punct _ 6 the the DET DT Definite=Def|PronType=Art 7 det 7:det _ 7 compressor compressor NOUN NN Number=Sing 8 nsubj 8:nsubj _ 8 went go VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 0 root 0:root _ 9 out out ADP RP _ 8 compound:prt 8:compound:prt SpaceAfter=No 10 . . PUNCT . _ 8 punct 8:punct _ """ ENGLISH_COMMA_SWAP_RESULT=""" # sent_id = reviews-086839-0004 # text = Approx 4 months later ,the compressor went out. 1 Approx approx ADV RB _ 3 advmod 3:advmod _ 2 4 4 NUM CD NumType=Card 3 nummod 3:nummod _ 3 months month NOUN NNS Number=Plur 4 obl:npmod 4:obl:npmod _ 4 later late ADV RBR Degree=Cmp 8 advmod 8:advmod _ 5 , , PUNCT , _ 8 punct 8:punct SpaceAfter=No 6 the the DET DT Definite=Def|PronType=Art 7 det 7:det _ 7 compressor compressor NOUN NN Number=Sing 8 nsubj 8:nsubj _ 8 went go VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 0 root 0:root _ 9 out out ADP RP _ 8 compound:prt 8:compound:prt SpaceAfter=No 10 . . PUNCT . _ 8 punct 8:punct _ # sent_id = reviews-086839-0004b # text = Approx 4 months later , the compressor went out. 1 Approx approx ADV RB _ 3 advmod 3:advmod _ 2 4 4 NUM CD NumType=Card 3 nummod 3:nummod _ 3 months month NOUN NNS Number=Plur 4 obl:npmod 4:obl:npmod _ 4 later late ADV RBR Degree=Cmp 8 advmod 8:advmod _ 5 , , PUNCT , _ 8 punct 8:punct _ 6 the the DET DT Definite=Def|PronType=Art 7 det 7:det _ 7 compressor compressor NOUN NN Number=Sing 8 nsubj 8:nsubj _ 8 went go VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 0 root 0:root _ 9 out out ADP RP _ 8 compound:prt 8:compound:prt SpaceAfter=No 10 . . PUNCT . _ 8 punct 8:punct _ """ def test_augment_space_final_punct(): doc = read_test_doc(ENGLISH_COMMA_SWAP_TEST_CASE) doc2 = prepare_tokenizer_treebank.augment_move_comma(doc, ratio=1.0) expected = read_test_doc(ENGLISH_COMMA_SWAP_RESULT) assert doc2 == expected COMMA_SEP_TEST_CASE = """ # text = Fuzzy people, floating people 1 Fuzzy fuzzy ADJ JJ Degree=Pos 2 amod 2:amod _ 2 people people NOUN NNS Number=Plur 0 root 0:root SpaceAfter=No 3 , , PUNCT , _ 2 punct 2:punct _ 4 floating float VERB VBG VerbForm=Ger 5 amod 5:amod _ 5 people people NOUN NNS Number=Plur 2 appos 2:appos _ """ COMMA_SEP_TEST_EXPECTED = """ # text = Fuzzy people,floating people 1 Fuzzy fuzzy ADJ JJ Degree=Pos 2 amod 2:amod _ 2 people people NOUN NNS Number=Plur 0 root 0:root SpaceAfter=No 3 , , PUNCT , _ 2 punct 2:punct SpaceAfter=No 4 floating float VERB VBG VerbForm=Ger 5 amod 5:amod _ 5 people people NOUN NNS Number=Plur 2 appos 2:appos _ """ def test_augment_comma_separations(): doc = read_test_doc(COMMA_SEP_TEST_CASE) doc2 = prepare_tokenizer_treebank.augment_comma_separations(doc, ratio=1.0) assert len(doc2) == 2 expected = read_test_doc(COMMA_SEP_TEST_EXPECTED) assert doc2[1] == expected[0] ================================================ FILE: stanza/tests/tokenization/test_replace_long_tokens.py ================================================ """ Check to make sure long tokens are replaced with "UNK" by the tokenization processor """ import pytest import stanza from stanza.pipeline import tokenize_processor from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] def test_replace_long_tokens(): nlp = stanza.Pipeline(lang="en", download_method=None, model_dir=TEST_MODELS_DIR, processors="tokenize") test_str = "foo " + "x" * 10000 + " bar" res = nlp(test_str) assert res.sentences[0].words[1].text == tokenize_processor.TOKEN_TOO_LONG_REPLACEMENT def test_set_max_len(): nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'download_method': None, 'tokenize_max_seqlen': 20}) doc = nlp("This is a doc withaverylongtokenthatshouldbereplaced") assert len(doc.sentences) == 1 assert len(doc.sentences[0].words) == 5 assert doc.sentences[0].words[-1].text == tokenize_processor.TOKEN_TOO_LONG_REPLACEMENT ================================================ FILE: stanza/tests/tokenization/test_spaces.py ================================================ """ Test that when tokenizing a document, the Space annotations get set the way we expect """ import stanza from stanza.tests import TEST_MODELS_DIR EXPECTED_NO_MWT = """ # text = Jennifer has nice antennae. # sent_id = 0 1 Jennifer _ _ _ _ 0 _ _ SpacesBefore=\\s\\s|start_char=2|end_char=10 2 has _ _ _ _ 1 _ _ start_char=11|end_char=14 3 nice _ _ _ _ 2 _ _ start_char=15|end_char=19 4 antennae _ _ _ _ 3 _ _ SpaceAfter=No|start_char=20|end_char=28 5 . _ _ _ _ 4 _ _ SpacesAfter=\\s\\s|start_char=28|end_char=29 # text = Not very nice person, though. # sent_id = 1 1 Not _ _ _ _ 0 _ _ start_char=31|end_char=34 2 very _ _ _ _ 1 _ _ start_char=35|end_char=39 3 nice _ _ _ _ 2 _ _ start_char=40|end_char=44 4 person _ _ _ _ 3 _ _ SpaceAfter=No|start_char=45|end_char=51 5 , _ _ _ _ 4 _ _ start_char=51|end_char=52 6 though _ _ _ _ 5 _ _ SpaceAfter=No|start_char=53|end_char=59 7 . _ _ _ _ 6 _ _ SpacesAfter=\\s\\s|start_char=59|end_char=60 """.strip() def test_spaces_no_mwt(): """ Test what happens if the words in a document have SpacesBefore and/or After """ nlp = stanza.Pipeline(**{'processors': 'tokenize', 'download_method': None, 'dir': TEST_MODELS_DIR, 'lang': 'en'}) doc = nlp(" Jennifer has nice antennae. Not very nice person, though. ") result = "{:C}".format(doc) result = result.strip() assert EXPECTED_NO_MWT == result EXPECTED_MWT = """ # text = She's not a nice person. # sent_id = 0 1-2 She's _ _ _ _ _ _ _ SpacesBefore=\\s\\s|start_char=2|end_char=7 1 She _ _ _ _ 0 _ _ start_char=2|end_char=5 2 's _ _ _ _ 1 _ _ start_char=5|end_char=7 3 not _ _ _ _ 2 _ _ start_char=8|end_char=11 4 a _ _ _ _ 3 _ _ start_char=12|end_char=13 5 nice _ _ _ _ 4 _ _ start_char=14|end_char=18 6 person _ _ _ _ 5 _ _ SpaceAfter=No|start_char=19|end_char=25 7 . _ _ _ _ 6 _ _ SpacesAfter=\\s\\s|start_char=25|end_char=26 # text = However, the best antennae on the Cerritos are Jennifer's. # sent_id = 1 1 However _ _ _ _ 0 _ _ SpaceAfter=No|start_char=28|end_char=35 2 , _ _ _ _ 1 _ _ start_char=35|end_char=36 3 the _ _ _ _ 2 _ _ start_char=37|end_char=40 4 best _ _ _ _ 3 _ _ start_char=41|end_char=45 5 antennae _ _ _ _ 4 _ _ start_char=46|end_char=54 6 on _ _ _ _ 5 _ _ start_char=55|end_char=57 7 the _ _ _ _ 6 _ _ start_char=58|end_char=61 8 Cerritos _ _ _ _ 7 _ _ start_char=62|end_char=70 9 are _ _ _ _ 8 _ _ start_char=71|end_char=74 10-11 Jennifer's _ _ _ _ _ _ _ SpaceAfter=No|start_char=75|end_char=85 10 Jennifer _ _ _ _ 9 _ _ start_char=75|end_char=83 11 's _ _ _ _ 10 _ _ start_char=83|end_char=85 12 . _ _ _ _ 11 _ _ SpacesAfter=\\s\\s|start_char=85|end_char=86 """.strip() def test_spaces_mwt(): """ Similar to the above test, but now we test it with MWT """ nlp = stanza.Pipeline(**{'processors': 'tokenize', 'download_method': None, 'dir': TEST_MODELS_DIR, 'lang': 'en'}) doc = nlp(" She's not a nice person. However, the best antennae on the Cerritos are Jennifer's. ") result = "{:C}".format(doc) result = result.strip() assert EXPECTED_MWT == result ================================================ FILE: stanza/tests/tokenization/test_tokenization_lst20.py ================================================ import os import tempfile import pytest import stanza from stanza.tests import * from stanza.utils.datasets.common import convert_conllu_to_txt from stanza.utils.datasets.tokenization.convert_th_lst20 import read_document from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section pytestmark = [pytest.mark.travis, pytest.mark.pipeline] SMALL_LST_SAMPLE=""" สุรยุทธ์ NN B_PER B_CLS ยัน VV O I_CLS ปฏิเสธ VV O I_CLS ลงนาม VV O I_CLS _ PU O I_CLS MOU NN O I_CLS _ PU O I_CLS กับ PS O I_CLS อียู NN B_ORG I_CLS ไม่ NG O I_CLS กระทบ VV O I_CLS สัมพันธ์ NN O E_CLS 1 NU B_DTM B_CLS _ PU I_DTM I_CLS กันยายน NN I_DTM I_CLS _ PU I_DTM I_CLS 2550 NU E_DTM I_CLS _ PU O I_CLS 12:21 NU B_DTM I_CLS _ PU I_DTM I_CLS น. CL E_DTM E_CLS ผู้สื่อข่าว NN O B_CLS รายงาน VV O I_CLS เพิ่มเติม VV O I_CLS ว่า CC O E_CLS _ PU O O จาก PS O B_CLS การ FX O I_CLS ลง VV O I_CLS พื้นที่ NN O I_CLS พบ VV O I_CLS ว่า CC O E_CLS """.strip() EXPECTED_CONLLU=""" 1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes 2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 ลงนาม _ _ _ _ 3 dep 3:dep _ 5 MOU _ _ _ _ 4 dep 4:dep _ 6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No 7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No 8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No 9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No 10 สัมพันธ์ _ _ _ _ 9 dep 9:dep SpaceAfter=No 1 1 _ _ _ _ 0 root 0:root _ 2 กันยายน _ _ _ _ 1 dep 1:dep _ 3 2550 _ _ _ _ 2 dep 2:dep _ 4 12:21 _ _ _ _ 3 dep 3:dep _ 5 น. _ _ _ _ 4 dep 4:dep SpaceAfter=No 1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No 2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 ว่า _ _ _ _ 3 dep 3:dep _ 5 จาก _ _ _ _ 4 dep 4:dep SpaceAfter=No 6 การ _ _ _ _ 5 dep 5:dep SpaceAfter=No 7 ลง _ _ _ _ 6 dep 6:dep SpaceAfter=No 8 พื้นที่ _ _ _ _ 7 dep 7:dep SpaceAfter=No 9 พบ _ _ _ _ 8 dep 8:dep SpaceAfter=No 10 ว่า _ _ _ _ 9 dep 9:dep SpaceAfter=No """.strip() # Note: these DO NOT line up perfectly (in an emacs window, at least) # because Thai characters have a length greater than 1. # The lengths of the words are: # สุรยุทธ์ 8 # ยัน 3 # ปฏิเสธ 6 # ลงนาม 5 # MOU 3 # กับ 3 # อียู 4 # ไม่ 3 # กระทบ 5 # สัมพันธ์ 8 # 1 1 # กันยายน 7 # 2550 4 # 12:21 5 # น. 2 # ผู้สื่อข่าว 11 # รายงาน 6 # เพิ่มเติม 9 # ว่า 3 # จาก 3 # การ 3 # ลง 2 # พื้นที่ 7 # พบ 2 # ว่า 3 EXPECTED_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์1 กันยายน 2550 12:21 น.ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n" EXPECTED_LABELS = "000000010010000010000100010001000100100001000000021000000010000100000100200000000001000001000000001001000100101000000101002\n\n" # counting spaces 1234567812312345612345_123_123123412312345123456781_1234567_1234_12345_12123456789AB123456123456789123_12312312123456712123 # note that the word splits go on the final letter of the word in the # UD conllu datasets, so that is what we mimic here # for example, from EWT: # Al-Zaman : American forces killed Shaikh Abdullah # 0110000101000000001000000100000010000001000000001 def check_results(documents, expected_conllu, expected_txt, expected_labels): with tempfile.TemporaryDirectory() as output_dir: write_section(output_dir, "lst20", "train", documents) with open(os.path.join(output_dir, "th_lst20.train.gold.conllu")) as fin: conllu = fin.read().strip() with open(os.path.join(output_dir, "th_lst20.train.txt")) as fin: txt = fin.read() with open(os.path.join(output_dir, "th_lst20-ud-train.toklabels")) as fin: labels = fin.read() assert conllu == expected_conllu assert txt == expected_txt assert labels == expected_labels assert len(txt) == len(labels) def test_small(): """ A small test just to verify that the output is being produced as we want Note that there currently are no spaces after the first sentence. Apparently this is wrong, but weirdly, doing that makes the model even worse. """ lines = SMALL_LST_SAMPLE.strip().split("\n") documents = read_document(lines, spaces_after=False, split_clauses=False) check_results(documents, EXPECTED_CONLLU, EXPECTED_TXT, EXPECTED_LABELS) EXPECTED_SPACE_CONLLU=""" 1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes 2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 ลงนาม _ _ _ _ 3 dep 3:dep _ 5 MOU _ _ _ _ 4 dep 4:dep _ 6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No 7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No 8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No 9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No 10 สัมพันธ์ _ _ _ _ 9 dep 9:dep _ 1 1 _ _ _ _ 0 root 0:root _ 2 กันยายน _ _ _ _ 1 dep 1:dep _ 3 2550 _ _ _ _ 2 dep 2:dep _ 4 12:21 _ _ _ _ 3 dep 3:dep _ 5 น. _ _ _ _ 4 dep 4:dep _ 1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No 2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 ว่า _ _ _ _ 3 dep 3:dep _ 5 จาก _ _ _ _ 4 dep 4:dep SpaceAfter=No 6 การ _ _ _ _ 5 dep 5:dep SpaceAfter=No 7 ลง _ _ _ _ 6 dep 6:dep SpaceAfter=No 8 พื้นที่ _ _ _ _ 7 dep 7:dep SpaceAfter=No 9 พบ _ _ _ _ 8 dep 8:dep SpaceAfter=No 10 ว่า _ _ _ _ 9 dep 9:dep _ """.strip() EXPECTED_SPACE_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์ 1 กันยายน 2550 12:21 น. ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n" EXPECTED_SPACE_LABELS = "00000001001000001000010001000100010010000100000002010000000100001000001002000000000001000001000000001001000100101000000101002\n\n" def test_space_after(): """ This version of the test adds the space after attribute """ lines = SMALL_LST_SAMPLE.strip().split("\n") documents = read_document(lines, spaces_after=True, split_clauses=False) check_results(documents, EXPECTED_SPACE_CONLLU, EXPECTED_SPACE_TXT, EXPECTED_SPACE_LABELS) EXPECTED_CLAUSE_CONLLU=""" 1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes 2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 ลงนาม _ _ _ _ 3 dep 3:dep _ 5 MOU _ _ _ _ 4 dep 4:dep _ 6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No 7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No 8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No 9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No 10 สัมพันธ์ _ _ _ _ 9 dep 9:dep _ 1 1 _ _ _ _ 0 root 0:root _ 2 กันยายน _ _ _ _ 1 dep 1:dep _ 3 2550 _ _ _ _ 2 dep 2:dep _ 4 12:21 _ _ _ _ 3 dep 3:dep _ 5 น. _ _ _ _ 4 dep 4:dep _ 1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No 2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 ว่า _ _ _ _ 3 dep 3:dep _ 1 จาก _ _ _ _ 0 root 0:root SpaceAfter=No 2 การ _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 ลง _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 พื้นที่ _ _ _ _ 3 dep 3:dep SpaceAfter=No 5 พบ _ _ _ _ 4 dep 4:dep SpaceAfter=No 6 ว่า _ _ _ _ 5 dep 5:dep _ """.strip() EXPECTED_CLAUSE_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์ 1 กันยายน 2550 12:21 น. ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n" EXPECTED_CLAUSE_LABELS = "00000001001000001000010001000100010010000100000002010000000100001000001002000000000001000001000000001002000100101000000101002\n\n" def test_split_clause(): """ This version of the test also resplits on spaces between clauses """ lines = SMALL_LST_SAMPLE.strip().split("\n") documents = read_document(lines, spaces_after=True, split_clauses=True) check_results(documents, EXPECTED_CLAUSE_CONLLU, EXPECTED_CLAUSE_TXT, EXPECTED_CLAUSE_LABELS) if __name__ == "__main__": lines = SMALL_LST_SAMPLE.strip().split("\n") documents = read_document(lines, spaces_after=False, split_clauses=False) write_section("foo", "lst20", "train", documents) ================================================ FILE: stanza/tests/tokenization/test_tokenization_orchid.py ================================================ import os import tempfile import pytest import xml.etree.ElementTree as ET import stanza from stanza.tests import * from stanza.utils.datasets.common import convert_conllu_to_txt from stanza.utils.datasets.tokenization.convert_th_orchid import parse_xml from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section pytestmark = [pytest.mark.travis, pytest.mark.pipeline] SMALL_DOC=""" """ EXPECTED_RESULTS=""" 1 การ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes 2 ประชุม _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 ทาง _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 วิชาการ _ _ _ _ 3 dep 3:dep _ 5 ครั้ง _ _ _ _ 4 dep 4:dep SpaceAfter=No 6 ที่ 1 _ _ _ _ 5 dep 5:dep _ 1 โครงการวิจัยและพัฒนา _ _ _ _ 0 root 0:root SpaceAfter=No 2 อิเล็กทรอนิกส์ _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 และ _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 คอมพิวเตอร์ _ _ _ _ 3 dep 3:dep _ 1 วัน _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes 2 ที่ 15 _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 - _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 16 _ _ _ _ 3 dep 3:dep _ 5 สิงหาคม _ _ _ _ 4 dep 4:dep _ 6 2532 _ _ _ _ 5 dep 5:dep _ """.strip() EXPECTED_TEXT="""การประชุมทางวิชาการ ครั้งที่ 1 โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์ วันที่ 15-16 สิงหาคม 2532 """ EXPECTED_LABELS="""0010000010010000001000001000020000000000000000000010000000000000100100000000002 0010000011010000000100002 """ def check_results(documents, expected_conllu, expected_txt, expected_labels): with tempfile.TemporaryDirectory() as output_dir: write_section(output_dir, "orchid", "train", documents) with open(os.path.join(output_dir, "th_orchid.train.gold.conllu")) as fin: conllu = fin.read().strip() with open(os.path.join(output_dir, "th_orchid.train.txt")) as fin: txt = fin.read() with open(os.path.join(output_dir, "th_orchid-ud-train.toklabels")) as fin: labels = fin.read() assert conllu == expected_conllu assert txt == expected_txt assert labels == expected_labels assert len(txt) == len(labels) def test_orchid(): tree = ET.ElementTree(ET.fromstring(SMALL_DOC)) documents = parse_xml(tree) check_results(documents, EXPECTED_RESULTS, EXPECTED_TEXT, EXPECTED_LABELS) ================================================ FILE: stanza/tests/tokenization/test_tokenize_data.py ================================================ """ Very simple test of the mwt counting functionality in tokenization/data.py TODO: could add a bunch more simple tests, including tests of reading the data from a temp file, for example """ import pytest import tempfile import numpy as np import stanza from stanza import Pipeline from stanza.tests import * from stanza.models.tokenization.data import DataLoader, NUMERIC_RE pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def write_tokenizer_input(test_dir, raw_text, labels): """ Writes raw_text and labels to randomly named files in test_dir Note that the tempfiles are not set to automatically clean up. This will not be a problem if you put them in a tempdir. """ with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', dir=test_dir, delete=False) as fout: txt_file = fout.name fout.write(raw_text) with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', dir=test_dir, delete=False) as fout: label_file = fout.name fout.write(labels) return txt_file, label_file # A single slice of the German tokenization data with no MWT in it NO_MWT_TEXT = "Sehr gute Beratung, schnelle Behebung der Probleme" NO_MWT_LABELS = "00010000100000000110000000010000000010001000000002" # A single slice of the German tokenization data with an MWT in it MWT_TEXT = " Die Kosten sind definitiv auch im Rahmen." MWT_LABELS = "000100000010000100000000010000100300000012" FAKE_PROPERTIES = { "lang":"de", 'feat_funcs': ("space_before","capitalized"), 'max_seqlen': 300, 'use_dictionary': False, } def test_has_mwt(): """ One dataset has no mwt, the other does """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: txt_file, label_file = write_tokenizer_input(test_dir, NO_MWT_TEXT, NO_MWT_LABELS) data = DataLoader(args=FAKE_PROPERTIES, input_files={'txt': txt_file, 'label': label_file}) assert not data.has_mwt() txt_file, label_file = write_tokenizer_input(test_dir, MWT_TEXT, MWT_LABELS) data = DataLoader(args=FAKE_PROPERTIES, input_files={'txt': txt_file, 'label': label_file}) assert data.has_mwt() @pytest.fixture(scope="module") def tokenizer(): pipeline = Pipeline("en", dir=TEST_MODELS_DIR, download_method=None, processors="tokenize") tokenizer = pipeline.processors['tokenize'] return tokenizer @pytest.fixture(scope="module") def zhtok(): pipeline = Pipeline("zh-hans", dir=TEST_MODELS_DIR, download_method=None, processors="tokenize") tokenizer = pipeline.processors['tokenize'] return tokenizer EXPECTED_TWO_NL_RAW = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0)], [('f', 0), ('o', 0), ('o', 0)]] # in this test, the newline after test becomes a space labeled 0 EXPECTED_ONE_NL_RAW = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), (' ', 0), ('f', 0), ('o', 0), ('o', 0)]] EXPECTED_SKIP_NL_RAW = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), ('f', 0), ('o', 0), ('o', 0)]] def test_convert_units_raw_text(tokenizer): """ Tests converting a couple small segments to units """ raw_text = "This is a test\n\nfoo" batches = DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) assert batches.data == EXPECTED_TWO_NL_RAW raw_text = "This is a test\nfoo" batches = DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) assert batches.data == EXPECTED_ONE_NL_RAW skip_newline_config = dict(tokenizer.config) skip_newline_config['skip_newline'] = True batches = DataLoader(skip_newline_config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) assert batches.data == EXPECTED_SKIP_NL_RAW EXPECTED_TWO_NL_FILE = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), ('.', 1)], [('f', 0), ('o', 0), ('o', 0)]] EXPECTED_TWO_NL_FILE_LABELS = [np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=np.int32), np.array([0, 0, 0], dtype=np.int32)] # in this test, the newline after test becomes a space labeled 0 EXPECTED_ONE_NL_FILE = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), ('.', 1), (' ', 0), ('f', 0), ('o', 0), ('o', 0)]] EXPECTED_ONE_NL_FILE_LABELS = [np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int32)] EXPECTED_SKIP_NL_FILE = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), ('.', 1), ('f', 0), ('o', 0), ('o', 0)]] EXPECTED_SKIP_NL_FILE_LABELS = [np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], dtype=np.int32)] def check_labels(labels, expected_labels): assert len(labels) == len(expected_labels) for label, expected in zip(labels, expected_labels): assert np.array_equiv(label, expected) def test_convert_units_file(tokenizer): """ Tests reading some text from a file and converting that to units """ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir: # two nl test case, read from file labels = "00000000000000000001\n\n000\n\n" raw_text = "This is a test.\n\nfoo\n\n" txt_file, label_file = write_tokenizer_input(test_dir, raw_text, labels) batches = DataLoader(tokenizer.config, input_files={'txt': txt_file, 'label': label_file}, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) assert batches.data == EXPECTED_TWO_NL_FILE check_labels(batches.labels(), EXPECTED_TWO_NL_FILE_LABELS) # one nl test case, read from file labels = "000000000000000000010000\n\n" raw_text = "This is a test.\nfoo\n\n" txt_file, label_file = write_tokenizer_input(test_dir, raw_text, labels) batches = DataLoader(tokenizer.config, input_files={'txt': txt_file, 'label': label_file}, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) assert batches.data == EXPECTED_ONE_NL_FILE check_labels(batches.labels(), EXPECTED_ONE_NL_FILE_LABELS) skip_newline_config = dict(tokenizer.config) skip_newline_config['skip_newline'] = True labels = "000000000000000000010000\n\n" raw_text = "This is a test.\nfoo\n\n" txt_file, label_file = write_tokenizer_input(test_dir, raw_text, labels) batches = DataLoader(skip_newline_config, input_files={'txt': txt_file, 'label': label_file}, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) assert batches.data == EXPECTED_SKIP_NL_FILE check_labels(batches.labels(), EXPECTED_SKIP_NL_FILE_LABELS) def test_dictionary(zhtok): """ Tests some features of the zh tokenizer dictionary The expectation is that the Chinese tokenizer will be serialized with a dictionary (if it ever gets serialized without, this test will warn us!) """ assert zhtok.trainer.lexicon is not None assert zhtok.trainer.dictionary is not None assert "老师" in zhtok.trainer.lexicon # egg-white-stuff, eg protein assert "蛋白质" in zhtok.trainer.lexicon # egg-white assert "蛋白" in zhtok.trainer.dictionary['prefixes'] # egg assert "蛋" in zhtok.trainer.dictionary['prefixes'] # white-stuff assert "白质" in zhtok.trainer.dictionary['suffixes'] # stuff assert "质" in zhtok.trainer.dictionary['suffixes'] def test_dictionary_feats(zhtok): """ Test the results of running a sentence into the dictionary featurizer """ raw_text = "我想吃蛋白质" batches = DataLoader(zhtok.config, input_text=raw_text, vocab=zhtok.vocab, evaluation=True, dictionary=zhtok.trainer.dictionary) data = batches.data assert len(data) == 1 assert len(data[0]) == 6 expected_features = [ # in our example, the 2-grams made by the one character words at the start # don't form any prefixes or suffixes [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0], ] for i, expected in enumerate(expected_features): dict_features = batches.extract_dict_feat(data[0], i) assert dict_features == expected def test_numeric_re(): """ Test the "is numeric" function This function is entirely based on an RE in data.py """ # the last one is Thai matches = ["57", "135245345", "12535.", "852358.458345", "435345...345345", "111,,,111,,,111,,,111", "5318008", "5", "๕"] # note that we might want to consider .4 a numeric token after all # however, changing that means retraining all the models # the really long one only works if NUMERIC_RE avoids catastrophic backtracking not_matches = [".4", "54353a", "5453 35345", "aaa143234", "a,a,a,a", "sh'reyan", "asdaf786876asdfasdf", "", "11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111a"] for x in matches: assert NUMERIC_RE.match(x) is not None for x in not_matches: assert NUMERIC_RE.match(x) is None ================================================ FILE: stanza/tests/tokenization/test_tokenize_files.py ================================================ import pytest from stanza.models.tokenization import tokenize_files from stanza.tests import TEST_MODELS_DIR pytestmark = [pytest.mark.pipeline, pytest.mark.travis] EXPECTED = """ This is a test . This is a second sentence . I took my daughter ice skating """.lstrip() def test_tokenize_files(tmp_path): input_file = tmp_path / "input.txt" with open(input_file, "w") as fout: fout.write("This is a test. This is a second sentence.\n\nI took my daughter ice skating") output_file = tmp_path / "output.txt" tokenize_files.main([str(input_file), "--lang", "en", "--output_file", str(output_file), "--model_dir", TEST_MODELS_DIR]) with open(output_file) as fin: text = fin.read() assert EXPECTED == text ================================================ FILE: stanza/tests/tokenization/test_tokenize_utils.py ================================================ """ Very simple test of the sentence slicing by tags TODO: could add a bunch more simple tests for the tokenization utils """ import pytest import stanza from stanza import Pipeline from stanza.tests import * from stanza.models.common import doc from stanza.models.tokenization import data from stanza.models.tokenization import utils pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_find_spans(): """ Test various raw -> span manipulations """ raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l'] assert utils.find_spans(raw) == [(0, 14)] raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', ''] assert utils.find_spans(raw) == [(0, 14)] raw = ['', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', ''] assert utils.find_spans(raw) == [(1, 15)] raw = ['', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l'] assert utils.find_spans(raw) == [(1, 15)] raw = ['', 'u', 'n', 'b', 'a', 'n', '', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l'] assert utils.find_spans(raw) == [(1, 6), (7, 15)] def check_offsets(doc, expected_offsets): """ Compare the start_char and end_char of the tokens in the doc with the given list of list of offsets """ assert len(doc.sentences) == len(expected_offsets) for sentence, offsets in zip(doc.sentences, expected_offsets): assert len(sentence.tokens) == len(offsets) for token, offset in zip(sentence.tokens, offsets): assert token.start_char == offset[0] assert token.end_char == offset[1] def test_match_tokens_with_text(): """ Test the conversion of pretokenized text to Document """ doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisatest") expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)]] check_offsets(doc, expected_offsets) doc = utils.match_tokens_with_text([["This", "is", "a", "test"], ["unban", "mox", "opal", "!"]], "Thisisatest unban mox opal!") expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)], [(13, 18), (19, 22), (24, 28), (28, 29)]] check_offsets(doc, expected_offsets) with pytest.raises(ValueError): doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisatestttt") with pytest.raises(ValueError): doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisates") with pytest.raises(ValueError): doc = utils.match_tokens_with_text([["This", "iz", "a", "test"]], "Thisisatest") def test_long_paragraph(): """ Test the tokenizer's capacity to break text up into smaller chunks """ pipeline = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize") tokenizer = pipeline.processors['tokenize'] raw_text = "TIL not to ask a date to dress up as Smurfette on a first date. " * 100 # run a test to make sure the chunk operation is called # if not, the test isn't actually testing what we need to test batches = data.TokenizationDataset(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) batches.advance_old_batch = None with pytest.raises(TypeError): _, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000, orig_text=raw_text, no_ssplit=tokenizer.config.get('no_ssplit', False)) # a new DataLoader should not be crippled as the above one was batches = data.TokenizationDataset(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) _, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000, orig_text=raw_text, no_ssplit=tokenizer.config.get('no_ssplit', False)) document = doc.Document(document, raw_text) assert len(document.sentences) == 100 def test_postprocessor_application(): """ Check that the postprocessor behaves correctly by applying the identity postprocessor and hoping that it does indeed return correctly. """ good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']] text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken." target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]] def postprocesor(_): return good_tokenization res = utils.postprocess_doc(target_doc, postprocesor, text) assert res == target_doc def test_reassembly_indexing(): """ Check that the reassembly code counts the indicies correctly, and including OOV chars. """ good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']] good_mwts = [[False for _ in range(len(i))] for i in good_tokenization] good_expansions = [[None for _ in range(len(i))] for i in good_tokenization] text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken." target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]] res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text) assert res == target_doc def test_reassembly_reference_failures(): """ Check that the reassembly code complains correctly when the user adds tokens that doesn't exist """ bad_addition_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Southern', 'California', '.']] bad_addition_mwts = [[False for _ in range(len(bad_addition_tokenization[0]))]] bad_addition_expansions = [[None for _ in range(len(bad_addition_tokenization[0]))]] bad_inline_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Californiaa', '.']] bad_inline_mwts = [[False for _ in range(len(bad_inline_tokenization[0]))]] bad_inline_expansions = [[None for _ in range(len(bad_inline_tokenization[0]))]] good_tokenization = [['Joe', 'Smith', 'lives', 'in', 'California', '.']] good_mwts = [[False for _ in range(len(good_tokenization[0]))]] good_expansions = [[None for _ in range(len(good_tokenization[0]))]] text = "Joe Smith lives in California." with pytest.raises(ValueError): utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, bad_addition_expansions, text) with pytest.raises(ValueError): utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, bad_inline_mwts, text) utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text) TRAIN_DATA = """ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003 # text = DPA: Iraqi authorities announced that they'd busted up three terrorist cells operating in Baghdad. 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No 2 : : PUNCT : _ 1 punct 1:punct _ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _ 6 that that SCONJ IN _ 9 mark 9:mark _ 7-8 they'd _ _ _ _ _ _ _ _ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _ 8 'd have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _ 11 three three NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _ 15 in in ADP IN _ 16 case 16:case _ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No 17 . . PUNCT . _ 1 punct 1:punct _ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004 # text = Two of them were being run by 2 officials of the Ministry of the Interior! 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _ 2 of of ADP IN _ 3 case 3:case _ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ 7 by by ADP IN _ 9 case 9:case _ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _ 10 of of ADP IN _ 12 case 12:case _ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _ 13 of of ADP IN _ 15 case 15:case _ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No 16 ! ! PUNCT . _ 6 punct 6:punct _ """.lstrip() def test_lexicon_from_training_data(tmp_path): """ Test a couple aspects of building a lexicon from training data expected number of words eliminated for being too long duplicate words counted once numbers eliminated """ conllu_file = str(tmp_path / "train.conllu") with open(conllu_file, "w", encoding="utf-8") as fout: fout.write(TRAIN_DATA) lexicon, num_dict_feat = utils.create_lexicon("en_test", conllu_file) lexicon = sorted(lexicon) expected_lexicon = ["'d", 'announced', 'baghdad', 'being', 'busted', 'by', 'cells', 'dpa', 'in', 'interior', 'iraqi', 'ministry', 'of', 'officials', 'operating', 'run', 'terrorist', 'that', 'the', 'them', 'they', "they'd", 'three', 'two', 'up', 'were'] assert lexicon == expected_lexicon assert num_dict_feat == max(len(x) for x in lexicon) ================================================ FILE: stanza/tests/tokenization/test_vocab.py ================================================ import pytest from stanza.models.common.vocab import UNK, PAD from stanza.models.tokenization.vocab import Vocab pytestmark = [pytest.mark.travis, pytest.mark.pipeline] def test_build(): """ Test that building a vocab out of a text produces the expected units and ids in the vocab """ text = ["this is a test"] vocab = Vocab(data=text, lang="en") expected = {'', '', 't', 's', ' ', 'i', 'h', 'a', 'e'} assert expected == set(vocab._id2unit) for unit in vocab._id2unit: assert vocab.id2unit(vocab.unit2id(unit)) == unit def test_append(): text = ["this is a test"] vocab = Vocab(data=text, lang="en") assert 'z' not in vocab vocab.append('z') expected = {'', '', 't', 's', ' ', 'i', 'h', 'a', 'e', 'z'} assert expected == set(vocab._id2unit) for unit in vocab._id2unit: assert vocab.id2unit(vocab.unit2id(unit)) == unit ================================================ FILE: stanza/utils/__init__.py ================================================ ================================================ FILE: stanza/utils/avg_sent_len.py ================================================ import sys import json def avg_sent_len(toklabels): if toklabels.endswith('.json'): with open(toklabels, 'r') as f: l = json.load(f) l = [''.join([str(x[1]) for x in para]) for para in l] else: with open(toklabels, 'r') as f: l = ''.join(f.readlines()) l = l.split('\n\n') sentlen = [len(x) + 1 for para in l for x in para.split('2')] return sum(sentlen) / len(sentlen) if __name__ == '__main__': print(avg_sent_len(sys.args[1])) ================================================ FILE: stanza/utils/charlm/__init__.py ================================================ ================================================ FILE: stanza/utils/charlm/conll17_to_text.py ================================================ """ Turns a directory of conllu files from the conll 2017 shared task to a text file Part of the process for building a charlm dataset python conll17_to_text.py This is an extension of the original script: https://github.com/stanfordnlp/stanza-scripts/blob/master/charlm/conll17/conll2txt.py To build a new charlm for a new language from a conll17 dataset: - look for conll17 shared task data, possibly here: https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-1989 - python3 stanza/utils/charlm/conll17_to_text.py ~/extern_data/conll17/Bulgarian --output_directory extern_data/charlm_raw/bg/conll17 - python3 stanza/utils/charlm/make_lm_data.py --langs bg extern_data/charlm_raw extern_data/charlm/ """ import argparse import lzma import sys import os def process_file(input_filename, output_directory, compress): if not input_filename.endswith('.conllu') and not input_filename.endswith(".conllu.xz"): print("Skipping {}".format(input_filename)) return if input_filename.endswith(".xz"): open_fn = lambda x: lzma.open(x, mode='rt') output_filename = input_filename[:-3].replace(".conllu", ".txt") else: open_fn = lambda x: open(x) output_filename = input_filename.replace('.conllu', '.txt') if output_directory: output_filename = os.path.join(output_directory, os.path.split(output_filename)[1]) if compress: output_filename = output_filename + ".xz" output_fn = lambda x: lzma.open(x, mode='wt') else: output_fn = lambda x: open(x, mode='w') if os.path.exists(output_filename): print("Cowardly refusing to overwrite %s" % output_filename) return print("Converting %s to %s" % (input_filename, output_filename)) with open_fn(input_filename) as fin: sentences = [] sentence = [] for line in fin: line = line.strip() if len(line) == 0: # new sentence sentences.append(sentence) sentence = [] continue if line[0] == '#': # comment continue splitline = line.split('\t') assert(len(splitline) == 10) # correct conllu id, word = splitline[0], splitline[1] if '-' not in id: # not mwt token sentence.append(word) if sentence: sentences.append(sentence) print(" Read in {} sentences".format(len(sentences))) with output_fn(output_filename) as fout: fout.write('\n'.join([' '.join(sentence) for sentence in sentences])) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("input_directory", help="Root directory with conllu or conllu.xz files.") parser.add_argument("--output_directory", default=None, help="Directory to output to. Will output to input_directory if None") parser.add_argument("--no_xz_output", default=True, dest="xz_output", action="store_false", help="Output compressed xz files") args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() directory = args.input_directory filenames = sorted(os.listdir(directory)) print("Files to process in {}: {}".format(directory, filenames)) print("Processing to .xz files: {}".format(args.xz_output)) if args.output_directory: os.makedirs(args.output_directory, exist_ok=True) for filename in filenames: process_file(os.path.join(directory, filename), args.output_directory, args.xz_output) ================================================ FILE: stanza/utils/charlm/dump_oscar.py ================================================ """ This script downloads and extracts the text from an Oscar crawl on HuggingFace To use, just run dump_oscar.py It will download the dataset and output all of the text to the --output directory. Files will be broken into pieces to avoid having one giant file. By default, files will also be compressed with xz (although this can be turned off) """ import argparse import lzma import math import os from tqdm import tqdm from datasets import get_dataset_split_names from datasets import load_dataset from stanza.models.common.constant import lang_to_langcode def parse_args(): """ A few specific arguments for the dump program Uses lang_to_langcode to process args.language, hopefully converting a variety of possible formats to the short code used by HuggingFace """ parser = argparse.ArgumentParser() parser.add_argument("language", help="Language to download") parser.add_argument("--output", default="oscar_dump", help="Path for saving files") parser.add_argument("--no_xz", dest="xz", default=True, action='store_false', help="Don't xz the files - default is to compress while writing") parser.add_argument("--prefix", default="oscar_dump", help="Prefix to use for the pieces of the dataset") parser.add_argument("--version", choices=["2019", "2023"], default="2023", help="Which version of the Oscar dataset to download") args = parser.parse_args() args.language = lang_to_langcode(args.language) return args def download_2023(args): dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd') split_names = list(dataset.keys()) def main(): args = parse_args() # this is the 2019 version. for 2023, you can do # dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd') language = args.language if args.version == "2019": dataset_name = "unshuffled_deduplicated_%s" % language try: split_names = get_dataset_split_names("oscar", dataset_name) except ValueError as e: raise ValueError("Language %s not available in HuggingFace Oscar" % language) from e if len(split_names) > 1: raise ValueError("Unexpected split_names: {}".format(split_names)) dataset = load_dataset("oscar", dataset_name) dataset = dataset[split_names[0]] size_in_bytes = dataset.info.size_in_bytes process_item = lambda x: x['text'] elif args.version == "2023": dataset = load_dataset("oscar-corpus/OSCAR-2301", language) split_names = list(dataset.keys()) if len(split_names) > 1: raise ValueError("Unexpected split_names: {}".format(split_names)) # it's not clear if some languages don't support size_in_bytes, # or if there was an update to datasets which now allows that # # previously we did: # dataset = dataset[split_names[0]]['text'] # size_in_bytes = sum(len(x) for x in dataset) # process_item = lambda x: x dataset = dataset[split_names[0]] size_in_bytes = dataset.info.size_in_bytes process_item = lambda x: x['text'] else: raise AssertionError("Unknown version: %s" % args.version) chunks = max(1.0, size_in_bytes // 1e8) # an overestimate id_len = max(3, math.floor(math.log10(chunks)) + 1) if args.xz: format_str = "%s_%%0%dd.txt.xz" % (args.prefix, id_len) fopen = lambda file_idx: lzma.open(os.path.join(args.output, format_str % file_idx), "wt") else: format_str = "%s_%%0%dd.txt" % (args.prefix, id_len) fopen = lambda file_idx: open(os.path.join(args.output, format_str % file_idx), "w") print("Writing dataset to %s" % args.output) print("Dataset length: {}".format(size_in_bytes)) os.makedirs(args.output, exist_ok=True) file_idx = 0 file_len = 0 total_len = 0 fout = fopen(file_idx) for item in tqdm(dataset): text = process_item(item) fout.write(text) fout.write("\n") file_len += len(text) file_len += 1 if file_len > 1e8: file_len = 0 fout.close() file_idx = file_idx + 1 fout = fopen(file_idx) fout.close() if __name__ == '__main__': main() ================================================ FILE: stanza/utils/charlm/make_lm_data.py ================================================ """ Create Stanza character LM train/dev/test data, by reading from txt files in each source corpus directory, shuffling, splitting and saving into multiple smaller files (50MB by default) in a target directory. This script assumes the following source directory structures: - {src_dir}/{language}/{corpus}/*.txt It will read from all source .txt files and create the following target directory structures: - {tgt_dir}/{language}/{corpus} and within each target directory, it will create the following files: - train/*.txt - dev.txt - test.txt Args: - src_root: root directory of the source. - tgt_root: root directory of the target. - langs: a list of language codes to process; if specified, languages not in this list will be ignored. Note: edit the {EXCLUDED_FOLDERS} variable to exclude more folders in the source directory. """ import argparse import glob import os from pathlib import Path import shutil import subprocess import tempfile from tqdm import tqdm EXCLUDED_FOLDERS = ['raw_corpus'] def main(): parser = argparse.ArgumentParser() parser.add_argument("src_root", default="src", help="Root directory with all source files. Expected structure is root dir -> language dirs -> package dirs -> text files to process") parser.add_argument("tgt_root", default="tgt", help="Root directory with all target files.") parser.add_argument("--langs", default="", help="A list of language codes to process. If not set, all languages under src_root will be processed.") parser.add_argument("--packages", default="", help="A list of packages to process. If not set, all packages under the languages found will be processed.") parser.add_argument("--no_xz_output", default=True, dest="xz_output", action="store_false", help="Output compressed xz files") parser.add_argument("--split_size", default=50, type=int, help="How large to make each split, in MB") parser.add_argument("--no_make_test_file", default=True, dest="make_test_file", action="store_false", help="Don't save a test file. Honestly, we never even use it. Best for low resource languages where every bit helps") args = parser.parse_args() print("Processing files:") print(f"source root: {args.src_root}") print(f"target root: {args.tgt_root}") print("") langs = [] if len(args.langs) > 0: langs = args.langs.split(',') print("Only processing the following languages: " + str(langs)) packages = [] if len(args.packages) > 0: packages = args.packages.split(',') print("Only processing the following packages: " + str(packages)) src_root = Path(args.src_root) tgt_root = Path(args.tgt_root) lang_dirs = os.listdir(src_root) lang_dirs = [l for l in lang_dirs if l not in EXCLUDED_FOLDERS] # skip excluded lang_dirs = [l for l in lang_dirs if os.path.isdir(src_root / l)] # skip non-directory if len(langs) > 0: # filter languages if specified lang_dirs = [l for l in lang_dirs if l in langs] print(f"{len(lang_dirs)} total languages found:") print(lang_dirs) print("") split_size = int(args.split_size * 1024 * 1024) for lang in lang_dirs: lang_root = src_root / lang data_dirs = os.listdir(lang_root) if len(packages) > 0: data_dirs = [d for d in data_dirs if d in packages] data_dirs = [d for d in data_dirs if os.path.isdir(lang_root / d)] print(f"{len(data_dirs)} total corpus found for language {lang}.") print(data_dirs) print("") for dataset_name in data_dirs: src_dir = lang_root / dataset_name tgt_dir = tgt_root / lang / dataset_name if not os.path.exists(tgt_dir): os.makedirs(tgt_dir) print(f"-> Processing {lang}-{dataset_name}") prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, args.xz_output, split_size, args.make_test_file) print("") def prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, compress, split_size, make_test_file): """ Combine, shuffle and split data into smaller files, following a naming convention. """ assert isinstance(src_dir, Path) assert isinstance(tgt_dir, Path) with tempfile.TemporaryDirectory(dir=tgt_dir) as tempdir: tgt_tmp = os.path.join(tempdir, f"{lang}-{dataset_name}.tmp") print(f"--> Copying files into {tgt_tmp}...") # TODO: we can do this without the shell commands input_files = glob.glob(str(src_dir) + '/*.txt') + glob.glob(str(src_dir) + '/*.txt.xz') + glob.glob(str(src_dir) + '/*.txt.gz') for src_fn in tqdm(input_files): if src_fn.endswith(".txt"): cmd = f"cat {src_fn} >> {tgt_tmp}" subprocess.run(cmd, shell=True) elif src_fn.endswith(".txt.xz"): cmd = f"xzcat {src_fn} >> {tgt_tmp}" subprocess.run(cmd, shell=True) elif src_fn.endswith(".txt.gz"): cmd = f"zcat {src_fn} >> {tgt_tmp}" subprocess.run(cmd, shell=True) else: raise AssertionError("should not have found %s" % src_fn) tgt_tmp_shuffled = os.path.join(tempdir, f"{lang}-{dataset_name}.tmp.shuffled") print(f"--> Shuffling files into {tgt_tmp_shuffled}...") cmd = f"cat {tgt_tmp} | shuf > {tgt_tmp_shuffled}" result = subprocess.run(cmd, shell=True) if result.returncode != 0: raise RuntimeError("Failed to shuffle files!") size = os.path.getsize(tgt_tmp_shuffled) / 1024 / 1024 / 1024 print(f"--> Shuffled file size: {size:.4f} GB") if size < 0.1: raise RuntimeError("Not enough data found to build a charlm. At least 100MB data expected") print(f"--> Splitting into smaller files of size {split_size} ...") train_dir = tgt_dir / 'train' if not os.path.exists(train_dir): # make training dir os.makedirs(train_dir) cmd = f"split -C {split_size} -a 4 -d --additional-suffix .txt {tgt_tmp_shuffled} {train_dir}/{lang}-{dataset_name}-" result = subprocess.run(cmd, shell=True) if result.returncode != 0: raise RuntimeError("Failed to split files!") total = len(glob.glob(f'{train_dir}/*.txt')) print(f"--> {total} total files generated.") if total < 3: raise RuntimeError("Something went wrong! %d file(s) produced by shuffle and split, expected at least 3" % total) dev_file = f"{tgt_dir}/dev.txt" test_file = f"{tgt_dir}/test.txt" if make_test_file: print("--> Creating dev and test files...") shutil.move(f"{train_dir}/{lang}-{dataset_name}-0000.txt", dev_file) shutil.move(f"{train_dir}/{lang}-{dataset_name}-0001.txt", test_file) txt_files = [dev_file, test_file] + glob.glob(f'{train_dir}/*.txt') else: print("--> Creating dev file...") shutil.move(f"{train_dir}/{lang}-{dataset_name}-0000.txt", dev_file) txt_files = [dev_file] + glob.glob(f'{train_dir}/*.txt') if compress: print("--> Compressing files...") for txt_file in tqdm(txt_files): subprocess.run(['xz', txt_file]) print("--> Cleaning up...") print(f"--> All done for {lang}-{dataset_name}.\n") if __name__ == "__main__": main() ================================================ FILE: stanza/utils/charlm/oscar_to_text.py ================================================ """ Turns an Oscar 2022 jsonl file to text YOU DO NOT NEED THIS if you use the oscar extractor which reads from HuggingFace, dump_oscar.py to run: python3 -m stanza.utils.charlm.oscar_to_text ... each path can be a file or a directory with multiple .jsonl files in it """ import argparse import glob import json import lzma import os import sys from stanza.models.common.utils import open_read_text def extract_file(output_directory, input_filename, use_xz): print("Extracting %s" % input_filename) if output_directory is None: output_directory, output_filename = os.path.split(input_filename) else: _, output_filename = os.path.split(input_filename) json_idx = output_filename.rfind(".jsonl") if json_idx < 0: output_filename = output_filename + ".txt" else: output_filename = output_filename[:json_idx] + ".txt" if use_xz: output_filename += ".xz" open_file = lambda x: lzma.open(x, "wt", encoding="utf-8") else: open_file = lambda x: open(x, "w", encoding="utf-8") output_filename = os.path.join(output_directory, output_filename) print("Writing content to %s" % output_filename) with open_read_text(input_filename) as fin: with open_file(output_filename) as fout: for line in fin: content = json.loads(line) content = content['content'] fout.write(content) fout.write("\n\n") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--output", default=None, help="Output directory for saving files. If None, will write to the original directory") parser.add_argument("--no_xz", default=True, dest="xz", action="store_false", help="Don't use xz to compress the output files") parser.add_argument("filenames", nargs="+", help="Filenames or directories to process") args = parser.parse_args() return args def main(): """ Go through each of the given filenames or directories, convert json to .txt.xz """ args = parse_args() if args.output is not None: os.makedirs(args.output, exist_ok=True) for filename in args.filenames: if os.path.isfile(filename): extract_file(args.output, filename, args.xz) elif os.path.isdir(filename): files = glob.glob(os.path.join(filename, "*jsonl*")) files = sorted([x for x in files if os.path.isfile(x)]) print("Found %d files:" % len(files)) if len(files) > 0: print(" %s" % "\n ".join(files)) for json_filename in files: extract_file(args.output, json_filename, args.xz) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/confusion.py ================================================ from collections import defaultdict, namedtuple F1Result = namedtuple("F1Result", ['precision', 'recall', 'f1']) def condense_ner_labels(confusion, gold_labels, pred_labels): new_confusion = defaultdict(lambda: defaultdict(int)) new_gold_labels = [] new_pred_labels = [] for l1 in gold_labels: if l1.find("-") >= 0: new_l1 = l1.split("-", 1)[1] else: new_l1 = l1 if new_l1 not in new_gold_labels: new_gold_labels.append(new_l1) for l2 in pred_labels: if l2.find("-") >= 0: new_l2 = l2.split("-", 1)[1] else: new_l2 = l2 if new_l2 not in new_pred_labels: new_pred_labels.append(new_l2) old_value = confusion.get(l1, {}).get(l2, 0) new_confusion[new_l1][new_l2] = new_confusion[new_l1][new_l2] + old_value return new_confusion, new_gold_labels, new_pred_labels def format_confusion(confusion, labels=None, hide_zeroes=False, hide_blank=False, transpose=False): """ pretty print for confusion matrixes adapted from https://gist.github.com/zachguo/10296432 The matrix should look like this: confusion[gold][pred] """ def sort_labels(labels): """ Sorts the labels in the list, respecting BIES if all labels are BIES, putting O at the front """ labels = set(labels) if 'O' in labels: had_O = True labels.remove('O') else: had_O = False if not all(isinstance(x, str) and len(x) > 2 and x[0] in ('B', 'I', 'E', 'S') and x[1] in ('-', '_') for x in labels): labels = sorted(labels) else: # sort first by the body of the lable, then by BEIS labels = sorted(labels, key=lambda x: (x[2:], x[0])) if had_O: labels = ['O'] + labels return labels if transpose: new_confusion = defaultdict(lambda: defaultdict(int)) for label1 in confusion.keys(): for label2 in confusion[label1].keys(): new_confusion[label2][label1] = confusion[label1][label2] confusion = new_confusion if labels is None: gold_labels = set(confusion.keys()) if hide_blank: gold_labels = set(x for x in gold_labels if any(confusion[x][key] != 0 for key in confusion[x].keys())) pred_labels = set() for key in confusion.keys(): if hide_blank: new_pred_labels = set(x for x in confusion[key].keys() if confusion[key][x] != 0) else: new_pred_labels = confusion[key].keys() pred_labels = pred_labels.union(new_pred_labels) if not hide_blank: gold_labels = gold_labels.union(pred_labels) pred_labels = gold_labels gold_labels = sort_labels(gold_labels) pred_labels = sort_labels(pred_labels) else: gold_labels = labels pred_labels = labels columnwidth = max([len(str(x)) for x in pred_labels] + [5]) # 5 is value length empty_cell = " " * columnwidth # If the numbers are all ints, no need to include the .0 at the end of each entry all_ints = True for i, label1 in enumerate(gold_labels): for j, label2 in enumerate(pred_labels): if not isinstance(confusion.get(label1, {}).get(label2, 0), int): all_ints = False break if not all_ints: break if all_ints: format_cell = lambda confusion_cell: "%{0}d".format(columnwidth) % confusion_cell else: format_cell = lambda confusion_cell: "%{0}.1f".format(columnwidth) % confusion_cell # make sure the columnwidth can handle long numbers for i, label1 in enumerate(gold_labels): for j, label2 in enumerate(pred_labels): cell = confusion.get(label1, {}).get(label2, 0) columnwidth = max(columnwidth, len(format_cell(cell))) # if this is an NER confusion matrix (well, if it has - in the labels) # try to drop a bunch of labels to make the matrix easier to display if columnwidth * len(pred_labels) > 150: confusion, gold_labels, pred_labels = condense_ner_labels(confusion, gold_labels, pred_labels) # Print header if transpose: corner_label = "p\\t" else: corner_label = "t\\p" fst_empty_cell = (columnwidth-3)//2 * " " + corner_label + (columnwidth-3)//2 * " " if len(fst_empty_cell) < len(empty_cell): fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell header = " " + fst_empty_cell + " " for label in pred_labels: header = header + "%{0}s ".format(columnwidth) % str(label) text = [header.rstrip()] # Print rows for i, label1 in enumerate(gold_labels): row = " %{0}s ".format(columnwidth) % str(label1) for j, label2 in enumerate(pred_labels): confusion_cell = confusion.get(label1, {}).get(label2, 0) cell = format_cell(confusion_cell) if hide_zeroes: cell = cell if confusion_cell else empty_cell row = row + cell + " " text.append(row.rstrip()) return "\n".join(text) def confusion_to_accuracy(confusion_matrix): """ Given a confusion dictionary, return correct, total """ correct = 0 total = 0 for l1 in confusion_matrix.keys(): for l2 in confusion_matrix[l1].keys(): if l1 == l2: correct = correct + confusion_matrix[l1][l2] else: total = total + confusion_matrix[l1][l2] return correct, (correct + total) def confusion_to_f1(confusion_matrix): results = {} keys = set() for k in confusion_matrix.keys(): keys.add(k) for k2 in confusion_matrix.get(k).keys(): keys.add(k2) sum_f1 = 0 for k in keys: tp = 0 fn = 0 fp = 0 for k2 in keys: if k == k2: tp = confusion_matrix.get(k, {}).get(k, 0) else: fn = fn + confusion_matrix.get(k, {}).get(k2, 0) fp = fp + confusion_matrix.get(k2, {}).get(k, 0) if tp + fp == 0: precision = 0.0 else: precision = tp / (tp + fp) if tp + fn == 0: recall = 0.0 else: recall = tp / (tp + fn) if precision + recall == 0.0: f1 = 0.0 else: f1 = 2 * (precision * recall) / (precision + recall) results[k] = F1Result(precision, recall, f1) return results def confusion_to_macro_f1(confusion_matrix): """ Return the macro f1 for a confusion matrix. """ sum_f1 = 0.0 results = confusion_to_f1(confusion_matrix) for k in results.keys(): sum_f1 = sum_f1 + results[k].f1 return sum_f1 / len(results) def confusion_to_weighted_f1(confusion_matrix, exclude=None): results = confusion_to_f1(confusion_matrix) sum_f1 = 0.0 total_items = 0 for k in results.keys(): if exclude is not None and k in exclude: continue k_items = sum(confusion_matrix.get(k, {}).values()) total_items += k_items sum_f1 += results[k].f1 * k_items return sum_f1 / total_items ================================================ FILE: stanza/utils/conll.py ================================================ """ Utility functions for the loading and conversion of CoNLL-format files. """ import os import io from zipfile import ZipFile from stanza.models.common.doc import Document from stanza.models.common.doc import ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, NER, START_CHAR, END_CHAR from stanza.models.common.doc import FIELD_TO_IDX, FIELD_NUM from stanza.models.common.doc import LINE_NUMBER class CoNLLError(ValueError): pass class CoNLL: @staticmethod def load_conll(f, ignore_gapping=True, keep_line_numbers=False): """ Load the file or string into the CoNLL-U format data. Input: file or string reader, where the data is in CoNLL-U format. Output: a tuple whose first element is a list of list of list for each token in each sentence in the data, where the innermost list represents all fields of a token; and whose second element is a list of lists for each comment in each sentence in the data. """ # f is open() or io.StringIO() doc, sent = [], [] doc_comments, sent_comments = [], [] for line_idx, line in enumerate(f): # leave whitespace such as NBSP, in case it is meaningful in the conll-u doc line = line.lstrip().rstrip(' \n\r\t') if len(line) == 0: if len(sent) > 0: doc.append(sent) sent = [] doc_comments.append(sent_comments) sent_comments = [] else: if line.startswith('#'): # read comment line sent_comments.append(line) continue array = line.split('\t') if ignore_gapping and '.' in array[0]: continue if len(array) != FIELD_NUM: raise CoNLLError(f"Cannot parse CoNLL line {line_idx+1}: expecting {FIELD_NUM} fields, {len(array)} found at line {line_idx}\n {array}") if keep_line_numbers: if array[-1] == "_" or array[-1] is None: array[-1] = "%s=%d" % (LINE_NUMBER, line_idx) else: array[-1] = "%s|%s=%d" % (array[-1], LINE_NUMBER, line_idx) sent += [array] if len(sent) > 0: doc.append(sent) doc_comments.append(sent_comments) return doc, doc_comments @staticmethod def convert_conll(doc_conll): """ Convert the CoNLL-U format input data to a dictionary format output data. Input: list of token fields loaded from the CoNLL-U format data, where the outmost list represents a list of sentences, and the inside list represents all fields of a token. Output: a list of list of dictionaries for each token in each sentence in the document. """ doc_dict = [] doc_empty = [] for sent_idx, sent_conll in enumerate(doc_conll): sent_dict = [] sent_empty = [] for token_idx, token_conll in enumerate(sent_conll): try: token_dict = CoNLL.convert_conll_token(token_conll) except ValueError as e: raise CoNLLError("Could not process sentence %d token %d:\n%s\n%s" % (sent_idx, token_idx, token_conll, str(e))) from e if '.' in token_dict[ID]: token_dict[ID] = tuple(int(x) for x in token_dict[ID].split(".", maxsplit=1)) sent_empty.append(token_dict) else: try: token_dict[ID] = tuple(int(x) for x in token_dict[ID].split("-", maxsplit=1)) except ValueError as e: raise CoNLLError("Could not process ID %s at sent_idx %d, token_idx %d\nEntire token dict:\n%s" % (token_dict[ID], sent_idx, token_idx, token_dict)) from e sent_dict.append(token_dict) doc_dict.append(sent_dict) doc_empty.append(sent_empty) return doc_dict, doc_empty @staticmethod def convert_dict(doc_dict): """ Convert the dictionary format input data to the CoNLL-U format output data. This is the reverse function of `convert_conll`, but does not include sentence level annotations or comments. Can call this on a Document using `CoNLL.convert_dict(doc.to_dict())` Input: dictionary format data, which is a list of list of dictionaries for each token in each sentence in the data. Output: CoNLL-U format data as a list of list of list for each token in each sentence in the data. """ doc = Document(doc_dict) text = "{:c}".format(doc) sentences = text.split("\n\n") doc_conll = [[x.split("\t") for x in sentence.split("\n")] for sentence in sentences] return doc_conll @staticmethod def convert_conll_token(token_conll): """ Convert the CoNLL-U format input token to the dictionary format output token. Input: a list of all CoNLL-U fields for the token. Output: a dictionary that maps from field name to value. """ token_dict = {} for field, field_idx in FIELD_TO_IDX.items(): value = token_conll[field_idx] if value == '' and field is FEATS: continue elif value != '_': if field is HEAD: token_dict[field] = int(value) else: token_dict[field] = value # special case if text is '_' if token_conll[FIELD_TO_IDX[TEXT]] == '_': token_dict[TEXT] = token_conll[FIELD_TO_IDX[TEXT]] token_dict[LEMMA] = token_conll[FIELD_TO_IDX[LEMMA]] return token_dict @staticmethod def conll2dict(input_file=None, input_str=None, ignore_gapping=True, zip_file=None, keep_line_numbers=False): """ Load the CoNLL-U format data from file or string into lists of dictionaries. """ assert any([input_file, input_str]) and not all([input_file, input_str]), 'either use input file or input string' if zip_file: assert input_file, 'must provide input_file if zip_file is set' if input_str: infile = io.StringIO(input_str) doc_conll, doc_comments = CoNLL.load_conll(infile, ignore_gapping, keep_line_numbers) elif zip_file: with ZipFile(zip_file) as zin: with zin.open(input_file) as fin: doc_conll, doc_comments = CoNLL.load_conll(io.TextIOWrapper(fin, encoding="utf-8"), ignore_gapping, keep_line_numbers) else: with open(input_file, encoding='utf-8') as fin: doc_conll, doc_comments = CoNLL.load_conll(fin, ignore_gapping, keep_line_numbers) doc_dict, doc_empty = CoNLL.convert_conll(doc_conll) return doc_dict, doc_comments, doc_empty @staticmethod def conll2doc(input_file=None, input_str=None, ignore_gapping=True, zip_file=None, keep_line_numbers=False): doc_dict, doc_comments, doc_empty = CoNLL.conll2dict(input_file, input_str, ignore_gapping, zip_file=zip_file, keep_line_numbers=keep_line_numbers) return Document(doc_dict, text=None, comments=doc_comments, empty_sentences=doc_empty) @staticmethod def conll2multi_docs(input_file=None, input_str=None, ignore_gapping=True, zip_file=None): doc_dict, doc_comments, doc_empty = CoNLL.conll2dict(input_file, input_str, ignore_gapping, zip_file=zip_file) docs = [] current_doc = [] current_comments = [] current_empty = [] current_doc_id = None for doc, comments, empty in zip(doc_dict, doc_comments, doc_empty): for comment in comments: if comment.startswith("# doc_id =") or comment.startswith("# newdoc id ="): doc_id = comment.split("=", maxsplit=1)[1] if len(current_doc) == 0: current_doc_id = doc_id elif doc_id != current_doc_id: new_doc = Document(current_doc, text=None, comments=current_comments, empty_sentences=current_empty) if current_doc_id != None: for i in new_doc.sentences: i.doc_id = current_doc_id.strip() docs.append(new_doc) current_doc_id = doc_id else: continue current_doc = [doc] current_comments = [comments] current_empty = [empty] break else: # no comments defined a new doc_id, so just add it to the current document current_doc.append(doc) current_comments.append(comments) current_empty.append(empty) if len(current_doc) > 0: new_doc = Document(current_doc, text=None, comments=current_comments, empty_sentences=current_empty) if current_doc_id != None: for i in new_doc.sentences: i.doc_id = current_doc_id.strip() docs.append(new_doc) current_doc_id = doc_id return docs @staticmethod def dict2conll(doc_dict, filename): """ Convert the dictionary format input data to the CoNLL-U format output data and write to a file. """ doc = Document(doc_dict) CoNLL.write_doc2conll(doc, filename) @staticmethod def write_doc2conll(doc, filename, mode='w', encoding='utf-8'): """ Writes the doc as a conll file to the given file. If passed a string, that filename will be opened. Otherwise, filename.write() will be called. Note that the output needs an extra \n\n at the end to be a legal output file """ if hasattr(filename, "write"): filename.write("{:C}\n\n".format(doc)) else: with open(filename, mode, encoding=encoding) as outfile: outfile.write("{:C}\n\n".format(doc)) ================================================ FILE: stanza/utils/constituency/__init__.py ================================================ ================================================ FILE: stanza/utils/constituency/check_transitions.py ================================================ import argparse from stanza.models.constituency import transition_sequence from stanza.models.constituency import tree_reader from stanza.models.constituency.parse_transitions import TransitionScheme from stanza.models.constituency.parse_tree import Tree from stanza.models.constituency.utils import verify_transitions def main(): parser = argparse.ArgumentParser() parser.add_argument('--train_file', type=str, default="data/constituency/en_ptb3_train.mrg", help='Input file for data loader.') parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()], help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme))) parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed') parser.add_argument('--iterations', default=30, type=int, help='How many times to iterate, such as if doing a cProfile') args = parser.parse_args() args = vars(args) train_trees = tree_reader.read_treebank(args['train_file']) unary_limit = max(t.count_unary_depth() for t in train_trees) + 1 train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], args['reversed']) root_labels = Tree.get_root_labels(train_trees) for i in range(args['iterations']): verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], "train", root_labels) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/constituency/grep_dev_logs.py ================================================ import subprocess import sys iteration = sys.argv[1] filenames = sys.argv[2:] total_score = 0.0 num_scores = 0 for filename in filenames: grep_cmd = ["grep", "Dev score.* %s[)]" % iteration, "-A1", filename] grep_result = subprocess.run(grep_cmd, stdout=subprocess.PIPE, encoding="utf-8") grep_result = grep_result.stdout.strip() if not grep_result: max_cmd = ["grep", "Dev score", filename] max_result = subprocess.run(max_cmd, stdout=subprocess.PIPE, encoding="utf-8") max_result = max_result.stdout.strip() if not max_result: print("{}: no result".format(filename)) else: max_it = max_result.split("\n")[-1] max_it = int(max_it.split(":")[0].split("(")[-1][:-1]) epoch_finished_string = "Epoch %d finished" % max_it finish_cmd = ["grep", epoch_finished_string, filename] finish_result = subprocess.run(finish_cmd, stdout=subprocess.PIPE, encoding="utf-8") finish_result = finish_result.stdout.strip() finish_time = finish_result.split(" INFO")[0] print("{}: no result. max iteration: {} finished at {}".format(filename, max_it, finish_time)) else: grep_result = grep_result.split("\n")[-1] score = float(grep_result.split(":")[-1]) best_iteration = int(grep_result.split(":")[-2][-6:-1]) print("{}: {} ({})".format(filename, score, best_iteration)) total_score += score num_scores += 1 if num_scores > 0: avg = total_score / num_scores print("Avg: {}".format(avg)) ================================================ FILE: stanza/utils/constituency/grep_test_logs.py ================================================ import subprocess import sys filenames = sys.argv[1:] total_score = 0.0 num_scores = 0 for filename in filenames: grep_cmd = ["grep", "F1 score.*test.*", filename] grep_result = subprocess.run(grep_cmd, stdout=subprocess.PIPE, encoding="utf-8") grep_result = grep_result.stdout.strip() if not grep_result: print("{}: no result".format(filename)) continue score = float(grep_result.split()[-1]) print("{}: {}".format(filename, score)) total_score += score num_scores += 1 if num_scores > 0: avg = total_score / num_scores print("Avg: {}".format(avg)) ================================================ FILE: stanza/utils/constituency/list_tensors.py ================================================ """ Lists all the tensors in a constituency model. Currently useful in combination with torchshow for displaying a series of tensors as they change. """ import sys from stanza.models.constituency.trainer import Trainer trainer = Trainer.load(sys.argv[1]) model = trainer.model for name, param in model.named_parameters(): print(name, param.requires_grad) ================================================ FILE: stanza/utils/datasets/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/common.py ================================================ import argparse from enum import Enum import glob import logging import os import re import subprocess import sys import unicodedata from stanza.models.common.short_name_to_treebank import canonical_treebank_name import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data import stanza.utils.datasets.conllu_to_text as conllu_to_text import stanza.utils.default_paths as default_paths logger = logging.getLogger('stanza') # RE to see if the index of a conllu line represents an MWT MWT_RE = re.compile("^[0-9]+[-][0-9]+") # RE to see if the index of a conllu line represents an MWT or copy node MWT_OR_COPY_RE = re.compile("^[0-9]+[-.][0-9]+") # more restrictive than an actual int as we expect certain formats in the conllu files INT_RE = re.compile("^[0-9]+$") class ModelType(Enum): TOKENIZER = 1 MWT = 2 POS = 3 LEMMA = 4 DEPPARSE = 5 class UnknownDatasetError(ValueError): def __init__(self, dataset, text): super().__init__(text) self.dataset = dataset def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")): """ Convert the conllu documents for this dataset to a .txt format This follows the old conllu_to_text.pl script, except we never used the ZH option anyway, so we didn't reimplement it here """ for dataset in shards: output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" if not os.path.exists(output_conllu): raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu) conllu_to_text.main([output_conllu, output_txt]) def strip_accents(word): """ Remove diacritics from words such as in the UD GRC datasets """ converted = ''.join(c for c in unicodedata.normalize('NFD', word) if unicodedata.category(c) not in ('Mn')) if len(converted) == 0: return word return converted def mwt_name(base_dir, short_name, dataset): return os.path.join(base_dir, f"{short_name}-ud-{dataset}-mwt.json") def tokenizer_conllu_name(base_dir, short_name, dataset): return os.path.join(base_dir, f"{short_name}.{dataset}.gold.conllu") def prepare_tokenizer_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset): labels_filename = f"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels" mwt_filename = mwt_name(tokenizer_dir, short_name, dataset) prepare_tokenizer_data.main([input_txt, input_conllu, "-o", labels_filename, "-m", mwt_filename]) def prepare_tokenizer_treebank_labels(tokenizer_dir, short_name): """ Given the txt and gold.conllu files, prepare mwt and label files for train/dev/test """ for dataset in ("train", "dev", "test"): output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" try: prepare_tokenizer_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset) except (KeyboardInterrupt, SystemExit): raise except: print("Failed to convert %s to %s" % (output_txt, output_conllu)) raise def read_sentences_from_conllu(filename): """ Reads a conllu file as a list of list of strings Finding a blank line separates the lists """ sents = [] cache = [] with open(filename, encoding="utf-8") as infile: for line in infile: line = line.strip() if len(line) == 0: if len(cache) > 0: sents.append(cache) cache = [] continue cache.append(line) if len(cache) > 0: sents.append(cache) return sents def maybe_add_fake_dependencies(lines): """ Possibly add fake dependencies in columns 6 and 7 (counting from 0) The conllu scripts need the dependencies column filled out, so in the case of models we build without dependency data, we need to add those fake dependencies in order to use the eval script etc lines: a list of strings with 10 tab separated columns comments are allowed (they will be skipped) returns: the same strings, but with fake dependencies added if columns 6 and 7 were empty """ new_lines = [] root_idx = None first_idx = None for line_idx, line in enumerate(lines): if line.startswith("#"): new_lines.append(line) continue pieces = line.split("\t") if MWT_OR_COPY_RE.match(pieces[0]): new_lines.append(line) continue token_idx = int(pieces[0]) if pieces[6] != '_': if pieces[6] == '0': root_idx = token_idx new_lines.append(line) elif token_idx == 1: # note that the comments might make this not the first line # we keep track of this separately so we can either make this the root, # or set this to be the root later first_idx = line_idx new_lines.append(pieces) else: pieces[6] = "1" pieces[7] = "dep" new_lines.append("\t".join(pieces)) if first_idx is not None: if root_idx is None: new_lines[first_idx][6] = "0" new_lines[first_idx][7] = "root" else: new_lines[first_idx][6] = str(root_idx) new_lines[first_idx][7] = "dep" new_lines[first_idx] = "\t".join(new_lines[first_idx]) return new_lines def write_sentences_to_file(outfile, sents): for lines in sents: lines = maybe_add_fake_dependencies(lines) for line in lines: print(line, file=outfile) print("", file=outfile) def write_sentences_to_conllu(filename, sents): with open(filename, 'w', encoding="utf-8") as outfile: write_sentences_to_file(outfile, sents) def find_treebank_dataset_file(treebank, udbase_dir, dataset, extension, fail=False, env_var="UDBASE"): """ For a given treebank, dataset, extension, look for the exact filename to use. Sometimes the short name we use is different from the short name used by UD. For example, Norwegian or Chinese. Hence the reason to not hardcode it based on treebank set fail=True to fail if the file is not found """ if treebank.startswith("UD_Korean") and treebank.endswith("_seg"): treebank = treebank[:-4] if treebank.startswith("UD_Ancient_Greek-") and (treebank.endswith("-Diacritics") or treebank.endswith("-diacritics")): treebank = treebank[:-11] filename = os.path.join(udbase_dir, treebank, f"*-ud-{dataset}.{extension}") files = glob.glob(filename) if len(files) == 0: if fail: raise FileNotFoundError("Could not find any treebank files which matched {}\nIf you have the data elsewhere, you can change the base directory for the search by changing the {} environment variable".format(filename, env_var)) else: return None elif len(files) == 1: return files[0] else: raise RuntimeError(f"Unexpected number of files matched '{udbase_dir}/{treebank}/*-ud-{dataset}.{extension}'") def mostly_underscores(filename): """ Certain treebanks have proprietary data, so the text is hidden For example: UD_Arabic-NYUAD UD_English-ESL UD_English-GUMReddit UD_Hindi_English-HIENCS UD_Japanese-BCCWJ """ underscore_count = 0 total_count = 0 for line in open(filename).readlines(): line = line.strip() if not line: continue if line.startswith("#"): continue total_count = total_count + 1 pieces = line.split("\t") if pieces[1] in ("_", "-"): underscore_count = underscore_count + 1 return underscore_count / total_count > 0.5 def num_words_in_file(conllu_file): """ Count the number of non-blank lines in a conllu file """ count = 0 with open(conllu_file) as fin: for line in fin: line = line.strip() if not line: continue if line.startswith("#"): continue count = count + 1 return count def get_test_only_ud_treebanks(udbase_dir, filtered=True): """ Looks in udbase_dir for all the treebanks which are *only* test sets, but might be big enough Filters out: - less than 10000 words - the language already has a larger treebank we can use The second filter takes quite some time, as there is a check that goes through all the text in the treebank """ treebanks = sorted(glob.glob(udbase_dir + "/UD_*")) # skip UD_English-GUMReddit as it is usually incorporated into UD_English-GUM treebanks = [os.path.split(t)[1] for t in treebanks] treebanks = [t for t in treebanks if t != "UD_English-GUMReddit"] # only take the ones which do have test, but don't have train treebanks = [t for t in treebanks if not find_treebank_dataset_file(t, udbase_dir, "train", "conllu")] treebanks = [t for t in treebanks if find_treebank_dataset_file(t, udbase_dir, "test", "conllu")] treebanks = [t for t in treebanks if not mostly_underscores(find_treebank_dataset_file(t, udbase_dir, "test", "conllu"))] if any(find_treebank_dataset_file(t, udbase_dir, "dev", "conllu") for t in treebanks): raise AssertionError("Found a treebank with dev and test, but no train. This violates our expectations") treebanks = [t for t in treebanks if num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "test", "conllu")) > 10000] if filtered: treebanks = [t for t in treebanks if len(get_ud_treebanks(udbase_dir, lang=t.split("-")[0])) == 0] #for t in treebanks: # print(t, # num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "test", "conllu"))) return treebanks def get_ud_treebanks(udbase_dir, lang=None, filtered=True): """ Looks in udbase_dir for all the treebanks which have both train, dev, and test If specified, lang should be exactly UD_English or however the treebanks appear in the UD release """ if lang is None: treebanks = sorted(glob.glob(udbase_dir + "/UD_*")) else: treebanks = sorted(glob.glob("%s/%s*" % (udbase_dir, lang))) # skip UD_English-GUMReddit as it is usually incorporated into UD_English-GUM treebanks = [os.path.split(t)[1] for t in treebanks] treebanks = [t for t in treebanks if t != "UD_English-GUMReddit"] if filtered: treebanks = [t for t in treebanks if (find_treebank_dataset_file(t, udbase_dir, "train", "conllu") and # this will be fixed using XV #find_treebank_dataset_file(t, udbase_dir, "dev", "conllu") and find_treebank_dataset_file(t, udbase_dir, "test", "conllu"))] treebanks = [t for t in treebanks if not mostly_underscores(find_treebank_dataset_file(t, udbase_dir, "train", "conllu"))] # eliminate partial treebanks (fixed with XV) for which we only have 1000 words or less # if the train set is small and the test set is large enough, we'll flip them treebanks = [t for t in treebanks if (find_treebank_dataset_file(t, udbase_dir, "dev", "conllu") or num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "train", "conllu")) > 1000 or num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "test", "conllu")) > 5000)] return treebanks def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks') return parser def main(process_treebank, model_type, add_specific_args=None): logger.info("Datasets program called with:\n" + " ".join(sys.argv)) parser = build_argparse() if add_specific_args is not None: add_specific_args(parser) args = parser.parse_args() paths = default_paths.get_default_paths() treebanks = [] for treebank in args.treebanks: if treebank.lower() in ('ud_all', 'all_ud'): ud_treebanks = get_ud_treebanks(paths["UDBASE"]) treebanks.extend(ud_treebanks) else: # If this is a known UD short name, use the official name (we need it for the paths) treebank = canonical_treebank_name(treebank) treebanks.append(treebank) for treebank in treebanks: process_treebank(treebank, model_type, paths, args) ================================================ FILE: stanza/utils/datasets/conllu_to_text.py ================================================ import argparse import re TEXT_RE = re.compile("^#\\s*text") NEWPAR_RE = re.compile("^#\\s*newpar") NEWDOC_RE = re.compile("^#\\s*newdoc") MWT_RE = re.compile("^\\d+-(\\d+)\t") WORD_RE = re.compile("^(\\d)+\t") WORD_NEWPAR_RE = re.compile("NewPar=Yes") SPACEAFTER_RE = re.compile("SpaceAfter=No") def print_new_paragraph_if_needed(fout, start, newdoc, newpar, output_buffer): if not start and (newdoc or newpar): if output_buffer: fout.write(output_buffer) fout.write("\n") fout.write("\n") return "" return output_buffer def print_lines_from_buffer(fout, output_buffer, max_len): while len(output_buffer) >= max_len: split_idx = None for idx in range(len(output_buffer)): if idx > max_len and split_idx is not None: break if output_buffer[idx].isspace(): split_idx = idx if split_idx is not None: fout.write(output_buffer[:split_idx]) fout.write("\n") output_buffer = output_buffer[split_idx+1:] else: fout.write(output_buffer) fout.write("\n") output_buffer = "" return output_buffer def convert_text(conllu_file, output_file): with open(conllu_file, encoding="utf-8") as fin: lines = fin.readlines() with open(output_file, "w", encoding="utf-8") as fout: newpar = False newdoc = False start = True in_mwt = False mwt_last = None def print_and_reset(output_buffer, incoming_buffer): nonlocal start, newpar, newdoc, in_mwt output_buffer = print_new_paragraph_if_needed(fout, start, newdoc, newpar, output_buffer) output_buffer += incoming_buffer output_buffer = print_lines_from_buffer(fout, output_buffer, 80) start = False newpar = False newdoc = False in_mwt = False return output_buffer output_buffer = "" incoming_buffer = "" for line in lines: line = line.strip() if not line: output_buffer = print_and_reset(output_buffer, incoming_buffer) incoming_buffer = "" if TEXT_RE.match(line): # we ignore the #text and extract the text from the tokens continue if NEWPAR_RE.match(line): newpar = True continue if NEWDOC_RE.match(line): newdoc = True continue match = MWT_RE.match(line) if match: in_mwt = True mwt_last = int(match.group(1)) pieces = line.split("\t") if WORD_NEWPAR_RE.search(pieces[9]): output_buffer = print_and_reset(output_buffer, incoming_buffer) incoming_buffer = "" fout.write(output_buffer) fout.write("\n\n") output_buffer = "" incoming_buffer += pieces[1] if not SPACEAFTER_RE.search(pieces[9]): incoming_buffer += " " continue match = WORD_RE.match(line) if match: pieces = line.split("\t") word_id = int(pieces[0]) if in_mwt and word_id <= mwt_last: continue in_mwt = False if WORD_NEWPAR_RE.search(pieces[9]): output_buffer = print_and_reset(output_buffer, incoming_buffer) incoming_buffer = "" fout.write(output_buffer) fout.write("\n\n") output_buffer = "" incoming_buffer += pieces[1] if not SPACEAFTER_RE.search(pieces[9]): incoming_buffer += " " continue if output_buffer != "": fout.write(output_buffer) fout.write("\n") def main(args=None): parser = argparse.ArgumentParser() parser.add_argument('conllu_file', type=str, help="CoNLL-U file containing tokens and sentence breaks") parser.add_argument('output_file', type=str, help="Plaintext file containing the raw input") args = parser.parse_args(args) convert_text(args.conllu_file, args.output_file) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/datasets/constituency/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/constituency/build_silver_dataset.py ================================================ """ Given two ensembles and a tokenized file, output the trees for which those ensembles agree and report how many of the sub-models agree on those trees. For example: python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_AA.txt --lang it --output_file asdf.out --e1 saved_models/constituency/it_vit_electra_100?_top_constituency.pt --e2 saved_models/constituency/it_vit_electra_100?_constituency.pt for i in `echo f g h i j k l m n o p q r s t`; do nlprun -d a6000 "python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tok_6M_a$i.txt --lang it --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a$i.trees --e1 saved_models/constituency/it_vit_electra_100?_top_constituency.pt --e2 saved_models/constituency/it_vit_electra_100?_constituency.pt" -o /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a$i.out; done for i in `echo a b c d`; do nlprun -d a6000 "python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/english/en_wiki_2023/shuf_1M.a$i --lang en --output_file /u/nlp/data/constituency-parser/english/2024_en_ptb3_electra/forward_a$i.trees --e1 saved_models/constituency/en_ptb3_electra-large_100?_in_constituency.pt --e2 saved_models/constituency/en_ptb3_electra-large_100?_top_constituency.pt" -o /u/nlp/data/constituency-parser/english/2024_en_ptb3_electra/forward_a$i.out; done """ import argparse import json import logging from stanza.models.common import utils from stanza.models.common.foundation_cache import FoundationCache from stanza.models.constituency import retagging from stanza.models.constituency import text_processing from stanza.models.constituency import tree_reader from stanza.models.constituency.ensemble import Ensemble from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() logger = logging.getLogger('stanza.constituency.trainer') def parse_args(args=None): parser = argparse.ArgumentParser(description="Script that uses multiple ensembles to find trees where both ensembles agree") input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.') input_group.add_argument('--tree_file', type=str, default=None, help='Input file of already parsed text for reparsing with parse_text.') parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file') parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') utils.add_device_args(parser) parser.add_argument('--lang', default='en', help='Language to use') parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval') parser.add_argument('--e1', type=str, nargs='+', default=None, help="Which model(s) to load in the first ensemble") parser.add_argument('--e2', type=str, nargs='+', default=None, help="Which model(s) to load in the second ensemble") parser.add_argument('--mode', default='predict', choices=['parse_text', 'predict']) # another option would be to include the tree idx in each entry in an existing saved file # the processing could then pick up at exactly the last known idx parser.add_argument('--start_tree', type=int, default=0, help='Where to start... most useful if the previous incarnation crashed') parser.add_argument('--end_tree', type=int, default=None, help='Where to end. If unset, will process to the end of the file') retagging.add_retag_args(parser) args = vars(parser.parse_args()) retagging.postprocess_args(args) args['num_generate'] = 0 return args def main(): args = parse_args() utils.log_training_args(args, logger, name="ensemble") retag_pipeline = retagging.build_retag_pipeline(args) foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() logger.info("Building ensemble #1 out of %s", args['e1']) e1 = Ensemble(args, filenames=args['e1'], foundation_cache=foundation_cache) e1.to(args.get('device', None)) logger.info("Building ensemble #2 out of %s", args['e2']) e2 = Ensemble(args, filenames=args['e2'], foundation_cache=foundation_cache) e2.to(args.get('device', None)) if args['tokenized_file']: tokenized_sentences, _ = text_processing.read_tokenized_file(args['tokenized_file']) elif args['tree_file']: treebank = tree_reader.read_treebank(args['tree_file']) tokenized_sentences = [x.leaf_labels() for x in treebank] if args['lang'] == 'vi': tokenized_sentences = [[x.replace("_", " ") for x in sentence] for sentence in tokenized_sentences] logger.info("Read %d tokenized sentences", len(tokenized_sentences)) all_models = e1.models + e2.models chunk_size = 1000 with open(args['output_file'], 'w', encoding='utf-8') as fout: end_tree = len(tokenized_sentences) if args['end_tree'] is None else args['end_tree'] for chunk_start in tqdm(range(args['start_tree'], end_tree, chunk_size)): chunk = tokenized_sentences[chunk_start:chunk_start+chunk_size] logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk)) parsed1 = text_processing.parse_tokenized_sentences(args, e1, retag_pipeline, chunk) parsed1 = [x.predictions[0].tree for x in parsed1] parsed2 = text_processing.parse_tokenized_sentences(args, e2, retag_pipeline, chunk) parsed2 = [x.predictions[0].tree for x in parsed2] matching = [t for t, t2 in zip(parsed1, parsed2) if t == t2] logger.info("%d trees matched", len(matching)) model_counts = [0] * len(matching) for model in all_models: model_chunk = model.parse_sentences_no_grad(iter(matching), model.build_batch_from_trees, args['eval_batch_size'], model.predict) model_chunk = [x.predictions[0].tree for x in model_chunk] for idx, (t1, t2) in enumerate(zip(matching, model_chunk)): if t1 == t2: model_counts[idx] += 1 for count, tree in zip(model_counts, matching): line = {"tree": "%s" % tree, "count": count} fout.write(json.dumps(line)) fout.write("\n") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/common_trees.py ================================================ """ Look through 2 files, only output the common trees pretty basic - could use some more options """ import sys def main(): in1 = sys.argv[1] with open(in1, encoding="utf-8") as fin: lines1 = fin.readlines() in2 = sys.argv[2] with open(in2, encoding="utf-8") as fin: lines2 = fin.readlines() common = [l1 for l1, l2 in zip(lines1, lines2) if l1 == l2] for l in common: print(l.strip()) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/convert_alt.py ================================================ """ Read files of parses and the files which define the train/dev/test splits Write out the files after splitting them Sequence of operations: - read the raw lines from the input files - read the recommended splits, as per the ALT description page - separate the trees using the recommended split files - write back the trees """ def read_split_file(split_file): """ Read a split file for ALT The format of the file is expected to be a list of lines such as URL.1234 Here, we only care about the id return: a set of the ids """ with open(split_file, encoding="utf-8") as fin: lines = fin.readlines() lines = [x.strip() for x in lines] lines = [x.split()[0] for x in lines if x] if any(not x.startswith("URL.") for x in lines): raise ValueError("Unexpected line in %s: %s" % (split_file, x)) split = set(int(x.split(".", 1)[1]) for x in lines) return split def split_trees(all_lines, splits): """ Splits lines of the form SNT.17873.4049 (S ... then assigns them to a list based on the file id in SNT.. """ trees = [list() for _ in splits] for line in all_lines: tree_id, tree_text = line.split(maxsplit=1) tree_id = int(tree_id.split(".", 2)[1]) for split_idx, split in enumerate(splits): if tree_id in split: trees[split_idx].append(tree_text) break else: # couldn't figure out which split to put this in raise ValueError("Couldn't find which split this line goes in:\n%s" % line) return trees def read_alt_lines(input_files): """ Read the trees from the given file(s) Any trees with wide spaces are eliminated. The parse tree handling doesn't handle it well and the tokenizer won't produce tokens which are entirely wide spaces anyway The tree lines are not processed into trees, though """ all_lines = [] for input_file in input_files: with open(input_file, encoding="utf-8") as fin: all_lines.extend(fin.readlines()) all_lines = [x.strip() for x in all_lines] all_lines = [x for x in all_lines if x] original_count = len(all_lines) # there is 1 tree with wide space as an entire token, and 4 with wide spaces at the end of a token all_lines = [x for x in all_lines if not " " in x] new_count = len(all_lines) if new_count < original_count: print("Eliminated %d trees for having wide spaces in it" % ((original_count - new_count))) original_count = new_count all_lines = [x for x in all_lines if not "\\x" in x] new_count = len(all_lines) if new_count < original_count: print("Eliminated %d trees for not being correctly encoded" % ((original_count - new_count))) original_count = new_count return all_lines def convert_alt(input_files, split_files, output_files): """ Convert the ALT treebank into train/dev/test splits input_files: paths to read trees split_files: recommended splits from the ALT page output_files: where to write train/dev/test """ all_lines = read_alt_lines(input_files) splits = [read_split_file(split_file) for split_file in split_files] trees = split_trees(all_lines, splits) for chunk, output_file in zip(trees, output_files): print("Writing %d trees to %s" % (len(chunk), output_file)) with open(output_file, "w", encoding="utf-8") as fout: for tree in chunk: # the extra ROOT is because the ALT doesn't have this at the top of its trees fout.write("(ROOT {})\n".format(tree)) ================================================ FILE: stanza/utils/datasets/constituency/convert_arboretum.py ================================================ """ Parses a Tiger dataset to PTB Also handles problems specific for the Arboretum treebank. - validation errors in the XML: -- there is a "&" instead of an "&" early on -- there are tags "<{note}>" and "<{parentes-udeladt}>" which may or may not be relevant, but are definitely not properly xml encoded - trees with stranded nodes. 5 trees have links to words in a different tree. those trees are skipped - trees with empty nodes. 58 trees have phrase nodes with no leaves. those trees are skipped - trees with missing words. 134 trees have words in the text which aren't in the tree those trees are also skipped - trees with categories not in the category directory for example, intj... replaced with fcl? most of these are replaced with what might be a sensible replacement - trees with labels that don't have an obvious replacement these trees are eliminated, 4 total - underscores in words. those words are split into multiple words the tagging is not going to be ideal, but the first step of training a parser is usually to retag the words anyway, so this should be okay - tree 14729 is really weirdly annotated. skipped - 5373 trees total have non-projective constituents. These don't work with the stanza parser... in order to work around this, we rearrange them when possible. ((X Z) Y1 Y2 ...) -> (X Y1 Y2 Z) this rearranges 3021 trees ((X Z1 ...) Y1 Y2 ...) -> (X Y1 Y2 Z) this rearranges 403 trees ((X Z1 ...) (tag Y1) ...) -> (X (Y1) Z) this rearranges 1258 trees A couple examples of things which get rearranged (limited in scope and without the words to avoid breaking our license): (vp (v-fin s4_6) (conj-c s4_8) (v-fin s4_9)) (pron-pers s4_7) --> (vp (v-fin s4_6) (pron-pers s4_7) (conj-c s4_8) (v-fin s4_9)) (vp (v-fin s1_2) (v-pcp2 s1_4)) (adv s1_3) --> (vp (v-fin s1_2) (adv s1_3) (v-pcp2 s1_4)) This process leaves behind 691 trees. In some cases, the non-projective structure is at a higher level than the attachment. In others, there are nested non-projectivities that are not rearranged by the above pattern. A couple examples: here, the 3-7 nonprojectivity has the 7 in a nested structure (s (par (n s206_1) (pu s206_2) (fcl (fcl (pron-pers s206_3) (fcl (pron-pers s206_7) (adv s206_8) (v-fin s206_9))) (vp (v-fin s206_4) (v-inf s206_6)) (pron-pers s206_5)) (pu s206_10))) here, 11 is attached at a higher level than 12 & 13 (s (fcl (icl (np (adv s223_1) (np (n s223_2) (pp (prp s223_3) (par (adv s223_4) (prop s223_5) (pu s223_6) (prop s223_7) (conj-c s223_8) (np (adv s223_9) (prop s223_10)))))) (vp (infm s223_12) (v-inf s223_13))) (v-fin s223_11) (pu s223_14))) even if we moved _6 between 2 and 7, we'd then have a completely flat structure when moving 3..5 inside (s (fcl (xx s499_1) (np (pp (pron-pers s499_2) (prp s499_7)) (n s499_6)) (v-fin s499_3) (adv s499_4) (adv s499_5) (pu s499_8))) """ from collections import namedtuple import io import xml.etree.ElementTree as ET from tqdm import tqdm from stanza.models.constituency.parse_tree import Tree from stanza.server import tsurgeon def read_xml_file(input_filename): """ Convert an XML file into a list of trees - each becomes its own object """ print("Reading {}".format(input_filename)) with open(input_filename, encoding="utf-8") as fin: lines = fin.readlines() sentences = [] current_sentence = [] in_sentence = False for line_idx, line in enumerate(lines): if line.startswith(" 0: raise ValueError("Found the start of a sentence inside an existing sentence, line {}".format(line_idx)) in_sentence = True if in_sentence: current_sentence.append(line) if line.startswith(""): assert in_sentence current_sentence = [x.replace("<{parentes-udeladt}>", "") for x in current_sentence] current_sentence = [x.replace("<{note}>", "") for x in current_sentence] sentences.append("".join(current_sentence)) current_sentence = [] in_sentence = False assert len(current_sentence) == 0 xml_sentences = [] for sent_idx, text in enumerate(sentences): sentence = io.StringIO(text) try: tree = ET.parse(sentence) xml_sentences.append(tree) except ET.ParseError as e: raise ValueError("Failed to parse sentence {}".format(sent_idx)) return xml_sentences Word = namedtuple('Word', ['word', 'tag']) Node = namedtuple('Node', ['label', 'children']) class BrokenLinkError(ValueError): def __init__(self, error): super(BrokenLinkError, self).__init__(error) def process_nodes(root_id, words, nodes, visited): """ Given a root_id, a map of words, and a map of nodes, construct a Tree visited is a set of string ids and mutates over the course of the recursive call """ if root_id in visited: raise ValueError("Loop in the tree!") visited.add(root_id) if root_id in words: word = words[root_id] # big brain move: put the root_id here so we can use that to # check the sorted order when we are done word_node = Tree(label=root_id) tag_node = Tree(label=word.tag, children=word_node) return tag_node elif root_id in nodes: node = nodes[root_id] children = [process_nodes(child, words, nodes, visited) for child in node.children] return Tree(label=node.label, children=children) else: raise BrokenLinkError("Unknown id! {}".format(root_id)) def check_words(tree, tsurgeon_processor): """ Check that the words of a sentence are in order If they are not, this applies a tsurgeon to rearrange simple cases The tsurgeon looks at the gap between words, eg _3 to _7, and looks for the words between, such as _4 _5 _6. if those words are under a node at the same level as the 3-7 node and does not include any other nodes (such as _8), that subtree is moved to between _3 and _7 Example: (vp (v-fin s4_6) (conj-c s4_8) (v-fin s4_9)) (pron-pers s4_7) --> (vp (v-fin s4_6) (pron-pers s4_7) (conj-c s4_8) (v-fin s4_9)) """ while True: words = tree.leaf_labels() indices = [int(w.split("_", 1)[1]) for w in words] for word_idx, word_label in enumerate(indices): if word_idx != word_label - 1: break else: # if there are no weird indices, keep the tree return tree sorted_indices = sorted(indices) if indices == sorted_indices: raise ValueError("Skipped index! This should already be accounted for {}".format(tree)) if word_idx == 0: return None prefix = words[0].split("_", 1)[0] prev_idx = word_idx - 1 prev_label = indices[prev_idx] missing_words = ["%s_%d" % (prefix, x) for x in range(prev_label + 1, word_label)] missing_words = "|".join(missing_words) #move_tregex = "%s > (__=home > (__=parent > __=grandparent)) . (%s > (__=move > =grandparent))" % (words[word_idx], "|".join(missing_words)) move_tregex = "%s > (__=home > (__=parent << %s $+ (__=move <<, %s <<- %s)))" % (words[word_idx], words[prev_idx], missing_words, missing_words) move_tsurgeon = "move move $+ home" modified = tsurgeon_processor.process(tree, move_tregex, move_tsurgeon)[0] if modified == tree: # this only happens if the desired fix didn't happen #print("Failed to process:\n {}\n {} {}".format(tree, prev_label, word_label)) return None tree = modified def replace_words(tree, words): """ Remap the leaf words given a map of the labels we expect in the leaves """ leaves = tree.leaf_labels() new_words = [words[w].word for w in leaves] new_tree = tree.replace_words(new_words) return new_tree def process_tree(sentence): """ Convert a single ET element representing a Tiger tree to a parse tree """ sentence = sentence.getroot() sent_id = sentence.get("id") if sent_id is None: raise ValueError("Tree {} does not have an id".format(sent_id)) if len(sentence) > 1: raise ValueError("Longer than expected number of items in {}".format(sent_id)) graph = sentence.find("graph") if graph is None: raise ValueError("Unexpected tree structure in {} : top tag is not 'graph'".format(sent_id)) root_id = graph.get("root") if root_id is None: raise ValueError("Tree has no root id in {}".format(sent_id)) terminals = graph.find("terminals") if terminals is None: raise ValueError("No terminals in tree {}".format(sent_id)) # some Arboretum graphs have two sets of nonterminals, # apparently intentionally, so we ignore that possible error nonterminals = graph.find("nonterminals") if nonterminals is None: raise ValueError("No nonterminals in tree {}".format(sent_id)) # read the words. the words have ids, text, and tags which we care about words = {} for word in terminals: if word.tag == 'parentes-udeladt' or word.tag == 'note': continue if word.tag != "t": raise ValueError("Unexpected tree structure in {} : word with tag other than t".format(sent_id)) word_id = word.get("id") if not word_id: raise ValueError("Word had no id in {}".format(sent_id)) word_text = word.get("word") if not word_text: raise ValueError("Word had no text in {}".format(sent_id)) word_pos = word.get("pos") if not word_pos: raise ValueError("Word had no pos in {}".format(sent_id)) words[word_id] = Word(word_text, word_pos) # read the nodes. the nodes have ids, labels, and children # some of the edges are labeled "secedge". we ignore those nodes = {} for nt in nonterminals: if nt.tag != "nt": raise ValueError("Unexpected tree structure in {} : node with tag other than nt".format(sent_id)) nt_id = nt.get("id") if not nt_id: raise ValueError("NT has no id in {}".format(sent_id)) nt_label = nt.get("cat") if not nt_label: raise ValueError("NT has no label in {}".format(sent_id)) children = [] for child in nt: if child.tag != "edge" and child.tag != "secedge": raise ValueError("NT has unexpected child in {} : {}".format(sent_id, child.tag)) if child.tag == "edge": child_id = child.get("idref") if not child_id: raise ValueError("Child is missing an id in {}".format(sent_id)) children.append(child_id) nodes[nt_id] = Node(nt_label, children) if root_id not in nodes: raise ValueError("Could not find root in nodes in {}".format(sent_id)) tree = process_nodes(root_id, words, nodes, set()) return tree, words def word_sequence_missing_words(tree): """ Check if the word sequence is missing words Some trees skip labels, such as (s (fcl (pron-pers s16817_1) (v-fin s16817_2) (prp s16817_3) (pp (prp s16817_5) (par (n s16817_6) (conj-c s16817_7) (n s16817_8))) (pu s16817_9))) but in these cases, the word is present in the original text and simply not attached to the tree """ words = tree.leaf_labels() indices = [int(w.split("_")[1]) for w in words] indices = sorted(indices) for idx, label in enumerate(indices): if label != idx + 1: return True return False WORD_TO_PHRASE = { "art": "advp", # "en smule" is the one time this happens. it is used as an advp elsewhere "adj": "adjp", "adv": "advp", "conj": "cp", "intj": "fcl", # not sure? seems to match "hold kæft" when it shows up "n": "np", "num": "np", # would prefer something like QP from PTB "pron": "np", # ?? "prop": "np", "prp": "pp", "v": "vp", } def split_underscores(tree): assert not tree.is_leaf(), "Should never reach a leaf in this code path" if tree.is_preterminal(): return tree children = tree.children new_children = [] for child in children: if child.is_preterminal(): if '_' not in child.children[0].label: new_children.append(child) continue if child.label.split("-")[0] not in WORD_TO_PHRASE: raise ValueError("SPLITTING {}".format(child)) pieces = [] for piece in child.children[0].label.split("_"): # This may not be accurate, but we already retag the treebank anyway if len(piece) == 0: raise ValueError("A word started or ended with _") pieces.append(Tree(child.label, Tree(piece))) new_children.append(Tree(WORD_TO_PHRASE[child.label.split("-")[0]], pieces)) else: new_children.append(split_underscores(child)) return Tree(tree.label, new_children) REMAP_LABELS = { "adj": "adjp", "adv": "advp", "intj": "fcl", "n": "np", "num": "np", # again, a dedicated number node would be better, but there are only a few "num" labeled "prp": "pp", } def has_weird_constituents(tree): """ Eliminate a few trees with weird labels Eliminate p? there are only 3 and they have varying structure underneath Also cl, since I have no idea how to label it and it only excludes 1 tree """ labels = Tree.get_unique_constituent_labels(tree) if "p" in labels or "cl" in labels: return True return False def convert_tiger_treebank(input_filename): sentences = read_xml_file(input_filename) unfixable = 0 dangling = 0 broken_links = 0 missing_words = 0 weird_constituents = 0 trees = [] with tsurgeon.Tsurgeon() as tsurgeon_processor: for sent_idx, sentence in enumerate(tqdm(sentences)): try: tree, words = process_tree(sentence) if not tree.all_leaves_are_preterminals(): dangling += 1 continue if word_sequence_missing_words(tree): missing_words += 1 continue tree = check_words(tree, tsurgeon_processor) if tree is None: unfixable += 1 continue if has_weird_constituents(tree): weird_constituents += 1 continue tree = replace_words(tree, words) tree = split_underscores(tree) tree = tree.remap_constituent_labels(REMAP_LABELS) trees.append(tree) except BrokenLinkError as e: # the get("id") would have failed as a different error type if missing, # so we can safely use it directly like this broken_links += 1 # print("Unable to process {} because of broken links: {}".format(sentence.getroot().get("id"), e)) print("Found {} trees with empty nodes".format(dangling)) print("Found {} trees with unattached words".format(missing_words)) print("Found {} trees with confusing constituent labels".format(weird_constituents)) print("Not able to rearrange {} nodes".format(unfixable)) print("Unable to handle {} trees because of broken links, eg names in another tree".format(broken_links)) print("Parsed {} trees from {}".format(len(trees), input_filename)) return trees def main(): treebank = convert_tiger_treebank("extern_data/constituency/danish/W0084/arboretum.tiger/arboretum.tiger") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/convert_cintil.py ================================================ import xml.etree.ElementTree as ET from stanza.models.constituency import tree_reader from stanza.utils.datasets.constituency import utils def read_xml_file(input_filename): """ Convert the CINTIL xml file to id & test Returns a list of tuples: (id, text) """ with open(input_filename, encoding="utf-8") as fin: dataset = ET.parse(fin) dataset = dataset.getroot() corpus = dataset.find("{http://nlx.di.fc.ul.pt}corpus") if not corpus: raise ValueError("Unexpected dataset structure : no 'corpus'") trees = [] for sentence in corpus: if sentence.tag != "{http://nlx.di.fc.ul.pt}sentence": raise ValueError("Unexpected sentence tag: {}".format(sentence.tag)) id_node = None raw_node = None tree_nodde = None for node in sentence: if node.tag == '{http://nlx.di.fc.ul.pt}id': id_node = node elif node.tag == '{http://nlx.di.fc.ul.pt}raw': raw_node = node elif node.tag == '{http://nlx.di.fc.ul.pt}tree': tree_node = node else: raise ValueError("Unexpected tag in sentence {}: {}".format(sentence, node.tag)) if id_node is None or raw_node is None or tree_node is None: raise ValueError("Missing node in sentence {}".format(sentence)) tree_id = "".join(id_node.itertext()) tree_text = "".join(tree_node.itertext()) trees.append((tree_id, tree_text)) return trees def convert_cintil_treebank(input_filename, train_size=0.8, dev_size=0.1): """ dev_size is the size for splitting train & dev """ trees = read_xml_file(input_filename) synthetic_trees = [] natural_trees = [] for tree_id, tree_text in trees: if tree_text.find(" _") >= 0: raise ValueError("Unexpected underscore") tree_text = tree_text.replace("_)", ")") tree_text = tree_text.replace("(A (", "(A' (") # trees don't have ROOT, but we typically use a ROOT label at the top tree_text = "(ROOT %s)" % tree_text trees = tree_reader.read_trees(tree_text) if len(trees) != 1: raise ValueError("Unexpectedly found %d trees in %s" % (len(trees), tree_id)) tree = trees[0] if tree_id.startswith("aTSTS"): synthetic_trees.append(tree) elif tree_id.find("TSTS") >= 0: raise ValueError("Unexpected TSTS") else: natural_trees.append(tree) print("Read %d synthetic trees" % len(synthetic_trees)) print("Read %d natural trees" % len(natural_trees)) train_trees, dev_trees, test_trees = utils.split_treebank(natural_trees, train_size, dev_size) print("Split %d trees into %d train %d dev %d test" % (len(natural_trees), len(train_trees), len(dev_trees), len(test_trees))) train_trees = synthetic_trees + train_trees print("Total lengths %d train %d dev %d test" % (len(train_trees), len(dev_trees), len(test_trees))) return train_trees, dev_trees, test_trees def main(): treebank = convert_cintil_treebank("extern_data/constituency/portuguese/CINTIL/CINTIL-Treebank.xml") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/convert_ctb.py ================================================ from enum import Enum import glob import os import re import xml.etree.ElementTree as ET from stanza.models.constituency import tree_reader from stanza.utils.datasets.constituency.utils import write_dataset from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() class Version(Enum): V51 = 1 V51b = 2 V90 = 3 def filenum_to_shard_51(filenum): if filenum >= 1 and filenum <= 815: return 0 if filenum >= 1001 and filenum <= 1136: return 0 if filenum >= 886 and filenum <= 931: return 1 if filenum >= 1148 and filenum <= 1151: return 1 if filenum >= 816 and filenum <= 885: return 2 if filenum >= 1137 and filenum <= 1147: return 2 raise ValueError("Unhandled filenum %d" % filenum) def filenum_to_shard_51_basic(filenum): if filenum >= 1 and filenum <= 270: return 0 if filenum >= 440 and filenum <= 1151: return 0 if filenum >= 301 and filenum <= 325: return 1 if filenum >= 271 and filenum <= 300: return 2 if filenum >= 400 and filenum <= 439: return None raise ValueError("Unhandled filenum %d" % filenum) def filenum_to_shard_90(filenum): if filenum >= 1 and filenum <= 40: return 2 if filenum >= 900 and filenum <= 931: return 2 if filenum in (1018, 1020, 1036, 1044, 1060, 1061, 1072, 1118, 1119, 1132, 1141, 1142, 1148): return 2 if filenum >= 2165 and filenum <= 2180: return 2 if filenum >= 2295 and filenum <= 2310: return 2 if filenum >= 2570 and filenum <= 2602: return 2 if filenum >= 2800 and filenum <= 2819: return 2 if filenum >= 3110 and filenum <= 3145: return 2 if filenum >= 41 and filenum <= 80: return 1 if filenum >= 1120 and filenum <= 1129: return 1 if filenum >= 2140 and filenum <= 2159: return 1 if filenum >= 2280 and filenum <= 2294: return 1 if filenum >= 2550 and filenum <= 2569: return 1 if filenum >= 2775 and filenum <= 2799: return 1 if filenum >= 3080 and filenum <= 3109: return 1 if filenum >= 81 and filenum <= 900: return 0 if filenum >= 1001 and filenum <= 1017: return 0 if filenum in (1019, 1130, 1131): return 0 if filenum >= 1021 and filenum <= 1035: return 0 if filenum >= 1037 and filenum <= 1043: return 0 if filenum >= 1045 and filenum <= 1059: return 0 if filenum >= 1062 and filenum <= 1071: return 0 if filenum >= 1073 and filenum <= 1117: return 0 if filenum >= 1133 and filenum <= 1140: return 0 if filenum >= 1143 and filenum <= 1147: return 0 if filenum >= 1149 and filenum <= 2139: return 0 if filenum >= 2160 and filenum <= 2164: return 0 if filenum >= 2181 and filenum <= 2279: return 0 if filenum >= 2311 and filenum <= 2549: return 0 if filenum >= 2603 and filenum <= 2774: return 0 if filenum >= 2820 and filenum <= 3079: return 0 if filenum >= 4000 and filenum <= 7017: return 0 def collect_trees_s(root): if root.tag == 'S': yield root.text, root.attrib['ID'] for child in root: for tree in collect_trees_s(child): yield tree def collect_trees_text(root): if root.tag == 'TEXT' and len(root.text.strip()) > 0: yield root.text, None if root.tag == 'TURN' and len(root.text.strip()) > 0: yield root.text, None for child in root: for tree in collect_trees_text(child): yield tree id_re = re.compile("") su_re = re.compile("<(su|msg) id=([0-9a-zA-Z_=]+)>") def convert_ctb(input_dir, output_dir, dataset_name, version): input_files = glob.glob(os.path.join(input_dir, "*")) # train, dev, test datasets = [[], [], []] sorted_filenames = [] for input_filename in input_files: base_filename = os.path.split(input_filename)[1] filenum = int(os.path.splitext(base_filename)[0].split("_")[1]) sorted_filenames.append((filenum, input_filename)) sorted_filenames.sort() for filenum, filename in tqdm(sorted_filenames): if version in (Version.V51, Version.V51b): with open(filename, errors='ignore', encoding="gb2312") as fin: text = fin.read() elif version is Version.V90: with open(filename, encoding="utf-8") as fin: text = fin.read() if text.find("") >= 0 and text.find("") < 0: text = text.replace("", "") if filenum in (4205, 4208, 4289): text = text.replace("<)", "<)").replace(">)", ">)") if filenum >= 4000 and filenum <= 4411: if text.find("= 0: text = text.replace("", "") elif text.find("", "") text = "\n%s\n" % text if filenum >= 5000 and filenum <= 5558 or filenum >= 6000 and filenum <= 6700 or filenum >= 7000 and filenum <= 7017: text = su_re.sub("", text) if filenum in (6066, 6453): text = text.replace("<", "<").replace(">", ">") text = "\n%s\n" % text else: raise ValueError("Unknown CTB version %s" % version) text = id_re.sub(r'', text) text = text.replace("&", "&") try: xml_root = ET.fromstring(text) except Exception as e: print(text[:1000]) raise RuntimeError("Cannot xml process %s" % filename) from e trees = [x for x in collect_trees_s(xml_root)] if version is Version.V90 and len(trees) == 0: trees = [x for x in collect_trees_text(xml_root)] if version in (Version.V51, Version.V51b): trees = [x[0] for x in trees if filenum != 414 or x[1] != "4366"] else: trees = [x[0] for x in trees] trees = "\n".join(trees) try: trees = tree_reader.read_trees(trees, use_tqdm=False) except ValueError as e: print(text[:300]) raise RuntimeError("Could not process the tree text in %s" % filename) trees = [t.prune_none().simplify_labels() for t in trees] assert len(trees) > 0, "No trees in %s" % filename if version is Version.V51: shard = filenum_to_shard_51(filenum) elif version is Version.V51b: shard = filenum_to_shard_51_basic(filenum) else: shard = filenum_to_shard_90(filenum) if shard is None: continue datasets[shard].extend(trees) write_dataset(datasets, output_dir, dataset_name) ================================================ FILE: stanza/utils/datasets/constituency/convert_icepahc.py ================================================ """ Currently this doesn't function The goal is simply to demonstrate how to use tsurgeon """ from stanza.models.constituency.tree_reader import read_trees, read_treebank from stanza.server import tsurgeon TREEBANK = """ ( (IP-MAT (NP-SBJ (PRO-N Það-það)) (BEPI er-vera) (ADVP (ADV eiginlega-eiginlega)) (ADJP (NEG ekki-ekki) (ADJ-N hægt-hægur)) (IP-INF (TO að-að) (VB lýsa-lýsa)) (NP-OB1 (N-D tilfinningu$-tilfinning) (D-D $nni-hinn)) (IP-INF (TO að-að) (VB fá-fá)) (IP-INF (TO að-að) (VB taka-taka)) (NP-OB1 (N-A þátt-þáttur)) (PP (P í-í) (NP (D-D þessu-þessi))) (, ,-,) (VBPI segir-segja) (NP-SBJ (NPR-N Sverrir-sverrir) (NPR-N Ingi-ingi)) (. .-.))) """ # Output of the first tsurgeon: #(ROOT # (IP-MAT # (NP-SBJ (PRO-N Það)) # (BEPI er) # (ADVP (ADV eiginlega)) # (ADJP (NEG ekki) (ADJ-N hægt)) # (IP-INF (TO að) (VB lýsa)) # (NP-OB1 (N-D tilfinningu$) (D-D $nni)) # (IP-INF (TO að) (VB fá)) # (IP-INF (TO að) (VB taka)) # (NP-OB1 (N-A þátt)) # (PP # (P í) # (NP (D-D þessu))) # (, ,) # (VBPI segir) # (NP-SBJ (NPR-N Sverrir) (NPR-N Ingi)) # (. .))) # Output of the second operation #(ROOT # (IP-MAT # (NP-SBJ (PRO-N Það)) # (BEPI er) # (ADVP (ADV eiginlega)) # (ADJP (NEG ekki) (ADJ-N hægt)) # (IP-INF (TO að) (VB lýsa)) # (NP-OB1 (N-D tilfinningunni)) # (IP-INF (TO að) (VB fá)) # (IP-INF (TO að) (VB taka)) # (NP-OB1 (N-A þátt)) # (PP # (P í) # (NP (D-D þessu))) # (, ,) # (VBPI segir) # (NP-SBJ (NPR-N Sverrir) (NPR-N Ingi)) # (. .))) treebank = read_trees(TREEBANK) with tsurgeon.Tsurgeon(classpath="$CLASSPATH") as tsurgeon_processor: form_tregex = "/^(.+)-.+$/#1%form=word !< __" form_tsurgeon = "relabel word /^.+$/%{form}/" noun_det_tregex = "/^N-/ < /^([^$]+)[$]$/#1%noun=noun $+ (/^D-/ < /^[$]([^$]+)$/#1%det=det)" noun_det_relabel = "relabel noun /^.+$/%{noun}%{det}/" noun_det_prune = "prune det" for tree in treebank: updated_tree = tsurgeon_processor.process(tree, (form_tregex, form_tsurgeon))[0] print("{:P}".format(updated_tree)) updated_tree = tsurgeon_processor.process(updated_tree, (noun_det_tregex, noun_det_relabel, noun_det_prune))[0] print("{:P}".format(updated_tree)) ================================================ FILE: stanza/utils/datasets/constituency/convert_it_turin.py ================================================ """ Converts Turin's constituency dataset Turin University put out a freely available constituency dataset in 2011. It is not as large as VIT or ISST, but it is free, which is nice. The 2011 parsing task combines trees from several sources: http://www.di.unito.it/~tutreeb/evalita-parsingtask-11.html There is another site for Turin treebanks: http://www.di.unito.it/~tutreeb/treebanks.html Weirdly, the most recent versions of the Evalita trees are not there. The most relevant parts are the ParTUT downloads. As of Sep. 2021: http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/JRCAcquis_It.pen http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/UDHR_It.pen http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/CC_It.pen http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/FB_It.pen http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/WIT3_It.pen We can't simply cat all these files together as there are a bunch of asterisks as comments and the files may have some duplicates. For example, the JRCAcquis piece has many duplicates. Also, some don't pass validation for one reason or another. One oddity of these data files is that the MWT are denoted by doubling the token. The token is not split as would be expected, though. We try to use stanza's MWT tokenizer for IT to split the tokens, with some rules added by hand in BIWORD_SPLITS. Two are still unsplit, though... """ import glob import os import re import sys import stanza from stanza.models.constituency import parse_tree from stanza.models.constituency import tree_reader def load_without_asterisks(in_file, encoding='utf-8'): with open(in_file, encoding=encoding) as fin: new_lines = [x if x.find("********") < 0 else "\n" for x in fin.readlines()] if len(new_lines) > 0 and not new_lines[-1].endswith("\n"): new_lines[-1] = new_lines[-1] + "\n" return new_lines CONSTITUENT_SPLIT = re.compile("[-=#+0-9]") # JRCA is almost entirely duplicates # WIT3 follows a different annotation scheme FILES_TO_ELIMINATE = ["JRCAcquis_It.pen", "WIT3_It.pen"] # assuming this is a typo REMAP_NODES = { "Sbar" : "SBAR" } REMAP_WORDS = { "-LSB-": "[", "-RSB-": "]" } # these mostly seem to be mistakes # maybe Vbar and ADVbar should be converted to something else? NODES_TO_ELIMINATE = ["C", "PHRASP", "PRDT", "Vbar", "parte", "ADVbar"] UNKNOWN_SPLITS = set() # a map of splits that the tokenizer or MWT doesn't handle well BIWORD_SPLITS = { "offertogli": ("offerto", "gli"), "offertegli": ("offerte", "gli"), "formatasi": ("formata", "si"), "formatosi": ("formato", "si"), "multiplexarlo": ("multiplexar", "lo"), "esibirsi": ("esibir", "si"), "pagarne": ("pagar", "ne"), "recarsi": ("recar", "si"), "trarne": ("trar", "ne"), "esserci": ("esser", "ci"), "aprirne": ("aprir", "ne"), "farle": ("far", "le"), "disporne": ("dispor", "ne"), "andargli": ("andar", "gli"), "CONSIDERARSI": ("CONSIDERAR", "SI"), "conferitegli": ("conferite", "gli"), "formatasi": ("formata", "si"), "formatosi": ("formato", "si"), "Formatisi": ("Formati", "si"), "multiplexarlo": ("multiplexar", "lo"), "esibirsi": ("esibir", "si"), "pagarne": ("pagar", "ne"), "recarsi": ("recar", "si"), "trarne": ("trar", "ne"), "temerne": ("temer", "ne"), "esserci": ("esser", "ci"), "esservi": ("esser", "vi"), "restituirne": ("restituir", "ne"), "col": ("con", "il"), "cogli": ("con", "gli"), "dirgli": ("dir", "gli"), "opporgli": ("oppor", "gli"), "eccolo": ("ecco", "lo"), "Eccolo": ("Ecco", "lo"), "Eccole": ("Ecco", "le"), "farci": ("far", "ci"), "farli": ("far", "li"), "farne": ("far", "ne"), "farsi": ("far", "si"), "farvi": ("far", "vi"), "Connettiti": ("Connetti", "ti"), "APPLICARSI": ("APPLICAR", "SI"), # This is not always two words, but if it IS two words, # it gets split like this "assicurati": ("assicura", "ti"), "Fatti": ("Fai", "te"), "ai": ("a", "i"), "Ai": ("A", "i"), "AI": ("A", "I"), "al": ("a", "il"), "Al": ("A", "il"), "AL": ("A", "IL"), "coi": ("con", "i"), "colla": ("con", "la"), "colle": ("con", "le"), "dal": ("da", "il"), "Dal": ("Da", "il"), "DAL": ("DA", "IL"), "dei": ("di", "i"), "Dei": ("Di", "i"), "DEI": ("DI", "I"), "del": ("di", "il"), "Del": ("Di", "il"), "DEL": ("DI", "IL"), "nei": ("in", "i"), "NEI": ("IN", "I"), "nel": ("in", "il"), "Nel": ("In", "il"), "NEL": ("IN", "IL"), "pel": ("per", "il"), "sui": ("su", "i"), "Sui": ("Su", "i"), "sul": ("su", "il"), "Sul": ("Su", "il"), ",": (",", ","), ".": (".", "."), '"': ('"', '"'), '-': ('-', '-'), '-LRB-': ('-LRB-', '-LRB-'), "garantirne": ("garantir", "ne"), "aprirvi": ("aprir", "vi"), "esimersi": ("esimer", "si"), "opporsi": ("oppor", "si"), } CAP_BIWORD = re.compile("[A-Z]+_[A-Z]+") def split_mwe(tree, pipeline): words = list(tree.leaf_labels()) found = False for idx, word in enumerate(words[:-3]): if word == words[idx+1] and word == words[idx+2] and word == words[idx+3]: raise ValueError("Oh no, 4 consecutive words") for idx, word in enumerate(words[:-2]): if word == words[idx+1] and word == words[idx+2]: doc = pipeline(word) assert len(doc.sentences) == 1 if len(doc.sentences[0].words) != 3: raise RuntimeError("Word {} not tokenized into 3 parts... thought all 3 part words were handled!".format(word)) words[idx] = doc.sentences[0].words[0].text words[idx+1] = doc.sentences[0].words[1].text words[idx+2] = doc.sentences[0].words[2].text found = True for idx, word in enumerate(words[:-1]): if word == words[idx+1]: if word in BIWORD_SPLITS: first_word = BIWORD_SPLITS[word][0] second_word = BIWORD_SPLITS[word][1] elif CAP_BIWORD.match(word): first_word, second_word = word.split("_") else: doc = pipeline(word) assert len(doc.sentences) == 1 if len(doc.sentences[0].words) == 2: first_word = doc.sentences[0].words[0].text second_word = doc.sentences[0].words[1].text else: if word not in UNKNOWN_SPLITS: UNKNOWN_SPLITS.add(word) print("Could not figure out how to split {}\n {}\n {}".format(word, " ".join(words), tree)) continue words[idx] = first_word words[idx+1] = second_word found = True if found: tree = tree.replace_words(words) return tree def load_trees(filename, pipeline): # some of the files are in latin-1 encoding rather than utf-8 try: raw_text = load_without_asterisks(filename, "utf-8") except UnicodeDecodeError: raw_text = load_without_asterisks(filename, "latin-1") # also, some have messed up validation (it will be logged) # hence the broken_ok=True argument trees = tree_reader.read_trees("".join(raw_text), broken_ok=True) filtered_trees = [] for tree in trees: if tree.children[0].label is None: print("Skipping a broken tree (missing label) in {}: {}".format(filename, tree)) continue try: words = tuple(tree.leaf_labels()) except ValueError: print("Skipping a broken tree (missing preterminal) in {}: {}".format(filename, tree)) continue if any('www.facebook' in pt.label for pt in tree.preterminals()): print("Skipping a tree with a weird preterminal label in {}: {}".format(filename, tree)) continue tree = tree.prune_none().simplify_labels(CONSTITUENT_SPLIT) if len(tree.children) > 1: print("Found a tree with a non-unary root! {}: {}".format(filename, tree)) continue if tree.children[0].is_preterminal(): print("Found a tree with a single preterminal node! {}: {}".format(filename, tree)) continue # The expectation is that the retagging will handle this anyway for pt in tree.preterminals(): if not pt.label: pt.label = "UNK" print("Found a tree with a blank preterminal label. Setting it to UNK. {}: {}".format(filename, tree)) tree = tree.remap_constituent_labels(REMAP_NODES) tree = tree.remap_words(REMAP_WORDS) tree = split_mwe(tree, pipeline) if tree is None: continue constituents = set(parse_tree.Tree.get_unique_constituent_labels(tree)) for weird_label in NODES_TO_ELIMINATE: if weird_label in constituents: break else: weird_label = None if weird_label is not None: print("Skipping a tree with a weird label {} in {}: {}".format(weird_label, filename, tree)) continue filtered_trees.append(tree) return filtered_trees def save_trees(out_file, trees): print("Saving {} trees to {}".format(len(trees), out_file)) with open(out_file, "w", encoding="utf-8") as fout: for tree in trees: fout.write(str(tree)) fout.write("\n") def convert_it_turin(input_path, output_path): pipeline = stanza.Pipeline("it", processors="tokenize, mwt", tokenize_no_ssplit=True) os.makedirs(output_path, exist_ok=True) evalita_dir = os.path.join(input_path, "evalita") evalita_test = os.path.join(evalita_dir, "evalita11_TESTgold_CONPARSE.penn") it_test = os.path.join(output_path, "it_turin_test.mrg") test_trees = load_trees(evalita_test, pipeline) save_trees(it_test, test_trees) known_text = set() for tree in test_trees: words = tuple(tree.leaf_labels()) assert words not in known_text known_text.add(words) evalita_train = os.path.join(output_path, "it_turin_train.mrg") evalita_files = glob.glob(os.path.join(evalita_dir, "*2011*penn")) turin_files = glob.glob(os.path.join(input_path, "turin", "*pen")) filenames = evalita_files + turin_files filtered_trees = [] for filename in filenames: if os.path.split(filename)[1] in FILES_TO_ELIMINATE: continue trees = load_trees(filename, pipeline) file_trees = [] for tree in trees: words = tuple(tree.leaf_labels()) if words in known_text: print("Skipping a duplicate in {}: {}".format(filename, tree)) continue known_text.add(words) file_trees.append(tree) filtered_trees.append((filename, file_trees)) print("{} contains {} usable trees".format(evalita_test, len(test_trees))) print(" Unique constituents in {}: {}".format(evalita_test, parse_tree.Tree.get_unique_constituent_labels(test_trees))) train_trees = [] dev_trees = [] for filename, file_trees in filtered_trees: print("{} contains {} usable trees".format(filename, len(file_trees))) print(" Unique constituents in {}: {}".format(filename, parse_tree.Tree.get_unique_constituent_labels(file_trees))) for tree in file_trees: if len(train_trees) <= len(dev_trees) * 9: train_trees.append(tree) else: dev_trees.append(tree) it_train = os.path.join(output_path, "it_turin_train.mrg") save_trees(it_train, train_trees) it_dev = os.path.join(output_path, "it_turin_dev.mrg") save_trees(it_dev, dev_trees) def main(): input_path = sys.argv[1] output_path = sys.argv[2] convert_it_turin(input_path, output_path) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/convert_it_vit.py ================================================ """Converts the proprietary VIT dataset to a format suitable for stanza There are multiple corrections in the UD version of VIT, along with recommended splits for the MWT, along with recommended splits of the sentences into train/dev/test Accordingly, it is necessary to use the UD dataset as a reference Here is a sample line of the text file we use: #ID=sent_00002 cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]] Here you can already see multiple issues when parsing: - the first word is "negli", which is split into In_ADP gli_DET in the UD version - also the first word is capitalized in the UD version - comma looks like a tempting split target, but there is a ',' in this sentence punt-',' - not shown here is '-' which is different from the - used for denoting POS par-'-' Fortunately, -[ is always an open and ] is always a close As of April 2022, the UD version of the dataset has some minor edits which are necessary for the proper functioning of this script. Otherwise, the MWT won't align correctly, some typos won't be corrected, etc. These edits are released in UD 2.10 The data itself is available from ELRA: http://catalog.elra.info/en-us/repository/browse/ELRA-W0040/ Internally at Stanford you can contact Chris Manning or John Bauer. The processing goes as follows: - read in UD and con trees some of the con trees have broken brackets and are discarded in other cases, abbreviations were turned into single tokens in UD - extract the MWT expansions of Italian contractions, such as Negli -> In gli - attempt to align the trees from the two datasets using ngrams some trees had the sentence splitting updated sentences which can't be matched are discarded - use CoreNLP tsurgeon to update tokens in the con trees based on the information in the UD dataset - split contractions - rearrange clitics which are occasionally non-projective - replace the words in the con tree with the dep tree's words this takes advantage of spelling & capitalization fixes In 2022, there was an update to the dataset from Prof. Delmonte. This update is hopefully in current ELRA distributions now. If not, please contact ELRA to specifically ask for the updated version. Internally to Stanford, feel free to ask Chris or John for the updates. Look for the line below "original version with more errors" In August 2022, Prof. Delmonte made a slight update in a zip file `john.zip`. If/when that gets updated to ELRA, we will update it here. Contact Chris or John for a copy if not updated yet, or go back in git history to get the older version of the code which works with the 2022 ELRA update. Later, in September 2022, there is yet another update, New version of VIT.zip Unzip the contents into a folder $CONSTITUENCY_BASE/italian/it_vit so there should be a file $CONSTITUENCY_BASE/italian/it_vit/VITwritten/VITconstsyntNumb There are a few other updates needed to improve the annotations, but all the nagging seemed to give Prof. Delmonte a headache, so at this point we include those fixes in this script instead. See the first few tsurgeon operations in update_mwts_and_special_cases """ from collections import defaultdict, deque, namedtuple import itertools import os import re import sys from tqdm import tqdm from stanza.models.constituency.tree_reader import read_trees, UnclosedTreeError, ExtraCloseTreeError from stanza.server import tsurgeon from stanza.utils.conll import CoNLL from stanza.utils.datasets.constituency.utils import SHARDS, write_dataset import stanza.utils.default_paths as default_paths def read_constituency_sentences(fin): """ Reads the lines from the constituency treebank and splits into ID, text No further processing is done on the trees yet """ sentences = [] for line in fin: line = line.strip() # WTF why doesn't strip() remove this line = line.replace(u'\ufeff', '') if not line: continue sent_id, sent_text = line.split(maxsplit=1) # we have seen a couple different versions of this sentence header # although one file is always consistent with itself, at least if not sent_id.startswith("#ID=sent") and not sent_id.startswith("ID#sent"): raise ValueError("Unexpected start of sentence: |{}|".format(sent_id)) if not sent_text: raise ValueError("Empty text for |{}|".format(sent_id)) sentences.append((sent_id, sent_text)) return sentences def read_constituency_file(filename): print("Reading raw constituencies from %s" % filename) with open(filename, encoding='utf-8') as fin: return read_constituency_sentences(fin) OPEN = "-[" CLOSE = "]" DATE_RE = re.compile("^([0-9]{1,2})[_]([0-9]{2})$") INTEGER_PERCENT_RE = re.compile(r"^((?:min|plus)?[0-9]{1,3})[%]$") DECIMAL_PERCENT_RE = re.compile(r"^((?:min|plus)?[0-9]{1,3})[/_]([0-9]{1,3})[%]$") RANGE_PERCENT_RE = re.compile(r"^([0-9]{1,2}[/_][0-9]{1,2})[/]([0-9]{1,2}[/_][0-9]{1,2})[%]$") DECIMAL_RE = re.compile(r"^([0-9])[_]([0-9])$") ProcessedTree = namedtuple('ProcessedTree', ['con_id', 'dep_id', 'tree']) def raw_tree(text): """ A sentence will look like this: #ID=sent_00001 fc-[f3-[sn-[art-le, n-infrastrutture, sc-[ccom-come, sn-[n-fattore, spd-[pd-di, sn-[n-competitività]]]]]], f3-[spd-[pd-di, sn-[mw-Angela, nh-Airoldi]]], punto-.] Non-preterminal nodes have tags, followed by the stuff under the node, -[ The node is closed by the ] """ pieces = [] open_pieces = text.split(OPEN) for open_idx, open_piece in enumerate(open_pieces): if open_idx > 0: pieces[-1] = pieces[-1] + OPEN open_piece = open_piece.strip() if not open_piece: raise ValueError("Unexpected empty node!") close_pieces = open_piece.split(CLOSE) for close_idx, close_piece in enumerate(close_pieces): if close_idx > 0: pieces.append(CLOSE) close_piece = close_piece.strip() if not close_piece: # this is okay - multiple closes at the end of a deep bracket continue word_pieces = close_piece.split(", ") pieces.extend([x.strip() for x in word_pieces if x.strip()]) # at this point, pieces is a list with: # tag-[ for opens # tag-word for words # ] for closes # this structure converts pretty well to reading using the tree reader PIECE_MAPPING = { "agn-/ter'": "(agn ter)", "cong-'&'": "(cong &)", "da_riempire-'...'": "(da_riempire ...)", "date-1992_1993": "(date 1992/1993)", "date-'31-12-95'": "(date 31-12-95)", "date-'novantaquattro-95'":"(date novantaquattro-95)", "date-'novantaquattro-95": "(date novantaquattro-95)", "date-'novantaquattro-novantacinque'": "(date novantaquattro-novantacinque)", "dirs-':'": "(dirs :)", "dirs-'\"'": "(dirs \")", "mw-'&'": "(mw &)", "mw-'Presunto'": "(mw Presunto)", "nh-'Alain-Gauze'": "(nh Alain-Gauze)", "np-'porto_Marghera'": "(np Porto) (np Marghera)", "np-'roma-l_aquila'": "(np Roma-L'Aquila)", "np-'L_Aquila-Villa_Vomano'": "(np L'Aquila) (np -) (np Villa) (np Vomano)", "npro-'Avanti_!'": "(npro Avanti_!)", "npro-'Viacom-Paramount'": "(npro Viacom-Paramount)", "npro-'Rhone-Poulenc'": "(npro Rhone-Poulenc)", "npro-'Itar-Tass'": "(npro Itar-Tass)", "par-(-)": "(par -)", "par-','": "(par ,)", "par-'<'": "(par <)", "par-'>'": "(par >)", "par-'-'": "(par -)", "par-'\"'": "(par \")", "par-'('": "(par -LRB-)", "par-')'": "(par -RRB-)", "par-'&&'": "(par &&)", "punt-','": "(punt ,)", "punt-'-'": "(punt -)", "punt-';'": "(punt ;)", "punto-':'": "(punto :)", "punto-';'": "(punto ;)", "puntint-'!'": "(puntint !)", "puntint-'?'": "(puntint !)", "num-'2plus2'": "(num 2+2)", "num-/bis'": "(num bis)", "num-/ter'": "(num ter)", "num-18_00/1_00": "(num 18:00/1:00)", "num-1/500_2/000": "(num 1.500-2.000)", "num-16_1": "(num 16,1)", "num-0_1": "(num 0,1)", "num-0_3": "(num 0,3)", "num-2_7": "(num 2,7)", "num-455_68": "(num 455/68)", "num-437_5": "(num 437,5)", "num-4708_82": "(num 4708,82)", "num-16EQ517_7": "(num 16EQ517/7)", "num-2=184_90": "(num 2=184/90)", "num-3EQ429_20": "(num 3eq429/20)", "num-'1990-EQU-100'": "(num 1990-EQU-100)", "num-'500-EQU-250'": "(num 500-EQU-250)", "num-0_39%minus": "(num 0,39) (num %%) (num -)", "num-1_88/76": "(num 1-88/76)", "num-'70/80'": "(num 70,80)", "num-'18/20'": "(num 18:20)", "num-295/mila'": "(num 295mila)", "num-'295/mila'": "(num 295mila)", "num-0/07%plus": "(num 0,07) (num %%) (num plus)", "num-0/69%minus": "(num 0,69) (num %%) (num minus)", "num-0_39%minus": "(num 0,39) (num %%) (num minus)", "num-9_11/16": "(num 9-11,16)", "num-2/184_90": "(num 2=184/90)", "num-3/429_20": "(num 3eq429/20)", # TODO: remove the following num conversions if possible # this would require editing either constituency or UD "num-1:28_124": "(num 1=8/1242)", "num-1:28_397": "(num 1=8/3972)", "num-1:28_947": "(num 1=8/9472)", "num-1:29_657": "(num 1=9/6572)", "num-1:29_867": "(num 1=9/8672)", "num-1:29_874": "(num 1=9/8742)", "num-1:30_083": "(num 1=0/0833)", "num-1:30_140": "(num 1=0/1403)", "num-1:30_354": "(num 1=0/3543)", "num-1:30_453": "(num 1=0/4533)", "num-1:30_946": "(num 1=0/9463)", "num-1:31_602": "(num 1=1/6023)", "num-1:31_842": "(num 1=1/8423)", "num-1:32_087": "(num 1=2/0873)", "num-1:32_259": "(num 1=2/2593)", "num-1:33_166": "(num 1=3/1663)", "num-1:34_154": "(num 1=4/1543)", "num-1:34_556": "(num 1=4/5563)", "num-1:35_323": "(num 1=5/3233)", "num-1:36_023": "(num 1=6/0233)", "num-1:36_076": "(num 1=6/0763)", "num-1:36_651": "(num 1=6/6513)", "n-giga_flop/s": "(n giga_flop/s)", "sect-'g-1'": "(sect g-1)", "sect-'h-1'": "(sect h-1)", "sect-'h-2'": "(sect h-2)", "sect-'h-3'": "(sect h-3)", "abbr-'a-b-c'": "(abbr a-b-c)", "abbr-d_o_a_": "(abbr DOA)", "abbr-d_l_": "(abbr DL)", "abbr-i_s_e_f_": "(abbr ISEF)", "abbr-d_p_r_": "(abbr DPR)", "abbr-D_P_R_": "(abbr DPR)", "abbr-d_m_": "(abbr dm)", "abbr-T_U_": "(abbr TU)", "abbr-F_A_M_E_": "(abbr Fame)", "dots-'...'": "(dots ...)", } new_pieces = ["(ROOT "] for piece in pieces: if piece.endswith(OPEN): new_pieces.append("(" + piece[:-2]) elif piece == CLOSE: new_pieces.append(")") elif piece in PIECE_MAPPING: new_pieces.append(PIECE_MAPPING[piece]) else: # maxsplit=1 because of words like 1990-EQU-100 tag, word = piece.split("-", maxsplit=1) if word.find("'") >= 0 or word.find("(") >= 0 or word.find(")") >= 0: raise ValueError("Unhandled weird node: {}".format(piece)) if word.endswith("_"): word = word[:-1] + "'" date_match = DATE_RE.match(word) if date_match: # 10_30 special case sent_07072 # 16_30 special case sent_07098 # 21_15 special case sent_07099 and others word = date_match.group(1) + ":" + date_match.group(2) integer_percent = INTEGER_PERCENT_RE.match(word) if integer_percent: word = integer_percent.group(1) + "_%%" range_percent = RANGE_PERCENT_RE.match(word) if range_percent: word = range_percent.group(1) + "," + range_percent.group(2) + "_%%" percent = DECIMAL_PERCENT_RE.match(word) if percent: word = percent.group(1) + "," + percent.group(2) + "_%%" decimal = DECIMAL_RE.match(word) if decimal: word = decimal.group(1) + "," + decimal.group(2) # there are words which are multiple words mashed together # with _ for some reason # also, words which end in ' are replaced with _ # fortunately, no words seem to have both # splitting like this means the tags are likely wrong, # but the conparser needs to retag anyway, so it shouldn't matter word_pieces = word.split("_") for word_piece in word_pieces: new_pieces.append("(%s %s)" % (tag, word_piece)) new_pieces.append(")") text = " ".join(new_pieces) trees = read_trees(text) if len(trees) > 1: raise ValueError("Unexpected number of trees!") return trees[0] def extract_ngrams(sentence, process_func, ngram_len=4): leaf_words = [x for x in process_func(sentence)] leaf_words = ["l'" if x == "l" else x for x in leaf_words] if len(leaf_words) <= ngram_len: return [tuple(leaf_words)] its = [leaf_words[i:i+len(leaf_words)-ngram_len+1] for i in range(ngram_len)] return [words for words in itertools.zip_longest(*its)] def build_ngrams(sentences, process_func, id_func, ngram_len=4): """ Turn the list of processed trees into a bunch of ngrams The returned map is from tuple to set of ids The idea being that this map can be used to search for trees to match datasets """ ngram_map = defaultdict(set) for sentence in tqdm(sentences, postfix="Extracting ngrams"): sentence_id = id_func(sentence) ngrams = extract_ngrams(sentence, process_func, ngram_len) for ngram in ngrams: ngram_map[ngram].add(sentence_id) return ngram_map # just the tokens (maybe use words? depends on MWT in the con dataset) DEP_PROCESS_FUNC = lambda x: [t.text.lower() for t in x.tokens] # find the comment with "sent_id" in it, take just the id itself DEP_ID_FUNC = lambda x: [c for c in x.comments if c.startswith("# sent_id")][0].split()[-1] CON_PROCESS_FUNC = lambda x: [y.lower() for y in x.leaf_labels()] def match_ngrams(sentence_ngrams, ngram_map, debug=False): """ Check if there is a SINGLE matching sentence in the ngram_map for these ngrams If an ngram shows up in multiple sentences, that is okay, but we ignore that info If an ngram shows up in just one sentence, that is considered the match If a different ngram then shows up in a different sentence, that is a problem TODO: taking the intersection of all non-empty matches might be better """ if debug: print("NGRAMS FOR DEBUG SENTENCE:") potential_match = None unknown_ngram = 0 for ngram in sentence_ngrams: con_matches = ngram_map[ngram] if debug: print("{} matched {}".format(ngram, len(con_matches))) if len(con_matches) == 0: unknown_ngram += 1 continue if len(con_matches) > 1: continue # get the one & only element from the set con_match = next(iter(con_matches)) if debug: print(" {}".format(con_match)) if potential_match is None: potential_match = con_match elif potential_match != con_match: return None if unknown_ngram > len(sentence_ngrams) / 2: return None return potential_match def match_sentences(con_tree_map, con_vit_ngrams, dep_sentences, split_name, debug_sentence=None): """ Match ngrams in the dependency sentences to the constituency sentences Then, to make sure the constituency sentence wasn't split into two in the UD dataset, this checks the ngrams in the reverse direction Some examples of things which don't match: VIT-4769 Insegnanti non vedenti, insegnanti non autosufficienti con protesi agli arti inferiori. this is duplicated in the original dataset, so the matching algorithm can't possibly work VIT-4796 I posti istituiti con attività di sostegno dei docenti che ottengono il trasferimento su classi di concorso; the correct con match should be sent_04829 but the brackets on that tree are broken """ con_to_dep_matches = {} dep_ngram_map = build_ngrams(dep_sentences, DEP_PROCESS_FUNC, DEP_ID_FUNC) unmatched = 0 bad_match = 0 for sentence in dep_sentences: sentence_ngrams = extract_ngrams(sentence, DEP_PROCESS_FUNC) potential_match = match_ngrams(sentence_ngrams, con_vit_ngrams, debug_sentence is not None and DEP_ID_FUNC(sentence) == debug_sentence) if potential_match is None: if unmatched < 5: print("Could not match the following sentence: {} {}".format(DEP_ID_FUNC(sentence), sentence.text)) unmatched += 1 continue if potential_match not in con_tree_map: raise ValueError("wtf") con_ngrams = extract_ngrams(con_tree_map[potential_match], CON_PROCESS_FUNC) reverse_match = match_ngrams(con_ngrams, dep_ngram_map) if reverse_match is None: #print("Matched sentence {} to sentence {} but the reverse match failed".format(sentence.text, " ".join(con_tree_map[potential_match].leaf_labels()))) bad_match += 1 continue con_to_dep_matches[potential_match] = reverse_match print("Failed to match %d sentences and found %d spurious matches in the %s section" % (unmatched, bad_match, split_name)) return con_to_dep_matches EXCEPTIONS = ["gliene", "glielo", "gliela", "eccoci"] def get_mwt(*dep_datasets): """ Get the ADP/DET MWTs from the UD dataset This class of MWT are expanded in the UD but not the constituencies """ mwt_map = {} for dataset in dep_datasets: for sentence in dataset.sentences: for token in sentence.tokens: if len(token.words) == 1: continue # words such as "accorgermene" we just skip over # those are already expanded in the constituency dataset # TODO: the clitics are actually expanded weirdly, maybe need to compensate for that if token.words[0].upos in ('VERB', 'AUX') and all(word.upos == 'PRON' for word in token.words[1:]): continue if token.text.lower() in EXCEPTIONS: continue if len(token.words) != 2 or token.words[0].upos != 'ADP' or token.words[1].upos != 'DET': raise ValueError("Not sure how to handle this: {}".format(token)) expansion = (token.words[0].text, token.words[1].text) if token.text in mwt_map: if mwt_map[token.text] != expansion: raise ValueError("Inconsistent MWT: {} -> {} or {}".format(token.text, expansion, mwt_map[token.text])) continue #print("Expanding {} to {}".format(token.text, expansion)) mwt_map[token.text] = expansion return mwt_map def update_mwts_and_special_cases(original_tree, dep_sentence, mwt_map, tsurgeon_processor): """ Replace MWT structures with their UD equivalents, along with some other minor tsurgeon based edits original_tree: the tree as read from VIT dep_sentence: the UD dependency dataset version of this sentence """ updated_tree = original_tree operations = [] # first, remove titles or testo from the start of a sentence con_words = updated_tree.leaf_labels() if con_words[0] == "Tit'": operations.append(["/^Tit'$/=prune !, __", "prune prune"]) elif con_words[0] == "TESTO": operations.append(["/^TESTO$/=prune !, __", "prune prune"]) elif con_words[0] == "testo": operations.append(["/^testo$/ !, __ . /^:$/=prune", "prune prune"]) operations.append(["/^testo$/=prune !, __", "prune prune"]) if len(con_words) >= 2 and con_words[-2] == '...' and con_words[-1] == '.': # the most recent VIT constituency has some sentence final . after a ... # the UD dataset has a more typical ... ending instead # these lines used to say "riempire" which was rather odd operations.append(["/^[.][.][.]$/ . /^[.]$/=prune", "prune prune"]) # a few constituent tags are simply errors which need to be fixed if original_tree.children[0].label == 'p': # 'p' shouldn't be at root operations.append(["_ROOT_ < p=p", "relabel p cp"]) # fix one specific tree if it has an s_top in it operations.append(["s_top=stop < (in=in < più=piu)", "replace piu (q più)", "relabel in sq", "relabel stop sa"]) # sect doesn't exist as a constituent. replace it with sa operations.append(["sect=sect < num", "relabel sect sa"]) # ppas as an internal node gets removed operations.append(["ppas=ppas < (__ < __)", "excise ppas ppas"]) # now assemble a bunch of regex to split and otherwise manipulate # the MWT in the trees for token in dep_sentence.tokens: if len(token.words) == 1: continue if token.text in mwt_map: mwt_pieces = mwt_map[token.text] if len(mwt_pieces) != 2: raise NotImplementedError("Expected exactly 2 pieces of mwt for %s" % token.text) # the MWT words in the UD version will have ' when needed, # but the corresponding ' is skipped in the con version of VIT, # hence the replace("'", "") # however, all' has the ' included, because this is a # constituent treebank, not a consistent treebank search_regex = "/^(?i:%s(?:')?)$/" % token.text.replace("'", "") # tags which seem to be relevant: # avvl|ccom|php|part|partd|partda tregex = "__ !> __ <<<%d (%s=child > (__=parent $+ sn=sn))" % (token.id[0], search_regex) tsurgeons = ["insert (art %s) >0 sn" % mwt_pieces[1], "relabel child %s" % mwt_pieces[0]] operations.append([tregex] + tsurgeons) tregex = "__ !> __ <<<%d (%s=child > (__=parent !$+ sn !$+ (art < %s)))" % (token.id[0], search_regex, mwt_pieces[1]) tsurgeons = ["insert (art %s) $- parent" % mwt_pieces[1], "relabel child %s" % mwt_pieces[0]] operations.append([tregex] + tsurgeons) elif len(token.words) == 2: #print("{} not in mwt_map".format(token.text)) # apparently some trees like sent_00381 and sent_05070 # have the clitic in a non-projective manner # [vcl-essersi, vppin-sparato, compt-[clitdat-si # intj-figurarsi, fs-[cosu-quando, f-[ibar-[clit-si # and before you ask, there are also clitics which are # simply not there at all, rather than always attached # in a non-projective manner tregex = "__=parent < (/^(?i:%s)$/=child . (__=np !< __ . (/^clit/=clit < %s)))" % (token.text, token.words[1].text) tsurgeon = "moveprune clit $- parent" operations.append([tregex, tsurgeon]) # there are also some trees which don't have clitics # for example, trees should look like this: # [ibar-[vsup-poteva, vcl-rivelarsi], compc-[clit-si, sn-[...]]] # however, at least one such example for rivelarsi instead # looks like this, with no corresponding clit # [... vcl-rivelarsi], compc-[sn-[in-ancora]] # note that is the actual tag, not just me being pissed off # breaking down the tregex: # the child is the original MWT, not split # !. clit verifies that it is not split (and stops the tsurgeon once fixed) # !$+ checks that the parent of the MWT is the last element under parent # note that !. can leave the immediate parent to touch the clit # neighbor will be the place the new clit will be sticking out tregex = "__=parent < (/^(?i:%s)$/=child !. /^clit/) !$+ __ > (__=gp $+ __=neighbor)" % token.text tsurgeon = "insert (clit %s) >0 neighbor" % token.words[1].text operations.append([tregex, tsurgeon]) # secondary option: while most trees are like the above, # with an outer bracket around the MWT and another verb, # some go straight into the next phrase # sent_05076 # sv5-[vcl-adeguandosi, compin-[sp-[part-alle, ... tregex = "__=parent < (/^(?i:%s)$/=child !. /^clit/) $+ __" % token.text tsurgeon = "insert (clit %s) $- parent" % token.words[1].text operations.append([tregex, tsurgeon]) else: pass if len(operations) > 0: updated_tree = tsurgeon_processor.process(updated_tree, *operations)[0] return updated_tree, operations def update_tree(original_tree, dep_sentence, con_id, dep_id, mwt_map, tsurgeon_processor): """ Update a tree using the mwt_map and tsurgeon to expand some MWTs Then replace the words in the con tree with the words in the dep tree """ ud_words = [x.text for x in dep_sentence.words] updated_tree, operations = update_mwts_and_special_cases(original_tree, dep_sentence, mwt_map, tsurgeon_processor) # this checks number of words try: updated_tree = updated_tree.replace_words(ud_words) except ValueError as e: raise ValueError("Failed to process {} {}:\nORIGINAL TREE\n{}\nUPDATED TREE\n{}\nUPDATED LEAVES\n{}\nUD TEXT\n{}\nTsurgeons applied:\n{}\n".format(con_id, dep_id, original_tree, updated_tree, updated_tree.leaf_labels(), ud_words, "\n".join("{}".format(op) for op in operations))) from e return updated_tree # train set: # 858: missing close parens in the UD conversion # 1169: 'che', 'poi', 'tutti', 'i', 'Paesi', 'ue', '.' -> 'per', 'tutti', 'i', 'paesi', 'Ue', '.' # 2375: the problem is inconsistent treatment of s_p_a_ # 05052: the heuristic to fill in a missing "si" doesn't work because there's # already another "si" immediately after # # test set: # 09764: weird punct at end # 10058: weird punct at end IGNORE_IDS = ["sent_00867", "sent_01169", "sent_02375", "sent_05052", "sent_09764", "sent_10058"] def extract_updated_dataset(con_tree_map, dep_sentence_map, split_ids, mwt_map, tsurgeon_processor): """ Update constituency trees using the information in the dependency treebank """ trees = [] for con_id, dep_id in tqdm(split_ids.items()): # skip a few trees which have non-MWT word modifications if con_id in IGNORE_IDS: continue original_tree = con_tree_map[con_id] dep_sentence = dep_sentence_map[dep_id] updated_tree = update_tree(original_tree, dep_sentence, con_id, dep_id, mwt_map, tsurgeon_processor) trees.append(ProcessedTree(con_id, dep_id, updated_tree)) return trees def read_updated_trees(paths, debug_sentence=None): # original version with more errors #con_filename = os.path.join(con_directory, "2011-12-20", "Archive", "VIT_newconstsynt.txt") # this is the April 2022 version #con_filename = os.path.join(con_directory, "VIT_newconstsynt.txt") # the most recent update from ELRA may look like this? # it's what we got, at least # con_filename = os.path.join(con_directory, "italian", "VITwritten", "VITconstsyntNumb") # needs at least UD 2.11 or this will not work con_directory = paths["CONSTITUENCY_BASE"] ud_directory = os.path.join(paths["UDBASE"], "UD_Italian-VIT") con_filename = os.path.join(con_directory, "italian", "it_vit", "VITwritten", "VITconstsyntNumb") ud_vit_train = os.path.join(ud_directory, "it_vit-ud-train.conllu") ud_vit_dev = os.path.join(ud_directory, "it_vit-ud-dev.conllu") ud_vit_test = os.path.join(ud_directory, "it_vit-ud-test.conllu") print("Reading UD train/dev/test from %s" % ud_directory) ud_train_data = CoNLL.conll2doc(input_file=ud_vit_train) ud_dev_data = CoNLL.conll2doc(input_file=ud_vit_dev) ud_test_data = CoNLL.conll2doc(input_file=ud_vit_test) ud_vit_train_map = { DEP_ID_FUNC(x) : x for x in ud_train_data.sentences } ud_vit_dev_map = { DEP_ID_FUNC(x) : x for x in ud_dev_data.sentences } ud_vit_test_map = { DEP_ID_FUNC(x) : x for x in ud_test_data.sentences } print("Getting ADP/DET expansions from UD data") mwt_map = get_mwt(ud_train_data, ud_dev_data, ud_test_data) con_sentences = read_constituency_file(con_filename) num_discarded = 0 con_tree_map = {} for idx, sentence in enumerate(tqdm(con_sentences, postfix="Processing")): try: tree = raw_tree(sentence[1]) if sentence[0].startswith("#ID="): tree_id = sentence[0].split("=")[-1] else: tree_id = sentence[0].split("#")[-1] # don't care about the raw text? con_tree_map[tree_id] = tree except UnclosedTreeError as e: num_discarded = num_discarded + 1 print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1])) except ExtraCloseTreeError as e: num_discarded = num_discarded + 1 print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1])) except ValueError as e: print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1])) num_discarded = num_discarded + 1 #raise ValueError("Could not process line %d" % idx) from e print("Discarded %d trees. Have %d trees left" % (num_discarded, len(con_tree_map))) if num_discarded > 0: raise ValueError("Oops! We thought all of the VIT trees were properly bracketed now") con_vit_ngrams = build_ngrams(con_tree_map.items(), lambda x: CON_PROCESS_FUNC(x[1]), lambda x: x[0]) # TODO: match more sentences. some are probably missing because of MWT train_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_train_data.sentences, "train", debug_sentence) dev_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_dev_data.sentences, "dev", debug_sentence) test_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_test_data.sentences, "test", debug_sentence) print("Remaining total trees: %d" % (len(train_ids) + len(dev_ids) + len(test_ids))) print(" {} train {} dev {} test".format(len(train_ids), len(dev_ids), len(test_ids))) print("Updating trees with MWT and newer tokens from UD...") # the moveprune feature requires a new corenlp release after 4.4.0 with tsurgeon.Tsurgeon(classpath="$CLASSPATH") as tsurgeon_processor: train_trees = extract_updated_dataset(con_tree_map, ud_vit_train_map, train_ids, mwt_map, tsurgeon_processor) dev_trees = extract_updated_dataset(con_tree_map, ud_vit_dev_map, dev_ids, mwt_map, tsurgeon_processor) test_trees = extract_updated_dataset(con_tree_map, ud_vit_test_map, test_ids, mwt_map, tsurgeon_processor) return train_trees, dev_trees, test_trees def convert_it_vit(paths, dataset_name, debug_sentence=None): """ Read the trees, then write them out to the expected output_directory """ train_trees, dev_trees, test_trees = read_updated_trees(paths, debug_sentence) train_trees = [x.tree for x in train_trees] dev_trees = [x.tree for x in dev_trees] test_trees = [x.tree for x in test_trees] output_directory = paths["CONSTITUENCY_DATA_DIR"] write_dataset([train_trees, dev_trees, test_trees], output_directory, dataset_name) def main(): paths = default_paths.get_default_paths() dataset_name = "it_vit" debug_sentence = sys.argv[1] if len(sys.argv) > 1 else None convert_it_vit(paths, dataset_name, debug_sentence) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/convert_spmrl.py ================================================ import os from stanza.models.constituency.parse_tree import Tree from stanza.models.constituency.tree_reader import read_treebank from stanza.utils.default_paths import get_default_paths SHARDS = ("train", "dev", "test") def add_root(tree): if tree.label.startswith("NN"): tree = Tree("NP", tree) if tree.label.startswith("NE"): tree = Tree("PN", tree) elif tree.label.startswith("XY"): tree = Tree("VROOT", tree) return Tree("ROOT", tree) def convert_spmrl(input_directory, output_directory, short_name): for shard in SHARDS: tree_filename = os.path.join(input_directory, shard, shard + ".German.gold.ptb") trees = read_treebank(tree_filename, tree_callback=add_root) output_filename = os.path.join(output_directory, "%s_%s.mrg" % (short_name, shard)) with open(output_filename, "w", encoding="utf-8") as fout: for tree in trees: fout.write(str(tree)) fout.write("\n") print("Wrote %d trees to %s" % (len(trees), output_filename)) if __name__ == '__main__': paths = get_default_paths() output_directory = paths["CONSTITUENCY_DATA_DIR"] input_directory = "extern_data/constituency/spmrl/SPMRL_SHARED_2014/GERMAN_SPMRL/gold/ptb" convert_spmrl(input_directory, output_directory, "de_spmrl") ================================================ FILE: stanza/utils/datasets/constituency/convert_starlang.py ================================================ import os import re from tqdm import tqdm from stanza.models.constituency import parse_tree from stanza.models.constituency import tree_reader TURKISH_RE = re.compile(r"[{]turkish=([^}]+)[}]") DISALLOWED_LABELS = ('DT', 'DET', 's', 'vp', 'AFVP', 'CONJ', 'INTJ', '-XXX-') def read_tree(text): """ Reads in a tree, then extracts specifically the word from the specific format used Also converts LCB/RCB as needed """ trees = tree_reader.read_trees(text) if len(trees) > 1: raise ValueError("Tree file had two trees!") tree = trees[0] labels = tree.leaf_labels() new_labels = [] for label in labels: match = TURKISH_RE.search(label) if match is None: raise ValueError("Could not find word in |{}|".format(label)) word = match.group(1) word = word.replace("-LCB-", "{").replace("-RCB-", "}") new_labels.append(word) tree = tree.replace_words(new_labels) #tree = tree.remap_constituent_labels(LABEL_MAP) con_labels = tree.get_unique_constituent_labels([tree]) if any(label in DISALLOWED_LABELS for label in con_labels): raise ValueError("found an unexpected phrasal node {}".format(label)) return tree def read_files(filenames, conversion, log): trees = [] for filename in filenames: with open(filename, encoding="utf-8") as fin: text = fin.read() try: tree = conversion(text) if tree is not None: trees.append(tree) except ValueError as e: if log: print("-----------------\nFound an error in {}: {} Original text: {}".format(filename, e, text)) return trees def read_starlang(paths, conversion=read_tree, log=True): """ Read the starlang trees, converting them using the given method. read_tree or any other conversion turns one file at a time to a sentence. log is whether or not to log a ValueError - the NER division has many missing labels """ if isinstance(paths, str): paths = (paths,) train_files = [] dev_files = [] test_files = [] for path in paths: tree_files = [os.path.join(path, x) for x in os.listdir(path)] train_files.extend([x for x in tree_files if x.endswith(".train")]) dev_files.extend([x for x in tree_files if x.endswith(".dev")]) test_files.extend([x for x in tree_files if x.endswith(".test")]) print("Reading %d total files" % (len(train_files) + len(dev_files) + len(test_files))) train_treebank = read_files(tqdm(train_files), conversion=conversion, log=log) dev_treebank = read_files(tqdm(dev_files), conversion=conversion, log=log) test_treebank = read_files(tqdm(test_files), conversion=conversion, log=log) return train_treebank, dev_treebank, test_treebank def main(conversion=read_tree, log=True): paths = ["extern_data/constituency/turkish/TurkishAnnotatedTreeBank-15", "extern_data/constituency/turkish/TurkishAnnotatedTreeBank2-15", "extern_data/constituency/turkish/TurkishAnnotatedTreeBank2-20"] train_treebank, dev_treebank, test_treebank = read_starlang(paths, conversion=conversion, log=log) print("Train: %d" % len(train_treebank)) print("Dev: %d" % len(dev_treebank)) print("Test: %d" % len(test_treebank)) print(train_treebank[0]) return train_treebank, dev_treebank, test_treebank if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/count_common_words.py ================================================ import sys from collections import Counter from stanza.models.constituency import parse_tree from stanza.models.constituency import tree_reader word_counter = Counter() count_words = lambda x: word_counter.update(x.leaf_labels()) tree_reader.read_tree_file(sys.argv[1], tree_callback=count_words) print(word_counter.most_common()[:100]) ================================================ FILE: stanza/utils/datasets/constituency/extract_all_silver_dataset.py ================================================ """ After running build_silver_dataset.py, this extracts the trees of all match levels at once For example python stanza/utils/datasets/constituency/extract_all_silver_dataset.py --output_prefix /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_ --parsed_trees /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_wiki_a*trees cat /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_[012345678].mrg | sort | uniq | shuf > /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_sort.mrg shuf /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_sort.mrg | head -n 200000 > /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_200K.mrg """ import argparse from collections import defaultdict import json def parse_args(): parser = argparse.ArgumentParser(description="After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy") parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.') parser.add_argument('--output_prefix', type=str, default=None, help='Prefix to use for outputting trees') parser.add_argument('--output_suffix', type=str, default=".mrg", help='Suffix to use for outputting trees') args = parser.parse_args() return args def main(): args = parse_args() trees = defaultdict(list) for filename in args.parsed_trees: with open(filename, encoding='utf-8') as fin: for line in fin.readlines(): tree = json.loads(line) trees[tree['count']].append(tree['tree']) for score, tree_list in trees.items(): filename = "%s%s%s" % (args.output_prefix, score, args.output_suffix) with open(filename, 'w', encoding='utf-8') as fout: for tree in tree_list: fout.write(tree) fout.write('\n') if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/extract_silver_dataset.py ================================================ """ After running build_silver_dataset.py, this extracts the trees of a certain match level For example python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score 0 --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg for i in `echo 0 1 2 3 4 5 6 7 8 9 10`; do python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score $i --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_$i.mrg; done """ import argparse import json def parse_args(): parser = argparse.ArgumentParser(description="After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy") parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.') parser.add_argument('--keep_score', type=int, default=None, help='Which agreement level to keep. None keeps all') parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file') args = parser.parse_args() return args def main(): args = parse_args() trees = [] for filename in args.parsed_trees: with open(filename, encoding='utf-8') as fin: for line in fin.readlines(): tree = json.loads(line) if args.keep_score is None or tree['count'] == args.keep_score: tree = tree['tree'] trees.append(tree) if args.output_file is None: for tree in trees: print(tree) else: with open(args.output_file, 'w', encoding='utf-8') as fout: for tree in trees: fout.write(tree) fout.write('\n') if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/prepare_con_dataset.py ================================================ """Converts raw data files from their original format (dataset dependent) into PTB trees. The operation of this script depends heavily on the dataset in question. The common result is that the data files go to data/constituency and are in PTB format. da_arboretum Ekhard Bick Arboretum, a Hybrid Treebank for Danish https://www.researchgate.net/publication/251202293_Arboretum_a_Hybrid_Treebank_for_Danish Available here for a license fee: http://catalog.elra.info/en-us/repository/browse/ELRA-W0084/ Internal to Stanford, please contact Chris Manning and/or John Bauer The file processed is the tiger xml, although there are some edits needed in order to make it functional for our parser The treebank comes as a tar.gz file, W0084.tar.gz untar this file in $CONSTITUENCY_BASE/danish then move the extracted folder to "arboretum" $CONSTITUENCY_BASE/danish/W0084/... becomes $CONSTITUENCY_BASE/danish/arboretum/... en_ptb3-revised is an updated version of PTB with NML and stuff put LDC2015T13 in $CONSTITUENCY_BASE/english the directory name may look like LDC2015T13_eng_news_txt_tbnk-ptb_revised python3 -m stanza.utils.datasets.constituency.prepare_con_dataset en_ptb3-revised All this needs to do is concatenate the various pieces @article{ptb_revised, title= {Penn Treebank Revised: English News Text Treebank LDC2015T13}, journal= {}, author= {Ann Bies and Justin Mott and Colin Warner}, year= {2015}, url= {https://doi.org/10.35111/xpjy-at91}, doi= {10.35111/xpjy-at91}, isbn= {1-58563-724-6}, dcmi= {text}, languages= {english}, language= {english}, ldc= {LDC2015T13}, } id_icon ICON: Building a Large-Scale Benchmark Constituency Treebank for the Indonesian Language Ee Suan Lim, Wei Qi Leong, Ngan Thanh Nguyen, Dea Adhista, Wei Ming Kng, William Chandra Tjhi, Ayu Purwarianti https://aclanthology.org/2023.tlt-1.5.pdf Available at https://github.com/aisingapore/seacorenlp-data git clone the repo in $CONSTITUENCY_BASE/seacorenlp so there is now a directory $CONSTITUENCY_BASE/seacorenlp/seacorenlp-data python3 -m stanza.utils.datasets.constituency.prepare_con_dataset id_icon it_turin A combination of Evalita competition from 2011 and the ParTUT trees More information is available in convert_it_turin it_vit The original for the VIT UD Dataset The UD version has a lot of corrections, so we try to apply those as much as possible In fact, we applied some corrections of our own back to UD based on this treebank. The first version which had those corrections is UD 2.10 Versions of UD before that won't work Hopefully versions after that work Set UDBASE to a path such that $UDBASE/UD_Italian-VIT is the UD version The constituency labels are generally not very understandable, unfortunately Some documentation is available here: https://core.ac.uk/download/pdf/223148096.pdf https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.423.5538&rep=rep1&type=pdf Available from ELRA: http://catalog.elra.info/en-us/repository/browse/ELRA-W0040/ ja_alt Asian Language Treebank produced a treebank for Japanese: Ye Kyaw Thu, Win Pa Pa, Masao Utiyama, Andrew Finch, Eiichiro Sumita Introducing the Asian Language Treebank http://www.lrec-conf.org/proceedings/lrec2016/pdf/435_Paper.pdf Download https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/Japanese-ALT-20210218.zip unzip this in $CONSTITUENCY_BASE/japanese this should create a directory $CONSTITUENCY_BASE/japanese/Japanese-ALT-20210218 In this directory, also download the following: https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-train.txt https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-dev.txt https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-test.txt In particular, there are two files with a bunch of bracketed parses, Japanese-ALT-Draft.txt and Japanese-ALT-Reviewed.txt The first word of each of these lines is SNT.80188.1 or something like that This correlates with the three URL-... files, telling us whether the sentence belongs in train/dev/test python3 -m stanza.utils.datasets.constituency.prepare_con_dataset ja_alt pt_cintil CINTIL treebank for Portuguese, available at ELRA: https://catalogue.elra.info/en-us/repository/browse/ELRA-W0055/ It can also be obtained from here: https://hdl.handle.net/21.11129/0000-000B-D2FE-A Produced at U Lisbon António Branco; João Silva; Francisco Costa; Sérgio Castro CINTIL TreeBank Handbook: Design options for the representation of syntactic constituency Silva, João; António Branco; Sérgio Castro; Ruben Reis Out-of-the-Box Robust Parsing of Portuguese https://portulanclarin.net/repository/extradocs/CINTIL-Treebank.pdf http://www.di.fc.ul.pt/~ahb/pubs/2011bBrancoSilvaCostaEtAl.pdf If at Stanford, ask John Bauer or Chris Manning for the data Otherwise, purchase it from ELRA or find it elsewhere if possible Either way, unzip it in $CONSTITUENCY_BASE/portuguese to the CINTIL directory so for example, the final result might be extern_data/constituency/portuguese/CINTIL/CINTIL-Treebank.xml python3 -m stanza.utils.datasets.constituency.prepare_con_dataset pt_cintil tr_starlang A dataset in three parts from the Starlang group in Turkey: Neslihan Kara, Büşra Marşan, et al Creating A Syntactically Felicitous Constituency Treebank For Turkish https://ieeexplore.ieee.org/document/9259873 git clone the following three repos https://github.com/olcaytaner/TurkishAnnotatedTreeBank-15 https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-15 https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-20 Put them in $CONSTITUENCY_BASE/turkish python3 -m stanza.utils.datasets.constituency.prepare_con_dataset tr_starlang vlsp09 is the 2009 constituency treebank: Nguyen Phuong Thai, Vu Xuan Luong, Nguyen Thi Minh Huyen, Nguyen Van Hiep, Le Hong Phuong Building a Large Syntactically-Annotated Corpus of Vietnamese Proceedings of The Third Linguistic Annotation Workshop In conjunction with ACL-IJCNLP 2009, Suntec City, Singapore, 2009 This can be obtained by contacting vlsp.resources@gmail.com vlsp22 is the 2022 constituency treebank from the VLSP bakeoff there is an official test set as well you may be able to obtain both of these by contacting vlsp.resources@gmail.com NGUYEN Thi Minh Huyen, HA My Linh, VU Xuan Luong, PHAN Thi Hue, LE Van Cuong, NGUYEN Thi Luong, NGO The Quyen VLSP 2022 Challenge: Vietnamese Constituency Parsing to appear in Journal of Computer Science and Cybernetics. vlsp23 is the 2023 update to the constituency treebank from the VLSP bakeoff the vlsp22 code also works for the new dataset, although some effort may be needed to update the tags As of late 2024, the test set is available on request at vlsp.resources@gmail.com Organize the directory $CONSTITUENCY_BASE/vietnamese/VLSP_2023 $CONSTITUENCY_BASE/vietnamese/VLSP_2023/Trainingset $CONSTITUENCY_BASE/vietnamese/VLSP_2023/test zh_ctb-51 is the 5.1 version of CTB put LDC2005T01U01_ChineseTreebank5.1 in $CONSTITUENCY_BASE/chinese python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-51 @article{xue_xia_chiou_palmer_2005, title={The Penn Chinese TreeBank: Phrase structure annotation of a large corpus}, volume={11}, DOI={10.1017/S135132490400364X}, number={2}, journal={Natural Language Engineering}, publisher={Cambridge University Press}, author={XUE, NAIWEN and XIA, FEI and CHIOU, FU-DONG and PALMER, MARTA}, year={2005}, pages={207–238}} zh_ctb-51b is the same dataset, but using a smaller dev/test set in our experiments, this is substantially easier zh_ctb-90 is the 9.0 version of CTB put LDC2016T13 in $CONSTITUENCY_BASE/chinese python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-90 the splits used are the ones from the file docs/ctb9.0-file-list.txt included in the CTB 9.0 release SPMRL adds several treebanks https://www.spmrl.org/ https://www.spmrl.org/sancl-posters2014.html Currently only German is converted, the German version being a version of the Tiger Treebank python3 -m stanza.utils.datasets.constituency.prepare_con_dataset de_spmrl en_mctb is a multidomain test set covering five domains other than newswire https://github.com/RingoS/multi-domain-parsing-analysis Challenges to Open-Domain Constituency Parsing @inproceedings{yang-etal-2022-challenges, title = "Challenges to Open-Domain Constituency Parsing", author = "Yang, Sen and Cui, Leyang and Ning, Ruoxi and Wu, Di and Zhang, Yue", booktitle = "Findings of the Association for Computational Linguistics: ACL 2022", month = may, year = "2022", address = "Dublin, Ireland", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/2022.findings-acl.11", doi = "10.18653/v1/2022.findings-acl.11", pages = "112--127", } This conversion replaces the top bracket from top -> ROOT and puts an extra S bracket on any roots with more than one node. """ import argparse import os import random import sys import tempfile from tqdm import tqdm from stanza.models.constituency import parse_tree import stanza.utils.default_paths as default_paths from stanza.models.constituency import tree_reader from stanza.models.constituency.parse_tree import Tree from stanza.server import tsurgeon from stanza.utils.datasets.common import UnknownDatasetError from stanza.utils.datasets.constituency import utils from stanza.utils.datasets.constituency.convert_alt import convert_alt from stanza.utils.datasets.constituency.convert_arboretum import convert_tiger_treebank from stanza.utils.datasets.constituency.convert_cintil import convert_cintil_treebank import stanza.utils.datasets.constituency.convert_ctb as convert_ctb from stanza.utils.datasets.constituency.convert_it_turin import convert_it_turin from stanza.utils.datasets.constituency.convert_it_vit import convert_it_vit from stanza.utils.datasets.constituency.convert_spmrl import convert_spmrl from stanza.utils.datasets.constituency.convert_starlang import read_starlang from stanza.utils.datasets.constituency.utils import SHARDS, write_dataset import stanza.utils.datasets.constituency.vtb_convert as vtb_convert import stanza.utils.datasets.constituency.vtb_split as vtb_split def process_it_turin(paths, dataset_name, *args): """ Convert the it_turin dataset """ assert dataset_name == 'it_turin' input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "italian") output_dir = paths["CONSTITUENCY_DATA_DIR"] convert_it_turin(input_dir, output_dir) def process_it_vit(paths, dataset_name, *args): # needs at least UD 2.11 or this will not work # in the meantime, the git version of VIT will suffice assert dataset_name == 'it_vit' convert_it_vit(paths, dataset_name) def process_vlsp09(paths, dataset_name, *args): """ Processes the VLSP 2009 dataset, discarding or fixing trees when needed """ assert dataset_name == 'vi_vlsp09' vlsp_path = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", "VietTreebank_VLSP_SP73", "Kho ngu lieu 10000 cay cu phap") with tempfile.TemporaryDirectory() as tmp_output_path: vtb_convert.convert_dir(vlsp_path, tmp_output_path) vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name) def process_vlsp21(paths, dataset_name, *args): """ Processes the VLSP 2021 dataset, which is just a single file """ assert dataset_name == 'vi_vlsp21' vlsp_file = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", "VLSP_2021", "VTB_VLSP21_tree.txt") if not os.path.exists(vlsp_file): raise FileNotFoundError("Could not find the 2021 dataset in the expected location of {} - CONSTITUENCY_BASE == {}".format(vlsp_file, paths["CONSTITUENCY_BASE"])) with tempfile.TemporaryDirectory() as tmp_output_path: vtb_convert.convert_files([vlsp_file], tmp_output_path) # This produces a 0 length test set, just as a placeholder until the actual test set is released vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0.9, dev_size=0.1) _, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], dataset_name) with open(test_file, "w"): # create an empty test file - currently we don't have actual test data for VLSP 21 pass def process_vlsp22(paths, dataset_name, *args): """ Processes the VLSP 2022 dataset, which is four separate files for some reason """ assert dataset_name == 'vi_vlsp22' or dataset_name == 'vi_vlsp23' if dataset_name == 'vi_vlsp22': default_subdir = 'VLSP_2022' default_make_test_split = False updated_tagset = False elif dataset_name == 'vi_vlsp23': default_subdir = os.path.join('VLSP_2023', 'Trainingdataset') default_make_test_split = False updated_tagset = True parser = argparse.ArgumentParser() parser.add_argument('--subdir', default=default_subdir, type=str, help='Where to find the data - allows for using previous versions, if needed') parser.add_argument('--no_convert_brackets', default=True, action='store_false', dest='convert_brackets', help="Don't convert the VLSP parens RKBT & LKBT to PTB parens") parser.add_argument('--n_splits', default=None, type=int, help='Split the data into this many pieces. Relevant as there is no set training/dev split, so this allows for N models on N different dev sets') parser.add_argument('--test_split', default=default_make_test_split, action='store_true', help='Split 1/10th of the data as a test split as well. Useful for experimental results. Less relevant since there is now an official test set') parser.add_argument('--no_test_split', dest='test_split', action='store_false', help='Split 1/10th of the data as a test split as well. Useful for experimental results. Less relevant since there is now an official test set') parser.add_argument('--seed', default=1234, type=int, help='Random seed to use when splitting') args = parser.parse_args(args=list(*args)) if os.path.exists(args.subdir): vlsp_dir = args.subdir else: vlsp_dir = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", args.subdir) if not os.path.exists(vlsp_dir): raise FileNotFoundError("Could not find the {} dataset in the expected location of {} - CONSTITUENCY_BASE == {}".format(dataset_name, vlsp_dir, paths["CONSTITUENCY_BASE"])) vlsp_files = os.listdir(vlsp_dir) vlsp_train_files = [os.path.join(vlsp_dir, x) for x in vlsp_files if x.startswith("file") and not x.endswith(".zip")] vlsp_train_files.sort() if dataset_name == 'vi_vlsp22': vlsp_test_files = [os.path.join(vlsp_dir, x) for x in vlsp_files if x.startswith("private") and not x.endswith(".zip")] elif dataset_name == 'vi_vlsp23': vlsp_test_dir = os.path.abspath(os.path.join(vlsp_dir, os.pardir, "test")) vlsp_test_files = os.listdir(vlsp_test_dir) vlsp_test_files = [os.path.join(vlsp_test_dir, x) for x in vlsp_test_files if x.endswith(".csv")] if len(vlsp_train_files) == 0: raise FileNotFoundError("No train files (files starting with 'file') found in {}".format(vlsp_dir)) if not args.test_split and len(vlsp_test_files) == 0: raise FileNotFoundError("No test files found in {}".format(vlsp_dir)) print("Loading training files from {}".format(vlsp_dir)) print("Procesing training files:\n {}".format("\n ".join(vlsp_train_files))) with tempfile.TemporaryDirectory() as train_output_path: vtb_convert.convert_files(vlsp_train_files, train_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset) # This produces a 0 length test set, just as a placeholder until the actual test set is released if args.n_splits: test_size = 0.1 if args.test_split else 0.0 dev_size = (1.0 - test_size) / args.n_splits train_size = 1.0 - test_size - dev_size for rotation in range(args.n_splits): # there is a shuffle inside the split routine, # so we need to reset the random seed each time random.seed(args.seed) rotation_name = "%s-%d-%d" % (dataset_name, rotation, args.n_splits) if args.test_split: rotation_name = rotation_name + "t" vtb_split.split_files(train_output_path, paths["CONSTITUENCY_DATA_DIR"], rotation_name, train_size=train_size, dev_size=dev_size, rotation=(rotation, args.n_splits)) else: test_size = 0.1 if args.test_split else 0.0 dev_size = 0.1 train_size = 1.0 - test_size - dev_size if args.test_split: dataset_name = dataset_name + "t" vtb_split.split_files(train_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=train_size, dev_size=dev_size) if not args.test_split: print("Procesing test files:\n {}".format("\n ".join(vlsp_test_files))) with tempfile.TemporaryDirectory() as test_output_path: vtb_convert.convert_files(vlsp_test_files, test_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset) if args.n_splits: for rotation in range(args.n_splits): rotation_name = "%s-%d-%d" % (dataset_name, rotation, args.n_splits) vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], rotation_name, train_size=0, dev_size=0) else: vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0, dev_size=0) if not args.test_split and not args.n_splits and dataset_name == 'vi_vlsp23': print("Procesing test files and keeping ids:\n {}".format("\n ".join(vlsp_test_files))) with tempfile.TemporaryDirectory() as test_output_path: vtb_convert.convert_files(vlsp_test_files, test_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset, write_ids=True) vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name + "-ids", train_size=0, dev_size=0) def process_arboretum(paths, dataset_name, *args): """ Processes the Danish dataset, Arboretum """ assert dataset_name == 'da_arboretum' arboretum_file = os.path.join(paths["CONSTITUENCY_BASE"], "danish", "arboretum", "arboretum.tiger", "arboretum.tiger") if not os.path.exists(arboretum_file): raise FileNotFoundError("Unable to find input file for Arboretum. Expected in {}".format(arboretum_file)) treebank = convert_tiger_treebank(arboretum_file) datasets = utils.split_treebank(treebank, 0.8, 0.1) output_dir = paths["CONSTITUENCY_DATA_DIR"] output_filename = os.path.join(output_dir, "%s.mrg" % dataset_name) print("Writing {} trees to {}".format(len(treebank), output_filename)) parse_tree.Tree.write_treebank(treebank, output_filename) write_dataset(datasets, output_dir, dataset_name) def process_starlang(paths, dataset_name, *args): """ Convert the Turkish Starlang dataset to brackets """ assert dataset_name == 'tr_starlang' PIECES = ["TurkishAnnotatedTreeBank-15", "TurkishAnnotatedTreeBank2-15", "TurkishAnnotatedTreeBank2-20"] output_dir = paths["CONSTITUENCY_DATA_DIR"] chunk_paths = [os.path.join(paths["CONSTITUENCY_BASE"], "turkish", piece) for piece in PIECES] datasets = read_starlang(chunk_paths) write_dataset(datasets, output_dir, dataset_name) def process_ja_alt(paths, dataset_name, *args): """ Convert and split the ALT dataset TODO: could theoretically extend this to MY or any other similar dataset from ALT """ lang, source = dataset_name.split("_", 1) assert lang == 'ja' assert source == 'alt' PIECES = ["Japanese-ALT-Draft.txt", "Japanese-ALT-Reviewed.txt"] input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "japanese", "Japanese-ALT-20210218") input_files = [os.path.join(input_dir, input_file) for input_file in PIECES] split_files = [os.path.join(input_dir, "URL-%s.txt" % shard) for shard in SHARDS] output_dir = paths["CONSTITUENCY_DATA_DIR"] output_files = [os.path.join(output_dir, "%s_%s.mrg" % (dataset_name, shard)) for shard in SHARDS] convert_alt(input_files, split_files, output_files) def process_pt_cintil(paths, dataset_name, *args): """ Convert and split the PT Cintil dataset """ lang, source = dataset_name.split("_", 1) assert lang == 'pt' assert source == 'cintil' input_file = os.path.join(paths["CONSTITUENCY_BASE"], "portuguese", "CINTIL", "CINTIL-Treebank.xml") output_dir = paths["CONSTITUENCY_DATA_DIR"] datasets = convert_cintil_treebank(input_file) write_dataset(datasets, output_dir, dataset_name) def process_id_icon(paths, dataset_name, *args): lang, source = dataset_name.split("_", 1) assert lang == 'id' assert source == 'icon' input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "seacorenlp", "seacorenlp-data", "id", "constituency") input_files = [os.path.join(input_dir, x) for x in ("train.txt", "dev.txt", "test.txt")] datasets = [] for input_file in input_files: trees = tree_reader.read_tree_file(input_file) trees = [Tree("ROOT", tree) for tree in trees] datasets.append(trees) output_dir = paths["CONSTITUENCY_DATA_DIR"] write_dataset(datasets, output_dir, dataset_name) def process_ctb_51(paths, dataset_name, *args): lang, source = dataset_name.split("_", 1) assert lang == 'zh-hans' assert source == 'ctb-51' input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2005T01U01_ChineseTreebank5.1", "bracketed") output_dir = paths["CONSTITUENCY_DATA_DIR"] convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51) def process_ctb_51b(paths, dataset_name, *args): lang, source = dataset_name.split("_", 1) assert lang == 'zh-hans' assert source == 'ctb-51b' input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2005T01U01_ChineseTreebank5.1", "bracketed") output_dir = paths["CONSTITUENCY_DATA_DIR"] if not os.path.exists(input_dir): raise FileNotFoundError("CTB 5.1 location not found: %s" % input_dir) print("Loading trees from %s" % input_dir) convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51b) def process_ctb_90(paths, dataset_name, *args): lang, source = dataset_name.split("_", 1) assert lang == 'zh-hans' assert source == 'ctb-90' input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2016T13", "ctb9.0", "data", "bracketed") output_dir = paths["CONSTITUENCY_DATA_DIR"] convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V90) def process_ptb3_revised(paths, dataset_name, *args): input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "english", "LDC2015T13_eng_news_txt_tbnk-ptb_revised") if not os.path.exists(input_dir): backup_input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "english", "LDC2015T13") if not os.path.exists(backup_input_dir): raise FileNotFoundError("Could not find ptb3-revised in either %s or %s" % (input_dir, backup_input_dir)) input_dir = backup_input_dir bracket_dir = os.path.join(input_dir, "data", "penntree") output_dir = paths["CONSTITUENCY_DATA_DIR"] # compensate for a weird mislabeling in the original dataset label_map = {"ADJ-PRD": "ADJP-PRD"} train_trees = [] for i in tqdm(range(2, 22)): new_trees = tree_reader.read_directory(os.path.join(bracket_dir, "%02d" % i)) new_trees = [t.remap_constituent_labels(label_map) for t in new_trees] train_trees.extend(new_trees) move_tregex = "_ROOT_ <1 __=home <2 /^[.]$/=move" move_tsurgeon = "move move >-1 home" print("Moving sentence final punctuation if necessary") with tsurgeon.Tsurgeon() as tsurgeon_processor: train_trees = [tsurgeon_processor.process(tree, move_tregex, move_tsurgeon)[0] for tree in tqdm(train_trees)] dev_trees = tree_reader.read_directory(os.path.join(bracket_dir, "22")) dev_trees = [t.remap_constituent_labels(label_map) for t in dev_trees] test_trees = tree_reader.read_directory(os.path.join(bracket_dir, "23")) test_trees = [t.remap_constituent_labels(label_map) for t in test_trees] print("Read %d train trees, %d dev trees, and %d test trees" % (len(train_trees), len(dev_trees), len(test_trees))) datasets = [train_trees, dev_trees, test_trees] write_dataset(datasets, output_dir, dataset_name) def process_en_mctb(paths, dataset_name, *args): """ Converts the following blocks: dialogue.cleaned.txt forum.cleaned.txt law.cleaned.txt literature.cleaned.txt review.cleaned.txt """ base_path = os.path.join(paths["CONSTITUENCY_BASE"], "english", "multi-domain-parsing-analysis", "data", "MCTB_en") if not os.path.exists(base_path): raise FileNotFoundError("Please download multi-domain-parsing-analysis to %s" % base_path) def tree_callback(tree): if len(tree.children) > 1: tree = parse_tree.Tree("S", tree.children) return parse_tree.Tree("ROOT", [tree]) return parse_tree.Tree("ROOT", tree.children) filenames = ["dialogue.cleaned.txt", "forum.cleaned.txt", "law.cleaned.txt", "literature.cleaned.txt", "review.cleaned.txt"] for filename in filenames: trees = tree_reader.read_tree_file(os.path.join(base_path, filename), tree_callback=tree_callback) print("%d trees in %s" % (len(trees), filename)) output_filename = "%s-%s_test.mrg" % (dataset_name, filename.split(".")[0]) output_filename = os.path.join(paths["CONSTITUENCY_DATA_DIR"], output_filename) print("Writing trees to %s" % output_filename) parse_tree.Tree.write_treebank(trees, output_filename) def process_spmrl(paths, dataset_name, *args): if dataset_name != 'de_spmrl': raise ValueError("SPMRL dataset %s currently not supported" % dataset_name) output_directory = paths["CONSTITUENCY_DATA_DIR"] input_directory = os.path.join(paths["CONSTITUENCY_BASE"], "spmrl", "SPMRL_SHARED_2014", "GERMAN_SPMRL", "gold", "ptb") convert_spmrl(input_directory, output_directory, dataset_name) DATASET_MAPPING = { 'da_arboretum': process_arboretum, 'de_spmrl': process_spmrl, 'en_ptb3-revised': process_ptb3_revised, 'en_mctb': process_en_mctb, 'id_icon': process_id_icon, 'it_turin': process_it_turin, 'it_vit': process_it_vit, 'ja_alt': process_ja_alt, 'pt_cintil': process_pt_cintil, 'tr_starlang': process_starlang, 'vi_vlsp09': process_vlsp09, 'vi_vlsp21': process_vlsp21, 'vi_vlsp22': process_vlsp22, 'vi_vlsp23': process_vlsp22, # options allow for this 'zh-hans_ctb-51': process_ctb_51, 'zh-hans_ctb-51b': process_ctb_51b, 'zh-hans_ctb-90': process_ctb_90, } def main(dataset_name, *args): paths = default_paths.get_default_paths() random.seed(1234) if dataset_name in DATASET_MAPPING: DATASET_MAPPING[dataset_name](paths, dataset_name, *args) else: raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_con_dataset") if __name__ == '__main__': if len(sys.argv) == 1: print("Known datasets:") for key in DATASET_MAPPING: print(" %s" % key) else: main(sys.argv[1], sys.argv[2:]) ================================================ FILE: stanza/utils/datasets/constituency/reduce_dataset.py ================================================ """ Cut short the training portion of a constituency dataset. One could think this script isn't necessary, as shuf | head would work, but some treebanks use multiple lines for representing trees. Thus it is necessary to actually intelligently read the trees. Run with python3 stanza/utils/datasets/constituency/reduce_dataset.py --input zh-hans_ctb-51b --output zh-hans_ctb5k """ import argparse import os import random from stanza.models.constituency import tree_reader import stanza.utils.default_paths as default_paths from stanza.utils.datasets.constituency.utils import SHARDS, write_dataset def main(): parser = argparse.ArgumentParser(description="Script that cuts a treebank down to size") parser.add_argument('--input', type=str, default=None, help='Input treebank') parser.add_argument('--output', type=str, default=None, help='Output treebank') parser.add_argument('--size', type=int, default=5000, help='How many trees') args = parser.parse_args() random.seed(1234) paths = default_paths.get_default_paths() output_directory = paths["CONSTITUENCY_DATA_DIR"] # data/constituency/en_ptb3_train.mrg input_filenames = [os.path.join(output_directory, "%s_%s.mrg" % (args.input, shard)) for shard in SHARDS] output_filenames = ["%s_%s.mrg" % (args.output, shard) for shard in SHARDS] shrink_datasets = [True, False, False] datasets = [] for input_filename, shrink in zip(input_filenames, shrink_datasets): treebank = tree_reader.read_treebank(input_filename) if shrink: random.shuffle(treebank) treebank = treebank[:args.size] datasets.append(treebank) write_dataset(datasets, output_directory, args.output) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/relabel_tags.py ================================================ """ Retag an S-expression tree with a new set of POS tags Also includes an option to write the new trees as bracket_labels (essentially, skipping the treebank_to_labeled_brackets step) """ import argparse import logging from stanza import Pipeline from stanza.models.constituency import retagging from stanza.models.constituency import tree_reader from stanza.models.constituency.utils import retag_trees logger = logging.getLogger('stanza') def parse_args(): parser = argparse.ArgumentParser(description="Script that retags a tree file") parser.add_argument('--lang', default='vi', type=str, help='Language') parser.add_argument('--input_file', default='data/constituency/vi_vlsp21_train.mrg', help='File to retag') parser.add_argument('--output_file', default='vi_vlsp21_train_retagged.mrg', help='Where to write the retagged trees') retagging.add_retag_args(parser) parser.add_argument('--bracket_labels', action='store_true', help='Write the trees as bracket labels instead of S-expressions') args = parser.parse_args() args = vars(args) retagging.postprocess_args(args) return args def main(): args = parse_args() retag_pipeline = retagging.build_retag_pipeline(args) train_trees = tree_reader.read_treebank(args['input_file']) logger.info("Retagging %d trees using %s", len(train_trees), args['retag_package']) train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos']) tree_format = "{:L}" if args['bracket_labels'] else "{}" with open(args['output_file'], "w") as fout: for tree in train_trees: fout.write(tree_format.format(tree)) fout.write("\n") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/selftrain.py ================================================ """ Common methods for the various self-training data collection scripts """ import logging import os import random import re import stanza from stanza.models.common import utils from stanza.models.common.bert_embedding import TextTooLongError from stanza.utils.get_tqdm import get_tqdm logger = logging.getLogger('stanza') tqdm = get_tqdm() def common_args(parser): parser.add_argument( '--output_file', default='data/constituency/vi_silver.mrg', help='Where to write the silver trees' ) parser.add_argument( '--lang', default='vi', help='Which language tools to use for tokenization and POS' ) parser.add_argument( '--num_sentences', type=int, default=-1, help='How many sentences to get per file (max)' ) parser.add_argument( '--models', default='saved_models/constituency/vi_vlsp21_inorder.pt', help='What models to use for parsing. comma-separated' ) parser.add_argument( '--package', default='default', help='Which package to load pretrain & charlm from for the parsers' ) parser.add_argument( '--output_ptb', default=False, action='store_true', help='Output trees in PTB brackets (default is a bracket language format)' ) def add_length_args(parser): parser.add_argument( '--min_len', default=5, type=int, help='Minimum length sentence to keep. None = unlimited' ) parser.add_argument( '--no_min_len', dest='min_len', action='store_const', const=None, help='No minimum length' ) parser.add_argument( '--max_len', default=100, type=int, help='Maximum length sentence to keep. None = unlimited' ) parser.add_argument( '--no_max_len', dest='max_len', action='store_const', const=None, help='No maximum length' ) def build_ssplit_pipe(ssplit, lang): if ssplit: return stanza.Pipeline(lang, processors="tokenize") else: return stanza.Pipeline(lang, processors="tokenize", tokenize_no_ssplit=True) def build_tag_pipe(ssplit, lang, foundation_cache=None): if ssplit: return stanza.Pipeline(lang, processors="tokenize,pos", foundation_cache=foundation_cache) else: return stanza.Pipeline(lang, processors="tokenize,pos", tokenize_no_ssplit=True, foundation_cache=foundation_cache) def build_parser_pipes(lang, models, package="default", foundation_cache=None): """ Build separate pipelines for each parser model we want to use It is highly recommended to pass in a FoundationCache to reuse bottom layers """ parser_pipes = [] for model_name in models.split(","): if os.path.exists(model_name): # if the model name exists as a file, treat it as the path to the model pipe = stanza.Pipeline(lang, processors="constituency", package=package, constituency_model_path=model_name, constituency_pretagged=True, foundation_cache=foundation_cache) else: # otherwise, assume it is a package name? pipe = stanza.Pipeline(lang, processors={"constituency": model_name}, constituency_pretagged=True, package=None, foundation_cache=foundation_cache) parser_pipes.append(pipe) return parser_pipes def split_docs(docs, ssplit_pipe, max_len=140, max_word_len=50, chunk_size=2000): """ Using the ssplit pipeline, break up the documents into sentences Filters out sentences which are too long or have words too long. This step is necessary because some web text has unstructured sentences which overwhelm the tagger, or even text with no whitespace which breaks the charlm in the tokenizer or tagger """ raw_sentences = 0 filtered_sentences = 0 new_docs = [] logger.info("Splitting raw docs into sentences: %d", len(docs)) for chunk_start in tqdm(range(0, len(docs), chunk_size)): chunk = docs[chunk_start:chunk_start+chunk_size] chunk = [stanza.Document([], text=t) for t in chunk] chunk = ssplit_pipe(chunk) sentences = [s for d in chunk for s in d.sentences] raw_sentences += len(sentences) sentences = [s for s in sentences if len(s.words) < max_len] sentences = [s for s in sentences if max(len(w.text) for w in s.words) < max_word_len] filtered_sentences += len(sentences) new_docs.extend([s.text for s in sentences]) logger.info("Split sentences: %d", raw_sentences) logger.info("Sentences filtered for length: %d", filtered_sentences) return new_docs # from https://stackoverflow.com/questions/2718196/find-all-chinese-text-in-a-string-using-python-and-regex ZH_RE = re.compile(u'[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]', re.UNICODE) # https://stackoverflow.com/questions/6787716/regular-expression-for-japanese-characters JA_RE = re.compile(u'[一-龠ぁ-ゔァ-ヴー々〆〤ヶ]', re.UNICODE) DEV_RE = re.compile(u'[\u0900-\u097f]', re.UNICODE) def tokenize_docs(docs, pipe, min_len, max_len): """ Turn the text in docs into a list of whitespace separated sentences docs: a list of strings pipe: a Stanza pipeline for tokenizing min_len, max_len: can be None to not filter by this attribute """ results = [] docs = [stanza.Document([], text=t) for t in docs] if len(docs) == 0: return results pipe(docs) is_zh = pipe.lang and pipe.lang.startswith("zh") is_ja = pipe.lang and pipe.lang.startswith("ja") is_vi = pipe.lang and pipe.lang.startswith("vi") for doc in docs: for sentence in doc.sentences: if min_len and len(sentence.words) < min_len: continue if max_len and len(sentence.words) > max_len: continue text = sentence.text if (text.find("|") >= 0 or text.find("_") >= 0 or text.find("<") >= 0 or text.find(">") >= 0 or text.find("[") >= 0 or text.find("]") >= 0 or text.find('—') >= 0): # an em dash, seems to be part of lists continue # the VI tokenizer in particular doesn't split these well if any(any(w.text.find(c) >= 0 and len(w.text) > 1 for w in sentence.words) for c in '"()'): continue text = [w.text.replace(" ", "_") for w in sentence.words] text = " ".join(text) if any(len(w.text) >= 50 for w in sentence.words): # skip sentences where some of the words are unreasonably long # could make this an argument continue if not is_zh and len(ZH_RE.findall(text)) > 250: # some Chinese sentences show up in VI Wikipedia # we want to eliminate ones which will choke the bert models continue if not is_ja and len(JA_RE.findall(text)) > 150: # some Japanese sentences also show up in VI Wikipedia # we want to eliminate ones which will choke the bert models continue if is_vi and len(DEV_RE.findall(text)) > 100: # would need some list of languages that use # Devanagari to eliminate sentences from all datasets. # Otherwise we might accidentally throw away all the # text from a language we need (although that would be obvious) continue results.append(text) return results def find_matching_trees(docs, num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=True, chunk_size=10, max_len=140, min_len=10, output_ptb=False): """ Find trees where all the parsers in parser_pipes agree docs should be a list of strings. one sentence per string or a whole block of text as long as the tag_pipe can break it into sentences num_sentences > 0 gives an upper limit on how many sentences to extract. If < 0, all possible sentences are extracted accepted_trees is a running tally of all the trees already built, so that we don't reuse the same sentence if we see it again """ if num_sentences < 0: tqdm_total = len(docs) else: tqdm_total = num_sentences if output_ptb: output_format = "{}" else: output_format = "{:L}" with tqdm(total=tqdm_total, leave=False) as pbar: if shuffle: random.shuffle(docs) new_trees = set() for chunk_start in range(0, len(docs), chunk_size): chunk = docs[chunk_start:chunk_start+chunk_size] chunk = [stanza.Document([], text=t) for t in chunk] if num_sentences < 0: pbar.update(len(chunk)) # first, retag the sentences tag_pipe(chunk) chunk = [d for d in chunk if len(d.sentences) > 0] if max_len is not None: # for now, we don't have a good way to deal with sentences longer than the bert maxlen chunk = [d for d in chunk if max(len(s.words) for s in d.sentences) < max_len] if len(chunk) == 0: continue parses = [] try: for pipe in parser_pipes: pipe(chunk) trees = [output_format.format(sent.constituency) for doc in chunk for sent in doc.sentences if len(sent.words) >= min_len] parses.append(trees) except TextTooLongError as e: # easiest is to skip this chunk - could theoretically save the other sentences continue for tree in zip(*parses): if len(set(tree)) != 1: continue tree = tree[0] if tree in accepted_trees: continue if tree not in new_trees: new_trees.add(tree) if num_sentences >= 0: pbar.update(1) if num_sentences >= 0 and len(new_trees) >= num_sentences: return new_trees return new_trees ================================================ FILE: stanza/utils/datasets/constituency/selftrain_it.py ================================================ """Builds a self-training dataset from an Italian data source and two models The idea is that the top down and the inorder parsers should make somewhat different errors, so hopefully the sum of an 86 f1 parser and an 85.5 f1 parser will produce some half-decent silver trees which can be used as self-training so that a new model can do better than either. One dataset used is PaCCSS, which has 63000 pairs of sentences: http://www.italianlp.it/resources/paccss-it-parallel-corpus-of-complex-simple-sentences-for-italian/ PaCCSS-IT: A Parallel Corpus of Complex-Simple Sentences for Automatic Text Simplification Brunato, Dominique et al, 2016 https://aclanthology.org/D16-1034 Even larger is the IT section of Europarl, which has 1900000 lines https://www.statmt.org/europarl/ Europarl: A Parallel Corpus for Statistical Machine Translation Philipp Koehn https://homepages.inf.ed.ac.uk/pkoehn/publications/europarl-mtsummit05.pdf """ import argparse import logging import os import random import stanza from stanza.models.common.foundation_cache import FoundationCache from stanza.utils.datasets.constituency import selftrain from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() logger = logging.getLogger('stanza') def parse_args(): parser = argparse.ArgumentParser( description="Script that converts a public IT dataset to silver standard trees" ) selftrain.common_args(parser) parser.add_argument( '--input_dir', default='extern_data/italian', help='Path to the PaCCSS corpus and europarl corpus' ) parser.add_argument( '--no_europarl', default=True, action='store_false', dest='europarl', help='Use the europarl dataset. Turning this off makes the script a lot faster' ) parser.set_defaults(lang="it") parser.set_defaults(package="vit") parser.set_defaults(models="saved_models/constituency/it_best/it_vit_inorder_best.pt,saved_models/constituency/it_best/it_vit_topdown.pt") parser.set_defaults(output_file="data/constituency/it_silver.mrg") args = parser.parse_args() return args def get_paccss(input_dir): """ Read the paccss dataset, which is two sentences per line """ input_file = os.path.join(input_dir, "PaCCSS/data-set/PACCSS-IT.txt") with open(input_file) as fin: # the first line is a header line lines = fin.readlines()[1:] lines = [x.strip() for x in lines] lines = [x.split("\t")[:2] for x in lines if x] text = [y for x in lines for y in x] logger.info("Read %d sentences from %s", len(text), input_file) return text def get_europarl(input_dir, ssplit_pipe): """ Read the Europarl dataset This dataset needs to be tokenized and split into lines """ input_file = os.path.join(input_dir, "europarl/europarl-v7.it-en.it") with open(input_file) as fin: # the first line is a header line lines = fin.readlines()[1:] lines = [x.strip() for x in lines] lines = [x for x in lines if x] logger.info("Read %d docs from %s", len(lines), input_file) lines = selftrain.split_docs(lines, ssplit_pipe) return lines def main(): """ Combine the two datasets, parse them, and write out the results """ args = parse_args() foundation_cache = FoundationCache() ssplit_pipe = selftrain.build_ssplit_pipe(ssplit=True, lang=args.lang) tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang, foundation_cache=foundation_cache) parser_pipes = selftrain.build_parser_pipes(args.lang, args.models, package=args.package, foundation_cache=foundation_cache) docs = get_paccss(args.input_dir) if args.europarl: docs.extend(get_europarl(args.input_dir, ssplit_pipe)) logger.info("Processing %d docs", len(docs)) new_trees = selftrain.find_matching_trees(docs, args.num_sentences, set(), tag_pipe, parser_pipes, shuffle=False, chunk_size=100, output_ptb=args.output_ptb) logger.info("Found %d unique trees which are the same between models" % len(new_trees)) with open(args.output_file, "w") as fout: for tree in sorted(new_trees): fout.write(tree) fout.write("\n") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/selftrain_single_file.py ================================================ """ Builds a self-training dataset from a single file. Default is to assume one document of text per line. If a line has multiple sentences, they will be split using the stanza tokenizer. """ import argparse import io import logging import os import numpy as np import stanza from stanza.utils.datasets.constituency import selftrain from stanza.utils.get_tqdm import get_tqdm logger = logging.getLogger('stanza') tqdm = get_tqdm() def parse_args(): """ Only specific argument for this script is the file to process """ parser = argparse.ArgumentParser( description="Script that converts a single file of text to silver standard trees" ) selftrain.common_args(parser) parser.add_argument( '--input_file', default="vi_part_1.aa", help='Path to the file to read' ) args = parser.parse_args() return args def read_file(input_file): """ Read lines from an input file Takes care to avoid encoding errors at the end of Oscar files. The Oscar splits sometimes break a utf-8 character in half. """ with open(input_file, "rb") as fin: text = fin.read() text = text.decode("utf-8", errors="replace") with io.StringIO(text) as fin: lines = fin.readlines() return lines def main(): args = parse_args() # TODO: make ssplit an argument ssplit_pipe = selftrain.build_ssplit_pipe(ssplit=True, lang=args.lang) tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang) parser_pipes = selftrain.build_parser_pipes(args.lang, args.models) # create a blank file. we will append to this file so that partial results can be used with open(args.output_file, "w") as fout: pass docs = read_file(args.input_file) logger.info("Read %d lines from %s", len(docs), args.input_file) docs = selftrain.split_docs(docs, ssplit_pipe) # breaking into chunks lets us output partial results and see the # progress in log files accepted_trees = set() if len(docs) > 10000: chunks = tqdm(np.array_split(docs, 100), disable=False) else: chunks = [docs] for chunk in chunks: new_trees = selftrain.find_matching_trees(chunk, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100) accepted_trees.update(new_trees) with open(args.output_file, "a") as fout: for tree in sorted(new_trees): fout.write(tree) fout.write("\n") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/selftrain_vi_quad.py ================================================ """ Processes the train section of VI QuAD into trees suitable for use in the conparser lm """ import argparse import json import logging import stanza from stanza.utils.datasets.constituency import selftrain logger = logging.getLogger('stanza') def parse_args(): parser = argparse.ArgumentParser( description="Script that converts vi quad to silver standard trees" ) selftrain.common_args(parser) selftrain.add_length_args(parser) parser.add_argument( '--input_file', default="extern_data/vietnamese/ViQuAD/train_ViQuAD.json", help='Path to the ViQuAD train file' ) parser.add_argument( '--tokenize_only', default=False, action='store_true', help='Tokenize instead of writing trees' ) args = parser.parse_args() return args def parse_quad(text): """ Read in a file from the VI quad dataset The train file has a specific format: the doc has a 'data' section each block in the data is a separate document (138 in the train file, for example) each block has a 'paragraphs' section each paragrah has 'qas' and 'context'. we care about the qas each piece of qas has 'question', which is what we actually want """ doc = json.loads(text) questions = [] for block in doc['data']: paragraphs = block['paragraphs'] for paragraph in paragraphs: qas = paragraph['qas'] for question in qas: questions.append(question['question']) return questions def read_quad(train_file): with open(train_file) as fin: text = fin.read() return parse_quad(text) def main(): """ Turn the train section of VI quad into a list of trees """ args = parse_args() docs = read_quad(args.input_file) logger.info("Read %d lines from %s", len(docs), args.input_file) if args.tokenize_only: pipe = stanza.Pipeline(args.lang, processors="tokenize") text = selftrain.tokenize_docs(docs, pipe, args.min_len, args.max_len) with open(args.output_file, "w", encoding="utf-8") as fout: for line in text: fout.write(line) fout.write("\n") else: tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang) parser_pipes = selftrain.build_parser_pipes(args.lang, args.models) # create a blank file. we will append to this file so that partial results can be used with open(args.output_file, "w") as fout: pass accepted_trees = set() new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100) new_trees = [tree for tree in new_trees if tree.find("(_SQ") >= 0] with open(args.output_file, "a") as fout: for tree in sorted(new_trees): fout.write(tree) fout.write("\n") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/selftrain_wiki.py ================================================ """Builds a self-training dataset from an Italian data source and two models The idea is that the top down and the inorder parsers should make somewhat different errors, so hopefully the sum of an 86 f1 parser and an 85.5 f1 parser will produce some half-decent silver trees which can be used as self-training so that a new model can do better than either. The dataset used is PaCCSS, which has 63000 pairs of sentences: http://www.italianlp.it/resources/paccss-it-parallel-corpus-of-complex-simple-sentences-for-italian/ """ import argparse from collections import deque import glob import os import random from stanza.models.common.foundation_cache import FoundationCache from stanza.utils.datasets.constituency import selftrain from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() def parse_args(): parser = argparse.ArgumentParser( description="Script that converts part of a wikipedia dump to silver standard trees" ) selftrain.common_args(parser) parser.add_argument( '--input_dir', default='extern_data/vietnamese/wikipedia/text', help='Path to the wikipedia dump after processing by wikiextractor' ) parser.add_argument( '--no_shuffle', dest='shuffle', action='store_false', help="Don't shuffle files when processing the directory" ) parser.set_defaults(num_sentences=10000) args = parser.parse_args() return args def list_wikipedia_files(input_dir): """ Get a list of wiki files under the input_dir Recursively traverse the directory, then sort """ if not os.path.isdir(input_dir) and os.path.split(input_dir)[1].startswith("wiki_"): return [input_dir] wiki_files = [] recursive_files = deque() recursive_files.extend(glob.glob(os.path.join(input_dir, "*"))) while len(recursive_files) > 0: next_file = recursive_files.pop() if os.path.isdir(next_file): recursive_files.extend(glob.glob(os.path.join(next_file, "*"))) elif os.path.split(next_file)[1].startswith("wiki_"): wiki_files.append(next_file) wiki_files.sort() return wiki_files def read_wiki_file(filename): """ Read the text from a wiki file as a list of paragraphs. Each is its own item in the list. Lines are separated by \n\n to give hints to the stanza tokenizer. The first line after is skipped as it is usually the document title. """ with open(filename) as fin: lines = fin.readlines() docs = [] current_doc = [] line_iterator = iter(lines) line = next(line_iterator, None) while line is not None: if line.startswith(" 2: # a lot of very short documents are links to related documents # a single wikipedia can have tens of thousands of useless almost-duplicates docs.append("\n\n".join(current_doc)) current_doc = [] else: # not the start or end of a doc # hopefully this is valid text line = line.replace("()", " ") line = line.replace("( )", " ") line = line.strip() if line.find("<") >= 0 or line.find(">") >= 0: line = "" if line: current_doc.append(line) line = next(line_iterator, None) if current_doc: docs.append("\n\n".join(current_doc)) return docs def main(): args = parse_args() random.seed(1234) wiki_files = list_wikipedia_files(args.input_dir) if args.shuffle: random.shuffle(wiki_files) foundation_cache = FoundationCache() tag_pipe = selftrain.build_tag_pipe(ssplit=True, lang=args.lang, foundation_cache=foundation_cache) parser_pipes = selftrain.build_parser_pipes(args.lang, args.models, foundation_cache=foundation_cache) # create a blank file. we will append to this file so that partial results can be used with open(args.output_file, "w") as fout: pass accepted_trees = set() for filename in tqdm(wiki_files, disable=False): docs = read_wiki_file(filename) new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=args.shuffle) accepted_trees.update(new_trees) with open(args.output_file, "a") as fout: for tree in sorted(new_trees): fout.write(tree) fout.write("\n") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/silver_variance.py ================================================ """ Use the concepts in "Dataset Cartography" and "Mind Your Outliers" to find trees with the least variance over a training run https://arxiv.org/pdf/2009.10795.pdf https://arxiv.org/abs/2107.02331 The idea here is that high variance trees are more likely to be wrong in the first place. Using this will filter a silver dataset to have better trees. for example: nlprun -d a6000 -p high "export CLASSPATH=/sailhome/horatio/CoreNLP/classes:/sailhome/horatio/CoreNLP/lib/*:$CLASSPATH; python3 stanza/utils/datasets/constituency/silver_variance.py --eval_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg saved_models/constituency/it_vit.top.each.silver0.constituency_0*0.pt --output_file filtered_silver0.mrg" -o filter.out """ import argparse import logging import numpy from stanza.models.common import utils from stanza.models.common.foundation_cache import FoundationCache from stanza.models.constituency import retagging from stanza.models.constituency import tree_reader from stanza.models.constituency.parser_training import run_dev_set from stanza.models.constituency.trainer import Trainer from stanza.models.constituency.utils import retag_trees from stanza.server.parser_eval import EvaluateParser from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() logger = logging.getLogger('stanza.constituency.trainer') def parse_args(args=None): parser = argparse.ArgumentParser(description="Script to filter trees by how much variance they show over multiple checkpoints of a parser training run.") parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--output_file', type=str, default=None, help='Output file after sorting by variance.') parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') utils.add_device_args(parser) # TODO: use the training scripts to pick the charlm & pretrain if needed parser.add_argument('--lang', default='it', help='Language to use') parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval') parser.add_argument('models', type=str, nargs='+', default=None, help="Which model(s) to load") parser.add_argument('--keep', type=float, default=0.5, help="How many trees to keep after sorting by variance") parser.add_argument('--reverse', default=False, action='store_true', help='Actually, keep the high variance trees') retagging.add_retag_args(parser) args = vars(parser.parse_args()) retagging.postprocess_args(args) return args def main(): args = parse_args() retag_pipeline = retagging.build_retag_pipeline(args) foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() print("Analyzing with the following models:\n " + "\n ".join(args['models'])) treebank = tree_reader.read_treebank(args['eval_file']) logger.info("Read %d trees for analysis", len(treebank)) f1_history = [] retagged_treebank = None chunk_size = 5000 with EvaluateParser() as evaluator: for model_filename in args['models']: print("Starting processing with %s" % model_filename) trainer = Trainer.load(model_filename, args=args, foundation_cache=foundation_cache) if retag_pipeline is not None and retagged_treebank is None: retag_method = trainer.model.args['retag_method'] retag_xpos = trainer.model.args['retag_xpos'] logger.info("Retagging trees using the %s tags from the %s package...", retag_method, args['retag_package']) retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos) logger.info("Retagging finished") current_history = [] for chunk_start in range(0, len(treebank), chunk_size): chunk = treebank[chunk_start:chunk_start+chunk_size] retagged_chunk = retagged_treebank[chunk_start:chunk_start+chunk_size] if retagged_treebank else None f1, kbestF1, treeF1 = run_dev_set(trainer.model, retagged_chunk, chunk, args, evaluator) current_history.extend(treeF1) f1_history.append(current_history) f1_history = numpy.array(f1_history) f1_variance = numpy.var(f1_history, axis=0) f1_sorted = sorted([(x, idx) for idx, x in enumerate(f1_variance)], reverse=args['reverse']) num_keep = int(len(f1_sorted) * args['keep']) with open(args['output_file'], "w", encoding="utf-8") as fout: for _, idx in f1_sorted[:num_keep]: fout.write(str(treebank[idx])) fout.write("\n") if __name__ == "__main__": main() ================================================ FILE: stanza/utils/datasets/constituency/split_holdout.py ================================================ """ Split a constituency dataset randomly into 90/10 splits TODO: add a function to rotate the pieces of the split so that each training instance gets seen once """ import argparse import os import random from stanza.models.constituency import tree_reader from stanza.utils.datasets.constituency.utils import copy_dev_test from stanza.utils.default_paths import get_default_paths def write_trees(base_path, dataset_name, trees): output_path = os.path.join(base_path, "%s_train.mrg" % dataset_name) with open(output_path, "w", encoding="utf-8") as fout: for tree in trees: fout.write("%s\n" % tree) def main(): parser = argparse.ArgumentParser(description="Split a standard dataset into 90/10 proportions of train so there is held out training data") parser.add_argument('--dataset', type=str, default="id_icon", help='dataset to split') parser.add_argument('--base_dataset', type=str, default=None, help='output name for base dataset') parser.add_argument('--holdout_dataset', type=str, default=None, help='output name for holdout dataset') parser.add_argument('--ratio', type=float, default=0.1, help='Number of trees to hold out') parser.add_argument('--seed', type=int, default=1234, help='Random seed') args = parser.parse_args() if args.base_dataset is None: args.base_dataset = args.dataset + "-base" print("--base_dataset not set, using %s" % args.base_dataset) if args.holdout_dataset is None: args.holdout_dataset = args.dataset + "-holdout" print("--holdout_dataset not set, using %s" % args.holdout_dataset) base_path = get_default_paths()["CONSTITUENCY_DATA_DIR"] copy_dev_test(base_path, args.dataset, args.base_dataset) copy_dev_test(base_path, args.dataset, args.holdout_dataset) train_file = os.path.join(base_path, "%s_train.mrg" % args.dataset) print("Reading %s" % train_file) trees = tree_reader.read_tree_file(train_file) base_train = [] holdout_train = [] random.seed(args.seed) for tree in trees: if random.random() < args.ratio: holdout_train.append(tree) else: base_train.append(tree) write_trees(base_path, args.base_dataset, base_train) write_trees(base_path, args.holdout_dataset, holdout_train) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/split_weighted_ensemble.py ================================================ """ Read in a dataset and split the train portion into pieces One chunk of the train will be the original dataset. Others will be a sampling from the original dataset of the same size, but sampled with replacement, with the goal being to get a random distribution of trees with some reweighting of the original trees. """ import argparse import os import random from stanza.models.constituency import tree_reader from stanza.models.constituency.parse_tree import Tree from stanza.utils.datasets.constituency.utils import copy_dev_test from stanza.utils.default_paths import get_default_paths def main(): parser = argparse.ArgumentParser(description="Split a standard dataset into 1 base section and N-1 random redraws of training data") parser.add_argument('--dataset', type=str, default="id_icon", help='dataset to split') parser.add_argument('--seed', type=int, default=1234, help='Random seed') parser.add_argument('--num_splits', type=int, default=5, help='Number of splits') args = parser.parse_args() random.seed(args.seed) base_path = get_default_paths()["CONSTITUENCY_DATA_DIR"] train_file = os.path.join(base_path, "%s_train.mrg" % args.dataset) print("Reading %s" % train_file) train_trees = tree_reader.read_tree_file(train_file) # For datasets with low numbers of certain constituents in the train set, # we could easily find ourselves in a situation where all of the trees # with a specific constituent have been randomly shuffled away from # a random shuffle # An example of this is there are 3 total trees with SQ in id_icon # Therefore, we have to take a little care to guarantee at least one tree # for each constituent type is in a random slice # TODO: this doesn't compensate for transition schemes with compound transitions, # such as in_order_compound. could do a similar boosting with one per transition type constituents = sorted(Tree.get_unique_constituent_labels(train_trees)) con_to_trees = {con: list() for con in constituents} for tree in train_trees: tree_cons = Tree.get_unique_constituent_labels(tree) for con in tree_cons: con_to_trees[con].append(tree) for con in constituents: print("%d trees with %s" % (len(con_to_trees[con]), con)) for i in range(args.num_splits): dataset_name = "%s-random-%d" % (args.dataset, i) copy_dev_test(base_path, args.dataset, dataset_name) if i == 0: train_dataset = train_trees else: train_dataset = [] for con in constituents: train_dataset.extend(random.choices(con_to_trees[con], k=2)) needed_trees = len(train_trees) - len(train_dataset) if needed_trees > 0: print("%d trees already chosen. Adding %d more" % (len(train_dataset), needed_trees)) train_dataset.extend(random.choices(train_trees, k=needed_trees)) output_filename = os.path.join(base_path, "%s_train.mrg" % dataset_name) print("Writing {} trees to {}".format(len(train_dataset), output_filename)) Tree.write_treebank(train_dataset, output_filename) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/tokenize_wiki.py ================================================ """ A short script to use a Stanza tokenizer to extract tokenized sentences from Wikipedia The first step is to convert a Wikipedia dataset using Prof. Attardi's wikiextractor: https://github.com/attardi/wikiextractor This script then writes out sentences, one per line, whitespace separated Some common issues with the tokenizer are accounted for by discarding those lines. Also, to account for languages such as VI where whitespace occurs within words, spaces are replaced with _ This should not cause any confusion, as any line with a natural _ in has already been discarded. for i in `echo A B C D E F G H I J K`; do nlprun "python3 stanza/utils/datasets/constituency/tokenize_wiki.py --output_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_B$i.txt --lang it --max_len 120 --input_dir /u/nlp/data/Wikipedia/itwiki/B$i --tokenizer_model saved_models/tokenize/it_combined_tokenizer.pt --download_method None" -o /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_B$i.out; done """ import argparse import logging import stanza from stanza.models.common.bert_embedding import load_tokenizer, filter_data from stanza.utils.datasets.constituency import selftrain_wiki from stanza.utils.datasets.constituency.selftrain import add_length_args, tokenize_docs from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() def parse_args(): parser = argparse.ArgumentParser( description="Script that converts part of a wikipedia dump to silver standard trees" ) parser.add_argument( '--output_file', default='vi_wiki_tokenized.txt', help='Where to write the tokenized lines' ) parser.add_argument( '--lang', default='vi', help='Which language tools to use for tokenization and POS' ) input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument( '--input_dir', default=None, help='Path to the wikipedia dump after processing by wikiextractor' ) input_group.add_argument( '--input_file', default=None, help='Path to a single file of the wikipedia dump after processing by wikiextractor' ) parser.add_argument( '--bert_tokenizer', default=None, help='Which bert tokenizer (if any) to use to filter long sentences' ) parser.add_argument( '--tokenizer_model', default=None, help='Use this model instead of the current Stanza tokenizer for this language' ) parser.add_argument( '--download_method', default=None, help='Download pipeline models using this method (defaults to downloading updates from HF)' ) add_length_args(parser) args = parser.parse_args() return args def main(): args = parse_args() if args.input_dir is not None: files = selftrain_wiki.list_wikipedia_files(args.input_dir) elif args.input_file is not None: files = [args.input_file] else: raise ValueError("Need to specify at least one file or directory!") if args.bert_tokenizer: tokenizer = load_tokenizer(args.bert_tokenizer) print("Max model length: %d" % tokenizer.model_max_length) pipeline_args = {} if args.tokenizer_model: pipeline_args["tokenize_model_path"] = args.tokenizer_model if args.download_method: pipeline_args["download_method"] = args.download_method pipe = stanza.Pipeline(args.lang, processors="tokenize", **pipeline_args) with open(args.output_file, "w", encoding="utf-8") as fout: for filename in tqdm(files): docs = selftrain_wiki.read_wiki_file(filename) text = tokenize_docs(docs, pipe, args.min_len, args.max_len) if args.bert_tokenizer: filtered = filter_data(args.bert_tokenizer, [x.split() for x in text], tokenizer, logging.DEBUG) text = [" ".join(x) for x in filtered] for line in text: fout.write(line) fout.write("\n") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/treebank_to_labeled_brackets.py ================================================ """ Converts a PTB file to a format where all the brackets have labels on the start and end bracket. Such a file should be suitable for training an LM """ import argparse import logging import sys from stanza.models.constituency import tree_reader from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() logger = logging.getLogger('stanza.constituency') def main(): parser = argparse.ArgumentParser( description="Script that converts a PTB treebank into a labeled bracketed file suitable for LM training" ) parser.add_argument( 'ptb_file', help='Where to get the original PTB format treebank' ) parser.add_argument( 'label_file', help='Where to write the labeled bracketed file' ) parser.add_argument( '--separator', default="_", help='What separator to use in place of spaces', ) parser.add_argument( '--no_separator', dest='separator', action='store_const', const=None, help="Don't use a separator" ) args = parser.parse_args() treebank = tree_reader.read_treebank(args.ptb_file) logger.info("Writing %d trees to %s", len(treebank), args.label_file) tree_format = "{:%sL}\n" % args.separator if args.separator else "{:L}\n" with open(args.label_file, "w", encoding="utf-8") as fout: for tree in tqdm(treebank): fout.write(tree_format.format(tree)) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/utils.py ================================================ """ Utilities for the processing of constituency treebanks """ import os import shutil from stanza.models.constituency import parse_tree SHARDS = ("train", "dev", "test") def copy_dev_test(base_path, input_dataset, output_dataset): shutil.copy2(os.path.join(base_path, "%s_dev.mrg" % input_dataset), os.path.join(base_path, "%s_dev.mrg" % output_dataset)) shutil.copy2(os.path.join(base_path, "%s_test.mrg" % input_dataset), os.path.join(base_path, "%s_test.mrg" % output_dataset)) def write_dataset(datasets, output_dir, dataset_name): for dataset, shard in zip(datasets, SHARDS): output_filename = os.path.join(output_dir, "%s_%s.mrg" % (dataset_name, shard)) print("Writing {} trees to {}".format(len(dataset), output_filename)) parse_tree.Tree.write_treebank(dataset, output_filename) def split_treebank(treebank, train_size, dev_size): """ Split a treebank deterministically """ train_end = int(len(treebank) * train_size) dev_end = int(len(treebank) * (train_size + dev_size)) return treebank[:train_end], treebank[train_end:dev_end], treebank[dev_end:] ================================================ FILE: stanza/utils/datasets/constituency/vtb_convert.py ================================================ """ Script for processing the VTB files and turning their trees into the desired tree syntax The VTB original trees are stored in the directory: VietTreebank_VLSP_SP73/Kho ngu lieu 10000 cay cu phap The script requires two arguments: 1. Original directory storing the original trees 2. New directory storing the converted trees """ import argparse import os from collections import defaultdict from stanza.models.constituency.tree_reader import read_trees, MixedTreeError, UnlabeledTreeError REMAPPING = { '(ADV-MDP': '(RP-MDP', '(MPD': '(MDP', '(MP ': '(NP ', '(MP(': '(NP(', '(Np(': '(NP(', '(Np (': '(NP (', '(NLOC': '(NP-LOC', '(N-P-LOC': '(NP-LOC', '(N-p-loc': '(NP-LOC', '(NPDOB': '(NP-DOB', '(NPSUB': '(NP-SUB', '(NPTMP': '(NP-TMP', '(PPLOC': '(PP-LOC', '(SBA ': '(SBAR ', '(SBA-': '(SBAR-', '(SBA(': '(SBAR(', '(SBAS': '(SBAR', '(SABR': '(SBAR', '(SE-SPL': '(S-SPL', '(SBARR': '(SBAR', 'PPADV': 'PP-ADV', '(PR (': '(PP (', '(PPP': '(PP', 'VP0ADV': 'VP-ADV', '(S1': '(S', '(S2': '(S', '(S3': '(S', 'BP-SUB': 'NP-SUB', 'APPPD': 'AP-PPD', 'APPRD': 'AP-PPD', 'Np--H': 'Np-H', '(WPNP': '(WHNP', '(WHRPP': '(WHRP', # the one mistagged PV is on a prepositional phrase # (the subtree there maybe needs an SBAR as well, but who's counting) '(PV': '(PP', '(Mpd': '(MDP', # this only occurs on "bao giờ", "when" # that seems to be WHNP when under an SBAR, but WHRP otherwise '(Whadv ': '(WHRP ', # Whpr Occurs in two places: on "sao" in a context which is always WHRP, # and on "nào", which Vy says is more like a preposition '(Whpr (Pro-h nào))': '(WHPP (Pro-h nào))', '(Whpr (Pro-h Sao))': '(WHRP (Pro-h Sao))', # This is very clearly an NP: (Tp-tmp (N-h hiện nay)) # which is only ever in NP-TMP contexts '(Tp-tmp': '(NP-TMP', # This occurs once, in the context of (Yp (SYM @)) # The other times (SYM @) shows up, it's always NP '(Yp': '(NP', } def unify_label(tree): for old, new in REMAPPING.items(): tree = tree.replace(old, new) return tree def count_paren_parity(tree): """ Checks if the tree is properly closed :param tree: tree as a string :return: True if closed otherwise False """ count = 0 for char in tree: if char == '(': count += 1 elif char == ')': count -= 1 return count def is_valid_line(line): """ Check if a line being read is a valid constituent The idea is that some "trees" are just a long list of words with no tree structure and need to be eliminated. :param line: constituent being read :return: True if it has open OR closing parenthesis. """ if line.startswith('(') or line.endswith(')'): return True return False # not clear if TP is supposed to be NP or PP - needs a native speaker to decode WEIRD_LABELS = sorted(set(["WP", "YP", "SNP", "STC", "UPC", "(TP", "Xp", "XP", "WHVP", "WHPR", "NO", "WHADV", "(SC (", "(VOC (", "(Adv (", "(SP (", "ADV-MDP", "(SPL", "(ADV (", "(V-MWE ("] + list(REMAPPING.keys()))) # the 2023 dataset has TP and WHADV as actual labels # furthermore, trees with NO were cleaned up and one of the test trees has NORD as a word WEIRD_LABELS_2023 = sorted(set(["WP", "YP", "SNP", "STC", "UPC", "Xp", "XP", "WHVP", "WHPR", "(SC (", "(VOC (", "(Adv (", "(SP (", "ADV-MDP", "(SPL", "(ADV (", "(V-MWE ("] + list(REMAPPING.keys()))) def convert_file(orig_file, new_file, fix_errors=True, convert_brackets=False, updated_tagset=False, write_ids=False): """ :param orig_file: original directory storing original trees :param new_file: new directory storing formatted constituency trees This function writes new trees to the corresponding files in new_file """ if updated_tagset: weird_labels = WEIRD_LABELS_2023 else: weird_labels = WEIRD_LABELS errors = defaultdict(list) with open(orig_file, 'r', encoding='utf-8') as reader, open(new_file, 'w', encoding='utf-8') as writer: content = reader.readlines() # Tree string will only be written if the currently read # tree is a valid tree. It will not be written if it # does not have a '(' that signifies the presence of constituents tree = "" tree_id = None reading_tree = False for line_idx, line in enumerate(content): line = ' '.join(line.split()) if line == '': continue elif line == '' or line.startswith("") tree_id = int(tree_id[:-1]) elif line == '' and reading_tree: # one tree in 25432.prd is not valid because # it is just a bunch of blank lines if tree.strip() == '(ROOT': errors["empty"].append("Empty tree in {} line {}".format(orig_file, line_idx)) continue tree += ')\n' parity = count_paren_parity(tree) if parity > 0: errors["unclosed"].append("Unclosed tree from {} line {}: |{}|".format(orig_file, line_idx, tree)) continue if parity < 0: errors["extra_parens"].append("Extra parens at end of tree from {} line {} for having extra parens: {}".format(orig_file, line_idx, tree)) continue if convert_brackets: tree = tree.replace("RBKT", "-RRB-").replace("LBKT", "-LRB-") try: # test that the tree can be read in properly processed_trees = read_trees(tree) if len(processed_trees) > 1: errors["multiple"].append("Multiple trees in one xml annotation from {} line {}".format(orig_file, line_idx)) continue if len(processed_trees) == 0: errors["empty"].append("Empty tree in {} line {}".format(orig_file, line_idx)) continue if not processed_trees[0].all_leaves_are_preterminals(): errors["untagged_leaf"].append("Tree with non-preterminal leaves in {} line {}: {}".format(orig_file, line_idx, tree)) continue # Unify the labels if fix_errors: tree = unify_label(tree) # TODO: this block eliminates 3 trees from VLSP-22 # maybe those trees can be salvaged? bad_label = False for weird_label in weird_labels: if tree.find(weird_label) >= 0: bad_label = True errors[weird_label].append("Weird label {} from {} line {}: {}".format(weird_label, orig_file, line_idx, tree)) break if bad_label: continue if write_ids: if tree_id is None: errors["missing_id"].append("Missing ID from {} at line {}".format(orig_file, line_idx)) writer.write("") else: writer.write("\n" % tree_id) writer.write(tree) if write_ids: writer.write("\n") reading_tree = False tree = "" tree_id = None except MixedTreeError: errors["mixed"].append("Mixed leaves and constituents from {} line {}: {}".format(orig_file, line_idx, tree)) except UnlabeledTreeError: errors["unlabeled"].append("Unlabeled nodes in tree from {} line {}: {}".format(orig_file, line_idx, tree)) else: # content line if is_valid_line(line) and reading_tree: tree += line elif reading_tree: errors["invalid"].append("Invalid tree error in {} line {}: |{}|, rejected because of line |{}|".format(orig_file, line_idx, tree, line)) reading_tree = False return errors def convert_files(file_list, new_dir, verbose=False, fix_errors=True, convert_brackets=False, updated_tagset=False, write_ids=False): errors = defaultdict(list) for filename in file_list: base_name, _ = os.path.splitext(os.path.split(filename)[-1]) new_path = os.path.join(new_dir, base_name) new_file_path = f'{new_path}.mrg' # Convert the tree and write to new_file_path new_errors = convert_file(filename, new_file_path, fix_errors, convert_brackets, updated_tagset, write_ids) for e in new_errors: errors[e].extend(new_errors[e]) if len(errors.keys()) == 0: print("All errors were fixed!") else: print("Found the following errors:") keys = sorted(errors.keys()) if verbose: for e in keys: print("--------- %10s -------------" % e) print("\n\n".join(errors[e])) print() print() for e in keys: print("%s: %d" % (e, len(errors[e]))) def convert_dir(orig_dir, new_dir): file_list = os.listdir(orig_dir) # Only convert .prd files, skip the .raw files from VLSP 2009 file_list = [os.path.join(orig_dir, f) for f in file_list if os.path.splitext(f)[1] != '.raw'] convert_files(file_list, new_dir) def main(): """ Converts files from the 2009 version of VLSP to .mrg files Process args, loop through each file in the directory and convert to the desired tree format """ parser = argparse.ArgumentParser( description="Script that converts a VTB Tree into the desired format", ) parser.add_argument( 'orig_dir', help='The location of the original directory storing original trees ' ) parser.add_argument( 'new_dir', help='The location of new directory storing the new formatted trees' ) args = parser.parse_args() org_dir = args.org_dir new_dir = args.new_dir convert_dir(org_dir, new_dir) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/constituency/vtb_split.py ================================================ """ From a directory of files with VTB Trees, split into train/dev/test set with a split of 70/15/15 The script requires two arguments 1. org_dir: the original directory obtainable from running vtb_convert.py 2. split_dir: the directory where the train/dev/test splits will be stored """ import os import argparse import random def create_shuffle_list(org_dir): """ This function creates the random order with which we use to loop through the files :param org_dir: original directory storing the files that store the trees :return: list of file names randomly shuffled """ file_names = sorted(os.listdir(org_dir)) random.shuffle(file_names) return file_names def create_paths(split_dir, short_name): """ This function creates the necessary paths for the train/dev/test splits :param split_dir: directory that stores the splits :return: train path, dev path, test path """ if not short_name: short_name = "" elif not short_name.endswith("_"): short_name = short_name + "_" train_path = os.path.join(split_dir, '%strain.mrg' % short_name) dev_path = os.path.join(split_dir, '%sdev.mrg' % short_name) test_path = os.path.join(split_dir, '%stest.mrg' % short_name) return train_path, dev_path, test_path def get_num_samples(org_dir, file_names): """ Function for obtaining the number of samples :param org_dir: original directory storing the tree files :param file_names: list of file names in the directory :return: number of samples """ count = 0 # Loop through the files, which then loop through the trees for filename in file_names: # Skip files that are not .mrg if not filename.endswith('.mrg'): continue # File is .mrg. Start processing file_dir = os.path.join(org_dir, filename) with open(file_dir, 'r', encoding='utf-8') as reader: content = reader.readlines() for line in content: count += 1 return count def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15, rotation=None): os.makedirs(split_dir, exist_ok=True) if train_size + dev_size >= 1.0: print("Not making a test slice with the given ratios: train {} dev {}".format(train_size, dev_size)) # Create a random shuffle list of the file names in the original directory file_names = create_shuffle_list(org_dir) # Create train_path, dev_path, test_path train_path, dev_path, test_path = create_paths(split_dir, short_name) # Set up the number of samples for each train/dev/test set # TODO: if we ever wanted to split files with in them, # this particular code would need some updating to pay attention to the ids num_samples = get_num_samples(org_dir, file_names) print("Found {} total lines in {}".format(num_samples, org_dir)) stop_train = int(num_samples * train_size) if train_size + dev_size >= 1.0: stop_dev = num_samples output_limits = (stop_train, stop_dev) output_names = (train_path, dev_path) print("Splitting {} train, {} dev".format(stop_train, stop_dev - stop_train)) elif train_size + dev_size > 0.0: stop_dev = int(num_samples * (train_size + dev_size)) output_limits = (stop_train, stop_dev, num_samples) output_names = (train_path, dev_path, test_path) print("Splitting {} train, {} dev, {} test".format(stop_train, stop_dev - stop_train, num_samples - stop_dev)) else: stop_dev = 0 output_limits = (num_samples,) output_names = (test_path,) print("Copying all {} lines to test".format(num_samples)) # Count how much stuff we've written. # We will switch to the next output file when we're written enough count = 0 trees = [] for filename in file_names: if not filename.endswith('.mrg'): continue with open(os.path.join(org_dir, filename), encoding='utf-8') as reader: new_trees = reader.readlines() new_trees = [x.strip() for x in new_trees] new_trees = [x for x in new_trees if x] trees.extend(new_trees) # rotate the train & dev sections, leave the test section the same if rotation is not None and rotation[0] > 0: rotation_start = len(trees) * rotation[0] // rotation[1] rotation_end = stop_dev # if there are no test trees, rotation_end: will be empty anyway trees = trees[rotation_start:rotation_end] + trees[:rotation_start] + trees[rotation_end:] tree_iter = iter(trees) for write_path, count_limit in zip(output_names, output_limits): with open(write_path, 'w', encoding='utf-8') as writer: # Loop through the files, which then loop through the trees and write to write_path while count < count_limit: next_tree = next(tree_iter, None) if next_tree is None: raise RuntimeError("Ran out of trees before reading all of the expected trees") # Write to write_dir writer.write(next_tree) writer.write("\n") count += 1 def main(): """ Main function for the script Process args, loop through each tree in each file in the directory and write the trees to the train/dev/test split with a split of 70/15/15 """ parser = argparse.ArgumentParser( description="Script that splits a list of files of vtb trees into train/dev/test sets", ) parser.add_argument( 'org_dir', help='The location of the original directory storing correctly formatted vtb trees ' ) parser.add_argument( 'split_dir', help='The location of new directory storing the train/dev/test set' ) args = parser.parse_args() org_dir = args.org_dir split_dir = args.split_dir random.seed(1234) split_files(org_dir, split_dir) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/contract_mwt.py ================================================ import sys def contract_mwt(infile, outfile, ignore_gapping=True): """ Simplify the gold tokenizer data for use as MWT processor test files The simplifications are to remove the expanded MWTs, and in the case of ignore_gapping=True, remove any copy words for the dependencies """ with open(outfile, 'w', encoding='utf-8') as fout: with open(infile, 'r', encoding='utf-8') as fin: idx = 0 mwt_begin = 0 mwt_end = -1 for line in fin: line = line.strip() if line.startswith('#'): print(line, file=fout) continue elif len(line) <= 0: print(line, file=fout) idx = 0 mwt_begin = 0 mwt_end = -1 continue line = line.split('\t') # ignore gapping word if ignore_gapping and '.' in line[0]: continue idx += 1 if '-' in line[0]: mwt_begin, mwt_end = [int(x) for x in line[0].split('-')] print("{}\t{}\t{}".format(idx, "\t".join(line[1:-1]), "MWT=Yes" if line[-1] == '_' else line[-1] + "|MWT=Yes"), file=fout) idx -= 1 elif mwt_begin <= idx <= mwt_end: continue else: print("{}\t{}".format(idx, "\t".join(line[1:])), file=fout) if __name__ == '__main__': contract_mwt(sys.argv[1], sys.argv[2]) ================================================ FILE: stanza/utils/datasets/coref/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/coref/balance_languages.py ================================================ """ balance_concat.py create a test set from a dev set which is language balanced """ import json from collections import defaultdict from random import Random # fix random seed for reproducability R = Random(42) with open("./corefud_concat_v1_0_langid.train.json", 'r') as df: raw = json.load(df) # calculate type of each class; then, we will select the one # which has the LOWEST counts as the sample rate lang_counts = defaultdict(int) for i in raw: lang_counts[i["lang"]] += 1 min_lang_count = min(lang_counts.values()) # sample 20% of the smallest amount for test set # this will look like an absurdly small number, but # remember this is DOCUMENTS not TOKENS or UTTERANCES # so its actually decent # also its per language test_set_size = int(0.1*min_lang_count) # sampling input by language raw_by_language = defaultdict(list) for i in raw: raw_by_language[i["lang"]].append(i) languages = list(set(raw_by_language.keys())) train_set = [] test_set = [] for i in languages: length = list(range(len(raw_by_language[i]))) choices = R.sample(length, test_set_size) for indx,i in enumerate(raw_by_language[i]): if indx in choices: test_set.append(i) else: train_set.append(i) with open("./corefud_concat_v1_0_langid-bal.train.json", 'w') as df: json.dump(train_set, df, indent=2) with open("./corefud_concat_v1_0_langid-bal.test.json", 'w') as df: json.dump(test_set, df, indent=2) # raw_by_language["en"] ================================================ FILE: stanza/utils/datasets/coref/convert_hebrew_iahlt.py ================================================ """Convert the coref annotation of IAHLT to the Stanza coref format This dataset is available at https://github.com/IAHLT/coref Download it via git clone to $COREF_BASE/hebrew, so for example on the cluster: cd /u/nlp/data/coref/ mkdir hebrew cd hebrew git clone git@github.com:IAHLT/coref.git Then run python3 stanza/utils/datasets/coref/convert_hebrew_iahlt.py The scores for models built from the dataset are pretty lousy in general, but seem to be in line with the scores obtained by other people working on this data. For example, the authors said they had a 52 F1, whereas if we use roberta-xlm, we get 50. """ import argparse from collections import defaultdict, namedtuple import json import os import stanza from stanza.utils.default_paths import get_default_paths from stanza.utils.get_tqdm import get_tqdm from stanza.utils.datasets.coref.utils import process_document tqdm = get_tqdm() CorefDoc = namedtuple("CorefDoc", ['doc_id', 'sentences', 'coref_spans']) # TODO: binary search for speed? def search_mention_start(doc, mention_start): for sent_idx, sentence in enumerate(doc.sentences): if mention_start < doc.sentences[sent_idx].tokens[-1].end_char: break else: raise ValueError for word_idx, word in enumerate(sentence.words): if word.end_char is None: print("Found weirdness on sentence:\n|%s|" % sentence.text) print(word.parent) return None, None if mention_start < word.end_char: break else: raise ValueError return sent_idx, word_idx def search_mention_end(doc, mention_end): for sent_idx, sentence in enumerate(doc.sentences): if sent_idx + 1 == len(doc.sentences) or mention_end < doc.sentences[sent_idx+1].tokens[0].start_char: break for word_idx, word in enumerate(sentence.words): if word_idx + 1 == len(sentence.words) or mention_end < sentence.words[word_idx+1].start_char: break return sent_idx, word_idx def extract_doc(tokenizer, lines): # 16, 1, 5 for the train, dev, test sets broken = 0 tok_error = 0 singletons = 0 one_words = 0 processed_docs = [] for line_idx, line in enumerate(tqdm(lines)): all_clusters = defaultdict(list) doc_id = line['metadata']['doc_id'] text = line['text'] clusters = line['clusters'] doc = tokenizer(text) for cluster_idx, cluster in enumerate(clusters): found_mentions = [] for mention_idx, mention in enumerate(cluster['mentions']): mention_start = mention[0] mention_end = mention[1] start_sent, start_word = search_mention_start(doc, mention_start) if start_sent is None or start_word is None: tok_error += 1 continue end_sent, end_word = search_mention_end(doc, mention_end) assert end_sent >= start_sent if start_sent != end_sent: broken += 1 continue assert end_word >= start_word if end_word == start_word: one_words += 1 found_mentions.append((start_sent, start_word, end_word)) #if cluster_idx == 0 and line_idx == 0: # expanded_start = max(0, mention_start - 10) # expanded_end = min(len(text), mention_end + 10) # print("EXTRACTING MENTION: %d %d" % (mention[0], mention[1])) # print(" context: |%s|" % text[expanded_start:expanded_end]) # print(" mention[0]:mention[1]: |%s|" % text[mention[0]:mention[1]]) # print(" search text: |%s|" % text[mention_start:mention_end]) # extracted_words = doc.sentences[start_sent].words[start_word:end_word+1] # extracted_text = " ".join([x.text for x in extracted_words]) # print(" extracted words: |%s|" % extracted_text) # print(" endpoints: %d %d" % (mention_start, mention_end)) # print(" number of extracted words: %d" % len(extracted_words)) # print(" first word endpoints: %d %d" % (extracted_words[0].start_char, extracted_words[0].end_char)) # print(" last word endpoints: %d %d" % (extracted_words[-1].start_char, extracted_words[-1].end_char)) if len(found_mentions) == 0: continue elif len(found_mentions) == 1: # the number of singletons, after discarding mentions that # crossed a sentence boundary according to Stanza, is # 5, 0, 1 # so clearly the dataset does not intentionally have # (many?) singletons in it singletons += 1 continue else: all_clusters[cluster_idx] = found_mentions # maybe we need to update the interface - there can be MWT in Hebrew sentences = [[word.text for word in sent.words] for sent in doc.sentences] coref_spans = defaultdict(list) for cluster_idx in all_clusters: for sent_idx, start_word, end_word in all_clusters[cluster_idx]: coref_spans[sent_idx].append((cluster_idx, start_word, end_word)) processed_docs.append(CorefDoc(doc_id, sentences, coref_spans)) print("Found %d broken across two sentences, %d tok errors, %d singleton mentions, %d one_word mentions" % (broken, tok_error, singletons, one_words)) return processed_docs def read_doc(tokenizer, filename): with open(filename, encoding="utf-8") as fin: lines = fin.readlines() lines = [json.loads(line) for line in lines] return extract_doc(tokenizer, lines) def write_json_file(output_filename, dataset): with open(output_filename, "w", encoding="utf-8") as fout: json.dump(dataset, fout, indent=2, ensure_ascii=False) def main(args=None): paths = get_default_paths() parser = argparse.ArgumentParser( prog='Convert Hebrew IAHLT data', ) parser.add_argument('--output_directory', default=None, type=str, help='Where to output the data (defaults to %s)' % paths['COREF_DATA_DIR']) args = parser.parse_args(args=args) coref_output_path = args.output_directory if args.output_directory else paths['COREF_DATA_DIR'] print("Will write IAHLT dataset to %s" % coref_output_path) coref_input_path = paths["COREF_BASE"] hebrew_base_path = os.path.join(coref_input_path, "hebrew", "coref", "train_val_test") tokenizer = stanza.Pipeline("he", processors="tokenize", package="default_accurate") pipe = stanza.Pipeline("he", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True) input_files = ["coref-5-heb_train.jsonl", "coref-5-heb_val.jsonl", "coref-5-heb_test.jsonl"] output_files = ["he_iahlt.train.json", "he_iahlt.dev.json", "he_iahlt.test.json"] for input_filename, output_filename in zip(input_files, output_files): input_filename = os.path.join(hebrew_base_path, input_filename) assert os.path.exists(input_filename) docs = read_doc(tokenizer, input_filename) dataset = [process_document(pipe, doc.doc_id, "", doc.sentences, doc.coref_spans, None, lang="he") for doc in tqdm(docs)] output_filename = os.path.join(coref_output_path, output_filename) write_json_file(output_filename, dataset) return output_files if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/coref/convert_hebrew_mixed.py ================================================ """ Build a dataset mixed with IAHLT Hebrew and UD Coref We find that the IAHLT dataset by itself, trained using Stanza 1.11 with xlm-roberta-large and a lora finetuning layer, gets 49.7 F1. This is a bit lower than the value the IAHLT group originally had, as they reported 52. Interestingly, we find that mixing in the 1.3 UD Coref improves results, getting 51.7 under the same parameters This script runs the IAHLT conversion and the UD Coref conversion, then combines the files into one big training file """ import json import os import shutil import tempfile from stanza.utils.datasets.coref import convert_hebrew_iahlt from stanza.utils.datasets.coref import convert_udcoref from stanza.utils.default_paths import get_default_paths def main(): paths = get_default_paths() coref_output_path = paths['COREF_DATA_DIR'] with tempfile.TemporaryDirectory() as temp_dir_path: hebrew_filenames = convert_hebrew_iahlt.main(["--output_directory", temp_dir_path]) udcoref_filenames = convert_udcoref.main(["--project", "gerrom", "--output_directory", temp_dir_path]) with open(os.path.join(temp_dir_path, hebrew_filenames[0]), encoding="utf-8") as fin: hebrew_train = json.load(fin) udcoref_train_filename = os.path.join(temp_dir_path, udcoref_filenames[0]) with open(udcoref_train_filename, encoding="utf-8") as fin: print("Reading extra udcoref json data from %s" % udcoref_train_filename) udcoref_train = json.load(fin) mixed_train = hebrew_train + udcoref_train with open(os.path.join(coref_output_path, "he_mixed.train.json"), "w", encoding="utf-8") as fout: json.dump(mixed_train, fout, indent=2, ensure_ascii=False) shutil.copyfile(os.path.join(temp_dir_path, hebrew_filenames[1]), os.path.join(coref_output_path, "he_mixed.dev.json")) shutil.copyfile(os.path.join(temp_dir_path, hebrew_filenames[2]), os.path.join(coref_output_path, "he_mixed.test.json")) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/coref/convert_hindi.py ================================================ import argparse import json from operator import itemgetter import os import stanza from stanza.utils.default_paths import get_default_paths from stanza.utils.get_tqdm import get_tqdm from stanza.utils.datasets.coref.utils import process_document tqdm = get_tqdm() def flatten_spans(coref_spans): """ Put span IDs on each span, then flatten them into a single list sorted by first word """ # put span indices on the spans # [[[38, 39], [42, 43], [41, 41], [180, 180], [300, 300]], [[60, 68], # --> # [[[0, 38, 39], [0, 42, 43], [0, 41, 41], [0, 180, 180], [0, 300, 300]], [[1, 60, 68], ... coref_spans = [[[span_idx, x, y] for x, y in span] for span_idx, span in enumerate(coref_spans)] # flatten list # --> # [[0, 38, 39], [0, 42, 43], [0, 41, 41], [0, 180, 180], [0, 300, 300], [1, 60, 68], ... coref_spans = [y for x in coref_spans for y in x] # sort by the first word index # --> # [[0, 38, 39], [0, 41, 41], [0, 42, 43], [1, 60, 68], [0, 180, 180], [0, 300, 300], ... coref_spans = sorted(coref_spans, key=itemgetter(1)) return coref_spans def remove_nulls(coref_spans, sentences): """ Removes the "" and "NULL" words from the sentences Also, reindex the spans by the number of words removed. So, we might get something like [[0, 2], [31, 33], [134, 136], [161, 162]] -> [[0, 2], [30, 32], [129, 131], [155, 156]] """ word_map = [] word_idx = 0 map_idx = 0 new_sentences = [] for sentence in sentences: new_sentence = [] for word in sentence: word_map.append(map_idx) word_idx += 1 if word != '' and word != 'NULL': new_sentence.append(word) map_idx += 1 new_sentences.append(new_sentence) new_spans = [] for mention in coref_spans: new_mention = [] for span in mention: span = [word_map[x] for x in span] new_mention.append(span) new_spans.append(new_mention) return new_spans, new_sentences def arrange_spans_by_sentence(coref_spans, sentences): sentence_spans = [] current_index = 0 span_idx = 0 for sentence in sentences: current_sentence_spans = [] end_index = current_index + len(sentence) while span_idx < len(coref_spans) and coref_spans[span_idx][1] < end_index: new_span = [coref_spans[span_idx][0], coref_spans[span_idx][1] - current_index, coref_spans[span_idx][2] - current_index] current_sentence_spans.append(new_span) span_idx += 1 sentence_spans.append(current_sentence_spans) current_index = end_index return sentence_spans def convert_dataset_section(pipe, section, use_cconj_heads): """ Reprocess the original data into a format compatible with previous conversion utilities - remove blank and NULL words - rearrange the spans into spans per sentence instead of a list of indices for each span - process the document using a Hindi pipeline """ processed_section = [] for idx, doc in enumerate(tqdm(section)): doc_id = doc['doc_key'] part_id = "" sentences = doc['sentences'] sentence_speakers = doc['speakers'] coref_spans = doc['clusters'] coref_spans, sentences = remove_nulls(coref_spans, sentences) coref_spans = flatten_spans(coref_spans) coref_spans = arrange_spans_by_sentence(coref_spans, sentences) processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=use_cconj_heads) processed_section.append(processed) return processed_section def remove_nulls_dataset_section(section): processed_section = [] for doc in section: sentences = doc['sentences'] coref_spans = doc['clusters'] coref_spans, sentences = remove_nulls(coref_spans, sentences) doc['sentences'] = sentences doc['clusters'] = coref_spans processed_section.append(doc) return processed_section def read_json_file(filename): with open(filename, encoding="utf-8") as fin: dataset = [] for line in fin: line = line.strip() if not line: continue dataset.append(json.loads(line)) return dataset def write_json_file(output_filename, converted_section): with open(output_filename, "w", encoding="utf-8") as fout: json.dump(converted_section, fout, indent=2) def main(): parser = argparse.ArgumentParser( prog='Convert Hindi Coref Data', ) parser.add_argument('--no_use_cconj_heads', dest='use_cconj_heads', action='store_false', help="Don't use the conjunction-aware transformation") parser.add_argument('--remove_nulls', action='store_true', help="The only action is to remove the NULLs and blank tokens") args = parser.parse_args() paths = get_default_paths() coref_input_path = paths["COREF_BASE"] hindi_base_path = os.path.join(coref_input_path, "hindi", "dataset") sections = ("train", "dev", "test") if args.remove_nulls: for section in sections: input_filename = os.path.join(hindi_base_path, "%s.hindi.jsonlines" % section) dataset = read_json_file(input_filename) dataset = remove_nulls_dataset_section(dataset) output_filename = os.path.join(hindi_base_path, "hi_deeph.%s.nonulls.json" % section) with open(output_filename, "w", encoding="utf-8") as fout: for doc in dataset: json.dump(doc, fout, ensure_ascii=False) fout.write("\n") else: pipe = stanza.Pipeline("hi", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True) os.makedirs(paths["COREF_DATA_DIR"], exist_ok=True) for section in sections: input_filename = os.path.join(hindi_base_path, "%s.hindi.jsonlines" % section) dataset = read_json_file(input_filename) output_filename = os.path.join(paths["COREF_DATA_DIR"], "hi_deeph.%s.json" % section) converted_section = convert_dataset_section(pipe, dataset, use_cconj_heads=args.use_cconj_heads) write_json_file(output_filename, converted_section) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/coref/convert_ontonotes.py ================================================ """ convert_ontonotes.py This script is used to convert the OntoNotes dataset into a format that can be used by Stanza's coreference resolution model. The script uses the datasets package to download the OntoNotes dataset and then processes the dataset using Stanza's coreference resolution pipeline. The processed dataset is then saved in a JSON file. If you want to simply process the official OntoNotes dataset... 1. install the `datasets` package: `pip install datasets` 2. make folders! (or those adjusted to taste through scripts/config.sh) - extern_data/coref/english/en_ontonotes - data/coref 2. run this script: python -m stanza.utils.datasets.coref.convert_ontonotes If you happen to have singleton annotated coref chains... 1. install the `datasets` package: `pip install datasets` 2. make folders! (or those adjusted to taste through scripts/config.sh) - extern_data/coref/english/en_ontonotes - data/coref 3. get the singletons annotated coref chains in conll format from the Splice repo https://github.com/yilunzhu/splice/raw/refs/heads/main/data/ontonotes5_mentions.zip 4. place the singleton annotated coref chains in the folder `extern_data/coref/english/en_ontonotes` $ ls ./extern_data/coref/english/en_ontonotes dev_sg_pred.english.v4_gold_conll test_sg_pred.english.v4_gold_conll train_sg.english.v4_gold_conll 5. run this script: python -m stanza.utils.datasets.coref.convert_ontonotes --use_singletons Your results will appear in ./data/coref/, and you can be off to the races with training! Note that this script invokes Stanza itself to run some tagging. """ import json import os from pathlib import Path import argparse import stanza from stanza.models.constituency import tree_reader from stanza.utils.default_paths import get_default_paths from stanza.utils.get_tqdm import get_tqdm from stanza.utils.datasets.coref.utils import process_document from stanza.utils.conll import CoNLL from collections import defaultdict tqdm = get_tqdm() def read_paragraphs(section): for doc in section: part_id = None paragraph = [] for sentence in doc['sentences']: if part_id is None: part_id = sentence['part_id'] elif part_id != sentence['part_id']: yield doc['document_id'], part_id, paragraph paragraph = [] part_id = sentence['part_id'] paragraph.append(sentence) if paragraph != []: yield doc['document_id'], part_id, paragraph def convert_dataset_section(pipe, section, override_singleton_chains=None): processed_section = [] section = list(x for x in read_paragraphs(section)) # we need to do this because apparently the singleton annotations # don't use the same numbering scheme as the ontonotes annotations # so there will be chain id conflicts max_chain_id = sorted([ chain_id for i in section for j in i[2] for chain_id, _, _ in j["coref_spans"] ])[-1] # this dictionary will map singleton chains' "special" ids # to the OntoNotes IDs sg_to_ontonotes_cluster_id_map = defaultdict( lambda: len(sg_to_ontonotes_cluster_id_map)+max_chain_id+1 ) for idx, (doc_id, part_id, paragraph) in enumerate(tqdm(section)): sentences = [x['words'] for x in paragraph] truly_coref_spans = [x['coref_spans'] for x in paragraph] # the problem to solve here is that the singleton chains' # IDs don't match the coref chains' ids # # and, what the labels calls a "singleton" may not actually # be one because the "singleton" seems like it includes all # NPs which may or may not be a singleton coref_spans = [] if override_singleton_chains: singleton_chains = override_singleton_chains[doc_id][part_id] for singleton_pred, coref_pred in zip(singleton_chains, truly_coref_spans): sentence_coref_preds = [] # these are sentence level predictions, which we will # disambiguate: if a subspan of "singleton" exists in the # truly coref sets, we realise its not a singleton and # then ignore it coref_pred_locs = set([tuple(i[1:]) for i in coref_pred]) for id,start,end in singleton_pred: if (start,end) not in coref_pred_locs: # this is truly a singleton sentence_coref_preds.append([ sg_to_ontonotes_cluster_id_map[id], start, end ]) sentence_coref_preds += coref_pred coref_spans.append(sentence_coref_preds) else: coref_spans = truly_coref_spans sentence_speakers = [x['speaker'] for x in paragraph] processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers) processed_section.append(processed) return processed_section def extract_chains_from_chunk(chunk): """give a chunk of the gold conll, extract the coref chains remember, the indicies are front and back *inclusive*, zero indexed and a span that takes one word only is annotated [id, n, n] (i.e. we don't fencepost by +1) Arguments --------- chunk : List[str] list of strings, each string is a line in the conll file Returns ------- final_chains : List[Tuple[int, int, int ]] list of chains, each chain is a list of [id, open_location, close_location] """ chains = [sentence.split(" ")[-1].strip() for sentence in chunk] chains = [[] if i == '-' else i.split("|") for i in chains] opens = defaultdict(list) closes = defaultdict(list) for indx, elem in enumerate(chains): # for each one, check if its an open, close, or both for i in elem: id = int(i.strip("(").strip(")")) if (i[0]=="("): opens[id].append(indx) if (i[-1]==")"): closes[id].append(indx) # and now, we chain the ids' opens and closes together # into the shape of [id, open_location, close_location] opens = dict(opens) closes = dict(closes) final_chains = [] for key, open_indx in opens.items(): for o,c in zip(sorted(open_indx), sorted(closes[key])): final_chains.append([key, o,c]) return final_chains def extract_chains_from_conll(gold_coref_conll): """extract the coref chains from the gold conll file Arguments -------- gold_coref_conll : str path to the gold conll file, with coreference chains Returns ------- final_chunks : Dict[str, List[List[List[Tuple[int, int, int]]]]] dictionary of document_id to list of paragraphs into list of coref chains in OntoNotes style, keyed by document ID """ with open(gold_coref_conll, 'r') as df: gold_coref_conll = df.readlines() # we want to first separate the document into sentence-level # chunks; we assume that the ordering of the sentences are correct in the # gold document sections = [] section = [] chunk = [] for i in gold_coref_conll: if len(i.split(" ")) < 10: if len(chunk) > 0: section.append(chunk) elif i.startswith("#end document"): # this is a new paragraph sections.append(section) section = [] chunk = [] else: chunk.append(i) # finally, we process each chunk and *index them by ID* final_chunks = defaultdict(list) for section in sections: section_chains = [] for chunk in section: section_chains.append(extract_chains_from_chunk(chunk)) final_chunks[chunk[0].split(" ")[0]].append(section_chains) final_chunks = dict(final_chunks) return final_chunks SECTION_NAMES = {"train": "train", "dev": "validation", "test": "test"} OVERRIDE_CONLL_PATHS = {"en_ontonotes": { "train": "train_sg.english.v4_gold_conll", "validation": "dev_sg_pred.english.v4_gold_conll", "test": "test_sg_pred.english.v4_gold_conll" }} def process_dataset(short_name, ontonotes_path, coref_output_path, use_singletons=False): try: from datasets import load_dataset except ImportError as e: raise ImportError("Please install the datasets package to process OntoNotes coref with Stanza") if short_name == 'en_ontonotes': config_name = 'english_v4' elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'): config_name = 'chinese_v4' elif short_name == 'ar_ontonotes': config_name = 'arabic_v4' else: raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name) pipe = stanza.Pipeline("en", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True) # if the cache directory doesn't yet exist, we make it # we store the cache in a separate subfolder to distinguish from the # possible Singleton conlls that maybe in the folder (Path(ontonotes_path) / "cache").mkdir(exist_ok=True) dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=str(Path(ontonotes_path) / "cache"), trust_remote_code=True) for section, hf_name in SECTION_NAMES.items(): # for section, hf_name in [("test", "test")]: print("Processing %s" % section) if use_singletons: singletons_path = (Path(ontonotes_path) / OVERRIDE_CONLL_PATHS[short_name][hf_name]) if not singletons_path.exists(): raise FileNotFoundError( "Could not find singleton annotated coref chains " "in conll format\nensure you have placed them in the folder %s" % singletons_path ) # if, for instance, Amir have given us some singleton annotated coref chains in conll files, # we will use those instead of the ones that OntoNotes has converted_section = convert_dataset_section(pipe, dataset[hf_name], extract_chains_from_conll( str(singletons_path) )) else: converted_section = convert_dataset_section(pipe, dataset[hf_name]) output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section)) with open(output_filename, "w", encoding="utf-8") as fout: json.dump(converted_section, fout, indent=2) def main(): parser = argparse.ArgumentParser(prog="convert_ontonotes.py", description="Convert OntoNotes dataset to Stanza's coreference format") parser.add_argument("--use_singletons", default=False, action="store_true", help="Use singleton annotated coref chains") args = parser.parse_args() paths = get_default_paths() coref_input_path = paths['COREF_BASE'] ontonotes_path = os.path.join(coref_input_path, "english", "en_ontonotes") coref_output_path = paths['COREF_DATA_DIR'] process_dataset("en_ontonotes", ontonotes_path, coref_output_path, args.use_singletons) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/coref/convert_tamil.py ================================================ """ Convert the AU-KBC coreference dataset from Prof. Sobha https://aclanthology.org/2020.wildre-1.4/ Located in /u/nlp/data/coref/tamil on the Stanford cluster """ import argparse import glob import json from operator import itemgetter import os import random import re import stanza from stanza.utils.datasets.coref.utils import process_document from stanza.utils.default_paths import get_default_paths from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() begin_re = re.compile(r"B-([0-9]+)") in_re = re.compile(r"I-([0-9]+)") def write_json_file(output_filename, converted_section): with open(output_filename, "w", encoding="utf-8") as fout: json.dump(converted_section, fout, indent=2) def read_doc(filename): """ Returns the sentences and the coref markings from this filename sentences: a list of list of words corefs: a list of list of clusters, which were tagged B-num and I-num in the dataset """ with open(filename, encoding="utf-8") as fin: lines = fin.readlines() all_words = [] all_coref = [] current_words = [] current_coref = [] for line in lines: line = line.strip() if not line: all_words.append(current_words) all_coref.append(current_coref) current_words = [] current_coref = [] continue pieces = line.split("\t") current_words.append(pieces[3]) current_coref.append(pieces[-1]) if current_words: all_words.append(current_words) all_coref.append(current_coref) return all_words, all_coref def convert_clusters(filename, corefs): sentence_clusters = [] # current_clusters will be a list of (cluster id, start idx) for sent_idx, sentence_coref in enumerate(corefs): current_clusters = [] processed = [] for word_idx, word_coref in enumerate(sentence_coref): new_clusters = [] if word_coref == '-': pieces = [] else: pieces = word_coref.split(";") for piece in pieces: if not piece.startswith("I-") and not piece.startswith("B-"): raise ValueError("Unexpected coref format %s in document %s" % (word_coref, filename)) if piece.startswith("B-"): new_clusters.append((int(piece[2:]), word_idx)) else: assert piece.startswith("I-") cluster_id = int(piece[2:]) # this will keep the first cluster found # the effect of this is that when two clusters overlap, # and they happen to be the same cluster id, # they will be nested instead of overlapping past each other for idx, previous_cluster in enumerate(current_clusters): if previous_cluster[0] == cluster_id: break else: raise ValueError("Cluster %s does not continue an existing cluster in %s" % (piece, filename)) new_clusters.append(previous_cluster) del current_clusters[idx] for cluster, start_idx in current_clusters: processed.append((cluster, start_idx, word_idx-1)) current_clusters = new_clusters for cluster, start_idx in current_clusters: processed.append((cluster, start_idx, len(sentence_coref)-1)) # sort by the first word index processed = sorted(processed, key=itemgetter(1)) # TODO: cluster IDs are starting at 1, not 0. # that may or may not be relevant sentence_clusters.append(processed) return sentence_clusters def main(): parser = argparse.ArgumentParser( prog='Convert Tamil Coref Data', ) parser.add_argument('--no_use_cconj_heads', dest='use_cconj_heads', action='store_false', help="Don't use the conjunction-aware transformation") args = parser.parse_args() random.seed(1234) paths = get_default_paths() coref_input_path = paths["COREF_BASE"] tamil_base_path = os.path.join(coref_input_path, "tamil", "coref_ta_corrected") tamil_glob = os.path.join(tamil_base_path, "*txt") filenames = sorted(glob.glob(tamil_glob)) docs = [read_doc(x) for x in filenames] raw_sentences = [doc[0] for doc in docs] sentence_clusters = [convert_clusters(filename, doc[1]) for filename, doc in zip(filenames, docs)] pipe = stanza.Pipeline("ta", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True) train, dev, test = [], [], [] for filename, sentences, coref_spans in tqdm(zip(filenames, raw_sentences, sentence_clusters), total=len(filenames)): doc_id = filename part_id = " " sentence_speakers = [[""] * len(sent) for sent in sentences] processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=args.use_cconj_heads) location = random.choices((train, dev, test), weights = (0.8, 0.1, 0.1))[0] location.append(processed) output_filename = os.path.join(paths["COREF_DATA_DIR"], "ta_kbc.train.json") write_json_file(output_filename, train) output_filename = os.path.join(paths["COREF_DATA_DIR"], "ta_kbc.dev.json") write_json_file(output_filename, dev) output_filename = os.path.join(paths["COREF_DATA_DIR"], "ta_kbc.test.json") write_json_file(output_filename, test) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/coref/convert_udcoref.py ================================================ from collections import defaultdict import json import os import re import glob from stanza.utils.default_paths import get_default_paths from stanza.utils.get_tqdm import get_tqdm from stanza.utils.datasets.coref.utils import find_cconj_head from stanza.utils.conll import CoNLL import warnings from random import Random import argparse augment_random = Random(7) split_random = Random(8) tqdm = get_tqdm() IS_UDCOREF_FORMAT = True UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1 def process_documents(docs, augment=False): # docs = sections processed_section = [] for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)): # drop the last token 10% of the time if augment: for i in doc.sentences: if len(i.words) > 1: if augment_random.random() < 0.1: i.tokens = i.tokens[:-1] i.words = i.words[:-1] # extract the entities # get sentence words and lengths sentences = [[j.text for j in i.all_words] for i in doc.sentences] sentence_lens = [len(x.all_words) for x in doc.sentences] cased_words = [] for x in sentences: if augment: # modify case of the first word with 50% chance if augment_random.random() < 0.5: x[0] = x[0].lower() for y in x: cased_words.append(y) sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len] word_total = 0 heads = [] # TODO: does SD vs UD matter? deprel = [] for sentence in doc.sentences: for word in sentence.all_words: deprel.append(word.deprel) if not word.head or word.head == 0: heads.append("null") else: heads.append(word.head - 1 + word_total) word_total += len(sentence.all_words) span_clusters = defaultdict(list) word_clusters = defaultdict(list) head2span = [] is_zero = [] word_total = 0 SPANS = re.compile(r"(\(\w+|[%\w]+\))") do_ctn = False # if we broke in the loop for parsed_sentence in doc.sentences: # spans regex # parse the misc column, leaving on "Entity" entries misc = [[k.split("=") for k in j if k.split("=")[0] == "Entity"] for i in parsed_sentence.all_words for j in [i.misc.split("|") if i.misc else []]] # and extract the Entity entry values entities = [i[0][1] if len(i) > 0 else None for i in misc] # extract reference information refs = [SPANS.findall(i) if i else [] for i in entities] # and calculate spans: the basic rule is (e... begins a reference # and ) without e before ends the most recent reference # every single time we get a closing element, we pop it off # the refdict and insert the pair to final_refs refdict = defaultdict(list) final_refs = defaultdict(list) last_ref = None for indx, i in enumerate(refs): for j in i: # this is the beginning of a reference if j[0] == "(": refdict[j[1+UDCOREF_ADDN:]].append(indx) last_ref = j[1+UDCOREF_ADDN:] # at the end of a reference, if we got exxxxx, that ends # a particular refereenc; otherwise, it ends the last reference elif j[-1] == ")" and j[UDCOREF_ADDN:-1].isnumeric(): if (not UDCOREF_ADDN) or j[0] == "e": try: final_refs[j[UDCOREF_ADDN:-1]].append((refdict[j[UDCOREF_ADDN:-1]].pop(-1), indx)) except IndexError: # this is probably zero anaphora continue elif j[-1] == ")": final_refs[last_ref].append((refdict[last_ref].pop(-1), indx)) last_ref = None final_refs = dict(final_refs) # convert it to the right format (specifically, in (ref, start, end) tuples) coref_spans = [] for k, v in final_refs.items(): for i in v: coref_spans.append([int(k), i[0], i[1]]) sentence_upos = [x.upos for x in parsed_sentence.all_words] sentence_heads = [x.head - 1 if x.head and x.head > 0 else None for x in parsed_sentence.all_words] sentence_text = [x.text for x in parsed_sentence.all_words] # if "_" in sentence_text and sentence_text.index("_") in [j for i in coref_spans for j in i]: # import ipdb # ipdb.set_trace() for span in coref_spans: zero = False if sentence_text[span[1]] == "_" and span[1] == span[2]: is_zero.append([span[0], True]) zero = True # oo! that's a zero coref, we should merge it forwards # i.e. we pick the next word as the head! span = [span[0], span[1]+1, span[2]+1] # crap! there's two zeros right next to each other # we are sad and confused so we give up in this case if len(sentence_text) > span[1] and sentence_text[span[1]] == "_": warnings.warn("Found two zeros next to each other in sequence; we are confused and therefore giving up.") do_ctn = True break else: is_zero.append([span[0], False]) # input is expected to be start word, end word + 1 # counting from 0 # whereas the OntoNotes coref_span is [start_word, end_word] inclusive span_start = span[1] + word_total span_end = span[2] + word_total + 1 # if its a zero coref (i.e. coref, but the head in None), we call # the beginning of the span (i.e. the zero itself) the head if zero: candidate_head = span[1] else: try: candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) except RecursionError: candidate_head = span[1] if candidate_head is None: for candidate_head in range(span[1], span[2] + 1): # stanza uses 0 to mark the head, whereas OntoNotes is counting # words from 0, so we have to subtract 1 from the stanza heads #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) # treat the head of the phrase as the first word that has a head outside the phrase if (parsed_sentence.all_words[candidate_head].head is not None) and ( parsed_sentence.all_words[candidate_head].head - 1 < span[1] or parsed_sentence.all_words[candidate_head].head - 1 > span[2] ): break else: # if none have a head outside the phrase (circular??) # then just take the first word candidate_head = span[1] #print("----> %d" % candidate_head) candidate_head += word_total span_clusters[span[0]].append((span_start, span_end)) word_clusters[span[0]].append(candidate_head) head2span.append((candidate_head, span_start, span_end)) if do_ctn: break word_total += len(parsed_sentence.all_words) if do_ctn: continue span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) head2span = sorted(head2span) is_zero = [i for _,i in sorted(is_zero)] # remove zero tokens "_" from cased_words and adjust indices accordingly zero_positions = [i for i, w in enumerate(cased_words) if w == "_"] if zero_positions: old_to_new = {} new_idx = 0 for old_idx, w in enumerate(cased_words): if w != "_": old_to_new[old_idx] = new_idx new_idx += 1 cased_words = [w for w in cased_words if w != "_"] sent_id = [sent_id[i] for i in sorted(old_to_new.keys())] deprel = [deprel[i] for i in sorted(old_to_new.keys())] heads = [heads[i] for i in sorted(old_to_new.keys())] try: span_clusters = [ [(old_to_new[start], old_to_new[end - 1] + 1) for start, end in cluster] for cluster in span_clusters ] except (KeyError, TypeError) as _: # two errors, either end-1 = -1, or start/end is None warnings.warn("Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up.") continue word_clusters = [ [old_to_new[h] for h in cluster] for cluster in word_clusters ] head2span = [ (old_to_new[h], old_to_new[s], old_to_new[e - 1] + 1) for h, s, e in head2span ] processed = { "document_id": doc_id, "cased_words": cased_words, "sent_id": sent_id, "part_id": idx, # "pos": pos, "deprel": deprel, "head": heads, "span_clusters": span_clusters, "word_clusters": word_clusters, "head2span": head2span, "lang": lang, "is_zero": is_zero } processed_section.append(processed) return processed_section def process_dataset(short_name, coref_output_path, split_test, train_files, dev_files): section_names = ('train', 'dev') section_filenames = [train_files, dev_files] sections = [] test_sections = [] for section, filenames in zip(section_names, section_filenames): input_file = [] for load in filenames: lang = load.split("/")[-1].split("_")[0] print("Ingesting %s from %s of lang %s" % (section, load, lang)) docs = CoNLL.conll2multi_docs(load, ignore_gapping=False) # sections = docs[:10] print(" Ingested %d documents" % len(docs)) if split_test and section == 'train': test_section = [] train_section = [] for i in docs: # reseed for each doc so that we can attempt to keep things stable in the event # of different file orderings or some change to the number of documents split_random = Random(i.sentences[0].doc_id + i.sentences[0].text) if split_random.random() < split_test: test_section.append((i, i.sentences[0].doc_id, lang)) else: train_section.append((i, i.sentences[0].doc_id, lang)) if len(test_section) == 0 and len(train_section) >= 2: idx = split_random.randint(0, len(train_section) - 1) test_section = [train_section[idx]] train_section = train_section[:idx] + train_section[idx+1:] print(" Splitting %d documents from %s for test" % (len(test_section), load)) input_file.extend(train_section) test_sections.append(test_section) else: for i in docs: input_file.append((i, i.sentences[0].doc_id, lang)) print("Ingested %d total documents" % len(input_file)) sections.append(input_file) if split_test: section_names = ('train', 'dev', 'test') full_test_section = [] for filename, test_section in zip(filenames, test_sections): # TODO: could write dataset-specific test sections as well full_test_section.extend(test_section) sections.append(full_test_section) output_filenames = [] for section_data, section_name in zip(sections, section_names): converted_section = process_documents(section_data, augment=(section_name=="train")) os.makedirs(coref_output_path, exist_ok=True) output_filenames.append("%s.%s.json" % (short_name, section_name)) output_filename = os.path.join(coref_output_path, output_filenames[-1]) with open(output_filename, "w", encoding="utf-8") as fout: json.dump(converted_section, fout, indent=2) return output_filenames def get_dataset_by_language(coref_input_path, langs): conll_path = os.path.join(coref_input_path, "CorefUD-1.3-public", "data") train_filenames = [] dev_filenames = [] for lang in langs: train_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*train.conllu"))) dev_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*dev.conllu"))) train_filenames = sorted(train_filenames) dev_filenames = sorted(dev_filenames) return train_filenames, dev_filenames def main(args=None): paths = get_default_paths() parser = argparse.ArgumentParser( prog='Convert UDCoref Data', ) parser.add_argument('--split_test', default=None, type=float, help='How much of the data to randomly split from train to make a test set') parser.add_argument('--output_directory', default=None, type=str, help='Where to output the data (defaults to %s)' % paths['COREF_DATA_DIR']) group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--directory', type=str, help="the name of the subfolder for data conversion") group.add_argument('--project', type=str, help="Look for and use a set of datasets for data conversion - Slavic or Hungarian") group.add_argument('--languages', type=str, help="Only use these specific languages from the coref directory") args = parser.parse_args(args=args) coref_input_path = paths['COREF_BASE'] coref_output_path = args.output_directory if args.output_directory else paths['COREF_DATA_DIR'] if args.languages: langs = args.languages.split(",") project = "_".join(langs) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project: if args.project == 'baltoslavic': project = "baltoslavic_udcoref" langs = ('Polish', 'Russian', 'Czech', 'Old_Church_Slavonic', 'Lithuanian') train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'hungarian': project = "hu_udcoref" langs = ('Hungarian',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'gerrom': project = "gerrom_udcoref" langs = ('Catalan', 'English', 'French', 'German', 'Norwegian', 'Spanish') train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'germanic': project = "germanic_udcoref" langs = ('English', 'German', 'Norwegian') train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'norwegian': project = "norwegian_udcoref" langs = ('Norwegian',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'turkish': project = "turkish_udcoref" langs = ('Turkish',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'korean': project = "korean_udcoref" langs = ('Korean',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'hindi': project = "hindi_udcoref" langs = ('Hindi',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'ancient_greek': project = "ancient_greek_udcoref" langs = ('Ancient_Greek',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'ancient_hebrew': project = "ancient_hebrew_udcoref" langs = ('Ancient_Hebrew',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) else: project = args.directory conll_path = os.path.join(coref_input_path, project) if not os.path.exists(conll_path) and os.path.exists(project): conll_path = args.directory train_filenames = sorted(glob.glob(os.path.join(conll_path, f"*train.conllu"))) dev_filenames = sorted(glob.glob(os.path.join(conll_path, f"*dev.conllu"))) return process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/coref/convert_udcoref_1.2.py ================================================ from collections import defaultdict import json import os import re import glob from stanza.utils.default_paths import get_default_paths from stanza.utils.get_tqdm import get_tqdm from stanza.utils.datasets.coref.utils import find_cconj_head from stanza.utils.conll import CoNLL from random import Random import argparse augment_random = Random(7) split_random = Random(8) tqdm = get_tqdm() IS_UDCOREF_FORMAT = True UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1 def process_documents(docs, augment=False): processed_section = [] for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)): # drop the last token 10% of the time if augment: for i in doc.sentences: if len(i.words) > 1: if augment_random.random() < 0.1: i.tokens = i.tokens[:-1] i.words = i.words[:-1] # extract the entities # get sentence words and lengths sentences = [[j.text for j in i.words] for i in doc.sentences] sentence_lens = [len(x.words) for x in doc.sentences] cased_words = [] for x in sentences: if augment: # modify case of the first word with 50% chance if augment_random.random() < 0.5: x[0] = x[0].lower() for y in x: cased_words.append(y) sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len] word_total = 0 heads = [] # TODO: does SD vs UD matter? deprel = [] for sentence in doc.sentences: for word in sentence.words: deprel.append(word.deprel) if word.head == 0: heads.append("null") else: heads.append(word.head - 1 + word_total) word_total += len(sentence.words) span_clusters = defaultdict(list) word_clusters = defaultdict(list) head2span = [] word_total = 0 SPANS = re.compile(r"(\(\w+|[%\w]+\))") for parsed_sentence in doc.sentences: # spans regex # parse the misc column, leaving on "Entity" entries misc = [[k.split("=") for k in j if k.split("=")[0] == "Entity"] for i in parsed_sentence.words for j in [i.misc.split("|") if i.misc else []]] # and extract the Entity entry values entities = [i[0][1] if len(i) > 0 else None for i in misc] # extract reference information refs = [SPANS.findall(i) if i else [] for i in entities] # and calculate spans: the basic rule is (e... begins a reference # and ) without e before ends the most recent reference # every single time we get a closing element, we pop it off # the refdict and insert the pair to final_refs refdict = defaultdict(list) final_refs = defaultdict(list) last_ref = None for indx, i in enumerate(refs): for j in i: # this is the beginning of a reference if j[0] == "(": refdict[j[1+UDCOREF_ADDN:]].append(indx) last_ref = j[1+UDCOREF_ADDN:] # at the end of a reference, if we got exxxxx, that ends # a particular refereenc; otherwise, it ends the last reference elif j[-1] == ")" and j[UDCOREF_ADDN:-1].isnumeric(): if (not UDCOREF_ADDN) or j[0] == "e": try: final_refs[j[UDCOREF_ADDN:-1]].append((refdict[j[UDCOREF_ADDN:-1]].pop(-1), indx)) except IndexError: # this is probably zero anaphora continue elif j[-1] == ")": final_refs[last_ref].append((refdict[last_ref].pop(-1), indx)) last_ref = None final_refs = dict(final_refs) # convert it to the right format (specifically, in (ref, start, end) tuples) coref_spans = [] for k, v in final_refs.items(): for i in v: coref_spans.append([int(k), i[0], i[1]]) sentence_upos = [x.upos for x in parsed_sentence.words] sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words] for span in coref_spans: # input is expected to be start word, end word + 1 # counting from 0 # whereas the OntoNotes coref_span is [start_word, end_word] inclusive span_start = span[1] + word_total span_end = span[2] + word_total + 1 candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if candidate_head is None: for candidate_head in range(span[1], span[2] + 1): # stanza uses 0 to mark the head, whereas OntoNotes is counting # words from 0, so we have to subtract 1 from the stanza heads #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) # treat the head of the phrase as the first word that has a head outside the phrase if (parsed_sentence.words[candidate_head].head - 1 < span[1] or parsed_sentence.words[candidate_head].head - 1 > span[2]): break else: # if none have a head outside the phrase (circular??) # then just take the first word candidate_head = span[1] #print("----> %d" % candidate_head) candidate_head += word_total span_clusters[span[0]].append((span_start, span_end)) word_clusters[span[0]].append(candidate_head) head2span.append((candidate_head, span_start, span_end)) word_total += len(parsed_sentence.words) span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) head2span = sorted(head2span) processed = { "document_id": doc_id, "cased_words": cased_words, "sent_id": sent_id, "part_id": idx, # "pos": pos, "deprel": deprel, "head": heads, "span_clusters": span_clusters, "word_clusters": word_clusters, "head2span": head2span, "lang": lang } processed_section.append(processed) return processed_section def process_dataset(short_name, coref_output_path, split_test, train_files, dev_files): section_names = ('train', 'dev') section_filenames = [train_files, dev_files] sections = [] test_sections = [] for section, filenames in zip(section_names, section_filenames): input_file = [] for load in filenames: lang = load.split("/")[-1].split("_")[0] print("Ingesting %s from %s of lang %s" % (section, load, lang)) docs = CoNLL.conll2multi_docs(load) print(" Ingested %d documents" % len(docs)) if split_test and section == 'train': test_section = [] train_section = [] for i in docs: # reseed for each doc so that we can attempt to keep things stable in the event # of different file orderings or some change to the number of documents split_random = Random(i.sentences[0].doc_id + i.sentences[0].text) if split_random.random() < split_test: test_section.append((i, i.sentences[0].doc_id, lang)) else: train_section.append((i, i.sentences[0].doc_id, lang)) if len(test_section) == 0 and len(train_section) >= 2: idx = split_random.randint(0, len(train_section) - 1) test_section = [train_section[idx]] train_section = train_section[:idx] + train_section[idx+1:] print(" Splitting %d documents from %s for test" % (len(test_section), load)) input_file.extend(train_section) test_sections.append(test_section) else: for i in docs: input_file.append((i, i.sentences[0].doc_id, lang)) print("Ingested %d total documents" % len(input_file)) sections.append(input_file) if split_test: section_names = ('train', 'dev', 'test') full_test_section = [] for filename, test_section in zip(filenames, test_sections): # TODO: could write dataset-specific test sections as well full_test_section.extend(test_section) sections.append(full_test_section) for section_data, section_name in zip(sections, section_names): converted_section = process_documents(section_data, augment=(section_name=="train")) os.makedirs(coref_output_path, exist_ok=True) output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section_name)) with open(output_filename, "w", encoding="utf-8") as fout: json.dump(converted_section, fout, indent=2) def get_dataset_by_language(coref_input_path, langs): conll_path = os.path.join(coref_input_path, "CorefUD-1.2-public", "data") train_filenames = [] dev_filenames = [] for lang in langs: train_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*train.conllu"))) dev_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*dev.conllu"))) train_filenames = sorted(train_filenames) dev_filenames = sorted(dev_filenames) return train_filenames, dev_filenames def main(): paths = get_default_paths() parser = argparse.ArgumentParser( prog='Convert UDCoref Data', ) parser.add_argument('--split_test', default=None, type=float, help='How much of the data to randomly split from train to make a test set') group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--directory', type=str, help="the name of the subfolder for data conversion") group.add_argument('--project', type=str, help="Look for and use a set of datasets for data conversion - Slavic or Hungarian") args = parser.parse_args() coref_input_path = paths['COREF_BASE'] coref_output_path = paths['COREF_DATA_DIR'] if args.project: if args.project == 'slavic': project = "slavic_udcoref" langs = ('Polish', 'Russian', 'Czech') train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'hungarian': project = "hu_udcoref" langs = ('Hungarian',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'gerrom': project = "gerrom_udcoref" langs = ('Catalan', 'English', 'French', 'German', 'Norwegian', 'Spanish') train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'germanic': project = "germanic_udcoref" langs = ('English', 'German', 'Norwegian') train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'norwegian': project = "norwegian_udcoref" langs = ('Norwegian',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) else: project = args.directory conll_path = os.path.join(coref_input_path, project) if not os.path.exists(conll_path) and os.path.exists(project): conll_path = args.directory train_filenames = sorted(glob.glob(os.path.join(conll_path, f"*train.conllu"))) dev_filenames = sorted(glob.glob(os.path.join(conll_path, f"*dev.conllu"))) process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/coref/utils.py ================================================ from collections import defaultdict from functools import lru_cache class DynamicDepth(): """ Implements a cache + dynamic programming to find the relative depth of every word in a subphrase given the head word for every word. """ def get_parse_depths(self, heads, start, end): """Return the relative depth for every word Args: heads (list): List where each entry is the index of that entry's head word in the dependency parse start (int): starting index of the heads for the subphrase end (int): ending index of the heads for the subphrase Returns: list: Relative depth in the dependency parse for every word """ self.heads = heads[start:end] self.relative_heads = [h - start if h else -100 for h in self.heads] # -100 to deal with 'none' headwords depths = [self._get_depth_recursive(h) for h in range(len(self.relative_heads))] return depths @lru_cache(maxsize=None) def _get_depth_recursive(self, index): """Recursively get the depths of every index using a cache and recursion Args: index (int): Index of the word for which to calculate the relative depth Returns: int: Relative depth of the word at the index """ # if the head for the current index is outside the scope, this index is a relative root if self.relative_heads[index] >= len(self.relative_heads) or self.relative_heads[index] < 0: return 0 return self._get_depth_recursive(self.relative_heads[index]) + 1 def find_cconj_head(heads, upos, start, end): """ Finds how far each word is from the head of a span, then uses the closest CCONJ to the head as the new head If no CCONJ is present, returns None """ # use head information to extract parse depth dynamicDepth = DynamicDepth() depth = dynamicDepth.get_parse_depths(heads, start, end) depth_limit = 2 # return first 'CCONJ' token above depth limit, if exists # unlike the original paper, we expect the parses to use UPOS, hence CCONJ instead of CC cc_indexes = [i for i in range(end - start) if upos[i+start] == 'CCONJ' and depth[i] < depth_limit] if cc_indexes: return cc_indexes[0] + start return None def process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=True, lang=None): """ doc_id: a string naming the document part_id: if the document has a particular subpart (can be blank) sentences: a list of list of string representing the raw text coref_spans: a list of lists one list per sentence each sentence has a list of spans, where each span is (span_index, span_start, span_end) the indices are relative to 0 for that particular sentence, and if the span is exactly 1 word long, span_start == span_end sentence_speakers: a list of list of string representing who said each word. can all be blank if there are no known speakers """ sentence_lens = [len(x) for x in sentences] if sentence_speakers is None: sentence_speakers = [" " for _ in sentences] if all(isinstance(x, list) for x in sentence_speakers): speaker = [y for x in sentence_speakers for y in x] else: speaker = [y for x, sent_len in zip(sentence_speakers, sentence_lens) for y in [x] * sent_len] cased_words = [y for x in sentences for y in x] sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len] # use the trees to get the xpos tags # alternatively, could translate the pos_tags field, # but those have numbers, which is annoying #tree_text = "\n".join(x['parse_tree'] for x in paragraph) #trees = tree_reader.read_trees(tree_text) #pos = [x.label for tree in trees for x in tree.yield_preterminals()] # actually, the downstream code doesn't use pos at all. maybe we can skip? doc = pipe(sentences) word_total = 0 heads = [] # TODO: does SD vs UD matter? deprel = [] for sentence in doc.sentences: for word in sentence.words: deprel.append(word.deprel) if word.head == 0: heads.append("null") else: heads.append(word.head - 1 + word_total) word_total += len(sentence.words) span_clusters = defaultdict(list) word_clusters = defaultdict(list) head2span = [] word_total = 0 for sent_idx, (parsed_sentence, ontonotes_words) in enumerate(zip(doc.sentences, sentences)): sentence_upos = [x.upos for x in parsed_sentence.words] sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words] for span in coref_spans[sent_idx]: # input is expected to be start word, end word + 1 # counting from 0 # whereas the OntoNotes coref_span is [start_word, end_word] inclusive span_start = span[1] + word_total span_end = span[2] + word_total + 1 candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if use_cconj_heads else None if candidate_head is None: for candidate_head in range(span[1], span[2] + 1): # stanza uses 0 to mark the head, whereas OntoNotes is counting # words from 0, so we have to subtract 1 from the stanza heads #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) # treat the head of the phrase as the first word that has a head outside the phrase if (parsed_sentence.words[candidate_head].head - 1 < span[1] or parsed_sentence.words[candidate_head].head - 1 > span[2]): break else: # if none have a head outside the phrase (circular??) # then just take the first word candidate_head = span[1] #print("----> %d" % candidate_head) candidate_head += word_total span_clusters[span[0]].append((span_start, span_end)) word_clusters[span[0]].append(candidate_head) head2span.append((candidate_head, span_start, span_end)) word_total += len(ontonotes_words) span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) head2span = sorted(head2span) processed = { "document_id": doc_id, "part_id": part_id, "cased_words": cased_words, "sent_id": sent_id, "speaker": speaker, #"pos": pos, "deprel": deprel, "head": heads, "span_clusters": span_clusters, "word_clusters": word_clusters, "head2span": head2span, } if part_id is not None: processed["part_id"] = part_id if lang is not None: processed["lang"] = lang return processed ================================================ FILE: stanza/utils/datasets/corenlp_segmenter_dataset.py ================================================ """ Output a treebank's sentences in a form that can be processed by the CoreNLP CRF Segmenter Run it as python3 -m stanza.utils.datasets.corenlp_segmenter_dataset such as python3 -m stanza.utils.datasets.corenlp_segmenter_dataset UD_Chinese-GSDSimp --output_dir $CHINESE_SEGMENTER_HOME """ import argparse import os import sys import tempfile import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank import stanza.utils.default_paths as default_paths from stanza.models.common.constant import treebank_to_short_name def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('treebanks', type=str, nargs='*', default=["UD_Chinese-GSDSimp"], help='Which treebanks to run on') parser.add_argument('--output_dir', type=str, default='.', help='Where to put the results') return parser def write_segmenter_file(output_filename, dataset): with open(output_filename, "w") as fout: for sentence in dataset: sentence = [x for x in sentence if not x.startswith("#")] sentence = [x for x in [y.strip() for y in sentence] if x] # eliminate MWE, although Chinese currently doesn't have any sentence = [x for x in sentence if x.split("\t")[0].find("-") < 0] text = " ".join(x.split("\t")[1] for x in sentence) fout.write(text) fout.write("\n") def process_treebank(treebank, model_type, paths, output_dir): with tempfile.TemporaryDirectory() as tokenizer_dir: paths = dict(paths) paths["TOKENIZE_DATA_DIR"] = tokenizer_dir short_name = treebank_to_short_name(treebank) # first we process the tokenization data args = argparse.Namespace() args.augment = False args.prepare_labels = False prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, args) # TODO: these names should be refactored train_file = f"{tokenizer_dir}/{short_name}.train.gold.conllu" dev_file = f"{tokenizer_dir}/{short_name}.dev.gold.conllu" test_file = f"{tokenizer_dir}/{short_name}.test.gold.conllu" train_set = common.read_sentences_from_conllu(train_file) dev_set = common.read_sentences_from_conllu(dev_file) test_set = common.read_sentences_from_conllu(test_file) train_out = os.path.join(output_dir, f"{short_name}.train.seg.txt") test_out = os.path.join(output_dir, f"{short_name}.test.seg.txt") write_segmenter_file(train_out, train_set + dev_set) write_segmenter_file(test_out, test_set) def main(): parser = build_argparse() args = parser.parse_args() paths = default_paths.get_default_paths() for treebank in args.treebanks: process_treebank(treebank, common.ModelType.TOKENIZER, paths, args.output_dir) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/depparse/check_results.py ================================================ """ A small script to report the dev/test scores from a depparse run, along with averaging multiple runs at once. Uses the expected log format from the depparse. Will not work otherwise. """ import argparse import re import sys dev_re = re.compile(".*INFO: step ([0-9]+).*dev_score = ([.0-9]+).*") def main(): parser = argparse.ArgumentParser(description="Grep through a list of files looking for the final results or best results up to a point") parser.add_argument("filenames", nargs="+", help="Files to check") parser.add_argument("--step", default=None, type=int, help="If set, stop checking at this step") args = parser.parse_args() filenames = args.filenames if len(filenames) == 0: return dev_scores = [] test_scores = [] best_step = None for filename in filenames: with open(filename, encoding="utf-8") as fin: lines = fin.readlines() dev_score = None test_score = None for line in lines: if line.find("Parser score") >= 0: score = float(line.strip().split()[-1]) if "dev" in line: dev_score = score elif "test" in line: test_score = score else: raise AssertionError("Did the parser score layout change? Got an unexpected score line in %s" % filename) best_step = None dev_match = dev_re.match(line) if dev_match: step = int(dev_match.groups()[0]) if args.step is not None and step > args.step: break score = float(dev_match.groups()[1]) * 100 if dev_score is None or score > dev_score: dev_score = score best_step = step if dev_score is None: dev_score = "N/A" else: dev_scores.append(dev_score) dev_score = "%.2f" % dev_score if test_score is None: test_score = "N/A" else: test_scores.append(test_score) test_score = "%.2f" % test_score if best_step is not None: print("%s %s (%d)" % (filename, dev_score, best_step)) else: print("%s %s %s" % (filename, dev_score, test_score)) if len(dev_scores) > 0: dev_score = sum(dev_scores) / len(dev_scores) print("Avg dev score: %.2f" % dev_score) if len(test_scores) > 0: test_score = sum(test_scores) / len(test_scores) print("Avg test score: %.2f" % test_score) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/ner/build_en_combined.py ================================================ """ Builds a combined model out of OntoNotes, WW, and CoNLL. This is done with three layers in the multi_ner column: First layer is OntoNotes only. Other datasets have that left as blank. Second layer is the 9 class WW dataset. OntoNotes is reduced to 9 classes for this column. Third column is the CoNLL dataset. OntoNotes and WW are both projected to this. """ import json import os import shutil from stanza.utils import default_paths from stanza.utils.datasets.ner.simplify_en_worldwide import process_label from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide from stanza.utils.datasets.ner.utils import combine_files def convert_ontonotes_file(filename, short_name): assert "en_ontonotes." in filename if not os.path.exists(filename): raise FileNotFoundError("Cannot convert missing file %s" % filename) new_filename = filename.replace("en_ontonotes.", short_name + ".ontonotes.") with open(filename) as fin: doc = json.load(fin) for sentence in doc: is_start = False for word in sentence: text = word['text'] ner = word['ner'] s9 = simplify_ontonotes_to_worldwide(ner) _, s4, is_start = process_label((text, s9), is_start) word['multi_ner'] = (ner, s9, s4) with open(new_filename, "w") as fout: json.dump(doc, fout, indent=2) def convert_worldwide_file(filename, short_name): assert "en_worldwide-9class." in filename if not os.path.exists(filename): raise FileNotFoundError("Cannot convert missing file %s" % filename) new_filename = filename.replace("en_worldwide-9class.", short_name + ".worldwide-9class.") with open(filename) as fin: doc = json.load(fin) for sentence in doc: is_start = False for word in sentence: text = word['text'] ner = word['ner'] _, s4, is_start = process_label((text, ner), is_start) word['multi_ner'] = ("-", ner, s4) with open(new_filename, "w") as fout: json.dump(doc, fout, indent=2) def convert_conll03_file(filename, short_name): assert "en_conll03." in filename if not os.path.exists(filename): raise FileNotFoundError("Cannot convert missing file %s" % filename) new_filename = filename.replace("en_conll03.", short_name + ".conll03.") with open(filename) as fin: doc = json.load(fin) for sentence in doc: for word in sentence: ner = word['ner'] word['multi_ner'] = ("-", "-", ner) with open(new_filename, "w") as fout: json.dump(doc, fout, indent=2) def build_combined_dataset(base_output_path, short_name): convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.train.json"), short_name) convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.dev.json"), short_name) convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.test.json"), short_name) convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.train.json"), short_name) convert_conll03_file(os.path.join(base_output_path, "en_conll03.train.json"), short_name) combine_files(os.path.join(base_output_path, "%s.train.json" % short_name), os.path.join(base_output_path, "en_combined.ontonotes.train.json"), os.path.join(base_output_path, "en_combined.worldwide-9class.train.json"), os.path.join(base_output_path, "en_combined.conll03.train.json")) shutil.copyfile(os.path.join(base_output_path, "en_combined.ontonotes.dev.json"), os.path.join(base_output_path, "%s.dev.json" % short_name)) shutil.copyfile(os.path.join(base_output_path, "en_combined.ontonotes.test.json"), os.path.join(base_output_path, "%s.test.json" % short_name)) def main(): paths = default_paths.get_default_paths() base_output_path = paths["NER_DATA_DIR"] build_combined_dataset(base_output_path, "en_combined") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/check_for_duplicates.py ================================================ """ A simple tool to check if there are duplicates in a set of NER files It's surprising how many datasets have a bunch of duplicates... """ def read_sentences(filename): """ Read the sentences (without tags) from a BIO file """ sentences = [] with open(filename) as fin: lines = fin.readlines() current_sentence = [] for line in lines: line = line.strip() if not line: if current_sentence: sentences.append(tuple(current_sentence)) current_sentence = [] continue word = line.split("\t")[0] current_sentence.append(word) if len(current_sentence) > 0: sentences.append(tuple(current_sentence)) return sentences def check_for_duplicates(output_filenames, fail=False, check_self=False, print_all=False): """ Checks for exact duplicates in a list of NER files """ sentence_map = {} for output_filename in output_filenames: duplicates = 0 sentences = read_sentences(output_filename) for sentence in sentences: other_file = sentence_map.get(sentence, None) if other_file is not None and (check_self or other_file != output_filename): if fail: raise ValueError("Duplicate sentence '{}', first in {}, also in {}".format("".join(sentence), sentence_map[sentence], output_filename)) else: if duplicates == 0 and not print_all: print("First duplicate:") if duplicates == 0 or print_all: print("{}\nFound in {} and {}".format(sentence, other_file, output_filename)) duplicates = duplicates + 1 sentence_map[sentence] = output_filename if duplicates > 0: print("%d duplicates found in %s" % (duplicates, output_filename)) ================================================ FILE: stanza/utils/datasets/ner/combine_ner_datasets.py ================================================ import argparse from stanza.utils.default_paths import get_default_paths from stanza.utils.datasets.ner.utils import combine_dataset SHARDS = ("train", "dev", "test") def main(args=None): ner_data_dir = get_default_paths()['NER_DATA_DIR'] parser = argparse.ArgumentParser() parser.add_argument('--output_dataset', type=str, help='What dataset to output') parser.add_argument('input_datasets', type=str, nargs='+', help='Which datasets to input') parser.add_argument('--input_dir', type=str, default=ner_data_dir, help='Which directory to find the datasets') parser.add_argument('--output_dir', type=str, default=ner_data_dir, help='Which directory to write the dataset') args = parser.parse_args(args) input_dir = args.input_dir output_dir = args.output_dir combine_dataset(input_dir, output_dir, args.input_datasets, args.output_dataset) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/compare_entities.py ================================================ """ Report the fraction of NER entities in one file which are present in another. Purpose: show the coverage of one file on another, such as reporting the number of entities in one dataset on another """ import argparse from stanza.utils.datasets.ner.utils import read_json_entities def parse_args(): parser = argparse.ArgumentParser(description="Report the coverage of one NER file on another.") parser.add_argument('--train', type=str, nargs="+", required=True, help='File to use to collect the known entities (not necessarily train).') parser.add_argument('--test', type=str, nargs="+", required=True, help='File for which we want to know the ratio of known entities') args = parser.parse_args() return args def report_known_entities(train_file, test_file): train_entities = read_json_entities(train_file) test_entities = read_json_entities(test_file) train_entities = set(x[0] for x in train_entities) total_score = sum(1 for x in test_entities if x[0] in train_entities) print(train_file, test_file, total_score / len(test_entities)) def main(): args = parse_args() for train_idx, train_file in enumerate(args.train): if train_idx > 0: print() for test_file in args.test: report_known_entities(train_file, test_file) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/conll_to_iob.py ================================================ """ Process a conll file into BIO Includes the ability to process a file from a text file or a text file within a zip Main program extracts a piece of the zip file from the Danish DDT dataset """ import io import zipfile from zipfile import ZipFile from stanza.utils.conll import CoNLL def process_conll(input_file, output_file, zip_file=None, conversion=None, attr_prefix="name", allow_empty=False): """ Process a single file from DDT zip_filename: path to ddt.zip in_filename: which piece to read out_filename: where to write the result label: which attribute to get from the misc field """ if not attr_prefix.endswith("="): attr_prefix = attr_prefix + "=" doc = CoNLL.conll2doc(input_file=input_file, zip_file=zip_file) with open(output_file, "w", encoding="utf-8") as fout: for sentence_idx, sentence in enumerate(doc.sentences): for token_idx, token in enumerate(sentence.tokens): misc = token.misc.split("|") for attr in misc: if attr.startswith(attr_prefix): ner = attr.split("=", 1)[1] break else: # name= not found if allow_empty: ner = "O" else: raise ValueError("Could not find ner tag in document {}, sentence {}, token {}".format(input_file, sentence_idx, token_idx)) if ner != "O" and conversion is not None: if isinstance(conversion, dict): bio, label = ner.split("-", 1) if label in conversion: label = conversion[label] ner = "%s-%s" % (bio, label) else: ner = conversion(ner) fout.write("%s\t%s\n" % (token.text, ner)) fout.write("\n") def main(): process_conll(zip_file="extern_data/ner/da_ddt/ddt.zip", input_file="ddt.train.conllu", output_file="data/ner/da_ddt.train.bio") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/convert_amt.py ================================================ """ Converts a .json file from AMT to a .bio format and then a .json file To ignore Facility and Product, turn NORP into miscellaneous: python3 stanza/utils/datasets/ner/convert_amt.py --input_path /u/nlp/data/ner/stanza/en_amt/output.manifest --ignore Product,Facility --remap NORP=Miscellaneous To turn all labels into the 4 class used in conll03: python3 stanza/utils/datasets/ner/convert_amt.py --input_path /u/nlp/data/ner/stanza/en_amt/output.manifest --ignore Product,Facility --remap NORP=MISC,Miscellaneous=MISC,Location=LOC,Person=PER,Organization=ORG """ import argparse import copy import json from operator import itemgetter import sys from tqdm import tqdm import stanza from stanza.utils.datasets.ner.utils import write_sentences import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file def read_json(input_filename): """ Read the json file and extract the NER labels Will not return lines which are not labeled Return format is a list of lines where each line is a tuple: (text, labels) labels is a list of maps, {'label':..., 'startOffset':..., 'endOffset':...} """ docs = [] blank = 0 unlabeled = 0 broken = 0 with open(input_filename, encoding="utf-8") as fin: for line_idx, line in enumerate(fin): doc = json.loads(line) if sorted(doc.keys()) == ['source']: unlabeled += 1 continue if 'source' not in doc: blank += 1 continue source = doc['source'] entities = None for k in doc.keys(): if k == 'source' or k.endswith('metadata'): continue if 'annotations' not in doc[k]: continue annotations = doc[k]['annotations'] if 'entities' not in annotations: continue if 'entities' in annotations: if entities is not None: raise ValueError("Found a map with multiple annotations at line %d" % line_idx) entities = annotations['entities'] # entities is now a map such as # [{'label': 'Location', 'startOffset': 0, 'endOffset': 6}, # {'label': 'Location', 'startOffset': 11, 'endOffset': 23}, # {'label': 'NORP', 'startOffset': 66, 'endOffset': 74}, # {'label': 'NORP', 'startOffset': 191, 'endOffset': 214}] if entities is None: unlabeled += 1 continue is_broken = any(any(x not in entity for x in ('label', 'startOffset', 'endOffset')) for entity in entities) if is_broken: broken += 1 if broken == 1: print("Found an entity which was missing either label, startOffset, or endOffset") print(entities) docs.append((source, entities)) print("Found %d labeled lines. %d lines were blank, %d lines were broken, and %d lines were unlabeled" % (len(docs), blank, broken, unlabeled)) return docs def remove_ignored_labels(docs, ignored): if not ignored: return docs ignored = set(ignored.split(",")) # drop all labels which match something in ignored # otherwise leave everything the same new_docs = [(doc[0], [x for x in doc[1] if x['label'] not in ignored]) for doc in docs] return new_docs def remap_labels(docs, remap): if not remap: return docs remappings = {} for remapping in remap.split(","): pieces = remapping.split("=") remappings[pieces[0]] = pieces[1] print(remappings) new_docs = [] for doc in docs: entities = copy.deepcopy(doc[1]) for entity in entities: entity['label'] = remappings.get(entity['label'], entity['label']) new_doc = (doc[0], entities) new_docs.append(new_doc) return new_docs def remove_nesting(docs): """ Currently the NER tool does not handle nesting, so we just throw away nested entities In the event of entites which exactly overlap, the first one in the list wins """ new_docs = [] nested = 0 exact = 0 total = 0 for doc in docs: source, labels = doc # sort by startOffset, -endOffset labels = sorted(labels, key=lambda x: (x['startOffset'], -x['endOffset'])) new_labels = [] for label in labels: total += 1 # note that this works trivially for an empty list for other in reversed(new_labels): if label['startOffset'] == other['startOffset'] and label['endOffset'] == other['endOffset']: exact += 1 break if label['startOffset'] < other['endOffset']: #print("Ignoring nested entity: {} |{}| vs {} |{}|".format(label, source[label['startOffset']:label['endOffset']], other, source[other['startOffset']:other['endOffset']])) nested += 1 break else: # yes, this is meant to be a for-else new_labels.append(label) new_docs.append((source, new_labels)) print("Ignored %d exact and %d nested labels out of %d entries" % (exact, nested, total)) return new_docs def process_doc(source, labels, pipe): """ Given a source text and a list of labels, tokenize the text, then assign labels based on the spans defined """ doc = pipe(source) sentences = doc.sentences for sentence in sentences: for token in sentence.tokens: token.ner = "O" for label in labels: ner = label['label'] start_offset = label['startOffset'] end_offset = label['endOffset'] for sentence in sentences: if (sentence.tokens[0].start_char <= start_offset and sentence.tokens[-1].end_char >= end_offset): # found the sentence! break else: # for-else again! deal with it continue start_token = None end_token = None for token_idx, token in enumerate(sentence.tokens): if token.start_char <= start_offset and token.end_char > start_offset: # ideally we'd have start_char == start_offset, but maybe our # tokenization doesn't match the tokenization of the annotators start_token = token start_token.ner = "B-" + ner elif start_token is not None: if token.start_char >= end_offset and token_idx > 0: end_token = sentence.tokens[token_idx-1] break if token.end_char == end_offset and token_idx > 0 and token.text in (',', '.'): end_token = sentence.tokens[token_idx-1] break token.ner = "I-" + ner if token.end_char >= end_offset and end_token is None: end_token = token break if start_token is None or end_token is None: raise AssertionError("This should not happen") return [[(token.text, token.ner) for token in sentence.tokens] for sentence in sentences] def main(args): """ Read in a .json file of labeled data from AMT, write out a converted .bio file Enforces that there is only one set of labels on a sentence (TODO: add an option to skip certain sets of labels) """ docs = read_json(args.input_path) if len(docs) == 0: print("Error: no documents found in the input file!") return docs = remove_ignored_labels(docs, args.ignore) docs = remap_labels(docs, args.remap) docs = remove_nesting(docs) pipe = stanza.Pipeline(args.language, processors="tokenize") sentences = [] for doc in tqdm(docs): sentences.extend(process_doc(*doc, pipe)) print("Found %d total sentences (may be more than #docs if a doc has more than one sentence)" % len(sentences)) bio_filename = args.output_path write_sentences(args.output_path, sentences) print("Sentences written to %s" % args.output_path) if bio_filename.endswith(".bio"): json_filename = bio_filename[:-4] + ".json" else: json_filename = bio_filename + ".json" prepare_ner_file.process_dataset(bio_filename, json_filename) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--language', type=str, default="en", help="Language to process") parser.add_argument('--input_path', type=str, default="output.manifest", help="Where to find the files") parser.add_argument('--output_path', type=str, default="data/ner/en_amt.test.bio", help="Where to output the results") parser.add_argument('--json_output_path', type=str, default=None, help="Where to output .json. Best guess will be made if there is no .json file") parser.add_argument('--ignore', type=str, default=None, help="Ignore these labels: comma separated list without B- or I-") parser.add_argument('--remap', type=str, default=None, help="Remap labels: comma separated list of X=Y") args = parser.parse_args() main(args) ================================================ FILE: stanza/utils/datasets/ner/convert_ar_aqmar.py ================================================ """ A script to randomly shuffle the input files in the AQMAR dataset and produce train/dev/test for stanza The sentences themselves are shuffled, not the data files This script reads the input files directly from the .zip """ from collections import Counter import random import zipfile from stanza.utils.datasets.ner.utils import write_dataset def read_sentences(infile): """ Read sentences from an open file """ sents = [] cache = [] for line in infile: if isinstance(line, bytes): line = line.decode() line = line.rstrip() if len(line) == 0: if len(cache) > 0: sents.append(cache) cache = [] continue array = line.split() assert len(array) == 2 w, t = array cache.append([w, t]) if len(cache) > 0: sents.append(cache) cache = [] return sents def normalize_tags(sents): new_sents = [] # normalize tags for sent in sents: new_sentence = [] for i, pair in enumerate(sent): w, t = pair if t.startswith('O'): new_t = 'O' elif t.startswith('I-'): type = t[2:] if type.startswith('MIS'): new_t = 'I-MISC' elif type.startswith('-'): # handle I--ORG new_t = 'I-' + type[1:] else: new_t = t elif t.startswith('B-'): type = t[2:] if type.startswith('MIS'): new_t = 'B-MISC' elif type.startswith('ENGLISH') or type.startswith('SPANISH'): new_t = 'O' else: new_t = t else: new_t = 'O' # modify original tag new_sentence.append((sent[i][0], new_t)) new_sents.append(new_sentence) return new_sents def convert_shuffle(base_input_path, base_output_path, short_name): """ Convert AQMAR to a randomly shuffled dataset base_input_path is the zip file. base_output_path is the output directory """ if not zipfile.is_zipfile(base_input_path): raise FileNotFoundError("Expected %s to be the zipfile with AQMAR in it" % base_input_path) with zipfile.ZipFile(base_input_path) as zin: namelist = zin.namelist() annotation_files = [x for x in namelist if x.endswith(".txt") and not "/" in x] annotation_files = sorted(annotation_files) # although not necessary for good results, this does put # things in the same order the shell was alphabetizing files # when the original models were created for Stanza assert annotation_files[2] == 'Computer.txt' assert annotation_files[3] == 'Computer_Software.txt' annotation_files[2], annotation_files[3] = annotation_files[3], annotation_files[2] if len(annotation_files) != 28: raise RuntimeError("Expected exactly 28 labeled .txt files in %s but got %d" % (base_input_path, len(annotation_files))) sentences = [] for in_filename in annotation_files: with zin.open(in_filename) as infile: new_sentences = read_sentences(infile) print(f"{len(new_sentences)} sentences read from {in_filename}") new_sentences = normalize_tags(new_sentences) sentences.extend(new_sentences) all_tags = Counter([p[1] for sent in sentences for p in sent]) print("All tags after normalization:") print(list(all_tags.keys())) num = len(sentences) train_num = int(num*0.7) dev_num = int(num*0.15) random.seed(1234) random.shuffle(sentences) train_sents = sentences[:train_num] dev_sents = sentences[train_num:train_num+dev_num] test_sents = sentences[train_num+dev_num:] shuffled_dataset = [train_sents, dev_sents, test_sents] write_dataset(shuffled_dataset, base_output_path, short_name) ================================================ FILE: stanza/utils/datasets/ner/convert_bn_daffodil.py ================================================ """ Convert a Bengali NER dataset to our internal .json format The dataset is here: https://github.com/Rifat1493/Bengali-NER/tree/master/Input """ import argparse import os import random import tempfile from stanza.utils.datasets.ner.utils import read_tsv, write_dataset def redo_time_tags(sentences): """ Replace all TIM, TIM with B-TIM, I-TIM A brief use of Google Translate suggests the time phrases are generally one phrase, so we don't want to turn this into B-TIM, B-TIM """ new_sentences = [] for sentence in sentences: new_sentence = [] prev_time = False for word, tag in sentence: if tag == 'TIM': if prev_time: new_sentence.append((word, "I-TIM")) else: prev_time = True new_sentence.append((word, "B-TIM")) else: prev_time = False new_sentence.append((word, tag)) new_sentences.append(new_sentence) return new_sentences def strip_words(dataset): return [[(x[0].strip().replace('\ufeff', ''), x[1]) for x in sentence] for sentence in dataset] def filter_blank_words(train_file, train_filtered_file): """ As of July 2022, this dataset has blank words with O labels, which is not ideal This method removes those lines """ with open(train_file, encoding="utf-8") as fin: with open(train_filtered_file, "w", encoding="utf-8") as fout: for line in fin: if line.strip() == 'O': continue fout.write(line) def filter_broken_tags(train_sentences): """ Eliminate any sentences where any of the tags were empty """ return [x for x in train_sentences if not any(y[1] is None for y in x)] def filter_bad_words(train_sentences): """ Not bad words like poop, but characters that don't exist These characters look like n and l in emacs, but they are really 0xF06C and 0xF06E """ return [[x for x in sentence if not x[0] in ("", "")] for sentence in train_sentences] def read_datasets(in_directory): """ Reads & splits the train data, reads the test data There is no validation data, so we split the training data into two pieces and use the smaller piece as the dev set Also performeed is a conversion of TIM -> B-TIM, I-TIM """ # make sure we always get the same shuffle & split random.seed(1234) train_file = os.path.join(in_directory, "Input", "train_data.txt") with tempfile.TemporaryDirectory() as tempdir: train_filtered_file = os.path.join(tempdir, "train.txt") filter_blank_words(train_file, train_filtered_file) train_sentences = read_tsv(train_filtered_file, text_column=0, annotation_column=1, keep_broken_tags=True) train_sentences = filter_broken_tags(train_sentences) train_sentences = filter_bad_words(train_sentences) train_sentences = redo_time_tags(train_sentences) train_sentences = strip_words(train_sentences) test_file = os.path.join(in_directory, "Input", "test_data.txt") test_sentences = read_tsv(test_file, text_column=0, annotation_column=1, keep_broken_tags=True) test_sentences = filter_broken_tags(test_sentences) test_sentences = filter_bad_words(test_sentences) test_sentences = redo_time_tags(test_sentences) test_sentences = strip_words(test_sentences) random.shuffle(train_sentences) split_len = len(train_sentences) * 9 // 10 dev_sentences = train_sentences[split_len:] train_sentences = train_sentences[:split_len] datasets = (train_sentences, dev_sentences, test_sentences) return datasets def convert_dataset(in_directory, out_directory): """ Reads the datasets using read_datasets, then write them back out """ datasets = read_datasets(in_directory) write_dataset(datasets, out_directory, "bn_daffodil") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/bangla/Bengali-NER", help="Where to find the files") parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner", help="Where to output the results") args = parser.parse_args() convert_dataset(args.input_path, args.output_path) ================================================ FILE: stanza/utils/datasets/ner/convert_bsf_to_beios.py ================================================ import argparse import logging import os import glob from collections import namedtuple import re from typing import Tuple from tqdm import tqdm from random import choices, shuffle BsfInfo = namedtuple('BsfInfo', 'id, tag, start_idx, end_idx, token') log = logging.getLogger(__name__) log.setLevel(logging.INFO) def format_token_as_beios(token: str, tag: str) -> list: t_words = token.split() res = [] if len(t_words) == 1: res.append(token + ' S-' + tag) else: res.append(t_words[0] + ' B-' + tag) for t_word in t_words[1: -1]: res.append(t_word + ' I-' + tag) res.append(t_words[-1] + ' E-' + tag) return res def format_token_as_iob(token: str, tag: str) -> list: t_words = token.split() res = [] if len(t_words) == 1: res.append(token + ' B-' + tag) else: res.append(t_words[0] + ' B-' + tag) for t_word in t_words[1:]: res.append(t_word + ' I-' + tag) return res def convert_bsf(data: str, bsf_markup: str, converter: str = 'beios') -> str: """ Convert data file with NER markup in Brat Standoff Format to BEIOS or IOB format. :param converter: iob or beios converter to use for document :param data: tokenized data to be converted. Each token separated with a space :param bsf_markup: Brat Standoff Format markup :return: data in BEIOS or IOB format https://en.wikipedia.org/wiki/Inside–outside–beginning_(tagging) """ def join_simple_chunk(chunk: str) -> list: if len(chunk.strip()) == 0: return [] # keep the newlines, but discard the non-newline whitespace tokens = re.split(r'(\n)|\s', chunk.strip()) # the re will return None for splits which were not caught in a group tokens = [x for x in tokens if x is not None] return [token + ' O' if len(token.strip()) > 0 else token for token in tokens] converters = {'beios': format_token_as_beios, 'iob': format_token_as_iob} res = [] markup = parse_bsf(bsf_markup) prev_idx = 0 m_ln: BsfInfo for m_ln in markup: res += join_simple_chunk(data[prev_idx:m_ln.start_idx]) convert_f = converters[converter] res.extend(convert_f(m_ln.token, m_ln.tag)) prev_idx = m_ln.end_idx if prev_idx < len(data) - 1: res += join_simple_chunk(data[prev_idx:]) return '\n'.join(res) def parse_bsf(bsf_data: str) -> list: """ Convert textual bsf representation to a list of named entities. :param bsf_data: data in the format 'T9 PERS 778 783 токен' :return: list of named tuples for each line of the data representing a single named entity token """ if len(bsf_data.strip()) == 0: return [] ln_ptrn = re.compile(r'(T\d+)\s(\w+)\s(\d+)\s(\d+)\s(.+?)(?=T\d+\s\w+\s\d+\s\d+|$)', flags=re.DOTALL) result = [] for m in ln_ptrn.finditer(bsf_data.strip()): bsf = BsfInfo(m.group(1), m.group(2), int(m.group(3)), int(m.group(4)), m.group(5).strip()) result.append(bsf) return result CORPUS_NAME = 'Ukrainian-languk' def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str = 'beios', doc_delim: str = '\n', train_test_split_file: str = None) -> None: """ :param doc_delim: delimiter to be used between documents :param src_dir_path: path to directory with BSF marked files :param dst_dir_path: where to save output data :param converter: `beios` or `iob` output formats :param train_test_split_file: path to file containing train/test lists of file names :return: """ ann_path = os.path.join(src_dir_path, '*.tok.ann') ann_files = glob.glob(ann_path) ann_files.sort() tok_path = os.path.join(src_dir_path, '*.tok.txt') tok_files = glob.glob(tok_path) tok_files.sort() corpus_folder = os.path.join(dst_dir_path, CORPUS_NAME) if not os.path.exists(corpus_folder): os.makedirs(corpus_folder) if len(ann_files) == 0 or len(tok_files) == 0: raise FileNotFoundError(f'Token and annotation files are not found at specified path {ann_path}') if len(ann_files) != len(tok_files): raise RuntimeError(f'Mismatch between Annotation and Token files. Ann files: {len(ann_files)}, token files: {len(tok_files)}') train_set = [] dev_set = [] test_set = [] data_sets = [train_set, dev_set, test_set] split_weights = (8, 1, 1) if train_test_split_file is not None: train_names, dev_names, test_names = read_languk_train_test_split(train_test_split_file) log.info(f'Found {len(tok_files)} files in data folder "{src_dir_path}"') for (tok_fname, ann_fname) in tqdm(zip(tok_files, ann_files), total=len(tok_files), unit='file'): if tok_fname[:-3] != ann_fname[:-3]: tqdm.write(f'Token and Annotation file names do not match ann={ann_fname}, tok={tok_fname}') continue with open(tok_fname) as tok_file, open(ann_fname) as ann_file: token_data = tok_file.read() ann_data = ann_file.read() out_data = convert_bsf(token_data, ann_data, converter) if train_test_split_file is None: target_dataset = choices(data_sets, split_weights)[0] else: target_dataset = train_set fkey = os.path.basename(tok_fname)[:-4] if fkey in dev_names: target_dataset = dev_set elif fkey in test_names: target_dataset = test_set target_dataset.append(out_data) log.info(f'Data is split as following: train={len(train_set)}, dev={len(dev_set)}, test={len(test_set)}') # writing data to {train/dev/test}.bio files names = ['train', 'dev', 'test'] if doc_delim != '\n': doc_delim = '\n' + doc_delim + '\n' for idx, name in enumerate(names): fname = os.path.join(corpus_folder, name + '.bio') with open(fname, 'w') as f: f.write(doc_delim.join(data_sets[idx])) log.info('Writing to ' + fname) log.info('All done') def read_languk_train_test_split(file_path: str, dev_split: float = 0.1) -> Tuple: """ Read predefined split of train and test files in data set. Originally located under doc/dev-test-split.txt :param file_path: path to dev-test-split.txt file (should include file name with extension) :param dev_split: 0 to 1 float value defining how much to allocate to dev split :return: tuple of (train, dev, test) each containing list of files to be used for respective data sets """ log.info(f'Trying to read train/dev/test split from file "{file_path}". Dev allocation = {dev_split}') train_files, test_files, dev_files = [], [], [] container = test_files with open(file_path, 'r') as f: for ln in f: ln = ln.strip() if ln == 'DEV': container = train_files elif ln == 'TEST': container = test_files elif ln == '': pass else: container.append(ln) # split in file only contains train and test split. # For Stanza training we need train, dev, test # We will take part of train as dev set # This way anyone using test set outside of this code base can be sure that there was no data set polution shuffle(train_files) dev_files = train_files[: int(len(train_files) * dev_split)] train_files = train_files[int(len(train_files) * dev_split):] assert len(set(train_files).intersection(set(dev_files))) == 0 log.info(f'Files in each set: train={len(train_files)}, dev={len(dev_files)}, test={len(test_files)}') return train_files, dev_files, test_files if __name__ == '__main__': logging.basicConfig() parser = argparse.ArgumentParser(description='Convert lang-uk NER data set from BSF format to BEIOS format compatible with Stanza NER model training requirements.\n' 'Original data set should be downloaded from https://github.com/lang-uk/ner-uk\n' 'For example, create a directory extern_data/lang_uk, then run "git clone git@github.com:lang-uk/ner-uk.git') parser.add_argument('--src_dataset', type=str, default='extern_data/ner/lang-uk/ner-uk/data', help='Dir with lang-uk dataset "data" folder (https://github.com/lang-uk/ner-uk)') parser.add_argument('--dst', type=str, default='data/ner', help='Where to store the converted dataset') parser.add_argument('-c', type=str, default='beios', help='`beios` or `iob` formats to be used for output') parser.add_argument('--doc_delim', type=str, default='\n', help='Delimiter to be used to separate documents in the output data') parser.add_argument('--split_file', type=str, help='Name of a file containing Train/Test split (files in train and test set)') parser.print_help() args = parser.parse_args() convert_bsf_in_folder(args.src_dataset, args.dst, args.c, args.doc_delim, train_test_split_file=args.split_file) ================================================ FILE: stanza/utils/datasets/ner/convert_bsnlp.py ================================================ import argparse import glob import os import logging import random import re import stanza logger = logging.getLogger('stanza') AVAILABLE_LANGUAGES = ("bg", "cs", "pl", "ru") def normalize_bg_entity(text, entity, raw): entity = entity.strip() # sanity check that the token is in the original text if text.find(entity) >= 0: return entity # some entities have quotes, but the quotes are different from those in the data file # for example: # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_458.txt # 'Съвета "Общи въпроси"' # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1002.txt # 'Съвет "Общи въпроси"' if sum(1 for x in entity if x == '"') == 2: quote_entity = entity.replace('"', '“') if text.find(quote_entity) >= 0: logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw)) return quote_entity quote_entity = entity.replace('"', '„', 1).replace('"', '“') if text.find(quote_entity) >= 0: logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw)) return quote_entity if sum(1 for x in entity if x == '"') == 1: quote_entity = entity.replace('"', '„', 1) if text.find(quote_entity) >= 0: logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw)) return quote_entity if entity.find("'") >= 0: quote_entity = entity.replace("'", "’") if text.find(quote_entity) >= 0: logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw)) return quote_entity lower_idx = text.lower().find(entity.lower()) if lower_idx >= 0: fixed_entity = text[lower_idx:lower_idx+len(entity)] logger.info("lowercase match found. Searching for '%s' instead of '%s' in %s" % (fixed_entity, entity, raw)) return fixed_entity substitution_pairs = { # this exact error happens in: # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_67.txt 'Съвет по общи въпроси': 'Съвета по общи въпроси', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_214.txt 'Сумимото Мицуи файненшъл груп': 'Сумитомо Мицуи файненшъл груп', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_245.txt 'С и Д': 'С&Д', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_348.txt 'законопроекта за излизане на Великобритания за излизане от Европейския съюз': 'законопроекта за излизане на Великобритания от Европейския съюз', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_771.txt 'Унивеситета в Есекс': 'Университета в Есекс', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_779.txt 'Съвет за сигурност на ООН': 'Съвета за сигурност на ООН', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_787.txt 'Федерика Могерини': 'Федереика Могерини', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_938.txt 'Уайстейбъл': 'Уайтстейбъл', 'Партията за независимост на Обединеното кралство': 'Партията на независимостта на Обединеното кралство', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_972.txt 'Европейска банка за възстановяване и развитие': 'Европейската банка за възстановяване и развитие', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1065.txt 'Харолд Уилсон': 'Харолд Уилсън', 'Манчестърски университет': 'Манчестърския университет', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1096.txt 'Обединеното кралство в променящата се Европа': 'Обединеното кралство в променяща се Европа', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1175.txt 'The Daily Express': 'Daily Express', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1186.txt 'демократичната юнионистка партия': 'демократична юнионистка партия', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1192.txt 'Европейската агенция за безопасността на полетите': 'Европейската агенция за сигурността на полетите', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1219.txt 'пресцентъра на Външно министертво': 'пресцентъра на Външно министерство', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1281.txt 'Европейска агенциа за безопасността на полетите': 'Европейската агенция за сигурността на полетите', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1415.txt 'Хонк Конг': 'Хонг Конг', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1663.txt 'Лейбъристка партия': 'Лейбъристката партия', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1963.txt 'Найджъл Фараж': 'Найджъл Фарадж', 'Фараж': 'Фарадж', # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1773.txt has an entity which is mixed Cyrillic and Ascii 'Tescо': 'Tesco', } if entity in substitution_pairs and text.find(substitution_pairs[entity]) >= 0: fixed_entity = substitution_pairs[entity] logger.info("searching for '%s' instead of '%s' in %s" % (fixed_entity, entity, raw)) return fixed_entity # oops, can't find it anywhere # want to raise ValueError but there are just too many in the train set for BG logger.error("Could not find '%s' in %s" % (entity, raw)) def fix_bg_typos(text, raw_filename): typo_pairs = { # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_202.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters 'brexit_bg.txt_file_202.txt': ('Вlооmbеrg', 'Bloomberg'), # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_261.txt has a typo: Telegaph instead of Telegraph 'brexit_bg.txt_file_261.txt': ('Telegaph', 'Telegraph'), # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_574.txt has a typo: politicalskrapbook instead of politicalscrapbook 'brexit_bg.txt_file_574.txt': ('politicalskrapbook', 'politicalscrapbook'), # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_861.txt has a mix of cyrillic and ascii 'brexit_bg.txt_file_861.txt': ('Съвета „Общи въпроси“', 'Съветa "Общи въпроси"'), # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_992.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters 'brexit_bg.txt_file_992.txt': ('The Guardiаn', 'The Guardian'), # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1856.txt has a typo: Southerb instead of Southern 'brexit_bg.txt_file_1856.txt': ('Southerb', 'Southern'), } filename = os.path.split(raw_filename)[1] if filename in typo_pairs: replacement = typo_pairs.get(filename) text = text.replace(replacement[0], replacement[1]) return text def get_sentences(language, pipeline, annotated, raw): if language == 'bg': normalize_entity = normalize_bg_entity fix_typos = fix_bg_typos else: raise AssertionError("Please build a normalize_%s_entity and fix_%s_typos first" % language) annotated_sentences = [] with open(raw) as fin: lines = fin.readlines() if len(lines) < 5: raise ValueError("Unexpected format in %s" % raw) text = "\n".join(lines[4:]) text = fix_typos(text, raw) entities = {} with open(annotated) as fin: # first line header = fin.readline().strip() if len(header.split("\t")) > 1: raise ValueError("Unexpected missing header line in %s" % annotated) for line in fin: pieces = line.strip().split("\t") if len(pieces) < 3 or len(pieces) > 4: raise ValueError("Unexpected annotation format in %s" % annotated) entity = normalize_entity(text, pieces[0], raw) if not entity: continue if entity in entities: if entities[entity] != pieces[2]: # would like to make this an error, but it actually happens and it's not clear how to fix # annotated/nord_stream/bg/nord_stream_bg.txt_file_119.out logger.warn("found multiple definitions for %s in %s" % (pieces[0], annotated)) entities[entity] = pieces[2] else: entities[entity] = pieces[2] tokenized = pipeline(text) # The benefit of doing these one at a time, instead of all at once, # is that nested entities won't clobber previously labeled entities. # For example, the file # training_pl_cs_ru_bg_rc1/annotated/bg/brexit_bg.txt_file_994.out # has each of: # Северна Ирландия # Република Ирландия # Ирландия # By doing the larger ones first, we can detect and skip the ones # we already labeled when we reach the shorter one regexes = [re.compile(re.escape(x)) for x in sorted(entities.keys(), key=len, reverse=True)] bad_sentences = set() for regex in regexes: for match in regex.finditer(text): start_char, end_char = match.span() # this is inefficient, but for something only run once, it shouldn't matter start_token = None start_sloppy = False end_token = None end_sloppy = False for token in tokenized.iter_tokens(): if token.start_char <= start_char and token.end_char > start_char: start_token = token if token.start_char != start_char: start_sloppy = True if token.start_char <= end_char and token.end_char >= end_char: end_token = token if token.end_char != end_char: end_sloppy = True break if start_token is None or end_token is None: raise RuntimeError("Match %s did not align with any tokens in %s" % (match.group(0), raw)) if not start_token.sent is end_token.sent: bad_sentences.add(start_token.sent.id) bad_sentences.add(end_token.sent.id) logger.warn("match %s spanned sentences %d and %d in document %s" % (match.group(0), start_token.sent.id, end_token.sent.id, raw)) continue # ids start at 1, not 0, so we have to subtract 1 # then the end token is included, so we add back the 1 # TODO: verify that this is correct if the language has MWE - cs, pl, for example tokens = start_token.sent.tokens[start_token.id[0]-1:end_token.id[0]] if all(token.ner for token in tokens): # skip matches which have already been made # this has the nice side effect of not complaining if # a smaller match is found after a larger match # earlier set the NER on those tokens continue if start_sloppy and end_sloppy: bad_sentences.add(start_token.sent.id) logger.warn("match %s matched in the middle of a token in %s" % (match.group(0), raw)) continue if start_sloppy: bad_sentences.add(end_token.sent.id) logger.warn("match %s started matching in the middle of a token in %s" % (match.group(0), raw)) #print(start_token) #print(end_token) #print(start_char, end_char) continue if end_sloppy: bad_sentences.add(start_token.sent.id) logger.warn("match %s ended matching in the middle of a token in %s" % (match.group(0), raw)) #print(start_token) #print(end_token) #print(start_char, end_char) continue match_text = match.group(0) if match_text not in entities: raise RuntimeError("Matched %s, which is not in the entities from %s" % (match_text, annotated)) ner_tag = entities[match_text] tokens[0].ner = "B-" + ner_tag for token in tokens[1:]: token.ner = "I-" + ner_tag for sentence in tokenized.sentences: if not sentence.id in bad_sentences: annotated_sentences.append(sentence) return annotated_sentences def write_sentences(output_filename, annotated_sentences): logger.info("Writing %d sentences to %s" % (len(annotated_sentences), output_filename)) with open(output_filename, "w") as fout: for sentence in annotated_sentences: for token in sentence.tokens: ner_tag = token.ner if not ner_tag: ner_tag = "O" fout.write("%s\t%s\n" % (token.text, ner_tag)) fout.write("\n") def convert_bsnlp(language, base_input_path, output_filename, split_filename=None): """ Converts the BSNLP dataset for the given language. If only one output_filename is provided, all of the output goes to that file. If split_filename is provided as well, 15% of the output chosen randomly goes there instead. The dataset has no dev set, so this helps divide the data into train/dev/test. Note that the custom error fixes are only done for BG currently. Please manually correct the data as appropriate before using this for another language. """ if language not in AVAILABLE_LANGUAGES: raise ValueError("The current BSNLP datasets only include the following languages: %s" % ",".join(AVAILABLE_LANGUAGES)) if language != "bg": raise ValueError("There were quite a few data fixes needed to get the data correct for BG. Please work on similar fixes before using the model for %s" % language.upper()) pipeline = stanza.Pipeline(language, processors="tokenize") random.seed(1234) annotated_path = os.path.join(base_input_path, "annotated", "*", language, "*") annotated_files = sorted(glob.glob(annotated_path)) raw_path = os.path.join(base_input_path, "raw", "*", language, "*") raw_files = sorted(glob.glob(raw_path)) # if the instructions for downloading the data from the # process_ner_dataset script are followed, there will be two test # directories of data and a separate training directory of data. if len(annotated_files) == 0 and len(raw_files) == 0: logger.info("Could not find files in %s" % annotated_path) annotated_path = os.path.join(base_input_path, "annotated", language, "*") logger.info("Trying %s instead" % annotated_path) annotated_files = sorted(glob.glob(annotated_path)) raw_path = os.path.join(base_input_path, "raw", language, "*") raw_files = sorted(glob.glob(raw_path)) if len(annotated_files) != len(raw_files): raise ValueError("Unexpected differences in the file lists between %s and %s" % (annotated_files, raw_files)) for i, j in zip(annotated_files, raw_files): if os.path.split(i)[1][:-4] != os.path.split(j)[1][:-4]: raise ValueError("Unexpected differences in the file lists: found %s instead of %s" % (i, j)) annotated_sentences = [] if split_filename: split_sentences = [] for annotated, raw in zip(annotated_files, raw_files): new_sentences = get_sentences(language, pipeline, annotated, raw) if not split_filename or random.random() < 0.85: annotated_sentences.extend(new_sentences) else: split_sentences.extend(new_sentences) write_sentences(output_filename, annotated_sentences) if split_filename: write_sentences(split_filename, split_sentences) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--language', type=str, default="bg", help="Language to process") parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/bsnlp2019", help="Where to find the files") parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner/bg_bsnlp.test.csv", help="Where to output the results") parser.add_argument('--dev_path', type=str, default=None, help="A secondary output path - 15% of the data will go here") args = parser.parse_args() convert_bsnlp(args.language, args.input_path, args.output_path, args.dev_path) ================================================ FILE: stanza/utils/datasets/ner/convert_en_conll03.py ================================================ """ Downloads (if necessary) conll03 from Huggingface, then converts it to Stanza .json Some online sources for CoNLL 2003 require multiple pieces, but it is currently hosted on HF: https://huggingface.co/datasets/conll2003 """ import os from stanza.utils.default_paths import get_default_paths from stanza.utils.datasets.ner.utils import write_dataset TAG_TO_ID = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8} ID_TO_TAG = {y: x for x, y in TAG_TO_ID.items()} def convert_dataset_section(section): sentences = [] for item in section: words = item['tokens'] tags = [ID_TO_TAG[x] for x in item['ner_tags']] sentences.append(list(zip(words, tags))) return sentences def process_dataset(short_name, conll_path, ner_output_path): try: from datasets import load_dataset except ImportError as e: raise ImportError("Please install the datasets package to process CoNLL03 with Stanza") dataset = load_dataset('conll2003', cache_dir=conll_path) datasets = [convert_dataset_section(x) for x in [dataset['train'], dataset['validation'], dataset['test']]] write_dataset(datasets, ner_output_path, short_name) def main(): paths = get_default_paths() ner_input_path = paths['NERBASE'] conll_path = os.path.join(ner_input_path, "english", "en_conll03") ner_output_path = paths['NER_DATA_DIR'] process_dataset("en_conll03", conll_path, ner_output_path) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/convert_fire_2013.py ================================================ """ Converts the FIRE 2013 dataset to TSV http://au-kbc.org/nlp/NER-FIRE2013/index.html The dataset is in six tab separated columns. The columns are word tag chunk ner1 ner2 ner3 This script keeps just the word and the ner1. It is quite possible that using the tag would help """ import argparse import glob import os import random def normalize(e1, e2, e3): if e1 == 'o': return "O" if e2 != 'o' and e1[:2] != e2[:2]: raise ValueError("Found a token with conflicting position tags %s,%s" % (e1, e2)) if e3 != 'o' and e2 == 'o': raise ValueError("Found a token with tertiary label but no secondary label %s,%s,%s" % (e1, e2, e3)) if e3 != 'o' and (e1[:2] != e2[:2] or e1[:2] != e3[:2]): raise ValueError("Found a token with conflicting position tags %s,%s,%s" % (e1, e2, e3)) if e1[2:] in ('ORGANIZATION', 'FACILITIES'): return e1 if e1[2:] == 'ENTERTAINMENT' and e2[2:] != 'SPORTS' and e2[2:] != 'CINEMA': return e1 if e1[2:] == 'DISEASE' and e2 == 'o': return e1 if e1[2:] == 'PLANTS' and e2[2:] != 'PARTS': return e1 if e1[2:] == 'PERSON' and e2[2:] == 'INDIVIDUAL': return e1 if e1[2:] == 'LOCATION' and e2[2:] == 'PLACE': return e1 if e1[2:] in ('DATE', 'TIME', 'YEAR'): string = e1[:2] + 'DATETIME' return string return "O" def read_fileset(filenames): # first, read the sentences from each data file sentences = [] for filename in filenames: with open(filename) as fin: next_sentence = [] for line in fin: line = line.strip() if not line: # lots of single line "sentences" in the dataset if next_sentence: if len(next_sentence) > 1: sentences.append(next_sentence) next_sentence = [] else: next_sentence.append(line) if next_sentence and len(next_sentence) > 1: sentences.append(next_sentence) return sentences def write_fileset(output_csv_file, sentences): with open(output_csv_file, "w") as fout: for sentence in sentences: for line in sentence: pieces = line.split("\t") if len(pieces) != 6: raise ValueError("Found %d pieces instead of the expected 6" % len(pieces)) if pieces[3] == 'o' and (pieces[4] != 'o' or pieces[5] != 'o'): raise ValueError("Inner NER labeled but the top layer was O") fout.write("%s\t%s\n" % (pieces[0], normalize(pieces[3], pieces[4], pieces[5]))) fout.write("\n") def convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file): random.seed(1234) filenames = glob.glob(os.path.join(input_path, "*")) # won't be numerically sorted... shouldn't matter filenames = sorted(filenames) random.shuffle(filenames) sentences = read_fileset(filenames) random.shuffle(sentences) train_cutoff = int(0.8 * len(sentences)) dev_cutoff = int(0.9 * len(sentences)) train_sentences = sentences[:train_cutoff] dev_sentences = sentences[train_cutoff:dev_cutoff] test_sentences = sentences[dev_cutoff:] random.shuffle(train_sentences) random.shuffle(dev_sentences) random.shuffle(test_sentences) assert len(train_sentences) > 0 assert len(dev_sentences) > 0 assert len(test_sentences) > 0 write_fileset(train_csv_file, train_sentences) write_fileset(dev_csv_file, dev_sentences) write_fileset(test_csv_file, test_sentences) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/FIRE2013/hindi_train", help="Directory with raw files to read") parser.add_argument('--train_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.train.csv", help="Where to put the train file") parser.add_argument('--dev_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.dev.csv", help="Where to put the dev file") parser.add_argument('--test_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.test.csv", help="Where to put the test file") args = parser.parse_args() convert_fire_2013(args.input_path, args.train_file, args.dev_file, args.test_file) ================================================ FILE: stanza/utils/datasets/ner/convert_he_iahlt.py ================================================ from collections import defaultdict import os import re from stanza.utils.conll import CoNLL import stanza.utils.default_paths as default_paths from stanza.utils.datasets.ner.utils import write_dataset def output_entities(sentence): for word in sentence.words: misc = word.misc if misc is None: continue pieces = misc.split("|") for piece in pieces: if piece.startswith("Entity="): entity = piece.split("=", maxsplit=1)[1] print(" " + entity) break def extract_single_sentence(sentence): current_entity = [] words = [] for word in sentence.words: text = word.text misc = word.misc if misc is None: pieces = [] else: pieces = misc.split("|") closes = [] first_entity = False for piece in pieces: if piece.startswith("Entity="): entity = piece.split("=", maxsplit=1)[1] entity_pieces = re.split(r"([()])", entity) entity_pieces = [x for x in entity_pieces if x] # remove blanks from re.split entity_idx = 0 while entity_idx < len(entity_pieces): if entity_pieces[entity_idx] == '(': assert len(entity_pieces) > entity_idx + 1, "Opening an unspecified entity" if len(current_entity) == 0: first_entity = True current_entity.append(entity_pieces[entity_idx + 1]) entity_idx += 2 elif entity_pieces[entity_idx] == ')': assert entity_idx != 0, "Closing an unspecified entity" closes.append(entity_pieces[entity_idx-1]) entity_idx += 1 else: # the entities themselves get added or removed via the () entity_idx += 1 if len(current_entity) == 0: entity = 'O' else: entity = current_entity[0] entity = "B-" + entity if first_entity else "I-" + entity words.append((text, entity)) assert len(current_entity) >= len(closes), "Too many closes for the current open entities" for close_entity in closes: # TODO: check the close is closing the right thing assert close_entity == current_entity[-1], "Closed the wrong entity: %s vs %s" % (close_entity, current_entity[-1]) current_entity = current_entity[:-1] return words def extract_sentences(doc): sentences = [] for sentence in doc.sentences: try: words = extract_single_sentence(sentence) sentences.append(words) except AssertionError as e: print("Skipping sentence %s ... %s" % (sentence.sent_id, str(e))) output_entities(sentence) return sentences def convert_iahlt(udbase, output_dir, short_name): shards = ("train", "dev", "test") ud_datasets = ["UD_Hebrew-IAHLTwiki", "UD_Hebrew-IAHLTknesset"] base_filenames = ["he_iahltwiki-ud-%s.conllu", "he_iahltknesset-ud-%s.conllu"] datasets = defaultdict(list) for ud_dataset, base_filename in zip(ud_datasets, base_filenames): ud_dataset_path = os.path.join(udbase, ud_dataset) for shard in shards: filename = os.path.join(ud_dataset_path, base_filename % shard) doc = CoNLL.conll2doc(filename) sentences = extract_sentences(doc) print("Read %d sentences from %s" % (len(sentences), filename)) datasets[shard].extend(sentences) datasets = [datasets[x] for x in shards] write_dataset(datasets, output_dir, short_name) def main(): paths = default_paths.get_default_paths() udbase = paths["UDBASE_GIT"] output_directory = paths["NER_DATA_DIR"] convert_iahlt(udbase, output_directory, "he_iahlt") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/convert_hy_armtdp.py ================================================ """ Convert a ArmTDP-NER dataset to BIO format The dataset is here: https://github.com/myavrum/ArmTDP-NER.git """ import argparse import os import json import re import stanza import random from tqdm import tqdm from stanza import DownloadMethod, Pipeline import stanza.utils.default_paths as default_paths def read_data(path: str) -> list: """ Reads the Armenian named entity recognition dataset Returns a list of dictionaries. Each dictionary contains information about a paragraph (text, labels, etc.) """ with open(path, 'r') as file: paragraphs = [json.loads(line) for line in file] return paragraphs def filter_unicode_broken_characters(text: str) -> str: """ Removes all unicode characters in text """ return re.sub(r'\\u[A-Za-z0-9]{4}', '', text) def get_label(tok_start_char: int, tok_end_char: int, labels: list) -> list: """ Returns the label that corresponds to the given token """ for label in labels: if label[0] <= tok_start_char and label[1] >= tok_end_char: return label return [] def format_sentences(paragraphs: list, nlp_hy: Pipeline) -> list: """ Takes a list of paragraphs and returns a list of sentences, where each sentence is a list of tokens along with their respective entity tags. """ sentences = [] for paragraph in tqdm(paragraphs): doc = nlp_hy(filter_unicode_broken_characters(paragraph['text'])) for sentence in doc.sentences: sentence_ents = [] entity = [] for token in sentence.tokens: label = get_label(token.start_char, token.end_char, paragraph['labels']) if label: entity.append(token.text) if token.end_char == label[1]: sentence_ents.append({'tokens': entity, 'tag': label[2]}) entity = [] else: sentence_ents.append({'tokens': [token.text], 'tag': 'O'}) sentences.append(sentence_ents) return sentences def convert_to_bioes(sentences: list) -> list: """ Returns a list of strings where each string represents a sentence in BIOES format """ beios_sents = [] for sentence in tqdm(sentences): sentence_toc = '' for ent in sentence: if ent['tag'] == 'O': sentence_toc += ent['tokens'][0] + '\tO' + '\n' else: if len(ent['tokens']) == 1: sentence_toc += ent['tokens'][0] + '\tS-' + ent['tag'] + '\n' else: sentence_toc += ent['tokens'][0] + '\tB-' + ent['tag'] + '\n' for token in ent['tokens'][1:-1]: sentence_toc += token + '\tI-' + ent['tag'] + '\n' sentence_toc += ent['tokens'][-1] + '\tE-' + ent['tag'] + '\n' beios_sents.append(sentence_toc) return beios_sents def write_sentences_to_file(sents, filename): print(f"Writing {len(sents)} sentences to {filename}") with open(filename, 'w') as outfile: for sent in sents: outfile.write(sent + '\n\n') def train_test_dev_split(sents, base_output_path, short_name, train_fraction=0.7, dev_fraction=0.15): """ Splits a list of sentences into training, dev, and test sets, and writes each set to a separate file with write_sentences_to_file """ num = len(sents) train_num = int(num * train_fraction) dev_num = int(num * dev_fraction) if train_fraction + dev_fraction > 1.0: raise ValueError( "Train and dev fractions added up to more than 1: {} {} {}".format(train_fraction, dev_fraction)) random.shuffle(sents) train_sents = sents[:train_num] dev_sents = sents[train_num:train_num + dev_num] test_sents = sents[train_num + dev_num:] batches = [train_sents, dev_sents, test_sents] filenames = [f'{short_name}.train.tsv', f'{short_name}.dev.tsv', f'{short_name}.test.tsv'] for batch, filename in zip(batches, filenames): write_sentences_to_file(batch, os.path.join(base_output_path, filename)) def convert_dataset(base_input_path, base_output_path, short_name, download_method=DownloadMethod.DOWNLOAD_RESOURCES): nlp_hy = stanza.Pipeline(lang='hy', processors='tokenize', download_method=download_method) paragraphs = read_data(os.path.join(base_input_path, 'ArmNER-HY.json1')) tagged_sentences = format_sentences(paragraphs, nlp_hy) beios_sentences = convert_to_bioes(tagged_sentences) train_test_dev_split(beios_sentences, base_output_path, short_name) if __name__ == '__main__': paths = default_paths.get_default_paths() parser = argparse.ArgumentParser() parser.add_argument('--input_path', type=str, default=os.path.join(paths["NERBASE"], "armenian", "ArmTDP-NER"), help="Path to input file") parser.add_argument('--output_path', type=str, default=paths["NER_DATA_DIR"], help="Path to the output directory") parser.add_argument('--short_name', type=str, default="hy_armtdp", help="Name to identify the dataset and the model") parser.add_argument('--download_method', type=str, default=DownloadMethod.DOWNLOAD_RESOURCES, help="Download method for initializing the Pipeline. Default downloads the Armenian pipeline, --download_method NONE does not. Options: %s" % DownloadMethod._member_names_) args = parser.parse_args() convert_dataset(args.input_path, args.output_path, args.short_name, args.download_method) ================================================ FILE: stanza/utils/datasets/ner/convert_ijc.py ================================================ import argparse import random import sys """ Converts IJC data to a TSV format. So far, tested on Hindi. Not checked on any of the other languages. """ def convert_tag(tag): """ Project the classes IJC used to 4 classes with more human-readable names The trained result is a pile, as I inadvertently taught my daughter to call horrible things, but leaving them with the original classes is also a pile """ if not tag: return "O" if tag == "NEP": return "PER" if tag == "NEO": return "ORG" if tag == "NEL": return "LOC" return "MISC" def read_single_file(input_file, bio_format=True): """ Reads an IJC NER file and returns a list of list of lines """ sentences = [] lineno = 0 with open(input_file) as fin: current_sentence = [] in_ner = False in_sentence = False printed_first = False nesting = 0 for line in fin: lineno = lineno + 1 line = line.strip() if not line: continue if line.startswith(""): assert not current_sentence, "File %s had an unexpected tag" % input_file continue if line.startswith(""): # Would like to assert that empty sentences don't exist, but alas, they do # assert current_sentence, "File %s has an empty sentence at %d" % (input_file, lineno) # AssertionError: File .../hi_ijc/training-hindi/193.naval.utf8 has an empty sentence at 74 if current_sentence: sentences.append(current_sentence) current_sentence = [] continue if line == "))": assert in_sentence, "File %s closed a sentence when there was no open sentence at %d" % (input_file, lineno) nesting = nesting - 1 if nesting < 0: in_sentence = False nesting = 0 elif nesting == 0: in_ner = False continue pieces = line.split("\t") if pieces[0] == '0': assert pieces[1] == '((', "File %s has an unexpected first line at %d" % (input_file, lineno) in_sentence = True continue if pieces[1] == '((': nesting = nesting + 1 if nesting == 1: if len(pieces) < 4: tag = None else: assert pieces[3][0] == '<' and pieces[3][-1] == '>', "File %s has an unexpected tag format at %d: %s" % (input_file, lineno, pieces[3]) ne, tag = pieces[3][1:-1].split('=', 1) assert pieces[3] == "<%s=%s>" % (ne, tag), "File %s has an unexpected tag format at %d: %s" % (input_file, lineno, pieces[3]) in_ner = True printed_first = False tag = convert_tag(tag) elif in_ner and tag: if bio_format: if printed_first: current_sentence.append((pieces[1], "I-" + tag)) else: current_sentence.append((pieces[1], "B-" + tag)) printed_first = True else: current_sentence.append((pieces[1], tag)) else: current_sentence.append((pieces[1], "O")) assert not current_sentence, "File %s is unclosed!" % input_file return sentences def read_ijc_files(input_files, bio_format=True): sentences = [] for input_file in input_files: sentences.extend(read_single_file(input_file, bio_format)) return sentences def convert_ijc(input_files, csv_file, bio_format=True): sentences = read_ijc_files(input_files, bio_format) with open(csv_file, "w") as fout: for sentence in sentences: for word in sentence: fout.write("%s\t%s\n" % word) fout.write("\n") def convert_split_ijc(input_files, train_csv, dev_csv): """ Randomly splits the given list of input files into a train/dev with 85/15 split The original datasets only have train & test """ random.seed(1234) train_files = [] dev_files = [] for filename in input_files: if random.random() < 0.85: train_files.append(filename) else: dev_files.append(filename) if len(train_files) == 0 or len(dev_files) == 0: raise RuntimeError("Not enough files to split into train & dev") convert_ijc(train_files, train_csv) convert_ijc(dev_files, dev_csv) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner/hi_ijc.test.csv", help="Where to output the results") parser.add_argument('input_files', metavar='N', nargs='+', help='input files to process') args = parser.parse_args() convert_ijc(args.input_files, args.output_path, False) ================================================ FILE: stanza/utils/datasets/ner/convert_kk_kazNERD.py ================================================ """ Convert a Kazakh NER dataset to our internal .json format The dataset is here: https://github.com/IS2AI/KazNERD/tree/main/KazNERD """ import argparse import os import shutil # import random from stanza.utils.datasets.ner.utils import convert_bio_to_json, SHARDS def convert_dataset(in_directory, out_directory, short_name): """ Reads in train, validation, and test data and converts them to .json file """ filenames = ("IOB2_train.txt", "IOB2_valid.txt", "IOB2_test.txt") for shard, filename in zip(SHARDS, filenames): input_filename = os.path.join(in_directory, filename) output_filename = os.path.join(out_directory, "%s.%s.bio" % (short_name, shard)) shutil.copy(input_filename, output_filename) convert_bio_to_json(out_directory, out_directory, short_name, "bio") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input_path', type=str, default="/nlp/scr/aaydin/kazNERD/NER", help="Where to find the files") parser.add_argument('--output_path', type=str, default="/nlp/scr/aaydin/kazNERD/data/ner", help="Where to output the results") args = parser.parse_args() # in_path = '/nlp/scr/aaydin/kazNERD/NER' # out_path = '/nlp/scr/aaydin/kazNERD/NER/output' # convert_dataset(in_path, out_path) convert_dataset(args.input_path, args.output_path, "kk_kazNERD") ================================================ FILE: stanza/utils/datasets/ner/convert_lst20.py ================================================ """ Converts the Thai LST20 dataset to a format usable by Stanza's NER model The dataset in the original format has a few tag errors which we automatically fix (or at worst cover up) """ import os from stanza.utils.datasets.ner.utils import convert_bio_to_json def convert_lst20(paths, short_name, include_space_char=True): assert short_name == "th_lst20" SHARDS = ("train", "eval", "test") BASE_OUTPUT_PATH = paths["NER_DATA_DIR"] input_split = [(os.path.join(paths["NERBASE"], "thai", "LST20_Corpus", x), x) for x in SHARDS] if not include_space_char: short_name = short_name + "_no_ws" for input_folder, split_type in input_split: text_list = [text for text in os.listdir(input_folder) if text[0] == 'T'] if split_type == "eval": split_type = "dev" output_path = os.path.join(BASE_OUTPUT_PATH, "%s.%s.bio" % (short_name, split_type)) print(output_path) with open(output_path, 'w', encoding='utf-8') as fout: for text in text_list: lst = [] with open(os.path.join(input_folder, text), 'r', encoding='utf-8') as fin: lines = fin.readlines() for line_idx, line in enumerate(lines): x = line.strip().split('\t') if len(x) > 1: if x[0] == '_' and not include_space_char: continue else: word, tag = x[0], x[2] if tag == "MEA_BI": tag = "B_MEA" if tag == "OBRN_B": tag = "B_BRN" if tag == "ORG_I": tag = "I_ORG" if tag == "PER_I": tag = "I_PER" if tag == "LOC_I": tag = "I_LOC" if tag == "B" and line_idx + 1 < len(lines): x_next = lines[line_idx+1].strip().split('\t') if len(x_next) > 1: tag_next = x_next[2] if "I_" in tag_next or "E_" in tag_next: tag = tag + tag_next[1:] else: tag = "O" else: tag = "O" if "_" in tag: tag = tag.replace("_", "-") if "ABB" in tag or tag == "DDEM" or tag == "I" or tag == "__": tag = "O" fout.write('{}\t{}'.format(word, tag)) fout.write('\n') else: fout.write('\n') convert_bio_to_json(BASE_OUTPUT_PATH, BASE_OUTPUT_PATH, short_name) ================================================ FILE: stanza/utils/datasets/ner/convert_mr_l3cube.py ================================================ """ Reads one piece of the MR L3Cube dataset The dataset is structured as a long list of words already in IOB format The sentences have an ID which changes when a new sentence starts The tags are labeled BNEM instead of B-NEM, so we update that. (Could theoretically remap the tags to names more typical of other datasets as well) """ def convert(input_file): """ Converts one file of the dataset Return: a list of list of pairs, (text, tag) """ with open(input_file, encoding="utf-8") as fin: lines = fin.readlines() sentences = [] current_sentence = [] prev_sent_id = None for idx, line in enumerate(lines): # first line of each of the segments is the header if idx == 0: continue line = line.strip() if not line: continue pieces = line.split("\t") if len(pieces) != 3: raise ValueError("Unexpected number of pieces at line %d of %s" % (idx, input_file)) text, ner, sent_id = pieces if ner != 'O': # ner symbols are written as BNEM, BNED, etc in this dataset ner = ner[0] + "-" + ner[1:] if not prev_sent_id: prev_sent_id = sent_id if sent_id != prev_sent_id: prev_sent_id = sent_id if len(current_sentence) == 0: raise ValueError("This should not happen!") sentences.append(current_sentence) current_sentence = [] current_sentence.append((text, ner)) if current_sentence: sentences.append(current_sentence) print("Read %d sentences in %d lines from %s" % (len(sentences), len(lines), input_file)) return sentences ================================================ FILE: stanza/utils/datasets/ner/convert_my_ucsy.py ================================================ """ Processes the three pieces of the NER dataset we received from UCSY. Requires the Myanmar tokenizer to exist, since the text is not already tokenized. There are three files sent to us from UCSY, one each for train, dev, test This script expects them to be in the ner directory with the names $NERBASE/my_ucsy/Myanmar_NER_train.txt $NERBASE/my_ucsy/Myanmar_NER_dev.txt $NERBASE/my_ucsy/Myanmar_NER_test.txt The files are in the following format: unsegmentedtext@LABEL|unsegmentedtext@LABEL|... with one sentence per line Solution: - break the text up into fragments by splitting on | - extract the labels - segment each block of text using the MY tokenizer We could take two approaches to breaking up the blocks. One would be to combine all chunks, then segment an entire sentence at once. This would require some logic to re-chunk the resulting pieces. Instead, we resegment each individual chunk by itself. This loses the information from the neighboring chunks, but guarantees there are no screwups where segmentation crosses segment boundaries and is simpler to code. Of course, experimenting with the alternate approach might be better. There is one stray label of SB in the training data, so we throw out that entire sentence. """ import os from tqdm import tqdm import stanza from stanza.utils.datasets.ner.check_for_duplicates import check_for_duplicates SPLITS = ("train", "dev", "test") def convert_file(input_filename, output_filename, pipe): with open(input_filename) as fin: lines = fin.readlines() all_labels = set() with open(output_filename, "w") as fout: for line in tqdm(lines): pieces = line.split("|") texts = [] labels = [] skip_sentence = False for piece in pieces: piece = piece.strip() if not piece: continue text, label = piece.rsplit("@", maxsplit=1) text = text.strip() if not text: continue if label == 'SB': skip_sentence = True break texts.append(text) labels.append(label) if skip_sentence: continue text = "\n\n".join(texts) doc = pipe(text) assert len(doc.sentences) == len(texts) for sentence, label in zip(doc.sentences, labels): all_labels.add(label) for word_idx, word in enumerate(sentence.words): if label == "O": output_label = "O" elif word_idx == 0: output_label = "B-" + label else: output_label = "I-" + label fout.write("%s\t%s\n" % (word.text, output_label)) fout.write("\n\n") print("Finished processing {} Labels found: {}".format(input_filename, sorted(all_labels))) def convert_my_ucsy(base_input_path, base_output_path): os.makedirs(base_output_path, exist_ok=True) pipe = stanza.Pipeline("my", processors="tokenize", tokenize_no_ssplit=True) output_filenames = [os.path.join(base_output_path, "my_ucsy.%s.bio" % split) for split in SPLITS] for split, output_filename in zip(SPLITS, output_filenames): input_filename = os.path.join(base_input_path, "Myanmar_NER_%s.txt" % split) if not os.path.exists(input_filename): raise FileNotFoundError("Necessary file for my_ucsy does not exist: %s" % input_filename) convert_file(input_filename, output_filename, pipe) ================================================ FILE: stanza/utils/datasets/ner/convert_nkjp.py ================================================ import argparse import json import os import random import tarfile import tempfile from tqdm import tqdm # could import lxml here, but that would involve adding lxml as a # dependency to the stanza package # another alternative would be to try & catch ImportError try: from lxml import etree except ImportError: import xml.etree.ElementTree as etree NAMESPACE = "http://www.tei-c.org/ns/1.0" MORPH_FILE = "ann_morphosyntax.xml" NER_FILE = "ann_named.xml" SEGMENTATION_FILE = "ann_segmentation.xml" def parse_xml(path): if not os.path.exists(path): return None et = etree.parse(path) rt = et.getroot() return rt def get_node_id(node): # get the id from the xml node return node.get('{http://www.w3.org/XML/1998/namespace}id') def extract_entities_from_subfolder(subfolder, nkjp_dir): # read the ner annotation from a subfolder, assign it to paragraphs subfolder_entities = extract_unassigned_subfolder_entities(subfolder, nkjp_dir) par_id_to_segs = assign_entities(subfolder, subfolder_entities, nkjp_dir) return par_id_to_segs def extract_unassigned_subfolder_entities(subfolder, nkjp_dir): """ Build and return a map from par_id to extracted entities """ ner_path = os.path.join(nkjp_dir, subfolder, NER_FILE) rt = parse_xml(ner_path) if rt is None: return None subfolder_entities = {} ner_pars = rt.findall("{%s}TEI/{%s}text/{%s}body/{%s}p" % (NAMESPACE, NAMESPACE, NAMESPACE, NAMESPACE)) for par in ner_pars: par_entities = {} _, par_id = get_node_id(par).split("_") ner_sents = par.findall("{%s}s" % NAMESPACE) for ner_sent in ner_sents: corresp = ner_sent.get("corresp") _, ner_sent_id = corresp.split("#morph_") par_entities[ner_sent_id] = extract_entities_from_sentence(ner_sent) subfolder_entities[par_id] = par_entities return subfolder_entities def extract_entities_from_sentence(ner_sent): # extracts all the entity dicts from the sentence # we assume that an entity cannot span across sentences segs = ner_sent.findall("./{%s}seg" % NAMESPACE) sent_entities = {} for i, seg in enumerate(segs): ent_id = get_node_id(seg) targets = [ptr.get("target") for ptr in seg.findall("./{%s}ptr" % NAMESPACE)] orth = seg.findall("./{%s}fs/{%s}f[@name='orth']/{%s}string" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].text ner_type = seg.findall("./{%s}fs/{%s}f[@name='type']/{%s}symbol" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].get("value") ner_subtype_node = seg.findall("./{%s}fs/{%s}f[@name='subtype']/{%s}symbol" % (NAMESPACE, NAMESPACE, NAMESPACE)) if ner_subtype_node: ner_subtype = ner_subtype_node[0].get("value") else: ner_subtype = None entity = {"ent_id": ent_id, "index": i, "orth": orth, "ner_type": ner_type, "ner_subtype": ner_subtype, "targets": targets} sent_entities[ent_id] = entity cleared_entities = clear_entities(sent_entities) return cleared_entities def clear_entities(entities): # eliminates entities which extend beyond our scope resolve_entities(entities) entities_list = sorted(list(entities.values()), key=lambda ent: ent["index"]) entities = eliminate_overlapping_entities(entities_list) for entity in entities: targets = entity["targets"] entity["targets"] = [t.split("morph_")[1] for t in targets] return entities def resolve_entities(entities): # assign morphological level targets to entities resolved_targets = {entity_id: resolve_entity(entity, entities) for entity_id, entity in entities.items()} for entity_id in entities: entities[entity_id]["targets"] = resolved_targets[entity_id] def resolve_entity(entity, entities): # translate targets defined in terms of entities, into morphological units # works recurrently targets = entity["targets"] resolved = [] for target in targets: if target.startswith("named_"): target_entity = entities[target] resolved.extend(resolve_entity(target_entity, entities)) else: resolved.append(target) return resolved def eliminate_overlapping_entities(entities_list): # we eliminate entities which are at least partially contained in one ocurring prior to them # this amounts to removing overlap subsumed = set([]) for sub_i, sub in enumerate(entities_list): for over in entities_list[:sub_i]: if any([target in over["targets"] for target in sub["targets"]]): subsumed.add(sub["ent_id"]) return [entity for entity in entities_list if entity["ent_id"] not in subsumed] def assign_entities(subfolder, subfolder_entities, nkjp_dir): # recovers all the segments from a subfolder, and annotates it with NER morph_path = os.path.join(nkjp_dir, subfolder, MORPH_FILE) rt = parse_xml(morph_path) morph_pars = rt.findall("{%s}TEI/{%s}text/{%s}body/{%s}p" % (NAMESPACE, NAMESPACE, NAMESPACE, NAMESPACE)) par_id_to_segs = {} for par in morph_pars: _, par_id = get_node_id(par).split("_") morph_sents = par.findall("{%s}s" % NAMESPACE) sent_id_to_segs = {} for morph_sent in morph_sents: _, sent_id = get_node_id(morph_sent).split("_") segs = morph_sent.findall("{%s}seg" % NAMESPACE) sent_segs = {} for i, seg in enumerate(segs): _, seg_id = get_node_id(seg).split("morph_") orth = seg.findall("{%s}fs/{%s}f[@name='orth']/{%s}string" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].text token = {"seg_id": seg_id, "i": i, "orth": orth, "text": orth, "tag": "_", "ner": "O", # This will be overwritten "ner_subtype": None, } sent_segs[seg_id] = token sent_id_to_segs[sent_id] = sent_segs par_id_to_segs[par_id] = sent_id_to_segs if subfolder_entities is None: return None for par_key in subfolder_entities: par_ents = subfolder_entities[par_key] for sent_key in par_ents: sent_entities = par_ents[sent_key] for entity in sent_entities: targets = entity["targets"] iob = "B" ner_label = entity["ner_type"] matching_tokens = sorted([par_id_to_segs[par_key][sent_key][target] for target in targets], key=lambda x:x["i"]) for token in matching_tokens: full_label = f"{iob}-{ner_label}" token["ner"] = full_label token["ner_subtype"] = entity["ner_subtype"] iob = "I" return par_id_to_segs def load_xml_nkjp(nkjp_dir): subfolder_to_annotations = {} subfolders = sorted(os.listdir(nkjp_dir)) for subfolder in tqdm([name for name in subfolders if os.path.isdir(os.path.join(nkjp_dir, name))]): out = extract_entities_from_subfolder(subfolder, nkjp_dir) if out: subfolder_to_annotations[subfolder] = out else: print(subfolder, "has no ann_named.xml file") return subfolder_to_annotations def split_dataset(dataset, shuffle=True, train_fraction=0.9, dev_fraction=0.05, test_section=True): random.seed(987654321) if shuffle: random.shuffle(dataset) if not test_section: dev_fraction = 1 - train_fraction train_size = int(train_fraction * len(dataset)) dev_size = int(dev_fraction * len(dataset)) train = dataset[:train_size] dev = dataset[train_size: train_size + dev_size] test = dataset[train_size + dev_size:] return { 'train': train, 'dev': dev, 'test': test } def convert_nkjp(nkjp_path, output_dir): """Converts NKJP NER data into IOB json format. nkjp_dir is the path to directory where NKJP files are located. """ # Load XML NKJP print("Reading data from %s" % nkjp_path) if os.path.isfile(nkjp_path) and (nkjp_path.endswith(".tar.gz") or nkjp_path.endswith(".tgz")): with tempfile.TemporaryDirectory() as nkjp_dir: print("Temporarily extracting %s to %s" % (nkjp_path, nkjp_dir)) with tarfile.open(nkjp_path, "r:gz") as tar_in: tar_in.extractall(nkjp_dir) subfolder_to_entities = load_xml_nkjp(nkjp_dir) elif os.path.isdir(nkjp_path): subfolder_to_entities = load_xml_nkjp(nkjp_path) else: raise FileNotFoundError("Cannot find either unpacked dataset or gzipped file") converted = [] for subfolder_name, pars in subfolder_to_entities.items(): for par_id, par in pars.items(): paragraph_identifier = f"{subfolder_name}|{par_id}" par_tokens = [] for _, sent in par.items(): tokens = sent.values() srt = sorted(tokens, key=lambda tok:tok["i"]) for token in srt: _ = token.pop("i") _ = token.pop("seg_id") par_tokens.append(token) par_tokens[0]["paragraph_id"] = paragraph_identifier converted.append(par_tokens) split = split_dataset(converted) for split_name, split in split.items(): if split: with open(os.path.join(output_dir, f"pl_nkjp.{split_name}.json"), "w", encoding="utf-8") as f: json.dump(split, f, ensure_ascii=False, indent=2) def main(): parser = argparse.ArgumentParser() parser.add_argument('--input_path', type=str, default="/u/nlp/data/ner/stanza/polish/NKJP-PodkorpusMilionowy-1.2.tar.gz", help="Where to find the files") parser.add_argument('--output_path', type=str, default="data/ner", help="Where to output the results") args = parser.parse_args() convert_nkjp(args.input_path, args.output_path) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/convert_nner22.py ================================================ """ Converts the Thai NNER22 dataset to a format usable by Stanza's NER model The dataset is already written in json format, so we will convert into a compatible json format. The dataset in the original format has nested NER format which we will only extract the first layer of NER tag and write it in the format accepted by current Stanza model """ import os import logging import json def convert_nner22(paths, short_name, include_space_char=True): assert short_name == "th_nner22" SHARDS = ("train", "dev", "test") BASE_INPUT_PATH = os.path.join(paths["NERBASE"], "thai", "Thai-NNER", "data", "scb-nner-th-2022", "postproc") if not include_space_char: short_name = short_name + "_no_ws" for shard in SHARDS: input_path = os.path.join(BASE_INPUT_PATH, "%s.json" % (shard)) output_path = os.path.join(paths["NER_DATA_DIR"], "%s.%s.json" % (short_name, shard)) logging.info("Output path for %s split at %s" % (shard, output_path)) data = json.load(open(input_path)) documents = [] for i in range(len(data)): token, entities = data[i]["tokens"], data[i]["entities"] token_length, sofar = len(token), 0 document, ner_dict = [], {} for entity in entities: start, stop = entity["span"] if stop > sofar: ner = entity["entity_type"].upper() sofar = stop for j in range(start, stop): if j == start: ner_tag = "B-" + ner elif j == stop - 1: ner_tag = "E-" + ner else: ner_tag = "I-" + ner ner_dict[j] = (ner_tag, token[j]) for k in range(token_length): dict_add = {} if k not in ner_dict: dict_add["ner"], dict_add["text"] = "O", token[k] else: dict_add["ner"], dict_add["text"] = ner_dict[k] document.append(dict_add) documents.append(document) with open(output_path, "w") as outfile: json.dump(documents, outfile, indent=1) logging.info("%s.%s.json file successfully created" % (short_name, shard)) ================================================ FILE: stanza/utils/datasets/ner/convert_nytk.py ================================================ import glob import os def convert_nytk(base_input_path, base_output_path, short_name): for shard in ('train', 'dev', 'test'): if shard == 'dev': base_input_subdir = os.path.join(base_input_path, "data/train-devel-test/devel") else: base_input_subdir = os.path.join(base_input_path, "data/train-devel-test", shard) shard_lines = [] base_input_glob = base_input_subdir + "/*/no-morph/*" subpaths = glob.glob(base_input_glob) print("Reading %d input files from %s" % (len(subpaths), base_input_glob)) for input_filename in subpaths: if len(shard_lines) > 0: shard_lines.append("") with open(input_filename) as fin: lines = fin.readlines() if lines[0].strip() != '# global.columns = FORM LEMMA UPOS XPOS FEATS CONLL:NER': raise ValueError("Unexpected format in %s" % input_filename) lines = [x.strip().split("\t") for x in lines[1:]] lines = ["%s\t%s" % (x[0], x[5]) if len(x) > 1 else "" for x in lines] shard_lines.extend(lines) bio_filename = os.path.join(base_output_path, '%s.%s.bio' % (short_name, shard)) with open(bio_filename, "w") as fout: print("Writing %d lines to %s" % (len(shard_lines), bio_filename)) for line in shard_lines: fout.write(line) fout.write("\n") ================================================ FILE: stanza/utils/datasets/ner/convert_ontonotes.py ================================================ """ Downloads (if necessary) conll03 from Huggingface, then converts it to Stanza .json Some online sources for CoNLL 2003 require multiple pieces, but it is currently hosted on HF: https://huggingface.co/datasets/conll2003 """ import os from stanza.utils.default_paths import get_default_paths from stanza.utils.datasets.ner.utils import write_dataset ID_TO_TAG = ["O", "B-PERSON", "I-PERSON", "B-NORP", "I-NORP", "B-FAC", "I-FAC", "B-ORG", "I-ORG", "B-GPE", "I-GPE", "B-LOC", "I-LOC", "B-PRODUCT", "I-PRODUCT", "B-DATE", "I-DATE", "B-TIME", "I-TIME", "B-PERCENT", "I-PERCENT", "B-MONEY", "I-MONEY", "B-QUANTITY", "I-QUANTITY", "B-ORDINAL", "I-ORDINAL", "B-CARDINAL", "I-CARDINAL", "B-EVENT", "I-EVENT", "B-WORK_OF_ART", "I-WORK_OF_ART", "B-LAW", "I-LAW", "B-LANGUAGE", "I-LANGUAGE",] def convert_dataset_section(config_name, section): sentences = [] for doc in section: # the nt_ sentences (New Testament) in the HF version of OntoNotes # have blank named_entities, even though there was no original .name file # that corresponded with these annotations if config_name.startswith("english") and doc['document_id'].startswith("pt/nt"): continue for sentence in doc['sentences']: words = sentence['words'] tags = [ID_TO_TAG[x] for x in sentence['named_entities']] sentences.append(list(zip(words, tags))) return sentences def process_dataset(short_name, conll_path, ner_output_path): try: from datasets import load_dataset except ImportError as e: raise ImportError("Please install the datasets package to process CoNLL03 with Stanza") if short_name == 'en_ontonotes': # there is an english_v12, but it is filled with junk annotations # for example, near the end: # And John_O, I realize config_name = 'english_v4' elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'): config_name = 'chinese_v4' elif short_name == 'ar_ontonotes': config_name = 'arabic_v4' else: raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name) dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=conll_path) datasets = [convert_dataset_section(config_name, x) for x in [dataset['train'], dataset['validation'], dataset['test']]] write_dataset(datasets, ner_output_path, short_name) def main(): paths = get_default_paths() ner_input_path = paths['NERBASE'] conll_path = os.path.join(ner_input_path, "english", "en_ontonotes") ner_output_path = paths['NER_DATA_DIR'] process_dataset("en_ontonotes", conll_path, ner_output_path) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/convert_rgai.py ================================================ """ This script converts the Hungarian files available at u-szeged https://rgai.inf.u-szeged.hu/node/130 """ import os import tempfile # we reuse this to split the data randomly from stanza.utils.datasets.ner.split_wikiner import split_wikiner def read_rgai_file(filename, separator): with open(filename, encoding="latin-1") as fin: lines = fin.readlines() lines = [x.strip() for x in lines] for idx, line in enumerate(lines): if not line: continue pieces = lines[idx].split(separator) if len(pieces) != 2: raise ValueError("Line %d is in an unexpected format! Expected exactly two pieces when split on %s" % (idx, separator)) # some of the data has '0' (the digit) instead of 'O' (the letter) if pieces[-1] == '0': pieces[-1] = "O" lines[idx] = "\t".join(pieces) print("Read %d lines from %s" % (len(lines), filename)) return lines def get_rgai_data(base_input_path, use_business, use_criminal): assert use_business or use_criminal, "Must specify one or more sections of the dataset to use" dataset_lines = [] if use_business: business_file = os.path.join(base_input_path, "hun_ner_corpus.txt") lines = read_rgai_file(business_file, "\t") dataset_lines.extend(lines) if use_criminal: # There are two different annotation schemes, Context and # NoContext. NoContext seems to fit better with the # business_file's annotation scheme, since the scores are much # higher when NoContext and hun_ner are combined criminal_file = os.path.join(base_input_path, "HVGJavNENoContext") lines = read_rgai_file(criminal_file, " ") dataset_lines.extend(lines) return dataset_lines def convert_rgai(base_input_path, base_output_path, short_name, use_business, use_criminal): all_data_file = tempfile.NamedTemporaryFile(delete=False) try: raw_data = get_rgai_data(base_input_path, use_business, use_criminal) for line in raw_data: all_data_file.write(line.encode()) all_data_file.write("\n".encode()) all_data_file.close() split_wikiner(base_output_path, all_data_file.name, prefix=short_name) finally: os.unlink(all_data_file.name) ================================================ FILE: stanza/utils/datasets/ner/convert_sindhi_siner.py ================================================ """ Converts the raw data from SiNER to .json for the Stanza NER system https://aclanthology.org/2020.lrec-1.361.pdf """ from stanza.utils.datasets.ner.utils import write_dataset def fix_sentence(sentence): """ Fix some of the mistags in the dataset This covers 11 sentences: 1 P-PERSON, 2 with line breaks in the middle of the tag, and 8 with no B- or I- """ new_sentence = [] for word_idx, word in enumerate(sentence): if word[1] == 'P-PERSON': new_sentence.append((word[0], 'B-PERSON')) elif word[1] == 'B-OT"': new_sentence.append((word[0], 'B-OTHERS')) elif word[1] == 'B-T"': new_sentence.append((word[0], 'B-TITLE')) elif word[1] in ('GPE', 'LOC', 'OTHERS'): if len(new_sentence) > 0 and new_sentence[-1][1][:2] in ('B-', 'I-') and new_sentence[-1][1][2:] == word[1]: # one example... no idea if it should be a break or # not, but the last word translates to "Corporation", # so probably not: ميٽرو پوليٽن ڪارپوريشن new_sentence.append((word[0], 'I-' + word[1])) else: new_sentence.append((word[0], 'B-' + word[1])) else: new_sentence.append(word) return new_sentence def convert_sindhi_siner(in_filename, out_directory, short_name, train_frac=0.8, dev_frac=0.1): """ Read lines from the dataset, crudely separate sentences based on . or !, and write the dataset """ with open(in_filename, encoding="utf-8") as fin: lines = fin.readlines() lines = [x.strip().split("\t") for x in lines] lines = [(x[0].strip(), x[1].strip()) for x in lines if len(x) == 2] print("Read %d words from %s" % (len(lines), in_filename)) sentences = [] prev_idx = 0 for sent_idx, line in enumerate(lines): # maybe also handle line[0] == '،', "Arabic comma"? if line[0] in ('.', '!'): sentences.append(lines[prev_idx:sent_idx+1]) prev_idx=sent_idx+1 # in case the file doesn't end with punctuation, grab the last few lines if prev_idx < len(lines): sentences.append(lines[prev_idx:]) print("Found %d sentences before splitting" % len(sentences)) sentences = [fix_sentence(x) for x in sentences] assert not any('"' in x[1] or x[1].startswith("P-") or x[1] in ("GPE", "LOC", "OTHERS") for sentence in sentences for x in sentence) train_len = int(len(sentences) * train_frac) dev_len = int(len(sentences) * (train_frac+dev_frac)) train_sentences = sentences[:train_len] dev_sentences = sentences[train_len:dev_len] test_sentences = sentences[dev_len:] datasets = (train_sentences, dev_sentences, test_sentences) write_dataset(datasets, out_directory, short_name, suffix="bio") ================================================ FILE: stanza/utils/datasets/ner/convert_starlang_ner.py ================================================ """ Convert the starlang trees to a NER dataset Has to hide quite a few trees with missing NER labels """ import re from stanza.models.constituency import tree_reader import stanza.utils.datasets.constituency.convert_starlang as convert_starlang TURKISH_WORD_RE = re.compile(r"[{]turkish=([^}]+)[}]") TURKISH_LABEL_RE = re.compile(r"[{]namedEntity=([^}]+)[}]") def read_tree(text): """ Reads in a tree, then extracts the word and the NER One problem is that it is unknown if there are cases of two separate items occurring consecutively Note that this is quite similar to the convert_starlang script for constituency. """ trees = tree_reader.read_trees(text) if len(trees) > 1: raise ValueError("Tree file had two trees!") tree = trees[0] words = [] for label in tree.leaf_labels(): match = TURKISH_WORD_RE.search(label) if match is None: raise ValueError("Could not find word in |{}|".format(label)) word = match.group(1) word = word.replace("-LCB-", "{").replace("-RCB-", "}") match = TURKISH_LABEL_RE.search(label) if match is None: raise ValueError("Could not find ner in |{}|".format(label)) tag = match.group(1) if tag == 'NONE' or tag == "null": tag = 'O' words.append((word, tag)) return words def read_starlang(paths): return convert_starlang.read_starlang(paths, conversion=read_tree, log=False) def main(): train, dev, test = convert_starlang.main(conversion=read_tree, log=False) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/count_entities.py ================================================ import argparse from collections import defaultdict import json from stanza.models.common.doc import Document from stanza.utils.datasets.ner.utils import list_doc_entities def parse_args(): parser = argparse.ArgumentParser(description="Report the coverage of one NER file on another.") parser.add_argument('filename', type=str, nargs='+', help='File(s) to count') args = parser.parse_args() return args def count_entities(*filenames): entity_collection = defaultdict(list) for filename in filenames: with open(filename) as fin: doc = Document(json.load(fin)) num_tokens = sum(1 for sentence in doc.sentences for token in sentence.tokens) print("Number of tokens in %s: %d" % (filename, num_tokens)) entities = list_doc_entities(doc) for ent in entities: entity_collection[ent[1]].append(ent[0]) keys = sorted(entity_collection.keys()) for k in keys: print(k, len(entity_collection[k])) def main(): args = parse_args() count_entities(*args.filename) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/json_to_bio.py ================================================ """ If you want to convert .json back to .bio for some reason, this will do it for you """ import argparse import json import os from stanza.models.common.doc import Document from stanza.models.ner.utils import process_tags from stanza.utils.default_paths import get_default_paths def convert_json_to_bio(input_filename, output_filename): with open(input_filename, encoding="utf-8") as fin: doc = Document(json.load(fin)) sentences = [[(word.text, word.ner) for word in sentence.tokens] for sentence in doc.sentences] sentences = process_tags(sentences, "bioes") with open(output_filename, "w", encoding="utf-8") as fout: for sentence in sentences: for word in sentence: fout.write("%s\t%s\n" % word) fout.write("\n") def main(args=None): ner_data_dir = get_default_paths()['NER_DATA_DIR'] parser = argparse.ArgumentParser() parser.add_argument('--input_filename', type=str, default="data/ner/en_foreign-4class.test.json", help='Convert an individual file') parser.add_argument('--input_dir', type=str, default=ner_data_dir, help='Which directory to find the dataset, if using --input_dataset') parser.add_argument('--input_dataset', type=str, help='Convert an entire dataset') parser.add_argument('--output_suffix', type=str, default='bioes', help='suffix for output filenames') args = parser.parse_args(args) if args.input_dataset: input_filenames = [os.path.join(args.input_dir, "%s.%s.json" % (args.input_dataset, shard)) for shard in ("train", "dev", "test")] else: input_filenames = [args.input_filename] for input_filename in input_filenames: output_filename = os.path.splitext(input_filename)[0] + "." + args.output_suffix print("%s -> %s" % (input_filename, output_filename)) convert_json_to_bio(input_filename, output_filename) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/misc_to_date.py ================================================ # for the Worldwide dataset, automatically switch the Misc tags to Date when Stanza Ontonotes thinks it's a Date # this keeps our annotation scheme for dates (eg, not "3 months ago") while hopefully switching them all to Date # # maybe some got missed # also, there are a few with some nested entities. printed out warnings and edited those by hand # # just need to run this with the Worldwide dataset in the ner path # it will automatically convert as many as it can import os from tqdm import tqdm import stanza from stanza.utils.datasets.ner.utils import read_tsv from stanza.utils.default_paths import get_default_paths paths = get_default_paths() BASE_PATH = os.path.join(paths["NERBASE"], "en_foreign") input_dir = os.path.join(BASE_PATH, "en-foreign-newswire") pipe = stanza.Pipeline("en", processors="tokenize,ner", tokenize_pretokenized=True, package={"ner": "ontonotes_bert"}) filenames = [] def ner_tags(pipe, sentence): doc = pipe([sentence]) tags = [token.ner for sentence in doc.sentences for token in sentence.tokens] return tags for root, dirs, files in os.walk(input_dir): if root[-6:] == "REVIEW": batch_files = os.listdir(root) for filename in batch_files: file_path = os.path.join(root, filename) filenames.append(file_path) for filename in tqdm(filenames): try: data = read_tsv(filename, text_column=0, annotation_column=1, skip_comments=False, keep_all_columns=True) with open(filename, 'w', encoding='utf-8') as fout: warned_file = False for sentence in data: # segments delimited by spaces, effectively sentences tokens = [x[0] for x in sentence] labels = [x[1] for x in sentence] if any(x.endswith("Misc") for x in labels): stanza_tags = ner_tags(pipe, tokens) in_date = False for i, stanza_tag in enumerate(stanza_tags): if stanza_tag[2:] == "DATE" and labels[i] != "O": if len(sentence[i]) > 2: if not warned_file: print("Warning: file %s has nested tags being altered" % filename) warned_file = True # put DATE tags where Stanza thinks there are DATEs # as long as we already had a MISC (or something else, I suppose) if in_date and not stanza_tag[0].startswith("B") and not stanza_tag[0].startswith("S"): sentence[i][1] = "I-Date" else: sentence[i][1] = "B-Date" in_date = True elif in_date: # make sure new tags start with B- instead of I- # honestly it's not clear if, in these cases, # we should be switching the following tags to # DATE as well. will have to experiment some in_date = False if labels[i].startswith("I-"): sentence[i][1] = "B-" + labels[i][2:] for word in sentence: fout.write("\t".join(word)) fout.write("\n") fout.write("\n") except AssertionError: print("Could not process %s" % filename) ================================================ FILE: stanza/utils/datasets/ner/ontonotes_multitag.py ================================================ """ Combines OntoNotes and WW into a single dataset with OntoNotes used for dev & test The resulting dataset has two layers saved in the multi_ner column. WW is kept as 9 classes, with the tag put in either the first or second layer depending on the flags. OntoNotes is converted to one column for 18 and one column for 9 classes. """ import argparse import json import os import shutil from stanza.utils import default_paths from stanza.utils.datasets.ner.utils import combine_files from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide def convert_ontonotes_file(filename, simplify, bigger_first): assert "en_ontonotes" in filename if not os.path.exists(filename): raise FileNotFoundError("Cannot convert missing file %s" % filename) new_filename = filename.replace("en_ontonotes", "en_ontonotes-multi") with open(filename) as fin: doc = json.load(fin) for sentence in doc: for word in sentence: ner = word['ner'] if simplify: simplified = simplify_ontonotes_to_worldwide(ner) else: simplified = "-" if bigger_first: word['multi_ner'] = (ner, simplified) else: word['multi_ner'] = (simplified, ner) with open(new_filename, "w") as fout: json.dump(doc, fout, indent=2) def convert_worldwide_file(filename, bigger_first): assert "en_worldwide-9class" in filename if not os.path.exists(filename): raise FileNotFoundError("Cannot convert missing file %s" % filename) new_filename = filename.replace("en_worldwide-9class", "en_worldwide-9class-multi") with open(filename) as fin: doc = json.load(fin) for sentence in doc: for word in sentence: ner = word['ner'] if bigger_first: word['multi_ner'] = ("-", ner) else: word['multi_ner'] = (ner, "-") with open(new_filename, "w") as fout: json.dump(doc, fout, indent=2) def build_multitag_dataset(base_output_path, short_name, simplify, bigger_first): convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.train.json"), simplify, bigger_first) convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.dev.json"), simplify, bigger_first) convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.test.json"), simplify, bigger_first) convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.train.json"), bigger_first) convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.dev.json"), bigger_first) convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.test.json"), bigger_first) combine_files(os.path.join(base_output_path, "%s.train.json" % short_name), os.path.join(base_output_path, "en_ontonotes-multi.train.json"), os.path.join(base_output_path, "en_worldwide-9class-multi.train.json")) shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.dev.json"), os.path.join(base_output_path, "%s.dev.json" % short_name)) shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.test.json"), os.path.join(base_output_path, "%s.test.json" % short_name)) def main(): parser = argparse.ArgumentParser() parser.add_argument('--no_simplify', dest='simplify', action='store_false', help='By default, this script will simplify the OntoNotes 18 classes to the 8 WorldWide classes in a second column. Turning that off will leave that column blank. Initial experiments with that setting were very bad, though') parser.add_argument('--no_bigger_first', dest='bigger_first', action='store_false', help='By default, this script will put the 18 class tags in the first column and the 8 in the second. This flips the order') args = parser.parse_args() paths = default_paths.get_default_paths() base_output_path = paths["NER_DATA_DIR"] build_multitag_dataset(base_output_path, "en_ontonotes-ww-multi", args.simplify, args.bigger_first) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/prepare_ner_dataset.py ================================================ """Converts raw data files into json files usable by the training script. Currently it supports converting WikiNER datasets, available here: https://figshare.com/articles/dataset/Learning_multilingual_named_entity_recognition_from_Wikipedia/5462500 - download the language of interest to {Language}-WikiNER - then run prepare_ner_dataset.py French-WikiNER A gold re-edit of WikiNER for French is here: - https://huggingface.co/datasets/danrun/WikiNER-fr-gold/tree/main - https://arxiv.org/abs/2411.00030 Danrun Cao, Nicolas Béchet, Pierre-François Marteau - download to $NERBASE/wikiner-fr-gold/wikiner-fr-gold.conll prepare_ner_dataset.py fr_wikinergold French WikiNER and its gold re-edit can be mixed together with prepare_ner_dataset.py fr_wikinermixed - the data for both WikiNER and WikiNER-fr-gold needs to be in the right place first Also, Finnish Turku dataset, available here: - https://turkunlp.org/fin-ner.html - https://github.com/TurkuNLP/turku-ner-corpus git clone the repo into $NERBASE/finnish you will now have a directory $NERBASE/finnish/turku-ner-corpus - prepare_ner_dataset.py fi_turku FBK in Italy produced an Italian dataset. - KIND: an Italian Multi-Domain Dataset for Named Entity Recognition Paccosi T. and Palmero Aprosio A. LREC 2022 - https://arxiv.org/abs/2112.15099 The processing here is for a combined .tsv file they sent us. - prepare_ner_dataset.py it_fbk There is a newer version of the data available here: https://github.com/dhfbk/KIND TODO: update to the newer version of the data IJCNLP 2008 produced a few Indian language NER datasets. description: http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=3 download: http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=5 The models produced from these datasets have extremely low recall, unfortunately. - prepare_ner_dataset.py hi_ijc FIRE 2013 also produced NER datasets for Indian languages. http://au-kbc.org/nlp/NER-FIRE2013/index.html The datasets are password locked. For Stanford users, contact Chris Manning for license details. For external users, please contact the organizers for more information. - prepare_ner_dataset.py hi-fire2013 HiNER is another Hindi dataset option https://github.com/cfiltnlp/HiNER - HiNER: A Large Hindi Named Entity Recognition Dataset Murthy, Rudra and Bhattacharjee, Pallab and Sharnagat, Rahul and Khatri, Jyotsana and Kanojia, Diptesh and Bhattacharyya, Pushpak There are two versions: hi_hinercollapsed and hi_hiner The collapsed version has just PER, LOC, ORG - convert data as follows: cd $NERBASE mkdir hindi cd hindi git clone git@github.com:cfiltnlp/HiNER.git python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hi_hiner python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hi_hinercollapsed IL-NER has four datasets: HI, OR, TE, UR https://github.com/ltrc/IL-NER - Fine-tuning Pre-trained Named Entity Recognition Models For Indian Languages Bahad, Sankalp and Mishra, Pruthwik and Krishnamurthy, Parameswari and Sharma, Dipti Convert the data as follows: cd $NERBASE mkdir indic cd indic git clone git@github.com:ltrc/IL-NER.git python3 -m stanza.utils.datasets.ner.prepare_ner_dataset or_ilner suralk/multiNER contains three languages, EN, SI, and TA https://github.com/suralk/multiNER https://arxiv.org/abs/2412.02056 - Ranathunga, Surangika, et al. A Multi-way Parallel Named Entity Annotated Corpus for English, Tamil and Sinhala The tags are in BIO format, with the same 4 tags as CoNLL Convert the data as follows: cd $NERBASE mkdir mixed cd mixed git clone git@github.com:suralk/multiNER.git python3 -m stanza.utils.datasets.ner.prepare_ner_dataset ta_suralk Ukranian NER is provided by lang-uk, available here: https://github.com/lang-uk/ner-uk git clone the repo to $NERBASE/lang-uk There should be a subdirectory $NERBASE/lang-uk/ner-uk/data at that point Conversion script graciously provided by Andrii Garkavyi @gawy - prepare_ner_dataset.py uk_languk There are two Hungarian datasets are available here: https://rgai.inf.u-szeged.hu/node/130 http://www.lrec-conf.org/proceedings/lrec2006/pdf/365_pdf.pdf We combined them and give them the label hu_rgai You can also build individual pieces with hu_rgai_business or hu_rgai_criminal Create a subdirectory of $NERBASE, $NERBASE/hu_rgai, and download both of the pieces and unzip them in that directory. - prepare_ner_dataset.py hu_rgai Another Hungarian dataset is here: - https://github.com/nytud/NYTK-NerKor - git clone the entire thing in your $NERBASE directory to operate on it - prepare_ner_dataset.py hu_nytk The two Hungarian datasets can be combined with hu_combined TODO: verify that there is no overlap in text - prepare_ner_dataset.py hu_combined BSNLP publishes NER datasets for Eastern European languages. - In 2019 they published BG, CS, PL, RU. - http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html - In 2021 they added some more data, but the test sets were not publicly available as of April 2021. Therefore, currently the model is made from 2019. In 2021, the link to the 2021 task is here: http://bsnlp.cs.helsinki.fi/shared-task.html - The below method processes the 2019 version of the corpus. It has specific adjustments for the BG section, which has quite a few typos or mis-annotations in it. Other languages probably need similar work in order to function optimally. - make a directory $NERBASE/bsnlp2019 - download the "training data are available HERE" and "test data are available HERE" to this subdirectory - unzip those files in that directory - we use the code name "bg_bsnlp19". Other languages from bsnlp 2019 can be supported by adding the appropriate functionality in convert_bsnlp.py. - prepare_ner_dataset.py bg_bsnlp19 NCHLT produced NER datasets for many African languages. Unfortunately, it is difficult to make use of many of these, as there is no corresponding UD data from which to build a tokenizer or other tools. - Afrikaans: https://repo.sadilar.org/handle/20.500.12185/299 - isiNdebele: https://repo.sadilar.org/handle/20.500.12185/306 - isiXhosa: https://repo.sadilar.org/handle/20.500.12185/312 - isiZulu: https://repo.sadilar.org/handle/20.500.12185/319 - Sepedi: https://repo.sadilar.org/handle/20.500.12185/328 - Sesotho: https://repo.sadilar.org/handle/20.500.12185/334 - Setswana: https://repo.sadilar.org/handle/20.500.12185/341 - Siswati: https://repo.sadilar.org/handle/20.500.12185/346 - Tsivenda: https://repo.sadilar.org/handle/20.500.12185/355 - Xitsonga: https://repo.sadilar.org/handle/20.500.12185/362 Agree to the license, download the zip, and unzip it in $NERBASE/NCHLT UCSY built a Myanmar dataset. They have not made it publicly available, but they did make it available to Stanford for research purposes. Contact Chris Manning or John Bauer for the data files if you are Stanford affiliated. - https://arxiv.org/abs/1903.04739 - Syllable-based Neural Named Entity Recognition for Myanmar Language by Hsu Myat Mo and Khin Mar Soe Hanieh Poostchi et al produced a Persian NER dataset: - git@github.com:HaniehP/PersianNER.git - https://github.com/HaniehP/PersianNER - Hanieh Poostchi, Ehsan Zare Borzeshi, Mohammad Abdous, and Massimo Piccardi, "PersoNER: Persian Named-Entity Recognition" - Hanieh Poostchi, Ehsan Zare Borzeshi, and Massimo Piccardi, "BiLSTM-CRF for Persian Named-Entity Recognition; ArmanPersoNERCorpus: the First Entity-Annotated Persian Dataset" - Conveniently, this dataset is already in BIO format. It does not have a dev split, though. git clone the above repo, unzip ArmanPersoNERCorpus.zip, and this script will split the first train fold into a dev section. SUC3 is a Swedish NER dataset provided by Språkbanken - https://spraakbanken.gu.se/en/resources/suc3 - The splitting tool is generously provided by Emil Stenstrom https://github.com/EmilStenstrom/suc_to_iob - Download the .bz2 file at this URL and put it in $NERBASE/sv_suc3shuffle It is not necessary to unzip it. - Gustafson-Capková, Sophia and Britt Hartmann, 2006, Manual of the Stockholm Umeå Corpus version 2.0. Stockholm University. - Östling, Robert, 2013, Stagger an Open-Source Part of Speech Tagger for Swedish Northern European Journal of Language Technology 3: 1–18 DOI 10.3384/nejlt.2000-1533.1331 - The shuffled dataset can be converted with dataset code prepare_ner_dataset.py sv_suc3shuffle - If you fill out the license form and get the official data, you can get the official splits by putting the provided zip file in $NERBASE/sv_suc3licensed. Again, not necessary to unzip it python3 -m stanza.utils.datasets.ner.prepare_ner_dataset sv_suc3licensed DDT is a reformulation of the Danish Dependency Treebank as an NER dataset - https://danlp-alexandra.readthedocs.io/en/latest/docs/datasets.html#dane - direct download link as of late 2021: https://danlp.alexandra.dk/304bd159d5de/datasets/ddt.zip - https://aclanthology.org/2020.lrec-1.565.pdf DaNE: A Named Entity Resource for Danish Rasmus Hvingelby, Amalie Brogaard Pauli, Maria Barrett, Christina Rosted, Lasse Malm Lidegaard, Anders Søgaard - place ddt.zip in $NERBASE/da_ddt/ddt.zip python3 -m stanza.utils.datasets.ner.prepare_ner_dataset da_ddt NorNE is the Norwegian Dependency Treebank with NER labels - LREC 2020 NorNE: Annotating Named Entities for Norwegian Fredrik Jørgensen, Tobias Aasmoe, Anne-Stine Ruud Husevåg, Lilja Øvrelid, and Erik Velldal - both Bokmål and Nynorsk - This dataset is in a git repo: https://github.com/ltgoslo/norne Clone it into $NERBASE git clone git@github.com:ltgoslo/norne.git python3 -m stanza.utils.datasets.ner.prepare_ner_dataset nb_norne python3 -m stanza.utils.datasets.ner.prepare_ner_dataset nn_norne tr_starlang is a set of constituency trees for Turkish The words in this dataset (usually) have NER labels as well A dataset in three parts from the Starlang group in Turkey: Neslihan Kara, Büşra Marşan, et al Creating A Syntactically Felicitous Constituency Treebank For Turkish https://ieeexplore.ieee.org/document/9259873 git clone the following three repos https://github.com/olcaytaner/TurkishAnnotatedTreeBank-15 https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-15 https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-20 Put them in $CONSTITUENCY_HOME/turkish (yes, the constituency home) python3 -m stanza.utils.datasets.ner.prepare_ner_dataset tr_starlang GermEval2014 is a German NER dataset https://sites.google.com/site/germeval2014ner/data https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J Download the files in that directory NER-de-train.tsv NER-de-dev.tsv NER-de-test.tsv put them in $NERBASE/germeval2014 then run python3 -m stanza.utils.datasets.ner.prepare_ner_dataset de_germeval2014 The UD Japanese GSD dataset has a conversion by Megagon Labs https://github.com/megagonlabs/UD_Japanese-GSD https://github.com/megagonlabs/UD_Japanese-GSD/tags - r2.9-NE has the NE tagged files inside a "spacy" folder in the download - expected directory for this data: unzip the .zip of the release into $NERBASE/ja_gsd so it should wind up in $NERBASE/ja_gsd/UD_Japanese-GSD-r2.9-NE python3 -m stanza.utils.datasets.ner.prepare_ner_dataset ja_gsd L3Cube is a Marathi dataset - https://arxiv.org/abs/2204.06029 https://arxiv.org/pdf/2204.06029.pdf https://github.com/l3cube-pune/MarathiNLP - L3Cube-MahaNER: A Marathi Named Entity Recognition Dataset and BERT models Parth Patil, Aparna Ranade, Maithili Sabane, Onkar Litake, Raviraj Joshi Clone the repo into $NERBASE/marathi git clone git@github.com:l3cube-pune/MarathiNLP.git Then run python3 -m stanza.utils.datasets.ner.prepare_ner_dataset mr_l3cube Daffodil University produced a Bangla NER dataset - https://github.com/Rifat1493/Bengali-NER - https://ieeexplore.ieee.org/document/8944804 - Bengali Named Entity Recognition: A survey with deep learning benchmark Md Jamiur Rahman Rifat, Sheikh Abujar, Sheak Rashed Haider Noori, Syed Akhter Hossain Clone the repo into a "bangla" subdirectory of $NERBASE cd $NERBASE/bangla git clone git@github.com:Rifat1493/Bengali-NER.git Then run python3 -m stanza.utils.datasets.ner.prepare_ner_dataset bn_daffodil LST20 is a Thai NER dataset from 2020 - https://arxiv.org/abs/2008.05055 The Annotation Guideline of LST20 Corpus Prachya Boonkwan, Vorapon Luantangsrisuk, Sitthaa Phaholphinyo, Kanyanat Kriengket, Dhanon Leenoi, Charun Phrombut, Monthika Boriboon, Krit Kosawat, Thepchai Supnithi - This script processes a version which can be downloaded here after registration: https://aiforthai.in.th/index.php - There is another version downloadable from HuggingFace The script will likely need some modification to be compatible with the HuggingFace version - Download the data in $NERBASE/thai/LST20_Corpus There should be "train", "eval", "test" directories after downloading - Then run pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset th_lst20 Thai-NNER is another Thai NER dataset, from 2022 - https://github.com/vistec-AI/Thai-NNER - https://aclanthology.org/2022.findings-acl.116/ Thai Nested Named Entity Recognition Corpus Weerayut Buaphet, Can Udomcharoenchaikit, Peerat Limkonchotiwat, Attapol Rutherford, and Sarana Nutanong - git clone the data to $NERBASE/thai - On the git repo, there should be a link to a more complete version of the dataset. For example, in Sep. 2023 it is here: https://github.com/vistec-AI/Thai-NNER#dataset The Google drive it goes to has "postproc". Put the train.json, dev.json, and test.json in $NERBASE/thai/Thai-NNER/data/scb-nner-th-2022/postproc/ - Then run pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset th_nner22 NKJP is a Polish NER dataset - http://nkjp.pl/index.php?page=0&lang=1 About the Project - http://zil.ipipan.waw.pl/DistrNKJP Wikipedia subcorpus used to train charlm model - http://clip.ipipan.waw.pl/NationalCorpusOfPolish?action=AttachFile&do=view&target=NKJP-PodkorpusMilionowy-1.2.tar.gz Annotated subcorpus to train NER model. Download and extract to $NERBASE/Polish-NKJP or leave the gzip in $NERBASE/polish/... kk_kazNERD is a Kazakh dataset published in 2021 - https://github.com/IS2AI/KazNERD - https://arxiv.org/abs/2111.13419 KazNERD: Kazakh Named Entity Recognition Dataset Rustem Yeshpanov, Yerbolat Khassanov, Huseyin Atakan Varol - in $NERBASE, make a "kazakh" directory, then git clone the repo there mkdir -p $NERBASE/kazakh cd $NERBASE/kazakh git clone git@github.com:IS2AI/KazNERD.git - Then run pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset kk_kazNERD Masakhane NER is a set of NER datasets for African languages - MasakhaNER: Named Entity Recognition for African Languages Adelani, David Ifeoluwa; Abbott, Jade; Neubig, Graham; D’souza, Daniel; Kreutzer, Julia; Lignos, Constantine; Palen-Michel, Chester; Buzaaba, Happy; Rijhwani, Shruti; Ruder, Sebastian; Mayhew, Stephen; Azime, Israel Abebe; Muhammad, Shamsuddeen H.; Emezue, Chris Chinenye; Nakatumba-Nabende, Joyce; Ogayo, Perez; Anuoluwapo, Aremu; Gitau, Catherine; Mbaye, Derguene; Alabi, Jesujoba; Yimam, Seid Muhie; Gwadabe, Tajuddeen Rabiu; Ezeani, Ignatius; Niyongabo, Rubungo Andre; Mukiibi, Jonathan; Otiende, Verrah; Orife, Iroro; David, Davis; Ngom, Samba; Adewumi, Tosin; Rayson, Paul; Adeyemi, Mofetoluwa; Muriuki, Gerald; Anebi, Emmanuel; Chukwuneke, Chiamaka; Odu, Nkiruka; Wairagala, Eric Peter; Oyerinde, Samuel; Siro, Clemencia; Bateesa, Tobius Saul; Oloyede, Temilola; Wambui, Yvonne; Akinode, Victor; Nabagereka, Deborah; Katusiime, Maurice; Awokoya, Ayodele; MBOUP, Mouhamadane; Gebreyohannes, Dibora; Tilaye, Henok; Nwaike, Kelechi; Wolde, Degaga; Faye, Abdoulaye; Sibanda, Blessing; Ahia, Orevaoghene; Dossou, Bonaventure F. P.; Ogueji, Kelechi; DIOP, Thierno Ibrahima; Diallo, Abdoulaye; Akinfaderin, Adewale; Marengereke, Tendai; Osei, Salomey - https://github.com/masakhane-io/masakhane-ner - git clone the repo to $NERBASE - Then run python3 -m stanza.utils.datasets.ner.prepare_ner_dataset lcode_masakhane - You can use the full language name, the 3 letter language code, or in the case of languages with a 2 letter language code, the 2 letter code for lcode. The tool will throw an error if the language is not supported in Masakhane. SiNER is a Sindhi NER dataset - https://aclanthology.org/2020.lrec-1.361/ SiNER: A Large Dataset for Sindhi Named Entity Recognition Wazir Ali, Junyu Lu, Zenglin Xu - It is available via git repository https://github.com/AliWazir/SiNER-dataset As of Nov. 2022, there were a few changes to the dataset to update a couple instances of broken tags & tokenization - Clone the repo to $NERBASE/sindhi mkdir $NERBASE/sindhi cd $NERBASE/sindhi git clone git@github.com:AliWazir/SiNER-dataset.git - Then, prepare the dataset with this script: python3 -m stanza.utils.datasets.ner.prepare_ner_dataset sd_siner en_sample is the toy dataset included with stanza-train https://github.com/stanfordnlp/stanza-train this is not meant for any kind of actual NER use ArmTDP-NER is an Armenian NER dataset - https://github.com/myavrum/ArmTDP-NER.git ArmTDP-NER: The corpus was developed by the ArmTDP team led by Marat M. Yavrumyan at the Yerevan State University by the collaboration of "Armenia National SDG Innovation Lab" and "UC Berkley's Armenian Linguists' network". - in $NERBASE, make a "armenian" directory, then git clone the repo there mkdir -p $NERBASE/armenian cd $NERBASE/armenian git clone https://github.com/myavrum/ArmTDP-NER.git - Then run python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hy_armtdp en_conll03 is the classic 2003 4 class CoNLL dataset - The version we use is posted on HuggingFace - https://huggingface.co/datasets/conll2003 - The prepare script will download from HF using the datasets package, then convert to json - Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition Tjong Kim Sang, Erik F. and De Meulder, Fien - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conll03 en_conll03ww is CoNLL 03 with Worldwide added to the training data. - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conll03ww en_conllpp is a test set from 2020 newswire - https://arxiv.org/abs/2212.09747 - https://github.com/ShuhengL/acl2023_conllpp - Do CoNLL-2003 Named Entity Taggers Still Work Well in 2023? Shuheng Liu, Alan Ritter - git clone the repo in $NERBASE - then run python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conllpp en_ontonotes is the OntoNotes 5 on HuggingFace - https://huggingface.co/datasets/conll2012_ontonotesv5 - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_ontonotes - this downloads the "v12" version of the data en_worldwide-4class is an English non-US newswire dataset - annotated by MLTwist and Aya Data, with help from Datasaur, collected at Stanford - work to be published at EMNLP Findings - the 4 class version is converted to the 4 classes in conll, then split into train/dev/test - clone https://github.com/stanfordnlp/en-worldwide-newswire into $NERBASE/en_worldwide en_worldwide-9class is an English non-US newswire dataset - annotated by MLTwist and Aya Data, with help from Datasaur, collected at Stanford - work to be published at EMNLP Findings - the 9 class version is not edited - clone https://github.com/stanfordnlp/en-worldwide-newswire into $NERBASE/en_worldwide zh-hans_ontonotes is the ZH split of the OntoNotes dataset - https://catalog.ldc.upenn.edu/LDC2013T19 - https://huggingface.co/datasets/conll2012_ontonotesv5 - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py zh-hans_ontonotes - this downloads the "v4" version of the data AQMAR is a small dataset of Arabic Wikipedia articles - http://www.cs.cmu.edu/~ark/ArabicNER/ - Recall-Oriented Learning of Named Entities in Arabic Wikipedia Behrang Mohit, Nathan Schneider, Rishav Bhowmick, Kemal Oflazer, and Noah A. Smith. In Proceedings of the 13th Conference of the European Chapter of the Association for Computational Linguistics, Avignon, France, April 2012. - download the .zip file there and put it in $NERBASE/arabic/AQMAR - there is a challenge for it here: https://www.topcoder.com/challenges/f3cf483e-a95c-4a7e-83e8-6bdd83174d38 - alternatively, we just randomly split it ourselves - currently, running the following reproduces the random split: python3 stanza/utils/datasets/ner/prepare_ner_dataset.py ar_aqmar IAHLT contains NER for Hebrew in the knesset treebank - as of UD 2.14, it is only in the git repo - download that git repo to $UDBASE_GIT: https://github.com/UniversalDependencies/UD_Hebrew-IAHLTknesset - change to the dev branch in that repo python3 stanza/utils/datasets/ner/prepare_ner_dataset.py he_iahlt ang_ewt is an Old English dataset available here: https://github.com/dmetola/Old_English-OEDT/tree/main As more information, including a citation, will be added here - install in NERBASE: mkdir $NERBASE/ang cd $NERBASE/ang git clone git@github.com:dmetola/Old_English-OEDT.git - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py ang_ewt """ import glob import os import json import random import re import shutil import sys import tempfile from stanza.models.common.constant import treebank_to_short_name, lcode2lang, lang_to_langcode, two_to_three_letters from stanza.models.ner.utils import to_bio2, bio2_to_bioes import stanza.utils.default_paths as default_paths from stanza.utils.datasets.common import UnknownDatasetError from stanza.utils.datasets.ner.preprocess_wikiner import preprocess_wikiner from stanza.utils.datasets.ner.split_wikiner import split_wikiner, split_wikiner_data import stanza.utils.datasets.ner.build_en_combined as build_en_combined import stanza.utils.datasets.ner.conll_to_iob as conll_to_iob import stanza.utils.datasets.ner.convert_ar_aqmar as convert_ar_aqmar import stanza.utils.datasets.ner.convert_bn_daffodil as convert_bn_daffodil import stanza.utils.datasets.ner.convert_bsf_to_beios as convert_bsf_to_beios import stanza.utils.datasets.ner.convert_bsnlp as convert_bsnlp import stanza.utils.datasets.ner.convert_en_conll03 as convert_en_conll03 import stanza.utils.datasets.ner.convert_fire_2013 as convert_fire_2013 import stanza.utils.datasets.ner.convert_he_iahlt as convert_he_iahlt import stanza.utils.datasets.ner.convert_ijc as convert_ijc import stanza.utils.datasets.ner.convert_kk_kazNERD as convert_kk_kazNERD import stanza.utils.datasets.ner.convert_lst20 as convert_lst20 import stanza.utils.datasets.ner.convert_nner22 as convert_nner22 import stanza.utils.datasets.ner.convert_mr_l3cube as convert_mr_l3cube import stanza.utils.datasets.ner.convert_my_ucsy as convert_my_ucsy import stanza.utils.datasets.ner.convert_ontonotes as convert_ontonotes import stanza.utils.datasets.ner.convert_rgai as convert_rgai import stanza.utils.datasets.ner.convert_nytk as convert_nytk import stanza.utils.datasets.ner.convert_starlang_ner as convert_starlang_ner import stanza.utils.datasets.ner.convert_nkjp as convert_nkjp import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file import stanza.utils.datasets.ner.convert_sindhi_siner as convert_sindhi_siner import stanza.utils.datasets.ner.ontonotes_multitag as ontonotes_multitag import stanza.utils.datasets.ner.simplify_en_worldwide as simplify_en_worldwide import stanza.utils.datasets.ner.suc_to_iob as suc_to_iob import stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob import stanza.utils.datasets.ner.convert_hy_armtdp as convert_hy_armtdp from stanza.utils.datasets.ner.utils import convert_bioes_to_bio, convert_bio_to_json, get_tags, read_tsv, write_sentences, write_dataset, random_shuffle_by_prefixes, read_prefix_file, combine_files SHARDS = ('train', 'dev', 'test') def process_turku(paths, short_name): assert short_name == 'fi_turku' base_input_path = os.path.join(paths["NERBASE"], "finnish", "turku-ner-corpus", "data", "conll") base_output_path = paths["NER_DATA_DIR"] for shard in SHARDS: input_filename = os.path.join(base_input_path, '%s.tsv' % shard) if not os.path.exists(input_filename): raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename)) output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) prepare_ner_file.process_dataset(input_filename, output_filename) def process_it_fbk(paths, short_name): assert short_name == "it_fbk" base_input_path = os.path.join(paths["NERBASE"], short_name) csv_file = os.path.join(base_input_path, "all-wiki-split.tsv") if not os.path.exists(csv_file): raise FileNotFoundError("Cannot find the FBK dataset in its expected location: {}".format(csv_file)) base_output_path = paths["NER_DATA_DIR"] split_wikiner(base_output_path, csv_file, prefix=short_name, suffix="io", shuffle=False, train_fraction=0.8, dev_fraction=0.1) convert_bio_to_json(base_output_path, base_output_path, short_name, suffix="io") def process_suralk_multiner(paths, short_name): lang_filenames = { "en": "Final_English.txt", "si": "Final_Sinhala.txt", "ta": "Final_Tamil.txt", } lang, ending = short_name.split("_") assert ending == "suralk" assert lang in lang_filenames, "suralk/multiNER only supports %s" % (", ".join(lang_filenames.keys())) suralk_path = os.path.join(paths["NERBASE"], "mixed", "multiNER", "nerannotateddatasets.zip") if not os.path.exists(suralk_path): raise FileNotFoundError("Expected to find the suralk/multiNER dataset in %s" % suralk_path) sentences = read_tsv(lang_filenames[lang], text_column=0, annotation_column=1, separator=None, zip_filename=suralk_path) print("Read %d sentences from %s::%s" % (len(sentences), suralk_path, lang_filenames[lang])) base_output_path = paths["NER_DATA_DIR"] split_wikiner_data(base_output_path, sentences, prefix=short_name, suffix="bio", shuffle=True) convert_bio_to_json(base_output_path, base_output_path, short_name, suffix="bio") def process_il_ner(paths, short_name): joiner = chr(0x200c) def fix_tag(tag): if tag == '-': return 'O' if tag.endswith("'"): # not sure the correct fix, but we filed an issue, so hopefully they fix it return "O" if tag.endswith("NIMI") or tag.endswith("NET"): return tag[:2] + "NETI" tag = tag.replace(joiner, "").upper() if tag.startswith("-"): return 'B%s' % tag return tag def fix_line(line): if line == 'O': return '-\tO' return line lang_paths = { "hi": "Hindi", "or": "Odia", "te": "Telugu", "ur": "Urdu", } lang, ending = short_name.split("_") assert ending == "ilner" assert lang in lang_paths, "IL-NER only supports %s" % (", ".join(lang_paths.keys())) ilner_path = os.path.join(paths["NERBASE"], "indic", "IL-NER") if not os.path.exists(ilner_path): raise FileNotFounderror("Cannot find the IL-NER dataset in its expected location: {}".format(ilner_path)) ilner_path = os.path.join(ilner_path, "Datasets", lang_paths[lang]) if not os.path.exists(ilner_path): raise FileNotFoundError("IL-NER not in the layout expected: directory not found {}".format(ilner_path)) filenames = os.listdir(ilner_path) base_output_path = paths["NER_DATA_DIR"] for shard in SHARDS: input_filenames = [x for x in filenames if shard in x] if len(input_filenames) == 0: raise FileNotFoundError("No %s file in %s" % (shard, ilner_path)) if len(input_filenames) > 1: raise FileNotFoundError("Unexpected multiple files for %s in %s: %s" % (shard, ilner_path, input_filenames)) input_filename = os.path.join(ilner_path, input_filenames[0]) int_filename = os.path.join(base_output_path, '%s.%s.tsv' % (short_name, shard)) output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) sentences = read_tsv(input_filename, text_column=0, annotation_column=1, remap_tag_fn=fix_tag, remap_line=fix_line) print("Loaded %d sentences from %s" % (len(sentences), input_filename)) write_sentences(int_filename, sentences) prepare_ner_file.process_dataset(int_filename, output_filename) def process_languk(paths, short_name): assert short_name == 'uk_languk' base_input_path = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'data') base_output_path = paths["NER_DATA_DIR"] train_test_split_fname = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'doc', 'dev-test-split.txt') convert_bsf_to_beios.convert_bsf_in_folder(base_input_path, base_output_path, train_test_split_file=train_test_split_fname) for shard in SHARDS: input_filename = os.path.join(base_output_path, convert_bsf_to_beios.CORPUS_NAME, "%s.bio" % shard) if not os.path.exists(input_filename): raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename)) output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) prepare_ner_file.process_dataset(input_filename, output_filename) def process_ijc(paths, short_name): """ Splits the ijc Hindi dataset in train, dev, test The original data had train & test splits, so we randomly divide the files in train to make a dev set. The expected location of the IJC data is hi_ijc. This method should be possible to use for other languages, but we have very little support for the other languages of IJC at the moment. """ base_input_path = os.path.join(paths["NERBASE"], short_name) base_output_path = paths["NER_DATA_DIR"] test_files = [os.path.join(base_input_path, "test-data-hindi.txt")] test_csv_file = os.path.join(base_output_path, short_name + ".test.csv") print("Converting test input %s to space separated file in %s" % (test_files[0], test_csv_file)) convert_ijc.convert_ijc(test_files, test_csv_file) train_input_path = os.path.join(base_input_path, "training-hindi", "*utf8") train_files = glob.glob(train_input_path) train_csv_file = os.path.join(base_output_path, short_name + ".train.csv") dev_csv_file = os.path.join(base_output_path, short_name + ".dev.csv") print("Converting training input from %s to space separated files in %s and %s" % (train_input_path, train_csv_file, dev_csv_file)) convert_ijc.convert_split_ijc(train_files, train_csv_file, dev_csv_file) for csv_file, shard in zip((train_csv_file, dev_csv_file, test_csv_file), SHARDS): output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) prepare_ner_file.process_dataset(csv_file, output_filename) def process_fire_2013(paths, dataset): """ Splits the FIRE 2013 dataset into train, dev, test The provided datasets are all mixed together at this point, so it is not possible to recreate the original test conditions used in the bakeoff """ short_name = treebank_to_short_name(dataset) langcode, _ = short_name.split("_") short_name = "%s_fire2013" % langcode if not langcode in ("hi", "en", "ta", "bn", "mal"): raise UnkonwnDatasetError(dataset, "Language %s not one of the FIRE 2013 languages" % langcode) language = lcode2lang[langcode].lower() # for example, FIRE2013/hindi_train base_input_path = os.path.join(paths["NERBASE"], "FIRE2013", "%s_train" % language) base_output_path = paths["NER_DATA_DIR"] train_csv_file = os.path.join(base_output_path, "%s.train.csv" % short_name) dev_csv_file = os.path.join(base_output_path, "%s.dev.csv" % short_name) test_csv_file = os.path.join(base_output_path, "%s.test.csv" % short_name) convert_fire_2013.convert_fire_2013(base_input_path, train_csv_file, dev_csv_file, test_csv_file) for csv_file, shard in zip((train_csv_file, dev_csv_file, test_csv_file), SHARDS): output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) prepare_ner_file.process_dataset(csv_file, output_filename) def process_wikiner(paths, dataset): short_name = treebank_to_short_name(dataset) base_input_path = os.path.join(paths["NERBASE"], dataset) base_output_path = paths["NER_DATA_DIR"] expected_filename = "aij*wikiner*" input_files = [x for x in glob.glob(os.path.join(base_input_path, expected_filename)) if not x.endswith("bz2")] if len(input_files) == 0: raw_input_path = os.path.join(base_input_path, "raw") input_files = [x for x in glob.glob(os.path.join(raw_input_path, expected_filename)) if not x.endswith("bz2")] if len(input_files) > 1: raise FileNotFoundError("Found too many raw wikiner files in %s: %s" % (raw_input_path, ", ".join(input_files))) elif len(input_files) > 1: raise FileNotFoundError("Found too many raw wikiner files in %s: %s" % (base_input_path, ", ".join(input_files))) if len(input_files) == 0: raise FileNotFoundError("Could not find any raw wikiner files in %s or %s" % (base_input_path, raw_input_path)) csv_file = os.path.join(base_output_path, short_name + "_csv") print("Converting raw input %s to space separated file in %s" % (input_files[0], csv_file)) try: preprocess_wikiner(input_files[0], csv_file) except UnicodeDecodeError: preprocess_wikiner(input_files[0], csv_file, encoding="iso8859-1") # this should create train.bio, dev.bio, and test.bio print("Splitting %s to %s" % (csv_file, base_output_path)) split_wikiner(base_output_path, csv_file, prefix=short_name) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_french_wikiner_gold(paths, dataset): short_name = treebank_to_short_name(dataset) base_input_path = os.path.join(paths["NERBASE"], "wikiner-fr-gold") base_output_path = paths["NER_DATA_DIR"] input_filename = os.path.join(base_input_path, "wikiner-fr-gold.conll") if not os.path.exists(input_filename): raise FileNotFoundError("Could not find the expected input file %s for dataset %s" % (input_filename, base_input_path)) print("Reading %s" % input_filename) sentences = read_tsv(input_filename, text_column=0, annotation_column=2, separator=" ") print("Read %d sentences" % len(sentences)) tags = [y for sentence in sentences for x, y in sentence] tags = sorted(set(tags)) print("Found the following tags:\n%s" % tags) expected_tags = ['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'E-LOC', 'E-MISC', 'E-ORG', 'E-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O', 'S-LOC', 'S-MISC', 'S-ORG', 'S-PER'] assert tags == expected_tags output_filename = os.path.join(base_output_path, "%s.full.bioes" % short_name) print("Writing BIOES to %s" % output_filename) write_sentences(output_filename, sentences) print("Splitting %s to %s" % (output_filename, base_output_path)) split_wikiner(base_output_path, output_filename, prefix=short_name, suffix="bioes") convert_bioes_to_bio(base_output_path, base_output_path, short_name) convert_bio_to_json(base_output_path, base_output_path, short_name, suffix="bioes") def process_french_wikiner_mixed(paths, dataset): """ Build both the original and gold edited versions of WikiNER, then mix them First we eliminate any duplicates (with one exception), then we combine the data There are two main ways we could have done this: - mix it together without any restrictions - use the multi_ner mechanism to build a dataset which represents two prediction heads The second method seems to give slightly better results than the first method, but neither beat just using a transformer on the gold set alone On the randomly selected test set, using WV and charlm but not a transformer (this was on a previously published version of the dataset): one prediction head: INFO: Score by entity: Prec. Rec. F1 89.32 89.26 89.29 INFO: Score by token: Prec. Rec. F1 89.43 86.88 88.14 INFO: Weighted f1 for non-O tokens: 0.878855 two prediction heads: INFO: Score by entity: Prec. Rec. F1 89.83 89.76 89.79 INFO: Score by token: Prec. Rec. F1 89.17 88.15 88.66 INFO: Weighted f1 for non-O tokens: 0.885675 On a randomly selected dev set, using transformer: gold: INFO: Score by entity: Prec. Rec. F1 93.63 93.98 93.81 INFO: Score by token: Prec. Rec. F1 92.80 92.79 92.80 INFO: Weighted f1 for non-O tokens: 0.927548 mixed: INFO: Score by entity: Prec. Rec. F1 93.54 93.82 93.68 INFO: Score by token: Prec. Rec. F1 92.99 92.51 92.75 INFO: Weighted f1 for non-O tokens: 0.926964 """ short_name = treebank_to_short_name(dataset) process_french_wikiner_gold(paths, "fr_wikinergold") process_wikiner(paths, "French-WikiNER") base_output_path = paths["NER_DATA_DIR"] with open(os.path.join(base_output_path, "fr_wikinergold.train.json")) as fin: gold_train = json.load(fin) with open(os.path.join(base_output_path, "fr_wikinergold.dev.json")) as fin: gold_dev = json.load(fin) with open(os.path.join(base_output_path, "fr_wikinergold.test.json")) as fin: gold_test = json.load(fin) gold = gold_train + gold_dev + gold_test print("%d total sentences in the gold relabeled dataset (randomly split)" % len(gold)) gold = {tuple([x["text"] for x in sentence]): sentence for sentence in gold} print(" (%d after dedup)" % len(gold)) original = (read_tsv(os.path.join(base_output_path, "fr_wikiner.train.bio"), text_column=0, annotation_column=1) + read_tsv(os.path.join(base_output_path, "fr_wikiner.dev.bio"), text_column=0, annotation_column=1) + read_tsv(os.path.join(base_output_path, "fr_wikiner.test.bio"), text_column=0, annotation_column=1)) print("%d total sentences in the original wiki" % len(original)) original_words = {tuple([x[0] for x in sentence]) for sentence in original} print(" (%d after dedup)" % len(original_words)) missing = [sentence for sentence in gold if sentence not in original_words] for sentence in missing: # the capitalization of WisiGoths and OstroGoths is different # between the original and the new in some cases goths = tuple([x.replace("Goth", "goth") for x in sentence]) if goths != sentence and goths in original_words: original_words.add(sentence) missing = [sentence for sentence in gold if sentence not in original_words] # currently this dataset doesn't find two sentences # one was dropped by the filter for incompletely tagged lines # the other is probably not a huge deal to have one duplicate print("Missing %d sentences" % len(missing)) assert len(missing) <= 2 for sent in missing: print(sent) skipped = 0 silver = [] silver_used = set() for sentence in original: words = tuple([x[0] for x in sentence]) tags = tuple([x[1] for x in sentence]) if words in gold or words in silver_used: skipped += 1 continue tags = to_bio2(tags) tags = bio2_to_bioes(tags) sentence = [{"text": x, "ner": y, "multi_ner": ["-", y]} for x, y in zip(words, tags)] silver.append(sentence) silver_used.add(words) print("Using %d sentences from the original wikiner alongside the gold annotated train set" % len(silver)) print("Skipped %d sentences" % skipped) gold_train = [[{"text": x["text"], "ner": x["ner"], "multi_ner": [x["ner"], "-"]} for x in sentence] for sentence in gold_train] gold_dev = [[{"text": x["text"], "ner": x["ner"], "multi_ner": [x["ner"], "-"]} for x in sentence] for sentence in gold_dev] gold_test = [[{"text": x["text"], "ner": x["ner"], "multi_ner": [x["ner"], "-"]} for x in sentence] for sentence in gold_test] mixed_train = gold_train + silver print("Total sentences in the mixed training set: %d" % len(mixed_train)) output_filename = os.path.join(base_output_path, "%s.train.json" % short_name) with open(output_filename, 'w', encoding='utf-8') as fout: json.dump(mixed_train, fout, indent=1) output_filename = os.path.join(base_output_path, "%s.dev.json" % short_name) with open(output_filename, 'w', encoding='utf-8') as fout: json.dump(gold_dev, fout, indent=1) output_filename = os.path.join(base_output_path, "%s.test.json" % short_name) with open(output_filename, 'w', encoding='utf-8') as fout: json.dump(gold_test, fout, indent=1) def get_rgai_input_path(paths): return os.path.join(paths["NERBASE"], "hu_rgai") def process_rgai(paths, short_name): base_output_path = paths["NER_DATA_DIR"] base_input_path = get_rgai_input_path(paths) if short_name == 'hu_rgai': use_business = True use_criminal = True elif short_name == 'hu_rgai_business': use_business = True use_criminal = False elif short_name == 'hu_rgai_criminal': use_business = False use_criminal = True else: raise UnknownDatasetError(short_name, "Unknown subset of hu_rgai data: %s" % short_name) convert_rgai.convert_rgai(base_input_path, base_output_path, short_name, use_business, use_criminal) convert_bio_to_json(base_output_path, base_output_path, short_name) def get_nytk_input_path(paths): return os.path.join(paths["NERBASE"], "NYTK-NerKor") def process_nytk(paths, short_name): """ Process the NYTK dataset """ assert short_name == "hu_nytk" base_output_path = paths["NER_DATA_DIR"] base_input_path = get_nytk_input_path(paths) convert_nytk.convert_nytk(base_input_path, base_output_path, short_name) convert_bio_to_json(base_output_path, base_output_path, short_name) def concat_files(output_file, *input_files): input_lines = [] for input_file in input_files: with open(input_file) as fin: lines = fin.readlines() if not len(lines): raise ValueError("Empty input file: %s" % input_file) if not lines[-1]: lines[-1] = "\n" elif lines[-1].strip(): lines.append("\n") input_lines.append(lines) with open(output_file, "w") as fout: for lines in input_lines: for line in lines: fout.write(line) def process_hu_combined(paths, short_name): assert short_name == "hu_combined" base_output_path = paths["NER_DATA_DIR"] rgai_input_path = get_rgai_input_path(paths) nytk_input_path = get_nytk_input_path(paths) with tempfile.TemporaryDirectory() as tmp_output_path: convert_rgai.convert_rgai(rgai_input_path, tmp_output_path, "hu_rgai", True, True) convert_nytk.convert_nytk(nytk_input_path, tmp_output_path, "hu_nytk") for shard in SHARDS: rgai_input = os.path.join(tmp_output_path, "hu_rgai.%s.bio" % shard) nytk_input = os.path.join(tmp_output_path, "hu_nytk.%s.bio" % shard) output_file = os.path.join(base_output_path, "hu_combined.%s.bio" % shard) concat_files(output_file, rgai_input, nytk_input) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_bsnlp(paths, short_name): """ Process files downloaded from http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html If you download the training and test data zip files and unzip them without rearranging in any way, the layout is somewhat weird. Training data goes into a specific subdirectory, but the test data goes into the top level directory. """ base_input_path = os.path.join(paths["NERBASE"], "bsnlp2019") base_train_path = os.path.join(base_input_path, "training_pl_cs_ru_bg_rc1") base_test_path = base_input_path base_output_path = paths["NER_DATA_DIR"] output_train_filename = os.path.join(base_output_path, "%s.train.csv" % short_name) output_dev_filename = os.path.join(base_output_path, "%s.dev.csv" % short_name) output_test_filename = os.path.join(base_output_path, "%s.test.csv" % short_name) language = short_name.split("_")[0] convert_bsnlp.convert_bsnlp(language, base_test_path, output_test_filename) convert_bsnlp.convert_bsnlp(language, base_train_path, output_train_filename, output_dev_filename) for shard, csv_file in zip(SHARDS, (output_train_filename, output_dev_filename, output_test_filename)): output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) prepare_ner_file.process_dataset(csv_file, output_filename) NCHLT_LANGUAGE_MAP = { "af": "NCHLT Afrikaans Named Entity Annotated Corpus", # none of the following have UD datasets as of 2.8. Until they # exist, we assume the language codes NCHTL are sufficient "nr": "NCHLT isiNdebele Named Entity Annotated Corpus", "nso": "NCHLT Sepedi Named Entity Annotated Corpus", "ss": "NCHLT Siswati Named Entity Annotated Corpus", "st": "NCHLT Sesotho Named Entity Annotated Corpus", "tn": "NCHLT Setswana Named Entity Annotated Corpus", "ts": "NCHLT Xitsonga Named Entity Annotated Corpus", "ve": "NCHLT Tshivenda Named Entity Annotated Corpus", "xh": "NCHLT isiXhosa Named Entity Annotated Corpus", "zu": "NCHLT isiZulu Named Entity Annotated Corpus", } def process_nchlt(paths, short_name): language = short_name.split("_")[0] if not language in NCHLT_LANGUAGE_MAP: raise UnknownDatasetError(short_name, "Language %s not part of NCHLT" % language) short_name = "%s_nchlt" % language base_input_path = os.path.join(paths["NERBASE"], "NCHLT", NCHLT_LANGUAGE_MAP[language], "*Full.txt") input_files = glob.glob(base_input_path) if len(input_files) == 0: raise FileNotFoundError("Cannot find NCHLT dataset in '%s' Did you remember to download the file?" % base_input_path) if len(input_files) > 1: raise ValueError("Unexpected number of files matched '%s' There should only be one" % base_input_path) base_output_path = paths["NER_DATA_DIR"] split_wikiner(base_output_path, input_files[0], prefix=short_name, remap={"OUT": "O"}) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_my_ucsy(paths, short_name): assert short_name == "my_ucsy" language = "my" base_input_path = os.path.join(paths["NERBASE"], short_name) base_output_path = paths["NER_DATA_DIR"] convert_my_ucsy.convert_my_ucsy(base_input_path, base_output_path) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_fa_arman(paths, short_name): """ Converts fa_arman dataset The conversion is quite simple, actually. Just need to split the train file and then convert bio -> json """ assert short_name == "fa_arman" language = "fa" base_input_path = os.path.join(paths["NERBASE"], "PersianNER") train_input_file = os.path.join(base_input_path, "train_fold1.txt") test_input_file = os.path.join(base_input_path, "test_fold1.txt") if not os.path.exists(train_input_file) or not os.path.exists(test_input_file): full_corpus_file = os.path.join(base_input_path, "ArmanPersoNERCorpus.zip") if os.path.exists(full_corpus_file): raise FileNotFoundError("Please unzip the file {}".format(full_corpus_file)) raise FileNotFoundError("Cannot find the arman corpus in the expected directory: {}".format(base_input_path)) base_output_path = paths["NER_DATA_DIR"] test_output_file = os.path.join(base_output_path, "%s.test.bio" % short_name) split_wikiner(base_output_path, train_input_file, prefix=short_name, train_fraction=0.8, test_section=False) shutil.copy2(test_input_file, test_output_file) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_sv_suc3licensed(paths, short_name): """ The .zip provided for SUC3 includes train/dev/test splits already This extracts those splits without needing to unzip the original file """ assert short_name == "sv_suc3licensed" language = "sv" train_input_file = os.path.join(paths["NERBASE"], short_name, "SUC3.0.zip") if not os.path.exists(train_input_file): raise FileNotFoundError("Cannot find the officially licensed SUC3 dataset in %s" % train_input_file) base_output_path = paths["NER_DATA_DIR"] suc_conll_to_iob.process_suc3(train_input_file, short_name, base_output_path) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_sv_suc3shuffle(paths, short_name): """ Uses an externally provided script to read the SUC3 XML file, then splits it """ assert short_name == "sv_suc3shuffle" language = "sv" train_input_file = os.path.join(paths["NERBASE"], short_name, "suc3.xml.bz2") if not os.path.exists(train_input_file): train_input_file = train_input_file[:-4] if not os.path.exists(train_input_file): raise FileNotFoundError("Unable to find the SUC3 dataset in {}.bz2".format(train_input_file)) base_output_path = paths["NER_DATA_DIR"] train_output_file = os.path.join(base_output_path, "sv_suc3shuffle.bio") suc_to_iob.main([train_input_file, train_output_file]) split_wikiner(base_output_path, train_output_file, prefix=short_name) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_da_ddt(paths, short_name): """ Processes Danish DDT dataset This dataset is in a conll file with the "name" attribute in the misc column for the NER tag. This function uses a script to convert such CoNLL files to .bio """ assert short_name == "da_ddt" language = "da" IN_FILES = ("ddt.train.conllu", "ddt.dev.conllu", "ddt.test.conllu") base_output_path = paths["NER_DATA_DIR"] OUT_FILES = [os.path.join(base_output_path, "%s.%s.bio" % (short_name, shard)) for shard in SHARDS] zip_file = os.path.join(paths["NERBASE"], "da_ddt", "ddt.zip") if os.path.exists(zip_file): for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS): conll_to_iob.process_conll(in_filename, out_filename, zip_file) else: for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS): in_filename = os.path.join(paths["NERBASE"], "da_ddt", in_filename) if not os.path.exists(in_filename): raise FileNotFoundError("Could not find zip in expected location %s and could not file %s file in %s" % (zip_file, shard, in_filename)) conll_to_iob.process_conll(in_filename, out_filename) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_norne(paths, short_name): """ Processes Norwegian NorNE Can handle either Bokmål or Nynorsk Converts GPE_LOC and GPE_ORG to GPE """ language, name = short_name.split("_", 1) assert language in ('nb', 'nn') assert name == 'norne' if language == 'nb': IN_FILES = ("nob/no_bokmaal-ud-train.conllu", "nob/no_bokmaal-ud-dev.conllu", "nob/no_bokmaal-ud-test.conllu") else: IN_FILES = ("nno/no_nynorsk-ud-train.conllu", "nno/no_nynorsk-ud-dev.conllu", "nno/no_nynorsk-ud-test.conllu") base_output_path = paths["NER_DATA_DIR"] OUT_FILES = [os.path.join(base_output_path, "%s.%s.bio" % (short_name, shard)) for shard in SHARDS] CONVERSION = { "GPE_LOC": "GPE", "GPE_ORG": "GPE" } for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS): in_filename = os.path.join(paths["NERBASE"], "norne", "ud", in_filename) if not os.path.exists(in_filename): raise FileNotFoundError("Could not find %s file in %s" % (shard, in_filename)) conll_to_iob.process_conll(in_filename, out_filename, conversion=CONVERSION) convert_bio_to_json(base_output_path, base_output_path, short_name) def process_ja_gsd(paths, short_name): """ Convert ja_gsd from MegagonLabs for example, can download from https://github.com/megagonlabs/UD_Japanese-GSD/releases/tag/r2.9-NE """ language, name = short_name.split("_", 1) assert language == 'ja' assert name == 'gsd' base_output_path = paths["NER_DATA_DIR"] output_files = [os.path.join(base_output_path, "%s.%s.bio" % (short_name, shard)) for shard in SHARDS] search_path = os.path.join(paths["NERBASE"], "ja_gsd", "UD_Japanese-GSD-r2.*-NE") versions = glob.glob(search_path) max_version = None base_input_path = None version_re = re.compile("GSD-r2.([0-9]+)-NE$") for ver in versions: match = version_re.search(ver) if not match: continue ver_num = int(match.groups(1)[0]) if max_version is None or ver_num > max_version: max_version = ver_num base_input_path = ver if base_input_path is None: raise FileNotFoundError("Could not find any copies of the NE conversion of ja_gsd here: {}".format(search_path)) print("Most recent version found: {}".format(base_input_path)) input_files = ["ja_gsd-ud-train.ne.conllu", "ja_gsd-ud-dev.ne.conllu", "ja_gsd-ud-test.ne.conllu"] def conversion(x): if x[0] == 'L': return 'E' + x[1:] if x[0] == 'U': return 'S' + x[1:] # B, I unchanged return x for in_filename, out_filename, shard in zip(input_files, output_files, SHARDS): in_path = os.path.join(base_input_path, in_filename) if not os.path.exists(in_path): in_spacy = os.path.join(base_input_path, "spacy", in_filename) if not os.path.exists(in_spacy): raise FileNotFoundError("Could not find %s file in %s or %s" % (shard, in_path, in_spacy)) in_path = in_spacy conll_to_iob.process_conll(in_path, out_filename, conversion=conversion, allow_empty=True, attr_prefix="NE") convert_bio_to_json(base_output_path, base_output_path, short_name) def process_starlang(paths, short_name): """ Process a Turkish dataset from Starlang """ assert short_name == 'tr_starlang' PIECES = ["TurkishAnnotatedTreeBank-15", "TurkishAnnotatedTreeBank2-15", "TurkishAnnotatedTreeBank2-20"] chunk_paths = [os.path.join(paths["CONSTITUENCY_BASE"], "turkish", piece) for piece in PIECES] datasets = convert_starlang_ner.read_starlang(chunk_paths) write_dataset(datasets, paths["NER_DATA_DIR"], short_name) def remap_germeval_tag(tag): """ Simplify tags for GermEval2014 using a simple rubric all tags become their parent tag OTH becomes MISC """ if tag == "O": return tag if tag[1:5] == "-LOC": return tag[:5] if tag[1:5] == "-PER": return tag[:5] if tag[1:5] == "-ORG": return tag[:5] if tag[1:5] == "-OTH": return tag[0] + "-MISC" raise ValueError("Unexpected tag: %s" % tag) def process_de_germeval2014(paths, short_name): """ Process the TSV of the GermEval2014 dataset """ in_directory = os.path.join(paths["NERBASE"], "germeval2014") base_output_path = paths["NER_DATA_DIR"] datasets = [] for shard in SHARDS: in_file = os.path.join(in_directory, "NER-de-%s.tsv" % shard) sentences = read_tsv(in_file, 1, 2, remap_tag_fn=remap_germeval_tag) datasets.append(sentences) tags = get_tags(datasets) print("Found the following tags: {}".format(sorted(tags))) write_dataset(datasets, base_output_path, short_name) def process_hiner(paths, short_name): in_directory = os.path.join(paths["NERBASE"], "hindi", "HiNER", "data", "original") convert_bio_to_json(in_directory, paths["NER_DATA_DIR"], short_name, suffix="conll", shard_names=("train", "validation", "test")) def process_hinercollapsed(paths, short_name): in_directory = os.path.join(paths["NERBASE"], "hindi", "HiNER", "data", "collapsed") convert_bio_to_json(in_directory, paths["NER_DATA_DIR"], short_name, suffix="conll", shard_names=("train", "validation", "test")) def process_lst20(paths, short_name, include_space_char=True): convert_lst20.convert_lst20(paths, short_name, include_space_char) def process_nner22(paths, short_name, include_space_char=True): convert_nner22.convert_nner22(paths, short_name, include_space_char) def process_mr_l3cube(paths, short_name): base_output_path = paths["NER_DATA_DIR"] in_directory = os.path.join(paths["NERBASE"], "marathi", "MarathiNLP", "L3Cube-MahaNER", "IOB") input_files = ["train_iob.txt", "valid_iob.txt", "test_iob.txt"] input_files = [os.path.join(in_directory, x) for x in input_files] for input_file in input_files: if not os.path.exists(input_file): raise FileNotFoundError("Could not find the expected piece of the l3cube dataset %s" % input_file) datasets = [convert_mr_l3cube.convert(input_file) for input_file in input_files] write_dataset(datasets, base_output_path, short_name) def process_bn_daffodil(paths, short_name): in_directory = os.path.join(paths["NERBASE"], "bangla", "Bengali-NER") out_directory = paths["NER_DATA_DIR"] convert_bn_daffodil.convert_dataset(in_directory, out_directory) def process_pl_nkjp(paths, short_name): out_directory = paths["NER_DATA_DIR"] candidates = [os.path.join(paths["NERBASE"], "Polish-NKJP"), os.path.join(paths["NERBASE"], "polish", "Polish-NKJP"), os.path.join(paths["NERBASE"], "polish", "NKJP-PodkorpusMilionowy-1.2.tar.gz"),] for in_path in candidates: if os.path.exists(in_path): break else: raise FileNotFoundError("Could not find %s Looked in %s" % (short_name, " ".join(candidates))) convert_nkjp.convert_nkjp(in_path, out_directory) def process_kk_kazNERD(paths, short_name): in_directory = os.path.join(paths["NERBASE"], "kazakh", "KazNERD", "KazNERD") out_directory = paths["NER_DATA_DIR"] convert_kk_kazNERD.convert_dataset(in_directory, out_directory, short_name) def process_masakhane(paths, dataset_name): """ Converts Masakhane NER datasets to Stanza's .json format If we let N be the length of the first sentence, the NER files (in version 2, at least) are all of the form word tag ... word tag (blank line for sentence break) word tag ... Once the dataset is git cloned in $NERBASE, the directory structure is $NERBASE/masakhane-ner/MasakhaNER2.0/data/$lcode/{train,dev,test}.txt The only tricky thing here is that for some languages, we treat the 2 letter lcode as canonical thanks to UD, but Masakhane NER uses 3 letter lcodes for all languages. """ language, dataset = dataset_name.split("_") lcode = lang_to_langcode(language) if lcode in two_to_three_letters: masakhane_lcode = two_to_three_letters[lcode] else: masakhane_lcode = lcode mn_directory = os.path.join(paths["NERBASE"], "masakhane-ner") if not os.path.exists(mn_directory): raise FileNotFoundError("Cannot find Masakhane NER repo. Please check the setting of NERBASE or clone the repo to %s" % mn_directory) data_directory = os.path.join(mn_directory, "MasakhaNER2.0", "data") if not os.path.exists(data_directory): raise FileNotFoundError("Apparently found the repo at %s but the expected directory structure is not there - was looking for %s" % (mn_directory, data_directory)) in_directory = os.path.join(data_directory, masakhane_lcode) if not os.path.exists(in_directory): raise UnknownDatasetError(dataset_name, "Found the Masakhane repo, but there was no %s in the repo at path %s" % (dataset_name, in_directory)) convert_bio_to_json(in_directory, paths["NER_DATA_DIR"], "%s_masakhane" % lcode, "txt") def process_sd_siner(paths, short_name): in_directory = os.path.join(paths["NERBASE"], "sindhi", "SiNER-dataset") if not os.path.exists(in_directory): raise FileNotFoundError("Cannot find SiNER checkout in $NERBASE/sindhi Please git clone to repo in that directory") in_filename = os.path.join(in_directory, "SiNER-dataset.txt") if not os.path.exists(in_filename): in_filename = os.path.join(in_directory, "SiNER dataset.txt") if not os.path.exists(in_filename): raise FileNotFoundError("Found an SiNER directory at %s but the directory did not contain the dataset" % in_directory) convert_sindhi_siner.convert_sindhi_siner(in_filename, paths["NER_DATA_DIR"], short_name) def process_en_worldwide_4class(paths, short_name): simplify_en_worldwide.main(args=['--simplify']) in_directory = os.path.join(paths["NERBASE"], "en_worldwide", "4class") out_directory = paths["NER_DATA_DIR"] destination_file = os.path.join(paths["NERBASE"], "en_worldwide", "en-worldwide-newswire", "regions.txt") prefix_map = read_prefix_file(destination_file) random_shuffle_by_prefixes(in_directory, out_directory, short_name, prefix_map) def process_en_worldwide_9class(paths, short_name): simplify_en_worldwide.main(args=['--no_simplify']) in_directory = os.path.join(paths["NERBASE"], "en_worldwide", "9class") out_directory = paths["NER_DATA_DIR"] destination_file = os.path.join(paths["NERBASE"], "en_worldwide", "en-worldwide-newswire", "regions.txt") prefix_map = read_prefix_file(destination_file) random_shuffle_by_prefixes(in_directory, out_directory, short_name, prefix_map) def process_en_ontonotes(paths, short_name): ner_input_path = paths['NERBASE'] ontonotes_path = os.path.join(ner_input_path, "english", "en_ontonotes") ner_output_path = paths['NER_DATA_DIR'] convert_ontonotes.process_dataset("en_ontonotes", ontonotes_path, ner_output_path) def process_zh_ontonotes(paths, short_name): ner_input_path = paths['NERBASE'] ontonotes_path = os.path.join(ner_input_path, "chinese", "zh_ontonotes") ner_output_path = paths['NER_DATA_DIR'] convert_ontonotes.process_dataset(short_name, ontonotes_path, ner_output_path) def process_en_conll03(paths, short_name): ner_input_path = paths['NERBASE'] conll_path = os.path.join(ner_input_path, "english", "en_conll03") ner_output_path = paths['NER_DATA_DIR'] convert_en_conll03.process_dataset("en_conll03", conll_path, ner_output_path) def process_en_conll03_worldwide(paths, short_name): """ Adds the training data for conll03 and worldwide together """ print("============== Preparing CoNLL 2003 ===================") process_en_conll03(paths, "en_conll03") print("========== Preparing 4 Class Worldwide ================") process_en_worldwide_4class(paths, "en_worldwide-4class") print("============== Combined Train Data ====================") input_files = [os.path.join(paths['NER_DATA_DIR'], "en_conll03.train.json"), os.path.join(paths['NER_DATA_DIR'], "en_worldwide-4class.train.json")] output_file = os.path.join(paths['NER_DATA_DIR'], "%s.train.json" % short_name) combine_files(output_file, *input_files) shutil.copyfile(os.path.join(paths['NER_DATA_DIR'], "en_conll03.dev.json"), os.path.join(paths['NER_DATA_DIR'], "%s.dev.json" % short_name)) shutil.copyfile(os.path.join(paths['NER_DATA_DIR'], "en_conll03.test.json"), os.path.join(paths['NER_DATA_DIR'], "%s.test.json" % short_name)) def process_en_ontonotes_ww_multi(paths, short_name): """ Combine the worldwide data with the OntoNotes data in a multi channel format """ print("=============== Preparing OntoNotes ===============") process_en_ontonotes(paths, "en_ontonotes") print("========== Preparing 9 Class Worldwide ================") process_en_worldwide_9class(paths, "en_worldwide-9class") # TODO: pass in options? ontonotes_multitag.build_multitag_dataset(paths['NER_DATA_DIR'], short_name, True, True) def process_en_combined(paths, short_name): """ Combine WW, OntoNotes, and CoNLL into a 3 channel dataset """ print("================= Preparing OntoNotes =================") process_en_ontonotes(paths, "en_ontonotes") print("========== Preparing 9 Class Worldwide ================") process_en_worldwide_9class(paths, "en_worldwide-9class") print("=============== Preparing CoNLL 03 ====================") process_en_conll03(paths, "en_conll03") build_en_combined.build_combined_dataset(paths['NER_DATA_DIR'], short_name) def process_en_conllpp(paths, short_name): """ This is ONLY a test set the test set has entities start with I- instead of B- unless they are in the middle of a sentence, but that should be find, as process_tags in the NER model converts those to B- in a BIOES conversion """ base_input_path = os.path.join(paths["NERBASE"], "acl2023_conllpp", "dataset", "conllpp.txt") base_output_path = paths["NER_DATA_DIR"] sentences = read_tsv(base_input_path, 0, 3, separator=None) sentences = [sent for sent in sentences if len(sent) > 1 or sent[0][0] != '-DOCSTART-'] write_dataset([sentences], base_output_path, short_name, shard_names=["test"], shards=["test"]) def process_armtdp(paths, short_name): assert short_name == 'hy_armtdp' base_input_path = os.path.join(paths["NERBASE"], "armenian", "ArmTDP-NER") base_output_path = paths["NER_DATA_DIR"] convert_hy_armtdp.convert_dataset(base_input_path, base_output_path, short_name) for shard in SHARDS: input_filename = os.path.join(base_output_path, f'{short_name}.{shard}.tsv') if not os.path.exists(input_filename): raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename)) output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) prepare_ner_file.process_dataset(input_filename, output_filename) def process_toy_dataset(paths, short_name): convert_bio_to_json(os.path.join(paths["NERBASE"], "English-SAMPLE"), paths["NER_DATA_DIR"], short_name) def process_ar_aqmar(paths, short_name): base_input_path = os.path.join(paths["NERBASE"], "arabic", "AQMAR", "AQMAR_Arabic_NER_corpus-1.0.zip") base_output_path = paths["NER_DATA_DIR"] convert_ar_aqmar.convert_shuffle(base_input_path, base_output_path, short_name) def process_he_iahlt(paths, short_name): assert short_name == 'he_iahlt' # for now, need to use UDBASE_GIT until IAHLTknesset is added to UD udbase = paths["UDBASE_GIT"] base_output_path = paths["NER_DATA_DIR"] convert_he_iahlt.convert_iahlt(udbase, base_output_path, "he_iahlt") def process_ang_ewt(paths, short_name): assert short_name == 'ang_ewt' base_input_path = os.path.join(paths["NERBASE"], "ang", "Old_English-OEDT") convert_bio_to_json(base_input_path, paths["NER_DATA_DIR"], short_name) DATASET_MAPPING = { "ang_ewt": process_ang_ewt, "ar_aqmar": process_ar_aqmar, "bn_daffodil": process_bn_daffodil, "da_ddt": process_da_ddt, "de_germeval2014": process_de_germeval2014, "en_conll03": process_en_conll03, "en_conll03ww": process_en_conll03_worldwide, "en_conllpp": process_en_conllpp, "en_ontonotes": process_en_ontonotes, "en_ontonotes-ww-multi": process_en_ontonotes_ww_multi, "en_combined": process_en_combined, "en_worldwide-4class": process_en_worldwide_4class, "en_worldwide-9class": process_en_worldwide_9class, "fa_arman": process_fa_arman, "fi_turku": process_turku, "fr_wikinergold": process_french_wikiner_gold, "fr_wikinermixed": process_french_wikiner_mixed, "hi_hiner": process_hiner, "hi_hinercollapsed": process_hinercollapsed, "hi_ijc": process_ijc, "he_iahlt": process_he_iahlt, "hu_nytk": process_nytk, "hu_combined": process_hu_combined, "hy_armtdp": process_armtdp, "it_fbk": process_it_fbk, "ja_gsd": process_ja_gsd, "kk_kazNERD": process_kk_kazNERD, "mr_l3cube": process_mr_l3cube, "my_ucsy": process_my_ucsy, "pl_nkjp": process_pl_nkjp, "sd_siner": process_sd_siner, "sv_suc3licensed": process_sv_suc3licensed, "sv_suc3shuffle": process_sv_suc3shuffle, "tr_starlang": process_starlang, "th_lst20": process_lst20, "th_nner22": process_nner22, "zh-hans_ontonotes": process_zh_ontonotes, } SUFFIX_MAPPING = { "_ilner": process_il_ner, "_suralk": process_suralk_multiner, } def main(dataset_name): paths = default_paths.get_default_paths() print("Processing %s" % dataset_name) random.seed(1234) if dataset_name in DATASET_MAPPING: DATASET_MAPPING[dataset_name](paths, dataset_name) elif dataset_name in ('uk_languk', 'Ukranian_languk', 'Ukranian-languk'): process_languk(paths, 'uk_languk') elif dataset_name.endswith("FIRE2013") or dataset_name.endswith("fire2013"): process_fire_2013(paths, dataset_name) elif dataset_name.endswith('WikiNER'): process_wikiner(paths, dataset_name) elif dataset_name.startswith('hu_rgai'): process_rgai(paths, dataset_name) elif dataset_name.endswith("_bsnlp19"): process_bsnlp(paths, dataset_name) elif dataset_name.endswith("_nchlt"): process_nchlt(paths, dataset_name) elif dataset_name in ("nb_norne", "nn_norne"): process_norne(paths, dataset_name) elif dataset_name == 'en_sample': process_toy_dataset(paths, dataset_name) elif dataset_name.lower().endswith("_masakhane"): process_masakhane(paths, dataset_name) else: for ending in SUFFIX_MAPPING: if dataset_name.endswith(ending): SUFFIX_MAPPING[ending](paths, dataset_name) break else: raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_ner_dataset") print("Done processing %s" % dataset_name) if __name__ == '__main__': main(sys.argv[1]) ================================================ FILE: stanza/utils/datasets/ner/prepare_ner_file.py ================================================ """ This script converts NER data from the CoNLL03 format to the latest CoNLL-U format. The script assumes that in the input column format data, the token is always in the first column, while the NER tag is always in the last column. """ import argparse import json MIN_NUM_FIELD = 2 MAX_NUM_FIELD = 5 DOC_START_TOKEN = '-DOCSTART-' def parse_args(): parser = argparse.ArgumentParser(description="Convert the conll03 format data into conllu format.") parser.add_argument('input', help='Input conll03 format data filename.') parser.add_argument('output', help='Output json filename.') args = parser.parse_args() return args def main(): args = parse_args() process_dataset(args.input, args.output) def process_dataset(input_filename, output_filename): sentences = load_conll03(input_filename) print("{} examples loaded from {}".format(len(sentences), input_filename)) document = [] for (words, tags) in sentences: sent = [] for w, t in zip(words, tags): sent += [{'text': w, 'ner': t}] document += [sent] with open(output_filename, 'w', encoding="utf-8") as outfile: json.dump(document, outfile, indent=1) print("Generated json file {}".format(output_filename)) # TODO: make skip_doc_start an argument def load_conll03(filename, skip_doc_start=True): cached_lines = [] examples = [] with open(filename, encoding="utf-8") as infile: for line in infile: line = line.strip() if skip_doc_start and DOC_START_TOKEN in line: continue if len(line) > 0: array = line.split("\t") if len(array) < MIN_NUM_FIELD: array = line.split() if len(array) < MIN_NUM_FIELD: continue else: cached_lines.append(line) elif len(cached_lines) > 0: example = process_cache(cached_lines) examples.append(example) cached_lines = [] if len(cached_lines) > 0: examples.append(process_cache(cached_lines)) return examples def process_cache(cached_lines): tokens = [] ner_tags = [] for line in cached_lines: array = line.split("\t") if len(array) < MIN_NUM_FIELD: array = line.split() assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD, "Got unexpected line length: {}".format(array) tokens.append(array[0]) ner_tags.append(array[-1]) return (tokens, ner_tags) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/preprocess_wikiner.py ================================================ """ Converts the WikiNER data format to a format usable by our processing tools python preprocess_wikiner input output """ import sys def preprocess_wikiner(input_file, output_file, encoding="utf-8"): with open(input_file, encoding=encoding) as fin: with open(output_file, "w", encoding="utf-8") as fout: for line in fin: line = line.strip() if not line: fout.write("-DOCSTART- O\n") fout.write("\n") continue words = line.split() for word in words: pieces = word.split("|") text = pieces[0] tag = pieces[-1] # some words look like Daniel_Bernoulli|I-PER # but the original .pl conversion script didn't take that into account subtext = text.split("_") if tag.startswith("B-") and len(subtext) > 1: fout.write("{} {}\n".format(subtext[0], tag)) for chunk in subtext[1:]: fout.write("{} I-{}\n".format(chunk, tag[2:])) else: for chunk in subtext: fout.write("{} {}\n".format(chunk, tag)) fout.write("\n") if __name__ == '__main__': preprocess_wikiner(sys.argv[1], sys.argv[2]) ================================================ FILE: stanza/utils/datasets/ner/simplify_en_worldwide.py ================================================ import argparse import os import tempfile import stanza from stanza.utils.default_paths import get_default_paths from stanza.utils.datasets.ner.utils import read_tsv from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() PUNCTUATION = """!"#%&'()*+, -./:;<=>?@[\\]^_`{|}~""" MONEY_WORDS = {"million", "billion", "trillion", "millions", "billions", "trillions", "hundred", "hundreds", "lakh", "crore", # south asian english "tens", "of", "ten", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "couple"} # Doesn't include Money but this case is handled explicitly for processing LABEL_TRANSLATION = { "Date": None, "Misc": "MISC", "Product": "MISC", "NORP": "MISC", "Facility": "LOC", "Location": "LOC", "Person": "PER", "Organization": "ORG", } def isfloat(num): try: float(num) return True except ValueError: return False def process_label(line, is_start=False): """ Converts our stuff to conll labels event, product, work of art, norp -> MISC take out dates - can use Stanza to identify them as dates and eliminate them money requires some special care facility -> location (there are examples of Bridge and Hospital in the data) the version of conll we used to train CoreNLP NER is here: Overall plan: Collapse Product, NORP, Money (extract only the symbols), into misc. Collapse Facilities into LOC Deletes Dates Rule for currency is that we take out labels for the numbers that return True for isfloat() Take out words that categorize money (Million, Billion, Trillion, Thousand, Hundred, Ten, Nine, Eight, Seven, Six, Five, Four, Three, Two, One) Take out punctuation characters If we remove the 'B' tag, then move it to the first remaining tag. Replace tags with 'O' is_start parameter signals whether or not this current line is the new start of a tag. Needed for when the previous line analyzed is the start of a MONEY tag but is removed because it is a non symbol- need to set the starting token that is a symbol to the B-MONEY tag when it might have previously been I-MONEY """ if not line: return [] token = line[0] biggest_label = line[1] position, label_name = biggest_label[:2], biggest_label[2:] if label_name == "Money": if token.lower() in MONEY_WORDS or token in PUNCTUATION or isfloat(token): # remove this tag label_name = "O" is_start = True position = "" else: # keep money tag label_name = "MISC" if is_start: position = "B-" is_start = False elif not label_name or label_name == "O": pass elif label_name in LABEL_TRANSLATION: label_name = LABEL_TRANSLATION[label_name] if label_name is None: position = "" label_name = "O" is_start = False else: raise ValueError("Oops, missed a label: %s" % label_name) return [token, position + label_name, is_start] def write_new_file(save_dir, input_path, old_file, simplify): starts_b = False with open(input_path, "r+", encoding="utf-8") as iob: new_filename = (os.path.splitext(old_file)[0] + ".4class.tsv") if simplify else old_file with open(os.path.join(save_dir, new_filename), 'w', encoding='utf-8') as fout: for i, line in enumerate(iob): if i == 0 or i == 1: # skip over the URL and subsequent space line. continue line = line.strip() if not line: fout.write("\n") continue label = line.split("\t") if simplify: try: edited = process_label(label, is_start=starts_b) # processed label line labels except ValueError as e: raise ValueError("Error in %s at line %d" % (input_path, i)) from e assert edited starts_b = edited[-1] fout.write("\t".join(edited[:-1])) fout.write("\n") else: fout.write("%s\t%s\n" % (label[0], label[1])) def copy_and_simplify(base_path, simplify): with tempfile.TemporaryDirectory(dir=base_path) as tempdir: # Condense Labels input_dir = os.path.join(base_path, "en-worldwide-newswire") final_dir = os.path.join(base_path, "4class" if simplify else "9class") os.makedirs(tempdir, exist_ok=True) os.makedirs(final_dir, exist_ok=True) for root, dirs, files in os.walk(input_dir): if root[-6:] == "REVIEW": batch_files = os.listdir(root) for filename in batch_files: file_path = os.path.join(root, filename) write_new_file(final_dir, file_path, filename, simplify) def main(args=None): BASE_PATH = "C:\\Users\\SystemAdmin\\PycharmProjects\\General Code\\stanza source code" if not os.path.exists(BASE_PATH): paths = get_default_paths() BASE_PATH = os.path.join(paths["NERBASE"], "en_worldwide") parser = argparse.ArgumentParser() parser.add_argument('--base_path', type=str, default=BASE_PATH, help="Where to find the raw data") parser.add_argument('--simplify', default=False, action='store_true', help='Simplify to 4 classes... otherwise, keep all classes') parser.add_argument('--no_simplify', dest='simplify', action='store_false', help="Don't simplify to 4 classes") args = parser.parse_args(args=args) copy_and_simplify(args.base_path, args.simplify) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/simplify_ontonotes_to_worldwide.py ================================================ """ Simplify an existing ner json with the OntoNotes 18 class scheme to the Worldwide scheme Simplified classes used in the Worldwide dataset are: Date Facility Location Misc Money NORP Organization Person Product vs OntoNotes classes: CARDINAL DATE EVENT FAC GPE LANGUAGE LAW LOC MONEY NORP ORDINAL ORG PERCENT PERSON PRODUCT QUANTITY TIME WORK_OF_ART """ import argparse import glob import json import os from stanza.utils.default_paths import get_default_paths WORLDWIDE_ENTITY_MAPPING = { "CARDINAL": None, "ORDINAL": None, "PERCENT": None, "QUANTITY": None, "TIME": None, "DATE": "Date", "EVENT": "Misc", "FAC": "Facility", "GPE": "Location", "LANGUAGE": "NORP", "LAW": "Misc", "LOC": "Location", "MONEY": "Money", "NORP": "NORP", "ORG": "Organization", "PERSON": "Person", "PRODUCT": "Product", "WORK_OF_ART": "Misc", # identity map in case this is called on the Worldwide half of the tags "Date": "Date", "Facility": "Facility", "Location": "Location", "Misc": "Misc", "Money": "Money", "Organization":"Organization", "Person": "Person", "Product": "Product", } def simplify_ontonotes_to_worldwide(entity): if not entity or entity == "O": return "O" ent_iob, ent_type = entity.split("-", maxsplit=1) if ent_type in WORLDWIDE_ENTITY_MAPPING: if not WORLDWIDE_ENTITY_MAPPING[ent_type]: return "O" return ent_iob + "-" + WORLDWIDE_ENTITY_MAPPING[ent_type] raise ValueError("Unhandled entity: %s" % ent_type) def convert_file(in_file, out_file): with open(in_file) as fin: gold_doc = json.load(fin) for sentence in gold_doc: for word in sentence: if 'ner' not in word: continue word['ner'] = simplify_ontonotes_to_worldwide(word['ner']) with open(out_file, "w", encoding="utf-8") as fout: json.dump(gold_doc, fout, indent=2) def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--input_dataset', type=str, default='en_ontonotes', help='which files to convert') parser.add_argument('--output_dataset', type=str, default='en_ontonotes-8class', help='which files to write out') parser.add_argument('--ner_data_dir', type=str, default=get_default_paths()["NER_DATA_DIR"], help='which directory has the data') args = parser.parse_args() input_files = glob.glob(os.path.join(args.ner_data_dir, args.input_dataset + ".*")) for input_file in input_files: output_file = os.path.split(input_file)[1][len(args.input_dataset):] output_file = os.path.join(args.ner_data_dir, args.output_dataset + output_file) print("Converting %s to %s" % (input_file, output_file)) convert_file(input_file, output_file) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/split_wikiner.py ================================================ """ Preprocess the WikiNER dataset, by 1) normalizing tags; 2) split into train (70%), dev (15%), test (15%) datasets. """ import os import random import warnings from collections import Counter def read_sentences(filename, encoding): sents = [] cache = [] skipped = 0 skip = False with open(filename, encoding=encoding) as infile: for i, line in enumerate(infile): line = line.rstrip() if len(line) == 0: if len(cache) > 0: if not skip: sents.append(cache) else: skipped += 1 skip = False cache = [] continue array = line.split() if len(array) != 2: skip = True warnings.warn("Format error at line {}: {}".format(i+1, line)) continue w, t = array cache.append([w, t]) if len(cache) > 0: if not skip: sents.append(cache) else: skipped += 1 cache = [] print("Skipped {} examples due to formatting issues.".format(skipped)) return sents def write_sentences_to_file(sents, filename): print(f"Writing {len(sents)} sentences to {filename}") with open(filename, 'w', encoding='utf-8') as outfile: for sent in sents: for pair in sent: print(f"{pair[0]}\t{pair[1]}", file=outfile) print("", file=outfile) def remap_labels(sents, remap): new_sentences = [] for sentence in sents: new_sent = [] for word in sentence: new_sent.append([word[0], remap.get(word[1], word[1])]) new_sentences.append(new_sent) return new_sentences def split_wikiner_data(directory, sents, prefix="", suffix="bio", remap=None, shuffle=True, train_fraction=0.7, dev_fraction=0.15, test_section=True): random.seed(1234) if remap: sents = remap_labels(sents, remap) # split num = len(sents) train_num = int(num*train_fraction) if test_section: dev_num = int(num*dev_fraction) if train_fraction + dev_fraction > 1.0: raise ValueError("Train and dev fractions added up to more than 1: {} {} {}".format(train_fraction, dev_fraction)) else: dev_num = num - train_num if shuffle: random.shuffle(sents) train_sents = sents[:train_num] dev_sents = sents[train_num:train_num+dev_num] if test_section: test_sents = sents[train_num+dev_num:] batches = [train_sents, dev_sents, test_sents] filenames = [f'train.{suffix}', f'dev.{suffix}', f'test.{suffix}'] else: batches = [train_sents, dev_sents] filenames = [f'train.{suffix}', f'dev.{suffix}'] if prefix: filenames = ['%s.%s' % (prefix, f) for f in filenames] for batch, filename in zip(batches, filenames): write_sentences_to_file(batch, os.path.join(directory, filename)) def split_wikiner(directory, *in_filenames, encoding="utf-8", **kwargs): sents = [] for filename in in_filenames: new_sents = read_sentences(filename, encoding) print(f"{len(new_sents)} sentences read from {filename}.") sents.extend(new_sents) split_wikiner_data(directory, sents, **kwargs) if __name__ == "__main__": in_filename = 'raw/wp2.txt' directory = "." split_wikiner(directory, in_filename) ================================================ FILE: stanza/utils/datasets/ner/suc_conll_to_iob.py ================================================ """ Process the licensed version of SUC3 to BIO The main program processes the expected location, or you can pass in a specific zip or filename to read """ from io import TextIOWrapper from zipfile import ZipFile def extract(infile, outfile): """ Convert the infile to an outfile Assumes the files are already open (this allows you to pass in a zipfile reader, for example) The SUC3 format is like conll, but with the tags in tabs 10 and 11 """ lines = infile.readlines() sentences = [] cur_sentence = [] for idx, line in enumerate(lines): line = line.strip() if not line: # if we're currently reading a sentence, append it to the list if cur_sentence: sentences.append(cur_sentence) cur_sentence = [] continue pieces = line.split("\t") if len(pieces) < 12: raise ValueError("Unexpected line length in the SUC3 dataset at %d" % idx) if pieces[10] == 'O': cur_sentence.append((pieces[1], "O")) else: cur_sentence.append((pieces[1], "%s-%s" % (pieces[10], pieces[11]))) if cur_sentence: sentences.append(cur_sentence) for sentence in sentences: for word in sentence: outfile.write("%s\t%s\n" % word) outfile.write("\n") return len(sentences) def extract_from_zip(zip_filename, in_filename, out_filename): """ Process a single file from SUC3 zip_filename: path to SUC3.0.zip in_filename: which piece to read out_filename: where to write the result """ with ZipFile(zip_filename) as zin: with zin.open(in_filename) as fin: with open(out_filename, "w") as fout: num = extract(TextIOWrapper(fin, encoding="utf-8"), fout) print("Processed %d sentences from %s:%s to %s" % (num, zip_filename, in_filename, out_filename)) return num def process_suc3(zip_filename, short_name, out_dir): extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-train.conll", "%s/%s.train.bio" % (out_dir, short_name)) extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-dev.conll", "%s/%s.dev.bio" % (out_dir, short_name)) extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-test.conll", "%s/%s.test.bio" % (out_dir, short_name)) def main(): process_suc3("extern_data/ner/sv_suc3/SUC3.0.zip", "data/ner") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/suc_to_iob.py ================================================ """ Conversion tool to transform SUC3's xml format to IOB Copyright 2017-2022, Emil Stenström Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from bz2 import BZ2File from xml.etree.ElementTree import iterparse import argparse from collections import Counter import sys def parse(fp, skiptypes=[]): root = None ne_prefix = "" ne_type = "O" name_prefix = "" name_type = "O" for event, elem in iterparse(fp, events=("start", "end")): if root is None: root = elem if event == "start": if elem.tag == "name": _type = name_type_to_label(elem.attrib["type"]) if ( _type not in skiptypes and not (_type == "ORG" and ne_type == "LOC") ): name_type = _type name_prefix = "B-" elif elem.tag == "ne": _type = ne_type_to_label(elem.attrib["type"]) if "/" in _type: _type = ne_type_to_label(_type[_type.index("/") + 1:]) if _type not in skiptypes: ne_type = _type ne_prefix = "B-" elif elem.tag == "w": if name_type == "PER" and elem.attrib["pos"] == "NN": name_type = "O" name_prefix = "" elif event == "end": if elem.tag == "sentence": yield elif elem.tag == "name": name_type = "O" name_prefix = "" elif elem.tag == "ne": ne_type = "O" ne_prefix = "" elif elem.tag == "w": if name_type != "O" and name_type != "OTH": yield elem.text, name_prefix, name_type elif ne_type != "O": yield elem.text, ne_prefix, ne_type else: yield elem.text, "", "O" if ne_type != "O": ne_prefix = "I-" if name_type != "O": name_prefix = "I-" root.clear() def ne_type_to_label(ne_type): mapping = { "PRS": "PER", } return mapping.get(ne_type, ne_type) def name_type_to_label(name_type): mapping = { "inst": "ORG", "product": "OBJ", "other": "OTH", "place": "LOC", "myth": "PER", "person": "PER", "event": "EVN", "work": "WRK", "animal": "PER", } return mapping.get(name_type) def main(args=None): parser = argparse.ArgumentParser() parser.add_argument( "infile", help=""" Input for that contains the full SUC 3.0 XML. Can be the bz2-zipped version or the xml version. """ ) parser.add_argument( "outfile", nargs="?", help=""" Output file for IOB format. Optional - will print to stdout otherwise """ ) parser.add_argument( "--skiptypes", help="Entity types that should be skipped in output.", nargs="+", default=[] ) parser.add_argument( "--stats_only", help="Show statistics of found labels at the end of output.", action='store_true', default=False ) args = parser.parse_args(args) MAGIC_BZ2_FILE_START = b"\x42\x5a\x68" fp = open(args.infile, "rb") is_bz2 = (fp.read(len(MAGIC_BZ2_FILE_START)) == MAGIC_BZ2_FILE_START) if is_bz2: fp = BZ2File(args.infile, "rb") else: fp = open(args.infile, "rb") if args.outfile is not None: fout = open(args.outfile, "w", encoding="utf-8") else: fout = sys.stdout type_stats = Counter() for token in parse(fp, skiptypes=args.skiptypes): if not token: if not args.stats_only: fout.write("\n") else: word, prefix, label = token if args.stats_only: type_stats[label] += 1 else: fout.write("%s\t%s%s\n" % (word, prefix, label)) if args.stats_only: fout.write(str(type_stats) + "\n") fp.close() if args.outfile is not None: fout.close() if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/ner/utils.py ================================================ """ Utils for the processing of NER datasets These can be invoked from either the specific dataset scripts or the entire prepare_ner_dataset.py script """ from collections import defaultdict import io import json import os import random import zipfile from stanza.models.common.doc import Document import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file SHARDS = ('train', 'dev', 'test') def bioes_to_bio(tags): new_tags = [] in_entity = False for tag in tags: if tag == 'O': new_tags.append(tag) in_entity = False elif in_entity and (tag.startswith("B-") or tag.startswith("S-")): # TODO: does the tag have to match the previous tag? # eg, does B-LOC B-PER in BIOES need a B-PER or is I-PER sufficient? new_tags.append('B-' + tag[2:]) else: new_tags.append('I-' + tag[2:]) in_entity = True return new_tags def convert_bioes_to_bio(base_input_path, base_output_path, short_name): """ Convert BIOES files back to BIO (not BIO2) Useful for preparing datasets for CoreNLP, which doesn't do great with the more highly split classes """ for shard in SHARDS: input_filename = os.path.join(base_input_path, '%s.%s.bioes' % (short_name, shard)) output_filename = os.path.join(base_output_path, '%s.%s.bio' % (short_name, shard)) input_sentences = read_tsv(input_filename, text_column=0, annotation_column=1) new_sentences = [] for sentence in input_sentences: tags = [x[1] for x in sentence] tags = bioes_to_bio(tags) sentence = [(x[0], y) for x, y in zip(sentence, tags)] new_sentences.append(sentence) write_sentences(output_filename, new_sentences) def convert_bio_to_json(base_input_path, base_output_path, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS): """ Convert BIO files to json It can often be convenient to put the intermediate BIO files in the same directory as the output files, in which case you can pass in same path for both base_input_path and base_output_path. This also will rewrite a BIOES as json """ for input_shard, output_shard in zip(shard_names, shards): input_filename = os.path.join(base_input_path, '%s.%s.%s' % (short_name, input_shard, suffix)) if not os.path.exists(input_filename): alt_filename = os.path.join(base_input_path, '%s.%s' % (input_shard, suffix)) if os.path.exists(alt_filename): input_filename = alt_filename else: raise FileNotFoundError('Cannot find %s component of %s in %s or %s' % (output_shard, short_name, input_filename, alt_filename)) output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, output_shard)) print("Converting %s to %s" % (input_filename, output_filename)) prepare_ner_file.process_dataset(input_filename, output_filename) def get_tags(datasets): """ return the set of tags used in these datasets datasets is expected to be train, dev, test but could be any list """ tags = set() for dataset in datasets: for sentence in dataset: for word, tag in sentence: tags.add(tag) return tags def write_sentences(output_filename, dataset): """ Write exactly one output file worth of dataset """ os.makedirs(os.path.split(output_filename)[0], exist_ok=True) with open(output_filename, "w", encoding="utf-8") as fout: for sent_idx, sentence in enumerate(dataset): for word_idx, word in enumerate(sentence): if len(word) > 2: word = word[:2] try: fout.write("%s\t%s\n" % word) except TypeError: raise TypeError("Unable to process sentence %d word %d of file %s" % (sent_idx, word_idx, output_filename)) fout.write("\n") def write_dataset(datasets, output_dir, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS): """ write all three pieces of a dataset to output_dir datasets should be 3 lists: train, dev, test each list should be a list of sentences each sentence is a list of pairs: word, tag after writing to .bio files, the files will be converted to .json """ for shard, dataset in zip(shard_names, datasets): output_filename = os.path.join(output_dir, "%s.%s.%s" % (short_name, shard, suffix)) write_sentences(output_filename, dataset) convert_bio_to_json(output_dir, output_dir, short_name, suffix, shard_names=shard_names, shards=shards) def write_multitag_json(output_filename, dataset): json_dataset = [] for sentence in dataset: json_sentence = [] for word in sentence: word = {'text': word[0], 'ner': word[1], 'multi_ner': word[2]} json_sentence.append(word) json_dataset.append(json_sentence) with open(output_filename, 'w', encoding='utf-8') as fout: json.dump(json_dataset, fout, indent=2) def write_multitag_dataset(datasets, output_dir, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS): for shard, dataset in zip(shard_names, datasets): output_filename = os.path.join(output_dir, "%s.%s.%s" % (short_name, shard, suffix)) write_sentences(output_filename, dataset) for shard, dataset in zip(shard_names, datasets): output_filename = os.path.join(output_dir, "%s.%s.json" % (short_name, shard)) write_multitag_json(output_filename, dataset) def read_tsv(filename, text_column, annotation_column, remap_tag_fn=None, remap_line=None, skip_comments=True, keep_broken_tags=False, keep_all_columns=False, separator="\t", zip_filename=None): """ Read sentences from a TSV file Returns a list of list of (word, tag) If keep_broken_tags==True, then None is returned for a missing. Otherwise, an IndexError is thrown """ if zip_filename is not None: with zipfile.ZipFile(zip_filename) as zin: with zin.open(filename) as fin: fin = io.TextIOWrapper(fin, encoding='utf-8') lines = fin.readlines() else: with open(filename, encoding="utf-8") as fin: lines = fin.readlines() lines = [x.strip() for x in lines] sentences = [] current_sentence = [] for line_idx, line in enumerate(lines): if not line: if current_sentence: sentences.append(current_sentence) current_sentence = [] continue if skip_comments and line.startswith("#"): continue if remap_line is not None: line = remap_line(line) pieces = line.split(separator) try: word = pieces[text_column] except IndexError as e: raise IndexError("Filename %s: could not find word index %d at line %d |%s|" % (filename, text_column, line_idx, line)) from e if word == '\x96': # this happens in GermEval2014 for some reason continue try: tag = pieces[annotation_column] except IndexError as e: if keep_broken_tags: tag = None else: raise IndexError("Filename %s: could not find tag index %d at line %d |%s|" % (filename, annotation_column, line_idx, line)) from e if remap_tag_fn is not None: tag = remap_tag_fn(tag) if keep_all_columns: pieces[annotation_column] = tag current_sentence.append(pieces) else: current_sentence.append((word, tag)) if current_sentence: sentences.append(current_sentence) return sentences def random_shuffle_directory(input_dir, output_dir, short_name): input_files = os.listdir(input_dir) input_files = sorted(input_files) random_shuffle_files(input_dir, input_files, output_dir, short_name) def random_shuffle_files(input_dir, input_files, output_dir, short_name): """ Shuffle the files into different chunks based on their filename The first piece of the filename, split by ".", is used as a random seed. This will make it so that adding new files or using a different annotation scheme (assuming that's encoding in pieces of the filename) won't change the distibution of the files """ input_keys = {} for f in input_files: seed = f.split(".")[0] if seed in input_keys: raise ValueError("Multiple files with the same prefix: %s and %s" % (input_keys[seed], f)) input_keys[seed] = f assert len(input_keys) == len(input_files) train_files = [] dev_files = [] test_files = [] for filename in input_files: seed = filename.split(".")[0] # "salt" the filenames when using as a seed # definitely not because of a dumb bug in the original implementation seed = seed + ".txt.4class.tsv" random.seed(seed, 2) location = random.random() if location < 0.7: train_files.append(filename) elif location < 0.8: dev_files.append(filename) else: test_files.append(filename) print("Train files: %d Dev files: %d Test files: %d" % (len(train_files), len(dev_files), len(test_files))) assert len(train_files) + len(dev_files) + len(test_files) == len(input_files) file_lists = [train_files, dev_files, test_files] datasets = [] for files in file_lists: dataset = [] for filename in files: dataset.extend(read_tsv(os.path.join(input_dir, filename), 0, 1)) datasets.append(dataset) write_dataset(datasets, output_dir, short_name) return len(train_files), len(dev_files), len(test_files) def random_shuffle_by_prefixes(input_dir, output_dir, short_name, prefix_map): input_files = os.listdir(input_dir) input_files = sorted(input_files) file_divisions = defaultdict(list) for filename in input_files: for division in prefix_map.keys(): for prefix in prefix_map[division]: if filename.startswith(prefix): break else: # for/else is intentional continue break else: # yes, stop asking raise ValueError("Could not assign %s to any of the divisions in the prefix_map" % filename) #print("Assigning %s to %s because of %s" % (filename, division, prefix)) file_divisions[division].append(filename) num_train_files = 0 num_dev_files = 0 num_test_files = 0 for division in file_divisions.keys(): print() print("Processing %d files from %s" % (len(file_divisions[division]), division)) d_train, d_dev, d_test = random_shuffle_files(input_dir, file_divisions[division], output_dir, "%s-%s" % (short_name, division)) num_train_files += d_train num_dev_files += d_dev num_test_files += d_test print() print("After shuffling: Train files: %d Dev files: %d Test files: %d" % (num_train_files, num_dev_files, num_test_files)) dataset_divisions = ["%s-%s" % (short_name, division) for division in file_divisions] combine_dataset(output_dir, output_dir, dataset_divisions, short_name) def combine_dataset(input_dir, output_dir, input_datasets, output_dataset): datasets = [] for shard in SHARDS: full_dataset = [] for input_dataset in input_datasets: input_filename = "%s.%s.json" % (input_dataset, shard) input_path = os.path.join(input_dir, input_filename) with open(input_path, encoding="utf-8") as fin: dataset = json.load(fin) converted = [[(word['text'], word['ner']) for word in sentence] for sentence in dataset] full_dataset.extend(converted) datasets.append(full_dataset) write_dataset(datasets, output_dir, output_dataset) def read_prefix_file(destination_file): """ Read a prefix file such as the one for the Worldwide dataset the format should be africa: af_ ... asia: cn_ ... """ destination = None known_prefixes = set() prefixes = [] prefix_map = {} with open(destination_file, encoding="utf-8") as fin: for line in fin: line = line.strip() if line.startswith("#"): continue if not line: continue if line.endswith(":"): if destination is not None: prefix_map[destination] = prefixes prefixes = [] destination = line[:-1].strip().lower().replace(" ", "_") else: if not destination: raise RuntimeError("Found a prefix before the first label was assigned when reading %s" % destination_file) prefixes.append(line) if line in known_prefixes: raise RuntimeError("Found the same prefix twice! %s" % line) known_prefixes.add(line) if destination and prefixes: prefix_map[destination] = prefixes return prefix_map def read_json_entities(filename): """ Read entities from a file, return a list of (text, label) Should work on both BIOES and BIO """ with open(filename) as fin: doc = Document(json.load(fin)) return list_doc_entities(doc) def list_doc_entities(doc): """ Return a list of (text, label) Should work on both BIOES and BIO """ entities = [] for sentence in doc.sentences: current_entity = [] previous_label = None for token in sentence.tokens: if token.ner == 'O' or token.ner.startswith("E-"): if token.ner.startswith("E-"): current_entity.append(token.text) if current_entity: assert previous_label is not None entities.append((current_entity, previous_label)) current_entity = [] previous_label = None elif token.ner.startswith("I-"): if previous_label is not None and previous_label != 'O' and previous_label != token.ner[2:]: if current_entity: assert previous_label is not None entities.append((current_entity, previous_label)) current_entity = [] previous_label = token.ner[2:] current_entity.append(token.text) elif token.ner.startswith("B-") or token.ner.startswith("S-"): if current_entity: assert previous_label is not None entities.append((current_entity, previous_label)) current_entity = [] previous_label = None current_entity.append(token.text) previous_label = token.ner[2:] if token.ner.startswith("S-"): assert previous_label is not None entities.append(current_entity) current_entity = [] previous_label = None else: raise RuntimeError("Expected BIO(ES) format in the json file!") previous_label = token.ner[2:] if current_entity: assert previous_label is not None entities.append((current_entity, previous_label)) entities = [(tuple(x[0]), x[1]) for x in entities] return entities def combine_files(output_filename, *input_filenames): """ Combine multiple NER json files into one NER file """ doc = [] for filename in input_filenames: with open(filename) as fin: new_doc = json.load(fin) doc.extend(new_doc) with open(output_filename, "w") as fout: json.dump(doc, fout, indent=2) ================================================ FILE: stanza/utils/datasets/pos/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/pos/convert_trees_to_pos.py ================================================ """ Turns a constituency treebank into a POS dataset with the tags as the upos column The constituency treebank first has to be converted from the original data to PTB style trees. This script converts trees from the CONSTITUENCY_DATA_DIR folder to a conllu dataset in the POS_DATA_DIR folder. Note that this doesn't pay any attention to whether or not the tags actually are upos. Also not possible: using this for tokenization. TODO: upgrade the POS model to handle xpos datasets with no upos, then make upos/xpos an option here To run this: python3 stanza/utils/training/run_pos.py vi_vlsp22 """ import argparse import os import shutil import sys from stanza.models.constituency import tree_reader import stanza.utils.default_paths as default_paths from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() SHARDS = ("train", "dev", "test") def convert_file(in_file, out_file, upos): print("Reading %s" % in_file) trees = tree_reader.read_tree_file(in_file) print("Writing %s" % out_file) with open(out_file, "w") as fout: for tree in tqdm(trees): tree = tree.simplify_labels() text = " ".join(tree.leaf_labels()) fout.write("# text = %s\n" % text) for pt_idx, pt in enumerate(tree.yield_preterminals()): # word index fout.write("%d\t" % (pt_idx+1)) # word fout.write("%s\t" % pt.children[0].label) # don't know the lemma fout.write("_\t") # always put the tag, whatever it is, in the upos (for now) if upos: fout.write("%s\t_\t" % pt.label) else: fout.write("_\t%s\t" % pt.label) # don't have any features fout.write("_\t") # so word 0 fake dep on root, everyone else fake dep on previous word fout.write("%d\t" % pt_idx) if pt_idx == 0: fout.write("root") else: fout.write("dep") fout.write("\t_\t_\n") fout.write("\n") def convert_treebank(short_name, upos, output_name, paths): in_dir = paths["CONSTITUENCY_DATA_DIR"] in_files = [os.path.join(in_dir, "%s_%s.mrg" % (short_name, shard)) for shard in SHARDS] for in_file in in_files: if not os.path.exists(in_file): raise FileNotFoundError("Cannot find expected datafile %s" % in_file) out_dir = paths["POS_DATA_DIR"] if not os.path.exists(out_dir): os.makedirs(out_dir) if output_name is None: output_name = short_name out_files = [os.path.join(out_dir, "%s.%s.in.conllu" % (output_name, shard)) for shard in SHARDS] gold_files = [os.path.join(out_dir, "%s.%s.gold.conllu" % (output_name, shard)) for shard in SHARDS] for in_file, out_file in zip(in_files, out_files): convert_file(in_file, out_file, upos) for out_file, gold_file in zip(out_files, gold_files): shutil.copy2(out_file, gold_file) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("dataset", help="Which dataset to process from trees to POS") parser.add_argument("--upos", action="store_true", default=False, help="Store tags on the UPOS") parser.add_argument("--xpos", dest="upos", action="store_false", help="Store tags on the XPOS") parser.add_argument("--output_name", default=None, help="What name to give the output dataset. If blank, will use the dataset arg") args = parser.parse_args() paths = default_paths.get_default_paths() convert_treebank(args.dataset, args.upos, args.output_name, paths) ================================================ FILE: stanza/utils/datasets/pos/remove_columns.py ================================================ """ Remove xpos and feats from each file given at the command line. Useful to strip unwanted tags when combining files of two different types (or two different stages in the annotation process). Super rudimentary right now. Will be upgraded if needed """ import sys from stanza.utils.conll import CoNLL def remove_columns(filename): doc = CoNLL.conll2doc(filename) for sentence in doc.sentences: for word in sentence.words: word.xpos = None word.feats = None CoNLL.write_doc2conll(doc, filename) if __name__ == '__main__': for filename in sys.argv[1:]: remove_columns(filename) ================================================ FILE: stanza/utils/datasets/prepare_depparse_treebank.py ================================================ """ A script to prepare all depparse datasets. Prepares each of train, dev, test. Example: python -m stanza.utils.datasets.prepare_depparse_treebank {TREEBANK} Example: python -m stanza.utils.datasets.prepare_depparse_treebank UD_English-EWT """ from enum import Enum import glob import logging import os from stanza.models import tagger from stanza.models.common.constant import treebank_to_short_name from stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError from stanza.resources.default_packages import default_charlms, pos_charlms import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank from stanza.utils.training.common import build_pos_wordvec_args from stanza.utils.training.common import add_charlm_args, build_charlm_args, choose_charlm logger = logging.getLogger('stanza') class Tags(Enum): """Tags parameter values.""" GOLD = 1 PREDICTED = 2 # fmt: off def add_specific_args(parser) -> None: """Add specific args.""" parser.add_argument("--gold", dest='tag_method', action='store_const', const=Tags.GOLD, default=Tags.PREDICTED, help='Use gold tags for building the depparse data') parser.add_argument("--predicted", dest='tag_method', action='store_const', const=Tags.PREDICTED, help='Use predicted tags for building the depparse data') parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument('--tagger_model', type=str, default=None, help='Tagger save file to use. If not specified, order searched will be saved/models, then $STANZA_RESOURCES_DIR') parser.add_argument('--save_dir', type=str, default=os.path.join('saved_models', 'pos'), help='Where to look for recently trained POS models') parser.add_argument('--no_download_tagger', default=True, dest='download_tagger', action='store_false', help="Don't try to automatically download a tagger for retagging the dependencies. Will fail to make silver tags if there is no tagger model to be found") add_charlm_args(parser) # fmt: on def choose_tagger_model(short_language, dataset, tagger_model, args): """ Preferentially chooses a retrained tagger model, but tries to download one if that doesn't exist """ logger.debug("Looking for tagger for lang |%s| dataset |%s|. Suggested model |%s|. Looking first in |%s|.", short_language, dataset, tagger_model, args.save_dir) if tagger_model: return tagger_model candidates = glob.glob(os.path.join(args.save_dir, "%s_%s_*.pt" % (short_language, dataset))) if len(candidates) == 1: return candidates[0] if len(candidates) > 1: for ending in ("_trans_tagger.pt", "_charlm_tagger.pt", "_nocharlm_tagger.pt"): best_candidates = [x for x in candidates if x.endswith(ending)] if len(best_candidates) == 1: return best_candidates[0] if len(best_candidates) > 1: raise FileNotFoundError("Could not choose among the candidate taggers... please pick one with --tagger_model: {}".format(best_candidates)) raise FileNotFoundError("Could not choose among the candidate taggers... please pick one with --tagger_model: {}".format(candidates)) if not args.download_tagger: return None # TODO: just create a Pipeline for the retagging instead? pos_path = os.path.join(DEFAULT_MODEL_DIR, short_language, "pos", dataset + ".pt") if os.path.exists(pos_path): return pos_path try: download_list = download(lang=short_language, package=None, processors={"pos": dataset}) except UnknownLanguageError as e: raise FileNotFoundError("The language %s appears to be a language new to Stanza. Unfortunately, that means there are no taggers available for retagging the dependency dataset. Furthermore, there are no tagger models for this language found in %s. You can specify a different directory for already trained tagger models with --save_dir, specify an exact tagger model name with --tagger_model, or use gold tags with --gold" % (short_language, args.save_dir)) from e for processor, name in download_list: if processor == 'pos': pos_path = os.path.join(DEFAULT_MODEL_DIR, short_language, "pos", name + ".pt") return pos_path else: raise FileNotFoundError("Could not figure out which model file to use for %s. Just tried to download to %s the models %s" % (short_language, args.save_dir, download_list)) def process_treebank(treebank, model_type, paths, args) -> None: """Process treebank.""" if args.tag_method is Tags.GOLD: prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths["DEPPARSE_DATA_DIR"]) elif args.tag_method is Tags.PREDICTED: short_name = treebank_to_short_name(treebank) short_language, dataset = short_name.split("_", 1) # fmt: off base_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] # fmt: on # perhaps download a tagger if one doesn't already exist tagger_model = choose_tagger_model(short_language, dataset, args.tagger_model, args) if tagger_model is None: raise FileNotFoundError("Cannot find a tagger for language %s, dataset %s - you can specify one with the --tagger_model flag") else: logger.info("Using tagger model in %s for %s_%s", tagger_model, short_language, dataset) tagger_dir, tagger_name = os.path.split(tagger_model) base_args = base_args + ['--save_dir', tagger_dir, '--save_name', tagger_name] # word vector file for POS if args.wordvec_pretrain_file: base_args += ["--wordvec_pretrain_file", args.wordvec_pretrain_file] else: base_args = base_args + build_pos_wordvec_args(short_language, dataset, []) # charlm for POS charlm = choose_charlm(short_language, dataset, args.charlm, default_charlms, pos_charlms) charlm_args = build_charlm_args(short_language, charlm) base_args = base_args + charlm_args def retag_dataset(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name): original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu" retagged = f"{dest_dir}/{short_name}.{dest_file}.conllu" # fmt: off tagger_args = ["--eval_file", original, "--output_file", retagged] # fmt: on tagger_args = base_args + tagger_args logger.info("Running tagger to retag {} to {}\n Args: {}".format(original, retagged, tagger_args)) tagger.main(tagger_args) prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths["DEPPARSE_DATA_DIR"], retag_dataset) else: raise ValueError("Unknown tags method: {}".format(args.tag_method)) def main() -> None: """Call Process Treebank.""" common.main(process_treebank, common.ModelType.DEPPARSE, add_specific_args) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/prepare_lemma_classifier.py ================================================ import os import sys from stanza.utils.datasets.common import find_treebank_dataset_file, UnknownDatasetError from stanza.utils.default_paths import get_default_paths from stanza.models.lemma_classifier import prepare_dataset from stanza.models.common.short_name_to_treebank import short_name_to_treebank from stanza.utils.conll import CoNLL SECTIONS = ("train", "dev", "test") def process_treebank(paths, short_name, word, upos, allowed_lemmas, sections=SECTIONS): treebank = short_name_to_treebank(short_name) udbase_dir = paths["UDBASE"] output_dir = paths["LEMMA_CLASSIFIER_DATA_DIR"] os.makedirs(output_dir, exist_ok=True) output_filenames = [] for section in sections: filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True) output_filename = os.path.join(output_dir, "%s.%s.lemma" % (short_name, section)) args = ["--conll_path", filename, "--target_word", word, "--target_upos", upos, "--output_path", output_filename] if allowed_lemmas is not None: args.extend(["--allowed_lemmas", allowed_lemmas]) prepare_dataset.main(args) output_filenames.append(output_filename) return output_filenames def process_en_combined(paths, short_name): udbase_dir = paths["UDBASE"] output_dir = paths["LEMMA_CLASSIFIER_DATA_DIR"] os.makedirs(output_dir, exist_ok=True) train_treebanks = ["UD_English-EWT", "UD_English-GUM", "UD_English-GUMReddit", "UD_English-LinES"] test_treebanks = ["UD_English-PUD", "UD_English-Pronouns"] target_word = "'s" target_upos = ["AUX"] sentences = [ [], [], [] ] for treebank in train_treebanks: for section_idx, section in enumerate(SECTIONS): filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True) doc = CoNLL.conll2doc(filename) processor = prepare_dataset.DataProcessor(target_word=target_word, target_upos=target_upos, allowed_lemmas=".*") new_sentences = processor.process_document(doc, save_name=None) print("Read %d sentences from %s" % (len(new_sentences), filename)) sentences[section_idx].extend(new_sentences) for treebank in test_treebanks: section = "test" filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True) doc = CoNLL.conll2doc(filename) processor = prepare_dataset.DataProcessor(target_word=target_word, target_upos=target_upos, allowed_lemmas=".*") new_sentences = processor.process_document(doc, save_name=None) print("Read %d sentences from %s" % (len(new_sentences), filename)) sentences[2].extend(new_sentences) for section, section_sentences in zip(SECTIONS, sentences): output_filename = os.path.join(output_dir, "%s.%s.lemma" % (short_name, section)) prepare_dataset.DataProcessor.write_output_file(output_filename, target_upos, section_sentences) print("Wrote %s sentences to %s" % (len(section_sentences), output_filename)) def process_ja_gsd(paths, short_name): # this one looked promising, but only has 10 total dev & test cases # 行っ VERB Counter({'行う': 60, '行く': 38}) # could possibly do # ない AUX Counter({'ない': 383, '無い': 99}) # なく AUX Counter({'無い': 53, 'ない': 42}) # currently this one has enough in the dev & test data # and functions well # だ AUX Counter({'だ': 237, 'た': 67}) word = "だ" upos = "AUX" allowed_lemmas = None process_treebank(paths, short_name, word, upos, allowed_lemmas) def process_fa_perdt(paths, short_name): word = "شد" upos = "VERB" allowed_lemmas = "کرد|شد" process_treebank(paths, short_name, word, upos, allowed_lemmas) def process_hi_hdtb(paths, short_name): word = "के" upos = "ADP" allowed_lemmas = "का|के" process_treebank(paths, short_name, word, upos, allowed_lemmas) def process_ar_padt(paths, short_name): word = "أن" upos = "SCONJ" allowed_lemmas = "أَن|أَنَّ" process_treebank(paths, short_name, word, upos, allowed_lemmas) def process_el_gdt(paths, short_name): """ All of the Greek lemmas for these words are εγώ or μου τους PRON Counter({'μου': 118, 'εγώ': 32}) μας PRON Counter({'μου': 89, 'εγώ': 32}) του PRON Counter({'μου': 82, 'εγώ': 8}) της PRON Counter({'μου': 80, 'εγώ': 2}) σας PRON Counter({'μου': 34, 'εγώ': 24}) μου PRON Counter({'μου': 45, 'εγώ': 10}) """ word = "τους|μας|του|της|σας|μου" upos = "PRON" allowed_lemmas = None process_treebank(paths, short_name, word, upos, allowed_lemmas) DATASET_MAPPING = { "ar_padt": process_ar_padt, "el_gdt": process_el_gdt, "en_combined": process_en_combined, "fa_perdt": process_fa_perdt, "hi_hdtb": process_hi_hdtb, "ja_gsd": process_ja_gsd, } def main(dataset_name): paths = get_default_paths() print("Processing %s" % dataset_name) # obviously will want to multiplex to multiple languages / datasets if dataset_name in DATASET_MAPPING: DATASET_MAPPING[dataset_name](paths, dataset_name) else: raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_lemma_classifier.py") print("Done processing %s" % dataset_name) if __name__ == '__main__': main(sys.argv[1]) ================================================ FILE: stanza/utils/datasets/prepare_lemma_treebank.py ================================================ """ A script to prepare all lemma datasets. For example, do python -m stanza.utils.datasets.prepare_lemma_treebank TREEBANK such as python -m stanza.utils.datasets.prepare_lemma_treebank UD_English-EWT and it will prepare each of train, dev, test """ from stanza.models.common.constant import treebank_to_short_name import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier def add_specific_args(parser) -> None: parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false', default=True, help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer") def check_lemmas(train_file): """ Check if a treebank has any lemmas in it For example, in Vietnamese-VTB, all the words and lemmas are exactly the same in Telugu-MTG, all the lemmas are blank """ # could eliminate a few languages immediately based on UD 2.7 # but what if a later dataset includes lemmas? #if short_language in ('vi', 'fro', 'th'): # return False with open(train_file, encoding="utf-8") as fin: for line in fin: line = line.strip() if not line or line.startswith("#"): continue pieces = line.split("\t") word = pieces[1].lower().strip() lemma = pieces[2].lower().strip() if not lemma or lemma == '_' or lemma == '-': continue if word == lemma: continue return True return False def process_treebank(treebank, model_type, paths, args): if treebank.startswith("UD_"): udbase_dir = paths["UDBASE"] input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu") if not input_conllu: input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu", fail=True) augment = check_lemmas(input_conllu) if not augment: print("No lemma information found in %s. Not augmenting the dataset" % train_conllu) else: # TODO: check the data to see if there are lemmas or not augment = True prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths["LEMMA_DATA_DIR"], augment=augment) short_name = treebank_to_short_name(treebank) if args.lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING: prepare_lemma_classifier.main(short_name) def main(): common.main(process_treebank, common.ModelType.LEMMA, add_specific_args) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/prepare_mwt_treebank.py ================================================ """ A script to prepare all MWT datasets. For example, do python -m stanza.utils.datasets.prepare_mwt_treebank TREEBANK such as python -m stanza.utils.datasets.prepare_mwt_treebank UD_English-EWT and it will prepare each of train, dev, test """ import argparse import os import shutil import tempfile from stanza.utils.conll import CoNLL from stanza.models.common.constant import treebank_to_short_name import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank from stanza.utils.datasets.contract_mwt import contract_mwt # languages where the MWTs are always a composition of the words themselves KNOWN_COMPOSABLE_MWTS = {"en"} # ... but partut is not put together that way MWT_EXCEPTIONS = {"en_partut"} def copy_conllu(tokenizer_dir, mwt_dir, short_name, dataset, particle): input_conllu_tokenizer = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" input_conllu_mwt = f"{mwt_dir}/{short_name}.{dataset}.{particle}.conllu" shutil.copyfile(input_conllu_tokenizer, input_conllu_mwt) def check_mwt_composition(filename): print("Checking the MWTs in %s" % filename) doc = CoNLL.conll2doc(filename) for sent_idx, sentence in enumerate(doc.sentences): for token_idx, token in enumerate(sentence.tokens): if len(token.words) > 1: expected = "".join(x.text for x in token.words) if token.text != expected: raise ValueError("Unexpected token composition in filename %s sentence %d id %s token %d: %s instead of %s" % (filename, sent_idx, sentence.sent_id, token_idx, token.text, expected)) def process_treebank(treebank, model_type, paths, args): short_name = treebank_to_short_name(treebank) mwt_dir = paths["MWT_DATA_DIR"] os.makedirs(mwt_dir, exist_ok=True) with tempfile.TemporaryDirectory() as tokenizer_dir: paths = dict(paths) paths["TOKENIZE_DATA_DIR"] = tokenizer_dir # first we process the tokenization data tokenizer_args = argparse.Namespace() tokenizer_args.augment = False tokenizer_args.prepare_labels = True prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, tokenizer_args) copy_conllu(tokenizer_dir, mwt_dir, short_name, "train", "in") copy_conllu(tokenizer_dir, mwt_dir, short_name, "dev", "gold") copy_conllu(tokenizer_dir, mwt_dir, short_name, "test", "gold") for shard in ("train", "dev", "test"): source_filename = common.mwt_name(tokenizer_dir, short_name, shard) dest_filename = common.mwt_name(mwt_dir, short_name, shard) print("Copying from %s to %s" % (source_filename, dest_filename)) shutil.copyfile(source_filename, dest_filename) language = short_name.split("_", 1)[0] if language in KNOWN_COMPOSABLE_MWTS and short_name not in MWT_EXCEPTIONS: print("Language %s is known to have all MWT composed of exactly its word pieces. Checking..." % language) check_mwt_composition(f"{mwt_dir}/{short_name}.train.in.conllu") check_mwt_composition(f"{mwt_dir}/{short_name}.dev.gold.conllu") check_mwt_composition(f"{mwt_dir}/{short_name}.test.gold.conllu") contract_mwt(f"{mwt_dir}/{short_name}.dev.gold.conllu", f"{mwt_dir}/{short_name}.dev.in.conllu") contract_mwt(f"{mwt_dir}/{short_name}.test.gold.conllu", f"{mwt_dir}/{short_name}.test.in.conllu") def main(): common.main(process_treebank, common.ModelType.MWT) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/prepare_pos_treebank.py ================================================ """ A script to prepare all pos datasets. For example, do python -m stanza.utils.datasets.prepare_pos_treebank TREEBANK such as python -m stanza.utils.datasets.prepare_pos_treebank UD_English-EWT and it will prepare each of train, dev, test """ import os import shutil import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank def copy_conllu_file_or_zip(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name): original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.zip" copied = f"{dest_dir}/{short_name}.{dest_file}.zip" if os.path.exists(original): print("Copying from %s to %s" % (original, copied)) shutil.copyfile(original, copied) else: prepare_tokenizer_treebank.copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name) def process_treebank(treebank, model_type, paths, args): prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths["POS_DATA_DIR"], postprocess=copy_conllu_file_or_zip) def main(): common.main(process_treebank, common.ModelType.POS) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/prepare_tokenizer_data.py ================================================ import argparse import json import os import re import sys from collections import Counter """ Data is output in 4 files: a file containing the mwt information a file containing the words and sentences in conllu format a file containing the raw text of each paragraph a file of 0,1,2 indicating word break or sentence break on a character level for the raw text 1: end of word 2: end of sentence """ PARAGRAPH_BREAK = re.compile(r'\n\s*\n') def is_para_break(index, text): """ Detect if a paragraph break can be found, and return the length of the paragraph break sequence. """ if text[index] == '\n': para_break = PARAGRAPH_BREAK.match(text, index) if para_break: break_len = len(para_break.group(0)) return True, break_len return False, 0 def find_next_word(index, text, word, output): """ Locate the next word in the text. In case a paragraph break is found, also write paragraph break to labels. """ idx = 0 word_sofar = '' while index < len(text) and idx < len(word): para_break, break_len = is_para_break(index, text) if para_break: # multiple newlines found, paragraph break if len(word_sofar) > 0: assert re.match(r'^\s+$', word_sofar), 'Found non-empty string at the end of a paragraph that doesn\'t match any token: |{}|'.format(word_sofar) word_sofar = '' output.write('\n\n') index += break_len - 1 elif re.match(r'^\s$', text[index]) and not re.match(r'^\s$', word[idx]): # whitespace found, and whitespace is not part of a word word_sofar += text[index] else: # non-whitespace char, or a whitespace char that's part of a word word_sofar += text[index] assert text[index].replace('\n', ' ') == word[idx], "Character mismatch: raw text contains |%s| but the next word is |%s|." % (word_sofar, word) idx += 1 index += 1 return index, word_sofar def main(args): parser = argparse.ArgumentParser() parser.add_argument('plaintext_file', type=str, help="Plaintext file containing the raw input") parser.add_argument('conllu_file', type=str, help="CoNLL-U file containing tokens and sentence breaks") parser.add_argument('-o', '--output', default=None, type=str, help="Output file name; output to the console if not specified (the default)") parser.add_argument('-m', '--mwt_output', default=None, type=str, help="Output file name for MWT expansions; output to the console if not specified (the default)") args = parser.parse_args(args=args) with open(args.plaintext_file, 'r', encoding='utf-8') as f: text = ''.join(f.readlines()) textlen = len(text) if args.output is None: output = sys.stdout else: outdir = os.path.split(args.output)[0] os.makedirs(outdir, exist_ok=True) output = open(args.output, 'w') index = 0 # character offset in rawtext mwt_expansions = [] with open(args.conllu_file, 'r', encoding='utf-8') as f: buf = '' mwtbegin = 0 mwtend = -1 expanded = [] last_comments = "" for line in f: line = line.strip() if len(line): if line[0] == "#": # comment, don't do anything if len(last_comments) == 0: last_comments = line continue line = line.split('\t') if '.' in line[0]: # the tokenizer doesn't deal with ellipsis continue word = line[1] if '-' in line[0]: # multiword token mwtbegin, mwtend = [int(x) for x in line[0].split('-')] lastmwt = word expanded = [] elif mwtbegin <= int(line[0]) < mwtend: expanded += [word] continue elif int(line[0]) == mwtend: expanded += [word] expanded = [x.lower() for x in expanded] # evaluation doesn't care about case mwt_expansions += [(lastmwt, tuple(expanded))] if lastmwt[0].islower() and not expanded[0][0].islower(): print('Sentence ID with potential wrong MWT expansion: ', last_comments, file=sys.stderr) mwtbegin = 0 mwtend = -1 lastmwt = None continue if len(buf): output.write(buf) index, word_found = find_next_word(index, text, word, output) buf = '0' * (len(word_found)-1) + ('1' if '-' not in line[0] else '3') else: # sentence break found if len(buf): assert int(buf[-1]) >= 1 output.write(buf[:-1] + '{}'.format(int(buf[-1]) + 1)) buf = '' last_comments = '' status_line = "" if args.output: output.close() status_line = 'Tokenizer labels written to %s\n ' % args.output mwts = Counter(mwt_expansions) if args.mwt_output is None: print('MWTs:', mwts) else: with open(args.mwt_output, 'w') as f: json.dump(list(mwts.items()), f, indent=2) status_line = status_line + '{} unique MWTs found in data. MWTs written to {}'.format(len(mwts), args.mwt_output) print(status_line) if __name__ == '__main__': main(sys.argv[1:]) ================================================ FILE: stanza/utils/datasets/prepare_tokenizer_treebank.py ================================================ """ Prepares train, dev, test for a treebank For example, do python -m stanza.utils.datasets.prepare_tokenizer_treebank TREEBANK such as python -m stanza.utils.datasets.prepare_tokenizer_treebank UD_English-EWT and it will prepare each of train, dev, test There are macros for preparing all of the UD treebanks at once: python -m stanza.utils.datasets.prepare_tokenizer_treebank ud_all python -m stanza.utils.datasets.prepare_tokenizer_treebank all_ud Both are present because I kept forgetting which was the correct one There are a few special case handlings of treebanks in this file: - all Vietnamese treebanks have special post-processing to handle some of the difficult spacing issues in Vietnamese text - treebanks with train and test but no dev split have the train data randomly split into two pieces - however, instead of splitting very tiny treebanks, we skip those """ import argparse import glob import io import os import random import re import sys import tempfile import zipfile from collections import Counter from stanza.models.common.constant import treebank_to_short_name import stanza.utils.datasets.common as common from stanza.utils.datasets.common import read_sentences_from_conllu, write_sentences_to_conllu, write_sentences_to_file, INT_RE, MWT_RE, MWT_OR_COPY_RE import stanza.utils.datasets.tokenization.convert_ml_cochin as convert_ml_cochin import stanza.utils.datasets.tokenization.convert_my_alt as convert_my_alt import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp import stanza.utils.datasets.tokenization.convert_th_best as convert_th_best import stanza.utils.datasets.tokenization.convert_th_lst20 as convert_th_lst20 import stanza.utils.datasets.tokenization.convert_th_orchid as convert_th_orchid def copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name): original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu" copied = f"{dest_dir}/{short_name}.{dest_file}.conllu" print("Copying from %s to %s" % (original, copied)) # do this instead of shutil.copyfile in case there are manipulations needed # for example, we might need to add fake dependencies (TODO: still needed?) sents = read_sentences_from_conllu(original) write_sentences_to_conllu(copied, sents) def copy_conllu_treebank(treebank, model_type, paths, dest_dir, postprocess=None, augment=True): """ This utility method copies only the conllu files to the given destination directory. Both POS, lemma, and depparse annotators need this. """ os.makedirs(dest_dir, exist_ok=True) short_name = treebank_to_short_name(treebank) short_language = short_name.split("_")[0] with tempfile.TemporaryDirectory() as tokenizer_dir: paths = dict(paths) paths["TOKENIZE_DATA_DIR"] = tokenizer_dir # first we process the tokenization data args = argparse.Namespace() args.augment = augment args.prepare_labels = False process_treebank(treebank, model_type, paths, args) os.makedirs(dest_dir, exist_ok=True) if postprocess is None: postprocess = copy_conllu_file # now we copy the processed conllu data files postprocess(tokenizer_dir, "train.gold", dest_dir, "train.in", short_name) postprocess(tokenizer_dir, "dev.gold", dest_dir, "dev.in", short_name) postprocess(tokenizer_dir, "test.gold", dest_dir, "test.in", short_name) def split_conllu_file(treebank, input_conllu, train_output_conllu, dev_output_conllu, test_output_conllu): # set the seed for each data file so that the results are the same # regardless of how many treebanks are processed at once random.seed(1234) # read and shuffle conllu data sents = read_sentences_from_conllu(input_conllu) random.shuffle(sents) n_dev = int(len(sents) * XV_RATIO) assert n_dev >= 1, "Dev sentence number less than one." n_test = int(len(sents) * XV_RATIO) assert n_test >= 1, "Test sentence number less than one." n_train = len(sents) - n_dev - n_test # split conllu data dev_sents = sents[:n_dev] test_sents = sents[n_dev:n_dev+n_test] train_sents = sents[n_dev+n_test:] print("Train/dev/test split not present. Randomly splitting file from %s to %s, %s, %s" % (input_conllu, train_output_conllu, dev_output_conllu, test_output_conllu)) print(f"{len(sents)} total sentences found: {n_train} in train, {n_dev} in dev, {n_test} in test") # write conllu write_sentences_to_conllu(train_output_conllu, train_sents) write_sentences_to_conllu(dev_output_conllu, dev_sents) write_sentences_to_conllu(test_output_conllu, test_sents) return True def split_train_file(treebank, train_input_conllu, train_output_conllu, dev_output_conllu): # set the seed for each data file so that the results are the same # regardless of how many treebanks are processed at once random.seed(1234) # read and shuffle conllu data sents = read_sentences_from_conllu(train_input_conllu) random.shuffle(sents) n_dev = int(len(sents) * XV_RATIO) assert n_dev >= 1, "Dev sentence number less than one." n_train = len(sents) - n_dev # split conllu data dev_sents = sents[:n_dev] train_sents = sents[n_dev:] print("Train/dev split not present. Randomly splitting train file from %s to %s and %s" % (train_input_conllu, train_output_conllu, dev_output_conllu)) print(f"{len(sents)} total sentences found: {n_train} in train, {n_dev} in dev") # write conllu write_sentences_to_conllu(train_output_conllu, train_sents) write_sentences_to_conllu(dev_output_conllu, dev_sents) return True def has_space_after_no(piece): if not piece or piece == "_": return False if piece == "SpaceAfter=No": return True tags = piece.split("|") return any(t == "SpaceAfter=No" for t in tags) def remove_space_after_no(piece, fail_if_missing=True): """ Removes a SpaceAfter=No annotation from a single piece of a single word. In other words, given a list of conll lines, first call split("\t"), then call this on the -1 column """ # |SpaceAfter is in UD_Romanian-Nonstandard... seems fitting if piece == "SpaceAfter=No" or piece == "|SpaceAfter=No": piece = "_" elif piece.startswith("SpaceAfter=No|"): piece = piece.replace("SpaceAfter=No|", "") elif piece.find("|SpaceAfter=No") > 0: piece = piece.replace("|SpaceAfter=No", "") elif fail_if_missing: raise ValueError("Could not find SpaceAfter=No in the given notes field") return piece def add_space_after_no(piece, fail_if_found=True): if piece == '_': return "SpaceAfter=No" else: if fail_if_found: if has_space_after_no(piece): raise ValueError("Given notes field already contained SpaceAfter=No") return piece + "|SpaceAfter=No" def augment_telugu(sents): """ Add a few sentences with modified punctuation to Telugu_MTG The Telugu-MTG dataset has punctuation separated from the text in almost all cases, which makes the tokenizer not learn how to process that correctly. All of the Telugu sentences end with their sentence final punctuation being separated. Furthermore, all commas are separated. We change that on some subset of the sentences to make the tools more generalizable on wild text. """ new_sents = [] for sentence in sents: if not sentence[1].startswith("# text"): raise ValueError("Expected the second line of %s to start with # text" % sentence[0]) if not sentence[2].startswith("# translit"): raise ValueError("Expected the second line of %s to start with # translit" % sentence[0]) if sentence[1].endswith(". . .") or sentence[1][-1] not in ('.', '?', '!'): continue if sentence[1][-1] in ('.', '?', '!') and sentence[1][-2] != ' ' and sentence[1][-3:] != ' ..' and sentence[1][-4:] != ' ...': raise ValueError("Sentence %s does not end with space-punctuation, which is against our assumptions for the te_mtg treebank. Please check the augment method to see if it is still needed" % sentence[0]) if random.random() < 0.1: new_sentence = list(sentence) new_sentence[1] = new_sentence[1][:-2] + new_sentence[1][-1] new_sentence[2] = new_sentence[2][:-2] + new_sentence[2][-1] new_sentence[-2] = new_sentence[-2] + "|SpaceAfter=No" new_sents.append(new_sentence) if sentence[1].find(",") > 1 and random.random() < 0.1: new_sentence = list(sentence) index = sentence[1].find(",") new_sentence[1] = sentence[1][:index-1] + sentence[1][index:] index = sentence[1].find(",") new_sentence[2] = sentence[2][:index-1] + sentence[2][index:] for idx, word in enumerate(new_sentence): if idx < 4: # skip sent_id, text, transliteration, and the first word continue if word.split("\t")[1] == ',': new_sentence[idx-1] = new_sentence[idx-1] + "|SpaceAfter=No" break new_sents.append(new_sentence) return sents + new_sents COMMA_SEPARATED_RE = re.compile(" ([a-zA-Z]+)[,] ([a-zA-Z]+) ") def augment_comma_separations(sents, ratio=0.03): """Find some fraction of the sentences which match "asdf, zzzz" and squish them to "asdf,zzzz" This leaves the tokens and all of the other data the same. The only change made is to change SpaceAfter=No for the "," token and adjust the #text line, with the assumption that the conllu->txt conversion will correctly handle this change. This was particularly an issue for Spanish-AnCora, but it's reasonable to think it could happen to any dataset. Currently this just operates on commas and ascii letters to avoid accidentally squishing anything that shouldn't be squished. UD_Spanish-AnCora 2.7 had a problem is with this sentence: # orig_file_sentence 143#5 In this sentence, there was a comma smashed next to a token. Fixing just this one sentence is not sufficient to tokenize "asdf,zzzz" as desired, so we also augment by some fraction where we have squished "asdf, zzzz" into "asdf,zzzz". This exact example was later fixed in UD 2.8, but it should still potentially be useful for compensating for typos. """ new_sents = [] for sentence in sents: for text_idx, text_line in enumerate(sentence): # look for the line that starts with "# text". # keep going until we find it, or silently ignore it # if the dataset isn't in that format if text_line.startswith("# text"): break else: continue match = COMMA_SEPARATED_RE.search(sentence[text_idx]) if match and random.random() < ratio: for idx, word in enumerate(sentence): if word.startswith("#"): continue # find() doesn't work because we wind up finding substrings if word.split("\t")[1] != match.group(1): continue if sentence[idx+1].split("\t")[1] != ',': continue if sentence[idx+2].split("\t")[1] != match.group(2): continue break if idx == len(sentence) - 1: # this can happen with MWTs. we may actually just # want to skip MWTs anyway, so no big deal continue # now idx+1 should be the line with the comma in it comma = sentence[idx+1] pieces = comma.split("\t") assert pieces[1] == ',' pieces[-1] = add_space_after_no(pieces[-1]) comma = "\t".join(pieces) new_sent = sentence[:idx+1] + [comma] + sentence[idx+2:] text_offset = sentence[text_idx].find(match.group(1) + ", " + match.group(2)) text_len = len(match.group(1) + ", " + match.group(2)) new_text = sentence[text_idx][:text_offset] + match.group(1) + "," + match.group(2) + sentence[text_idx][text_offset+text_len:] new_sent[text_idx] = new_text new_sents.append(new_sent) print("Added %d new sentences with asdf, zzzz -> asdf,zzzz" % len(new_sents)) return sents + new_sents def augment_move_comma(sents, ratio=0.02): """ Move the comma from after a word to before the next word some fraction of the time We looks for this exact pattern: w1, w2 and replace it with w1 ,w2 The idea is that this is a relatively common typo, but the tool won't learn how to tokenize it without some help. Note that this modification replaces the original text. """ new_sents = [] num_operations = 0 for sentence in sents: if random.random() > ratio: new_sents.append(sentence) continue found = False for word_idx, word in enumerate(sentence): if word.startswith("#"): continue if word_idx == 0 or word_idx >= len(sentence) - 2: continue pieces = word.split("\t") if pieces[1] == ',' and not has_space_after_no(pieces[-1]): # found a comma with a space after it prev_word = sentence[word_idx-1] if not has_space_after_no(prev_word.split("\t")[-1]): # unfortunately, the previous word also had a # space after it. does not fit what we are # looking for continue # also, want to skip instances near MWT or copy nodes, # since those are harder to rearrange next_word = sentence[word_idx+1] if MWT_OR_COPY_RE.match(next_word.split("\t")[0]): continue if MWT_OR_COPY_RE.match(prev_word.split("\t")[0]): continue # at this point, the previous word has no space and the comma does found = True break if not found: new_sents.append(sentence) continue new_sentence = list(sentence) pieces = new_sentence[word_idx].split("\t") pieces[-1] = add_space_after_no(pieces[-1]) new_sentence[word_idx] = "\t".join(pieces) pieces = new_sentence[word_idx-1].split("\t") prev_word = pieces[1] pieces[-1] = remove_space_after_no(pieces[-1]) new_sentence[word_idx-1] = "\t".join(pieces) next_word = new_sentence[word_idx+1].split("\t")[1] for text_idx, text_line in enumerate(sentence): # look for the line that starts with "# text". # keep going until we find it, or silently ignore it # if the dataset isn't in that format if text_line.startswith("# text"): old_chunk = prev_word + ", " + next_word new_chunk = prev_word + " ," + next_word word_idx = text_line.find(old_chunk) if word_idx < 0: raise RuntimeError("Unexpected #text line which did not contain the original text to be modified. Looking for\n" + old_chunk + "\n" + text_line) new_text_line = text_line[:word_idx] + new_chunk + text_line[word_idx+len(old_chunk):] new_sentence[text_idx] = new_text_line break new_sents.append(new_sentence) num_operations = num_operations + 1 print("Swapped 'w1, w2' for 'w1 ,w2' %d times" % num_operations) return new_sents def augment_apos(sents): """ If there are no instances of ’ in the dataset, but there are instances of ', we replace some fraction of ' with ’ so that the tokenizer will recognize it. # TODO: we could do it the other way around as well """ has_unicode_apos = False has_ascii_apos = False for sent_idx, sent in enumerate(sents): if len(sent) == 0: raise AssertionError("Got a blank sentence in position %d!" % sent_idx) for line in sent: if line.startswith("# text"): if line.find("'") >= 0: has_ascii_apos = True if line.find("’") >= 0: has_unicode_apos = True break else: raise ValueError("Cannot find '# text' in sentences %d. First line: %s" % (sent_idx, sent[0])) if has_unicode_apos or not has_ascii_apos: return sents new_sents = [] for sent in sents: if random.random() > 0.05: new_sents.append(sent) continue new_sent = [] for line in sent: if line.startswith("# text"): new_sent.append(line.replace("'", "’")) elif line.startswith("#"): new_sent.append(line) else: pieces = line.split("\t") pieces[1] = pieces[1].replace("'", "’") new_sent.append("\t".join(pieces)) new_sents.append(new_sent) return new_sents def augment_ellipses(sents): """ Replaces a fraction of '...' with '…' """ has_ellipses = False has_unicode_ellipses = False for sent in sents: for line in sent: if line.startswith("#"): continue pieces = line.split("\t") if pieces[1] == '...': has_ellipses = True elif pieces[1] == '…': has_unicode_ellipses = True if has_unicode_ellipses or not has_ellipses: return sents new_sents = [] num_updated = 0 for sent in sents: if random.random() > 0.1: new_sents.append(sent) continue found = False new_sent = [] for line in sent: if line.startswith("#"): new_sent.append(line) else: pieces = line.split("\t") if pieces[1] == '...': pieces[1] = '…' found = True new_sent.append("\t".join(pieces)) new_sents.append(new_sent) if found: num_updated = num_updated + 1 print("Changed %d sentences to use fancy unicode ellipses" % num_updated) return new_sents # https://en.wikipedia.org/wiki/Quotation_mark QUOTES = ['"', '“', '”', '«', '»', '「', '」', '《', '》', '„', '″'] QUOTES_RE = re.compile("(.?)[" + "".join(QUOTES) + "](.+)[" + "".join(QUOTES) + "](.?)") # Danish does '«' the other way around from most European languages START_QUOTES = ['"', '“', '”', '«', '»', '「', '《', '„', '„', '″'] END_QUOTES = ['"', '“', '”', '»', '«', '」', '》', '”', '“', '″'] def augment_quotes(sents, ratio=0.15): """ Go through the sentences and replace a fraction of sentences with alternate quotes TODO: for certain languages we may want to make some language-specific changes eg Danish, don't add «...» """ assert len(START_QUOTES) == len(END_QUOTES) counts = Counter() new_sents = [] for sent in sents: if random.random() > ratio: new_sents.append(sent) continue # count if there are exactly 2 quotes in this sentence # this is for convenience - otherwise we need to figure out which pairs go together count_quotes = sum(1 for x in sent if (not x.startswith("#") and x.split("\t")[1] in QUOTES)) if count_quotes != 2: new_sents.append(sent) continue # choose a pair of quotes from the candidates quote_idx = random.choice(range(len(START_QUOTES))) start_quote = START_QUOTES[quote_idx] end_quote = END_QUOTES[quote_idx] counts[start_quote + end_quote] = counts[start_quote + end_quote] + 1 new_sent = [] saw_start = False for line in sent: if line.startswith("#"): new_sent.append(line) continue pieces = line.split("\t") if pieces[1] in QUOTES: if saw_start: # Note that we don't change the lemma. Presumably it's # set to the correct lemma for a quote for this treebank pieces[1] = end_quote else: pieces[1] = start_quote saw_start = True new_sent.append("\t".join(pieces)) else: new_sent.append(line) for text_idx, text_line in enumerate(new_sent): # look for the line that starts with "# text". # keep going until we find it, or silently ignore it # if the dataset isn't in that format if text_line.startswith("# text"): replacement = "\\1%s\\2%s\\3" % (start_quote, end_quote) new_text_line = QUOTES_RE.sub(replacement, text_line) new_sent[text_idx] = new_text_line new_sents.append(new_sent) # we go through this to make it simpler to execute on Windows # rather than nagging the user to set utf-8 out = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) print("Augmented {} quotes: {}".format(sum(counts.values()), counts), file=out) out.detach() return new_sents def find_text_idx(sentence): """ Return the index of the # text line or -1 """ for idx, line in enumerate(sentence): if line.startswith("# text"): return idx return -1 DIGIT_RE = re.compile("[0-9]") def change_indices(line, delta): """ Adjust all indices in the given sentence by delta. Useful when removing a word, for example """ if line.startswith("#"): return line pieces = line.split("\t") if MWT_RE.match(pieces[0]): indices = pieces[0].split("-") pieces[0] = "%d-%d" % (int(indices[0]) + delta, int(indices[1]) + delta) line = "\t".join(pieces) return line if MWT_OR_COPY_RE.match(pieces[0]): index_pieces = pieces[0].split(".", maxsplit=1) pieces[0] = "%d.%s" % (int(index_pieces[0]) + delta, index_pieces[1]) elif not INT_RE.match(pieces[0]): raise NotImplementedError("Unknown index type: %s" % pieces[0]) else: pieces[0] = str(int(pieces[0]) + delta) if pieces[6] != '_': # copy nodes don't have basic dependencies in the es_ancora treebank dep = int(pieces[6]) if dep != 0: pieces[6] = str(int(dep) + delta) if pieces[8] != '_': dep_pieces = pieces[8].split(":", maxsplit=1) if DIGIT_RE.search(dep_pieces[1]): raise NotImplementedError("Need to handle multiple additional deps:\n%s" % line) if int(dep_pieces[0]) != 0: pieces[8] = str(int(dep_pieces[0]) + delta) + ":" + dep_pieces[1] line = "\t".join(pieces) return line def augment_initial_punct(sents, ratio=0.20): """ If a sentence starts with certain punct marks, occasionally use the same sentence without the initial punct. Currently this just handles ¿ This helps languages such as CA and ES where the models go awry when the initial ¿ is missing. """ new_sents = [] for sent in sents: if random.random() > ratio: continue text_idx = find_text_idx(sent) text_line = sent[text_idx] if text_line.count("¿") != 1: # only handle sentences with exactly one ¿ continue # find the first line with actual text for idx, line in enumerate(sent): if line.startswith("#"): continue break if idx >= len(sent) - 1: raise ValueError("Unexpectedly an entire sentence is comments") pieces = line.split("\t") if pieces[1] != '¿': continue if has_space_after_no(pieces[-1]): replace_text = "¿" else: replace_text = "¿ " new_sent = sent[:idx] + sent[idx+1:] new_sent[text_idx] = text_line.replace(replace_text, "") # now need to update all indices new_sent = [change_indices(x, -1) for x in new_sent] new_sents.append(new_sent) if len(new_sents) > 0: print("Added %d sentences with the leading ¿ removed" % len(new_sents)) return sents + new_sents def augment_brackets(sents, ratio=0.1): """ If there are no sentences with [], transform some () into [] """ new_sents = [] for sent in sents: text_idx = find_text_idx(sent) text_line = sent[text_idx] if text_line.count("[") > 0 or text_line.count("]") > 0: # found a square bracket, so, never mind return sents for sent in sents: if random.random() > ratio: continue text_idx = find_text_idx(sent) text_line = sent[text_idx] if text_line.count("(") == 0 and text_line.count(")") == 0: continue text_line = text_line.replace("(", "[").replace(")", "]") new_sent = list(sent) new_sent[text_idx] = text_line for idx, line in enumerate(new_sent): if line.startswith("#"): continue pieces = line.split("\t") if pieces[1] == '(': pieces[1] = '[' elif pieces[1] == ')': pieces[1] = ']' new_sent[idx] = "\t".join(pieces) new_sents.append(new_sent) if len(new_sents) > 0: print("Added %d sentences with parens replaced with square brackets" % len(new_sents)) return sents + new_sents def augment_punct(sents): """ If there are no instances of ’ in the dataset, but there are instances of ', we replace some fraction of ' with ’ so that the tokenizer will recognize it. Also augments with ... / … """ new_sents = augment_apos(sents) new_sents = augment_quotes(new_sents) new_sents = augment_move_comma(new_sents) new_sents = augment_comma_separations(new_sents) new_sents = augment_initial_punct(new_sents) new_sents = augment_ellipses(new_sents) new_sents = augment_brackets(new_sents) return new_sents def remove_accents_from_words(sents): new_sents = [] for sent in sents: new_sent = [] for line in sent: if line.startswith("#"): new_sent.append(line) else: pieces = line.split("\t") pieces[1] = common.strip_accents(pieces[1]) new_sent.append("\t".join(pieces)) new_sents.append(new_sent) return new_sents def augment_accents(sents): return sents + remove_accents_from_words(sents) def write_augmented_dataset(input_conllu, output_conllu, augment_function): # set the seed for each data file so that the results are the same # regardless of how many treebanks are processed at once random.seed(1234) # read and shuffle conllu data sents = read_sentences_from_conllu(input_conllu) # the actual meat of the function - produce new sentences new_sents = augment_function(sents) write_sentences_to_conllu(output_conllu, new_sents) def remove_spaces_from_sentences(sents): """ Makes sure every word in the list of sentences has SpaceAfter=No. Returns a new list of sentences """ new_sents = [] for sentence in sents: new_sentence = [] for word in sentence: if word.startswith("#"): new_sentence.append(word) continue pieces = word.split("\t") if pieces[-1] == "_": pieces[-1] = "SpaceAfter=No" elif pieces[-1].find("SpaceAfter=No") >= 0: pass else: raise ValueError("oops") word = "\t".join(pieces) new_sentence.append(word) new_sents.append(new_sentence) return new_sents def remove_spaces(input_conllu, output_conllu): """ Turns a dataset into something appropriate for building a segmenter. For example, this works well on the Korean datasets. """ sents = read_sentences_from_conllu(input_conllu) new_sents = remove_spaces_from_sentences(sents) write_sentences_to_conllu(output_conllu, new_sents) def build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu): """ Builds a combined dataset out of multiple Korean datasets. Currently this uses GSD and Kaist. If a segmenter-appropriate dataset was requested, spaces are removed. TODO: we need to handle the difference in xpos tags somehow. """ gsd_conllu = common.find_treebank_dataset_file("UD_Korean-GSD", udbase_dir, dataset, "conllu") kaist_conllu = common.find_treebank_dataset_file("UD_Korean-Kaist", udbase_dir, dataset, "conllu") sents = read_sentences_from_conllu(gsd_conllu) + read_sentences_from_conllu(kaist_conllu) segmenter = short_name.endswith("_seg") if segmenter: sents = remove_spaces_from_sentences(sents) write_sentences_to_conllu(output_conllu, sents) def build_combined_korean(udbase_dir, tokenizer_dir, short_name): for dataset in ("train", "dev", "test"): output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu) def build_combined_italian_dataset(paths, model_type, dataset): udbase_dir = paths["UDBASE"] if dataset == 'train': # could maybe add ParTUT, but that dataset has a slightly different xpos set # (no DE or I) # and I didn't feel like sorting through the differences # TODO: for that dataset, can try adding it without the xpos or feats on ParTUT treebanks = [ "UD_Italian-ISDT", "UD_Italian-VIT", ] if model_type is not common.ModelType.TOKENIZER: treebanks.extend([ "UD_Italian-TWITTIRO", "UD_Italian-PoSTWITA" ]) print("Building {} dataset out of {}".format(model_type, " ".join(treebanks))) sents = [] for treebank in treebanks: conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) sents.extend(read_sentences_from_conllu(conllu_file)) else: istd_conllu = common.find_treebank_dataset_file("UD_Italian-ISDT", udbase_dir, dataset, "conllu") sents = read_sentences_from_conllu(istd_conllu) return sents def check_gum_ready(udbase_dir): gum_conllu = common.find_treebank_dataset_file("UD_English-GUMReddit", udbase_dir, "train", "conllu") if common.mostly_underscores(gum_conllu): raise ValueError("Cannot process UD_English-GUMReddit in its current form. There should be a download script available in the directory which will help integrate the missing proprietary values. Please run that script to update the data, then try again.") def build_combined_english_dataset(paths, model_type, dataset): """ en_combined is currently EWT, GUM, PUD, Pronouns, and handparsed """ udbase_dir = paths["UDBASE_GIT"] check_gum_ready(udbase_dir) if dataset == 'train': # TODO: include more UD treebanks, possibly with xpos removed # UD_English-ParTUT - xpos are different # also include "external" treebanks such as PTB # NOTE: in order to get the best results, make sure each of these treebanks have the latest edits applied train_treebanks = ["UD_English-EWT", "UD_English-GUM", "UD_English-GUMReddit"] test_treebanks = ["UD_English-PUD", "UD_English-Pronouns"] sents = [] for treebank in train_treebanks: conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(new_sents), conllu_file)) sents.extend(new_sents) for treebank in test_treebanks: conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(new_sents), conllu_file)) sents.extend(new_sents) else: ewt_conllu = common.find_treebank_dataset_file("UD_English-EWT", udbase_dir, dataset, "conllu") sents = read_sentences_from_conllu(ewt_conllu) return sents def add_english_sentence_final_punctuation(handparsed_sentences): """ Add a period to the end of a sentence with no punct at the end. The next-to-last word has SpaceAfter=No added as well. Possibly English-specific because of the xpos. Could be upgraded to handle multiple languages by passing in the xpos as an argument """ new_sents = [] for sent in handparsed_sentences: root_id = None max_id = None last_punct = False for line in sent: if line.startswith("#"): continue pieces = line.split("\t") if MWT_OR_COPY_RE.match(pieces[0]): continue if pieces[6] == '0': root_id = pieces[0] max_id = int(pieces[0]) last_punct = pieces[3] == 'PUNCT' if not last_punct: new_sent = list(sent) pieces = new_sent[-1].split("\t") pieces[-1] = add_space_after_no(pieces[-1]) new_sent[-1] = "\t".join(pieces) new_sent.append("%d\t.\t.\tPUNCT\t.\t_\t%s\tpunct\t%s:punct\t_" % (max_id+1, root_id, root_id)) new_sents.append(new_sent) else: new_sents.append(sent) return new_sents def build_extra_combined_french_dataset(paths, model_type, dataset): """ Extra sentences we don't want augmented for French - currently, handparsed lemmas """ handparsed_dir = paths["HANDPARSED_DIR"] sents = [] if dataset == 'train': if model_type is common.ModelType.LEMMA: handparsed_path = os.path.join(handparsed_dir, "french-lemmas", "fr_lemmas.conllu") handparsed_sentences = read_sentences_from_conllu(handparsed_path) print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path)) sents.extend(handparsed_sentences) handparsed_path = os.path.join(handparsed_dir, "french-lemmas", "french1st_6thGrade.conllu") handparsed_sentences = read_sentences_from_conllu(handparsed_path) print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path)) sents.extend(handparsed_sentences) return sents def build_extra_combined_german_dataset(paths, model_type, dataset): """ Extra sentences we don't want augmented for German Currently, this is just the lemmas from Wiktionary """ handparsed_dir = paths["HANDPARSED_DIR"] sents = [] if dataset == 'train': if model_type is common.ModelType.LEMMA: handparsed_path = os.path.join(handparsed_dir, "german-lemmas-wiki", "de_wiki_lemmas.conllu") handparsed_sentences = read_sentences_from_conllu(handparsed_path) print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path)) sents.extend(handparsed_sentences) return sents def build_extra_combined_english_dataset(paths, model_type, dataset): """ Extra sentences we don't want augmented """ handparsed_dir = paths["HANDPARSED_DIR"] sents = [] if dataset == 'train': handparsed_path = os.path.join(handparsed_dir, "english-handparsed", "english.conll") handparsed_sentences = read_sentences_from_conllu(handparsed_path) handparsed_sentences = add_english_sentence_final_punctuation(handparsed_sentences) sents.extend(handparsed_sentences) print("Loaded %d sentences from %s" % (len(sents), handparsed_path)) if model_type is common.ModelType.LEMMA: handparsed_path = os.path.join(handparsed_dir, "english-lemmas", "en_lemmas.conllu") handparsed_sentences = read_sentences_from_conllu(handparsed_path) print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path)) sents.extend(handparsed_sentences) handparsed_path = os.path.join(handparsed_dir, "english-lemmas-verbs", "irregularVerbs-noNnoAdj.conllu") handparsed_sentences = read_sentences_from_conllu(handparsed_path) print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path)) sents.extend(handparsed_sentences) handparsed_path = os.path.join(handparsed_dir, "english-lemmas-adj", "en_adj.conllu") handparsed_sentences = read_sentences_from_conllu(handparsed_path) print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path)) sents.extend(handparsed_sentences) return sents def build_extra_combined_italian_dataset(paths, model_type, dataset): """ Extra data - the MWT data for Italian """ handparsed_dir = paths["HANDPARSED_DIR"] if dataset != 'train': return [] extra_italian = os.path.join(handparsed_dir, "italian-mwt", "italian.mwt") if not os.path.exists(extra_italian): raise FileNotFoundError("Cannot find the extra dataset 'italian.mwt' which includes various multi-words retokenized, expected {}".format(extra_italian)) extra_sents = read_sentences_from_conllu(extra_italian) for sentence in extra_sents: if not sentence[2].endswith("_") or not MWT_RE.match(sentence[2]): raise AssertionError("Unexpected format of the italian.mwt file. Has it already be modified to have SpaceAfter=No everywhere?") sentence[2] = sentence[2][:-1] + "SpaceAfter=No" print("Loaded %d sentences from %s" % (len(extra_sents), extra_italian)) return extra_sents def replace_semicolons(sentences): """ Spanish GSD and AnCora have different standards for semicolons. GSD has semicolons at the end of sentences, AnCora has them in the middle as clause separators. Consecutive sentences in GSD do not seem to be related, so there is no combining that can be done. The easiest solution is to replace sentence final semicolons with "." in GSD """ new_sents = [] count = 0 for sentence in sentences: for text_idx, text_line in enumerate(sentence): if text_line.startswith("# text"): break else: raise ValueError("Expected every sentence in GSD to have a # text field") if not text_line.endswith(";"): new_sents.append(sentence) continue count = count + 1 new_sent = list(sentence) new_sent[text_idx] = text_line[:-1] + "." new_sent[-1] = new_sent[-1].replace(";", ".") count = count + 1 new_sents.append(new_sent) print("Updated %d sentences to replace sentence-final ; with ." % count) return new_sents def strip_column(sents, column): """ Removes a specified column from the given dataset Particularly useful when mixing two different POS formalisms in the same tagger """ new_sents = [] for sentence in sents: new_sent = [] for word in sentence: if word.startswith("#"): new_sent.append(word) continue pieces = word.split("\t") pieces[column] = "_" new_sent.append("\t".join(pieces)) new_sents.append(new_sent) return new_sents def strip_xpos(sents): """ Removes all xpos from the given dataset Particularly useful when mixing two different POS formalisms in the same tagger """ return strip_column(sents, 4) def strip_feats(sents): """ Removes all features from the given dataset Particularly useful when mixing two different POS formalisms in the same tagger """ return strip_column(sents, 5) def build_combined_japanese_dataset(paths, model_type, dataset): """ GSD with a handparsed dataset of some short verb phrases """ udbase_dir = paths["UDBASE"] handparsed_dir = paths["HANDPARSED_DIR"] treebank = "UD_Japanese-GSD" conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) gsd_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(gsd_sents), conllu_file)) if dataset == 'train': extra_japanese = os.path.join(handparsed_dir, "japanese-handparsed", "spaces-ready-checked.conllu") if not os.path.exists(extra_japanese): raise FileNotFoundError("Cannot find the extra dataset which includes various verb patterns, expected {}".format(extra_japanese)) extra_sents = read_sentences_from_conllu(extra_japanese) print("Read %d sentences from %s" % (len(extra_sents), extra_japanese)) if model_type == common.ModelType.POS: documents = {} documents[treebank] = gsd_sents documents['handparsed'] = extra_sents return documents else: sents = gsd_sents + extra_sents return sents else: return gsd_sents def build_combined_albanian_dataset(paths, model_type, dataset): """ sq_combined is STAF as the base, with TSA added for some things """ udbase_dir = paths["UDBASE"] udbase_git_dir = paths["UDBASE_GIT"] handparsed_dir = paths["HANDPARSED_DIR"] treebanks = ["UD_Albanian-STAF", "UD_Albanian-TSA"] if dataset == 'train' and model_type == common.ModelType.POS: documents = {} conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, "train", "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) documents[treebanks[0]] = new_sents # we use udbase_git_dir for TSA because of an updated MWT scheme conllu_file = common.find_treebank_dataset_file(treebanks[1], udbase_git_dir, "test", "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) new_sents = strip_xpos(new_sents) new_sents = strip_feats(new_sents) documents[treebanks[1]] = new_sents return documents if dataset == 'train' and model_type is not common.ModelType.DEPPARSE: sents = [] conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, "train", "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(new_sents), conllu_file)) sents.extend(new_sents) conllu_file = common.find_treebank_dataset_file(treebanks[1], udbase_git_dir, "test", "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(new_sents), conllu_file)) sents.extend(new_sents) return sents conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, dataset, "conllu", fail=True) sents = read_sentences_from_conllu(conllu_file) return sents def build_combined_german_dataset(paths, model_type, dataset): """ de_combined is currently GSD, with lemma information from Wiktionary the lemma information is added in build_extra_combined_german_dataset TODO: quite a bit of HDT we could possibly use """ udbase_dir = paths["UDBASE"] treebanks = ["UD_German-GSD"] treebank = treebanks[0] conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) sents = read_sentences_from_conllu(conllu_file) return sents def build_combined_spanish_dataset(paths, model_type, dataset): """ es_combined is AnCora and GSD put together For POS training, we put the different datasets into a zip file so that we can keep the conllu files separate and remove the xpos from the non-AnCora training files. It is necessary to remove the xpos because GSD and PUD both use different xpos schemes from AnCora, and the tagger can use additional data files as training data without a specific column if that column is entirely blank TODO: consider mixing in PUD? """ udbase_dir = paths["UDBASE"] handparsed_dir = paths["HANDPARSED_DIR"] treebanks = ["UD_Spanish-AnCora", "UD_Spanish-GSD"] if dataset == 'train' and model_type == common.ModelType.POS: documents = {} for treebank in treebanks: conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) if not treebank.endswith("AnCora"): new_sents = strip_xpos(new_sents) documents[treebank] = new_sents return documents if dataset == 'train': sents = [] for treebank in treebanks: conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(new_sents), conllu_file)) if treebank.endswith("GSD"): new_sents = replace_semicolons(new_sents) sents.extend(new_sents) if model_type in (common.ModelType.TOKENIZER, common.ModelType.MWT, common.ModelType.LEMMA): extra_spanish = os.path.join(handparsed_dir, "spanish-mwt", "adjectives.conllu") if not os.path.exists(extra_spanish): raise FileNotFoundError("Cannot find the extra dataset 'adjectives.conllu' which includes various multi-words retokenized, expected {}".format(extra_spanish)) extra_sents = read_sentences_from_conllu(extra_spanish) print("Read %d sentences from %s" % (len(extra_sents), extra_spanish)) sents.extend(extra_sents) else: conllu_file = common.find_treebank_dataset_file("UD_Spanish-AnCora", udbase_dir, dataset, "conllu", fail=True) sents = read_sentences_from_conllu(conllu_file) return sents def build_combined_french_dataset(paths, model_type, dataset): udbase_dir = paths["UDBASE"] handparsed_dir = paths["HANDPARSED_DIR"] if dataset == 'train': train_treebanks = ["UD_French-GSD", "UD_French-ParisStories", "UD_French-Rhapsodie", "UD_French-Sequoia"] sents = [] for treebank in train_treebanks: conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(new_sents), conllu_file)) sents.extend(new_sents) extra_french = os.path.join(handparsed_dir, "french-handparsed", "handparsed_deps.conllu") if not os.path.exists(extra_french): raise FileNotFoundError("Cannot find the extra dataset 'handparsed_deps.conllu' which includes various dependency fixes, expected {}".format(extra_italian)) extra_sents = read_sentences_from_conllu(extra_french) print("Read %d sentences from %s" % (len(extra_sents), extra_french)) sents.extend(extra_sents) else: gsd_conllu = common.find_treebank_dataset_file("UD_French-GSD", udbase_dir, dataset, "conllu") sents = read_sentences_from_conllu(gsd_conllu) return sents def build_combined_hebrew_dataset(paths, model_type, dataset): """ Combines the IAHLT treebank with an updated form of HTB where the annotation style more closes matches IAHLT Currently the updated HTB is not in UD, so you will need to clone git@github.com:IAHLT/UD_Hebrew.git to $UDBASE_GIT dev and test sets will be those from IAHLT """ udbase_dir = paths["UDBASE"] udbase_git_dir = paths["UDBASE_GIT"] treebanks = ["UD_Hebrew-IAHLTwiki", "UD_Hebrew-IAHLTknesset"] if dataset == 'train': sents = [] for treebank in treebanks: conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) new_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(new_sents), conllu_file)) sents.extend(new_sents) # if/when this gets ported back to UD, switch to getting both datasets from UD hebrew_git_dir = os.path.join(udbase_git_dir, "UD_Hebrew") if not os.path.exists(hebrew_git_dir): raise FileNotFoundError("Please download git@github.com:IAHLT/UD_Hebrew.git to %s (based on $UDBASE_GIT)" % hebrew_git_dir) conllu_file = os.path.join(hebrew_git_dir, "he_htb-ud-train.conllu") if not os.path.exists(conllu_file): raise FileNotFoundError("Found %s but inexplicably there was no %s" % (hebrew_git_dir, conllu_file)) new_sents = read_sentences_from_conllu(conllu_file) print("Read %d sentences from %s" % (len(new_sents), conllu_file)) sents.extend(new_sents) else: conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, dataset, "conllu", fail=True) sents = read_sentences_from_conllu(conllu_file) return sents COMBINED_FNS = { "de_combined": build_combined_german_dataset, "en_combined": build_combined_english_dataset, "es_combined": build_combined_spanish_dataset, "fr_combined": build_combined_french_dataset, "he_combined": build_combined_hebrew_dataset, "it_combined": build_combined_italian_dataset, "ja_combined": build_combined_japanese_dataset, "sq_combined": build_combined_albanian_dataset, } # some extra data for the combined models without augmenting COMBINED_EXTRA_FNS = { "de_combined": build_extra_combined_german_dataset, "en_combined": build_extra_combined_english_dataset, "fr_combined": build_extra_combined_french_dataset, "it_combined": build_extra_combined_italian_dataset, } def build_combined_dataset(paths, short_name, model_type, augment): random.seed(1234) tokenizer_dir = paths["TOKENIZE_DATA_DIR"] build_fn = COMBINED_FNS[short_name] extra_fn = COMBINED_EXTRA_FNS.get(short_name, None) for dataset in ("train", "dev", "test"): output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) sents = build_fn(paths, model_type, dataset) if isinstance(sents, dict): if dataset == 'train' and augment: for filename in list(sents.keys()): sents[filename] = augment_punct(sents[filename]) output_zip = os.path.splitext(output_conllu)[0] + ".zip" with zipfile.ZipFile(output_zip, "w") as zout: for filename in list(sents.keys()): with zout.open(filename + ".conllu", "w") as zfout: with io.TextIOWrapper(zfout, encoding='utf-8', newline='') as fout: write_sentences_to_file(fout, sents[filename]) else: if dataset == 'train' and augment: sents = augment_punct(sents) if extra_fn is not None: sents.extend(extra_fn(paths, model_type, dataset)) write_sentences_to_conllu(output_conllu, sents) BIO_DATASETS = ("en_craft", "en_genia", "en_mimic") def build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_name, model_type, augment): """ Process the en bio datasets Creates a dataset by combining the en_combined data with one of the bio sets """ random.seed(1234) name, bio_dataset = short_name.split("_") assert name == 'en' for dataset in ("train", "dev", "test"): output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) if dataset == 'train': sents = build_combined_english_dataset(paths, model_type, dataset) if dataset == 'train' and augment: sents = augment_punct(sents) else: sents = [] bio_file = os.path.join(paths["BIO_UD_DIR"], "UD_English-%s" % bio_dataset.upper(), "en_%s-ud-%s.conllu" % (bio_dataset.lower(), dataset)) new_sents = read_sentences_from_conllu(bio_file) print("Read %d sentences from %s" % (len(new_sents), bio_file)) sents.extend(new_sents) write_sentences_to_conllu(output_conllu, sents) def build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment): """ Build the GUM dataset by combining GUMReddit It checks to make sure GUMReddit is filled out using the included script """ check_gum_ready(udbase_dir) random.seed(1234) output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) treebanks = ["UD_English-GUM", "UD_English-GUMReddit"] sents = [] for treebank in treebanks: conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) sents.extend(read_sentences_from_conllu(conllu_file)) if dataset == 'train' and augment: sents = augment_punct(sents) write_sentences_to_conllu(output_conllu, sents) def build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, augment): for dataset in ("train", "dev", "test"): build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment) def prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, dataset, augment=True, input_conllu=None, output_conllu=None): if input_conllu is None: input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) if output_conllu is None: output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) print("Reading from %s and writing to %s" % (input_conllu, output_conllu)) if short_name == "te_mtg" and dataset == 'train' and augment: write_augmented_dataset(input_conllu, output_conllu, augment_telugu) elif short_name.startswith("ko_") and short_name.endswith("_seg"): remove_spaces(input_conllu, output_conllu) elif short_name.startswith("grc_") and short_name.endswith("-diacritics"): write_augmented_dataset(input_conllu, output_conllu, augment_accents) elif dataset == 'train' and augment: write_augmented_dataset(input_conllu, output_conllu, augment_punct) else: sents = read_sentences_from_conllu(input_conllu) write_sentences_to_conllu(output_conllu, sents) def process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, augment=True): """ Process a normal UD treebank with train/dev/test splits SL-SSJ and other datasets with inline modifications all use this code path as well. """ prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "train", augment) prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "dev", augment) prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "test", augment) XV_RATIO = 0.2 def process_test_only_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language): """ Process a large UD treebank with only a test Return False if the treebank is too small """ train_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu") dev_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu") test_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu") if train_input_conllu or dev_input_conllu: return False train_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "train") dev_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "dev") test_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "test") if common.num_words_in_file(test_input_conllu) <= 10000: return False if not split_conllu_file(treebank=treebank, input_conllu=test_input_conllu, train_output_conllu=train_output_conllu, dev_output_conllu=dev_output_conllu, test_output_conllu=test_output_conllu): return False return True def process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language): """ Process a UD treebank with only train/test splits For example, in UD 2.7: UD_Buryat-BDT UD_Galician-TreeGal UD_Indonesian-CSUI UD_Kazakh-KTB UD_Kurmanji-MG UD_Latin-Perseus UD_Livvi-KKPP UD_North_Sami-Giella UD_Old_Russian-RNC UD_Sanskrit-Vedic UD_Slovenian-SST UD_Upper_Sorbian-UFAL UD_Welsh-CCG Returns True if successful, False if not """ train_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu") test_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu") train_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "train") dev_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "dev") test_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "test") if (common.num_words_in_file(train_input_conllu) <= 1000 and common.num_words_in_file(test_input_conllu) > 5000): train_input_conllu, test_input_conllu = test_input_conllu, train_input_conllu if not split_train_file(treebank=treebank, train_input_conllu=train_input_conllu, train_output_conllu=train_output_conllu, dev_output_conllu=dev_output_conllu): return False # the test set is already fine # currently we do not do any augmentation of these partial treebanks prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "test", augment=False, input_conllu=test_input_conllu, output_conllu=test_output_conllu) return True def add_specific_args(parser): parser.add_argument('--no_augment', action='store_false', dest='augment', default=True, help='Augment the dataset in various ways') parser.add_argument('--no_prepare_labels', action='store_false', dest='prepare_labels', default=True, help='Prepare tokenizer and MWT labels. Expensive, but obviously necessary for training those models.') convert_th_lst20.add_lst20_args(parser) convert_vi_vlsp.add_vlsp_args(parser) def process_treebank(treebank, model_type, paths, args): """ Processes a single treebank into train, dev, test parts Includes processing for a few external tokenization datasets: vi_vlsp, th_orchid, th_best Also, there is no specific mechanism for UD_Arabic-NYUAD or similar treebanks, which need integration with LDC datsets """ udbase_dir = paths["UDBASE"] tokenizer_dir = paths["TOKENIZE_DATA_DIR"] handparsed_dir = paths["HANDPARSED_DIR"] short_name = treebank_to_short_name(treebank) short_language = short_name.split("_")[0] os.makedirs(tokenizer_dir, exist_ok=True) success = False if short_name == "my_alt": convert_my_alt.convert_my_alt(paths["CONSTITUENCY_BASE"], tokenizer_dir) elif short_name == "vi_vlsp": convert_vi_vlsp.convert_vi_vlsp(paths["STANZA_EXTERN_DIR"], tokenizer_dir, args) elif short_name == "th_orchid": convert_th_orchid.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir) elif short_name == "th_lst20": convert_th_lst20.convert(paths["STANZA_EXTERN_DIR"], tokenizer_dir, args) elif short_name == "th_best": convert_th_best.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir) elif short_name == "ml_cochin": convert_ml_cochin.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir) elif short_name.startswith("ko_combined"): build_combined_korean(udbase_dir, tokenizer_dir, short_name) elif short_name in COMBINED_FNS: # eg "it_combined", "en_combined", etc build_combined_dataset(paths, short_name, model_type, args.augment) elif short_name in BIO_DATASETS: build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_name, model_type, args.augment) elif short_name.startswith("en_gum"): # we special case GUM because it should include a filled-out GUMReddit print("Preparing data for %s: %s, %s" % (treebank, short_name, short_language)) build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, args.augment) else: # check that we can find the train file where we expect it train_conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=False) if not train_conllu_file: # maybe this dataset has a huge test set we can split? test_conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu", fail=True) print("Checking data for %s: %s, %s to see if the test dataset is large enough" % (treebank, short_name, short_language)) success = process_test_only_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language) else: print("Preparing data for %s: %s, %s" % (treebank, short_name, short_language)) if not common.find_treebank_dataset_file(treebank, udbase_dir, "dev", "conllu", fail=False): success = process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language) else: process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, args.augment) if success and (model_type is common.ModelType.TOKENIZER or model_type is common.ModelType.MWT): if not short_name in ('th_orchid', 'th_lst20'): common.convert_conllu_to_txt(tokenizer_dir, short_name) if args.prepare_labels: common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name) def main(): common.main(process_treebank, common.ModelType.TOKENIZER, add_specific_args) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/pretrain/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/pretrain/word_in_pretrain.py ================================================ """ Simple tool to query a word vector file to see if certain words are in that file """ import argparse import os from stanza.models.common.pretrain import Pretrain from stanza.resources.common import DEFAULT_MODEL_DIR, download def main(): parser = argparse.ArgumentParser() group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--pretrain", default=None, type=str, help="Where to read the converted PT file") group.add_argument("--package", default=None, type=str, help="Use a pretrain package instead") parser.add_argument("--download_json", default=False, action='store_true', help="Download the json even if it already exists") parser.add_argument("words", type=str, nargs="+", help="Which words to search for") args = parser.parse_args() if args.pretrain: pt = Pretrain(args.pretrain) else: lang, package = args.package.split("_", 1) download(lang=lang, package=None, processors={"pretrain": package}, download_json=args.download_json) pt_filename = os.path.join(DEFAULT_MODEL_DIR, lang, "pretrain", "%s.pt" % package) pt = Pretrain(pt_filename) for word in args.words: print("{}: {}".format(word, word in pt.vocab)) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/datasets/random_split_conllu.py ================================================ """ Randomly split a file into train, dev, and test sections Specifically used in the case of building a tagger from the initial POS tagging provided by Isra, but obviously can be used to split any conllu file """ import argparse import os import random from stanza.models.common.doc import Document from stanza.utils.conll import CoNLL from stanza.utils.default_paths import get_default_paths def random_split(doc, weights, remove_xpos=False, remove_feats=False): """ weights: a tuple / list of (train, dev, test) weights """ train_doc = ([], []) dev_doc = ([], []) test_doc = ([], []) splits = [train_doc, dev_doc, test_doc] for sentence in doc.sentences: sentence_dict = sentence.to_dict() if remove_xpos: for x in sentence_dict: x.pop('xpos', None) if remove_feats: for x in sentence_dict: x.pop('feats', None) split = random.choices(splits, weights)[0] split[0].append(sentence_dict) split[1].append(sentence.comments) splits = [Document(split[0], comments=split[1]) for split in splits] return splits def main(): parser = argparse.ArgumentParser() parser.add_argument('--filename', default='extern_data/sindhi/upos/sindhi_upos.conllu', help='Which file to split') parser.add_argument('--train', type=float, default=0.8, help='Fraction of the data to use for train') parser.add_argument('--dev', type=float, default=0.1, help='Fraction of the data to use for dev') parser.add_argument('--test', type=float, default=0.1, help='Fraction of the data to use for test') parser.add_argument('--seed', default='1234', help='Random seed to use') parser.add_argument('--short_name', default='sd_isra', help='Dataset name to use when writing output files') parser.add_argument('--no_remove_xpos', default=True, action='store_false', dest='remove_xpos', help='By default, we remove the xpos from the dataset') parser.add_argument('--no_remove_feats', default=True, action='store_false', dest='remove_feats', help='By default, we remove the feats from the dataset') parser.add_argument('--output_directory', default=get_default_paths()["POS_DATA_DIR"], help="Where to put the split conllu") args = parser.parse_args() weights = (args.train, args.dev, args.test) doc = CoNLL.conll2doc(args.filename) random.seed(args.seed) splits = random_split(doc, weights, args.remove_xpos, args.remove_feats) for split_doc, split_name in zip(splits, ("train", "dev", "test")): filename = os.path.join(args.output_directory, "%s.%s.in.conllu" % (args.short_name, split_name)) print("Outputting %d sentences to %s" % (len(split_doc.sentences), filename)) CoNLL.write_doc2conll(split_doc, filename) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/sentiment/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/sentiment/add_constituency.py ================================================ """ For a dataset produced by prepare_sentiment_dataset, add constituency parses. Obviously this will only work on languages that have a constituency parser """ import argparse import os import stanza from stanza.models.classifiers.data import read_dataset from stanza.models.classifiers.utils import WVType from stanza.models.mwt.utils import resplit_mwt from stanza.utils.datasets.sentiment import prepare_sentiment_dataset from stanza.utils.datasets.sentiment import process_utils import stanza.utils.default_paths as default_paths SHARDS = ("train", "dev", "test") def main(): parser = argparse.ArgumentParser() # TODO: allow multiple files? parser.add_argument('dataset', type=str, help="Dataset (or a single file) to process") parser.add_argument('--output', type=str, help="Write the processed data here instead of clobbering") parser.add_argument('--constituency_package', type=str, default=None, help="Constituency model to use for parsing") parser.add_argument('--constituency_model', type=str, default=None, help="Specific model file to use for parsing") parser.add_argument('--retag_package', type=str, default=None, help="Which tagger to use for retagging") parser.add_argument('--split_mwt', action='store_true', help="Split MWT from the original sentences if the language has MWT") parser.add_argument('--lang', type=str, default=None, help="Which language the dataset/file is in. If not specified, will try to use the dataset name") args = parser.parse_args() if os.path.exists(args.dataset): expected_files = [args.dataset] if args.output: output_files = [args.output] else: output_files = expected_files if not args.lang: _, filename = os.path.split(args.dataset) args.lang = filename.split("_")[0] print("Guessing lang=%s based on the filename %s" % (args.lang, filename)) else: paths = default_paths.get_default_paths() # TODO: one of the side effects of the tass2020 dataset is to make a bunch of extra files # Perhaps we could have the prepare_sentiment_dataset script return a list of those files expected_files = [os.path.join(paths['SENTIMENT_DATA_DIR'], '%s.%s.json' % (args.dataset, shard)) for shard in SHARDS] if args.output: output_files = [os.path.join(paths['SENTIMENT_DATA_DIR'], '%s.%s.json' % (args.output, shard)) for shard in SHARDS] else: output_files = expected_files for filename in expected_files: if not os.path.exists(filename): print("Cannot find expected dataset file %s - rebuilding dataset" % filename) prepare_sentiment_dataset.main(args.dataset) break if not args.lang: args.lang, _ = args.dataset.split("_", 1) print("Guessing lang=%s based on the dataset name" % args.lang) pipeline_args = {"lang": args.lang, "processors": "tokenize,pos,constituency", "tokenize_pretokenized": True, "pos_batch_size": 50, "pos_tqdm": True, "constituency_tqdm": True} package = {} if args.constituency_package is not None: package["constituency"] = args.constituency_package if args.retag_package is not None: package["pos"] = args.retag_package if package: pipeline_args["package"] = package if args.constituency_model is not None: pipeline_args["constituency_model_path"] = args.constituency_model pipe = stanza.Pipeline(**pipeline_args) if args.split_mwt: # TODO: allow for different tokenize packages mwt_pipe = stanza.Pipeline(lang=args.lang, processors="tokenize") if "mwt" in mwt_pipe.processors: print("This language has MWT. Will resplit any MWTs found in the dataset") else: print("--split_mwt was requested, but %s does not support MWT!" % args.lang) args.split_mwt = False for filename, output_filename in zip(expected_files, output_files): dataset = read_dataset(filename, WVType.OTHER, 1) text = [x.text for x in dataset] if args.split_mwt: print("Resplitting MWT in %d sentences from %s" % (len(dataset), filename)) doc = resplit_mwt(text, mwt_pipe) print("Parsing %d sentences from %s" % (len(dataset), filename)) doc = pipe(doc) else: print("Parsing %d sentences from %s" % (len(dataset), filename)) doc = pipe(text) assert len(dataset) == len(doc.sentences) for datum, sentence in zip(dataset, doc.sentences): datum.constituency = sentence.constituency process_utils.write_list(output_filename, dataset) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/sentiment/convert_italian_poetry_classification.py ================================================ """ A short tool to turn a labeled dataset of the format Prof. Delmonte provided into a stanza input file for the classifier. Data is expected to be in the sentiment italian subdirectory (see below) Only writes a test set. Use it as an eval file for a trained model. """ import os import stanza from stanza.models.classifiers.data import SentimentDatum from stanza.utils.datasets.sentiment import process_utils import stanza.utils.default_paths as default_paths def main(): paths = default_paths.get_default_paths() dataset_name = "it_vit_sentences_poetry" poetry_filename = os.path.join(paths["SENTIMENT_BASE"], "italian", "sentence_classification", "poetry", "testset_300_labeled.txt") if not os.path.exists(poetry_filename): raise FileNotFoundError("Expected to find the labeled file in %s" % poetry_filename) print("Reading the labeled poetry from %s" % poetry_filename) tokenizer = stanza.Pipeline("it", processors="tokenize", tokenize_no_ssplit=True) dataset = [] with open(poetry_filename, encoding="utf-8") as fin: for line_num, line in enumerate(fin): line = line.strip() if not line: continue line = line.replace(u'\ufeff', '') pieces = line.split(maxsplit=1) # first column is the label # remainder of the text is the raw text label = pieces[0].strip() if label not in ('0', '1'): if label == "viene" and line_num == 257: print("Skipping known missing label at line 257") continue assert isinstance(label, str) ords = ",".join(str(ord(x)) for x in label) raise ValueError("Unexpected label |%s| (%s) for line %d" % (label, ords, line_num)) # tokenize the text into words # we could make this faster by stacking it, but the input file is quite short anyway text = pieces[1] doc = tokenizer(text) words = [x.text for x in doc.sentences[0].words] dataset.append(SentimentDatum(label, words)) print("Read %d lines from %s" % (len(dataset), poetry_filename)) output_filename = "%s.test.json" % dataset_name output_path = os.path.join(paths["SENTIMENT_DATA_DIR"], output_filename) print("Writing output to %s" % output_path) process_utils.write_list(output_path, dataset) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/sentiment/convert_italian_sentence_classification.py ================================================ """ Converts a file of labels on constituency trees for the it_vit dataset The labels are for whether or not a sentence is written in a standard S-V-O order. The intent is to see how much a constituency parser can improve over a regular transformer classifier. This file is provided by Prof. Delmonte as part of a classification project. Contact John Bauer for more details. Technically this should be "classifier" instead of "sentiment" """ import os from stanza.models.classifiers.data import SentimentDatum from stanza.utils.datasets.sentiment import process_utils from stanza.utils.datasets.constituency.convert_it_vit import read_updated_trees import stanza.utils.default_paths as default_paths def label_trees(label_map, trees): new_trees = [] for tree in trees: if tree.con_id not in label_map: raise ValueError("%s not labeled" % tree.con_id) label = label_map[tree.con_id] new_trees.append(SentimentDatum(label, tree.tree.leaf_labels(), tree.tree)) return new_trees def read_label_map(label_filename): with open(label_filename, encoding="utf-8") as fin: lines = fin.readlines() lines = [x.strip() for x in lines] lines = [x.split() for x in lines if x] label_map = {} for line_idx, line in enumerate(lines): k = line[0].split("#")[1] v = line[1] # compensate for an off-by-one error in the labels for ids 12 through 129 # we went back and forth a few times but i couldn't explain the error, # so whatever, just compensate for it on the conversion side k_idx = int(k.split("_")[1]) if k_idx != line_idx + 1: if k_idx >= 12 and k_idx <= 129: k = "sent_%05d" % (k_idx - 1) else: raise ValueError("Unexpected key offset for line {}: {}".format(line_idx, line)) if v == "neg": v = "0" elif v == "pos": v = "1" else: raise ValueError("Unexpected label %s for key %s" % (v, k)) if k in label_map: raise ValueError("Duplicate key %s: new value %s, old value %s" % (k, v, label_map[k])) label_map[k] = v return label_map def main(): paths = default_paths.get_default_paths() dataset_name = "it_vit_sentences" label_filename = os.path.join(paths["SENTIMENT_BASE"], "italian", "sentence_classification", "classified") if not os.path.exists(label_filename): raise FileNotFoundError("Expected to find the labeled file in %s" % label_filename) label_map = read_label_map(label_filename) # this will produce three lists of trees with their con_id attached train_trees, dev_trees, test_trees = read_updated_trees(paths) train_trees = label_trees(label_map, train_trees) dev_trees = label_trees(label_map, dev_trees) test_trees = label_trees(label_map, test_trees) dataset = (train_trees, dev_trees, test_trees) process_utils.write_dataset(dataset, paths["SENTIMENT_DATA_DIR"], dataset_name) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/sentiment/prepare_sentiment_dataset.py ================================================ """Prepare a single dataset or a combination dataset for the sentiment project Manipulates various downloads from their original form to a form usable by the classifier model Explanations for the existing datasets are below. After processing the dataset, you can train with the run_sentiment script python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset python3 -m stanza.utils.training.run_sentiment English ------- SST (Stanford Sentiment Treebank) https://nlp.stanford.edu/sentiment/ https://github.com/stanfordnlp/sentiment-treebank The git repo includes fixed tokenization and sentence splits, along with a partial conversion to updated PTB tokenization standards. The first step is to git clone the SST to here: $SENTIMENT_BASE/sentiment-treebank eg: cd $SENTIMENT_BASE git clone git@github.com:stanfordnlp/sentiment-treebank.git There are a few different usages of SST. The scores most commonly reported are for SST-2, positive and negative only. To get a version of this: python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sst2 python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sst2roots The model we distribute is a three class model (+, 0, -) with some smaller datasets added for better coverage. See "sstplus" below. MELD https://github.com/SenticNet/MELD/tree/master/data/MELD https://github.com/SenticNet/MELD https://arxiv.org/pdf/1810.02508.pdf MELD: A Multimodal Multi-Party Dataset for Emotion Recognition in Conversation. ACL 2019. S. Poria, D. Hazarika, N. Majumder, G. Naik, E. Cambria, R. Mihalcea. An Emotion Corpus of Multi-Party Conversations. Chen, S.Y., Hsu, C.C., Kuo, C.C. and Ku, L.W. Copy the three files in the repo into $SENTIMENT_BASE/MELD TODO: make it so you git clone the repo instead There are train/dev/test splits, so you can build a model out of just this corpus. The first step is to convert to the classifier data format: python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_meld However, in general we simply include this in the sstplus model rather than releasing a separate model. Arguana http://argumentation.bplaced.net/arguana/data http://argumentation.bplaced.net/arguana-data/arguana-tripadvisor-annotated-v2.zip http://argumentation.bplaced.net/arguana-publications/papers/wachsmuth14a-cicling.pdf A Review Corpus for Argumentation Analysis. CICLing 2014 Henning Wachsmuth, Martin Trenkmann, Benno Stein, Gregor Engels, Tsvetomira Palarkarska Download the zip file and unzip it in $SENTIMENT_BASE/arguana This is included in the sstplus model. airline A Kaggle corpus for sentiment detection on airline tweets. We include this in sstplus as well. https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment Download Tweets.csv and put it in $SENTIMENT_BASE/airline SLSD https://archive.ics.uci.edu/ml/datasets/Sentiment+Labelled+Sentences From Group to Individual Labels using Deep Features. KDD 2015 Kotzias et. al Put the contents of the zip file in $SENTIMENT_BASE/slsd The sstplus model includes this as training data en_sstplus This is a three class model built from SST, along with the additional English data sources above for coverage of additional domains. python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sstplus en_corona A kaggle covid-19 text classification dataset https://www.kaggle.com/datasets/datatattle/covid-19-nlp-text-classification python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_corona German ------ de_sb10k This used to be here: https://www.spinningbytes.com/resources/germansentiment/ Now it appears to have moved here? https://github.com/oliverguhr/german-sentiment https://dl.acm.org/doi/pdf/10.1145/3038912.3052611 Leveraging Large Amounts of Weakly Supervised Data for Multi-Language Sentiment Classification WWW '17: Proceedings of the 26th International Conference on World Wide Web Jan Deriu, Aurelien Lucchi, Valeria De Luca, Aliaksei Severyn, Simon Müller, Mark Cieliebak, Thomas Hofmann, Martin Jaggi The current prep script works on the old version of the data. TODO: update to work on the git repo python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset de_sb10k de_scare http://romanklinger.de/scare/ The Sentiment Corpus of App Reviews with Fine-grained Annotations in German LREC 2016 Mario Sänger, Ulf Leser, Steffen Kemmerer, Peter Adolphs, and Roman Klinger Download the data and put it in $SENTIMENT_BASE/german/scare There should be two subdirectories once you are done: scare_v1.0.0 scare_v1.0.0_text We wound up not including this in the default German model. It might be worth revisiting in the future. de_usage https://www.romanklinger.de/usagecorpus/ http://www.lrec-conf.org/proceedings/lrec2014/summaries/85.html The USAGE Review Corpus for Fine Grained Multi Lingual Opinion Analysis Roman Klinger and Philipp Cimiano Again, not included in the default German model Chinese ------- zh-hans_ren This used to be here: http://a1-www.is.tokushima-u.ac.jp/member/ren/Ren-CECps1.0/Ren-CECps1.0.html That page doesn't seem to respond as of 2022, and I can't find it elsewhere. The following will be available starting in 1.4.1: Spanish ------- tass2020 - http://tass.sepln.org/2020/?page_id=74 - Download the following 5 files: task1.2-test-gold.tsv Task1-train-dev.zip tass2020-test-gold.zip Test1.1.zip test1.2.zip Put them in a directory $SENTIMENT_BASE/spanish/tass2020 python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset es_tass2020 Vietnamese ---------- vi_vsfc I found a corpus labeled VSFC here: https://drive.google.com/drive/folders/1xclbjHHK58zk2X6iqbvMPS2rcy9y9E0X It doesn't seem to have a license or paper associated with it, but happy to put those details here if relevant. Download the files to $SENTIMENT_BASE/vietnamese/_UIT-VSFC python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset vi_vsfc Marathi ------- mr_l3cube https://github.com/l3cube-pune/MarathiNLP https://arxiv.org/abs/2103.11408 L3CubeMahaSent: A Marathi Tweet-based Sentiment Analysis Dataset Atharva Kulkarni, Meet Mandhane, Manali Likhitkar, Gayatri Kshirsagar, Raviraj Joshi git clone the repo in $SENTIMENT_BASE cd $SENTIMENT_BASE git clone git@github.com:l3cube-pune/MarathiNLP.git python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset mr_l3cube Hindi ----- odiagenai https://huggingface.co/datasets/OdiaGenAI/sentiment_analysis_hindi Uses datasets package from HF, so that needs to be installed python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset hi_odiagenai This dataset has 2497 sentences in a train section. We randomly split them to make a usable dataset Italian ------- it_sentipolc16 from here: http://www.di.unito.it/~tutreeb/sentipolc-evalita16/data.html paper describing the evaluation and the results: http://ceur-ws.org/Vol-1749/paper_026.pdf download the training and test zip files to $SENTIMENT_BASE/italian/sentipolc16 unzip them there so you should have $SENTIMENT_BASE/italian/sentipolc16/test_set_sentipolc16_gold2000.csv $SENTIMENT_BASE/italian/sentipolc16/training_set_sentipolc16.csv python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset it_sentipolc16 this script splits the training data into dev & train, keeps the test the same The conversion allows for 4 ways of handling the "mixed" class: treat it as the same as neutral, treat it as a separate class, only distinguish positive or not positive, only distinguish negative or not negative for more details: python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset it_sentipolc16 --help another option not implemented yet: absita18 http://sag.art.uniroma2.it/absita/data/ """ import os import random import sys import stanza import stanza.utils.default_paths as default_paths from stanza.utils.datasets.sentiment import process_airline from stanza.utils.datasets.sentiment import process_arguana_xml from stanza.utils.datasets.sentiment import process_corona from stanza.utils.datasets.sentiment import process_es_tass2020 from stanza.utils.datasets.sentiment import process_it_sentipolc16 from stanza.utils.datasets.sentiment import process_MELD from stanza.utils.datasets.sentiment import process_ren_chinese from stanza.utils.datasets.sentiment import process_sb10k from stanza.utils.datasets.sentiment import process_scare from stanza.utils.datasets.sentiment import process_slsd from stanza.utils.datasets.sentiment import process_sst from stanza.utils.datasets.sentiment import process_usage_german from stanza.utils.datasets.sentiment import process_vsfc_vietnamese from stanza.utils.datasets.sentiment import process_utils from tqdm import tqdm def convert_sst_general(paths, dataset_name, version): in_directory = paths['SENTIMENT_BASE'] sst_dir = os.path.join(in_directory, "sentiment-treebank") train_phrases = process_sst.get_phrases(version, "train.txt", sst_dir) dev_phrases = process_sst.get_phrases(version, "dev.txt", sst_dir) test_phrases = process_sst.get_phrases(version, "test.txt", sst_dir) out_directory = paths['SENTIMENT_DATA_DIR'] dataset = [train_phrases, dev_phrases, test_phrases] process_utils.write_dataset(dataset, out_directory, dataset_name) def convert_sst2(paths, dataset_name, *args): """ Create a 2 class SST dataset (neutral items are dropped) """ convert_sst_general(paths, dataset_name, "binary") def convert_sst2roots(paths, dataset_name, *args): """ Create a 2 class SST dataset using only the roots """ convert_sst_general(paths, dataset_name, "binaryroot") def convert_sst3(paths, dataset_name, *args): """ Create a 3 class SST dataset using only the roots """ convert_sst_general(paths, dataset_name, "threeclass") def convert_sst3roots(paths, dataset_name, *args): """ Create a 3 class SST dataset using only the roots """ convert_sst_general(paths, dataset_name, "threeclassroot") def convert_sstplus(paths, dataset_name, *args): """ Create a 3 class SST dataset with a few other small datasets added """ train_phrases = [] in_directory = paths['SENTIMENT_BASE'] train_phrases.extend(process_arguana_xml.get_tokenized_phrases(os.path.join(in_directory, "arguana"))) train_phrases.extend(process_MELD.get_tokenized_phrases("train", os.path.join(in_directory, "MELD"))) train_phrases.extend(process_slsd.get_tokenized_phrases(os.path.join(in_directory, "slsd"))) train_phrases.extend(process_airline.get_tokenized_phrases(os.path.join(in_directory, "airline"))) sst_dir = os.path.join(in_directory, "sentiment-treebank") train_phrases.extend(process_sst.get_phrases("threeclass", "train.txt", sst_dir)) train_phrases.extend(process_sst.get_phrases("threeclass", "extra-train.txt", sst_dir)) train_phrases.extend(process_sst.get_phrases("threeclass", "checked-extra-train.txt", sst_dir)) dev_phrases = process_sst.get_phrases("threeclass", "dev.txt", sst_dir) test_phrases = process_sst.get_phrases("threeclass", "test.txt", sst_dir) out_directory = paths['SENTIMENT_DATA_DIR'] dataset = [train_phrases, dev_phrases, test_phrases] process_utils.write_dataset(dataset, out_directory, dataset_name) def convert_meld(paths, dataset_name, *args): """ Convert the MELD dataset to train/dev/test files """ in_directory = os.path.join(paths['SENTIMENT_BASE'], "MELD") out_directory = paths['SENTIMENT_DATA_DIR'] process_MELD.main(in_directory, out_directory, dataset_name) def convert_corona(paths, dataset_name, *args): """ Convert the kaggle covid dataset to train/dev/test files """ process_corona.main(*args) def convert_scare(paths, dataset_name, *args): in_directory = os.path.join(paths['SENTIMENT_BASE'], "german", "scare") out_directory = paths['SENTIMENT_DATA_DIR'] process_scare.main(in_directory, out_directory, dataset_name) def convert_de_usage(paths, dataset_name, *args): in_directory = os.path.join(paths['SENTIMENT_BASE'], "USAGE") out_directory = paths['SENTIMENT_DATA_DIR'] process_usage_german.main(in_directory, out_directory, dataset_name) def convert_sb10k(paths, dataset_name, *args): """ Essentially runs the sb10k script twice with different arguments to produce the de_sb10k dataset stanza.utils.datasets.sentiment.process_sb10k --csv_filename extern_data/sentiment/german/sb-10k/de_full/de_test.tsv --out_dir $SENTIMENT_DATA_DIR --short_name de_sb10k --split test --sentiment_column 2 --text_column 3 stanza.utils.datasets.sentiment.process_sb10k --csv_filename extern_data/sentiment/german/sb-10k/de_full/de_train.tsv --out_dir $SENTIMENT_DATA_DIR --short_name de_sb10k --split train_dev --sentiment_column 2 --text_column 3 """ column_args = ["--sentiment_column", "2", "--text_column", "3"] process_sb10k.main(["--csv_filename", os.path.join(paths['SENTIMENT_BASE'], "german", "sb-10k", "de_full", "de_test.tsv"), "--out_dir", paths['SENTIMENT_DATA_DIR'], "--short_name", dataset_name, "--split", "test", *column_args]) process_sb10k.main(["--csv_filename", os.path.join(paths['SENTIMENT_BASE'], "german", "sb-10k", "de_full", "de_train.tsv"), "--out_dir", paths['SENTIMENT_DATA_DIR'], "--short_name", dataset_name, "--split", "train_dev", *column_args]) def convert_vi_vsfc(paths, dataset_name, *args): in_directory = os.path.join(paths['SENTIMENT_BASE'], "vietnamese", "_UIT-VSFC") out_directory = paths['SENTIMENT_DATA_DIR'] process_vsfc_vietnamese.main(in_directory, out_directory, dataset_name) def convert_hi_odiagenai(paths, dataset_name, *args): out_directory = paths['SENTIMENT_DATA_DIR'] os.makedirs(out_directory, exist_ok=True) import datasets ds = datasets.load_dataset("OdiaGenAI/sentiment_analysis_hindi") nlp = stanza.Pipeline("hi", processors='tokenize') mapping = {"pos": 2, "neu": 1, "neg": 0} train = [] dev = [] test = [] for datum in tqdm(ds['train']): random_slice = random.randint(0, 9) if random_slice == 0: random_slice = dev elif random_slice == 1: random_slice = test else: random_slice = train datum = process_utils.process_datum(nlp, datum['text'], mapping, datum['label']) random_slice.append(datum) dataset = [train, dev, test] process_utils.write_dataset(dataset, out_directory, dataset_name) def convert_mr_l3cube(paths, dataset_name, *args): # csv_filename = 'extern_data/sentiment/MarathiNLP/L3CubeMahaSent Dataset/tweets-train.csv' MAPPING = {"-1": "0", "0": "1", "1": "2"} out_directory = paths['SENTIMENT_DATA_DIR'] os.makedirs(out_directory, exist_ok=True) in_directory = os.path.join(paths['SENTIMENT_BASE'], "MarathiNLP", "L3CubeMahaSent Dataset") input_files = ['tweets-train.csv', 'tweets-valid.csv', 'tweets-test.csv'] input_files = [os.path.join(in_directory, x) for x in input_files] datasets = [process_utils.read_snippets(csv_filename, sentiment_column=1, text_column=0, tokenizer_language="mr", mapping=MAPPING, delimiter=',', quotechar='"', skip_first_line=True) for csv_filename in input_files] process_utils.write_dataset(datasets, out_directory, dataset_name) def convert_es_tass2020(paths, dataset_name, *args): process_es_tass2020.convert_tass2020(paths['SENTIMENT_BASE'], paths['SENTIMENT_DATA_DIR'], dataset_name) def convert_it_sentipolc16(paths, dataset_name, *args): in_directory = os.path.join(paths['SENTIMENT_BASE'], "italian", "sentipolc16") out_directory = paths['SENTIMENT_DATA_DIR'] process_it_sentipolc16.main(in_directory, out_directory, dataset_name, *args) def convert_ren(paths, dataset_name, *args): in_directory = os.path.join(paths['SENTIMENT_BASE'], "chinese", "RenCECps") out_directory = paths['SENTIMENT_DATA_DIR'] process_ren_chinese.main(in_directory, out_directory, dataset_name) DATASET_MAPPING = { "de_sb10k": convert_sb10k, "de_scare": convert_scare, "de_usage": convert_de_usage, "en_corona": convert_corona, "en_sst2": convert_sst2, "en_sst2roots": convert_sst2roots, "en_sst3": convert_sst3, "en_sst3roots": convert_sst3roots, "en_sstplus": convert_sstplus, "en_meld": convert_meld, "es_tass2020": convert_es_tass2020, "hi_odiagenai": convert_hi_odiagenai, "it_sentipolc16": convert_it_sentipolc16, "mr_l3cube": convert_mr_l3cube, "vi_vsfc": convert_vi_vsfc, "zh-hans_ren": convert_ren, } def main(dataset_name, *args): paths = default_paths.get_default_paths() random.seed(1234) if dataset_name in DATASET_MAPPING: DATASET_MAPPING[dataset_name](paths, dataset_name, *args) else: raise ValueError(f"dataset {dataset_name} currently not handled") if __name__ == '__main__': main(sys.argv[1], sys.argv[2:]) ================================================ FILE: stanza/utils/datasets/sentiment/process_MELD.py ================================================ """ MELD is a dataset of Friends (the TV show) utterances. The ratings include judgment based on the visuals, so it might be harder than expected to directly extract from the text. However, it should broaden the scope of the model and doesn't seem to hurt performance. https://github.com/SenticNet/MELD/tree/master/data/MELD https://github.com/SenticNet/MELD https://arxiv.org/pdf/1810.02508.pdf Files in the MELD repo are csv, with quotes in "..." if they contained commas themselves. Accordingly, we use the csv module to read the files and output them in the format Run using python3 convert_MELD.py MELD/train_sent_emo.csv train.txt etc """ import csv import os import sys from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils def get_phrases(in_filename): """ Get the phrases from a single CSV filename """ with open(in_filename, newline='', encoding='windows-1252') as fin: cin = csv.reader(fin, delimiter=',', quotechar='"') lines = list(cin) phrases = [] for line in lines[1:]: sentiment = line[4] if sentiment == 'negative': sentiment = '0' elif sentiment == 'neutral': sentiment = '1' elif sentiment == 'positive': sentiment = '2' else: raise ValueError("Unknown sentiment: {}".format(sentiment)) utterance = line[1].replace("Â", "") phrases.append(SentimentDatum(sentiment, utterance)) return phrases def get_tokenized_phrases(split, in_directory): """ split in train,dev,test """ in_filename = os.path.join(in_directory, "%s_sent_emo.csv" % split) phrases = get_phrases(in_filename) phrases = process_utils.get_ptb_tokenized_phrases(phrases) print("Found {} phrases in MELD {}".format(len(phrases), split)) return phrases def main(in_directory, out_directory, short_name): os.makedirs(out_directory, exist_ok=True) for split in ("train", "dev", "test"): phrases = get_tokenized_phrases(split, in_directory) process_utils.write_list(os.path.join(out_directory, "%s.%s.json" % (short_name, split)), phrases) if __name__ == '__main__': in_directory = sys.argv[1] out_directory = sys.argv[2] short_name = sys.argv[3] main(in_directory, out_directory, short_name) ================================================ FILE: stanza/utils/datasets/sentiment/process_airline.py ================================================ """ Airline tweets from Kaggle from https://www.kaggle.com/crowdflower/twitter-airline-sentiment/data# Some ratings seem questionable, but it doesn't hurt performance much, if at all Files in the airline repo are csv, with quotes in "..." if they contained commas themselves. Accordingly, we use the csv module to read the files and output them in the format Run using python3 convert_airline.py Tweets.csv train.json If the first word is an @, it is removed, and after that, leading @ or # are removed. For example: @AngledLuffa you must hate having Mox Opal #banned -> you must hate having Mox Opal banned """ import csv import os import sys from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils def get_phrases(in_directory): in_filename = os.path.join(in_directory, "Tweets.csv") with open(in_filename, newline='') as fin: cin = csv.reader(fin, delimiter=',', quotechar='"') lines = list(cin) phrases = [] for line in lines[1:]: sentiment = line[1] if sentiment == 'negative': sentiment = '0' elif sentiment == 'neutral': sentiment = '1' elif sentiment == 'positive': sentiment = '2' else: raise ValueError("Unknown sentiment: {}".format(sentiment)) # some of the tweets have \n in them utterance = line[10].replace("\n", " ") phrases.append(SentimentDatum(sentiment, utterance)) return phrases def get_tokenized_phrases(in_directory): phrases = get_phrases(in_directory) phrases = process_utils.get_ptb_tokenized_phrases(phrases) phrases = [SentimentDatum(x.sentiment, process_utils.clean_tokenized_tweet(x.text)) for x in phrases] print("Found {} phrases in the airline corpus".format(len(phrases))) return phrases def main(in_directory, out_directory, short_name): phrases = get_tokenized_phrases(in_directory) os.makedirs(out_directory, exist_ok=True) out_filename = os.path.join(out_directory, "%s.train.json" % short_name) # filter leading @United, @American, etc from the tweets process_utils.write_list(out_filename, phrases) # something like this would count @s if you cared enough to count # would need to update for SentimentDatum() #ats = Counter() #for line in lines: # ats.update([x for x in line.split() if x[0] == '@']) if __name__ == '__main__': in_directory = sys.argv[1] out_directory = sys.argv[2] short_name = sys.argv[3] main(in_directory, out_directory, short_name) ================================================ FILE: stanza/utils/datasets/sentiment/process_arguana_xml.py ================================================ from collections import namedtuple import glob import os import sys import xml.etree.ElementTree as ET from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils ArguanaSentimentDatum = namedtuple('ArguanaSentimentDatum', ['begin', 'end', 'rating']) """ Extracts positive, neutral, and negative phrases from the ArguAna hotel review corpus Run as follows: python3 parse_arguana_xml.py split/training data/sentiment ArguAna can be downloaded here: http://argumentation.bplaced.net/arguana/data http://argumentation.bplaced.net/arguana-data/arguana-tripadvisor-annotated-v2.zip """ def get_phrases(filename): tree = ET.parse(filename) fragments = [] root = tree.getroot() body = None for child in root: if child.tag == '{http:///uima/cas.ecore}Sofa': body = child.attrib['sofaString'] elif child.tag == '{http:///de/aitools/ie/uima/type/arguana.ecore}Fact': fragments.append(ArguanaSentimentDatum(begin=int(child.attrib['begin']), end=int(child.attrib['end']), rating="1")) elif child.tag == '{http:///de/aitools/ie/uima/type/arguana.ecore}Opinion': if child.attrib['polarity'] == 'negative': rating = "0" elif child.attrib['polarity'] == 'positive': rating = "2" else: raise ValueError("Unexpected polarity found in {}".format(filename)) fragments.append(ArguanaSentimentDatum(begin=int(child.attrib['begin']), end=int(child.attrib['end']), rating=rating)) phrases = [SentimentDatum(fragment.rating, body[fragment.begin:fragment.end]) for fragment in fragments] #phrases = [phrase.replace("\n", " ") for phrase in phrases] return phrases def get_phrases_from_directory(directory): phrases = [] inpath = os.path.join(directory, "arguana-tripadvisor-annotated-v2", "split", "training", "*", "*xmi") for filename in glob.glob(inpath): phrases.extend(get_phrases(filename)) return phrases def get_tokenized_phrases(in_directory): phrases = get_phrases_from_directory(in_directory) phrases = process_utils.get_ptb_tokenized_phrases(phrases) print("Found {} phrases in arguana".format(len(phrases))) return phrases def main(in_directory, out_directory, short_name): phrases = get_tokenized_phrases(in_directory) process_utils.write_list(os.path.join(out_directory, "%s.train.json" % short_name), phrases) if __name__ == "__main__": in_directory = sys.argv[1] out_directory = sys.argv[2] short_name = sys.argv[3] main(in_directory, out_directory, short_name) ================================================ FILE: stanza/utils/datasets/sentiment/process_corona.py ================================================ """ Processes a kaggle covid-19 text classification dataset The original description of the dataset is here: https://www.kaggle.com/datasets/datatattle/covid-19-nlp-text-classification There are two files in the archive, Corona_NLP_train.csv and Corona_NLP_test.csv Unzip the files in archive.zip to $SENTIMENT_BASE/english/corona/Corona_NLP_train.csv There is no dedicated dev set, so we randomly split train/dev (using a specific seed, so that the split always comes out the same) """ import argparse import os import random import stanza import stanza.utils.datasets.sentiment.process_utils as process_utils from stanza.utils.default_paths import get_default_paths # TODO: could give an option to keep the 'extremely' MAPPING = {'extremely positive': "2", 'positive': "2", 'neutral': "1", 'negative': "0", 'extremely negative': "0"} def main(args=None): default_paths = get_default_paths() sentiment_base_dir = default_paths["SENTIMENT_BASE"] default_in_dir = os.path.join(sentiment_base_dir, "english", "corona") default_out_dir = default_paths["SENTIMENT_DATA_DIR"] parser = argparse.ArgumentParser() parser.add_argument('--in_dir', type=str, default=default_in_dir, help='Where to get the input files') parser.add_argument('--out_dir', type=str, default=default_out_dir, help='Where to write the output files') parser.add_argument('--short_name', type=str, default="en_corona", help='short name to use when writing files') args = parser.parse_args(args=args) TEXT_COLUMN = 4 SENTIMENT_COLUMN = 5 train_csv = os.path.join(args.in_dir, "Corona_NLP_train.csv") test_csv = os.path.join(args.in_dir, "Corona_NLP_test.csv") nlp = stanza.Pipeline("en", processors='tokenize') train_snippets = process_utils.read_snippets(train_csv, SENTIMENT_COLUMN, TEXT_COLUMN, 'en', MAPPING, delimiter=",", quotechar='"', skip_first_line=True, nlp=nlp, encoding="latin1") test_snippets = process_utils.read_snippets(test_csv, SENTIMENT_COLUMN, TEXT_COLUMN, 'en', MAPPING, delimiter=",", quotechar='"', skip_first_line=True, nlp=nlp, encoding="latin1") print("Read %d train snippets" % len(train_snippets)) print("Read %d test snippets" % len(test_snippets)) random.seed(1234) random.shuffle(train_snippets) os.makedirs(args.out_dir, exist_ok=True) process_utils.write_splits(args.out_dir, train_snippets, (process_utils.Split("%s.train.json" % args.short_name, 0.9), process_utils.Split("%s.dev.json" % args.short_name, 0.1))) process_utils.write_list(os.path.join(args.out_dir, "%s.test.json" % args.short_name), test_snippets) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/sentiment/process_es_tass2020.py ================================================ """ Convert the TASS 2020 dataset, available here: http://tass.sepln.org/2020/?page_id=74 There are two parts to the dataset, but only part 1 has the gold annotations available. Download: Task 1 train & dev sets Task 1.1 test set Task 1.2 test set Task 1.1 test set gold standard Task 1.2 test set gold standard (.tsv, not .zip) No need to unzip any of the files. The extraction script reads the expected paths directly from the zip files. There are two subtasks in TASS 2020. One is split among 5 Spanish speaking countries, and the other is combined across all of the countries. Here we combine all of the data into one output file. Also, each of the subparts are output into their own files, such as p2.json, p1.mx.json, etc """ import os import zipfile import stanza from stanza.models.classifiers.data import SentimentDatum import stanza.utils.default_paths as default_paths from stanza.utils.datasets.sentiment.process_utils import write_dataset, write_list def convert_label(label): """ N/NEU/P or error """ if label == "N": return 0 if label == "NEU": return 1 if label == "P": return 2 raise ValueError("Unexpected label %s" % label) def read_test_labels(fin): """ Read a tab (or space) separated list of id/label pairs """ label_map = {} for line_idx, line in enumerate(fin): if isinstance(line, bytes): line = line.decode("utf-8") pieces = line.split() if len(pieces) < 2: continue if len(pieces) > 2: raise ValueError("Unexpected format at line %d: all label lines should be len==2\n%s" % (line_idx, line)) datum_id, label = pieces try: label = convert_label(label) except ValueError: raise ValueError("Unexpected test label %s at line %d\n%s" % (label, line_idx, line)) label_map[datum_id] = label return label_map def open_read_test_labels(filename, zip_filename=None): """ Open either a text or zip file, then read the labels """ if zip_filename is None: with open(filename, encoding="utf-8") as fin: test_labels = read_test_labels(fin) print("Read %d lines from %s" % (len(test_labels), filename)) return test_labels with zipfile.ZipFile(zip_filename) as zin: with zin.open(filename) as fin: test_labels = read_test_labels(fin) print("Read %d lines from %s - %s" % (len(test_labels), zip_filename, filename)) return test_labels def read_sentences(fin): """ Read ids and text from the given file """ lines = [] for line_idx, line in enumerate(fin): line = line.decode("utf-8") pieces = line.split(maxsplit=1) if len(pieces) < 2: continue lines.append(pieces) return lines def open_read_sentences(filename, zip_filename): """ Opens a file and then reads the sentences Only applies to files inside zips, as all of the sentence files in this dataset are inside a zip """ with zipfile.ZipFile(zip_filename) as zin: with zin.open(filename) as fin: test_sentences = read_sentences(fin) print("Read %d texts from %s - %s" % (len(test_sentences), zip_filename, filename)) return test_sentences def combine_test_set(sentences, labels): """ Combines the labels and sentences from two pieces of the test set Matches the ID from the label files and the text files """ combined = [] if len(sentences) != len(labels): raise ValueError("Lengths of sentences and labels should match!") for sent_id, text in sentences: label = labels.get(sent_id, None) if label is None: raise KeyError("Cannot find a test label from the ID: %s" % sent_id) # not tokenized yet - we can do tokenization in batches combined.append(SentimentDatum(label, text)) return combined DATASET_PIECES = ("cr", "es", "mx", "pe", "uy") def tokenize(sentiment_data, pipe): """ Takes a list of (label, text) and returns a list of SentimentDatum with tokenized text Only the first 'sentence' is used - ideally the pipe has ssplit turned off """ docs = [x.text for x in sentiment_data] in_docs = [stanza.Document([], text=d) for d in docs] out_docs = pipe(in_docs) sentiment_data = [SentimentDatum(datum.sentiment, [y.text for y in doc.sentences[0].tokens]) # list of text tokens for each doc for datum, doc in zip(sentiment_data, out_docs)] return sentiment_data def read_test_set(label_zip_filename, label_filename, sentence_zip_filename, sentence_filename, pipe): """ Read and tokenize an entire test set given the label and sentence filenames """ test_labels = open_read_test_labels(label_filename, label_zip_filename) test_sentences = open_read_sentences(sentence_filename, sentence_zip_filename) sentiment_data = combine_test_set(test_sentences, test_labels) return tokenize(sentiment_data, pipe) return sentiment_data def read_train_file(zip_filename, filename, pipe): """ Read and tokenize a train set All of the train data is inside one zip. We read it one piece at a time """ sentiment_data = [] with zipfile.ZipFile(zip_filename) as zin: with zin.open(filename) as fin: for line_idx, line in enumerate(fin): if isinstance(line, bytes): line = line.decode("utf-8") pieces = line.split(maxsplit=1) if len(pieces) < 2: continue pieces = pieces[1].rsplit(maxsplit=1) if len(pieces) < 2: continue text, label = pieces try: label = convert_label(label) except ValueError: raise ValueError("Unexpected train label %s at line %d\n%s" % (label, line_idx, line)) sentiment_data.append(SentimentDatum(label, text)) print("Read %d texts from %s - %s" % (len(sentiment_data), zip_filename, filename)) sentiment_data = tokenize(sentiment_data, pipe) return sentiment_data def convert_tass2020(in_directory, out_directory, dataset_name): """ Read all of the data from in_directory/spanish/tass2020, write it to out_directory/dataset_name... """ in_directory = os.path.join(in_directory, "spanish", "tass2020") pipe = stanza.Pipeline(lang="es", processors="tokenize", tokenize_no_ssplit=True) test_11 = {} test_11_labels_zip = os.path.join(in_directory, "tass2020-test-gold.zip") test_11_sentences_zip = os.path.join(in_directory, "Test1.1.zip") for piece in DATASET_PIECES: inner_label_filename = piece + ".tsv" inner_sentence_filename = os.path.join("Test1.1", piece.upper() + ".tsv") test_11[piece] = read_test_set(test_11_labels_zip, inner_label_filename, test_11_sentences_zip, inner_sentence_filename, pipe) test_12_label_filename = os.path.join(in_directory, "task1.2-test-gold.tsv") test_12_sentences_zip = os.path.join(in_directory, "test1.2.zip") test_12_sentences_filename = "test1.2/task1.2.tsv" test_12 = read_test_set(None, test_12_label_filename, test_12_sentences_zip, test_12_sentences_filename, pipe) train_dev_zip = os.path.join(in_directory, "Task1-train-dev.zip") dev = {} train = {} for piece in DATASET_PIECES: dev_filename = os.path.join("dev", piece + ".tsv") dev[piece] = read_train_file(train_dev_zip, dev_filename, pipe) for piece in DATASET_PIECES: train_filename = os.path.join("train", piece + ".tsv") train[piece] = read_train_file(train_dev_zip, train_filename, pipe) all_test = test_12 + [item for piece in test_11.values() for item in piece] all_dev = [item for piece in dev.values() for item in piece] all_train = [item for piece in train.values() for item in piece] print("Total train items: %8d" % len(all_train)) print("Total dev items: %8d" % len(all_dev)) print("Total test items: %8d" % len(all_test)) write_dataset((all_train, all_dev, all_test), out_directory, dataset_name) output_file = os.path.join(out_directory, "%s.test.p2.json" % dataset_name) write_list(output_file, test_12) for piece in DATASET_PIECES: output_file = os.path.join(out_directory, "%s.test.p1.%s.json" % (dataset_name, piece)) write_list(output_file, test_11[piece]) def main(paths): in_directory = paths['SENTIMENT_BASE'] out_directory = paths['SENTIMENT_DATA_DIR'] convert_tass2020(in_directory, out_directory, "es_tass2020") if __name__ == '__main__': paths = default_paths.get_default_paths() main(paths) ================================================ FILE: stanza/utils/datasets/sentiment/process_it_sentipolc16.py ================================================ """ Process the SentiPolc dataset from Evalita Can be run as a standalone script or as a module from prepare_sentiment_dataset An option controls how to split up the positive/negative/neutral/mixed classes """ import argparse from enum import Enum import os import random import sys import stanza from stanza.utils.datasets.sentiment import process_utils import stanza.utils.default_paths as default_paths class Mode(Enum): COMBINED = 1 SEPARATE = 2 POSITIVE = 3 NEGATIVE = 4 def main(in_dir, out_dir, short_name, *args): parser = argparse.ArgumentParser() parser.add_argument('--mode', default=Mode.COMBINED, type=lambda x: Mode[x.upper()], help='How to handle mixed vs neutral. {}'.format(", ".join(x.name for x in Mode))) parser.add_argument('--name', default=None, type=str, help='Use a different name to save the dataset. Useful for keeping POSITIVE & NEGATIVE separate') args = parser.parse_args(args=list(*args)) if args.name is not None: short_name = args.name nlp = stanza.Pipeline("it", processors='tokenize') if args.mode == Mode.COMBINED: mapping = { ('0', '0'): "1", # neither negative nor positive: neutral ('1', '0'): "2", # positive, not negative: positive ('0', '1'): "0", # negative, not positive: negative ('1', '1'): "1", # mixed combined with neutral } elif args.mode == Mode.SEPARATE: mapping = { ('0', '0'): "1", # neither negative nor positive: neutral ('1', '0'): "2", # positive, not negative: positive ('0', '1'): "0", # negative, not positive: negative ('1', '1'): "3", # mixed as a different class } elif args.mode == Mode.POSITIVE: mapping = { ('0', '0'): "0", # neutral -> not positive ('1', '0'): "1", # positive -> positive ('0', '1'): "0", # negative -> not positive ('1', '1'): "1", # mixed -> positive } elif args.mode == Mode.NEGATIVE: mapping = { ('0', '0'): "0", # neutral -> not negative ('1', '0'): "0", # positive -> not negative ('0', '1'): "1", # negative -> negative ('1', '1'): "1", # mixed -> negative } print("Using {} scheme to handle the 4 values. Mapping: {}".format(args.mode, mapping)) print("Saving to {} using the short name {}".format(out_dir, short_name)) test_filename = os.path.join(in_dir, "test_set_sentipolc16_gold2000.csv") test_snippets = process_utils.read_snippets(test_filename, (2,3), 8, "it", mapping, delimiter=",", skip_first_line=False, quotechar='"', nlp=nlp) train_filename = os.path.join(in_dir, "training_set_sentipolc16.csv") train_snippets = process_utils.read_snippets(train_filename, (2,3), 8, "it", mapping, delimiter=",", skip_first_line=True, quotechar='"', nlp=nlp) random.shuffle(train_snippets) dev_len = len(train_snippets) // 10 dev_snippets = train_snippets[:dev_len] train_snippets = train_snippets[dev_len:] dataset = (train_snippets, dev_snippets, test_snippets) process_utils.write_dataset(dataset, out_dir, short_name) if __name__ == '__main__': paths = default_paths.get_default_paths() random.seed(1234) in_directory = os.path.join(paths['SENTIMENT_BASE'], "italian", "sentipolc16") out_directory = paths['SENTIMENT_DATA_DIR'] main(in_directory, out_directory, "it_sentipolc16", sys.argv[1:]) ================================================ FILE: stanza/utils/datasets/sentiment/process_ren_chinese.py ================================================ import glob import os import random import sys import xml.etree.ElementTree as ET from collections import namedtuple import stanza from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils """ This processes a Chinese corpus, hosted here: http://a1-www.is.tokushima-u.ac.jp/member/ren/Ren-CECps1.0/Ren-CECps1.0.html The authors want a signed document saying you won't redistribute the corpus. The corpus format is a bunch of .xml files, with sentences labeled with various emotions and an overall polarity. Polarity is labeled as follows: 消极: negative 中性: neutral 积极: positive """ def get_phrases(filename): tree = ET.parse(filename) fragments = [] root = tree.getroot() for child in root: if child.tag == 'paragraph': for subchild in child: if subchild.tag == 'sentence': text = subchild.attrib['S'].strip() if len(text) <= 2: continue polarity = None for inner in subchild: if inner.tag == 'Polarity': polarity = inner break if polarity is None: print("Found sentence with no polarity in {}: {}".format(filename, text)) continue if polarity.text == '消极': sentiment = "0" elif polarity.text == '中性': sentiment = "1" elif polarity.text == '积极': sentiment = "2" else: raise ValueError("Unknown polarity {} in {}".format(polarity.text, filename)) fragments.append(SentimentDatum(sentiment, text)) return fragments def read_snippets(xml_directory): sentences = [] for filename in glob.glob(xml_directory + '/xml/cet_*xml'): sentences.extend(get_phrases(filename)) nlp = stanza.Pipeline('zh', processors='tokenize') snippets = [] for sentence in sentences: doc = nlp(sentence.text) text = [token.text for sentence in doc.sentences for token in sentence.tokens] snippets.append(SentimentDatum(sentence.sentiment, text)) random.shuffle(snippets) return snippets def main(xml_directory, out_directory, short_name): snippets = read_snippets(xml_directory) print("Found {} phrases".format(len(snippets))) os.makedirs(out_directory, exist_ok=True) process_utils.write_splits(out_directory, snippets, (process_utils.Split("%s.train.json" % short_name, 0.8), process_utils.Split("%s.dev.json" % short_name, 0.1), process_utils.Split("%s.test.json" % short_name, 0.1))) if __name__ == "__main__": random.seed(1234) xml_directory = sys.argv[1] out_directory = sys.argv[2] short_name = sys.argv[3] main(xml_directory, out_directory, short_name) ================================================ FILE: stanza/utils/datasets/sentiment/process_sb10k.py ================================================ """ Processes the SB10k dataset The original description of the dataset and corpus_v1.0.tsv is here: https://www.spinningbytes.com/resources/germansentiment/ Download script is here: https://github.com/aritter/twitter_download The problem with this file is that many of the tweets with labels no longer exist. Roughly 1/3 as of June 2020. You can contact the authors for the complete dataset. There is a paper describing some experiments run on the dataset here: https://dl.acm.org/doi/pdf/10.1145/3038912.3052611 """ import argparse import os import random from enum import Enum import stanza.utils.datasets.sentiment.process_utils as process_utils class Split(Enum): TRAIN_DEV_TEST = 1 TRAIN_DEV = 2 TEST = 3 MAPPING = {'positive': "2", 'neutral': "1", 'negative': "0"} def main(args=None): parser = argparse.ArgumentParser() parser.add_argument('--csv_filename', type=str, default=None, help='CSV file to read in') parser.add_argument('--out_dir', type=str, default=None, help='Where to write the output files') parser.add_argument('--sentiment_column', type=int, default=2, help='Column with the sentiment') parser.add_argument('--text_column', type=int, default=3, help='Column with the text') parser.add_argument('--short_name', type=str, default="sb10k", help='short name to use when writing files') parser.add_argument('--split', type=lambda x: Split[x.upper()], default=Split.TRAIN_DEV_TEST, help="How to split the resulting data") args = parser.parse_args(args=args) snippets = process_utils.read_snippets(args.csv_filename, args.sentiment_column, args.text_column, 'de', MAPPING) print(len(snippets)) random.shuffle(snippets) os.makedirs(args.out_dir, exist_ok=True) if args.split is Split.TRAIN_DEV_TEST: process_utils.write_splits(args.out_dir, snippets, (process_utils.Split("%s.train.json" % args.short_name, 0.8), process_utils.Split("%s.dev.json" % args.short_name, 0.1), process_utils.Split("%s.test.json" % args.short_name, 0.1))) elif args.split is Split.TRAIN_DEV: process_utils.write_splits(args.out_dir, snippets, (process_utils.Split("%s.train.json" % args.short_name, 0.9), process_utils.Split("%s.dev.json" % args.short_name, 0.1))) elif args.split is Split.TEST: process_utils.write_list(os.path.join(args.out_dir, "%s.test.json" % args.short_name), snippets) else: raise ValueError("Unknown split method {}".format(args.split)) if __name__ == '__main__': random.seed(1234) main() ================================================ FILE: stanza/utils/datasets/sentiment/process_scare.py ================================================ """ SCARE is a dataset of German text with sentiment annotations. http://romanklinger.de/scare/ To run the script, pass in the directory where scare was unpacked. It should have subdirectories scare_v1.0.0 and scare_v1.0.0_text You need to fill out a license agreement to not redistribute the data in order to get the data, but the process is not onerous. Although it sounds interesting, there are unfortunately a lot of very short items. Not sure the long items will be enough """ import csv import glob import os import sys import stanza from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils def get_scare_snippets(nlp, csv_dir_path, text_id_map, filename_pattern="*.csv"): """ Read snippets from the given CSV directory """ num_short_items = 0 snippets = [] csv_files = glob.glob(os.path.join(csv_dir_path, filename_pattern)) for csv_filename in csv_files: with open(csv_filename, newline='') as fin: cin = csv.reader(fin, delimiter='\t', quotechar='"') lines = list(cin) for line in lines: ann_id, begin, end, sentiment = [line[i] for i in [1, 2, 3, 6]] begin = int(begin) end = int(end) if sentiment.lower() == 'unknown': continue elif sentiment.lower() == 'positive': sentiment = 2 elif sentiment.lower() == 'neutral': sentiment = 1 elif sentiment.lower() == 'negative': sentiment = 0 else: raise ValueError("Tell John he screwed up and this is why he can't have Mox Opal: {}".format(sentiment)) if ann_id not in text_id_map: print("Found snippet which can't be found: {}-{}".format(csv_filename, ann_id)) continue snippet = text_id_map[ann_id][begin:end] doc = nlp(snippet) text = [token.text for sentence in doc.sentences for token in sentence.tokens] num_tokens = sum(len(sentence.tokens) for sentence in doc.sentences) if num_tokens < 4: num_short_items = num_short_items + 1 snippets.append(SentimentDatum(sentiment, text)) print("Number of short items: {}".format(num_short_items)) return snippets def main(in_directory, out_directory, short_name): os.makedirs(out_directory, exist_ok=True) input_path = os.path.join(in_directory, "scare_v1.0.0_text", "annotations", "*txt") text_files = glob.glob(input_path) if len(text_files) == 0: raise FileNotFoundError("Did not find any input files in %s" % input_path) else: print("Found %d input files in %s" % (len(text_files), input_path)) text_id_map = {} for filename in text_files: with open(filename) as fin: for line in fin.readlines(): line = line.strip() if not line: continue key, value = line.split(maxsplit=1) if key in text_id_map: raise ValueError("Duplicate key {}".format(key)) text_id_map[key] = value print("Found %d total sentiment ratings" % len(text_id_map)) nlp = stanza.Pipeline('de', processors='tokenize') snippets = get_scare_snippets(nlp, os.path.join(in_directory, "scare_v1.0.0", "annotations"), text_id_map) print(len(snippets)) process_utils.write_list(os.path.join(out_directory, "%s.train.json" % short_name), snippets) if __name__ == '__main__': in_directory = sys.argv[1] out_directory = sys.argv[2] short_name = sys.argv[3] main(in_directory, out_directory, short_name) ================================================ FILE: stanza/utils/datasets/sentiment/process_slsd.py ================================================ """ A small dataset of 1500 positive and 1500 negative sentences. Supposedly has no neutral sentences by design https://archive.ics.uci.edu/ml/datasets/Sentiment+Labelled+Sentences https://archive.ics.uci.edu/ml/machine-learning-databases/00331/ See the existing readme for citation requirements etc Files in the slsd repo were one line per annotation, with labels 0 for negative and 1 for positive. No neutral labels existed. Accordingly, we rearrange the text and adjust the label to fit the 0/1/2 paradigm. Text is retokenized using PTBTokenizer. process_slsd.py """ import os import sys from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils def get_phrases(in_directory): in_filenames = [os.path.join(in_directory, 'amazon_cells_labelled.txt'), os.path.join(in_directory, 'imdb_labelled.txt'), os.path.join(in_directory, 'yelp_labelled.txt')] lines = [] for filename in in_filenames: lines.extend(open(filename, newline='')) phrases = [] for line in lines: line = line.strip() sentiment = line[-1] utterance = line[:-1] utterance = utterance.replace("!.", "!") utterance = utterance.replace("?.", "?") if sentiment == '0': sentiment = '0' elif sentiment == '1': sentiment = '2' else: raise ValueError("Unknown sentiment: {}".format(sentiment)) phrases.append(SentimentDatum(sentiment, utterance)) return phrases def get_tokenized_phrases(in_directory): phrases = get_phrases(in_directory) phrases = process_utils.get_ptb_tokenized_phrases(phrases) print("Found %d phrases in slsd" % len(phrases)) return phrases def main(in_directory, out_directory, short_name): phrases = get_tokenized_phrases(in_directory) out_filename = os.path.join(out_directory, "%s.train.json" % short_name) os.makedirs(out_directory, exist_ok=True) process_utils.write_list(out_filename, phrases) if __name__ == '__main__': in_directory = sys.argv[1] out_directory = sys.argv[2] short_name = sys.argv[3] main(in_directory, out_directory, short_name) ================================================ FILE: stanza/utils/datasets/sentiment/process_sst.py ================================================ import argparse import os import subprocess from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils import stanza.utils.default_paths as default_paths TREEBANK_FILES = ["train.txt", "dev.txt", "test.txt", "extra-train.txt", "checked-extra-train.txt"] ARGUMENTS = { "fiveclass": [], "root": ["-root_only"], "binary": ["-ignore_labels", "2", "-remap_labels", "1=0,2=-1,3=1,4=1"], "binaryroot": ["-root_only", "-ignore_labels", "2", "-remap_labels", "1=0,2=-1,3=1,4=1"], "threeclass": ["-remap_labels", "0=0,1=0,2=1,3=2,4=2"], "threeclassroot": ["-root_only", "-remap_labels", "0=0,1=0,2=1,3=2,4=2"], } def get_subtrees(input_file, *args): """ Use the CoreNLP OutputSubtrees tool to convert the input file to a bunch of phrases Returns a list of the SentimentDatum namedtuple """ # TODO: maybe can convert this to use the python tree? cmd = ["java", "edu.stanford.nlp.trees.OutputSubtrees", "-input", input_file] if len(args) > 0: cmd = cmd + list(args) print (" ".join(cmd)) results = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") lines = results.stdout.split("\n") lines = [x.strip() for x in lines] lines = [x for x in lines if x] lines = [x.split(maxsplit=1) for x in lines] phrases = [SentimentDatum(x[0], x[1].split()) for x in lines] return phrases def get_phrases(dataset, treebank_file, input_dir): extra_args = ARGUMENTS[dataset] input_file = os.path.join(input_dir, "fiveclass", treebank_file) if not os.path.exists(input_file): raise FileNotFoundError(input_file) phrases = get_subtrees(input_file, *extra_args) print("Found {} phrases in SST {} {}".format(len(phrases), treebank_file, dataset)) return phrases def convert_version(dataset, treebank_file, input_dir, output_dir): """ Convert the fiveclass files to a specific format Uses the ARGUMENTS specific for the format wanted """ phrases = get_phrases(dataset, treebank_file, input_dir) output_file = os.path.join(output_dir, "en_sst.%s.%s.json" % (dataset, treebank_file.split(".")[0])) process_utils.write_list(output_file, phrases) def parse_args(): """ Actually, the only argument used right now is the formats to convert """ parser = argparse.ArgumentParser() parser.add_argument('sections', type=str, nargs='*', help='Which transformations to use: {}'.format(" ".join(ARGUMENTS.keys()))) args = parser.parse_args() if not args.sections: args.sections = list(ARGUMENTS.keys()) return args def main(): args = parse_args() paths = default_paths.get_default_paths() input_dir = os.path.join(paths["SENTIMENT_BASE"], "sentiment-treebank") output_dir = paths["SENTIMENT_DATA_DIR"] os.makedirs(output_dir, exist_ok=True) for section in args.sections: for treebank_file in TREEBANK_FILES: convert_version(section, treebank_file, input_dir, output_dir) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/sentiment/process_usage_german.py ================================================ """ USAGE is produced by the same people as SCARE. USAGE has a German and English part. This script parses the German part. Run the script as process_usage_german.py path Here, path should be where USAGE was unpacked. It will have the documents, files, etc subdirectories. https://www.romanklinger.de/usagecorpus/ """ import csv import glob import os import sys import stanza from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils def main(in_directory, out_directory, short_name): os.makedirs(out_directory, exist_ok=True) nlp = stanza.Pipeline('de', processors='tokenize') num_short_items = 0 snippets = [] csv_files = glob.glob(os.path.join(in_directory, "files/de*csv")) for csv_filename in csv_files: with open(csv_filename, newline='') as fin: cin = csv.reader(fin, delimiter='\t', quotechar=None) lines = list(cin) for index, line in enumerate(lines): begin, end, snippet, sentiment = [line[i] for i in [2, 3, 4, 6]] begin = int(begin) end = int(end) if len(snippet) != end - begin: raise ValueError("Error found in {} line {}. Expected {} got {}".format(csv_filename, index, (end-begin), len(snippet))) if sentiment.lower() == 'unknown': continue elif sentiment.lower() == 'positive': sentiment = 2 elif sentiment.lower() == 'neutral': sentiment = 1 elif sentiment.lower() == 'negative': sentiment = 0 else: raise ValueError("Tell John he screwed up and this is why he can't have Mox Opal: {}".format(sentiment)) doc = nlp(snippet) text = [token.text for sentence in doc.sentences for token in sentence.tokens] num_tokens = sum(len(sentence.tokens) for sentence in doc.sentences) if num_tokens < 4: num_short_items = num_short_items + 1 snippets.append(SentimentDatum(sentiment, text)) print("Total snippets found for USAGE: %d" % len(snippets)) process_utils.write_list(os.path.join(out_directory, "%s.train.json" % short_name), snippets) if __name__ == '__main__': in_directory = sys.argv[1] out_directory = sys.argv[2] short_name = sys.argv[3] main(in_directory, out_directory, short_name) ================================================ FILE: stanza/utils/datasets/sentiment/process_utils.py ================================================ import csv import glob import json import os import tempfile from collections import namedtuple from tqdm import tqdm import stanza from stanza.models.classifiers.data import SentimentDatum Split = namedtuple('Split', ['filename', 'weight']) SHARDS = ("train", "dev", "test") def write_list(out_filename, dataset): """ Write a list of items to the given output file Expected: list(SentimentDatum) """ formatted_dataset = [line._asdict() for line in dataset] # Rather than write the dataset at once, we write one line at a time # Using `indent` puts each word on a separate line, which is rather noisy, # but not formatting at all makes one long line out of an entire dataset, # which is impossible to read #json.dump(formatted_dataset, fout, indent=2, ensure_ascii=False) with open(out_filename, 'w') as fout: fout.write("[\n") for idx, line in enumerate(formatted_dataset): fout.write(" ") json.dump(line, fout, ensure_ascii=False) if idx < len(formatted_dataset) - 1: fout.write(",") fout.write("\n") fout.write("]\n") def write_dataset(dataset, out_directory, dataset_name): """ Write train, dev, test as .json files for a given dataset dataset: 3 lists of sentiment tuples """ for shard, phrases in zip(SHARDS, dataset): output_file = os.path.join(out_directory, "%s.%s.json" % (dataset_name, shard)) write_list(output_file, phrases) def write_splits(out_directory, snippets, splits): """ Write the given list of items to the split files in the specified output directory """ total_weight = sum(split.weight for split in splits) divs = [] subtotal = 0.0 for split in splits: divs.append(int(len(snippets) * subtotal / total_weight)) subtotal = subtotal + split.weight # the last div will be guaranteed to be the full thing - no math used divs.append(len(snippets)) for i, split in enumerate(splits): filename = os.path.join(out_directory, split.filename) print("Writing {}:{} to {}".format(divs[i], divs[i+1], filename)) write_list(filename, snippets[divs[i]:divs[i+1]]) def clean_tokenized_tweet(line): line = list(line) if len(line) > 3 and line[0] == 'RT' and line[1][0] == '@' and line[2] == ':': line = line[3:] elif len(line) > 4 and line[0] == 'RT' and line[1] == '@' and line[3] == ':': line = line[4:] elif line[0][0] == '@': line = line[1:] for i in range(len(line)): if line[i][0] == '@' or line[i][0] == '#': line[i] = line[i][1:] line = [x for x in line if x and not x.startswith("http:") and not x.startswith("https:")] return line def get_ptb_tokenized_phrases(dataset): """ Use the PTB tokenizer to retokenize the phrases Not clear which is better, "Nov." or "Nov ." strictAcronym=true makes it do the latter tokenizePerLine=true should make it only pay attention to one line at a time Phrases will be returned as lists of words rather than one string """ with tempfile.TemporaryDirectory() as tempdir: phrase_filename = os.path.join(tempdir, "phrases.txt") #phrase_filename = "asdf.txt" with open(phrase_filename, "w", encoding="utf-8") as fout: for item in dataset: # extra newlines are so the tokenizer treats the lines # as separate sentences fout.write("%s\n\n\n" % (item.text)) tok_filename = os.path.join(tempdir, "tokenized.txt") os.system('java edu.stanford.nlp.process.PTBTokenizer -options "strictAcronym=true,tokenizePerLine=true" -preserveLines %s > %s' % (phrase_filename, tok_filename)) with open(tok_filename, encoding="utf-8") as fin: tokenized = fin.readlines() tokenized = [x.strip() for x in tokenized] tokenized = [x for x in tokenized if x] phrases = [SentimentDatum(x.sentiment, y.split()) for x, y in zip(dataset, tokenized)] return phrases def process_datum(nlp, text, mapping, sentiment): doc = nlp(text.strip()) converted_sentiment = mapping.get(sentiment, None) if converted_sentiment is None: raise ValueError("Value {} not in mapping at line {} of {}".format(sentiment, idx, csv_filename)) text = [] for sentence in doc.sentences: text.extend(token.text for token in sentence.tokens) text = clean_tokenized_tweet(text) return SentimentDatum(converted_sentiment, text) def read_snippets(csv_filename, sentiment_column, text_column, tokenizer_language, mapping, delimiter='\t', quotechar=None, skip_first_line=False, nlp=None, encoding="utf-8"): """ Read in a single CSV file and return a list of SentimentDatums """ if nlp is None: nlp = stanza.Pipeline(tokenizer_language, processors='tokenize') with open(csv_filename, newline='', encoding=encoding) as fin: if skip_first_line: next(fin) cin = csv.reader(fin, delimiter=delimiter, quotechar=quotechar) lines = list(cin) # Read in the data and parse it snippets = [] for idx, line in enumerate(tqdm(lines)): try: if isinstance(sentiment_column, int): sentiment = line[sentiment_column].lower() else: sentiment = tuple([line[x] for x in sentiment_column]) except IndexError as e: raise IndexError("Columns {} did not exist at line {}: {}".format(sentiment_column, idx, line)) from e text = line[text_column] datum = process_datum(nlp, text, mapping, sentiment) snippets.append(datum) return snippets ================================================ FILE: stanza/utils/datasets/sentiment/process_vsfc_vietnamese.py ================================================ """ VSFC sentiment dataset is available at https://drive.google.com/drive/folders/1xclbjHHK58zk2X6iqbvMPS2rcy9y9E0X The format is extremely similar to ours - labels are 0,1,2. Text needs to be tokenized, though. Also, the files are split into two pieces, labels and text. """ import os import sys from tqdm import tqdm import stanza from stanza.models.classifiers.data import SentimentDatum import stanza.utils.datasets.sentiment.process_utils as process_utils import stanza.utils.default_paths as default_paths def combine_columns(in_directory, dataset, nlp): directory = os.path.join(in_directory, dataset) sentiment_file = os.path.join(directory, "sentiments.txt") with open(sentiment_file) as fin: sentiment = fin.readlines() text_file = os.path.join(directory, "sents.txt") with open(text_file) as fin: text = fin.readlines() text = [[token.text for sentence in nlp(line.strip()).sentences for token in sentence.tokens] for line in tqdm(text)] phrases = [SentimentDatum(s.strip(), t) for s, t in zip(sentiment, text)] return phrases def main(in_directory, out_directory, short_name): nlp = stanza.Pipeline('vi', processors='tokenize') for shard in ("train", "dev", "test"): phrases = combine_columns(in_directory, shard, nlp) output_file = os.path.join(out_directory, "%s.%s.json" % (short_name, shard)) process_utils.write_list(output_file, phrases) if __name__ == '__main__': paths = default_paths.get_default_paths() if len(sys.argv) <= 1: in_directory = os.path.join(paths['SENTIMENT_BASE'], "vietnamese", "_UIT-VSFC") else: in_directory = sys.argv[1] if len(sys.argv) <= 2: out_directory = paths['SENTIMENT_DATA_DIR'] else: out_directory = sys.argv[2] if len(sys.argv) <= 3: short_name = 'vi_vsfc' else: short_name = sys.argv[3] main(in_directory, out_directory, short_name) ================================================ FILE: stanza/utils/datasets/thai_syllable_dict_generator.py ================================================ import glob import pathlib import argparse def create_dictionary(dataset_dir, save_dir): syllables = set() for p in pathlib.Path(dataset_dir).rglob("*.ssg"): # iterate through all files with open(p) as f: # for each file sentences = f.readlines() for i in range(len(sentences)): sentences[i] = sentences[i].replace("\n", "") sentences[i] = sentences[i].replace("", "~") sentences[i] = sentences[i].split("~") # create list of all syllables syllables = syllables.union(sentences[i]) print(len(syllables)) # Filter out syllables with English words import re a = [] for s in syllables: print("---") if bool(re.match("^[\u0E00-\u0E7F]*$", s)) and s != "" and " " not in s: a.append(s) else: pass a = set(a) a = dict(zip(list(a), range(len(a)))) import json print(a) print(len(a)) with open(save_dir, "w") as fp: json.dump(a, fp) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--dataset_dir', type=str, default="syllable_segmentation_data", help="Directory for syllable dataset") parser.add_argument('--save_dir', type=str, default="thai-syllable.json", help="Directory for generated file") args = parser.parse_args() create_dictionary(args.dataset_dir, args.save_dir) ================================================ FILE: stanza/utils/datasets/tokenization/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/tokenization/convert_ml_cochin.py ================================================ """ Convert a Malayalam NER dataset to a tokenization dataset using the additional labeling provided by TTec's Indian partners This is still WIP - ongoing discussion with TTec and the team at UFAL doing the UD Malayalam dataset - but if someone wants the data to recreate it, feel free to contact Prof. Manning or John Bauer Data was annotated through Datasaur by TTec - possibly another team involved, will double check with the annotators. #1 current issue with the data is a difference in annotation style observed by the UFAL group. I believe TTec is working on reannotating this. Discussing the first sentence in the first split file: > I am not sure about the guidelines that the annotators followed, but > I would not have split നാമജപത്തോടുകൂടി as നാമ --- ജപത്തോടുകൂടി. Because > they are not multiple syntactic words. I would have done it like > നാമജപത്തോടു --- കൂടി as കൂടി ('with') can be tagged as ADP. I agree with > the second MWT വ്യത്യസ്തം --- കൂടാതെ. > > In Malayalam, we do have many words which potentially can be treated > as compounds and split but sometimes it becomes difficult to make > that decision as the etymology or the word formation process is > unclear. So for the Malayalam UD annotations I stayed away from > doing it because I didn't find it necessary and moreover the > guidelines say that the words should be split into syntactic words > and not into morphemes. As for using this script, create a directory extern_data/malayalam/cochin_ner/ The original NER dataset from Cochin University going there: extern_data/malayalam/cochin_ner/final_ner.txt The relabeled data from TTEC goes in extern_data/malayalam/cochin_ner/relabeled_tsv/malayalam_File_1.txt.tsv etc etc This can be invoked from the command line, or it can be used as part of stanza/utils/datasets/prepare_tokenizer_treebank.py ml_cochin in which case the conll splits will be turned into tokenizer labels as well """ from difflib import SequenceMatcher import os import random import sys import stanza.utils.default_paths as default_paths def read_words(filename): with open(filename, encoding="utf-8") as fin: text = fin.readlines() text = [x.strip().split()[0] if x.strip() else "" for x in text] return text def read_original_text(input_dir): original_file = os.path.join(input_dir, "final_ner.txt") return read_words(original_file) def list_relabeled_files(relabeled_dir): tsv_files = os.listdir(relabeled_dir) assert all(x.startswith("malayalam_File_") and x.endswith(".txt.tsv") for x in tsv_files) tsv_files = sorted(tsv_files, key = lambda filename: int(filename.split(".")[0].split("_")[2])) return tsv_files def find_word(original_text, target, start_index, end_index): for word in original_text[start_index:end_index]: if word == target: return True return False def scan_file(original_text, current_index, tsv_file): relabeled_text = read_words(tsv_file) # for now, at least, we ignore these markers relabeled_indices = [idx for idx, x in enumerate(relabeled_text) if x != '$' and x != '^'] relabeled_text = [x for x in relabeled_text if x != '$' and x != '^'] diffs = SequenceMatcher(None, original_text, relabeled_text, False) blocks = diffs.get_matching_blocks() assert blocks[-1].size == 0 if len(blocks) == 1: raise ValueError("Could not find a match between %s and the original text" % tsv_file) sentences = [] current_sentence = [] in_mwt = False bad_sentence = False current_mwt = [] block_index = 0 current_block = blocks[0] for tsv_index, next_word in enumerate(relabeled_text): if not next_word: if in_mwt: current_mwt = [] in_mwt = False bad_sentence = True print("Unclosed MWT found at %s line %d" % (tsv_file, tsv_index)) if current_sentence: if not bad_sentence: sentences.append(current_sentence) bad_sentence = False current_sentence = [] continue # tsv_index will now be inside the current block or before the current block while tsv_index >= blocks[block_index].b + current_block.size: block_index += 1 current_block = blocks[block_index] #print(tsv_index, current_block.b, current_block.size) if next_word == ',' or next_word == '.': # many of these punctuations were added by the relabelers current_sentence.append(next_word) continue if tsv_index >= current_block.b and tsv_index < current_block.b + current_block.size: # ideal case: in a matching block current_sentence.append(next_word) continue # in between blocks... need to handle re-spelled words and MWTs if not in_mwt and next_word == '@': in_mwt = True continue if not in_mwt: current_sentence.append(next_word) continue if in_mwt and next_word == '@' and (tsv_index + 1 < len(relabeled_text) and relabeled_text[tsv_index+1] == '@'): # we'll stop the MWT next time around continue if in_mwt and next_word == '@': if block_index > 0 and (len(current_mwt) == 2 or len(current_mwt) == 3): mwt = "".join(current_mwt) start_original = blocks[block_index-1].a + blocks[block_index-1].size end_original = current_block.a if find_word(original_text, mwt, start_original, end_original): current_sentence.append((mwt, current_mwt)) else: print("%d word MWT %s at %s %d. Should be somewhere in %d %d" % (len(current_mwt), mwt, tsv_file, relabeled_indices[tsv_index], start_original, end_original)) bad_sentence = True elif len(current_mwt) > 6: raise ValueError("Unreasonably long MWT span in %s at line %d" % (tsv_file, relabeled_indices[tsv_index])) elif len(current_mwt) > 3: print("%d word sequence, stop being lazy - %s %d" % (len(current_mwt), tsv_file, relabeled_indices[tsv_index])) bad_sentence = True else: # short MWT, but it was at the start of a file, and we don't want to search the whole file for the item # TODO, could maybe search the 10 words or so before the start of the block? bad_sentence = True current_mwt = [] in_mwt = False continue # now we know we are in an MWT... TODO current_mwt.append(next_word) if len(current_sentence) > 0 and not bad_sentence: sentences.append(current_sentence) return current_index, sentences def split_sentences(sentences): train = [] dev = [] test = [] for sentence in sentences: rand = random.random() if rand < 0.8: train.append(sentence) elif rand < 0.9: dev.append(sentence) else: test.append(sentence) return train, dev, test def main(input_dir, tokenizer_dir, relabeled_dir="relabeled_tsv", split_data=True): random.seed(1006) input_dir = os.path.join(input_dir, "malayalam", "cochin_ner") relabeled_dir = os.path.join(input_dir, relabeled_dir) tsv_files = list_relabeled_files(relabeled_dir) original_text = read_original_text(input_dir) print("Original text len: %d" %len(original_text)) current_index = 0 sentences = [] for tsv_file in tsv_files: print(tsv_file) current_index, new_sentences = scan_file(original_text, current_index, os.path.join(relabeled_dir, tsv_file)) sentences.extend(new_sentences) print("Found %d sentences" % len(sentences)) if split_data: splits = split_sentences(sentences) SHARDS = ("train", "dev", "test") else: splits = [sentences] SHARDS = ["train"] for split, shard in zip(splits, SHARDS): output_filename = os.path.join(tokenizer_dir, "ml_cochin.%s.gold.conllu" % shard) print("Writing %d sentences to %s" % (len(split), output_filename)) with open(output_filename, "w", encoding="utf-8") as fout: for sentence in split: word_idx = 1 for token in sentence: if isinstance(token, str): fake_dep = "\t0\troot" if word_idx == 1 else "\t1\tdep" fout.write("%d\t%s" % (word_idx, token) + "\t_" * 4 + fake_dep + "\t_\t_\n") word_idx += 1 else: text = token[0] mwt = token[1] fout.write("%d-%d\t%s" % (word_idx, word_idx + len(mwt) - 1, text) + "\t_" * 8 + "\n") for piece in mwt: fake_dep = "\t0\troot" if word_idx == 1 else "\t1\tdep" fout.write("%d\t%s" % (word_idx, piece) + "\t_" * 4 + fake_dep + "\t_\t_\n") word_idx += 1 fout.write("\n") if __name__ == '__main__': sys.stdout.reconfigure(encoding='utf-8') paths = default_paths.get_default_paths() tokenizer_dir = paths["TOKENIZE_DATA_DIR"] input_dir = paths["STANZA_EXTERN_DIR"] main(input_dir, tokenizer_dir, "relabeled_tsv_v2", False) ================================================ FILE: stanza/utils/datasets/tokenization/convert_my_alt.py ================================================ """Converts the Myanmar ALT corpus to a tokenizer dataset. The ALT corpus is in the form of constituency trees, which basically means there is no guidance on where the whitespace belongs. However, in Myanmar writing, whitespace is apparently not actually required anywhere. The plan will be to make sentences where there is no whitespace at all, along with a random selection of sentences where some whitespace is randomly inserted. The treebank is available here: https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/ The following files describe the splits of the data: https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-train.txt https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-dev.txt https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-test.txt and this is the actual treebank: https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/my-alt-190530.zip Download each of the files, then unzip the my-alt zip in place. The expectation is this will produce a file my-alt-190530/data The default expected path to the Myanmar data is extern_data/constituency/myanmar/my_alt/my-alt-190530/data """ import os import random from stanza.models.constituency.tree_reader import read_trees def read_split(input_dir, section): """ Reads the split description for train, dev, or test Format (at least for the Myanmar section of ALT) is: one description per line each line is URL. we actually don't care about the URL itself all we want is the number, which we use to split up the tree file later Returns a set of numbers (as strings) """ urls = set() filename = os.path.join(input_dir, "myanmar", "my_alt", "URL-%s.txt" % section) with open(filename) as fin: lines = fin.readlines() for line in lines: line = line.strip() if not line or not line.startswith("URL"): continue # split into URL.100161 and a bunch of description we don't care about line = line.split(maxsplit=1) # get just the number line = line[0].split(".") assert len(line) == 2 assert line[0] == 'URL' urls.add(line[1]) return urls SPLITS = ("train", "dev", "test") def read_dataset_splits(input_dir): """ Call read_split for train, dev, and test Returns three sets: train, dev, test in order """ url_splits = [read_split(input_dir, section) for section in SPLITS] for url_split, split in zip(url_splits, SPLITS): print("Split %s has %d files in it" % (split, len(url_split))) return url_splits def read_alt_treebank(constituency_input_dir): """ Read the splits, read the trees, and split the trees based on the split descriptions Trees in ALT are: The tree id will look like SNT.. All we care about from this id is the url_id, which we crossreference in the splits to figure out which split the tree is in. The tree itself we don't process much, although we do convert it to a ParseTree The result is three lists: train, dev, test trees """ train_split, dev_split, test_split = read_dataset_splits(constituency_input_dir) datafile = os.path.join(constituency_input_dir, "myanmar", "my_alt", "my-alt-190530", "data") print("Reading trees from %s" % datafile) with open(datafile) as fin: tree_lines = fin.readlines() train_trees = [] dev_trees = [] test_trees = [] for idx, tree_line in enumerate(tree_lines): tree_line = tree_line.strip() if not tree_line: continue dataset, tree_text = tree_line.split(maxsplit=1) dataset = dataset.split(".", 2)[1] trees = read_trees(tree_text) if len(trees) != 1: raise ValueError("Unexpected number of trees in line %d: %d" % (idx, len(trees))) tree = trees[0] if dataset in train_split: train_trees.append(tree) elif dataset in dev_split: dev_trees.append(tree) elif dataset in test_split: test_trees.append(tree) else: raise ValueError("Could not figure out which split line %d belongs to" % idx) return train_trees, dev_trees, test_trees def write_sentence(fout, words, spaces): """ Write a sentence based on the list of words. spaces is a fraction of the words which should randomly have spaces If 0.0, none of the words will have spaces This is because the Myanmar language doesn't require spaces, but spaces always separate words """ full_text = "".join(words) fout.write("# text = %s\n" % full_text) for word_idx, word in enumerate(words): fake_dep = "root" if word_idx == 0 else "dep" fout.write("%d\t%s\t%s" % ((word_idx+1), word, word)) fout.write("\t_\t_\t_") fout.write("\t%d\t%s" % (word_idx, fake_dep)) fout.write("\t_\t") if random.random() > spaces: fout.write("SpaceAfter=No") else: fout.write("_") fout.write("\n") fout.write("\n") def write_dataset(filename, trees, split): """ Write all of the trees to the given filename """ count = 0 with open(filename, "w") as fout: # TODO: make some fraction have random spaces inserted for tree in trees: count = count + 1 words = tree.leaf_labels() write_sentence(fout, words, spaces=0.0) # We include a small number of spaces to teach the model # that spaces always separate a word if split == 'train' and random.random() < 0.1: count = count + 1 write_sentence(fout, words, spaces=0.05) print("Wrote %d sentences from %d trees to %s" % (count, len(trees), filename)) def convert_my_alt(constituency_input_dir, tokenizer_dir): """ Read and then convert the Myanmar ALT treebank """ random.seed(1234) tree_splits = read_alt_treebank(constituency_input_dir) output_filenames = [os.path.join(tokenizer_dir, "my_alt.%s.gold.conllu") % split for split in SPLITS] for filename, trees, split in zip(output_filenames, tree_splits, SPLITS): write_dataset(filename, trees, split) def main(): convert_my_alt("extern_data/constituency", "data/tokenize") if __name__ == "__main__": main() ================================================ FILE: stanza/utils/datasets/tokenization/convert_text_files.py ================================================ """ Given a text file and a file with one word per line, convert the text file Sentence splits should be represented as blank lines at the end of a sentence. """ import argparse import os import random from stanza.models.tokenization.utils import match_tokens_with_text import stanza.utils.datasets.common as common def read_tokens_file(token_file): """ Returns a list of list of tokens Each sentence is a list of tokens """ sentences = [] current_sentence = [] with open(token_file, encoding="utf-8") as fin: for line in fin: line = line.strip() if not line: if current_sentence: sentences.append(current_sentence) current_sentence = [] else: current_sentence.append(line) if current_sentence: sentences.append(current_sentence) return sentences def read_sentences_file(sentence_file): sentences = [] with open(sentence_file, encoding="utf-8") as fin: for line in fin: line = line.strip() if not line: continue sentences.append(line) return sentences def process_raw_file(text_file, token_file, sentence_file, base_sent_idx=0): """ Process a text file separated into a list of tokens using match_tokens_with_text from the tokenizer The tokens are one per line in the token_file The tokens in the token_file must add up to the text_file modulo whitespace. Sentences are also one per line in the sentence_file These must also add up to text_file The return format is a list of list of conllu lines representing the sentences. The only fields set will be the token index, the token text, and possibly SpaceAfter=No where SpaceAfter=No is true if the next token started with no whitespace in the text file """ with open(text_file, encoding="utf-8") as fin: text = fin.read() tokens = read_tokens_file(token_file) tokens = [[token for sentence in tokens for token in sentence]] tokens_doc = match_tokens_with_text(tokens, text) assert len(tokens_doc.sentences) == 1 assert len(tokens_doc.sentences[0].tokens) == len(tokens[0]) sentences = read_sentences_file(sentence_file) sentences_doc = match_tokens_with_text([sentences], text) assert len(sentences_doc.sentences) == 1 assert len(sentences_doc.sentences[0].tokens) == len(sentences) start_token_idx = 0 sentences = [] for sent_idx, sentence in enumerate(sentences_doc.sentences[0].tokens): tokens = [] tokens.append("# sent_id = %d" % (base_sent_idx + sent_idx + 1)) tokens.append("# text = %s" % text[sentence.start_char:sentence.end_char].replace("\n", " ")) token_idx = 0 while token_idx + start_token_idx < len(tokens_doc.sentences[0].tokens): token = tokens_doc.sentences[0].tokens[token_idx + start_token_idx] if token.start_char >= sentence.end_char: # have reached the end of this sentence # continue with the next sentence start_token_idx += token_idx break if token_idx + start_token_idx == len(tokens_doc.sentences[0].tokens) - 1: # definitely the end of the document space_after = True elif token.end_char == tokens_doc.sentences[0].tokens[token_idx + start_token_idx + 1].start_char: space_after = False else: space_after = True token = [str(token_idx+1), token.text] + ["_"] * 7 + ["_" if space_after else "SpaceAfter=No"] assert len(token) == 10, "Token length: %d" % len(token) token = "\t".join(token) tokens.append(token) token_idx += 1 sentences.append(tokens) return sentences def extract_sentences(dataset_files): sentences = [] for text_file, token_file, sentence_file in dataset_files: print("Extracting sentences from %s and tokens from %s from the text file %s" % (sentence_file, token_file, text_file)) sentences.extend(process_raw_file(text_file, token_file, sentence_file, len(sentences))) return sentences def split_sentences(sentences, train_split=0.8, dev_split=0.1): """ Splits randomly without shuffling """ generator = random.Random(1234) train = [] dev = [] test = [] for sentence in sentences: r = generator.random() if r < train_split: train.append(sentence) elif r < train_split + dev_split: dev.append(sentence) else: test.append(sentence) return (train, dev, test) def find_dataset_files(input_path, token_prefix, sentence_prefix): files = os.listdir(input_path) print("Found %d files in %s" % (len(files), input_path)) if len(files) > 0: if len(files) < 20: print("Files:", end="\n ") else: print("First few files:", end="\n ") print("\n ".join(files[:20])) token_files = {} sentence_files = {} text_files = [] for filename in files: if filename.endswith(".zip"): continue if filename.startswith(token_prefix): short_filename = filename[len(token_prefix):] if short_filename.startswith("_"): short_filename = short_filename[1:] token_files[short_filename] = filename elif filename.startswith(sentence_prefix): short_filename = filename[len(sentence_prefix):] if short_filename.startswith("_"): short_filename = short_filename[1:] sentence_files[short_filename] = filename else: text_files.append(filename) dataset_files = [] for filename in text_files: if filename not in token_files: raise FileNotFoundError("When looking in %s, found %s as a text file, but did not find a corresponding tokens file at %s_%s Please give an input directory which has only the text files, tokens files, and sentences files" % (input_path, filename, token_prefix, filename)) if filename not in sentence_files: raise FileNotFoundError("When looking in %s, found %s as a text file, but did not find a corresponding sentences file at %s_%s Please give an input directory which has only the text files, tokens files, and sentences files" % (input_path, filename, sentence_prefix, filename)) text_file = os.path.join(input_path, filename) token_file = os.path.join(input_path, token_files[filename]) sentence_file = os.path.join(input_path, sentence_files[filename]) dataset_files.append((text_file, token_file, sentence_file)) return dataset_files SHARDS = ("train", "dev", "test") def main(): parser = argparse.ArgumentParser() parser.add_argument('--token_prefix', type=str, default="tkns", help="Prefix for the token files") parser.add_argument('--sentence_prefix', type=str, default="stns", help="Prefix for the token files") parser.add_argument('--input_path', type=str, default="extern_data/sindhi/tokenization", help="Where to find all of the input files. Files with the prefix tkns_ will be treated as token files, files with the prefix stns_ will be treated as sentence files, and all others will be the text files.") parser.add_argument('--output_path', type=str, default="data/tokenize", help="Where to output the results") parser.add_argument('--dataset', type=str, default="sd_isra", help="What name to give this dataset") args = parser.parse_args() dataset_files = find_dataset_files(args.input_path, args.token_prefix, args.sentence_prefix) tokenizer_dir = args.output_path short_name = args.dataset # todo: convert a full name? sentences = extract_sentences(dataset_files) splits = split_sentences(sentences) os.makedirs(args.output_path, exist_ok=True) for dataset, shard in zip(splits, SHARDS): output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, shard) common.write_sentences_to_conllu(output_conllu, dataset) common.convert_conllu_to_txt(tokenizer_dir, short_name) common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/tokenization/convert_th_best.py ================================================ """Parses the BEST Thai dataset. That is to say, the dataset named BEST. We have not yet figured out which segmentation standard we prefer. Note that the version of BEST we used actually had some strange sentence splits according to a native Thai speaker. Not sure how to fix that. Options include doing it automatically or finding some knowledgable annotators to resplit it for us (or just not using BEST) This outputs the tokenization results in a conll format similar to that of the UD treebanks, so we pretend to be a UD treebank for ease of compatibility with the stanza tools. BEST can be downloaded from here: https://aiforthai.in.th/corpus.php python3 -m stanza.utils.datasets.tokenization.process_best extern_data/thai/best data/tokenize ./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000 """ import glob import os import random import re import sys try: from pythainlp import sent_tokenize except ImportError: pass from stanza.utils.datasets.tokenization.process_thai_tokenization import reprocess_lines, write_dataset, convert_processed_lines, write_dataset_best, write_dataset def clean_line(line): line = line.replace("html>", "html|>") # news_00089.txt line = line.replace("", "") line = line.replace("", "") # specific error that occurs in encyclopedia_00095.txt line = line.replace("Penn", "|Penn>") # news_00058.txt line = line.replace("จม.เปิดผนึก", "จม.|เปิดผนึก") # news_00015.txt line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) # news_00024.txt line = re.sub("([^|<>]+)", "\\1", line) # news_00055.txt line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) line = re.sub("([^|<>]+)([^|<>]+) ([^|<>]+)", "\\1|\\2|\\3", line) # news_00008.txt and other news articles line = re.sub("([0-9])", "|\\1", line) line = line.replace(" ", "|") line = line.replace("", "") line = line.replace("", "") line = line.strip() return line def clean_word(word): # novel_00078.txt if word == '': return 'พี่มน' if word.startswith("") and word.endswith(""): return word[4:-5] if word.startswith("") and word.endswith(""): return word[4:-5] if word.startswith("") and word.endswith(""): return word[6:-7] """ if word.startswith(""): return word[4:] if word.endswith(""): return word[:-5] """ if word.startswith(""): return word[4:] if word.endswith(""): return word[:-5] if word.startswith(""): return word[6:] if word.endswith(""): return word[:-7] if word == '<': return word return word def read_data(input_dir): # data for test sets test_files = [os.path.join(input_dir, 'TEST_100K_ANS.txt')] print(test_files) # data for train and dev sets subdirs = [os.path.join(input_dir, 'article'), os.path.join(input_dir, 'encyclopedia'), os.path.join(input_dir, 'news'), os.path.join(input_dir, 'novel')] files = [] for subdir in subdirs: if not os.path.exists(subdir): raise FileNotFoundError("Expected a directory that did not exist: {}".format(subdir)) files.extend(glob.glob(os.path.join(subdir, '*.txt'))) test_documents = [] for filename in test_files: print("File name:", filename) with open(filename) as fin: processed_lines = [] for line in fin.readlines(): line = clean_line(line) words = line.split("|") words = [clean_word(x) for x in words] for word in words: if len(word) > 1 and word[0] == '<': raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) words = [x for x in words if x] processed_lines.append(words) processed_lines = reprocess_lines(processed_lines) paragraphs = convert_processed_lines(processed_lines) test_documents.extend(paragraphs) print("Test document finished.") documents = [] for filename in files: with open(filename) as fin: print("File:", filename) processed_lines = [] for line in fin.readlines(): line = clean_line(line) words = line.split("|") words = [clean_word(x) for x in words] for word in words: if len(word) > 1 and word[0] == '<': raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) words = [x for x in words if x] processed_lines.append(words) processed_lines = reprocess_lines(processed_lines) paragraphs = convert_processed_lines(processed_lines) documents.extend(paragraphs) print("All documents finished.") return documents, test_documents def main(*args): random.seed(1000) if not args: args = sys.argv[1:] input_dir = args[0] full_input_dir = os.path.join(input_dir, "thai", "best") if os.path.exists(full_input_dir): # otherwise hopefully the user gave us the full path? input_dir = full_input_dir output_dir = args[1] documents, test_documents = read_data(input_dir) print("Finished reading data.") write_dataset_best(documents, test_documents, output_dir, "best") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/tokenization/convert_th_lst20.py ================================================ """Processes the tokenization section of the LST20 Thai dataset The dataset is available here: https://aiforthai.in.th/corpus.php The data should be installed under ${EXTERN_DATA}/thai/LST20_Corpus python3 -m stanza.utils.datasets.tokenization.convert_th_lst20 extern_data data/tokenize Unlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train/dev/test. ./scripts/run_tokenize.sh UD_Thai-lst20 --dropout 0.05 --unit_dropout 0.05 """ import argparse import glob import os import sys from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section, convert_processed_lines, reprocess_lines def read_document(lines, spaces_after, split_clauses): document = [] sentence = [] for line in lines: line = line.strip() if not line: if sentence: if spaces_after: sentence[-1] = (sentence[-1][0], True) document.append(sentence) sentence = [] else: pieces = line.split("\t") # there are some nbsp in tokens in lst20, but the downstream tools expect spaces pieces = [p.replace("\xa0", " ") for p in pieces] if split_clauses and pieces[0] == '_' and pieces[3] == 'O': if sentence: # note that we don't need to check spaces_after # the "token" is a space anyway sentence[-1] = (sentence[-1][0], True) document.append(sentence) sentence = [] elif pieces[0] == '_': sentence[-1] = (sentence[-1][0], True) else: sentence.append((pieces[0], False)) if sentence: if spaces_after: sentence[-1] = (sentence[-1][0], True) document.append(sentence) sentence = [] # TODO: is there any way to divide up a single document into paragraphs? return [[document]] def retokenize_document(lines): processed_lines = [] sentence = [] for line in lines: line = line.strip() if not line: if sentence: processed_lines.append(sentence) sentence = [] else: pieces = line.split("\t") if pieces[0] == '_': sentence.append(' ') else: sentence.append(pieces[0]) if sentence: processed_lines.append(sentence) processed_lines = reprocess_lines(processed_lines) paragraphs = convert_processed_lines(processed_lines) return paragraphs def read_data(input_dir, section, resegment, spaces_after, split_clauses): glob_path = os.path.join(input_dir, section, "*.txt") filenames = glob.glob(glob_path) print(" Found {} files in {}".format(len(filenames), glob_path)) if len(filenames) == 0: raise FileNotFoundError("Could not find any files for the {} section. Is LST20 installed in {}?".format(section, input_dir)) documents = [] for filename in filenames: with open(filename) as fin: lines = fin.readlines() if resegment: document = retokenize_document(lines) else: document = read_document(lines, spaces_after, split_clauses) documents.extend(document) return documents def add_lst20_args(parser): parser.add_argument('--no_lst20_resegment', action='store_false', dest="lst20_resegment", default=True, help='When processing th_lst20 tokenization, use pythainlp to resegment the text. The other option is to keep the original sentence segmentation. Currently our model is not good at that') parser.add_argument('--lst20_spaces_after', action='store_true', dest="lst20_spaces_after", default=False, help='When processing th_lst20 without pythainlp, put spaces after each sentence. This better fits the language but gets lower scores for some reason') parser.add_argument('--split_clauses', action='store_true', dest="split_clauses", default=False, help='When processing th_lst20 without pythainlp, turn spaces which are labeled as between clauses into sentence splits') def parse_lst20_args(): parser = argparse.ArgumentParser() parser.add_argument('input_dir', help="Directory to use when processing lst20") parser.add_argument('output_dir', help="Directory to use when saving lst20") add_lst20_args(parser) return parser.parse_args() def convert(input_dir, output_dir, args): input_dir = os.path.join(input_dir, "thai", "LST20_Corpus") if not os.path.exists(input_dir): raise FileNotFoundError("Could not find LST20 corpus in {}".format(input_dir)) for (in_section, out_section) in (("train", "train"), ("eval", "dev"), ("test", "test")): print("Processing %s" % out_section) documents = read_data(input_dir, in_section, args.lst20_resegment, args.lst20_spaces_after, args.split_clauses) print(" Read in %d documents" % len(documents)) write_section(output_dir, "lst20", out_section, documents) def main(): args = parse_lst20_args() convert(args.input_dir, args.output_dir, args) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/tokenization/convert_th_orchid.py ================================================ """Parses the xml conversion of orchid https://github.com/korakot/thainlp/blob/master/xmlchid.xml For example, if you put the data file in the above link in extern_data/thai/orchid/xmlchid.xml you would then run python3 -m stanza.utils.datasets.tokenization.convert_th_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize Because there is no definitive train/dev/test split that we have found so far, we randomly shuffle the data on a paragraph level and split it 80/10/10. A random seed is chosen so that the splits are reproducible. The datasets produced have a similar format to the UD datasets, so we give it a fake UD name to make life easier for the downstream tools. Training on this dataset seems to work best with low dropout numbers. For example: python3 -m stanza.utils.training.run_tokenizer th_orchid --dropout 0.05 --unit_dropout 0.05 This results in a model with dev set scores: th_orchid 87.98 70.94 test set scores: 91.60 72.43 Apparently the random split produced a test set easier than the dev set. """ import os import random import sys import xml.etree.ElementTree as ET from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset # line "122819" has some error in the tokenization of the musical notation # line "209380" is also messed up # others have @ followed by a part of speech, which is clearly wrong skipped_lines = { "122819", "209380", "227769", "245992", "347163", "409708", "431227", } escape_sequences = { '': '(', '': ')', '': '^', '': '.', '': '-', '': '*', '': '"', '': '/', '': ':', '': '=', '': ',', '': ';', '': '<', '': '>', '': '&', '': '{', '': '}', '': "'", '': '+', '': '#', '': '$', '': '@', '': '?', '': '!', 'app
  • ances': 'appliances', 'intel
  • gence': 'intelligence', "'": "/'", '<100>': '100', } allowed_sequences = { '', '', '', '', '', '
  • ', '<---vp', '<---', '<----', } def read_data(input_filename): print("Reading {}".format(input_filename)) tree = ET.parse(input_filename) documents = parse_xml(tree) print("Number of documents: {}".format(len(documents))) print("Number of paragraphs: {}".format(sum(len(document) for document in documents))) return documents def parse_xml(tree): # we will put each paragraph in a separate block in the output file # we won't pay any attention to the document boundaries unless we # later find out it was necessary # a paragraph will be a list of sentences # a sentence is a list of words, where each word is a string documents = [] root = tree.getroot() for document in root: # these should all be documents if document.tag != 'document': raise ValueError("Unexpected orchid xml layout: {}".format(document.tag)) paragraphs = [] for paragraph in document: if paragraph.tag != 'paragraph': raise ValueError("Unexpected orchid xml layout: {} under {}".format(paragraph.tag, document.tag)) sentences = [] for sentence in paragraph: if sentence.tag != 'sentence': raise ValueError("Unexpected orchid xml layout: {} under {}".format(sentence.tag, document.tag)) if sentence.attrib['line_num'] in skipped_lines: continue words = [] for word_idx, word in enumerate(sentence): if word.tag != 'word': raise ValueError("Unexpected orchid xml layout: {} under {}".format(word.tag, sentence.tag)) word = word.attrib['surface'] word = escape_sequences.get(word, word) if word == '': if word_idx == 0: raise ValueError("Space character was the first token in a sentence: {}".format(sentence.attrib['line_num'])) else: words[-1] = (words[-1][0], True) continue if len(word) > 1 and word[0] == '<' and word not in allowed_sequences: raise ValueError("Unknown escape sequence {}".format(word)) words.append((word, False)) if len(words) == 0: continue words[-1] = (words[-1][0], True) sentences.append(words) paragraphs.append(sentences) documents.append(paragraphs) return documents def main(*args): random.seed(1007) if not args: args = sys.argv[1:] input_filename = args[0] if os.path.isdir(input_filename): input_filename = os.path.join(input_filename, "thai", "orchid", "xmlchid.xml") output_dir = args[1] documents = read_data(input_filename) write_dataset(documents, output_dir, "orchid") if __name__ == '__main__': main() ================================================ FILE: stanza/utils/datasets/tokenization/convert_vi_vlsp.py ================================================ import os punctuation_set = (',', '.', '!', '?', ')', ':', ';', '”', '…', '...') def find_spaces(sentence): # TODO: there are some sentences where there is only one quote, # and some of them should be attached to the previous word instead # of the next word. Training should work this way, though odd_quotes = False spaces = [] for word_idx, word in enumerate(sentence): space = True # Quote period at the end of a sentence needs to be attached # to the rest of the text. Some sentences have `"... text` # in the middle, though, so look for that if word_idx < len(sentence) - 2 and sentence[word_idx+1] == '"': if sentence[word_idx+2] == '.': space = False elif word_idx == len(sentence) - 3 and sentence[word_idx+2] == '...': space = False if word_idx < len(sentence) - 1: if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...','/', '%'): space = False if word in ('(', '“', '/'): space = False if word == '"': if odd_quotes: # already saw one quote. put this one at the end of the PREVIOUS word # note that we know there must be at least one word already odd_quotes = False spaces[word_idx-1] = False else: odd_quotes = True space = False spaces.append(space) return spaces def add_vlsp_args(parser): parser.add_argument('--include_pos_data', action='store_true', default=False, help='To include or not POS training dataset for tokenization training. The path to POS dataset is expected to be in the same dir with WS path. For example, extern_dir/vietnamese/VLSP2013-POS-data') parser.add_argument('--vlsp_include_spaces', action='store_true', default=False, help='When processing vi_vlsp tokenization, include all of the spaces. Otherwise, we try to turn the text back into standard text') def write_file(vlsp_include_spaces, output_filename, sentences, shard): with open(output_filename, "w") as fout: check_headlines = False for sent_idx, sentence in enumerate(sentences): fout.write("# sent_id = %s.%d\n" % (shard, sent_idx)) orig_text = " ".join(sentence) #check if the previous line is a headline (no ending mark at the end) then make this sentence a new par if check_headlines: fout.write("# newpar id =%s.%d.1\n" % (shard, sent_idx)) check_headlines = False if sentence[len(sentence) - 1] not in punctuation_set: check_headlines = True if vlsp_include_spaces: fout.write("# text = %s\n" % orig_text) else: spaces = find_spaces(sentence) full_text = "" for word, space in zip(sentence, spaces): # could be made more efficient, but shouldn't matter full_text = full_text + word if space: full_text = full_text + " " fout.write("# text = %s\n" % full_text) fout.write("# orig_text = %s\n" % orig_text) for word_idx, word in enumerate(sentence): fake_dep = "root" if word_idx == 0 else "dep" fout.write("%d\t%s\t%s" % ((word_idx+1), word, word)) fout.write("\t_\t_\t_") fout.write("\t%d\t%s" % (word_idx, fake_dep)) fout.write("\t_\t") if vlsp_include_spaces or spaces[word_idx]: fout.write("_") else: fout.write("SpaceAfter=No") fout.write("\n") fout.write("\n") def convert_pos_dataset(file_path): """ This function is to process the pos dataset """ file = open(file_path, "r") document = file.readlines() sentences = [] sent = [] for line in document: if line == "\n" and len(sent)>1: if sent not in sentences: sentences.append(sent) sent = [] elif line != "\n": sent.append(line.split("\t")[0].replace("_"," ").strip()) return sentences def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, split_filename=None, split_shard=None, pos_data = None): with open(input_filename) as fin: lines = fin.readlines() sentences = [] set_sentences = set() for line in lines: if len(line.replace("_", " ").split())>1: words = line.split() #one syllable lines are eliminated if len(words) == 1 and len(words[0].split("_")) == 1: continue else: words = [w.replace("_", " ") for w in words] #only add sentences that hasn't been added before if words not in sentences: sentences.append(words) set_sentences.add(' '.join(words)) if split_filename is not None: # even this is a larger dev set than the train set split_point = int(len(sentences) * 0.95) #check pos_data that aren't overlapping with current VLSP WS dataset sentences_pos = [] if pos_data is None else [sent for sent in pos_data if ' '.join(sent) not in set_sentences] print("Added ", len(sentences_pos), " sentences from POS dataset.") write_file(vlsp_include_spaces, output_filename, sentences[:split_point]+sentences_pos, shard) write_file(vlsp_include_spaces, split_filename, sentences[split_point:], split_shard) else: write_file(vlsp_include_spaces, output_filename, sentences, shard) def convert_vi_vlsp(extern_dir, tokenizer_dir, args): input_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-WS-data") input_pos_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-POS-data") input_train_filename = os.path.join(input_path, "VLSP2013_WS_train_gold.txt") input_test_filename = os.path.join(input_path, "VLSP2013_WS_test_gold.txt") input_pos_filename = os.path.join(input_pos_path, "VLSP2013_POS_train_BI_POS_Column.txt.goldSeg") if not os.path.exists(input_train_filename): raise FileNotFoundError("Cannot find train set for VLSP at %s" % input_train_filename) if not os.path.exists(input_test_filename): raise FileNotFoundError("Cannot find test set for VLSP at %s" % input_test_filename) pos_data = None if args.include_pos_data: if not os.path.exists(input_pos_filename): raise FileNotFoundError("Cannot find pos dataset for VLSP at %" % input_pos_filename) else: pos_data = convert_pos_dataset(input_pos_filename) output_train_filename = os.path.join(tokenizer_dir, "vi_vlsp.train.gold.conllu") output_dev_filename = os.path.join(tokenizer_dir, "vi_vlsp.dev.gold.conllu") output_test_filename = os.path.join(tokenizer_dir, "vi_vlsp.test.gold.conllu") convert_file(args.vlsp_include_spaces, input_train_filename, output_train_filename, "train", output_dev_filename, "dev", pos_data) convert_file(args.vlsp_include_spaces, input_test_filename, output_test_filename, "test") ================================================ FILE: stanza/utils/datasets/tokenization/process_thai_tokenization.py ================================================ import os import random try: from pythainlp import sent_tokenize except ImportError: pass def write_section(output_dir, dataset_name, section, documents): """ Writes a list of documents for tokenization, including a file in conll format The Thai datasets generally have no MWT (apparently not relevant for Thai) output_dir: the destination directory for the output files dataset_name: orchid, BEST, lst20, etc section: train/dev/test documents: a nested list of documents, paragraphs, sentences, words words is a list of (word, space_follows) """ with open(os.path.join(output_dir, 'th_%s-ud-%s-mwt.json' % (dataset_name, section)), 'w') as fout: fout.write("[]\n") text_out = open(os.path.join(output_dir, 'th_%s.%s.txt' % (dataset_name, section)), 'w') label_out = open(os.path.join(output_dir, 'th_%s-ud-%s.toklabels' % (dataset_name, section)), 'w') for document in documents: for paragraph in document: for sentence_idx, sentence in enumerate(paragraph): for word_idx, word in enumerate(sentence): # TODO: split with newlines to make it more readable? text_out.write(word[0]) for i in range(len(word[0]) - 1): label_out.write("0") if word_idx == len(sentence) - 1: label_out.write("2") else: label_out.write("1") if word[1] and (sentence_idx != len(paragraph) - 1 or word_idx != len(sentence) - 1): text_out.write(' ') label_out.write('0') text_out.write("\n\n") label_out.write("\n\n") text_out.close() label_out.close() with open(os.path.join(output_dir, 'th_%s.%s.gold.conllu' % (dataset_name, section)), 'w') as fout: for document in documents: for paragraph in document: new_par = True for sentence in paragraph: for word_idx, word in enumerate(sentence): # SpaceAfter is left blank if there is space after the word if word[1] and new_par: space = 'NewPar=Yes' elif word[1]: space = '_' elif new_par: space = 'SpaceAfter=No|NewPar=Yes' else: space = 'SpaceAfter=No' new_par = False # Note the faked dependency structure: the conll reading code # needs it even if it isn't being used in any way fake_dep = 'root' if word_idx == 0 else 'dep' fout.write('{}\t{}\t_\t_\t_\t_\t{}\t{}\t{}:{}\t{}\n'.format(word_idx+1, word[0], word_idx, fake_dep, word_idx, fake_dep, space)) fout.write('\n') def write_dataset(documents, output_dir, dataset_name): """ Shuffle a list of documents, write three sections """ random.shuffle(documents) num_train = int(len(documents) * 0.8) num_dev = int(len(documents) * 0.1) os.makedirs(output_dir, exist_ok=True) write_section(output_dir, dataset_name, 'train', documents[:num_train]) write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev]) write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:]) def write_dataset_best(documents, test_documents, output_dir, dataset_name): """ Shuffle a list of documents, write three sections """ random.shuffle(documents) num_train = int(len(documents) * 0.85) num_dev = int(len(documents) * 0.15) os.makedirs(output_dir, exist_ok=True) write_section(output_dir, dataset_name, 'train', documents[:num_train]) write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev]) write_section(output_dir, dataset_name, 'test', test_documents) def reprocess_lines(processed_lines): """ Reprocesses lines using pythainlp to cut up sentences into shorter sentences. Many of the lines in BEST seem to be multiple Thai sentences concatenated, according to native Thai speakers. Input: a list of lines, where each line is a list of words. Space characters can be included as words Output: a new list of lines, resplit using pythainlp """ reprocessed_lines = [] for line in processed_lines: text = "".join(line) try: chunks = sent_tokenize(text) except NameError as e: raise NameError("Sentences cannot be reprocessed without first installing pythainlp") from e # Check that the total text back is the same as the text in if sum(len(x) for x in chunks) != len(text): raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks)) chunk_lengths = [len(x) for x in chunks] current_length = 0 new_line = [] for word in line: if len(word) + current_length < chunk_lengths[0]: new_line.append(word) current_length = current_length + len(word) elif len(word) + current_length == chunk_lengths[0]: new_line.append(word) reprocessed_lines.append(new_line) new_line = [] chunk_lengths = chunk_lengths[1:] current_length = 0 else: remaining_len = chunk_lengths[0] - current_length new_line.append(word[:remaining_len]) reprocessed_lines.append(new_line) word = word[remaining_len:] chunk_lengths = chunk_lengths[1:] while len(word) > chunk_lengths[0]: new_line = [word[:chunk_lengths[0]]] reprocessed_lines.append(new_line) word = word[chunk_lengths[0]:] chunk_lengths = chunk_lengths[1:] new_line = [word] current_length = len(word) reprocessed_lines.append(new_line) return reprocessed_lines def convert_processed_lines(processed_lines): """ Convert a list of sentences into documents suitable for the output methods in this module. Input: a list of lines, including space words Output: a list of documents, each document containing a list of sentences Each sentence is a list of words: (text, space_follows) Space words will be eliminated. """ paragraphs = [] sentences = [] for words in processed_lines: # turn the words into a sentence if len(words) > 1 and " " == words[0]: words = words[1:] elif len(words) == 1 and " " == words[0]: words = [] sentence = [] for word in words: word = word.strip() if not word: if len(sentence) == 0: print(word) raise ValueError("Unexpected space at start of sentence in document {}".format(filename)) sentence[-1] = (sentence[-1][0], True) else: sentence.append((word, False)) # blank lines are very rare in best, but why not treat them as a paragraph break if len(sentence) == 0: paragraphs.append([sentences]) sentences = [] continue sentence[-1] = (sentence[-1][0], True) sentences.append(sentence) paragraphs.append([sentences]) return paragraphs ================================================ FILE: stanza/utils/datasets/vietnamese/__init__.py ================================================ ================================================ FILE: stanza/utils/datasets/vietnamese/renormalize.py ================================================ """ Script to renormalize diacritics for Vietnamese text from BARTpho https://github.com/VinAIResearch/BARTpho/blob/main/VietnameseToneNormalization.md https://github.com/VinAIResearch/BARTpho/blob/main/LICENSE MIT License Copyright (c) 2021 VinAI Research Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import argparse import os DICT_MAP = { "òa": "oà", "Òa": "Oà", "ÒA": "OÀ", "óa": "oá", "Óa": "Oá", "ÓA": "OÁ", "ỏa": "oả", "Ỏa": "Oả", "ỎA": "OẢ", "õa": "oã", "Õa": "Oã", "ÕA": "OÃ", "ọa": "oạ", "Ọa": "Oạ", "ỌA": "OẠ", "òe": "oè", "Òe": "Oè", "ÒE": "OÈ", "óe": "oé", "Óe": "Oé", "ÓE": "OÉ", "ỏe": "oẻ", "Ỏe": "Oẻ", "ỎE": "OẺ", "õe": "oẽ", "Õe": "Oẽ", "ÕE": "OẼ", "ọe": "oẹ", "Ọe": "Oẹ", "ỌE": "OẸ", "ùy": "uỳ", "Ùy": "Uỳ", "ÙY": "UỲ", "úy": "uý", "Úy": "Uý", "ÚY": "UÝ", "ủy": "uỷ", "Ủy": "Uỷ", "ỦY": "UỶ", "ũy": "uỹ", "Ũy": "Uỹ", "ŨY": "UỸ", "ụy": "uỵ", "Ụy": "Uỵ", "ỤY": "UỴ", } def replace_all(text): for i, j in DICT_MAP.items(): text = text.replace(i, j) return text def convert_file(org_file, new_file): with open(org_file, 'r', encoding='utf-8') as reader, open(new_file, 'w', encoding='utf-8') as writer: content = reader.readlines() for line in content: new_line = replace_all(line) writer.write(new_line) def convert_files(file_list, new_dir): for file_name in file_list: base_name = os.path.split(file_name)[-1] new_file_path = os.path.join(new_dir, base_name) convert_file(file_name, new_file_path) def convert_dir(org_dir, new_dir, suffix): os.makedirs(new_dir, exist_ok=True) file_list = os.listdir(org_dir) file_list = [os.path.join(org_dir, f) for f in file_list if os.path.splitext(f)[1] == suffix] convert_files(file_list, new_dir) def main(): parser = argparse.ArgumentParser( description='Script that renormalizes diacritics' ) parser.add_argument( 'orig', help='Location of the original directory' ) parser.add_argument( 'converted', help='The location of new directory' ) parser.add_argument( '--suffix', type=str, default='.txt', help='Which suffix to look for when renormalizing a directory' ) args = parser.parse_args() if os.path.isfile(args.orig): convert_file(args.orig, args.converted) else: convert_dir(args.orig, args.converted, args.suffix) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/default_paths.py ================================================ import os def get_default_paths(): """ Gets base paths for the data directories If DATA_ROOT is set in the environment, use that as the root otherwise use "./data" individual paths can also be set in the environment """ DATA_ROOT = os.environ.get("DATA_ROOT", "data") defaults = { "TOKENIZE_DATA_DIR": DATA_ROOT + "/tokenize", "MWT_DATA_DIR": DATA_ROOT + "/mwt", "LEMMA_DATA_DIR": DATA_ROOT + "/lemma", "POS_DATA_DIR": DATA_ROOT + "/pos", "DEPPARSE_DATA_DIR": DATA_ROOT + "/depparse", "ETE_DATA_DIR": DATA_ROOT + "/ete", "NER_DATA_DIR": DATA_ROOT + "/ner", "CHARLM_DATA_DIR": DATA_ROOT + "/charlm", "SENTIMENT_DATA_DIR": DATA_ROOT + "/sentiment", "CONSTITUENCY_DATA_DIR": DATA_ROOT + "/constituency", "COREF_DATA_DIR": DATA_ROOT + "/coref", "LEMMA_CLASSIFIER_DATA_DIR": DATA_ROOT + "/lemma_classifier", # Set directories to store external word vector data "WORDVEC_DIR": "extern_data/wordvec", # TODO: not sure what other people actually have # TODO: also, could make this automatically update to the latest "UDBASE": "extern_data/ud2/ud-treebanks-v2.11", "UDBASE_GIT": "extern_data/ud2/git", "NERBASE": "extern_data/ner", "CONSTITUENCY_BASE": "extern_data/constituency", "SENTIMENT_BASE": "extern_data/sentiment", "COREF_BASE": "extern_data/coref", # there's a stanford github, stanfordnlp/handparsed-treebank, # with some data for different languages "HANDPARSED_DIR": "extern_data/handparsed-treebank", # directory with the contents of https://nlp.stanford.edu/projects/stanza/bio/ # on the cluster, for example, /u/nlp/software/stanza/bio_ud "BIO_UD_DIR": "extern_data/bio", # data root for other general input files, such as VI_VLSP "STANZA_EXTERN_DIR": "extern_data", } paths = { "DATA_ROOT" : DATA_ROOT } for k, v in defaults.items(): paths[k] = os.environ.get(k, v) return paths ================================================ FILE: stanza/utils/get_tqdm.py ================================================ import sys def get_tqdm(): """ Return a tqdm appropriate for the situation imports tqdm depending on if we're at a console, redir to a file, notebook, etc from @tcrimi at https://github.com/tqdm/tqdm/issues/506 This replaces `import tqdm`, so for example, you do this: from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() then do this when you want a scroll bar or regular iterator depending on context: tqdm(list) If there is no tty, the returned tqdm will always be disabled unless disable=False is specifically set. """ ipy_str = "" try: from IPython import get_ipython ipy_str = str(type(get_ipython())) except ImportError: pass if 'zmqshell' in ipy_str: from tqdm import tqdm_notebook as tqdm return tqdm if 'terminal' in ipy_str: from tqdm import tqdm return tqdm if sys.stderr is not None and hasattr(sys.stderr, "isatty") and sys.stderr.isatty(): from tqdm import tqdm return tqdm from tqdm import tqdm def hidden_tqdm(*args, **kwargs): if "disable" in kwargs: return tqdm(*args, **kwargs) kwargs["disable"] = True return tqdm(*args, **kwargs) return hidden_tqdm ================================================ FILE: stanza/utils/helper_func.py ================================================ def make_table(header, content, column_width=None): ''' Input: header -> List[str]: table header content -> List[List[str]]: table content column_width -> int: table column width; set to None for dynamically calculated widths Output: table_str -> str: well-formatted string for the table ''' table_str = '' len_column, len_row = len(header), len(content) + 1 if column_width is None: # dynamically decide column widths lens = [[len(str(h)) for h in header]] lens += [[len(str(x)) for x in row] for row in content] column_widths = [max(c)+3 for c in zip(*lens)] else: column_widths = [column_width] * len_column table_str += '=' * (sum(column_widths) + 1) + '\n' table_str += '|' for i, item in enumerate(header): table_str += ' ' + str(item).ljust(column_widths[i] - 2) + '|' table_str += '\n' table_str += '-' * (sum(column_widths) + 1) + '\n' for line in content: table_str += '|' for i, item in enumerate(line): table_str += ' ' + str(item).ljust(column_widths[i] - 2) + '|' table_str += '\n' table_str += '=' * (sum(column_widths) + 1) + '\n' return table_str ================================================ FILE: stanza/utils/languages/__init__.py ================================================ ================================================ FILE: stanza/utils/languages/kazakh_transliteration.py ================================================ """ Kazakh Transliteration: Cyrillic Kazakh --> Latin Kazakh """ import argparse import os from re import M import string import sys from stanza.models.common.utils import open_read_text, get_tqdm tqdm = get_tqdm() """ This dictionary isn't used in the code, just put this here in case you want to implement it more efficiently and in case the need to look up the unicode encodings for these letters might arise. Some letters are mapped to multiple latin letters, for these, I separated the unicde with a '%' delimiter between the two unicode characters. """ alph_map = { '\u0410' # А : '\u0041', # A '\u0430' # а : '\u0061', # a '\u04D8' # Ә : '\u00c4', # Ä '\u04D9' # ә : '\u00e4', # ä '\u0411' # Б : '\u0042', # B '\u0431' # б : '\u0062', # b '\u0412' # В : '\u0056', # V '\u0432' # в : '\u0076', # v '\u0413' # Г : '\u0047', # G '\u0433' # г : '\u0067', # g '\u0492' # Ғ : '\u011e', # Ğ '\u0493' # ғ : '\u011f', # ğ '\u0414' # Д : '\u0044', # D '\u0434' # д : '\u0064', # d '\u0415' # Е : '\u0045', # E '\u0435' # е : '\u0065', # e '\u0401' # Ё : '\u0130%\u006f', # İo '\u0451' # ё : '\u0069%\u006f', #io '\u0416' # Ж : '\u004a', # J '\u0436' # ж : '\u006a', # j '\u0417' # З : '\u005a', # Z '\u0437' # з : '\u007a', # z '\u0418' # И : '\u0130', # İ '\u0438' # и : '\u0069', # i '\u0419' # Й : '\u0130', # İ '\u0439' # й : '\u0069', # i '\u041A' # К : '\u004b', # K '\u043A' # к : '\u006b', # k '\u049A' # Қ : '\u0051', # Q '\u049B' # қ : '\u0071', # q '\u041B' # Л : '\u004c', # L '\u043B' # л : '\u006c', # l '\u041C' # М : '\u004d', # M '\u043C' # м : '\u006d', # m '\u041D' # Н : '\u004e', # N '\u043D' # н : '\u006e', # n '\u04A2' # Ң : '\u00d1', # Ñ '\u04A3' # ң : '\u00f1', # ñ '\u041E' # О : '\u004f', # O '\u043E' # о : '\u006f', # o '\u04E8' # Ө : '\u00d6', # Ö '\u04E9' # ө : '\u00f6', # ö '\u041F' # П : '\u0050', # P '\u043F' # п : '\u0070', # p '\u0420' # Р : '\u0052', # R '\u0440' # р : '\u0072', # r '\u0421' # С : '\u0053', # S '\u0441' # с : '\u0073', # s '\u0422' # Т : '\u0054', # T '\u0442' # т : '\u0074', # t '\u0423' # У : '\u0055', # U '\u0443' # у : '\u0075', # u '\u04B0' # Ұ : '\u016a', # Ū '\u04B1' # ұ : '\u016b', # ū '\u04AE' # Ү : '\u00dc', # Ü '\u04AF' # ү : '\u00fc', # ü '\u0424' # Ф : '\u0046', # F '\u0444' # ф : '\u0066', # f '\u0425' # Х : '\u0048', # H '\u0445' # х : '\u0068', # h '\u04BA' # Һ : '\u0048', # H '\u04BB' # һ : '\u0068', # h '\u0426' # Ц : '\u0043', # C '\u0446' # ц : '\u0063', # c '\u0427' # Ч : '\u00c7', # Ç '\u0447' # ч : '\u00e7', # ç '\u0428' # Ш : '\u015e', # Ş '\u0448' # ш : '\u015f', # ş '\u0429' # Щ : '\u015e%\u00e7', # Şç '\u0449' # щ : '\u015f%\u00e7', # şç '\u042A' # Ъ : '', # Empty String '\u044A' # ъ : '', # Empty String \u '\u042B' # Ы : '\u0059', # Y '\u044B' # ы : '\u0079', # y '\u0406' # І : '\u0130', # İ '\u0456' # і : '\u0069', # i '\u042C' # Ь : '', # Empty String '\u044C' # ь : '', # Empty String '\u042D' # Э : '\u0045', # E '\u044D' # э : '\u0065', # e '\u042E' # Ю : '\u0130%\u0075', # İu '\u044E' # ю : '\u0069%\u0075', # iu '\u042F' # Я : '\u0130%\u0061', # İa '\u044F' # я : '\u0069%\u0061' # ia } kazakh_alph = "АаӘәБбВвГгҒғДдЕеЁёЖжЗзИиЙйКкҚқЛлМмНнҢңОоӨөПпРрСсТтУуҰұҮүФфХхҺһЦцЧчШшЩщЪъЫыІіЬьЭэЮюЯя" latin_alph = "AaÄäBbVvGgĞğDdEeİoioJjZzİiİiKkQqLlMmNnÑñOoÖöPpRrSsTtUuŪūÜüFfHhHhCcÇ窺ŞçşçYyİiEeİuiuİaia" mult_mapping = "ЁёЩщЮюЯя" empty_mapping = "ЪъЬь" """ ϵ : Ukrainian letter for 'ё' ə : Russian utf-8 encoding for Kazakh 'ә' ó : 2016 Kazakh Latin adopted this instead of 'ö' ã : 1 occurrence in the dataset -- mapped to 'a' """ russian_alph = "ϵəóã" russian_counterpart = "ioäaöa" def create_dic(source_alph, target_alph, mult_mapping, empty_mapping): res = {} idx = 0 for i in range(len(source_alph)): l_s = source_alph[i] if l_s in mult_mapping: res[l_s] = target_alph[idx] + target_alph[idx+1] idx += 1 elif l_s in empty_mapping: res[l_s] = '' idx -= 1 else: res[l_s] = target_alph[idx] idx += 1 res['ϵ'] = 'io' res['ə'] = 'ä' res['ó'] = 'ö' res['ã'] = 'a' print(res) return res supp_alph = "IWwXx0123456789–«»—" def transliterate(source): output = "" tr_dict = create_dic(kazakh_alph, latin_alph, mult_mapping, empty_mapping) punc = string.punctuation white_spc = string.whitespace for c in source: if c in punc or c in white_spc: output += c elif c in latin_alph or c in supp_alph: output += c elif c in tr_dict: output += tr_dict[c] else: print(f"Transliteration Error: {c}") return output if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('input_file', type=str, nargs="+", help="Files to process") parser.add_argument('--output_dir', type=str, default=None, help="Directory to output results") args = parser.parse_args() tr_dict = create_dic(kazakh_alph, latin_alph, mult_mapping, empty_mapping) for filename in tqdm(args.input_file): if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) directory, basename = os.path.split(filename) output_name = os.path.join(args.output_dir, basename) if output_name.endswith(".xz"): output_name = output_name[:-3] output_name = output_name + ".trans" else: output_name = filename + ".trans" tqdm.write("Transliterating %s to %s" % (filename, output_name)) with open_read_text(filename) as f_in: data = f_in.read() with open(output_name, 'w') as f_out: punc = string.punctuation white_spc = string.whitespace for c in tqdm(data, leave=False): if c in tr_dict: f_out.write(tr_dict[c]) else: f_out.write(c) print("Process Completed Successfully!") ================================================ FILE: stanza/utils/lemma/__init__.py ================================================ ================================================ FILE: stanza/utils/lemma/count_ambiguous_lemmas.py ================================================ """ Read in a UD file, report any word/verb pairs which get lemmatized to different lemmas """ from collections import Counter, defaultdict import sys from stanza.utils.conll import CoNLL filename = sys.argv[1] print(filename) lemma_counters = defaultdict(Counter) doc = CoNLL.conll2doc(input_file=filename) for sentence in doc.sentences: for word in sentence.words: text = word.text upos = word.upos lemma = word.lemma lemma_counters[(text, upos)][lemma] += 1 keys = lemma_counters.keys() keys = sorted(keys, reverse=True, key=lambda x: sum(lemma_counters[x][y] for y in lemma_counters[x])) for text, upos in keys: if len(lemma_counters[(text, upos)]) > 1: print(text, upos, lemma_counters[(text, upos)]) ================================================ FILE: stanza/utils/max_mwt_length.py ================================================ import sys import json def max_mwt_length(filenames): max_len = 0 for filename in filenames: with open(filename) as f: d = json.load(f) max_len = max([max_len] + [len(" ".join(x[0][1])) for x in d]) return max_len if __name__ == '__main__': print(max_max_jlength(sys.argv[1:])) ================================================ FILE: stanza/utils/ner/__init__.py ================================================ ================================================ FILE: stanza/utils/ner/flair_ner_tag_dataset.py ================================================ """ Test a flair model on a 4 class dataset """ import argparse import json from flair.data import Sentence from flair.models import SequenceTagger from stanza.models.ner.utils import process_tags from stanza.models.ner.scorer import score_by_entity, score_by_token def test_file(eval_file, tagger): with open(eval_file) as fin: gold_doc = json.load(fin) gold_doc = [[(x['text'], x['ner']) for x in sentence] for sentence in gold_doc] gold_doc = process_tags(gold_doc, 'bioes') pred_doc = [] for gold_sentence in gold_doc: pred_sentence = [[x[0], 'O'] for x in gold_sentence] flair_sentence = Sentence(" ".join(x[0] for x in pred_sentence), use_tokenizer=False) tagger.predict(flair_sentence) for entity in flair_sentence.get_spans('ner'): tag = entity.tag tokens = entity.tokens start_idx = tokens[0].idx - 1 end_idx = tokens[-1].idx if len(tokens) == 1: pred_sentence[start_idx][1] = "S-" + tag else: pred_sentence[start_idx][1] = "B-" + tag pred_sentence[end_idx - 1][1] = "E-" + tag for idx in range(start_idx+1, end_idx - 1): pred_sentence[idx][1] = "I-" + tag pred_doc.append(pred_sentence) pred_tags = [[x[1] for x in sentence] for sentence in pred_doc] gold_tags = [[x[1] for x in sentence] for sentence in gold_doc] print("RESULTS ON: %s" % eval_file) _, _, f_micro, _ = score_by_entity(pred_tags, gold_tags) score_by_token(pred_tags, gold_tags) return f_micro def main(): parser = argparse.ArgumentParser() parser.add_argument('--ner_model', type=str, default=None, help='Which NER model to test') parser.add_argument('filename', type=str, nargs='*', help='which files to test') args = parser.parse_args() if args.ner_model is None: ner_models = ["ner-fast", "ner", "ner-large"] else: ner_models = [args.ner_model] if not args.filename: args.filename = ["data/ner/en_conll03.test.json", "data/ner/en_worldwide-4class.test.json", "data/ner/en_worldwide-4class-africa.test.json", "data/ner/en_worldwide-4class-asia.test.json", "data/ner/en_worldwide-4class-indigenous.test.json", "data/ner/en_worldwide-4class-latam.test.json", "data/ner/en_worldwide-4class-middle_east.test.json"] print("Processing the files: %s" % ",".join(args.filename)) results = [] model_results = {} for ner_model in ner_models: model_results[ner_model] = [] # load tagger #tagger = SequenceTagger.load("ner-fast") print("-----------------------------") print("Running %s" % ner_model) print("-----------------------------") tagger = SequenceTagger.load(ner_model) for filename in args.filename: f_micro = test_file(filename, tagger) f_micro = "%.2f" % (f_micro * 100) results.append((ner_model, filename, f_micro)) model_results[ner_model].append(f_micro) for result in results: print(result) for model in model_results.keys(): result = [model] + model_results[model] print(" & ".join(result)) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/ner/paying_annotators.py ================================================ import json import os def get_worker_subs(json_string): """ Gets the AWS worker IDs from the annotation file in output folder. Returns a list of the AWS worker subs """ subs = [] # json.loads() works on JSON strings, json.load() is for JSON files job_data = json.loads(json_string) for i in range(len(job_data["answers"])): subs.append(job_data["answers"][i]["workerMetadata"]["identityData"]["sub"]) return subs def track_tasks(input_path, worker_map=None): """ Takes a path to a folder containing the worker annotation metadata from AWS Sagemaker labeling job and a dictionary mapping AWS worker subs to their names or identification tags and returns a dictionary mapping the names/identification tags to the number of labeling tasks completed. If no worker map is provided, this function returns a dictionary mapping the worker "sub" fields to the number of tasks they completed. :param input_path: string of the path to the directory containing the worker annotation sub-directories :param worker_map: dictionary mapping AWS worker subs to the worker identifications :return: dictionary mapping worker identifications to the number of tasks completed """ tracker = {} res = {} for direc in os.listdir(input_path): subdir_path = os.path.join(input_path, direc) subdir = os.listdir(subdir_path) json_file_path = os.path.join(subdir_path, subdir[0]) with open(json_file_path) as json_file: json_string = json_file.read() subs = get_worker_subs(json_string) for sub in subs: tracker[sub] = tracker.get(sub, 0) + 1 if worker_map: for sub in tracker: worker = worker_map[sub] res[worker] = tracker[sub] return res return tracker def main(): # sample from completed labeling job print(track_tasks('..\\tests\\ner\\aws_labeling_copy', worker_map={ "7efc17ac-3397-4472-afe5-89184ad145d0": "Worker1", "afce8c28-969c-4e73-a20f-622ef122f585": "Worker2", "91f6236e-63c6-4a84-8fd6-1efbab6dedab": "Worker3", "6f202e93-e6b6-4e1d-8f07-0484b9a9093a": "Worker4", "2b674d33-f656-44b0-8f90-d70a1ab71ec2": "Worker5" } )) # sample from completed labeling job -- no worker map provided print(track_tasks('..\\tests\\ner\\aws_labeling_copy')) return if __name__ == "__main__": main() ================================================ FILE: stanza/utils/ner/spacy_ner_tag_dataset.py ================================================ """ Test a spacy model on a 4 class dataset """ import argparse import json import spacy from spacy.tokens import Doc from stanza.models.ner.utils import process_tags from stanza.models.ner.scorer import score_by_entity, score_by_token from stanza.utils.confusion import format_confusion from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() """ Simplified classes used in the Worldwide dataset are: Date Facility Location Misc Money NORP Organization Person Product vs OntoNotes classes: CARDINAL DATE EVENT FAC GPE LANGUAGE LAW LOC MONEY NORP ORDINAL ORG PERCENT PERSON PRODUCT QUANTITY TIME WORK_OF_ART """ def test_file(eval_file, tagger, simplify): with open(eval_file) as fin: gold_doc = json.load(fin) gold_doc = [[(x['text'], x['ner']) for x in sentence] for sentence in gold_doc] gold_doc = process_tags(gold_doc, 'bioes') if simplify: for doc in gold_doc: for idx, word in enumerate(doc): if word[1] != "O": word = [word[0], simplify_ontonotes_to_worldwide(word[1])] doc[idx] = word ignore_tags = "Date,DATE" if simplify else None original_text = [[x[0] for x in gold_sentence] for gold_sentence in gold_doc] pred_doc = [] for sentence in tqdm(original_text): spacy_sentence = Doc(tagger.vocab, sentence) spacy_sentence = tagger(spacy_sentence) entities = ["O" if not token.ent_type_ else "%s-%s" % (token.ent_iob_, token.ent_type_) for token in spacy_sentence] if simplify: entities = [simplify_ontonotes_to_worldwide(x) for x in entities] pred_sentence = [[token.text, entity] for token, entity in zip(spacy_sentence, entities)] pred_doc.append(pred_sentence) pred_doc = process_tags(pred_doc, 'bioes') pred_tags = [[x[1] for x in sentence] for sentence in pred_doc] gold_tags = [[x[1] for x in sentence] for sentence in gold_doc] print("RESULTS ON: %s" % eval_file) _, _, f_micro, _ = score_by_entity(pred_tags, gold_tags, ignore_tags=ignore_tags) _, _, _, confusion = score_by_token(pred_tags, gold_tags, ignore_tags=ignore_tags) print("NER token confusion matrix:\n{}".format(format_confusion(confusion, hide_blank=True, transpose=True))) return f_micro def main(): parser = argparse.ArgumentParser() parser.add_argument('--ner_model', type=str, default=None, help='Which spacy model to test') parser.add_argument('filename', type=str, nargs='*', help='which files to test') parser.add_argument('--simplify', default=False, action='store_true', help='Simplify classes to the 8 class Worldwide model') args = parser.parse_args() if args.ner_model is None: ner_models = ['en_core_web_sm', 'en_core_web_trf'] else: ner_models = [args.ner_model] if not args.filename: args.filename = ["data/ner/en_ontonotes-8class.test.json", "data/ner/en_worldwide-8class.test.json", "data/ner/en_worldwide-8class-africa.test.json", "data/ner/en_worldwide-8class-asia.test.json", "data/ner/en_worldwide-8class-indigenous.test.json", "data/ner/en_worldwide-8class-latam.test.json", "data/ner/en_worldwide-8class-middle_east.test.json"] print("Processing the files: %s" % ",".join(args.filename)) results = [] model_results = {} for ner_model in ner_models: model_results[ner_model] = [] # load tagger print("-----------------------------") print("Running %s" % ner_model) print("-----------------------------") tagger = spacy.load(ner_model, disable=["tagger", "parser", "attribute_ruler", "lemmatizer"]) for filename in args.filename: f_micro = test_file(filename, tagger, args.simplify) f_micro = "%.2f" % (f_micro * 100) results.append((ner_model, filename, f_micro)) model_results[ner_model].append(f_micro) for result in results: print(result) for model in model_results.keys(): result = [model] + model_results[model] print(" & ".join(result)) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/pretrain/__init__.py ================================================ ================================================ FILE: stanza/utils/pretrain/compare_pretrains.py ================================================ import sys import numpy as np from stanza.models.common.pretrain import Pretrain pt1_filename = sys.argv[1] pt2_filename = sys.argv[2] pt1 = Pretrain(pt1_filename) pt2 = Pretrain(pt2_filename) vocab1 = pt1.vocab vocab2 = pt2.vocab common_words = [x for x in vocab1 if x in vocab2] print("%d shared words, out of %d in %s and %d in %s" % (len(common_words), len(vocab1), pt1_filename, len(vocab2), pt2_filename)) eps = 0.0001 total_norm = 0.0 total_close = 0 words_different = [] for word, idx in vocab1._unit2id.items(): if word not in vocab2: continue v1 = pt1.emb[idx] v2 = pt2.emb[pt2.vocab[word]] norm = np.linalg.norm(v1 - v2) if norm < eps: total_close += 1 else: total_norm += norm if len(words_different) < 10: words_different.append("|%s|" % word) #print(word, idx, pt2.vocab[word]) #print(v1) #print(v2) if total_close < len(common_words): avg_norm = total_norm / (len(common_words) - total_close) print("%d vectors were close. Average difference of the others: %f" % (total_close, avg_norm)) print("The first few different words were:\n %s" % "\n ".join(words_different)) else: print("All %d vectors were close!" % total_close) for word, idx in vocab1._unit2id.items(): if word not in vocab2: continue if pt2.vocab[word] != idx: break else: print("All indices are the same") ================================================ FILE: stanza/utils/select_backoff.py ================================================ import sys backoff_models = { "UD_Breton-KEB": "ga_idt", "UD_Czech-PUD": "cs_pdt", "UD_English-PUD": "en_ewt", "UD_Faroese-OFT": "nn_nynorsk", "UD_Finnish-PUD": "fi_tdt", "UD_Japanese-Modern": "ja_gsd", "UD_Naija-NSC": "en_ewt", "UD_Swedish-PUD": "sv_talbanken" } print(backoff_models[sys.argv[1]]) ================================================ FILE: stanza/utils/training/__init__.py ================================================ ================================================ FILE: stanza/utils/training/common.py ================================================ import argparse import glob import logging import os import pathlib import random import sys from enum import Enum try: from udtools.udeval import build_evaluation_table except ImportError: from udtools.src.udtools.udeval import build_evaluation_table from stanza.resources.default_packages import default_charlms, lemma_charlms, tokenizer_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS from stanza.resources.default_packages import no_pretrain_languages, pos_pretrains, depparse_pretrains, default_pretrains from stanza.models.common.constant import treebank_to_short_name from stanza.models.common.utils import ud_scores from stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError from stanza.utils.datasets import common import stanza.utils.default_paths as default_paths logger = logging.getLogger('stanza') class Mode(Enum): TRAIN = 1 SCORE_DEV = 2 SCORE_TEST = 3 SCORE_TRAIN = 4 class ArgumentParserWithExtraHelp(argparse.ArgumentParser): def __init__(self, sub_argparse, *args, **kwargs): super().__init__(*args, **kwargs) # forwards all unused arguments self.sub_argparse = sub_argparse def print_help(self, file=None): super().print_help(file=file) def format_help(self): help_text = super().format_help() if self.sub_argparse is not None: sub_text = self.sub_argparse.format_help().split("\n") first_line = -1 for line_idx, line in enumerate(sub_text): if line.strip().startswith("usage:"): first_line = line_idx elif first_line >= 0 and not line.strip(): first_line = line_idx break help_text = help_text + "\n\nmodel arguments:" + "\n".join(sub_text[first_line:]) return help_text def build_argparse(sub_argparse=None): parser = ArgumentParserWithExtraHelp(sub_argparse=sub_argparse, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--save_output', dest='save_output', default=False, action='store_true', help="Save output - default is to use a temp directory.") parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks') parser.add_argument('--train', dest='mode', default=Mode.TRAIN, action='store_const', const=Mode.TRAIN, help='Run in train mode') parser.add_argument('--score_dev', dest='mode', action='store_const', const=Mode.SCORE_DEV, help='Score the dev set') parser.add_argument('--score_test', dest='mode', action='store_const', const=Mode.SCORE_TEST, help='Score the test set') parser.add_argument('--score_train', dest='mode', action='store_const', const=Mode.SCORE_TRAIN, help='Score the train set as a test set. Currently only implemented for some models') # These arguments need to be here so we can identify if the model already exists in the user-specified home # TODO: when all of the model scripts handle their own names, can eliminate this argument parser.add_argument('--save_dir', type=str, default=None, help="Root dir for saving models. If set, will override the model's default.") parser.add_argument('--save_name', type=str, default=None, help="Base name for saving models. If set, will override the model's default.") parser.add_argument('--charlm_only', action='store_true', default=False, help='When asking for ud_all, filter the ones which have charlms') parser.add_argument('--transformer_only', action='store_true', default=False, help='When asking for ud_all, filter the ones for languages where we have transformers') parser.add_argument('--force', dest='force', action='store_true', default=False, help='Retrain existing models') return parser def add_charlm_args(parser): parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm') parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package") def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None, args=None): """ A main program for each of the run_xyz scripts It collects the arguments and runs the main method for each dataset provided. It also tries to look for an existing model and not overwrite it unless --force is provided model_name can be a callable expecting the args - the charlm, for example, needs this feature, since it makes both forward and backward models """ if args is None: logger.info("Training program called with:\n" + " ".join(sys.argv)) args = sys.argv[1:] else: logger.info("Training program called with:\n" + " ".join(args)) paths = default_paths.get_default_paths() parser = build_argparse(sub_argparse) if add_specific_args is not None: add_specific_args(parser) if '--extra_args' in sys.argv: idx = sys.argv.index('--extra_args') extra_args = sys.argv[idx+1:] command_args = parser.parse_args(sys.argv[:idx]) else: command_args, extra_args = parser.parse_known_args(args=args) # Pass this through to the underlying model as well as use it here # we don't put --save_name here for the awkward situation of # --save_name being specified for an invocation with multiple treebanks if command_args.save_dir: extra_args.extend(["--save_dir", command_args.save_dir]) # if --no_seed is added to the args, we actually want to pick a seed here # that way, save file names will be consistent... # otherwise, it might try to use different save names when using the # train and dev sets, if the random seed is used as part of the save name while '--no_seed' in extra_args: idx = extra_args.index('--no_seed') random_seed = random.randint(0, 1000000000) logger.info("Using random seed %d", random_seed) extra_args[idx:idx+1] = ['--seed', str(random_seed)] if callable(model_name): model_name = model_name(command_args) mode = command_args.mode treebanks = [] for treebank in command_args.treebanks: # this is a really annoying typo to make if you copy/paste a # UD directory name on the cluster and your job dies 30s after # being queued for an hour if treebank.endswith("/"): treebank = treebank[:-1] if treebank.lower() in ('ud_all', 'all_ud'): ud_treebanks = common.get_ud_treebanks(paths["UDBASE"]) if choose_charlm_method is not None and command_args.charlm_only: logger.info("Filtering ud_all treebanks to only those which can use charlm for this model") ud_treebanks = [x for x in ud_treebanks if choose_charlm_method(*treebank_to_short_name(x).split("_", 1), 'default') is not None] if command_args.transformer_only: logger.info("Filtering ud_all treebanks to only those which can use a transformer for this model") ud_treebanks = [x for x in ud_treebanks if treebank_to_short_name(x).split("_")[0] in TRANSFORMERS] logger.info("Expanding %s to %s", treebank, " ".join(ud_treebanks)) treebanks.extend(ud_treebanks) else: treebanks.append(treebank) for treebank_idx, treebank in enumerate(treebanks): if treebank_idx > 0: logger.info("=========================================") short_name = treebank_to_short_name(treebank) logger.debug("%s: %s" % (treebank, short_name)) save_name_args = [] if model_name != 'ete': # ete is several models at once, so we don't set --save_name # theoretically we could handle a parametrized save_name if command_args.save_name: save_name = command_args.save_name # if there's more than 1 treebank, we can't save them all to this save_name # we have to override that value for each treebank if len(treebanks) > 1: save_name_dir, save_name_filename = os.path.split(save_name) save_name_filename = "%s_%s" % (short_name, save_name_filename) save_name = os.path.join(save_name_dir, save_name_filename) logger.info("Save file for %s model for %s: %s", short_name, treebank, save_name) save_name_args = ['--save_name', save_name] # some run scripts can build the model filename # in order to check for models that are already created elif build_model_filename is None: save_name = "%s_%s.pt" % (short_name, model_name) logger.info("Save file for %s model: %s", short_name, save_name) save_name_args = ['--save_name', save_name] else: save_name_args = [] if mode == Mode.TRAIN and not command_args.force: if build_model_filename is not None: model_path = build_model_filename(paths, short_name, command_args, extra_args) elif command_args.save_dir: model_path = os.path.join(command_args.save_dir, save_name) else: save_dir = os.path.join("saved_models", model_dir) save_name_args.extend(["--save_dir", save_dir]) model_path = os.path.join(save_dir, save_name) if model_path is None: # this can happen with the identity lemmatizer, for example pass elif os.path.exists(model_path): logger.info("%s: %s exists, skipping!" % (treebank, model_path)) continue else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) run_treebank(mode, paths, treebank, short_name, command_args, extra_args + save_name_args) def run_eval_script(gold_conllu_file, system_conllu_file, evals=None): """ Wrapper for lemma scorer. """ evaluation = ud_scores(gold_conllu_file, system_conllu_file) if evals is None: return ud_eval.build_evaluation_table(evaluation, verbose=True, counts=False, enhanced=False) else: results = [evaluation[key].f1 for key in evals] max_len = max(5, max(len(e) for e in evals)) evals_string = " ".join(("{:>%d}" % max_len).format(e) for e in evals) results_string = " ".join(("{:%d.2f}" % max_len).format(100 * x) for x in results) return evals_string + "\n" + results_string def run_eval_script_tokens(eval_gold, eval_pred): return run_eval_script(eval_gold, eval_pred, evals=["Tokens", "Sentences", "Words"]) def run_eval_script_mwt(eval_gold, eval_pred): return run_eval_script(eval_gold, eval_pred, evals=["Words"]) def run_eval_script_pos(eval_gold, eval_pred): return run_eval_script(eval_gold, eval_pred, evals=["UPOS", "XPOS", "UFeats", "AllTags"]) def run_eval_script_depparse(eval_gold, eval_pred): return run_eval_script(eval_gold, eval_pred, evals=["UAS", "LAS", "CLAS", "MLAS", "BLEX"]) def find_wordvec_pretrain(language, default_pretrains, dataset_pretrains=None, dataset=None, model_dir=DEFAULT_MODEL_DIR): # try to get the default pretrain for the language, # but allow the package specific value to override it if that is set default_pt = default_pretrains.get(language, None) if dataset is not None and dataset_pretrains is not None: default_pt = dataset_pretrains.get(language, {}).get(dataset, default_pt) if default_pt is not None: default_pt_path = '{}/{}/pretrain/{}.pt'.format(model_dir, language, default_pt) if not os.path.exists(default_pt_path): logger.info("Default pretrain should be {} Attempting to download".format(default_pt_path)) try: download(lang=language, package=None, processors={"pretrain": default_pt}, model_dir=model_dir) except UnknownLanguageError: # if there's a pretrain in the directory, hiding this # error will let us find that pretrain later pass if os.path.exists(default_pt_path): if dataset is not None and dataset_pretrains is not None and language in dataset_pretrains and dataset in dataset_pretrains[language]: logger.info(f"Using default pretrain for {language}:{dataset}, found in {default_pt_path} To use a different pretrain, specify --wordvec_pretrain_file") else: logger.info(f"Using default pretrain for language {language}, found in {default_pt_path} To use a different pretrain, specify --wordvec_pretrain_file") return default_pt_path pretrain_path = '{}/{}/pretrain/*.pt'.format(model_dir, language) pretrains = glob.glob(pretrain_path) if len(pretrains) == 0: # we already tried to download the default pretrain once # and it didn't work. maybe the default language package # will have something? logger.warning(f"Cannot figure out which pretrain to use for '{language}'. Will download the default package and hope for the best") try: download(lang=language, model_dir=model_dir) except UnknownLanguageError as e: # this is a very unusual situation # basically, there was a language which we started to add # to the resources, but then didn't release the models # as part of resources.json raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} No pretrains in the system for this language. Please prepare an embedding as a .pt and use --wordvec_pretrain_file to specify a .pt file to use") from e pretrains = glob.glob(pretrain_path) if len(pretrains) == 0: raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} Try 'stanza.download(\"{language}\")' to get a default pretrain or use --wordvec_pretrain_file to specify a .pt file to use") if len(pretrains) > 1: raise FileNotFoundError(f"Too many pretrains to choose from in {pretrain_path} Must specify an exact path to a --wordvec_pretrain_file") pt = pretrains[0] logger.info(f"Using pretrain found in {pt} To use a different pretrain, specify --wordvec_pretrain_file") return pt def choose_depparse_pretrain(language, dataset): if language in no_pretrain_languages: return None return find_wordvec_pretrain(language, default_pretrains, depparse_pretrains, dataset) def find_charlm_file(direction, language, charlm, model_dir=DEFAULT_MODEL_DIR): """ Return the path to the forward or backward charlm if it exists for the given package If we can figure out the package, but can't find it anywhere, we try to download it """ saved_path = 'saved_models/charlm/{}_{}_{}_charlm.pt'.format(language, charlm, direction) if os.path.exists(saved_path): logger.info(f'Using model {saved_path} for {direction} charlm') return saved_path resource_path = '{}/{}/{}_charlm/{}.pt'.format(model_dir, language, direction, charlm) if os.path.exists(resource_path): logger.info(f'Using model {resource_path} for {direction} charlm') return resource_path try: download(lang=language, package=None, processors={f"{direction}_charlm": charlm}, model_dir=model_dir) if os.path.exists(resource_path): logger.info(f'Downloaded model, using model {resource_path} for {direction} charlm') return resource_path except ValueError as e: raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work") from e raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work") def build_charlm_args(language, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR, use_backward_model=True): """ If specified, return forward and backward charlm args """ if charlm: try: forward = find_charlm_file('forward', language, charlm, model_dir=model_dir) if use_backward_model: backward = find_charlm_file('backward', language, charlm, model_dir=model_dir) except FileNotFoundError as e: # if we couldn't find sd_isra when training an SD model, # for example, but isra exists, we try to download the # shorter model name if charlm.startswith(language + "_"): short_charlm = charlm[len(language)+1:] try: forward = find_charlm_file('forward', language, short_charlm, model_dir=model_dir) if use_backward_model: backward = find_charlm_file('backward', language, short_charlm, model_dir=model_dir) except FileNotFoundError as e2: raise FileNotFoundError("Tried to find charlm %s, which doesn't exist. Also tried %s, but didn't find that either" % (charlm, short_charlm)) from e logger.warning("Was asked to find charlm %s, which does not exist. Did find %s though", charlm, short_charlm) else: raise char_args = ['--charlm_forward_file', forward] if use_backward_model: char_args += ['--charlm_backward_file', backward] if not base_args: return char_args return ['--charlm', '--charlm_shorthand', f'{language}_{charlm}'] + char_args return [] def choose_charlm(language, dataset, charlm, language_charlms, dataset_charlms): """ charlm == "default" means the default charlm for this dataset or language charlm == None is no charlm """ default_charlm = language_charlms.get(language, None) specific_charlm = dataset_charlms.get(language, {}).get(dataset, None) if charlm is None: return None elif charlm != "default": return charlm elif dataset in dataset_charlms.get(language, {}): # this way, a "" or None result gets honored # thus treating "not in the map" as a way for dataset_charlms to signal to use the default return specific_charlm elif default_charlm: return default_charlm else: return None def choose_pos_charlm(short_language, dataset, charlm): """ charlm == "default" means the default charlm for this dataset or language charlm == None is no charlm """ return choose_charlm(short_language, dataset, charlm, default_charlms, pos_charlms) def choose_depparse_charlm(short_language, dataset, charlm): """ charlm == "default" means the default charlm for this dataset or language charlm == None is no charlm """ return choose_charlm(short_language, dataset, charlm, default_charlms, depparse_charlms) def choose_lemma_charlm(short_language, dataset, charlm): """ charlm == "default" means the default charlm for this dataset or language charlm == None is no charlm """ return choose_charlm(short_language, dataset, charlm, default_charlms, lemma_charlms) def choose_tokenizer_charlm(short_language, dataset, charlm): """ charlm == "default" means the default charlm for this dataset or language charlm == None is no charlm """ return choose_charlm(short_language, dataset, charlm, default_charlms, tokenizer_charlms) def choose_transformer(short_language, command_args, extra_args, warn=True, layers=False): """ Choose a transformer using the default options for this language """ bert_args = [] if command_args is not None and command_args.use_bert and '--bert_model' not in extra_args: if short_language in TRANSFORMERS: bert_args = ['--bert_model', TRANSFORMERS.get(short_language)] if layers and short_language in TRANSFORMER_LAYERS and '--bert_hidden_layers' not in extra_args: bert_args.extend(['--bert_hidden_layers', str(TRANSFORMER_LAYERS.get(short_language))]) elif warn: logger.error("Transformer requested, but no default transformer for %s Specify one using --bert_model" % short_language) return bert_args def build_pos_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): charlm = choose_pos_charlm(short_language, dataset, charlm) charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir) return charlm_args def build_lemma_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): charlm = choose_lemma_charlm(short_language, dataset, charlm) charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir) return charlm_args def build_depparse_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): charlm = choose_depparse_charlm(short_language, dataset, charlm) charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir) return charlm_args def build_tokenizer_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): charlm = choose_tokenizer_charlm(short_language, dataset, charlm) charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir, use_backward_model=False) return charlm_args def build_wordvec_args(short_language, dataset, extra_args, task_pretrains): if '--wordvec_pretrain_file' in extra_args or '--no_pretrain' in extra_args: return [] if short_language in no_pretrain_languages: # we couldn't find word vectors for a few languages...: # coptic, naija, old russian, turkish german, swedish sign language logger.warning("No known word vectors for language {} If those vectors can be found, please update the training scripts.".format(short_language)) return ["--no_pretrain"] else: if short_language in task_pretrains and dataset in task_pretrains[short_language]: dataset_pretrains = task_pretrains else: dataset_pretrains = {} wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, dataset_pretrains, dataset) return ["--wordvec_pretrain_file", wordvec_pretrain] def build_pos_wordvec_args(short_language, dataset, extra_args): return build_wordvec_args(short_language, dataset, extra_args, pos_pretrains) def build_depparse_wordvec_args(short_language, dataset, extra_args): return build_wordvec_args(short_language, dataset, extra_args, depparse_pretrains) ================================================ FILE: stanza/utils/training/compose_ete_results.py ================================================ """ Turn the ETE results into markdown Parses blocks like this from the model eval script 2022-01-14 01:23:34 INFO: End to end results for af_afribooms models on af_afribooms test data: Metric | Precision | Recall | F1 Score | AligndAcc -----------+-----------+-----------+-----------+----------- Tokens | 99.93 | 99.92 | 99.93 | Sentences | 100.00 | 100.00 | 100.00 | Words | 99.93 | 99.92 | 99.93 | UPOS | 97.97 | 97.96 | 97.97 | 98.04 XPOS | 93.98 | 93.97 | 93.97 | 94.04 UFeats | 97.23 | 97.22 | 97.22 | 97.29 AllTags | 93.89 | 93.88 | 93.88 | 93.95 Lemmas | 97.40 | 97.39 | 97.39 | 97.46 UAS | 87.39 | 87.38 | 87.38 | 87.45 LAS | 83.57 | 83.56 | 83.57 | 83.63 CLAS | 76.88 | 76.45 | 76.66 | 76.52 MLAS | 72.28 | 71.87 | 72.07 | 71.94 BLEX | 73.20 | 72.79 | 73.00 | 72.86 Turns them into a markdown table. Included is an attempt to mark the default packages with a green check. """ import argparse from stanza.models.common.constant import pretty_langcode_to_lang from stanza.models.common.short_name_to_treebank import short_name_to_treebank from stanza.utils.training.run_ete import RESULTS_STRING from stanza.resources.default_packages import default_treebanks EXPECTED_ORDER = ["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"] parser = argparse.ArgumentParser() parser.add_argument("filenames", type=str, nargs="+", help="Which file(s) to read") args = parser.parse_args() lines = [] for filename in args.filenames: with open(filename) as fin: lines.extend(fin.readlines()) blocks = [] index = 0 while index < len(lines): line = lines[index] if line.find(RESULTS_STRING) < 0: index = index + 1 continue line = line[line.find(RESULTS_STRING) + len(RESULTS_STRING):].strip() short_name = line.split()[0] # skip the header of the expected output index = index + 1 line = lines[index] pieces = line.split("|") assert pieces[0].strip() == 'Metric', "output format changed?" assert pieces[3].strip() == 'F1 Score', "output format changed?" index = index + 1 line = lines[index] assert line.startswith("-----"), "output format changed?" index = index + 1 block = lines[index:index+13] assert len(block) == 13 index = index + 13 block = [x.split("|") for x in block] assert all(x[0].strip() == y for x, y in zip(block, EXPECTED_ORDER)), "output format changed?" lcode, short_dataset = short_name.split("_", 1) language = pretty_langcode_to_lang(lcode) treebank = short_name_to_treebank(short_name) long_dataset = treebank.split("-")[-1] checkmark = "" if default_treebanks[lcode] == short_dataset: checkmark = '' block = [language, "[%s](%s)" % (long_dataset, "https://github.com/UniversalDependencies/%s" % treebank), lcode, checkmark] + [x[3].strip() for x in block] blocks.append(block) PREFIX = ["​Macro Avg", "​", "​", ""] avg = [sum(float(x[i]) for x in blocks) / len(blocks) for i in range(len(PREFIX), len(EXPECTED_ORDER) + len(PREFIX))] avg = PREFIX + ["%.2f" % x for x in avg] blocks = sorted(blocks) blocks = [avg] + blocks chart = ["|%s|" % " | ".join(x) for x in blocks] for line in chart: print(line) ================================================ FILE: stanza/utils/training/remove_constituency_optimizer.py ================================================ """Saved a huge, bloated model with an optimizer? Use this to remove it, greatly shrinking the model size This tries to find reasonable defaults for word vectors and charlm (which need to be loaded so that the model knows the matrix sizes) so ideally all that needs to be run is python3 stanza/utils/training/remove_constituency_optimizer.py python3 stanza/utils/training/remove_constituency_optimizer.py da_arboretum ... This can also be used to load and save models as part of an update to the serialized format """ import argparse import logging import os from stanza.models import constituency_parser from stanza.models.common.constant import treebank_to_short_name from stanza.resources.default_packages import default_charlms, default_pretrains from stanza.utils.training import common logger = logging.getLogger('stanza') def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm') parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package") parser.add_argument('--load_dir', type=str, default="saved_models/constituency", help="Root dir for getting the models to resave.") parser.add_argument('--save_dir', type=str, default="resaved_models/constituency", help="Root dir for resaving the models.") parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks') args = parser.parse_args() return args def main(): """ For each of the models specified, load and resave the model The resaved model will have the optimizer removed """ args = parse_args() os.makedirs(args.save_dir, exist_ok=True) for treebank in args.treebanks: logger.info("PROCESSING %s", treebank) short_name = treebank_to_short_name(treebank) language, dataset = short_name.split("_", maxsplit=1) logger.info("%s: %s %s", short_name, language, dataset) if not args.wordvec_pretrain_file: # will throw an error if the pretrain can't be found wordvec_pretrain = common.find_wordvec_pretrain(language, default_pretrains) wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain] else: wordvec_args = [] charlm = common.choose_charlm(language, dataset, args.charlm, default_charlms, {}) charlm_args = common.build_charlm_args(language, charlm, base_args=False) base_name = '{}_constituency.pt'.format(short_name) load_name = os.path.join(args.load_dir, base_name) save_name = os.path.join(args.save_dir, base_name) resave_args = ['--mode', 'remove_optimizer', '--load_name', load_name, '--save_name', save_name, '--save_dir', ".", '--shorthand', short_name] resave_args = resave_args + wordvec_args + charlm_args constituency_parser.main(resave_args) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/training/run_charlm.py ================================================ """ Trains or scores a charlm model. """ import logging import os from stanza.models import charlm from stanza.utils.training import common from stanza.utils.training.common import Mode logger = logging.getLogger('stanza') def add_charlm_args(parser): """ Extra args for the charlm: forward/backward """ parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help="Forward or backward language model") parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help="Train a forward language model") parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help="Train a backward language model") def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset_name = short_name.split("_", 1) train_dir = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "train") dev_file = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "dev.txt") if not os.path.exists(dev_file) and os.path.exists(dev_file + ".xz"): dev_file = dev_file + ".xz" test_file = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "test.txt") if not os.path.exists(test_file) and os.path.exists(test_file + ".xz"): test_file = test_file + ".xz" # python -m stanza.models.charlm --train_dir $train_dir --eval_file $dev_file \ # --direction $direction --shorthand $short --mode train $args # python -m stanza.models.charlm --eval_file $dev_file \ # --direction $direction --shorthand $short --mode predict $args # python -m stanza.models.charlm --eval_file $test_file \ # --direction $direction --shorthand $short --mode predict $args direction = command_args.direction default_args = ['--%s' % direction, '--shorthand', short_name] if mode == Mode.TRAIN: train_args = ['--mode', 'train'] if '--train_dir' not in extra_args: train_args += ['--train_dir', train_dir] if '--eval_file' not in extra_args: train_args += ['--eval_file', dev_file] train_args = train_args + default_args + extra_args logger.info("Running train step with args: %s", train_args) charlm.main(train_args) if mode == Mode.SCORE_DEV: dev_args = ['--mode', 'predict'] if '--eval_file' not in extra_args: dev_args += ['--eval_file', dev_file] dev_args = dev_args + default_args + extra_args logger.info("Running dev step with args: %s", dev_args) charlm.main(dev_args) if mode == Mode.SCORE_TEST: test_args = ['--mode', 'predict'] if '--eval_file' not in extra_args: test_args += ['--eval_file', test_file] test_args = test_args + default_args + extra_args logger.info("Running test step with args: %s", test_args) charlm.main(test_args) def get_model_name(args): """ The charlm saves forward and backward charlms to the same dir, but with different filenames """ return "%s_charlm" % args.direction def main(): common.main(run_treebank, "charlm", get_model_name, add_charlm_args, charlm.build_argparse()) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_constituency.py ================================================ """ Trains or scores a constituency model. Currently a suuuuper preliminary script. Example of how to run on multiple parsers at the same time on the Stanford workqueue: for i in `echo 1000 1001 1002 1003 1004`; do nlprun -d a6000 "python3 stanza/utils/training/run_constituency.py vi_vlsp23 --use_bert --stage1_bert_finetun --save_name vi_vlsp23_$i.pt --seed $i --epochs 200 --force" -o vi_vlsp23_$i.out; done """ import logging import os from stanza.models import constituency_parser from stanza.models.constituency.retagging import RETAG_METHOD from stanza.utils.datasets.constituency import prepare_con_dataset from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain from stanza.resources.default_packages import default_charlms, default_pretrains logger = logging.getLogger('stanza') def add_constituency_args(parser): add_charlm_args(parser) parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') parser.add_argument('--parse_text', dest='mode', action='store_const', const="parse_text", help='Parse a text file') def build_wordvec_args(short_language, dataset, extra_args): if '--wordvec_pretrain_file' not in extra_args: # will throw an error if the pretrain can't be found wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains) wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain] else: wordvec_args = [] return wordvec_args def build_default_args(paths, short_language, dataset, command_args, extra_args): if short_language in RETAG_METHOD: retag_args = ["--retag_method", RETAG_METHOD[short_language]] else: retag_args = [] wordvec_args = build_wordvec_args(short_language, dataset, extra_args) charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {}) charlm_args = build_charlm_args(short_language, charlm, base_args=False) bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=True, layers=True) default_args = retag_args + wordvec_args + charlm_args + bert_args return default_args def build_model_filename(paths, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) default_args = build_default_args(paths, short_language, dataset, command_args, extra_args) train_args = ["--shorthand", short_name, "--mode", "train"] train_args = train_args + default_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: train_args.extend(["--save_dir", command_args.save_dir]) args = constituency_parser.parse_args(train_args) save_name = constituency_parser.build_model_filename(args) return save_name def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): constituency_dir = paths["CONSTITUENCY_DATA_DIR"] short_language, dataset = short_name.split("_") train_file = os.path.join(constituency_dir, f"{short_name}_train.mrg") dev_file = os.path.join(constituency_dir, f"{short_name}_dev.mrg") test_file = os.path.join(constituency_dir, f"{short_name}_test.mrg") if not os.path.exists(train_file) or not os.path.exists(dev_file) or not os.path.exists(test_file): logger.warning(f"The data for {short_name} is missing or incomplete. Attempting to rebuild...") try: prepare_con_dataset.main(short_name) except: logger.error(f"Unable to build the data. Please correctly build the files in {train_file}, {dev_file}, {test_file} and then try again.") raise default_args = build_default_args(paths, short_language, dataset, command_args, extra_args) if mode == Mode.TRAIN: train_args = ['--train_file', train_file, '--eval_file', dev_file, '--shorthand', short_name, '--mode', 'train'] train_args = train_args + default_args + extra_args logger.info("Running train step with args: {}".format(train_args)) constituency_parser.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ['--eval_file', dev_file, '--shorthand', short_name, '--mode', 'predict'] dev_args = dev_args + default_args + extra_args logger.info("Running dev step with args: {}".format(dev_args)) constituency_parser.main(dev_args) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ['--eval_file', test_file, '--shorthand', short_name, '--mode', 'predict'] test_args = test_args + default_args + extra_args logger.info("Running test step with args: {}".format(test_args)) constituency_parser.main(test_args) if mode == "parse_text": text_args = ['--shorthand', short_name, '--mode', 'parse_text'] text_args = text_args + default_args + extra_args logger.info("Processing text with args: {}".format(text_args)) constituency_parser.main(text_args) def main(): common.main(run_treebank, "constituency", "constituency", add_constituency_args, sub_argparse=constituency_parser.build_argparse(), build_model_filename=build_model_filename) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_depparse.py ================================================ import io import logging import os from stanza.models import parser from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_depparse_charlm_args, choose_depparse_charlm, choose_transformer from stanza.utils.training.common import build_depparse_wordvec_args from stanza.resources.default_packages import default_charlms, depparse_charlms logger = logging.getLogger('stanza') def add_depparse_args(parser): add_charlm_args(parser) parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') # TODO: refactor with run_pos def build_model_filename(paths, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) # TODO: can avoid downloading the charlm at this point, since we # might not even be training charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm) bert_args = choose_transformer(short_language, command_args, extra_args, warn=False) train_args = ["--shorthand", short_name, "--mode", "train"] # TODO: also, this downloads the wordvec, which we might not want to do yet train_args = train_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: train_args.extend(["--save_dir", command_args.save_dir]) args = parser.parse_args(train_args) save_name = parser.model_file_name(args) return save_name def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) # TODO: refactor these blocks? depparse_dir = paths["DEPPARSE_DATA_DIR"] train_file = f"{depparse_dir}/{short_name}.train.in.conllu" dev_in_file = f"{depparse_dir}/{short_name}.dev.in.conllu" dev_pred_file = f"{depparse_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{depparse_dir}/{short_name}.test.in.conllu" test_pred_file = f"{depparse_dir}/{short_name}.test.pred.conllu" eval_file = None if '--eval_file' in extra_args: eval_file = extra_args[extra_args.index('--eval_file') + 1] charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm) bert_args = choose_transformer(short_language, command_args, extra_args) if mode == Mode.TRAIN: zip_train_file = os.path.splitext(train_file)[0] + ".zip" if os.path.exists(train_file) and os.path.exists(zip_train_file): logger.error("POS TRAIN FILE %s and %s both exist... this is very confusing, skipping %s" % (train_file, zip_train_file, short_name)) return if os.path.exists(zip_train_file): train_file = zip_train_file if not os.path.exists(train_file): logger.error("TRAIN FILE NOT FOUND: %s ... skipping" % train_file) return # some languages need reduced batch size if short_name == 'de_hdt': # 'UD_German-HDT' batch_size = "1300" elif short_name in ('hr_set', 'fi_tdt', 'ru_taiga', 'cs_cltt', 'gl_treegal', 'lv_lvtb', 'ro_simonero'): # 'UD_Croatian-SET', 'UD_Finnish-TDT', 'UD_Russian-Taiga', # 'UD_Czech-CLTT', 'UD_Galician-TreeGal', 'UD_Latvian-LVTB' 'Romanian-SiMoNERo' batch_size = "3000" else: batch_size = "5000" train_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--train_file", train_file, "--eval_file", eval_file if eval_file else dev_in_file, "--batch_size", batch_size, "--lang", short_language, "--shorthand", short_name, "--mode", "train"] train_args = train_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args train_args = train_args + extra_args logger.info("Running train depparse for {} with args {}".format(treebank, train_args)) parser.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--eval_file", eval_file if eval_file else dev_in_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] if command_args.save_output: dev_args.extend(["--output_file", dev_pred_file]) dev_args = dev_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args dev_args = dev_args + extra_args logger.info("Running dev depparse for {} with args {}".format(treebank, dev_args)) _, dev_doc = parser.main(dev_args) if '--no_gold_labels' not in extra_args: if not command_args.save_output: dev_pred_file = "{:C}\n\n".format(dev_doc) dev_pred_file = io.StringIO(dev_pred_file) results = common.run_eval_script_depparse(eval_file if eval_file else dev_in_file, dev_pred_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) if command_args.save_output: logger.info("Output saved to %s", dev_pred_file) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--eval_file", eval_file if eval_file else test_in_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] if command_args.save_output: test_args.extend(["--output_file", test_pred_file]) test_args = test_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args test_args = test_args + extra_args logger.info("Running test depparse for {} with args {}".format(treebank, test_args)) _, test_doc = parser.main(test_args) if '--no_gold_labels' not in extra_args: if not command_args.save_output: test_pred_file = "{:C}\n\n".format(test_doc) test_pred_file = io.StringIO(test_pred_file) results = common.run_eval_script_depparse(eval_file if eval_file else test_in_file, test_pred_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) if command_args.save_output: logger.info("Output saved to %s", test_pred_file) def main(): common.main(run_treebank, "depparse", "parser", add_depparse_args, sub_argparse=parser.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_depparse_charlm) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_ete.py ================================================ """ Runs a pipeline end-to-end, reports conll scores. For example, you can do python3 stanza/utils/training/run_ete.py it_isdt --score_test You can run on all models at once: python3 stanza/utils/training/run_ete.py ud_all --score_test You can also run one model on a different model's data: python3 stanza/utils/training/run_ete.py it_isdt --score_dev --test_data it_vit python3 stanza/utils/training/run_ete.py it_isdt --score_test --test_data it_vit Running multiple models with a --test_data flag will run them all on the same data: python3 stanza/utils/training/run_ete.py it_combined it_isdt it_vit --score_test --test_data it_vit If run with no dataset arguments, then the dataset used is the train data, which may or may not be useful. """ import logging import os import tempfile from stanza.models import identity_lemmatizer from stanza.models import lemmatizer from stanza.models import mwt_expander from stanza.models import parser from stanza.models import tagger from stanza.models import tokenizer from stanza.models.common.constant import treebank_to_short_name from stanza.utils.training import common from stanza.utils.training.common import Mode, build_tokenizer_charlm_args, build_pos_charlm_args, build_lemma_charlm_args, build_depparse_charlm_args, build_pos_wordvec_args, build_depparse_wordvec_args from stanza.utils.training.run_lemma import check_lemmas from stanza.utils.training.run_mwt import check_mwt logger = logging.getLogger('stanza') # a constant so that the script which looks for these results knows what to look for RESULTS_STRING = "End to end results for" def add_args(parser): parser.add_argument('--test_data', default=None, type=str, help='Which data to test on, if not using the default data for this model') common.add_charlm_args(parser) def run_ete(paths, dataset, short_name, command_args, extra_args): short_language, package = short_name.split("_", 1) tokenize_dir = paths["TOKENIZE_DATA_DIR"] mwt_dir = paths["MWT_DATA_DIR"] lemma_dir = paths["LEMMA_DATA_DIR"] ete_dir = paths["ETE_DATA_DIR"] wordvec_dir = paths["WORDVEC_DIR"] # run models in the following order: # tokenize # mwt, if exists # pos # lemma, if exists # depparse # the output of each step is either kept or discarded based on the # value of command_args.save_output if command_args and command_args.test_data: test_short_name = treebank_to_short_name(command_args.test_data) else: test_short_name = short_name # TOKENIZE step # the raw data to process starts in tokenize_dir # retokenize it using the saved model tokenizer_type = "--txt_file" tokenizer_file = f"{tokenize_dir}/{test_short_name}.{dataset}.txt" tokenizer_output = f"{ete_dir}/{short_name}.{dataset}.tokenizer.conllu" tokenizer_args = ["--mode", "predict", tokenizer_type, tokenizer_file, "--lang", short_language, "--conll_file", tokenizer_output, "--shorthand", short_name] tokenizer_charlm_args = build_tokenizer_charlm_args(short_language, package, command_args.charlm) tokenizer_args = tokenizer_args + tokenizer_charlm_args + extra_args logger.info("----- TOKENIZER ----------") logger.info("Running tokenizer step with args: {}".format(tokenizer_args)) tokenizer.main(tokenizer_args) # If the data has any MWT in it, there should be an MWT model # trained, so run that. Otherwise, we skip MWT mwt_train_file = f"{mwt_dir}/{short_name}.train.in.conllu" logger.info("----- MWT ----------") if check_mwt(mwt_train_file): mwt_output = f"{ete_dir}/{short_name}.{dataset}.mwt.conllu" mwt_args = ['--eval_file', tokenizer_output, '--output_file', mwt_output, '--lang', short_language, '--shorthand', short_name, '--mode', 'predict'] mwt_args = mwt_args + extra_args logger.info("Running mwt step with args: {}".format(mwt_args)) mwt_expander.main(mwt_args) else: logger.info("No MWT in training data. Skipping") mwt_output = tokenizer_output # Run the POS step # TODO: add batch args # TODO: add transformer args logger.info("----- POS ----------") pos_output = f"{ete_dir}/{short_name}.{dataset}.pos.conllu" pos_args = ['--wordvec_dir', wordvec_dir, '--eval_file', mwt_output, '--output_file', pos_output, '--lang', short_language, '--shorthand', short_name, '--mode', 'predict', # the MWT is not preserving the tags, # so we don't ask the tagger to report a score # the ETE will score the whole thing at the end '--no_gold_labels'] pos_charlm_args = build_pos_charlm_args(short_language, package, command_args.charlm) pos_args = pos_args + build_pos_wordvec_args(short_language, package, extra_args) + pos_charlm_args + extra_args logger.info("Running pos step with args: {}".format(pos_args)) tagger.main(pos_args) # Run the LEMMA step. If there are no lemmas in the training # data, use the identity lemmatizer. logger.info("----- LEMMA ----------") lemma_train_file = f"{lemma_dir}/{short_name}.train.in.conllu" lemma_output = f"{ete_dir}/{short_name}.{dataset}.lemma.conllu" lemma_args = ['--eval_file', pos_output, '--output_file', lemma_output, '--shorthand', short_name, '--mode', 'predict'] if check_lemmas(lemma_train_file): lemma_charlm_args = build_lemma_charlm_args(short_language, package, command_args.charlm) lemma_args = lemma_args + lemma_charlm_args + extra_args logger.info("Running lemmatizer step with args: {}".format(lemma_args)) lemmatizer.main(lemma_args) else: lemma_args = lemma_args + extra_args logger.info("No lemmas in training data") logger.info("Running identity lemmatizer step with args: {}".format(lemma_args)) identity_lemmatizer.main(lemma_args) # Run the DEPPARSE step. This is the last step # Note that we do NOT use the depparse directory's data. That is # because it has either gold tags, or predicted tags based on # retagging using gold tokenization, and we aren't sure which at # this point in the process. # TODO: add batch args logger.info("----- DEPPARSE ----------") depparse_output = f"{ete_dir}/{short_name}.{dataset}.depparse.conllu" depparse_args = ['--wordvec_dir', wordvec_dir, '--eval_file', lemma_output, '--output_file', depparse_output, '--lang', short_name, '--shorthand', short_name, '--mode', 'predict', # we don't ask the parser to report a score either '--no_gold_labels'] depparse_charlm_args = build_depparse_charlm_args(short_language, package, command_args.charlm) depparse_args = depparse_args + build_depparse_wordvec_args(short_language, package, extra_args) + depparse_charlm_args + extra_args logger.info("Running depparse step with args: {}".format(depparse_args)) parser.main(depparse_args) logger.info("----- EVALUATION ----------") gold_file = f"{tokenize_dir}/{test_short_name}.{dataset}.gold.conllu" ete_file = depparse_output results = common.run_eval_script(gold_file, ete_file) logger.info("{} {} models on {} {} data:\n{}".format(RESULTS_STRING, short_name, test_short_name, dataset, results)) def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): if mode == Mode.TRAIN: dataset = 'train' elif mode == Mode.SCORE_DEV: dataset = 'dev' elif mode == Mode.SCORE_TEST: dataset = 'test' if not command_args.save_output: with tempfile.TemporaryDirectory() as ete_dir: paths = dict(paths) paths["ETE_DATA_DIR"] = ete_dir run_ete(paths, dataset, short_name, command_args, extra_args) else: os.makedirs(paths["ETE_DATA_DIR"], exist_ok=True) run_ete(paths, dataset, short_name, command_args, extra_args) def main(): common.main(run_treebank, "ete", "ete", add_args) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_lemma.py ================================================ """ This script allows for training or testing on dev / test of the UD lemmatizer. If run with a single treebank name, it will train or test that treebank. If run with ud_all or all_ud, it will iterate over all UD treebanks it can find. Mode can be set to train&dev with --train, to dev set only with --score_dev, and to test set only with --score_test. Treebanks are specified as a list. all_ud or ud_all means to look for all UD treebanks. Extra arguments are passed to the lemmatizer. In case the run script itself is shadowing arguments, you can specify --extra_args as a parameter to mark where the lemmatizer arguments start. """ import logging import os from stanza.models import identity_lemmatizer from stanza.models import lemmatizer from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm from stanza.utils.datasets.prepare_lemma_treebank import check_lemmas import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier logger = logging.getLogger('stanza') def add_lemma_args(parser): add_charlm_args(parser) parser.add_argument('--lemma_classifier', dest='lemma_classifier', action='store_true', default=None, help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer if the charlm is used") parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false', help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer if the charlm is used") def build_model_filename(paths, short_name, command_args, extra_args): """ Figure out what the model savename will be, taking into account the model settings. Useful for figuring out if the model already exists None will represent that there is no expected save_name """ short_language, dataset = short_name.split("_", 1) lemma_dir = paths["LEMMA_DATA_DIR"] train_file = f"{lemma_dir}/{short_name}.train.in.conllu" if not os.path.exists(train_file): logger.debug("Treebank %s is not prepared for training the lemmatizer. Could not find any training data at %s Cannot figure out the expected save_name without looking at the data, but a later step in the process will skip the training anyway" % (short_name, train_file)) return None has_lemmas = check_lemmas(train_file) if not has_lemmas: return None # TODO: can avoid downloading the charlm at this point, since we # might not even be training charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) train_args = ["--train_file", train_file, "--shorthand", short_name, "--mode", "train"] train_args = train_args + charlm_args + extra_args args = lemmatizer.parse_args(train_args) save_name = lemmatizer.build_model_filename(args) return save_name def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) lemma_dir = paths["LEMMA_DATA_DIR"] train_file = f"{lemma_dir}/{short_name}.train.in.conllu" dev_in_file = f"{lemma_dir}/{short_name}.dev.in.conllu" dev_pred_file = f"{lemma_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{lemma_dir}/{short_name}.test.in.conllu" test_pred_file = f"{lemma_dir}/{short_name}.test.pred.conllu" charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) if not os.path.exists(train_file): logger.error("Treebank %s is not prepared for training the lemmatizer. Could not find any training data at %s Skipping..." % (treebank, train_file)) return has_lemmas = check_lemmas(train_file) if not has_lemmas: logger.info("Treebank " + treebank + " (" + short_name + ") has no lemmas. Using identity lemmatizer") if mode == Mode.TRAIN or mode == Mode.SCORE_DEV: train_args = ["--train_file", train_file, "--eval_file", dev_in_file, "--gold_file", dev_in_file, "--shorthand", short_name] if command_args.save_output: train_args.extend(["--output_file", dev_pred_file]) logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args)) identity_lemmatizer.main(train_args) elif mode == Mode.SCORE_TEST: train_args = ["--train_file", train_file, "--eval_file", test_in_file, "--gold_file", test_in_file, "--shorthand", short_name] if command_args.save_output: train_args.extend(["--output_file", test_pred_file]) logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args)) identity_lemmatizer.main(train_args) else: if mode == Mode.TRAIN: # ('UD_Czech-PDT', 'UD_Russian-SynTagRus', 'UD_German-HDT') if short_name in ('cs_pdt', 'ru_syntagrus', 'de_hdt'): num_epochs = "30" else: num_epochs = "60" train_args = ["--train_file", train_file, "--eval_file", dev_in_file, "--shorthand", short_name, "--num_epoch", num_epochs, "--mode", "train"] train_args = train_args + charlm_args + extra_args logger.info("Running train lemmatizer for {} with args {}".format(treebank, train_args)) lemmatizer.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--eval_file", dev_in_file, "--shorthand", short_name, "--mode", "predict"] if command_args.save_output: train_args.extend(["--output_file", dev_pred_file]) dev_args = dev_args + charlm_args + extra_args logger.info("Running dev lemmatizer for {} with args {}".format(treebank, dev_args)) lemmatizer.main(dev_args) if mode == Mode.SCORE_TEST: test_args = ["--eval_file", test_in_file, "--shorthand", short_name, "--mode", "predict"] if command_args.save_output: train_args.extend(["--output_file", test_pred_file]) test_args = test_args + charlm_args + extra_args logger.info("Running test lemmatizer for {} with args {}".format(treebank, test_args)) lemmatizer.main(test_args) use_lemma_classifier = command_args.lemma_classifier if use_lemma_classifier is None: use_lemma_classifier = command_args.charlm is not None use_lemma_classifier = use_lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING if use_lemma_classifier and mode == Mode.TRAIN: # some installations may not have transformers, # so we bury the lemma_classifier import in the codepath # which actually needs it from stanza.models.lemma import attach_lemma_classifier from stanza.utils.training import run_lemma_classifier lc_charlm_args = ['--no_charlm'] if command_args.charlm is None else ['--charlm', command_args.charlm] lemma_classifier_args = [treebank] + lc_charlm_args if command_args.force: lemma_classifier_args.append('--force') run_lemma_classifier.main(lemma_classifier_args) save_name = build_model_filename(paths, short_name, command_args, extra_args) # TODO: use a temp path for the lemma_classifier or keep it somewhere attach_args = ['--input', save_name, '--output', save_name, '--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name] attach_lemma_classifier.main(attach_args) # now we rerun the dev set - the HI in particular demonstrates some good improvement lemmatizer.main(dev_args) def main(): common.main(run_treebank, "lemma", "lemmatizer", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_lemma_classifier.py ================================================ import os from stanza.models.lemma_classifier import evaluate_models from stanza.models.lemma_classifier import train_lstm_model from stanza.models.lemma_classifier import train_transformer_model from stanza.models.lemma_classifier.constants import ModelType from stanza.resources.default_packages import default_pretrains, TRANSFORMERS from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm, find_wordvec_pretrain def add_lemma_args(parser): add_charlm_args(parser) parser.add_argument('--model_type', default=ModelType.LSTM, type=lambda x: ModelType[x.upper()], help='Model type to use. {}'.format(", ".join(x.name for x in ModelType))) def build_model_filename(paths, short_name, command_args, extra_args): return os.path.join("saved_models", "lemma_classifier", short_name + "_lemma_classifier.pt") def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) base_args = [] if '--save_name' not in extra_args: base_args += ['--save_name', build_model_filename(paths, short_name, command_args, extra_args)] embedding_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) if '--wordvec_pretrain_file' not in extra_args: wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, {}, dataset) embedding_args += ["--wordvec_pretrain_file", wordvec_pretrain] bert_args = [] if command_args.model_type is ModelType.TRANSFORMER: if '--bert_model' not in extra_args: if short_language in TRANSFORMERS: bert_args = ['--bert_model', TRANSFORMERS.get(short_language)] else: raise ValueError("--bert_model not specified, so cannot figure out which transformer to use for language %s" % short_language) extra_train_args = [] if command_args.force: extra_train_args.append('--force') if mode == Mode.TRAIN: train_args = [] if "--train_file" not in extra_args: train_file = os.path.join("data", "lemma_classifier", "%s.train.lemma" % short_name) train_args += ['--train_file', train_file] if "--eval_file" not in extra_args: eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name) train_args += ['--eval_file', eval_file] train_args = base_args + train_args + extra_args + extra_train_args if command_args.model_type == ModelType.LSTM: train_args = embedding_args + train_args train_lstm_model.main(train_args) else: model_type_args = ["--model_type", command_args.model_type.name.lower()] train_args = bert_args + model_type_args + train_args train_transformer_model.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: eval_args = [] if "--eval_file" not in extra_args: eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name) eval_args += ['--eval_file', eval_file] model_type_args = ["--model_type", command_args.model_type.name.lower()] eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args evaluate_models.main(eval_args) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: eval_args = [] if "--eval_file" not in extra_args: eval_file = os.path.join("data", "lemma_classifier", "%s.test.lemma" % short_name) eval_args += ['--eval_file', eval_file] model_type_args = ["--model_type", command_args.model_type.name.lower()] eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args evaluate_models.main(eval_args) def main(args=None): common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm, args=args) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/training/run_mwt.py ================================================ """ This script allows for training or testing on dev / test of the UD mwt tools. If run with a single treebank name, it will train or test that treebank. If run with ud_all or all_ud, it will iterate over all UD treebanks it can find. Mode can be set to train&dev with --train, to dev set only with --score_dev, and to test set only with --score_test. Treebanks are specified as a list. all_ud or ud_all means to look for all UD treebanks. Extra arguments are passed to mwt. In case the run script itself is shadowing arguments, you can specify --extra_args as a parameter to mark where the mwt arguments start. """ import io import logging import math from stanza.models import mwt_expander from stanza.models.common.doc import Document from stanza.utils.conll import CoNLL from stanza.utils.training import common from stanza.utils.training.common import Mode from stanza.utils.max_mwt_length import max_mwt_length logger = logging.getLogger('stanza') def check_mwt(filename): """ Checks whether or not there are MWTs in the given conll file """ doc = CoNLL.conll2doc(filename) data = doc.get_mwt_expansions(False) return len(data) > 0 def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language = short_name.split("_")[0] mwt_dir = paths["MWT_DATA_DIR"] train_file = f"{mwt_dir}/{short_name}.train.in.conllu" dev_in_file = f"{mwt_dir}/{short_name}.dev.in.conllu" dev_gold_file = f"{mwt_dir}/{short_name}.dev.gold.conllu" dev_output_file = f"{mwt_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{mwt_dir}/{short_name}.test.in.conllu" test_gold_file = f"{mwt_dir}/{short_name}.test.gold.conllu" test_output_file = f"{mwt_dir}/{short_name}.test.pred.conllu" train_json = f"{mwt_dir}/{short_name}-ud-train-mwt.json" dev_json = f"{mwt_dir}/{short_name}-ud-dev-mwt.json" test_json = f"{mwt_dir}/{short_name}-ud-test-mwt.json" eval_file = None if '--eval_file' in extra_args: eval_file = extra_args[extra_args.index('--eval_file') + 1] gold_file = None if '--gold_file' in extra_args: gold_file = extra_args[extra_args.index('--gold_file') + 1] if not check_mwt(train_file): logger.info("No training MWTS found for %s. Skipping" % treebank) return if not check_mwt(dev_in_file) and mode == Mode.TRAIN: logger.info("No dev MWTS found for %s. Training only the deterministic MWT expander" % treebank) extra_args.append('--dict_only') if mode == Mode.TRAIN: max_mwt_len = math.ceil(max_mwt_length([train_json, dev_json]) * 1.1 + 1) logger.info("Max len: %f" % max_mwt_len) train_args = ['--train_file', train_file, '--eval_file', eval_file if eval_file else dev_in_file, '--gold_file', gold_file if gold_file else dev_gold_file, '--lang', short_language, '--shorthand', short_name, '--mode', 'train', '--max_dec_len', str(max_mwt_len)] train_args = train_args + extra_args logger.info("Running train step with args: {}".format(train_args)) mwt_expander.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ['--eval_file', eval_file if eval_file else dev_in_file, '--gold_file', gold_file if gold_file else dev_gold_file, '--lang', short_language, '--shorthand', short_name, '--mode', 'predict'] if command_args.save_output: dev_args.extend(['--output_file', dev_output_file]) dev_args = dev_args + extra_args logger.info("Running dev step with args: {}".format(dev_args)) _, dev_doc = mwt_expander.main(dev_args) if not command_args.save_output: dev_output_file = "{:C}\n\n".format(dev_doc) dev_output_file = io.StringIO(dev_output_file) results = common.run_eval_script_mwt(gold_file if gold_file else dev_gold_file, dev_output_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) if mode == Mode.SCORE_TEST: test_args = ['--eval_file', eval_file if eval_file else test_in_file, '--gold_file', gold_file if gold_file else test_gold_file, '--lang', short_language, '--shorthand', short_name, '--mode', 'predict'] if command_args.save_output: test_args.extend(['--output_file', test_output_file]) test_args = test_args + extra_args logger.info("Running test step with args: {}".format(test_args)) _, test_doc = mwt_expander.main(test_args) if not command_args.save_output: test_output_file = "{:C}\n\n".format(test_doc) test_output_file = io.StringIO(test_output_file) results = common.run_eval_script_mwt(gold_file if gold_file else test_gold_file, test_output_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) def main(): common.main(run_treebank, "mwt", "mwt_expander", sub_argparse=mwt_expander.build_argparse()) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_ner.py ================================================ """ Trains or scores an NER model. Will attempt to guess the appropriate word vector file if none is specified, and will use the charlms specified in the resources for a given dataset or language if possible. Example command line: python3 -m stanza.utils.training.run_ner.py hu_combined This script expects the prepared data to be in data/ner/{lang}_{dataset}.train.json, {lang}_{dataset}.dev.json, {lang}_{dataset}.test.json If those files don't exist, it will make an attempt to rebuild them using the prepare_ner_dataset script. However, this will fail if the data is not already downloaded. More information on where to find most of the datasets online is in that script. Some of the datasets have licenses which must be agreed to, so no attempt is made to automatically download the data. """ import logging import os from stanza.models import ner_tagger from stanza.resources.common import DEFAULT_MODEL_DIR from stanza.utils.datasets.ner import prepare_ner_dataset from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain from stanza.resources.default_packages import default_charlms, default_pretrains, ner_charlms, ner_pretrains # extra arguments specific to a particular dataset DATASET_EXTRA_ARGS = { "da_ddt": [ "--dropout", "0.6" ], "fa_arman": [ "--dropout", "0.6" ], "vi_vlsp": [ "--dropout", "0.6", "--word_dropout", "0.1", "--locked_dropout", "0.1", "--char_dropout", "0.1" ], } logger = logging.getLogger('stanza') def add_ner_args(parser): add_charlm_args(parser) parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') def build_pretrain_args(language, dataset, charlm="default", command_args=None, extra_args=None, model_dir=DEFAULT_MODEL_DIR): """ Returns one list with the args for this language & dataset's charlm and pretrained embedding """ charlm = choose_charlm(language, dataset, charlm, default_charlms, ner_charlms) charlm_args = build_charlm_args(language, charlm, model_dir=model_dir) wordvec_args = [] if '--wordvec_pretrain_file' not in extra_args and '--no_pretrain' not in extra_args: # will throw an error if the pretrain can't be found wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains, ner_pretrains, dataset, model_dir=model_dir) wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain] bert_args = common.choose_transformer(language, command_args, extra_args, warn=False) return charlm_args + wordvec_args + bert_args # TODO: refactor? tagger and depparse should be pretty similar def build_model_filename(paths, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) # TODO: can avoid downloading the charlm at this point, since we # might not even be training pretrain_args = build_pretrain_args(short_language, dataset, command_args.charlm, command_args, extra_args) dataset_args = DATASET_EXTRA_ARGS.get(short_name, []) train_args = ["--shorthand", short_name, "--mode", "train"] train_args = train_args + pretrain_args + dataset_args + extra_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: train_args.extend(["--save_dir", command_args.save_dir]) args = ner_tagger.parse_args(train_args) save_name = ner_tagger.model_file_name(args) return save_name # Technically NER datasets are not necessarily treebanks # (usually not, in fact) # However, to keep the naming consistent, we leave the # method which does the training as run_treebank # TODO: rename treebank -> dataset everywhere def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): ner_dir = paths["NER_DATA_DIR"] language, dataset = short_name.split("_") train_file = os.path.join(ner_dir, f"{treebank}.train.json") dev_file = os.path.join(ner_dir, f"{treebank}.dev.json") test_file = os.path.join(ner_dir, f"{treebank}.test.json") # if any files are missing, try to rebuild the dataset # if that still doesn't work, we have to throw an error missing_file = [x for x in (train_file, dev_file, test_file) if not os.path.exists(x)] if len(missing_file) > 0: logger.warning(f"The data for {treebank} is missing or incomplete. Cannot find {missing_file} Attempting to rebuild...") try: prepare_ner_dataset.main(treebank) except Exception as e: raise FileNotFoundError(f"An exception occurred while trying to build the data for {treebank} At least one portion of the data was missing: {missing_file} Please correctly build these files and then try again.") from e pretrain_args = build_pretrain_args(language, dataset, command_args.charlm, command_args, extra_args) if mode == Mode.TRAIN: # VI example arguments: # --wordvec_pretrain_file ~/stanza_resources/vi/pretrain/vtb.pt # --train_file data/ner/vi_vlsp.train.json # --eval_file data/ner/vi_vlsp.dev.json # --lang vi # --shorthand vi_vlsp # --mode train # --charlm --charlm_shorthand vi_conll17 # --dropout 0.6 --word_dropout 0.1 --locked_dropout 0.1 --char_dropout 0.1 dataset_args = DATASET_EXTRA_ARGS.get(short_name, []) train_args = ['--train_file', train_file, '--eval_file', dev_file, '--shorthand', short_name, '--mode', 'train'] train_args = train_args + pretrain_args + dataset_args + extra_args logger.info("Running train step with args: {}".format(train_args)) ner_tagger.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ['--eval_file', dev_file, '--shorthand', short_name, '--mode', 'predict'] dev_args = dev_args + pretrain_args + extra_args logger.info("Running dev step with args: {}".format(dev_args)) ner_tagger.main(dev_args) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ['--eval_file', test_file, '--shorthand', short_name, '--mode', 'predict'] test_args = test_args + pretrain_args + extra_args logger.info("Running test step with args: {}".format(test_args)) ner_tagger.main(test_args) def main(): common.main(run_treebank, "ner", "nertagger", add_ner_args, ner_tagger.build_argparse(), build_model_filename=build_model_filename) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_pos.py ================================================ import io import logging import os from stanza.models import tagger from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_pos_charlm_args, choose_pos_charlm, find_wordvec_pretrain, build_pos_wordvec_args logger = logging.getLogger('stanza') def add_pos_args(parser): add_charlm_args(parser) parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') def build_model_filename(paths, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) # TODO: can avoid downloading the charlm at this point, since we # might not even be training charlm_args = build_pos_charlm_args(short_language, dataset, command_args.charlm) bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=False) train_args = ["--shorthand", short_name, "--mode", "train"] # TODO: also, this downloads the wordvec, which we might not want to do yet train_args = train_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: train_args.extend(["--save_dir", command_args.save_dir]) args = tagger.parse_args(train_args) save_name = tagger.model_file_name(args) return save_name def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) pos_dir = paths["POS_DATA_DIR"] train_file = f"{pos_dir}/{short_name}.train.in.conllu" if short_name == 'vi_vlsp22': train_file += f";{pos_dir}/vi_vtb.train.in.conllu" dev_in_file = f"{pos_dir}/{short_name}.dev.in.conllu" dev_pred_file = f"{pos_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{pos_dir}/{short_name}.test.in.conllu" test_pred_file = f"{pos_dir}/{short_name}.test.pred.conllu" charlm_args = build_pos_charlm_args(short_language, dataset, command_args.charlm) bert_args = common.choose_transformer(short_language, command_args, extra_args) eval_file = None if '--eval_file' in extra_args: eval_file = extra_args[extra_args.index('--eval_file') + 1] if mode == Mode.TRAIN: train_pieces = [] for train_piece in train_file.split(";"): zip_piece = os.path.splitext(train_piece)[0] + ".zip" if os.path.exists(train_piece) and os.path.exists(zip_piece): logger.error("POS TRAIN FILE %s and %s both exist... this is very confusing, skipping %s" % (train_piece, zip_piece, short_name)) return if os.path.exists(train_piece): train_pieces.append(train_piece) else: # not os.path.exists(train_piece): if os.path.exists(zip_piece): train_pieces.append(zip_piece) continue logger.error("TRAIN FILE NOT FOUND: %s ... skipping" % train_piece) return train_file = ";".join(train_pieces) train_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--train_file", train_file, "--lang", short_language, "--shorthand", short_name, "--mode", "train"] if eval_file is None: train_args += ['--eval_file', dev_in_file] train_args = train_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args train_args = train_args + extra_args logger.info("Running train POS for {} with args {}".format(treebank, train_args)) tagger.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] if eval_file is None: dev_args += ['--eval_file', dev_in_file] if command_args.save_output: dev_args.extend(["--output_file", dev_pred_file]) dev_args = dev_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args dev_args = dev_args + extra_args logger.info("Running dev POS for {} with args {}".format(treebank, dev_args)) _, dev_doc = tagger.main(dev_args) if not command_args.save_output: dev_pred_file = "{:C}\n\n".format(dev_doc) dev_pred_file = io.StringIO(dev_pred_file) results = common.run_eval_script_pos(eval_file if eval_file else dev_in_file, dev_pred_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) if command_args.save_output: logger.info("Output saved to %s", dev_pred_file) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] if eval_file is None: test_args += ['--eval_file', test_in_file] if command_args.save_output: dev_args.extend(["--output_file", test_pred_file]) test_args = test_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args test_args = test_args + extra_args logger.info("Running test POS for {} with args {}".format(treebank, test_args)) _, test_doc = tagger.main(test_args) if not command_args.save_output: test_pred_file = "{:C}\n\n".format(test_doc) test_pred_file = io.StringIO(test_pred_file) results = common.run_eval_script_pos(eval_file if eval_file else test_in_file, test_pred_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) if command_args.save_output: logger.info("Output saved to %s", test_pred_file) def main(): common.main(run_treebank, "pos", "tagger", add_pos_args, tagger.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_pos_charlm) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_sentiment.py ================================================ """ Trains or tests a sentiment model using the classifier package The prep script has separate entries for the root-only version of SST, which is what people typically use to test. When training a model for SST which uses all the data, the root-only version is used for dev and test """ import logging import os from stanza.models import classifier from stanza.utils.training import common from stanza.utils.training.common import Mode, build_charlm_args, choose_charlm, find_wordvec_pretrain from stanza.resources.default_packages import default_charlms, default_pretrains logger = logging.getLogger('stanza') # TODO: refactor with ner & conparse def add_sentiment_args(parser): parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm') parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package") parser.add_argument('--use_charlm', action='store_true', help='If --use_bert is set, charlm will be turned off. This turns it on anyway') parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') ALTERNATE_DATASET = { "en_sst2": "en_sst2roots", "en_sstplus": "en_sst3roots", } def build_default_args(paths, short_language, dataset, command_args, extra_args): if '--wordvec_pretrain_file' not in extra_args: # will throw an error if the pretrain can't be found wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains) wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain] else: wordvec_args = [] if command_args.use_bert and not command_args.use_charlm: charlm = None else: charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {}) charlm_args = build_charlm_args(short_language, charlm, base_args=False) bert_args = common.choose_transformer(short_language, command_args, extra_args) default_args = wordvec_args + charlm_args + bert_args return default_args def build_model_filename(paths, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) default_args = build_default_args(paths, short_language, dataset, command_args, extra_args) train_args = ["--shorthand", short_name] train_args = train_args + default_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: train_args.extend(["--save_dir", command_args.save_dir]) args = classifier.parse_args(train_args + extra_args) save_name = classifier.build_model_filename(args) return save_name def run_dataset(mode, paths, treebank, short_name, command_args, extra_args): sentiment_dir = paths["SENTIMENT_DATA_DIR"] short_language, dataset = short_name.split("_", 1) train_file = os.path.join(sentiment_dir, f"{short_name}.train.json") other_name = ALTERNATE_DATASET.get(short_name, short_name) dev_file = os.path.join(sentiment_dir, f"{other_name}.dev.json") test_file = os.path.join(sentiment_dir, f"{other_name}.test.json") for filename in (train_file, dev_file, test_file): if not os.path.exists(filename): raise FileNotFoundError("Cannot find %s" % filename) default_args = build_default_args(paths, short_language, dataset, command_args, extra_args) if mode == Mode.TRAIN: train_args = ['--train_file', train_file, '--dev_file', dev_file, '--test_file', test_file, '--shorthand', short_name, '--wordvec_type', 'word2vec', # TODO: chinese is fasttext '--extra_wordvec_method', 'SUM'] train_args = train_args + default_args + extra_args logger.info("Running train step with args: {}".format(train_args)) classifier.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ['--no_train', '--test_file', dev_file, '--shorthand', short_name, '--wordvec_type', 'word2vec'] # TODO: chinese is fasttext dev_args = dev_args + default_args + extra_args logger.info("Running dev step with args: {}".format(dev_args)) classifier.main(dev_args) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ['--no_train', '--test_file', test_file, '--shorthand', short_name, '--wordvec_type', 'word2vec'] # TODO: chinese is fasttext test_args = test_args + default_args + extra_args logger.info("Running test step with args: {}".format(test_args)) classifier.main(test_args) def main(): common.main(run_dataset, "classifier", "classifier", add_sentiment_args, classifier.build_argparse(), build_model_filename=build_model_filename) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/run_tokenizer.py ================================================ """ This script allows for training or testing on dev / test of the UD tokenizer. If run with a single treebank name, it will train or test that treebank. If run with ud_all or all_ud, it will iterate over all UD treebanks it can find. Mode can be set to train&dev with --train, to dev set only with --score_dev, and to test set only with --score_test. Treebanks are specified as a list. all_ud or ud_all means to look for all UD treebanks. Extra arguments are passed to tokenizer. In case the run script itself is shadowing arguments, you can specify --extra_args as a parameter to mark where the tokenizer arguments start. Default behavior is to discard the output and just print the results. To keep the results instead, use --save_output """ import io import logging import math import os from stanza.models import tokenizer from stanza.models.common.doc import Document from stanza.utils.avg_sent_len import avg_sent_len from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_tokenizer_charlm_args logger = logging.getLogger('stanza') def add_tokenizer_args(parser): add_charlm_args(parser) def build_model_filename(paths, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) # TODO: can avoid downloading the charlm at this point, since we # might not even be training charlm_args = build_tokenizer_charlm_args(short_language, dataset, command_args.charlm) train_args = ["--shorthand", short_name, "--mode", "train"] train_args = train_args + charlm_args + extra_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: train_args.extend(["--save_dir", command_args.save_dir]) args = tokenizer.parse_args(train_args) save_name = tokenizer.model_file_name(args) return save_name def uses_dictionary(short_language): """ Some of the languages (as shown here) have external dictionaries We found this helped the overall tokenizer performance If these can't be found, they can be extracted from the previous iteration of models """ if short_language in ('ja', 'th', 'zh', 'zh-hans', 'zh-hant'): return True return False def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): tokenize_dir = paths["TOKENIZE_DATA_DIR"] short_language, dataset = short_name.split("_", 1) label_type = "--label_file" label_file = f"{tokenize_dir}/{short_name}-ud-train.toklabels" dev_type = "--txt_file" dev_file = f"{tokenize_dir}/{short_name}.dev.txt" test_type = "--txt_file" test_file = f"{tokenize_dir}/{short_name}.test.txt" train_type = "--txt_file" train_file = f"{tokenize_dir}/{short_name}.train.txt" train_dev_args = ["--dev_txt_file", dev_file, "--dev_label_file", f"{tokenize_dir}/{short_name}-ud-dev.toklabels"] if short_language == "zh" or short_language.startswith("zh-"): extra_args = ["--skip_newline"] + extra_args train_gold = f"{tokenize_dir}/{short_name}.train.gold.conllu" dev_gold = f"{tokenize_dir}/{short_name}.dev.gold.conllu" test_gold = f"{tokenize_dir}/{short_name}.test.gold.conllu" train_mwt = f"{tokenize_dir}/{short_name}-ud-train-mwt.json" dev_mwt = f"{tokenize_dir}/{short_name}-ud-dev-mwt.json" test_mwt = f"{tokenize_dir}/{short_name}-ud-test-mwt.json" train_pred = f"{tokenize_dir}/{short_name}.train.pred.conllu" dev_pred = f"{tokenize_dir}/{short_name}.dev.pred.conllu" test_pred = f"{tokenize_dir}/{short_name}.test.pred.conllu" charlm_args = build_tokenizer_charlm_args(short_language, dataset, command_args.charlm) if mode == Mode.TRAIN: seqlen = str(math.ceil(avg_sent_len(label_file) * 3 / 100) * 100) train_args = ([label_type, label_file, train_type, train_file, "--lang", short_language, "--max_seqlen", seqlen, "--mwt_json_file", dev_mwt] + train_dev_args + ["--dev_conll_gold", dev_gold, "--shorthand", short_name]) if uses_dictionary(short_language): train_args = train_args + ["--use_dictionary"] train_args = train_args + charlm_args + extra_args logger.info("Running train step with args: {}".format(train_args)) tokenizer.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--mode", "predict", dev_type, dev_file, "--lang", short_language, "--shorthand", short_name, "--mwt_json_file", dev_mwt] if command_args.save_output: dev_args.extend(["--conll_file", dev_pred]) dev_args = dev_args + charlm_args + extra_args logger.info("Running dev step with args: {}".format(dev_args)) _, dev_doc = tokenizer.main(dev_args) # TODO: log these results? The original script logged them to # echo $results $args >> ${TOKENIZE_DATA_DIR}/${short}.results if not command_args.save_output: dev_pred = "{:C}\n\n".format(Document(dev_doc)) dev_pred = io.StringIO(dev_pred) results = common.run_eval_script_tokens(dev_gold, dev_pred) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ["--mode", "predict", test_type, test_file, "--lang", short_language, "--shorthand", short_name, "--mwt_json_file", test_mwt] if command_args.save_output: test_args.extend(["--conll_file", test_pred]) test_args = test_args + charlm_args + extra_args logger.info("Running test step with args: {}".format(test_args)) _, test_doc = tokenizer.main(test_args) if not command_args.save_output: test_pred = "{:C}\n\n".format(Document(test_doc)) test_pred = io.StringIO(test_pred) results = common.run_eval_script_tokens(test_gold, test_pred) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) if mode == Mode.SCORE_TRAIN: test_args = ["--mode", "predict", test_type, train_file, "--lang", short_language, "--shorthand", short_name, "--mwt_json_file", train_mwt] if command_args.save_output: test_args.extend(["--conll_file", train_pred]) test_args = test_args + charlm_args + extra_args logger.info("Running test step with args: {}".format(test_args)) _, train_doc = tokenizer.main(test_args) if not command_args.save_output: train_pred = "{:C}\n\n".format(Document(train_doc)) train_pred = io.StringIO(train_pred) results = common.run_eval_script_tokens(train_gold, train_pred) logger.info("Finished running train set as a test on\n{}\n{}".format(treebank, results)) def main(): common.main(run_treebank, "tokenize", "tokenizer", add_tokenizer_args, sub_argparse=tokenizer.build_argparse(), build_model_filename=build_model_filename) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/training/separate_ner_pretrain.py ================================================ """ Loads NER models & separates out the word vectors to base & delta The model will then be resaved without the base word vector, greatly reducing the size of the model This may be useful for any external users of stanza who have an NER model they wish to reuse without retraining If you know which pretrain was used to build an NER model, you can provide that pretrain. Otherwise, you can give a directory of pretrains and the script will test each one. In the latter case, the name of the pretrain needs to look like lang_dataset_pretrain.pt """ import argparse from collections import defaultdict import logging import os import numpy as np import torch import torch.nn as nn from stanza import Pipeline from stanza.models.common.constant import lang_to_langcode from stanza.models.common.pretrain import Pretrain, PretrainedWordVocab from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX from stanza.models.ner.trainer import Trainer logger = logging.getLogger('stanza') logger.setLevel(logging.ERROR) DEBUG = False EPS = 0.0001 def main(): parser = argparse.ArgumentParser() parser.add_argument('--input_path', type=str, default='saved_models/ner', help='Where to find NER models (dir or filename)') parser.add_argument('--output_path', type=str, default='saved_models/shrunk', help='Where to write shrunk NER models (dir)') parser.add_argument('--pretrain_path', type=str, default='saved_models/pretrain', help='Where to find pretrains (dir or filename)') args = parser.parse_args() # get list of NER models to shrink if os.path.isdir(args.input_path): ner_model_dir = args.input_path ners = os.listdir(ner_model_dir) if len(ners) == 0: raise FileNotFoundError("No ner models found in {}".format(args.input_path)) else: if not os.path.isfile(args.input_path): raise FileNotFoundError("No ner model found at path {}".format(args.input_path)) ner_model_dir, ners = os.path.split(args.input_path) ners = [ners] # get map from language to candidate pretrains if os.path.isdir(args.pretrain_path): pt_model_dir = args.pretrain_path pretrains = os.listdir(pt_model_dir) lang_to_pretrain = defaultdict(list) for pt in pretrains: lang_to_pretrain[pt.split("_")[0]].append(pt) else: pt_model_dir, pretrains = os.path.split(pt_model_dir) pretrains = [pretrains] lang_to_pretrain = defaultdict(lambda: pretrains) # shrunk models will all go in this directory new_dir = args.output_path os.makedirs(new_dir, exist_ok=True) final_pretrains = [] missing_pretrains = [] no_finetune = [] # for each model, go through the various pretrains # until we find one that works or none of them work for ner_model in ners: ner_path = os.path.join(ner_model_dir, ner_model) expected_ending = "_nertagger.pt" if not ner_model.endswith(expected_ending): raise ValueError("Unexpected name: {}".format(ner_model)) short_name = ner_model[:-len(expected_ending)] lang, package = short_name.split("_", maxsplit=1) print("===============================================") print("Processing lang %s package %s" % (lang, package)) # this may look funny - basically, the pipeline has machinery # to make sure the model has everything it needs to load, # including downloading other pieces if needed pipe = Pipeline(lang, processors="tokenize,ner", tokenize_pretokenized=True, package={"ner": package}, ner_model_path=ner_path) ner_processor = pipe.processors['ner'] print("Loaded NER processor: {}".format(ner_processor)) trainer = ner_processor.trainers[0] vocab = trainer.model.vocab word_vocab = vocab['word'] num_vectors = trainer.model.word_emb.weight.shape[0] # sanity check, make sure the model loaded matches the # language from the model's filename lcode = lang_to_langcode(trainer.args['lang']) if lang != lcode and not (lcode == 'zh' and lang == 'zh-hans'): raise ValueError("lang not as expected: {} vs {} ({})".format(lang, trainer.args['lang'], lcode)) ner_pretrains = sorted(set(lang_to_pretrain[lang] + lang_to_pretrain[lcode])) for pt_model in ner_pretrains: pt_path = os.path.join(pt_model_dir, pt_model) print("Attempting pretrain: {}".format(pt_path)) pt = Pretrain(filename=pt_path) print(" pretrain shape: {}".format(pt.emb.shape)) print(" embedding in ner model shape: {}".format(trainer.model.word_emb.weight.shape)) if pt.emb.shape[1] != trainer.model.word_emb.weight.shape[1]: print(" DIMENSION DOES NOT MATCH. SKIPPING") continue N = min(pt.emb.shape[0], trainer.model.word_emb.weight.shape[0]) if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]: # If the vocab was exactly the same, that's a good # sign this pretrain was used, just with a different size # In such a case, we can reuse the rest of the pretrain # Minor issue: some vectors which were trained will be # lost in the case of |pt| < |model.word_emb| if all(word_vocab.id2unit(x) == word_vocab.id2unit(x) for x in range(N)): print(" Attempting to use pt vectors to replace ner model's vectors") else: print(" NUM VECTORS DO NOT MATCH. WORDS DO NOT MATCH. SKIPPING") continue if pt.emb.shape[0] < trainer.model.word_emb.weight.shape[0]: print(" WARNING: if any vectors beyond {} were fine tuned, that fine tuning will be lost".format(N)) device = next(trainer.model.parameters()).device delta = trainer.model.word_emb.weight[:N, :] - pt.emb.to(device)[:N, :] delta = delta.detach() delta_norms = torch.linalg.norm(delta, dim=1).cpu().numpy() if np.sum(delta_norms < 0) > 0: raise ValueError("This should not be - a norm was less than 0!") num_matching = np.sum(delta_norms < EPS) if num_matching > N / 2: print(" Accepted! %d of %d vectors match for %s" % (num_matching, N, pt_path)) if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]: print(" Setting model vocab to match the pretrain") word_vocab = pt.vocab vocab['word'] = word_vocab trainer.args['word_emb_dim'] = pt.emb.shape[1] break else: print(" %d of %d vectors matched for %s - SKIPPING" % (num_matching, N, pt_path)) vocab_same = sum(x in pt.vocab for x in word_vocab) print(" %d words were in both vocabs" % vocab_same) # this is expensive, and in practice doesn't happen, # but theoretically we might have missed a mostly matching pt # if the vocab had been scrambled if DEBUG: rearranged_count = 0 for x in word_vocab: if x not in pt.vocab: continue x_id = word_vocab.unit2id(x) x_vec = trainer.model.word_emb.weight[x_id, :] pt_id = pt.vocab.unit2id(x) pt_vec = pt.emb[pt_id, :] if (x_vec.detach().cpu() - pt_vec).norm() < EPS: rearranged_count += 1 print(" %d vectors were close when ignoring id ordering" % rearranged_count) else: print("COULD NOT FIND A MATCHING PT: {}".format(ner_processor)) missing_pretrains.append(ner_model) continue # build a delta vector & embedding assert 'delta' not in vocab.keys() delta_vectors = [delta[i].cpu() for i in range(4)] delta_vocab = [] for i in range(4, len(delta_norms)): if delta_norms[i] > 0.0: delta_vocab.append(word_vocab.id2unit(i)) delta_vectors.append(delta[i].cpu()) trainer.model.unsaved_modules.append("word_emb") if len(delta_vocab) == 0: print("No vectors were changed! Perhaps this model was trained without finetune.") no_finetune.append(ner_model) else: print("%d delta vocab" % len(delta_vocab)) print("%d vectors in the delta set" % len(delta_vectors)) delta_vectors = np.stack(delta_vectors) delta_vectors = torch.from_numpy(delta_vectors) assert delta_vectors.shape[0] == len(delta_vocab) + len(VOCAB_PREFIX) print(delta_vectors.shape) delta_vocab = PretrainedWordVocab(delta_vocab, lang=word_vocab.lang, lower=word_vocab.lower) vocab['delta'] = delta_vocab trainer.model.delta_emb = nn.Embedding(delta_vectors.shape[0], delta_vectors.shape[1], PAD_ID) trainer.model.delta_emb.weight.data.copy_(delta_vectors) new_path = os.path.join(new_dir, ner_model) trainer.save(new_path) final_pretrains.append((ner_model, pt_model)) print() if len(final_pretrains) > 0: print("Final pretrain mappings:") for i in final_pretrains: print(i) if len(missing_pretrains) > 0: print("MISSING EMBEDDINGS:") for i in missing_pretrains: print(i) if len(no_finetune) > 0: print("NOT FINE TUNED:") for i in no_finetune: print(i) if __name__ == '__main__': main() ================================================ FILE: stanza/utils/visualization/README ================================================ # Overview The code in this directory contains tooling required for Semgrex and Ssurgeon visualization. Searching dependency graphs and manipulating them can be a time consuming and challenging task to get right. Semgrex is a system for searching dependency graphs and Ssurgeon is a system for manipulating the output of Semgrex. The compact language used by these systems allows for easy command line or API processing of dependencies. We now offer Semgrex and Ssurgeon through a web interface, now accessible via Streamlit with visualizations. ## How to run visualizations through Streamlit Streamlit can be used to visualize Semgrex and Ssurgeon results and process files. Here are instructions for setting up a Streamlit webpage: 1. install Streamlit. `pip install streamlit` 2. install Stanford CoreNLP if you have not. You can find an installation here: https://stanfordnlp.github.io/CoreNLP/download.html 3. set the $CLASSPATH environment variable to your local installation of CoreNLP. 4. install streamlit, spacy, and ipython. You can use the "visualization" stanza setup option for that 5. Run `streamlit run stanza/utils/visualization/semgrex_app.py --theme.backgroundColor "#FFFFFF"` This should begin a Streamlit runtime application on your local machine that can be interacted with. For instructions on how to use Ssurgeon and Semgrex, refer to these helpful pages: https://aclanthology.org/2023.tlt-1.7.pdf https://nlp.stanford.edu/nlp/javadoc/javanlp-3.5.0/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html https://stanfordnlp.github.io/stanza/client_regex.html https://stanfordnlp.github.io/CoreNLP/corenlp-server.html#query-tokensregex-tokensregex ================================================ FILE: stanza/utils/visualization/__init__.py ================================================ ================================================ FILE: stanza/utils/visualization/conll_deprel_visualization.py ================================================ from stanza.models.common.constant import is_right_to_left import spacy import argparse from spacy import displacy from spacy.tokens import Doc from stanza.utils import conll from stanza.utils.visualization import dependency_visualization as viz def conll_to_visual(conll_file, pipeline, sent_count=10, display_all=False): """ Takes in a conll file and visualizes it by converting the conll file to a Stanza Document object and visualizing it with the visualize_doc method. Input should be a proper conll file. The pipeline for the conll file to be processed in must be provided as well. Optionally, the sent_count argument can be tweaked to display a different amount of sentences. To display all of the sentences in a conll file, the display_all argument can optionally be set to True. BEWARE: setting this argument for a large conll file may result in too many renderings, resulting in a crash. """ # convert conll file to doc doc = conll.CoNLL.conll2doc(conll_file) if display_all: viz.visualize_doc(conll.CoNLL.conll2doc(conll_file), pipeline) else: # visualize a given number of sentences visualization_options = {"compact": True, "bg": "#09a3d5", "color": "white", "distance": 100, "font": "Source Sans Pro", "offset_x": 30, "arrow_spacing": 20} # see spaCy visualization settings doc for more options nlp = spacy.blank("en") sentences_to_visualize, rtl, num_sentences = [], is_right_to_left(pipeline), len(doc.sentences) for i in range(sent_count): if i >= num_sentences: # case where there are less sentences than amount requested break sentence = doc.sentences[i] words, lemmas, heads, deps, tags = [], [], [], [], [] sentence_words = sentence.words if rtl: # rtl languages will be visually rendered from right to left as well sentence_words = reversed(sentence.words) sent_len = len(sentence.words) for word in sentence_words: words.append(word.text) lemmas.append(word.lemma) deps.append(word.deprel) tags.append(word.upos) if rtl and word.head == 0: # word heads are off-by-1 in spaCy doc inits compared to Stanza heads.append(sent_len - word.id) elif rtl and word.head != 0: heads.append(sent_len - word.head) elif not rtl and word.head == 0: heads.append(word.id - 1) elif not rtl and word.head != 0: heads.append(word.head - 1) document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags) sentences_to_visualize.append(document_result) print(sentences_to_visualize) for line in sentences_to_visualize: # render all sentences through displaCy displacy.render(line, style="dep", options=visualization_options) def main(): parser = argparse.ArgumentParser() parser.add_argument('--conll_file', type=str, default="C:\\Users\\Alex\\stanza\\demo\\en_test.conllu.txt", help="File path of the CoNLL file to visualize dependencies of") parser.add_argument('--pipeline', type=str, default="en", help="Language code of the language pipeline to use (ex: 'en' for English)") parser.add_argument('--sent_count', type=int, default=10, help="Number of sentences to visualize from CoNLL file") parser.add_argument('--display_all', type=bool, default=False, help="Whether or not to visualize all of the sentences from the file. Overrides sent_count if set to True") args = parser.parse_args() conll_to_visual(args.conll_file, args.pipeline, args.sent_count, args.display_all) return if __name__ == "__main__": main() ================================================ FILE: stanza/utils/visualization/constants.py ================================================ """ Constants used for visualization tooling """ # Ssurgeon constants SAMPLE_SSURGEON_DOC = """ # sent_id = 271 # text = Hers is easy to clean. # previous = What did the dealer like about Alex's car? # comment = extraction/raising via "tough extraction" and clausal subject 1 Hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nsubj _ _ 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _ 3 easy easy ADJ JJ Degree=Pos 0 root _ _ 4 to to PART TO _ 5 mark _ _ 5 clean clean VERB VB VerbForm=Inf 3 csubj _ SpaceAfter=No 6 . . PUNCT . _ 5 punct _ _ """ # Semgrex constants DEFAULT_SAMPLE_TEXT = "Banning opal removed artifact decks from the meta." DEFAULT_SEMGREX_QUERY = "{pos:NN}=object original order while next_word_index <= len(sentence.words) - 1 and sentence.words[next_word_index].text.isascii(): to_append.append(sentence.words[next_word_index].text[::-1]) next_word_index += 1 to_append = reversed(to_append) for token in to_append: words.append(token) already_found = True elif rtl and word.text.isascii() and already_found: # skip over already collected words continue else: # arabic chars words.append(word.text) already_found = False document = Doc(model.vocab, words=words) # tag all NER tokens found for ent in sentence.ents: if select and ent.type not in select: continue found_indexes = [] for token in ent.tokens: found_indexes.append(token.id[0] - 1) if not rtl: to_add = Span(document, found_indexes[0], found_indexes[-1] + 1, ent.type) else: # RTL languages need the override char to flip order to_add = Span(document, found_indexes[0], found_indexes[-1] + 1, RTL_OVERRIDE + ent.type[::-1]) display_ents.append(to_add) document.set_ents(display_ents) documents.append(document) # Visualize doc objects visualization_options = {"ents": select} if colors: visualization_options["colors"] = visualization_colors for document in documents: displacy.render(document, style='ent', options=visualization_options) def visualize_ner_str(text, pipe, select=None, colors=None): """ Takes in a text string and visualizes the named entities within the text. Required args also include a pipeline code, the two-letter code for a language defined by Universal Dependencies (ex: "en" for English). Lastly, the user must provide an NLP pipeline - we recommend Stanza (ex: pipe = stanza.Pipeline('en')). Optionally, the 'select' argument allows for specific NER tags to be highlighted; the 'color' argument allows for specific NER tags to have certain color(s). """ doc = pipe(text) visualize_ner_doc(doc, pipe.lang, select, colors) def visualize_strings(texts, language_code, select=None, colors=None): """ Takes in a list of strings and a language code (Stanza defines these, ex: 'en' for English) to visualize all of the strings' named entities. The strings are processed by the Stanza pipeline and the named entities are displayed. Each text is separated by a delimiting line. Optionally, the 'select' argument may be configured to only visualize given named entities (ex: select=['ORG', 'PERSON']). The optional colors argument is formatted as a dictionary of NER tags with their corresponding colors, which can be represented as a string (ex: "blue"), a color hex value (ex: #aa9cfc), or as a linear gradient of color values (ex: "linear-gradient(90deg, #aa9cfc, #fc9ce7)"). """ lang_pipe = stanza.Pipeline(language_code, processors="tokenize,ner") for text in texts: visualize_ner_str(text, lang_pipe, select=select, colors=colors) def visualize_docs(docs, language_code, select=None, colors=None): """ Takes in a list of doc and a language code (Stanza defines these, ex: 'en' for English) to visualize all of the strings' named entities. Each text is separated by a delimiting line. Optionally, the 'select' argument may be configured to only visualize given named entities (ex: select=['ORG', 'PERSON']). The optional colors argument is formatted as a dictionary of NER tags with their corresponding colors, which can be represented as a string (ex: "blue"), a color hex value (ex: #aa9cfc), or as a linear gradient of color values (ex: "linear-gradient(90deg, #aa9cfc, #fc9ce7)"). """ for doc in docs: visualize_ner_doc(doc, language_code, select=select, colors=colors) def main(): en_strings = ['''Samuel Jackson, a Christian man from Utah, went to the JFK Airport for a flight to New York. He was thinking of attending the US Open, his favorite tennis tournament besides Wimbledon. That would be a dream trip, certainly not possible since it is $5000 attendance and 5000 miles away. On the way there, he watched the Super Bowl for 2 hours and read War and Piece by Tolstoy for 1 hour. In New York, he crossed the Brooklyn Bridge and listened to the 5th symphony of Beethoven as well as "All I want for Christmas is You" by Mariah Carey.''', "Barack Obama was born in Hawaii. He was elected President of the United States in 2008"] zh_strings = ['''来自犹他州的基督徒塞缪尔杰克逊前往肯尼迪机场搭乘航班飞往纽约。 他正在考虑参加美国公开赛,这是除了温布尔登之外他最喜欢的网球赛事。 那将是一次梦想之旅,当然不可能,因为它的出勤费为 5000 美元,距离 5000 英里。 在去的路上,他看了 2 个小时的超级碗比赛,看了 1 个小时的托尔斯泰的《战争与碎片》。 在纽约,他穿过布鲁克林大桥,聆听了贝多芬的第五交响曲以及 玛丽亚凯莉的“圣诞节我想要的就是你”。''', "我觉得罗家费德勒住在加州, 在美国里面。"] ar_strings = [ ".أعيش في سان فرانسيسكو ، كاليفورنيا. اسمي أليكس وأنا ألتحق بجامعة ستانفورد. أنا أدرس علوم الكمبيوتر وأستاذي هو كريس مانينغ" , "اسمي أليكس ، أنا من الولايات المتحدة.", '''صامويل جاكسون ، رجل مسيحي من ولاية يوتا ، ذهب إلى مطار جون كنيدي في رحلة إلى نيويورك. كان يفكر في حضور بطولة الولايات المتحدة المفتوحة للتنس ، بطولة التنس المفضلة لديه إلى جانب بطولة ويمبلدون. ستكون هذه رحلة الأحلام ، وبالتأكيد ليست ممكنة لأنها تبلغ 5000 دولار للحضور و 5000 ميل. في الطريق إلى هناك ، شاهد Super Bowl لمدة ساعتين وقرأ War and Piece by Tolstoy لمدة ساعة واحدة. في نيويورك ، عبر جسر بروكلين واستمع إلى السيمفونية الخامسة لبيتهوفن وكذلك "كل ما أريده في عيد الميلاد هو أنت" لماريا كاري.'''] visualize_strings(en_strings, "en") visualize_strings(zh_strings, "zh", colors={"PERSON": "yellow", "DATE": "red", "GPE": "blue"}) visualize_strings(zh_strings, "zh", select=['PERSON', 'DATE']) visualize_strings(ar_strings, "ar", colors={"PER": "pink", "LOC": "linear-gradient(90deg, #aa9cfc, #fc9ce7)", "ORG": "yellow"}) if __name__ == "__main__": main() ================================================ FILE: stanza/utils/visualization/semgrex_app.py ================================================ import os import sys import streamlit as st import streamlit.components.v1 as components import stanza.utils.visualization.ssurgeon_visualizer as ssv import logging from stanza.utils.visualization.semgrex_visualizer import visualize_search_str from stanza.utils.visualization.semgrex_visualizer import edit_html_overflow from stanza.utils.visualization.constants import * from stanza.utils.conll import CoNLL from stanza.server.ssurgeon import * from stanza.pipeline.core import Pipeline from io import StringIO import os from typing import List, Tuple, Any import argparse def get_semgrex_text_and_query() -> Tuple[str, str]: """ Gets user input for the Semgrex text and queries to process. @return: A tuple containing the user's input text and their input queries """ input_txt = st.text_area( "Text to analyze", DEFAULT_SAMPLE_TEXT, placeholder=DEFAULT_SAMPLE_TEXT, ) input_queries = st.text_area( "Semgrex search queries (separate each query with a comma)", DEFAULT_SEMGREX_QUERY, placeholder=DEFAULT_SEMGREX_QUERY, ) return input_txt, input_queries def get_file_input() -> List[str]: """ Allows user to submit files for analysis. @return: List of strings containing the file contents of each submitted file. The i-th element of res is the string representing the i-th file uploaded. """ st.markdown("""**Alternatively, upload file(s) to analyze.**""") uploaded_files = st.file_uploader( "button_label", accept_multiple_files=True, label_visibility="collapsed" ) res = [] for file in uploaded_files: stringio = StringIO(file.getvalue().decode("utf-8")) string_data = stringio.read() res.append(string_data) return res def get_semgrex_window_input() -> Tuple[bool, int, int]: """ Allows user to specify a specific window of Semgrex hits to visualize. Works similar to Python splicing. @return: A tuple containing a bool representing whether or not the user wants to visualize a splice of the visualizations, and two ints representing the start and end indices of the splice. """ show_window = st.checkbox( "Visualize a specific window of Semgrex search hits?", help="""If you want to visualize all search results, leave this unmarked.""", ) start_window, end_window = None, None if show_window: start_window = st.number_input( "Which search hit should visualizations start from?", help="""If you want to visualize the first 10 search results, set this to 0.""", min_value=0, ) end_window = st.number_input( "Which search hit should visualizations stop on?", help="""If you want to visualize the first 10 search results, set this to 11. The 11th result will NOT be displayed.""", value=11, min_value=start_window + 1, ) return show_window, start_window, end_window def get_pos_input() -> bool: """ Prompts client for whether they want to see xpos tags instead of upos. """ use_xpos = st.checkbox("Would you like to visualize xpos tags?", help="The default visualization options use upos tags for part-of-speech labeling. If xpos tags aren't available for the sentence, displays upos.") return use_xpos def get_input() -> Tuple[str, str, List[str], Tuple[bool, int, int, bool]]: """ Tie together all inputs to query user for all possible inputs. """ input_txt, input_queries = get_semgrex_text_and_query() client_files = get_file_input() # this is already converted to string format window_input = get_semgrex_window_input() visualize_xpos = get_pos_input() return input_txt, input_queries, client_files, window_input, visualize_xpos def run_semgrex_process( input_txt: str, input_queries: str, client_files: List[str], show_window: bool, clicked: bool, pipe: Any, start_window: int, end_window: int, visualize_xpos: bool, show_success: bool = True ) -> None: """ Run Semgrex search on the input text/files with input query and serve the HTML on the app. @param input_txt: Text to analyze and draw sentences from. @param input_queries: Semgrex queries to parse the input with. @param client_files: Alternative to input text, we can parse the content of files for scaled analysis. @param show_window: Whether or not the user wants a splice of the visualizations @param clicked: Whether or not the button has been clicked to run Semgrex search @param pipe: NLP pipeline to process input with @param start_window: If displaying a splice of visualizations, this is the start idx @param end_window: If displaying a splice of visualizations, this is the end idx @param visualize_xpos: Set to true if using xpos tags for part of speech labels, otherwise use upos tags """ if clicked: # process inputs, reject bad ones if not input_txt and not client_files: st.error("Please provide a text input or upload files for analysis.") elif input_txt and client_files: st.error( "Please only choose to visualize your input text or your uploaded files, not both." ) elif not input_queries: st.error("Please provide a set of Semgrex queries.") else: # no input errors try: with st.spinner("Processing..."): queries = [ query.strip() for query in input_queries.split(",") ] # separate queries into individual parts if client_files: html_strings, begin_viz_idx, end_viz_idx = [], 0, float("inf") if show_window: begin_viz_idx, end_viz_idx = ( start_window - 1, end_window - 1, ) for client_file in client_files: client_file_html_strings = visualize_search_str( client_file, queries, "en", start_match=begin_viz_idx, end_match=end_viz_idx, pipe=pipe, visualize_xpos=visualize_xpos ) html_strings += client_file_html_strings else: # just input text, no files if show_window: html_strings = visualize_search_str( input_txt, queries, "en", start_match=start_window - 1, end_match=end_window - 1, pipe=pipe, visualize_xpos=visualize_xpos ) else: html_strings = visualize_search_str( input_txt, queries, "en", end_match=float("inf"), pipe=pipe, visualize_xpos=visualize_xpos ) if len(html_strings) == 0: st.write("No Semgrex match hits!") # Render successful Semgrex results for s in html_strings: s_no_overflow = edit_html_overflow(s) components.html( s_no_overflow, height=200, width=1000, scrolling=True ) if show_success: if len(html_strings) == 1: st.success( f"Completed! Visualized {len(html_strings)} Semgrex search hit." ) else: st.success( f"Completed! Visualized {len(html_strings)} Semgrex search hits." ) except OSError: st.error( "Your text input or your provided Semgrex queries are incorrect. Please try again." ) def semgrex_state(): """ Contains the Semgrex portion of the webpage. This contains the markdown and calls to the processes which run when a query is made. When the `Load Semgrex search visualization` button is pressed, the function `run_semgrex_process` is called inside this function and the rendered visual is placed onto the webpage. """ # Title Markdown for page header st.title("Displaying Semgrex Queries") html_string = ( "

    Enter a text below, along with your Semgrex query of choice.

    " ) st.markdown(html_string, unsafe_allow_html=True) input_txt, input_queries, client_files, window_input, visualize_xpos = get_input() show_window, start_window, end_window = window_input clicked = st.button( "Load Semgrex search visualization", help="""Semgrex search visualizations only display sentences with a query match. Non-matching sentences are not shown.""", ) # use the on_click param run_semgrex_process( input_txt=input_txt, input_queries=input_queries, client_files=client_files, show_window=show_window, clicked=clicked, pipe=st.session_state["pipeline"], start_window=start_window, end_window=end_window, visualize_xpos=visualize_xpos ) def ssurgeon_state(): """ Contains the ssurgeon state for the webpage. This contains the markdown and calls the processes that run Ssurgeon operations. When the text boxes, buttons, or other interactable features are edited by the user, this function runs with the updated page state and conducts operations (e.g. runs a Ssurgeon operation on a submitted file) """ st.title("Displaying Ssurgeon Results") # Textbox for input to SSurgeon (text) input_txt = st.text_area( "Text to analyze", SAMPLE_SSURGEON_DOC, placeholder=SAMPLE_SSURGEON_DOC, ) # Textbox for input queries to SSurgeon (commands + queries) semgrex_input_queries = st.text_area( "Semgrex search queries (separate each query with a comma)", "{}=source >nsubj {} >csubj=bad {}", placeholder="""{}=source >nsubj {} >csubj=bad {}""", ) ssurgeon_input_queries = st.text_area( "Ssurgeon commands", "relabelNamedEdge -edge bad -reln advcl", placeholder="relabelNamedEdge -edge bad -reln advcl" ) # File uploading box st.markdown("""**Alternatively, upload file(s) to edit.**""") uploaded_files = st.file_uploader( "", accept_multiple_files=True, label_visibility="collapsed" ) res = [] # Convert uploaded files to strings for processing for file in uploaded_files: stringio = StringIO(file.getvalue().decode("utf-8")) string_data = stringio.read() res.append(string_data) # Input button to trigger processing phase clicked = st.button( "Load Ssurgeon visualization", ) clicked_for_file_edit = st.button( "Edit File" ) # Once the user requests the Ssurgeon operation, run this block: if clicked: try: with st.spinner("Processing..."): semgrex_queries = semgrex_input_queries # separate queries into individual parts ssurgeon_queries = [ssurgeon_input_queries] # use SSurgeon to edit the deprel, get the HTML for new relations html_strings = ssv.visualize_ssurgeon_deprel_adjusted_str_input(input_txt, semgrex_queries, ssurgeon_queries) doc = CoNLL.conll2doc(input_str=input_txt) string_txt = " ".join([word.text for sentence in doc.sentences for word in sentence.words]) # Render pre-edited input html_string = ( "

    Previous deprel visualization:

    " ) st.markdown(html_string, unsafe_allow_html=True) components.html( run_semgrex_process(input_txt=string_txt, input_queries=semgrex_queries, clicked=clicked, show_window=False, client_files=[], pipe=st.session_state["pipeline"], start_window=1, end_window=11, visualize_xpos=False, show_success=False) ) if len(html_strings) == 0: st.write("No Semgrex match hits!") # Render edited outputs for s in html_strings: html_string = ( "

    Edited deprel visualization:

    " ) st.markdown(html_string, unsafe_allow_html=True) s_no_overflow = edit_html_overflow(s) components.html( s_no_overflow, height=200, width=1000, scrolling=True ) except OSError: st.error( "Your text input or your provided Semgrex/Ssurgeon queries are incorrect. Please try again." ) # If the input is a file instead of raw text, process the file with Ssurgeon and give an output # that can be downloaded by the client if clicked_for_file_edit: # files are in res if len(res) == 0: st.error("You must provide files for analysis.") with st.spinner("Editing..."): single_file = res[0] doc = CoNLL.conll2doc(input_str=single_file) ssurgeon_response = process_doc_one_operation(doc, semgrex_input_queries, [ssurgeon_input_queries]) updated_doc = convert_response_to_doc(doc, ssurgeon_response) output = CoNLL.doc2conll(updated_doc)[0] output_str = "\n".join(output) st.download_button("Download your edited file", data=output_str, file_name="SSurgeon.conll") def main(): parser = argparse.ArgumentParser() parser.add_argument( "--CLASSPATH", type=str, default=os.environ.get("CLASSPATH", None), help="Path to your CoreNLP directory.", ) # for example, set $CLASSPATH to "C:\\stanford-corenlp-4.5.2\\stanford-corenlp-4.5.2\\*" args = parser.parse_args() CLASSPATH = args.CLASSPATH os.environ["CLASSPATH"] = CLASSPATH if os.environ.get("CLASSPATH") is None: logging.error("Provide a valid $CLASSPATH value (path to your CoreNLP installation).") raise ValueError("Provide a valid $CLASSPATH value (path to your CoreNLP installation).") # run pipeline once per user session if "pipeline" not in st.session_state: en_nlp_stanza = Pipeline( "en", processors="tokenize, pos, lemma, depparse" ) st.session_state["pipeline"] = en_nlp_stanza #### Below is the webpage states that run. Streamlit operates by having the rendered HTML and when the user interacts with # the page, these states are run once more with their internal states possibly altered (e.g. user clicks a button). semgrex_state() ssurgeon_state() if __name__ == "__main__": main() ================================================ FILE: stanza/utils/visualization/semgrex_visualizer.py ================================================ import os import argparse import sys root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) sys.path.append(root_dir) from stanza.pipeline.core import Pipeline from stanza.server.semgrex import Semgrex from stanza.models.common.constant import is_right_to_left import spacy from spacy import displacy from spacy.tokens import Doc from IPython.display import display, HTML import typing from typing import List, Tuple, Any from stanza.utils.visualization.utils import find_nth, round_base def get_sentences_html(doc: Any, language: str, visualize_xpos: bool = False) -> List[str]: """ Returns a list of HTML strings representing the dependency visualizations of a given stanza document. One HTML string is generated per sentence of the document object. Converts the stanza document object to a spaCy doc object and generates HTML with displaCy. @param doc: a stanza document object which can be generated with an NLP pipeline. @param language: the two letter language code for the document e.g. "en" for English. @param visualize_xpos: A toggled option to use xpos tags for part-of-speech labels instead of upos. @return: a list of HTML strings which visualize the dependencies of the doc object. """ USE_FINE_GRAINED = False if not visualize_xpos else True html_strings, sentences_to_visualize = [], [] nlp = spacy.blank( "en" ) # blank model - we don't use any of the model features, just the visualization for sentence in doc.sentences: words, lemmas, heads, deps, tags = [], [], [], [], [] if is_right_to_left( language ): # order of words displayed is reversed, dependency arcs remain intact sentence_len = len(sentence.words) for word in reversed(sentence.words): words.append(word.text) lemmas.append(word.lemma) deps.append(word.deprel) if visualize_xpos and word.xpos: tags.append(word.xpos) else: tags.append(word.upos) if word.head == 0: # spaCy head indexes are one-off from Stanza's heads.append(sentence_len - word.id) else: heads.append(sentence_len - word.head) else: # left to right rendering for word in sentence.words: words.append(word.text) lemmas.append(word.lemma) deps.append(word.deprel) if visualize_xpos and word.xpos: tags.append(word.xpos) else: tags.append(word.upos) if word.head == 0: heads.append(word.id - 1) else: heads.append(word.head - 1) if USE_FINE_GRAINED: stanza_to_spacy_doc = Doc( nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, tags=tags ) else: stanza_to_spacy_doc = Doc( nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags ) sentences_to_visualize.append(stanza_to_spacy_doc) for line in sentences_to_visualize: # render all sentences through displaCy html_strings.append( displacy.render( line, style="dep", options={ "compact": True, "word_spacing": 30, "distance": 100, "arrow_spacing": 20, "fine_grained": USE_FINE_GRAINED }, jupyter=False, ) ) return html_strings def semgrexify_html(orig_html: str, semgrex_sentence) -> str: """ Modifies the HTML of a sentence's dependency visualization, highlighting words involved in the semgrex_sentence search queries and adding the label of the word inside of the match. @param orig_html: unedited HTML of a sentence's dependency visualization. @param semgrex_sentence: a Semgrex result object containing the matches to a provided query. @return: edited HTML containing the visual changes described above. """ tracker = {} # keep track of which words have multiple labels DEFAULT_TSPAN_COUNT = ( 2 # the original displacy html assigns two objects per object ) CLOSING_TSPAN_LEN = 8 # is 8 chars long colors = [ "#4477AA", "#66CCEE", "#228833", "#CCBB44", "#EE6677", "#AA3377", "#BBBBBB", ] # colorblind-friendly scheme css_bolded_class = "\n" opening_svg_end_idx = orig_html.find("\n") # insert the new style class orig_html = ( orig_html[: opening_svg_end_idx + 1] + css_bolded_class + orig_html[opening_svg_end_idx + 1 :] ) # Color and bold words involved in each Semgrex match for query in semgrex_sentence.result: for i, match in enumerate(query.match): color = colors[i] paired_dy = 2 for node in match.node: name, match_index = node.name, node.matchIndex # edit existing to change color and bold the text start = find_nth( orig_html, " of interest if ( match_index not in tracker ): # if we've already bolded and colored, keep the first color tspan_start = orig_html.find( " inside of the tspan_end = orig_html.find( "", start ) # finds start of the end of the above tspan_substr = ( orig_html[tspan_start : tspan_end + CLOSING_TSPAN_LEN + 1] + "\n" ) # color and bold words in the search hit edited_tspan = tspan_substr.replace( 'class="displacy-word"', 'class="bolded"' ).replace('fill="currentColor"', f'fill="{color}"') # insert edited object into html string # TODO: DEBUG. This code has a bug in it that causes the svg to not end on an input like # "The Wimbledon grass-court tennis tournament banned players, resulting in players hating others." # to malfunction and add another copy to the tail-end of the first svg rendering. # This bug has been patched in the end of this function, but need to find out what is going on. orig_html = ( orig_html[:tspan_start] + edited_tspan + orig_html[tspan_end + CLOSING_TSPAN_LEN + 2 :] ) tracker[match_index] = DEFAULT_TSPAN_COUNT # next, we have to insert the new object for the label # Copy old to copy formatting when creating new later prev_tspan_start = ( find_nth(orig_html[start:], " start index prev_tspan_end = ( find_nth(orig_html[start:], "", tracker[match_index] - 1) + start ) # find the prev start index prev_tspan = orig_html[ prev_tspan_start : prev_tspan_end + CLOSING_TSPAN_LEN + 1 ] # Find spot to insert new tspan closing_tspan_start = ( find_nth(orig_html[start:], "", tracker[match_index]) + start ) up_to_new_tspan = orig_html[ : closing_tspan_start + CLOSING_TSPAN_LEN + 1 ] rest = orig_html[closing_tspan_start + CLOSING_TSPAN_LEN + 1 :] # Calculate proper x value in svg x_value_start = prev_tspan.find('x="') x_value_end = ( prev_tspan[x_value_start + 3 :].find('"') + 3 ) # 3 is the length of the 'x="' substring x_value = prev_tspan[x_value_start + 3 : x_value_end + x_value_start] # Calculate proper y value in svg DEFAULT_DY_VAL, dy = 2, 2 if ( paired_dy != DEFAULT_DY_VAL and node == match.node[1] ): # we're on the second node and need to adjust height to match the paired node dy = paired_dy if node == match.node[0] and len(match.node) > 1: paired_node_level = 2 if ( match.node[1].matchIndex in tracker ): # check if we need to adjust heights of labels paired_node_level = tracker[match.node[1].matchIndex] dif = tracker[match_index] - paired_node_level if dif > 0: # current node has more labels paired_dy = DEFAULT_DY_VAL * dif + 1 dy = DEFAULT_DY_VAL else: # paired node has more labels, adjust this label down dy = DEFAULT_DY_VAL * (abs(dif) + 1) paired_dy = DEFAULT_DY_VAL # Insert new object new_tspan = f' {name[: 3].title()}.\n' # abbreviate label names to 3 chars orig_html = up_to_new_tspan + new_tspan + rest tracker[match_index] += 1 # process out extra term if present -- TODO: Figure out why the semgrexify_html function lines 164-168 cause a duplication bug end = find_nth(haystack=orig_html, needle=" has length 6 so add 1 to the end too if len(orig_html) > end + LENGTH_OF_END_SVG: orig_html = orig_html[: end + LENGTH_OF_END_SVG] return orig_html def render_html_strings(edited_html_strings: List[str]) -> None: """ Renders the HTML of each HTML string. """ for html_string in edited_html_strings: display(HTML(html_string)) def visualize_search_doc( doc: Any, semgrex_queries: List[str], lang_code: str, start_match: int = 0, end_match: int = 11, render: bool = True, visualize_xpos: bool = False ) -> List[str]: """ Visualizes the result of running Semgrex search on a document. The i-th element of the returned list is the HTML representation of the i-th sentence's dependency relationships. Only shows sentences that have a match on the Semgrex search. @param doc: A Stanza document object that contains dependency relationships . @param semgrex_queries: A list of Semgrex queries to search for in the document. @param lang_code: A two letter language abbreviation for the language that the Stanza document is written in. @param start_match: Beginning of the splice for which to display elements with. @param end_match: End of the splice for which to display elements with. @param render: A toggled option to render the HTML strings within the returned list @param visualize_xpos: A toggled option to use xpos tags in part-of-speech labels, defaulting to upos tags. @return: A list of HTML strings representing the dependency relations of the doc object. """ matches_count = 0 # Limits number of visualizations with Semgrex(classpath="$CLASSPATH") as sem: edited_html_strings = [] semgrex_results = sem.process(doc, *semgrex_queries) # one html string for each sentence unedited_html_strings = get_sentences_html(doc, lang_code, visualize_xpos=visualize_xpos) for semgrex_result in semgrex_results.result: if matches_count >= end_match: # we've collected enough matches break # read the sentence_idx off the matches, # in case they came back in an unexpected order sentence_idx = None for sentence_result in semgrex_result.result: for match in sentence_result.match: sentence_idx = match.sentenceIndex break # don't count empty match objects as having matched if sentence_idx is None: continue if start_match <= matches_count < end_match: unedited_html_string = unedited_html_strings[sentence_idx] edited_string = semgrexify_html( unedited_html_string, semgrex_result ) edited_string = adjust_dep_arrows(edited_string) edited_html_strings.append(edited_string) matches_count += 1 if render: render_html_strings(edited_html_strings) return edited_html_strings def visualize_search_str( text: str, semgrex_queries: List[str], lang_code: str, start_match: int = 0, end_match: int = 11, pipe=None, render: bool = True, visualize_xpos: bool = False ): """ Visualizes the result of running Semgrex search on a string. The i-th element of the returned list is the HTML representation of the i-th sentence's dependency relationships. Only shows sentences that have a match on the Semgrex search. @param text: The string for which Semgrex search will be run on. @param semgrex_queries: A list of Semgrex queries to search for in the document. @param lang_code: A two letter language abbreviation for the language that the Stanza document is written in. @param start_match: Beginning of the splice for which to display elements with. @param end_match: End of the splice for which to display elements with. @param pipe: An NLP pipeline through which the text will be processed. @param render: A toggled option to render the HTML strings within the returned list. @param visualize_xpos: A toggled option to use xpos tags for part-of-speech labeling, defaulting to upos tags @return: A list of HTML strings representing the dependency relations of the doc object. """ if pipe is None: nlp = Pipeline(lang_code, processors="tokenize, pos, lemma, depparse") else: nlp = pipe doc = nlp(text) return visualize_search_doc( doc, semgrex_queries, lang_code, start_match=start_match, end_match=end_match, render=render, visualize_xpos=visualize_xpos ) def adjust_dep_arrows(raw_html: str) -> str: """ Default spaCy dependency visualizations have misaligned arrows. Fix arrows by aligning arrow ends and bodies to the word that they are directed to. @param raw_html: Dependency relation visualization generated HTML from displaCy @return: Edited HTML string with fixed arrow placements """ HTML_ARROW_BEGINNING = '' HTML_ARROW_ENDING = "" HTML_ARROW_ENDING_LEN = 6 # there are 2 newline chars after the arrow ending arrows_start_idx = find_nth( haystack=raw_html, needle='', n=1 ) words_html, arrows_html = ( raw_html[:arrows_start_idx], raw_html[arrows_start_idx:], ) # separate html for words and arrows final_html = ( words_html # continually concatenate to this after processing each arrow ) arrow_number = 1 # which arrow we're currently editing (1-indexed) start_idx, end_of_class_idx = ( find_nth(haystack=arrows_html, needle=HTML_ARROW_BEGINNING, n=arrow_number), find_nth(haystack=arrows_html, needle=HTML_ARROW_ENDING, n=arrow_number), ) while start_idx != -1: # edit every arrow arrow_section = arrows_html[ start_idx : end_of_class_idx + HTML_ARROW_ENDING_LEN ] # slice a single svg arrow object if ( arrow_section[-1] == "<" ): # this is the last arrow in the HTML, don't cut the splice early arrow_section = arrows_html[start_idx:] edited_arrow_section = edit_dep_arrow(arrow_section) final_html = ( final_html + edited_arrow_section ) # continually update html with new arrow html until done # Prepare for next iteration arrow_number += 1 start_idx = find_nth(arrows_html, '', arrow_number) end_of_class_idx = find_nth(arrows_html, "", arrow_number) return final_html def edit_dep_arrow(arrow_html: str) -> str: """ The formatting of a single displacy arrow in svg is the following: csubj We edit the 'd = ...' parts of the section to fix the arrow direction and length to round to the nearest 50 units, centering on each word's center. This is because the words start at x=50 and have spacing of 100, so each word is at an x-value that is a multiple of 50. @param arrow_html: Original SVG for a single displaCy arrow. @return: Edited SVG for the displaCy arrow, adjusting its placement """ WORD_SPACING = 50 # words start at x=50 and are separated by 100s so their x values are multiples of 50 M_OFFSET = 4 # length of 'd="M' that we search for to extract the number from d="M70, for instance ARROW_PIXEL_SIZE = 4 first_d_idx, second_d_idx = ( find_nth(arrow_html, 'd="M', 1), find_nth(arrow_html, 'd="M', 2), ) # find where d="M starts first_d_cutoff, second_d_cutoff = ( arrow_html.find(",", first_d_idx), arrow_html.find(",", second_d_idx), ) # isolate the number after 'M' e.g. 'M70' # gives svg x values of arrow body starting position and arrowhead position arrow_position, arrowhead_position = ( float(arrow_html[first_d_idx + M_OFFSET : first_d_cutoff]), float(arrow_html[second_d_idx + M_OFFSET : second_d_cutoff]), ) # gives starting index of where 'fill="none"' or 'fill="currentColor"' begin, reference points to end the d= section first_fill_start_idx, second_fill_start_idx = ( find_nth(arrow_html, "fill", n=1), find_nth(arrow_html, "fill", n=3), ) # isolate the d= ... section to edit first_d, second_d = ( arrow_html[first_d_idx:first_fill_start_idx], arrow_html[second_d_idx:second_fill_start_idx], ) first_d_split, second_d_split = first_d.split(","), second_d.split(",") if ( arrow_position == arrowhead_position ): # This arrow is incoming onto the word, center the arrow/head to word center corrected_arrow_pos = corrected_arrowhead_pos = round_base( arrow_position, base=WORD_SPACING ) # edit first_d -- arrow body second_term = first_d_split[1].split(" ")[0] + " " + str(corrected_arrow_pos) first_d = ( 'd="M' + str(corrected_arrow_pos) + "," + second_term + "," + ",".join(first_d_split[2:]) ) # edit second_d -- arrowhead second_term = ( second_d_split[1].split(" ")[0] + " L" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE) ) third_term = ( second_d_split[2].split(" ")[0] + " " + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE) ) second_d = ( 'd="M' + str(corrected_arrowhead_pos) + "," + second_term + "," + third_term + "," + ",".join(second_d_split[3:]) ) else: # This arrow is outgoing to another word, center the arrow/head to that word's center corrected_arrowhead_pos = round_base(arrowhead_position, base=WORD_SPACING) # edit first_d -- arrow body third_term = first_d_split[2].split(" ")[0] + " " + str(corrected_arrowhead_pos) fourth_term = ( first_d_split[3].split(" ")[0] + " " + str(corrected_arrowhead_pos) ) terms = [ first_d_split[0], first_d_split[1], third_term, fourth_term, ] + first_d_split[4:] first_d = ",".join(terms) # edit second_d -- arrow head first_term = f'd="M{corrected_arrowhead_pos}' second_term = ( second_d_split[1].split(" ")[0] + " L" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE) ) third_term = ( second_d_split[2].split(" ")[0] + " " + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE) ) terms = [first_term, second_term, third_term] + second_d_split[3:] second_d = ",".join(terms) # rebuild and return html from its individual sections return ( arrow_html[:first_d_idx] + first_d + " " + arrow_html[first_fill_start_idx:second_d_idx] + second_d + " " + arrow_html[second_fill_start_idx:] ) def edit_html_overflow(html_string: str) -> str: """ Adds to overflow and display settings to the SVG header to visualize overflowing HTML renderings in the Semgrex streamlit app. Prevents Semgrex search tags from being cut off at the bottom of visualizations. The opening of each HTML string looks similar to this; we add to the end of the SVG header. Banning VERB Act. @param html_string: HTML of the result of running Semgrex search on a text @return: Edited HTML to visualize the dependencies even in the case of overflow. """ BUFFER_LEN = 14 # length of 'direction: ltr"' editing_start_idx = find_nth(html_string, "direction: ltr", n=1) SVG_HEADER_ADDITION = "overflow: visible; display: block" return ( html_string[:editing_start_idx] + "; " + SVG_HEADER_ADDITION + html_string[editing_start_idx + BUFFER_LEN :] ) def main(): """ IMPORTANT: For the code in this module to run, you must have corenlp and Java installed on your machine. Additionally, set an environment variable CLASSPATH equal to the path of your corenlp directory. Example: CLASSPATH=C:\\Users\\Alex\\PycharmProjects\\pythonProject\\stanford-corenlp-4.5.0\\stanford-corenlp-4.5.0\\* """ nlp = Pipeline("en", processors="tokenize,pos,lemma,depparse") doc = nlp( "Banning opal removed artifact decks from the meta. Banning tennis resulted in players banning people." ) queries = [ "{pos:NN}=object = 0 and n > 1: start = haystack.find(needle, start + len(needle)) n -= 1 return start def round_base(num, base=10): """ Rounding a number to its nearest multiple of the base. round_base(49.2, base=50) = 50. """ return base * round(num / base)