[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: bug\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1. Go to '...'\n2. Click on '....'\n3. Scroll down to '....'\n4. See error\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Environment (please complete the following information):**\n - OS: [e.g. Windows, Ubuntu, CentOS, MacOS]\n - Python version: [e.g. Python 3.6.8 from Anaconda]\n - Stanza version: [e.g., 1.0.0]\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: enhancement\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/question.md",
    "content": "---\nname: Question\nabout: 'Question about general usage. '\ntitle: \"[QUESTION]\"\nlabels: question\nassignees: ''\n\n---\n\nBefore you start, make sure to check out:\n* Our documentation: https://stanfordnlp.github.io/stanza/\n* Our FAQ: https://stanfordnlp.github.io/stanza/faq.html\n* Github issues (especially closed ones)\nYour question might have an answer in these places!\n\nIf you still couldn't find the answer to your question, feel free to delete this text and write down your question. The more information you provide with your question, the faster we will be able to help you!\n\nIf you have a question about an issue you're facing when using Stanza, please try to provide a detailed step-by-step guide to reproduce the issue you're facing. Try to at least provide a minimal code sample to reproduce the problem you are facing, instead of just describing it. That would greatly help us in locating the issue faster and help you resolve it!\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "**BEFORE YOU START**: please make sure your pull request is against the `dev` branch. \nWe cannot accept pull requests against the `main` branch. \nSee our [contributing guide](https://github.com/stanfordnlp/stanza/blob/main/CONTRIBUTING.md) for details.\n\n## Description\nA brief and concise description of what your pull request is trying to accomplish.\n\n## Fixes Issues\nA list of issues/bugs with # references. (e.g., #123)\n\n## Unit test coverage\nAre there unit tests in place to make sure your code is functioning correctly?\n(see [here](https://github.com/stanfordnlp/stanza/blob/master/tests/test_tagger.py) for a simple example)\n\n## Known breaking changes/behaviors\nDoes this break anything in Stanza's existing user interface? If so, what is it and how is it addressed?\n"
  },
  {
    "path": ".github/stale.yml",
    "content": "# Number of days of inactivity before an issue becomes stale\ndaysUntilStale: 60\n# Number of days of inactivity before a stale issue is closed\ndaysUntilClose: 7\n# Issues with these labels will never be considered stale\nexemptLabels:\n  - pinned\n  - security\n  - fixed on dev\n  - bug\n  - enhancement\n# Label to use when marking an issue as stale\nstaleLabel: stale\n# Comment to post when marking an issue as stale. Set to `false` to disable\nmarkComment: >\n  This issue has been automatically marked as stale because it has not had\n  recent activity. It will be closed if no further activity occurs. Thank you\n  for your contributions.\n# Comment to post when closing a stale issue. Set to `false` to disable\ncloseComment: >\n  This issue has been automatically closed due to inactivity.\n"
  },
  {
    "path": ".github/workflows/stanza-tests.yaml",
    "content": "name: Run Stanza Tests\non: [push]\njobs:\n  Run-Stanza-Tests:\n    runs-on: self-hosted\n    steps:\n      - run: echo \"🎉 The job was automatically triggered by a ${{ github.event_name }} event.\"\n      - run: echo \"🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!\"\n      - run: echo \"🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}.\"\n      - name: Check out repository code\n        uses: actions/checkout@v2\n      - run: echo \"💡 The ${{ github.repository }} repository has been cloned to the runner.\"\n      - run: echo \"🖥️ The workflow is now ready to test your code on the runner.\"\n      - name: Run Stanza Tests\n        run: |\n          # set up environment\n          echo \"Setting up environment...\"\n          bash\n          #. $CONDA_PREFIX/etc/profile.d/conda.sh\n          . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh\n          conda activate stanza\n          export STANZA_TEST_HOME=/scr/stanza_test\n          export CORENLP_HOME=$STANZA_TEST_HOME/corenlp_dir\n          export CLASSPATH=$CORENLP_HOME/*:\n          echo CORENLP_HOME=$CORENLP_HOME\n          echo CLASSPATH=$CLASSPATH\n          # install from stanza repo being evaluated\n          echo PWD: $pwd\n          echo PATH: $PATH\n          pip3 install -e .\n          pip3 install -e .[test]\n          pip3 install -e .[transformers]\n          pip3 install -e .[tokenizers]\n          pip3 install -e .[morphseg]\n          # set up for tests\n          echo \"Running stanza test set up...\"\n          rm -rf $STANZA_TEST_HOME\n          python3 stanza/tests/setup.py\n          # run tests\n          echo \"Running tests...\"\n          export CUDA_VISIBLE_DEVICES=2\n          pytest stanza/tests\n          \n      - run: echo \"🍏 This job's status is ${{ job.status }}.\"\n"
  },
  {
    "path": ".gitignore",
    "content": "# kept from original\n.DS_Store\n*.tmp\n*.pkl\n*.conllu\n*.lem\n*.toklabels\n\n# also data w/o any slash to account for symlinks\ndata\ndata/\nstanza_resources/\nstanza_test/\nsaved_models/\nlogs/\nlog/\n*_test_treebanks\nwandb/\n\nparams/*/*.json\n!params/*/default.json\n\n# emacs backup files\n*~\n# VI backup files?\n*py.swp\n\n# standard github python project gitignore\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# IDE-related\n.vscode/\n\n.idea/vcs.xml\n.idea/inspectionProfiles/profiles_settings.xml\n.idea/workspace.xml\n\n# Jekyll stuff, triggered by running the docs locally\n.jekyll-cache/\n.jekyll-metadata\n_site/\n\n# symlink / directory for data files\nextern_data\n"
  },
  {
    "path": ".travis.yml",
    "content": "language: python\npython:\n  - 3.6.5\nnotifications:\n  email: false\ninstall:\n  - pip install --quiet .\n  - export CORENLP_HOME=~/corenlp-latest CORENLP_VERSION=stanford-corenlp-latest\n  - export CORENLP_URL=\"http://nlp.stanford.edu/software/${CORENLP_VERSION}.zip\"\n  - wget $CORENLP_URL -O corenlp-latest.zip\n  - unzip corenlp-latest.zip > unzip.log\n  - export CORENLP_UNZIP=`grep creating unzip.log | head -n 1 | cut -d \":\" -f 2`\n  - mv $CORENLP_UNZIP $CORENLP_HOME\n  - mkdir ~/stanza_test\n  - mkdir ~/stanza_test/in\n  - mkdir ~/stanza_test/out\n  - mkdir ~/stanza_test/scripts\n  - cp tests/data/external_server.properties ~/stanza_test/scripts\n  - cp tests/data/example_french.json ~/stanza_test/out\n  - cp tests/data/tiny_emb.* ~/stanza_test/in\n  - export STANZA_TEST_HOME=~/stanza_test\nscript:\n  - python -m pytest -m travis tests/\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to Stanza\n\nWe would love to see contributions to Stanza from the community! Contributions that we welcome include bugfixes and enhancements. If you want to report a bug or suggest a feature but don't intend to fix or implement it by yourself, please create a corresponding issue on [our issues page](https://github.com/stanfordnlp/stanza/issues). If you plan to contribute a bugfix or enhancement, please read the following.\n\n## 🛠️ Bugfixes\n\nFor bugfixes, please follow these steps:\n\n- Make sure a fix does not already exist, by searching through existing [issues](https://github.com/stanfordnlp/stanza/issues) (including closed ones) and [pull requests](https://github.com/stanfordnlp/stanza/pulls).\n- Confirm the bug with us by creating a bug-report issue. In your issue, you should at least include the platform and environment that you are running with, and a minimal code snippet that will reproduce the bug.\n- Once the bug is confirmed, you can go ahead with implementing the bugfix, and create a pull request **against the `dev` branch**.\n\n## 💡 Enhancements\n\nFor enhancements, please follow these steps:\n\n- Make sure a similar enhancement suggestion does not already exist, by searching through existing [issues](https://github.com/stanfordnlp/stanza/issues).\n- Create a feature-request issue and discuss about this enhancement with us. We'll need to make sure this enhancement won't break existing user interface and functionalities.\n- Once the enhancement is confirmed with us, you can go ahead with implementing it, and create a pull request **against the `dev` branch**.\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright 2019 The Board of Trustees of The Leland Stanford Junior University\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\"><img src=\"https://github.com/stanfordnlp/stanza/raw/dev/images/stanza-logo.png\" height=\"100px\"/></div>\n\n<h2 align=\"center\">Stanza: A Python NLP Library for Many Human Languages</h2>\n\n<div align=\"center\">\n    <a href=\"https://github.com/stanfordnlp/stanza/actions\">\n       <img alt=\"Run Tests\" src=\"https://github.com/stanfordnlp/stanza/actions/workflows/stanza-tests.yaml/badge.svg\">\n    </a>\n    <a href=\"https://pypi.org/project/stanza/\">\n        <img alt=\"PyPI Version\" src=\"https://img.shields.io/pypi/v/stanza?color=blue\">\n    </a>\n    <a href=\"https://anaconda.org/stanfordnlp/stanza\">\n        <img alt=\"Conda Versions\" src=\"https://img.shields.io/conda/vn/stanfordnlp/stanza?color=blue&label=conda\">\n    </a>\n    <a href=\"https://pypi.org/project/stanza/\">\n        <img alt=\"Python Versions\" src=\"https://img.shields.io/pypi/pyversions/stanza?colorB=blue\">\n    </a>\n</div>\n\nThe Stanford NLP Group's official Python NLP library. It contains support for running various accurate natural language processing tools on 60+ languages and for accessing the Java Stanford CoreNLP software from Python. For detailed information please visit our [official website](https://stanfordnlp.github.io/stanza/).\n\n🔥 &nbsp;A new collection of **biomedical** and **clinical** English model packages are now available, offering seamless experience for syntactic analysis and named entity recognition (NER) from biomedical literature text and clinical notes. For more information, check out our [Biomedical models documentation page](https://stanfordnlp.github.io/stanza/biomed.html).\n\n### References\n\nIf you use this library in your research, please kindly cite our [ACL2020 Stanza system demo paper](https://arxiv.org/abs/2003.07082):\n\n```bibtex\n@inproceedings{qi2020stanza,\n    title={Stanza: A {Python} Natural Language Processing Toolkit for Many Human Languages},\n    author={Qi, Peng and Zhang, Yuhao and Zhang, Yuhui and Bolton, Jason and Manning, Christopher D.},\n    booktitle = \"Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations\",\n    year={2020}\n}\n```\n\nIf you use our biomedical and clinical models, please also cite our [Stanza Biomedical Models description paper](https://arxiv.org/abs/2007.14640):\n\n```bibtex\n@article{zhang2021biomedical,\n    author = {Zhang, Yuhao and Zhang, Yuhui and Qi, Peng and Manning, Christopher D and Langlotz, Curtis P},\n    title = {Biomedical and clinical {E}nglish model packages for the {S}tanza {P}ython {NLP} library},\n    journal = {Journal of the American Medical Informatics Association},\n    year = {2021},\n    month = {06},\n    issn = {1527-974X}\n}\n```\n\nThe PyTorch implementation of the neural pipeline in this repository is due to [Peng Qi](http://qipeng.me) (@qipeng), [Yuhao Zhang](http://yuhao.im) (@yuhaozhang), and [Yuhui Zhang](https://cs.stanford.edu/~yuhuiz/) (@yuhui-zh15), with help from [Jason Bolton](mailto:jebolton@stanford.edu) (@j38), [Tim Dozat](https://web.stanford.edu/~tdozat/) (@tdozat) and [John Bauer](https://www.linkedin.com/in/john-bauer-b3883b60/) (@AngledLuffa). Maintenance of this repo is currently led by [John Bauer](https://www.linkedin.com/in/john-bauer-b3883b60/).\n\nIf you use the CoreNLP software through Stanza, please cite the CoreNLP software package and the respective modules as described [here](https://stanfordnlp.github.io/CoreNLP/#citing-stanford-corenlp-in-papers) (\"Citing Stanford CoreNLP in papers\"). The CoreNLP client is mostly written by [Arun Chaganty](http://arun.chagantys.org/), and [Jason Bolton](mailto:jebolton@stanford.edu) spearheaded merging the two projects together.\n\nIf you use the Semgrex or Ssurgeon part of CoreNLP, please cite [our GURT paper on Semgrex and Ssurgeon](https://aclanthology.org/2023.tlt-1.7/):\n\n```bibtex\n@inproceedings{bauer-etal-2023-semgrex,\n    title = \"Semgrex and Ssurgeon, Searching and Manipulating Dependency Graphs\",\n    author = \"Bauer, John  and\n      Kiddon, Chlo{\\'e}  and\n      Yeh, Eric  and\n      Shan, Alex  and\n      D. Manning, Christopher\",\n    booktitle = \"Proceedings of the 21st International Workshop on Treebanks and Linguistic Theories (TLT, GURT/SyntaxFest 2023)\",\n    month = mar,\n    year = \"2023\",\n    address = \"Washington, D.C.\",\n    publisher = \"Association for Computational Linguistics\",\n    url = \"https://aclanthology.org/2023.tlt-1.7\",\n    pages = \"67--73\",\n    abstract = \"Searching dependency graphs and manipulating them can be a time consuming and challenging task to get right. We document Semgrex, a system for searching dependency graphs, and introduce Ssurgeon, a system for manipulating the output of Semgrex. The compact language used by these systems allows for easy command line or API processing of dependencies. Additionally, integration with publicly released toolkits in Java and Python allows for searching text relations and attributes over natural text.\",\n}\n```\n\n## Issues and Usage Q&A\n\nTo ask questions, report issues or request features 🤔, please use the [GitHub Issue Tracker](https://github.com/stanfordnlp/stanza/issues). Before creating a new issue, please make sure to search for existing issues that may solve your problem, or visit the [Frequently Asked Questions (FAQ) page](https://stanfordnlp.github.io/stanza/faq.html) on our website.\n\n## Contributing to Stanza\n\nWe welcome community contributions to Stanza in the form of bugfixes 🛠️ and enhancements 💡! If you want to contribute, please first read [our contribution guideline](CONTRIBUTING.md).\n\n## Installation\n\n### pip\n\nStanza supports Python 3.6 or later. We recommend that you install Stanza via [pip](https://pip.pypa.io/en/stable/installing/), the Python package manager. To install, simply run:\n```bash\npip install stanza\n```\nThis should also help resolve all of the dependencies of Stanza, for instance [PyTorch](https://pytorch.org/) 1.3.0 or above.\n\nIf you currently have a previous version of `stanza` installed, use:\n```bash\npip install stanza -U\n```\n\n### Anaconda\n\nTo install Stanza via Anaconda, use the following conda command:\n\n```bash\nconda install -c stanfordnlp stanza\n```\n\nNote that for now installing Stanza via Anaconda does not work for Python 3.10. For Python 3.10 please use pip installation.\n\n### From Source\n\nAlternatively, you can also install from source of this git repository, which will give you more flexibility in developing on top of Stanza. For this option, run\n```bash\ngit clone https://github.com/stanfordnlp/stanza.git\ncd stanza\npip install -e .\n```\n\n## Running Stanza\n\n### Getting Started with the neural pipeline\n\nTo run your first Stanza pipeline, simply follow these steps in your Python interactive interpreter:\n\n```python\n>>> import stanza\n>>> stanza.download('en')       # Optional: pre-download English models (Pipeline can auto-download if needed)\n>>> nlp = stanza.Pipeline('en') # This sets up a default neural pipeline in English\n>>> doc = nlp(\"Barack Obama was born in Hawaii. He was elected president in 2008.\")\n>>> doc.sentences[0].print_dependencies()\n```\n\nIf you encounter `requests.exceptions.ConnectionError`, please try to use a proxy:\n\n```python\n>>> import stanza\n>>> proxies = {'http': 'http://ip:port', 'https': 'http://ip:port'}\n>>> stanza.download('en', proxies=proxies)  # Optional: pre-download English models (Pipeline can auto-download if needed)\n>>> nlp = stanza.Pipeline('en')             # This sets up a default neural pipeline in English\n>>> doc = nlp(\"Barack Obama was born in Hawaii. He was elected president in 2008.\")\n>>> doc.sentences[0].print_dependencies()\n```\n\nThe last command will print out the words in the first sentence in the input string (or [`Document`](https://stanfordnlp.github.io/stanza/data_objects.html#document), as it is represented in Stanza), as well as the indices for the word that governs it in the Universal Dependencies parse of that sentence (its \"head\"), along with the dependency relation between the words. The output should look like:\n\n```\n('Barack', '4', 'nsubj:pass')\n('Obama', '1', 'flat')\n('was', '4', 'aux:pass')\n('born', '0', 'root')\n('in', '6', 'case')\n('Hawaii', '4', 'obl')\n('.', '4', 'punct')\n```\n\nSee [our getting started guide](https://stanfordnlp.github.io/stanza/installation_usage.html#getting-started) for more details.\n\n### Accessing Java Stanford CoreNLP software\n\nAside from the neural pipeline, this package also includes an official wrapper for accessing the Java Stanford CoreNLP software with Python code.\n\nThere are a few initial setup steps.\n\n* Download [Stanford CoreNLP](https://stanfordnlp.github.io/CoreNLP/) and models for the language you wish to use\n* Put the model jars in the distribution folder\n* Tell the Python code where Stanford CoreNLP is located by setting the `CORENLP_HOME` environment variable (e.g., in *nix): `export CORENLP_HOME=/path/to/stanford-corenlp-4.5.3`\n\nWe provide [comprehensive examples](https://stanfordnlp.github.io/stanza/corenlp_client.html) in our documentation that show how one can use CoreNLP through Stanza and extract various annotations from it.\n\n### Online Colab Notebooks\n\nTo get your started, we also provide interactive Jupyter notebooks in the `demo` folder. You can also open these notebooks and run them interactively on [Google Colab](https://colab.research.google.com). To view all available notebooks, follow these steps:\n\n* Go to the [Google Colab website](https://colab.research.google.com)\n* Navigate to `File` -> `Open notebook`, and choose `GitHub` in the pop-up menu\n* Note that you do **not** need to give Colab access permission to your GitHub account\n* Type `stanfordnlp/stanza` in the search bar, and click enter\n\n### Trained Models for the Neural Pipeline\n\nWe currently provide models for all of the [Universal Dependencies](https://universaldependencies.org/) treebanks v2.8, as well as NER models for a few widely-spoken languages. You can find instructions for downloading and using these models [here](https://stanfordnlp.github.io/stanza/models.html).\n\n### Batching To Maximize Pipeline Speed\n\nTo maximize speed performance, it is essential to run the pipeline on batches of documents. Running a for loop on one sentence at a time will be very slow. The best approach at this time is to concatenate documents together, with each document separated by a blank line (i.e., two line breaks `\\n\\n`).  The tokenizer will recognize blank lines as sentence breaks. We are actively working on improving multi-document processing.\n\n## Training your own neural pipelines\n\nAll neural modules in this library can be trained with your own data. The tokenizer, the multi-word token (MWT) expander, the POS/morphological features tagger, the lemmatizer and the dependency parser require [CoNLL-U](https://universaldependencies.org/format.html) formatted data, while the NER model requires the BIOES format. Currently, we do not support model training via the `Pipeline` interface. Therefore, to train your own models, you need to clone this git repository and run training from the source.\n\nFor detailed step-by-step guidance on how to train and evaluate your own models, please visit our [training documentation](https://stanfordnlp.github.io/stanza/training.html).\n\n## LICENSE\n\nStanza is released under the Apache License, Version 2.0. See the [LICENSE](https://github.com/stanfordnlp/stanza/blob/master/LICENSE) file for more details.\n"
  },
  {
    "path": "demo/CONLL_Dependency_Visualizer_Example.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c0fd86c8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\\n\",\n    \"\\n\",\n    \"# load necessary conllu files - expected to be in the demo directory along with the notebook\\n\",\n    \"en_file = \\\"en_test.conllu.txt\\\"\\n\",\n    \"\\n\",\n    \"# testing left to right languages\\n\",\n    \"conll_to_visual(en_file, \\\"en\\\", sent_count=2)\\n\",\n    \"conll_to_visual(en_file, \\\"en\\\", sent_count=10)\\n\",\n    \"#conll_to_visual(en_file, \\\"en\\\", display_all=True)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"fc4b3f9b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\\n\",\n    \"\\n\",\n    \"jp_file = \\\"japanese_test.conllu.txt\\\"\\n\",\n    \"conll_to_visual(jp_file, \\\"ja\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6852b8e8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.conll_deprel_visualization import conll_to_visual\\n\",\n    \"\\n\",\n    \"# testing right to left languages\\n\",\n    \"ar_file = \\\"arabic_test.conllu.txt\\\"\\n\",\n    \"conll_to_visual(ar_file, \\\"ar\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.22\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "demo/Dependency_Visualization_Testing.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"64b2a9e0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.dependency_visualization import visualize_strings\\n\",\n    \"\\n\",\n    \"ar_strings = ['برلين ترفض حصول شركة اميركية على رخصة تصنيع دبابة \\\"ليوبارد\\\" الالمانية', \\\"هل بإمكاني مساعدتك؟\\\", \\n\",\n    \"              \\\"أراك في مابعد\\\", \\\"لحظة من فضلك\\\"]\\n\",\n    \"# Testing with right to left language\\n\",\n    \"visualize_strings(ar_strings, \\\"ar\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"35ef521b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.dependency_visualization import visualize_strings\\n\",\n    \"\\n\",\n    \"en_strings = [\\\"This is a sentence.\\\", \\n\",\n    \"              \\\"He is wearing a red shirt\\\",\\n\",\n    \"              \\\"Barack Obama was born in Hawaii. He was elected President of the United States in 2008.\\\"]\\n\",\n    \"# Testing with left to right languages\\n\",\n    \"visualize_strings(en_strings, \\\"en\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f3cf10ba\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.dependency_visualization import visualize_strings\\n\",\n    \"\\n\",\n    \"zh_strings = [\\\"中国是一个很有意思的国家。\\\"]\\n\",\n    \"# Testing with right to left language\\n\",\n    \"visualize_strings(zh_strings, \\\"zh\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d2b9b574\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.22\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "demo/NER_Visualization.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"abf300bb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.ner_visualization import visualize_strings\\n\",\n    \"\\n\",\n    \"en_strings = ['''Samuel Jackson, a Christian man from Utah, went to the JFK Airport for a flight to New York.\\n\",\n    \"                 He was thinking of attending the US Open, his favorite tennis tournament besides Wimbledon.\\n\",\n    \"                 That would be a dream trip, certainly not possible since it is $5000 attendance and 5000 miles away.\\n\",\n    \"                 On the way there, he watched the Super Bowl for 2 hours and read War and Piece by Tolstoy for 1 hour.\\n\",\n    \"                 In New York, he crossed the Brooklyn Bridge and listened to the 5th symphony of Beethoven as well as\\n\",\n    \"                 \\\"All I want for Christmas is You\\\" by Mariah Carey.''', \\n\",\n    \"              \\\"Barack Obama was born in Hawaii. He was elected President of the United States in 2008\\\"]\\n\",\n    \"    \\n\",\n    \"visualize_strings(en_strings, \\\"en\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5670921a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.ner_visualization import visualize_strings\\n\",\n    \"\\n\",\n    \"zh_strings = ['''来自犹他州的基督徒塞缪尔杰克逊前往肯尼迪机场搭乘航班飞往纽约。\\n\",\n    \"                 他正在考虑参加美国公开赛，这是除了温布尔登之外他最喜欢的网球赛事。\\n\",\n    \"                 那将是一次梦想之旅，当然不可能，因为它的出勤费为 5000 美元，距离 5000 英里。\\n\",\n    \"                 在去的路上，他看了 2 个小时的超级碗比赛，看了 1 个小时的托尔斯泰的《战争与碎片》。\\n\",\n    \"                 在纽约，他穿过布鲁克林大桥，聆听了贝多芬的第五交响曲以及 玛丽亚凯莉的“圣诞节我想要的就是你”。''',\\n\",\n    \"              \\\"我觉得罗家费德勒住在加州, 在美国里面。\\\"]\\n\",\n    \"visualize_strings(zh_strings, \\\"zh\\\", colors={\\\"PERSON\\\": \\\"yellow\\\", \\\"DATE\\\": \\\"red\\\", \\\"GPE\\\": \\\"blue\\\"})\\n\",\n    \"visualize_strings(zh_strings, \\\"zh\\\", select=['PERSON', 'DATE'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b8d96072\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from stanza.utils.visualization.ner_visualization import visualize_strings\\n\",\n    \"\\n\",\n    \"ar_strings = [\\\".أعيش في سان فرانسيسكو ، كاليفورنيا. اسمي أليكس وأنا ألتحق بجامعة ستانفورد. أنا أدرس علوم الكمبيوتر وأستاذي هو كريس مانينغ\\\"\\n\",\n    \"             , \\\"اسمي أليكس ، أنا من الولايات المتحدة.\\\",  \\n\",\n    \"               '''صامويل جاكسون ، رجل مسيحي من ولاية يوتا ، ذهب إلى مطار جون كنيدي في رحلة إلى نيويورك. كان يفكر في حضور بطولة الولايات المتحدة المفتوحة للتنس ، بطولة التنس المفضلة لديه إلى جانب بطولة ويمبلدون. ستكون هذه رحلة الأحلام ، وبالتأكيد ليست ممكنة لأنها تبلغ 5000 دولار للحضور و 5000 ميل. في الطريق إلى هناك ، شاهد Super Bowl لمدة ساعتين وقرأ War and Piece by Tolstoy لمدة ساعة واحدة. في نيويورك ، عبر جسر بروكلين واستمع إلى السيمفونية الخامسة لبيتهوفن وكذلك \\\"كل ما أريده في عيد الميلاد هو أنت\\\" لماريا كاري.''']\\n\",\n    \"\\n\",\n    \"visualize_strings(ar_strings, \\\"ar\\\", colors={\\\"PER\\\": \\\"pink\\\", \\\"LOC\\\": \\\"linear-gradient(90deg, #aa9cfc, #fc9ce7)\\\", \\\"ORG\\\": \\\"yellow\\\"})\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"22489b27\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.22\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "demo/Stanza_Beginners_Guide.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"Stanza-Beginners-Guide.ipynb\",\n      \"provenance\": [],\n      \"collapsed_sections\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    }\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"56LiYCkPM7V_\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"# Welcome to Stanza!\\n\",\n        \"\\n\",\n        \"![Latest Version](https://img.shields.io/pypi/v/stanza.svg?colorB=bc4545)\\n\",\n        \"![Python Versions](https://img.shields.io/pypi/pyversions/stanza.svg?colorB=bc4545)\\n\",\n        \"\\n\",\n        \"Stanza is a Python NLP toolkit that supports 60+ human languages. It is built with highly accurate neural network components that enable efficient training and evaluation with your own annotated data, and offers pretrained models on 100 treebanks. Additionally, Stanza provides a stable, officially maintained Python interface to Java Stanford CoreNLP Toolkit.\\n\",\n        \"\\n\",\n        \"In this tutorial, we will demonstrate how to set up Stanza and annotate text with its native neural network NLP models. For the use of the Python CoreNLP interface, please see other tutorials.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"yQff4Di5Nnq0\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 1. Installing Stanza\\n\",\n        \"\\n\",\n        \"Note that Stanza only supports Python 3.6 and above. Installing and importing Stanza are as simple as running the following commands:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"owSj1UtdEvSU\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Install; note that the prefix \\\"!\\\" is not needed if you are running in a terminal\\n\",\n        \"!pip install stanza\\n\",\n        \"\\n\",\n        \"# Import the package\\n\",\n        \"import stanza\"\n      ],\n      \"execution_count\": 0,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"4ixllwEKeCJg\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### More Information\\n\",\n        \"\\n\",\n        \"For common troubleshooting, please visit our [troubleshooting page](https://stanfordnlp.github.io/stanfordnlp/installation_usage.html#troubleshooting).\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"aeyPs5ARO79d\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 2. Downloading Models\\n\",\n        \"\\n\",\n        \"You can download models with the `stanza.download` command. The language can be specified with either a full language name (e.g., \\\"english\\\"), or a short code (e.g., \\\"en\\\"). \\n\",\n        \"\\n\",\n        \"By default, models will be saved to your `~/stanza_resources` directory. If you want to specify your own path to save the model files, you can pass a `dir=your_path` argument.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"HDwRm-KXGcYo\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Download an English model into the default directory\\n\",\n        \"print(\\\"Downloading English model...\\\")\\n\",\n        \"stanza.download('en')\\n\",\n        \"\\n\",\n        \"# Similarly, download a (simplified) Chinese model\\n\",\n        \"# Note that you can use verbose=False to turn off all printed messages\\n\",\n        \"print(\\\"Downloading Chinese model...\\\")\\n\",\n        \"stanza.download('zh', verbose=False)\"\n      ],\n      \"execution_count\": 0,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"7HCfQ0SfdmsU\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### More Information\\n\",\n        \"\\n\",\n        \"Pretrained models are provided for 60+ different languages. For all languages, available models and the corresponding short language codes, please check out the [models page](https://stanfordnlp.github.io/stanza/models.html).\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"b3-WZJrzWD2o\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 3. Processing Text\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"XrnKl2m3fq2f\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### Constructing Pipeline\\n\",\n        \"\\n\",\n        \"To process a piece of text, you'll need to first construct a `Pipeline` with different `Processor` units. The pipeline is language-specific, so again you'll need to first specify the language (see examples).\\n\",\n        \"\\n\",\n        \"- By default, the pipeline will include all processors, including tokenization, multi-word token expansion, part-of-speech tagging, lemmatization, dependency parsing and named entity recognition (for supported languages). However, you can always specify what processors you want to include with the `processors` argument.\\n\",\n        \"\\n\",\n        \"- Stanza's pipeline is CUDA-aware, meaning that a CUDA-device will be used whenever it is available, otherwise CPUs will be used when a GPU is not found. You can force the pipeline to use CPU regardless by setting `use_gpu=False`.\\n\",\n        \"\\n\",\n        \"- Again, you can suppress all printed messages by setting `verbose=False`.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"HbiTSBDPG53o\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Build an English pipeline, with all processors by default\\n\",\n        \"print(\\\"Building an English pipeline...\\\")\\n\",\n        \"en_nlp = stanza.Pipeline('en')\\n\",\n        \"\\n\",\n        \"# Build a Chinese pipeline, with customized processor list and no logging, and force it to use CPU\\n\",\n        \"print(\\\"Building a Chinese pipeline...\\\")\\n\",\n        \"zh_nlp = stanza.Pipeline('zh', processors='tokenize,lemma,pos,depparse', verbose=False, use_gpu=False)\"\n      ],\n      \"execution_count\": 0,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Go123Bx8e1wt\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### Annotating Text\\n\",\n        \"\\n\",\n        \"After a pipeline is successfully constructed, you can get annotations of a piece of text simply by passing the string into the pipeline object. The pipeline will return a `Document` object, which can be used to access detailed annotations from. For example:\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"k_p0h1UTHDMm\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Processing English text\\n\",\n        \"en_doc = en_nlp(\\\"Barack Obama was born in Hawaii.  He was elected president in 2008.\\\")\\n\",\n        \"print(type(en_doc))\\n\",\n        \"\\n\",\n        \"# Processing Chinese text\\n\",\n        \"zh_doc = zh_nlp(\\\"达沃斯世界经济论坛是每年全球政商界领袖聚在一起的年度盛事。\\\")\\n\",\n        \"print(type(zh_doc))\"\n      ],\n      \"execution_count\": 0,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"DavwCP9egzNZ\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### More Information\\n\",\n        \"\\n\",\n        \"For more information on how to construct a pipeline and information on different processors, please visit our [pipeline page](https://stanfordnlp.github.io/stanfordnlp/pipeline.html).\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"O_PYLEGziQWR\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 4. Accessing Annotations\\n\",\n        \"\\n\",\n        \"Annotations can be accessed from the returned `Document` object. \\n\",\n        \"\\n\",\n        \"A `Document` contains a list of `Sentence`s, and a `Sentence` contains a list of `Token`s and `Word`s. For the most part `Token`s and `Word`s overlap, but some tokens can be divided into mutiple words, for instance the French token `aux` is divided into the words `à` and `les`, while in English a word and a token are equivalent. Note that dependency parses are derived over `Word`s.\\n\",\n        \"\\n\",\n        \"Additionally, a `Span` object is used to represent annotations that are part of a document, such as named entity mentions.\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"The following example iterate over all English sentences and words, and print the word information one by one:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"B5691SpFHFZ6\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"for i, sent in enumerate(en_doc.sentences):\\n\",\n        \"    print(\\\"[Sentence {}]\\\".format(i+1))\\n\",\n        \"    for word in sent.words:\\n\",\n        \"        print(\\\"{:12s}\\\\t{:12s}\\\\t{:6s}\\\\t{:d}\\\\t{:12s}\\\".format(\\\\\\n\",\n        \"              word.text, word.lemma, word.pos, word.head, word.deprel))\\n\",\n        \"    print(\\\"\\\")\"\n      ],\n      \"execution_count\": 0,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"-AUkCkNIrusq\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"The following example iterate over all extracted named entity mentions and print out their character spans and types.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"5Uu0-WmvsnlK\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"print(\\\"Mention text\\\\tType\\\\tStart-End\\\")\\n\",\n        \"for ent in en_doc.ents:\\n\",\n        \"    print(\\\"{}\\\\t{}\\\\t{}-{}\\\".format(ent.text, ent.type, ent.start_char, ent.end_char))\"\n      ],\n      \"execution_count\": 0,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Ql1SZlZOnMLo\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"And similarly for the Chinese text:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"XsVcEO9tHKPG\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"for i, sent in enumerate(zh_doc.sentences):\\n\",\n        \"    print(\\\"[Sentence {}]\\\".format(i+1))\\n\",\n        \"    for word in sent.words:\\n\",\n        \"        print(\\\"{:12s}\\\\t{:12s}\\\\t{:6s}\\\\t{:d}\\\\t{:12s}\\\".format(\\\\\\n\",\n        \"              word.text, word.lemma, word.pos, word.head, word.deprel))\\n\",\n        \"    print(\\\"\\\")\"\n      ],\n      \"execution_count\": 0,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"dUhWAs8pnnHT\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"Alternatively, you can directly print a `Word` object to view all its annotations as a Python dict:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"6_UafNb7HHIg\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"word = en_doc.sentences[0].words[0]\\n\",\n        \"print(word)\"\n      ],\n      \"execution_count\": 0,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"TAQlOsuRoq2V\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### More Information\\n\",\n        \"\\n\",\n        \"For all information on different data objects, please visit our [data objects page](https://stanfordnlp.github.io/stanza/data_objects.html).\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"hiiWHxYPpmhd\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 5. Resources\\n\",\n        \"\\n\",\n        \"Apart from this interactive tutorial, we also provide tutorials on our website that cover a variety of use cases such as how to use different model \\\"packages\\\" for a language, how to use spaCy as a tokenizer, how to process pretokenized text without running the tokenizer, etc. For these tutorials please visit [our Tutorials page](https://stanfordnlp.github.io/stanza/tutorials.html).\\n\",\n        \"\\n\",\n        \"Other resources that you may find helpful include:\\n\",\n        \"\\n\",\n        \"- [Stanza Homepage](https://stanfordnlp.github.io/stanza/index.html)\\n\",\n        \"- [FAQs](https://stanfordnlp.github.io/stanza/faq.html)\\n\",\n        \"- [GitHub Repo](https://github.com/stanfordnlp/stanza)\\n\",\n        \"- [Reporting Issues](https://github.com/stanfordnlp/stanza/issues)\\n\",\n        \"- [Stanza System Description Paper](http://arxiv.org/abs/2003.07082)\\n\"\n      ]\n    }\n  ]\n}"
  },
  {
    "path": "demo/Stanza_CoreNLP_Interface.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"Stanza-CoreNLP-Interface.ipynb\",\n      \"provenance\": [],\n      \"collapsed_sections\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    }\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2-4lzQTC9yxG\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"# Stanza: A Tutorial on the Python CoreNLP Interface\\n\",\n        \"\\n\",\n        \"![Latest Version](https://img.shields.io/pypi/v/stanza.svg?colorB=bc4545)\\n\",\n        \"![Python Versions](https://img.shields.io/pypi/pyversions/stanza.svg?colorB=bc4545)\\n\",\n        \"\\n\",\n        \"While the Stanza library implements accurate neural network modules for basic functionalities such as part-of-speech tagging and dependency parsing, the [Stanford CoreNLP Java library](https://stanfordnlp.github.io/CoreNLP/) has been developed for years and offers more complementary features such as coreference resolution and relation extraction. To unlock these features, the Stanza library also offers an officially maintained Python interface to the CoreNLP Java library. This interface allows you to get NLP anntotations from CoreNLP by writing native Python code.\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"This tutorial walks you through the installation, setup and basic usage of this Python CoreNLP interface. If you want to learn how to use the neural network components in Stanza, please refer to other tutorials.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"YpKwWeVkASGt\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 1. Installation\\n\",\n        \"\\n\",\n        \"Before the installation starts, please make sure that you have Python 3 and Java installed on your computer. Since Colab already has them installed, we'll skip this procedure in this notebook.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"k1Az2ECuAfG8\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### Installing Stanza\\n\",\n        \"\\n\",\n        \"Installing and importing Stanza are as simple as running the following commands:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"xiFwYAgW4Mss\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Install stanza; note that the prefix \\\"!\\\" is not needed if you are running in a terminal\\n\",\n        \"!pip install stanza\\n\",\n        \"\\n\",\n        \"# Import stanza\\n\",\n        \"import stanza\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2zFvaA8_A32_\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### Setting up Stanford CoreNLP\\n\",\n        \"\\n\",\n        \"In order for the interface to work, the Stanford CoreNLP library has to be installed and a `CORENLP_HOME` environment variable has to be pointed to the installation location.\\n\",\n        \"\\n\",\n        \"Here we are going to show you how to download and install the CoreNLP library on your machine, with Stanza's installation command:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"MgK6-LPV-OdA\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Download the Stanford CoreNLP package with Stanza's installation command\\n\",\n        \"# This'll take several minutes, depending on the network speed\\n\",\n        \"corenlp_dir = './corenlp'\\n\",\n        \"stanza.install_corenlp(dir=corenlp_dir)\\n\",\n        \"\\n\",\n        \"# Set the CORENLP_HOME environment variable to point to the installation location\\n\",\n        \"import os\\n\",\n        \"os.environ[\\\"CORENLP_HOME\\\"] = corenlp_dir\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Jdq8MT-NAhKj\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"That's all for the installation! 🎉  We can now double check if the installation is successful by listing files in the CoreNLP directory. You should be able to see a number of `.jar` files by running the following command:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"K5eIOaJp_tuo\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Examine the CoreNLP installation folder to make sure the installation is successful\\n\",\n        \"!ls $CORENLP_HOME\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"S0xb9BHt__gx\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"**Note 1**:\\n\",\n        \"If you are want to use the interface in a terminal (instead of a Colab notebook), you can properly set the `CORENLP_HOME` environment variable with:\\n\",\n        \"\\n\",\n        \"```bash\\n\",\n        \"export CORENLP_HOME=path_to_corenlp_dir\\n\",\n        \"```\\n\",\n        \"\\n\",\n        \"Here we instead set this variable with the Python `os` library, simply because `export` command is not well-supported in Colab notebook.\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"**Note 2**:\\n\",\n        \"The `stanza.install_corenlp()` function is only available since Stanza v1.1.1. If you are using an earlier version of Stanza, please check out our [manual installation page](https://stanfordnlp.github.io/stanza/client_setup.html#manual-installation) for how to install CoreNLP on your computer.\\n\",\n        \"\\n\",\n        \"**Note 3**:\\n\",\n        \"Besides the installation function, we also provide a `stanza.download_corenlp_models()` function to help you download additional CoreNLP models for different languages that are not shipped with the default installation. Check out our [automatic installation website page](https://stanfordnlp.github.io/stanza/client_setup.html#automated-installation) for more information on how to use it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"xJsuO6D8D05q\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 2. Annotating Text with CoreNLP Interface\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"dZNHxXHkH1K2\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### Constructing CoreNLPClient\\n\",\n        \"\\n\",\n        \"At a high level, the CoreNLP Python interface works by first starting a background Java CoreNLP server process, and then initializing a client instance in Python which can pass the text to the background server process, and accept the returned annotation results.\\n\",\n        \"\\n\",\n        \"We wrap these functionalities in a `CoreNLPClient` class. Therefore, we need to start by importing this class from Stanza.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"LS4OKnqJ8wui\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Import client module\\n\",\n        \"from stanza.server import CoreNLPClient\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"WP4Dz6PIJHeL\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"After the import is done, we can construct a `CoreNLPClient` instance. The constructor method takes a Python list of annotator names as argument. Here let's explore some basic annotators including tokenization, sentence split, part-of-speech tagging, lemmatization and named entity recognition (NER). \\n\",\n        \"\\n\",\n        \"Additionally, the client constructor accepts a `memory` argument, which specifies how much memory will be allocated to the background Java process. An `endpoint` option can be used to specify a port number used by the communication between the server and the client. The default port is 9000. However, since this port is pre-occupied by a system process in Colab, we'll manually set it to 9001 in the following example.\\n\",\n        \"\\n\",\n        \"Also, here we manually set `be_quiet=True` to avoid an IO issue in colab notebook. You should be able to use `be_quiet=False` on your own computer, which will print detailed logging information from CoreNLP during usage.\\n\",\n        \"\\n\",\n        \"For more options in constructing the clients, please refer to the [CoreNLP Client Options List](https://stanfordnlp.github.io/stanza/corenlp_client.html#corenlp-client-options).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"mbOBugvd9JaM\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Construct a CoreNLPClient with some basic annotators, a memory allocation of 4GB, and port number 9001\\n\",\n        \"client = CoreNLPClient(\\n\",\n        \"    annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \\n\",\n        \"    memory='4G', \\n\",\n        \"    endpoint='http://localhost:9001',\\n\",\n        \"    be_quiet=True)\\n\",\n        \"print(client)\\n\",\n        \"\\n\",\n        \"# Start the background server and wait for some time\\n\",\n        \"# Note that in practice this is totally optional, as by default the server will be started when the first annotation is performed\\n\",\n        \"client.start()\\n\",\n        \"import time; time.sleep(10)\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kgTiVjNydmIW\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"After the above code block finishes executing, if you print the background processes, you should be able to find the Java CoreNLP server running.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"spZrJ-oFdkdF\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Print background processes and look for java\\n\",\n        \"# You should be able to see a StanfordCoreNLPServer java process running in the background\\n\",\n        \"!ps -o pid,cmd | grep java\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"KxJeJ0D2LoOs\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### Annotating Text\\n\",\n        \"\\n\",\n        \"Annotating a piece of text is as simple as passing the text into an `annotate` function of the client object. After the annotation is complete, a `Document`  object will be returned with all annotations.\\n\",\n        \"\\n\",\n        \"Note that although in general annotations are very fast, the first annotation might take a while to complete in the notebook. Please stay patient.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"s194RnNg5z95\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Annotate some text\\n\",\n        \"text = \\\"Albert Einstein was a German-born theoretical physicist. He developed the theory of relativity.\\\"\\n\",\n        \"document = client.annotate(text)\\n\",\n        \"print(type(document))\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"semmA3e0TcM1\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 3. Accessing Annotations\\n\",\n        \"\\n\",\n        \"Annotations can be accessed from the returned `Document` object.\\n\",\n        \"\\n\",\n        \"A `Document` contains a list of `Sentence`s, which contain a list of `Token`s. Here let's first explore the annotations stored in all tokens.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"lIO4B5d6Rk4I\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Iterate over all tokens in all sentences, and print out the word, lemma, pos and ner tags\\n\",\n        \"print(\\\"{:12s}\\\\t{:12s}\\\\t{:6s}\\\\t{}\\\".format(\\\"Word\\\", \\\"Lemma\\\", \\\"POS\\\", \\\"NER\\\"))\\n\",\n        \"\\n\",\n        \"for i, sent in enumerate(document.sentence):\\n\",\n        \"    print(\\\"[Sentence {}]\\\".format(i+1))\\n\",\n        \"    for t in sent.token:\\n\",\n        \"        print(\\\"{:12s}\\\\t{:12s}\\\\t{:6s}\\\\t{}\\\".format(t.word, t.lemma, t.pos, t.ner))\\n\",\n        \"    print(\\\"\\\")\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"msrJfvu8VV9m\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"Alternatively, you can also browse the NER results by iterating over entity mentions over the sentences. For example:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"ezEjc9LeV2Xs\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Iterate over all detected entity mentions\\n\",\n        \"print(\\\"{:30s}\\\\t{}\\\".format(\\\"Mention\\\", \\\"Type\\\"))\\n\",\n        \"\\n\",\n        \"for sent in document.sentence:\\n\",\n        \"    for m in sent.mentions:\\n\",\n        \"        print(\\\"{:30s}\\\\t{}\\\".format(m.entityMentionText, m.entityType))\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ueGzBZ3hWzkN\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"To print all annotations a sentence, token or mention has, you can simply print the corresponding obejct.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"4_S8o2BHXIed\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Print annotations of a token\\n\",\n        \"print(document.sentence[0].token[0])\\n\",\n        \"\\n\",\n        \"# Print annotations of a mention\\n\",\n        \"print(document.sentence[0].mentions[0])\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Qp66wjZ10xia\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"**Note**: Since the Stanza CoreNLP client interface simply ports the CoreNLP annotation results to native Python objects, for a comprehensive lists of available annotators and how their annotation results can be accessed, you will need to visit the [Stanford CoreNLP website](https://stanfordnlp.github.io/CoreNLP/).\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"IPqzMK90X0w3\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 4. Shutting Down the CoreNLP Server\\n\",\n        \"\\n\",\n        \"To shut down the background CoreNLP server process, simply call the `stop` function of the client. Note that once a server is shutdown, you'll have to restart the server with the `start()` function before any annotation is requested.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"xrJq8lZ3Nw7b\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"# Shut down the background CoreNLP server\\n\",\n        \"client.stop()\\n\",\n        \"\\n\",\n        \"time.sleep(10)\\n\",\n        \"!ps -o pid,cmd | grep java\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"23Vwa_ifYfF7\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"### More Information\\n\",\n        \"\\n\",\n        \"For more information on how to use the `CoreNLPClient`, please go to the [CoreNLPClient documentation page](https://stanfordnlp.github.io/stanza/corenlp_client.html).\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"YUrVT6kA_Bzx\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 5. Simplifying Client Usage with the Python `with` statement\\n\",\n        \"\\n\",\n        \"In the above demo, we explicitly called the `client.start()` and `client.stop()` functions to start and stop a client-server connection. However, doing this in practice is usually suboptimal, since you may forget to call the `stop()` function at the end, resulting in an unused server process occupying your machine memory.\\n\",\n        \"\\n\",\n        \"To solve is, a simple solution is to use the client interface with the [Python `with` statement](https://docs.python.org/3/reference/compound_stmts.html#the-with-statement). The `with` statement provides an elegant way to automatically start and stop the server process in your Python program, without you needing to worry about this. The following code snippet demonstrates how to establish a client, annotate an example text and then stop the server with a simple `with` statement. Note that we **always recommend** you to use the `with` statement when working with the Stanza CoreNLP client interface.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"H0ct2-R4AvJh\",\n        \"colab_type\": \"code\",\n        \"colab\": {}\n      },\n      \"source\": [\n        \"print(\\\"Starting a server with the Python \\\\\\\"with\\\\\\\" statement...\\\")\\n\",\n        \"with CoreNLPClient(annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \\n\",\n        \"                   memory='4G', endpoint='http://localhost:9001', be_quiet=True) as client:\\n\",\n        \"    text = \\\"Albert Einstein was a German-born theoretical physicist.\\\"\\n\",\n        \"    document = client.annotate(text)\\n\",\n        \"\\n\",\n        \"    print(\\\"{:30s}\\\\t{}\\\".format(\\\"Mention\\\", \\\"Type\\\"))\\n\",\n        \"    for sent in document.sentence:\\n\",\n        \"        for m in sent.mentions:\\n\",\n        \"            print(\\\"{:30s}\\\\t{}\\\".format(m.entityMentionText, m.entityType))\\n\",\n        \"\\n\",\n        \"print(\\\"\\\\nThe server should be stopped upon exit from the \\\\\\\"with\\\\\\\" statement.\\\")\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"W435Lwc4YqKb\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"## 6. Other Resources\\n\",\n        \"\\n\",\n        \"- [Stanza Homepage](https://stanfordnlp.github.io/stanza/)\\n\",\n        \"- [FAQs](https://stanfordnlp.github.io/stanza/faq.html)\\n\",\n        \"- [GitHub Repo](https://github.com/stanfordnlp/stanza)\\n\",\n        \"- [Reporting Issues](https://github.com/stanfordnlp/stanza/issues)\\n\"\n      ]\n    }\n  ]\n}"
  },
  {
    "path": "demo/arabic_test.conllu.txt",
    "content": "# newdoc id = assabah.20041005.0017\n# newpar id = assabah.20041005.0017:p1\n# sent_id = assabah.20041005.0017:p1u1\n# text = سوريا: تعديل وزاري واسع يشمل 8 حقائب\n# orig_file_sentence ASB_ARB_20041005.0017#1\n1\tسوريا\tسُورِيَا\tX\tX---------\tForeign=Yes\t0\troot\t0:root\tSpaceAfter=No|Vform=سُورِيَا|Gloss=Syria|Root=sUr|Translit=sūriyā|LTranslit=sūriyā\n2\t:\t:\tPUNCT\tG---------\t_\t1\tpunct\t1:punct\tVform=:|Translit=:\n3\tتعديل\tتَعدِيل\tNOUN\tN------S1I\tCase=Nom|Definite=Ind|Number=Sing\t6\tnsubj\t6:nsubj\tVform=تَعدِيلٌ|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=taʿdīlun|LTranslit=taʿdīl\n4\tوزاري\tوِزَارِيّ\tADJ\tA-----MS1I\tCase=Nom|Definite=Ind|Gender=Masc|Number=Sing\t3\tamod\t3:amod\tVform=وِزَارِيٌّ|Gloss=ministry,ministerial|Root=w_z_r|Translit=wizārīyun|LTranslit=wizārīy\n5\tواسع\tوَاسِع\tADJ\tA-----MS1I\tCase=Nom|Definite=Ind|Gender=Masc|Number=Sing\t3\tamod\t3:amod\tVform=وَاسِعٌ|Gloss=wide,extensive,broad|Root=w_s_`|Translit=wāsiʿun|LTranslit=wāsiʿ\n6\tيشمل\tشَمِل\tVERB\tVIIA-3MS--\tAspect=Imp|Gender=Masc|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act\t1\tparataxis\t1:parataxis\tVform=يَشمَلُ|Gloss=comprise,include,contain|Root=^s_m_l|Translit=yašmalu|LTranslit=šamil\n7\t8\t8\tNUM\tQ---------\tNumForm=Digit\t6\tobj\t6:obj\tVform=٨|Translit=8\n8\tحقائب\tحَقِيبَة\tNOUN\tN------P2I\tCase=Gen|Definite=Ind|Number=Plur\t7\tnmod\t7:nmod:gen\tVform=حَقَائِبَ|Gloss=briefcase,suitcase,portfolio,luggage|Root=.h_q_b|Translit=ḥaqāʾiba|LTranslit=ḥaqībat\n\n# newpar id = assabah.20041005.0017:p2\n# sent_id = assabah.20041005.0017:p2u1\n# text = دمشق (وكالات الانباء) - اجرى الرئيس السوري بشار الاسد تعديلا حكومياً واسعا تم بموجبه إقالة وزيري الداخلية والاعلام عن منصبيها في حين ظل محمد ناجي العطري رئيساً للحكومة.\n# orig_file_sentence ASB_ARB_20041005.0017#2\n1\tدمشق\tدمشق\tX\tU---------\t_\t0\troot\t0:root\tVform=دمشق|Root=OOV|Translit=dmšq\n2\t(\t(\tPUNCT\tG---------\t_\t3\tpunct\t3:punct\tSpaceAfter=No|Vform=(|Translit=(\n3\tوكالات\tوِكَالَة\tNOUN\tN------P1R\tCase=Nom|Definite=Cons|Number=Plur\t1\tdep\t1:dep\tVform=وِكَالَاتُ|Gloss=agency|Root=w_k_l|Translit=wikālātu|LTranslit=wikālat\n4\tالانباء\tنَبَأ\tNOUN\tN------P2D\tCase=Gen|Definite=Def|Number=Plur\t3\tnmod\t3:nmod:gen\tSpaceAfter=No|Vform=اَلأَنبَاءِ|Gloss=news_item,report|Root=n_b_'|Translit=al-ʾanbāʾi|LTranslit=nabaʾ\n5\t)\t)\tPUNCT\tG---------\t_\t3\tpunct\t3:punct\tVform=)|Translit=)\n6\t-\t-\tPUNCT\tG---------\t_\t1\tpunct\t1:punct\tVform=-|Translit=-\n7\tاجرى\tأَجرَى\tVERB\tVP-A-3MS--\tAspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act\t1\tadvcl\t1:advcl:فِي_حِينَ\tVform=أَجرَى|Gloss=conduct,carry_out,perform|Root=^g_r_y|Translit=ʾaǧrā|LTranslit=ʾaǧrā\n8\tالرئيس\tرَئِيس\tNOUN\tN------S1D\tCase=Nom|Definite=Def|Number=Sing\t7\tnsubj\t7:nsubj\tVform=اَلرَّئِيسُ|Gloss=president,head,chairman|Root=r_'_s|Translit=ar-raʾīsu|LTranslit=raʾīs\n9\tالسوري\tسُورِيّ\tADJ\tA-----MS1D\tCase=Nom|Definite=Def|Gender=Masc|Number=Sing\t8\tamod\t8:amod\tVform=اَلسُّورِيُّ|Gloss=Syrian|Root=sUr|Translit=as-sūrīyu|LTranslit=sūrīy\n10\tبشار\tبشار\tX\tU---------\t_\t11\tnmod\t11:nmod\tVform=بشار|Root=OOV|Translit=bšār\n11\tالاسد\tالاسد\tX\tU---------\t_\t8\tnmod\t8:nmod\tVform=الاسد|Root=OOV|Translit=ālāsd\n12\tتعديلا\tتَعدِيل\tNOUN\tN------S4I\tCase=Acc|Definite=Ind|Number=Sing\t7\tobj\t7:obj\tVform=تَعدِيلًا|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=taʿdīlan|LTranslit=taʿdīl\n13\tحكومياً\tحُكُومِيّ\tADJ\tA-----MS4I\tCase=Acc|Definite=Ind|Gender=Masc|Number=Sing\t12\tamod\t12:amod\tVform=حُكُومِيًّا|Gloss=governmental,state,official|Root=.h_k_m|Translit=ḥukūmīyan|LTranslit=ḥukūmīy\n14\tواسعا\tوَاسِع\tADJ\tA-----MS4I\tCase=Acc|Definite=Ind|Gender=Masc|Number=Sing\t12\tamod\t12:amod\tVform=وَاسِعًا|Gloss=wide,extensive,broad|Root=w_s_`|Translit=wāsiʿan|LTranslit=wāsiʿ\n15\tتم\tتَمّ\tVERB\tVP-A-3MS--\tAspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act\t12\tacl\t12:acl\tVform=تَمَّ|Gloss=conclude,take_place|Root=t_m_m|Translit=tamma|LTranslit=tamm\n16-18\tبموجبه\t_\t_\t_\t_\t_\t_\t_\t_\n16\tب\tبِ\tADP\tP---------\tAdpType=Prep\t18\tcase\t18:case\tVform=بِ|Gloss=by,with|Root=bi|Translit=bi|LTranslit=bi\n17\tموجب\tمُوجِب\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t16\tfixed\t16:fixed\tVform=مُوجِبِ|Gloss=reason,motive|Root=w_^g_b|Translit=mūǧibi|LTranslit=mūǧib\n18\tه\tهُوَ\tPRON\tSP---3MS2-\tCase=Gen|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t15\tnmod\t15:nmod:بِ_مُوجِب:gen\tVform=هِ|Gloss=he,she,it|Translit=hi|LTranslit=huwa\n19\tإقالة\tإِقَالَة\tNOUN\tN------S1R\tCase=Nom|Definite=Cons|Number=Sing\t15\tnsubj\t15:nsubj\tVform=إِقَالَةُ|Gloss=dismissal,discharge|Root=q_y_l|Translit=ʾiqālatu|LTranslit=ʾiqālat\n20\tوزيري\tوَزِير\tNOUN\tN------D2R\tCase=Gen|Definite=Cons|Number=Dual\t19\tnmod\t19:nmod:gen\tVform=وَزِيرَي|Gloss=minister|Root=w_z_r|Translit=wazīray|LTranslit=wazīr\n21\tالداخلية\tدَاخِلِيّ\tADJ\tA-----FS2D\tCase=Gen|Definite=Def|Gender=Fem|Number=Sing\t20\tamod\t20:amod\tVform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy\n22-23\tوالاعلام\t_\t_\t_\t_\t_\t_\t_\t_\n22\tو\tوَ\tCCONJ\tC---------\t_\t23\tcc\t23:cc\tVform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa\n23\tالإعلام\tإِعلَام\tNOUN\tN------S2D\tCase=Gen|Definite=Def|Number=Sing\t21\tconj\t20:amod|21:conj\tVform=اَلإِعلَامِ|Gloss=information,media|Root=`_l_m|Translit=al-ʾiʿlāmi|LTranslit=ʾiʿlām\n24\tعن\tعَن\tADP\tP---------\tAdpType=Prep\t25\tcase\t25:case\tVform=عَن|Gloss=about,from|Root=`an|Translit=ʿan|LTranslit=ʿan\n25-26\tمنصبيها\t_\t_\t_\t_\t_\t_\t_\t_\n25\tمنصبي\tمَنصِب\tNOUN\tN------D2R\tCase=Gen|Definite=Cons|Number=Dual\t19\tnmod\t19:nmod:عَن:gen\tVform=مَنصِبَي|Gloss=post,position,office|Root=n_.s_b|Translit=manṣibay|LTranslit=manṣib\n26\tها\tهُوَ\tPRON\tSP---3FS2-\tCase=Gen|Gender=Fem|Number=Sing|Person=3|PronType=Prs\t25\tnmod\t25:nmod:gen\tVform=هَا|Gloss=he,she,it|Translit=hā|LTranslit=huwa\n27\tفي\tفِي\tADP\tP---------\tAdpType=Prep\t7\tmark\t7:mark\tVform=فِي|Gloss=in|Root=fI|Translit=fī|LTranslit=fī\n28\tحين\tحِينَ\tADP\tPI------2-\tAdpType=Prep|Case=Gen\t7\tmark\t7:mark\tVform=حِينِ|Gloss=when|Root=.h_y_n|Translit=ḥīni|LTranslit=ḥīna\n29\tظل\tظَلّ\tVERB\tVP-A-3MS--\tAspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act\t7\tparataxis\t7:parataxis\tVform=ظَلَّ|Gloss=remain,continue|Root=.z_l_l|Translit=ẓalla|LTranslit=ẓall\n30\tمحمد\tمحمد\tX\tU---------\t_\t32\tnmod\t32:nmod\tVform=محمد|Root=OOV|Translit=mḥmd\n31\tناجي\tناجي\tX\tU---------\t_\t32\tnmod\t32:nmod\tVform=ناجي|Root=OOV|Translit=nāǧy\n32\tالعطري\tالعطري\tX\tU---------\t_\t29\tnsubj\t29:nsubj\tVform=العطري|Root=OOV|Translit=ālʿṭry\n33\tرئيساً\tرَئِيس\tNOUN\tN------S4I\tCase=Acc|Definite=Ind|Number=Sing\t29\txcomp\t29:xcomp\tVform=رَئِيسًا|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsan|LTranslit=raʾīs\n34-35\tللحكومة\t_\t_\t_\t_\t_\t_\t_\tSpaceAfter=No\n34\tل\tلِ\tADP\tP---------\tAdpType=Prep\t35\tcase\t35:case\tVform=لِ|Gloss=for,to|Root=l|Translit=li|LTranslit=li\n35\tالحكومة\tحُكُومَة\tNOUN\tN------S2D\tCase=Gen|Definite=Def|Number=Sing\t33\tnmod\t33:nmod:لِ:gen\tVform=اَلحُكُومَةِ|Gloss=government,administration|Root=.h_k_m|Translit=al-ḥukūmati|LTranslit=ḥukūmat\n36\t.\t.\tPUNCT\tG---------\t_\t1\tpunct\t1:punct\tVform=.|Translit=.\n\n# newpar id = assabah.20041005.0017:p3\n# sent_id = assabah.20041005.0017:p3u1\n# text = واضافت المصادر ان مهدي دخل الله رئيس تحرير صحيفة الحزب الحاكم والليبرالي التوجهات تسلم منصب وزير الاعلام خلفا لاحمد الحسن فيما تسلم اللواء غازي كنعان رئيس شعبة الامن السياسي منصب وزير الداخلية.\n# orig_file_sentence ASB_ARB_20041005.0017#3\n1-2\tواضافت\t_\t_\t_\t_\t_\t_\t_\t_\n1\tو\tوَ\tCCONJ\tC---------\t_\t0\troot\t0:root\tVform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa\n2\tأضافت\tأَضَاف\tVERB\tVP-A-3FS--\tAspect=Perf|Gender=Fem|Number=Sing|Person=3|Voice=Act\t1\tparataxis\t1:parataxis\tVform=أَضَافَت|Gloss=add,attach,receive_as_guest|Root=.d_y_f|Translit=ʾaḍāfat|LTranslit=ʾaḍāf\n3\tالمصادر\tمَصدَر\tNOUN\tN------P1D\tCase=Nom|Definite=Def|Number=Plur\t2\tnsubj\t2:nsubj\tVform=اَلمَصَادِرُ|Gloss=source|Root=.s_d_r|Translit=al-maṣādiru|LTranslit=maṣdar\n4\tان\tأَنَّ\tSCONJ\tC---------\t_\t16\tmark\t16:mark\tVform=أَنَّ|Gloss=that|Root='_n|Translit=ʾanna|LTranslit=ʾanna\n5\tمهدي\tمهدي\tX\tU---------\t_\t6\tnmod\t6:nmod\tVform=مهدي|Root=OOV|Translit=mhdy\n6\tدخل\tدخل\tX\tU---------\t_\t16\tnsubj\t16:nsubj\tVform=دخل|Root=OOV|Translit=dḫl\n7\tالله\tالله\tX\tU---------\t_\t6\tnmod\t6:nmod\tVform=الله|Root=OOV|Translit=āllh\n8\tرئيس\tرَئِيس\tNOUN\tN------S4R\tCase=Acc|Definite=Cons|Number=Sing\t6\tnmod\t6:nmod:acc\tVform=رَئِيسَ|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsa|LTranslit=raʾīs\n9\tتحرير\tتَحرِير\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t8\tnmod\t8:nmod:gen\tVform=تَحرِيرِ|Gloss=liberation,liberating,editorship,editing|Root=.h_r_r|Translit=taḥrīri|LTranslit=taḥrīr\n10\tصحيفة\tصَحِيفَة\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t9\tnmod\t9:nmod:gen\tVform=صَحِيفَةِ|Gloss=newspaper,sheet,leaf|Root=.s_.h_f|Translit=ṣaḥīfati|LTranslit=ṣaḥīfat\n11\tالحزب\tحِزب\tNOUN\tN------S2D\tCase=Gen|Definite=Def|Number=Sing\t10\tnmod\t10:nmod:gen\tVform=اَلحِزبِ|Gloss=party,band|Root=.h_z_b|Translit=al-ḥizbi|LTranslit=ḥizb\n12\tالحاكم\tحَاكِم\tNOUN\tN------S2D\tCase=Gen|Definite=Def|Number=Sing\t11\tnmod\t11:nmod:gen\tVform=اَلحَاكِمِ|Gloss=ruler,governor|Root=.h_k_m|Translit=al-ḥākimi|LTranslit=ḥākim\n13-14\tوالليبرالي\t_\t_\t_\t_\t_\t_\t_\t_\n13\tو\tوَ\tCCONJ\tC---------\t_\t6\tcc\t6:cc\tVform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa\n14\tالليبرالي\tلِيبِرَالِيّ\tADJ\tA-----MS4D\tCase=Acc|Definite=Def|Gender=Masc|Number=Sing\t6\tamod\t6:amod\tVform=اَللِّيبِرَالِيَّ|Gloss=liberal|Root=lIbirAl|Translit=al-lībirālīya|LTranslit=lībirālīy\n15\tالتوجهات\tتَوَجُّه\tNOUN\tN------P2D\tCase=Gen|Definite=Def|Number=Plur\t14\tnmod\t14:nmod:gen\tVform=اَلتَّوَجُّهَاتِ|Gloss=attitude,approach|Root=w_^g_h|Translit=at-tawaǧǧuhāti|LTranslit=tawaǧǧuh\n16\tتسلم\tتَسَلَّم\tVERB\tVP-A-3MS--\tAspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act\t2\tccomp\t2:ccomp\tVform=تَسَلَّمَ|Gloss=receive,assume|Root=s_l_m|Translit=tasallama|LTranslit=tasallam\n17\tمنصب\tمَنصِب\tNOUN\tN------S4R\tCase=Acc|Definite=Cons|Number=Sing\t16\tobj\t16:obj\tVform=مَنصِبَ|Gloss=post,position,office|Root=n_.s_b|Translit=manṣiba|LTranslit=manṣib\n18\tوزير\tوَزِير\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t17\tnmod\t17:nmod:gen\tVform=وَزِيرِ|Gloss=minister|Root=w_z_r|Translit=wazīri|LTranslit=wazīr\n19\tالاعلام\tعَلَم\tNOUN\tN------P2D\tCase=Gen|Definite=Def|Number=Plur\t18\tnmod\t18:nmod:gen\tVform=اَلأَعلَامِ|Gloss=flag,banner,badge|Root=`_l_m|Translit=al-ʾaʿlāmi|LTranslit=ʿalam\n20\tخلفا\tخَلَف\tNOUN\tN------S4I\tCase=Acc|Definite=Ind|Number=Sing\t16\tobl\t16:obl:acc\tVform=خَلَفًا|Gloss=substitute,scion|Root=_h_l_f|Translit=ḫalafan|LTranslit=ḫalaf\n21-22\tلاحمد\t_\t_\t_\t_\t_\t_\t_\t_\n21\tل\tلِ\tADP\tP---------\tAdpType=Prep\t23\tcase\t23:case\tVform=لِ|Gloss=for,to|Root=l|Translit=li|LTranslit=li\n22\tأحمد\tأَحمَد\tNOUN\tN------S2I\tCase=Gen|Definite=Ind|Number=Sing\t23\tnmod\t23:nmod:gen\tVform=أَحمَدَ|Gloss=Ahmad|Root=.h_m_d|Translit=ʾaḥmada|LTranslit=ʾaḥmad\n23\tالحسن\tالحسن\tX\tU---------\t_\t20\tnmod\t20:nmod:لِ\tVform=الحسن|Root=OOV|Translit=ālḥsn\n24\tفيما\tفِيمَا\tCCONJ\tC---------\t_\t25\tcc\t25:cc\tVform=فِيمَا|Gloss=while,during_which|Root=fI|Translit=fīmā|LTranslit=fīmā\n25\tتسلم\tتَسَلَّم\tVERB\tVP-A-3MS--\tAspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act\t16\tconj\t2:ccomp|16:conj\tVform=تَسَلَّمَ|Gloss=receive,assume|Root=s_l_m|Translit=tasallama|LTranslit=tasallam\n26\tاللواء\tلِوَاء\tNOUN\tN------S1D\tCase=Nom|Definite=Def|Number=Sing\t25\tnsubj\t25:nsubj\tVform=اَللِّوَاءُ|Gloss=banner,flag|Root=l_w_y|Translit=al-liwāʾu|LTranslit=liwāʾ\n27\tغازي\tغازي\tX\tU---------\t_\t28\tnmod\t28:nmod\tVform=غازي|Root=OOV|Translit=ġāzy\n28\tكنعان\tكنعان\tX\tU---------\t_\t26\tnmod\t26:nmod\tVform=كنعان|Root=OOV|Translit=knʿān\n29\tرئيس\tرَئِيس\tNOUN\tN------S1R\tCase=Nom|Definite=Cons|Number=Sing\t26\tnmod\t26:nmod:nom\tVform=رَئِيسُ|Gloss=president,head,chairman|Root=r_'_s|Translit=raʾīsu|LTranslit=raʾīs\n30\tشعبة\tشُعبَة\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t29\tnmod\t29:nmod:gen\tVform=شُعبَةِ|Gloss=branch,subdivision|Root=^s_`_b|Translit=šuʿbati|LTranslit=šuʿbat\n31\tالامن\tأَمن\tNOUN\tN------S2D\tCase=Gen|Definite=Def|Number=Sing\t30\tnmod\t30:nmod:gen\tVform=اَلأَمنِ|Gloss=security,safety|Root='_m_n|Translit=al-ʾamni|LTranslit=ʾamn\n32\tالسياسي\tسِيَاسِيّ\tADJ\tA-----MS2D\tCase=Gen|Definite=Def|Gender=Masc|Number=Sing\t31\tamod\t31:amod\tVform=اَلسِّيَاسِيِّ|Gloss=political|Root=s_w_s|Translit=as-siyāsīyi|LTranslit=siyāsīy\n33\tمنصب\tمَنصِب\tNOUN\tN------S4R\tCase=Acc|Definite=Cons|Number=Sing\t25\tobj\t25:obj\tVform=مَنصِبَ|Gloss=post,position,office|Root=n_.s_b|Translit=manṣiba|LTranslit=manṣib\n34\tوزير\tوَزِير\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t33\tnmod\t33:nmod:gen\tVform=وَزِيرِ|Gloss=minister|Root=w_z_r|Translit=wazīri|LTranslit=wazīr\n35\tالداخلية\tدَاخِلِيّ\tADJ\tA-----FS2D\tCase=Gen|Definite=Def|Gender=Fem|Number=Sing\t34\tamod\t34:amod\tSpaceAfter=No|Vform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy\n36\t.\t.\tPUNCT\tG---------\t_\t1\tpunct\t1:punct\tVform=.|Translit=.\n\n# newpar id = assabah.20041005.0017:p4\n# sent_id = assabah.20041005.0017:p4u1\n# text = وذكرت وكالة الانباء السورية ان التعديل شمل ثماني حقائب بينها وزارتا الداخلية والاقتصاد.\n# orig_file_sentence ASB_ARB_20041005.0017#4\n1-2\tوذكرت\t_\t_\t_\t_\t_\t_\t_\t_\n1\tو\tوَ\tCCONJ\tC---------\t_\t0\troot\t0:root\tVform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa\n2\tذكرت\tذَكَر\tVERB\tVP-A-3FS--\tAspect=Perf|Gender=Fem|Number=Sing|Person=3|Voice=Act\t1\tparataxis\t1:parataxis\tVform=ذَكَرَت|Gloss=mention,cite,remember|Root=_d_k_r|Translit=ḏakarat|LTranslit=ḏakar\n3\tوكالة\tوِكَالَة\tNOUN\tN------S1R\tCase=Nom|Definite=Cons|Number=Sing\t2\tnsubj\t2:nsubj\tVform=وِكَالَةُ|Gloss=agency|Root=w_k_l|Translit=wikālatu|LTranslit=wikālat\n4\tالانباء\tنَبَأ\tNOUN\tN------P2D\tCase=Gen|Definite=Def|Number=Plur\t3\tnmod\t3:nmod:gen\tVform=اَلأَنبَاءِ|Gloss=news_item,report|Root=n_b_'|Translit=al-ʾanbāʾi|LTranslit=nabaʾ\n5\tالسورية\tسُورِيّ\tADJ\tA-----FS1D\tCase=Nom|Definite=Def|Gender=Fem|Number=Sing\t3\tamod\t3:amod\tVform=اَلسُّورِيَّةُ|Gloss=Syrian|Root=sUr|Translit=as-sūrīyatu|LTranslit=sūrīy\n6\tان\tأَنَّ\tSCONJ\tC---------\t_\t8\tmark\t8:mark\tVform=أَنَّ|Gloss=that|Root='_n|Translit=ʾanna|LTranslit=ʾanna\n7\tالتعديل\tتَعدِيل\tNOUN\tN------S4D\tCase=Acc|Definite=Def|Number=Sing\t8\tobl\t8:obl:acc\tVform=اَلتَّعدِيلَ|Gloss=adjustment,change,modification,amendment|Root=`_d_l|Translit=at-taʿdīla|LTranslit=taʿdīl\n8\tشمل\tشَمِل\tVERB\tVP-A-3MS--\tAspect=Perf|Gender=Masc|Number=Sing|Person=3|Voice=Act\t2\tccomp\t2:ccomp\tVform=شَمِلَ|Gloss=comprise,include,contain|Root=^s_m_l|Translit=šamila|LTranslit=šamil\n9\tثماني\tثَمَانُون\tNUM\tQL------4R\tCase=Acc|Definite=Cons|NumForm=Word\t8\tobj\t8:obj\tVform=ثَمَانِي|Gloss=eighty|Root=_t_m_n|Translit=ṯamānī|LTranslit=ṯamānūn\n10\tحقائب\tحَقِيبَة\tNOUN\tN------P2I\tCase=Gen|Definite=Ind|Number=Plur\t9\tnmod\t9:nmod:gen\tVform=حَقَائِبَ|Gloss=briefcase,suitcase,portfolio,luggage|Root=.h_q_b|Translit=ḥaqāʾiba|LTranslit=ḥaqībat\n11-12\tبينها\t_\t_\t_\t_\t_\t_\t_\t_\n11\tبين\tبَينَ\tADP\tPI------4-\tAdpType=Prep|Case=Acc\t12\tcase\t12:case\tVform=بَينَ|Gloss=between,among|Root=b_y_n|Translit=bayna|LTranslit=bayna\n12\tها\tهُوَ\tPRON\tSP---3FS2-\tCase=Gen|Gender=Fem|Number=Sing|Person=3|PronType=Prs\t10\tobl\t10:obl:بَينَ:gen\tVform=هَا|Gloss=he,she,it|Translit=hā|LTranslit=huwa\n13\tوزارتا\tوِزَارَة\tNOUN\tN------D1R\tCase=Nom|Definite=Cons|Number=Dual\t12\tnsubj\t12:nsubj\tVform=وِزَارَتَا|Gloss=ministry|Root=w_z_r|Translit=wizāratā|LTranslit=wizārat\n14\tالداخلية\tدَاخِلِيّ\tADJ\tA-----FS2D\tCase=Gen|Definite=Def|Gender=Fem|Number=Sing\t13\tamod\t13:amod\tVform=اَلدَّاخِلِيَّةِ|Gloss=internal,domestic,interior,of_state|Root=d__h_l|Translit=ad-dāḫilīyati|LTranslit=dāḫilīy\n15-16\tوالاقتصاد\t_\t_\t_\t_\t_\t_\t_\tSpaceAfter=No\n15\tو\tوَ\tCCONJ\tC---------\t_\t16\tcc\t16:cc\tVform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa\n16\tالاقتصاد\tاِقتِصَاد\tNOUN\tN------S2D\tCase=Gen|Definite=Def|Number=Sing\t14\tconj\t13:amod|14:conj\tVform=اَلِاقتِصَادِ|Gloss=economy,saving|Root=q_.s_d|Translit=al-i-ʼqtiṣādi|LTranslit=iqtiṣād\n17\t.\t.\tPUNCT\tG---------\t_\t1\tpunct\t1:punct\tVform=.|Translit=."
  },
  {
    "path": "demo/corenlp.py",
    "content": "from stanza.server import CoreNLPClient\n\n# example text\nprint('---')\nprint('input text')\nprint('')\n\ntext = \"Chris Manning is a nice person. Chris wrote a simple sentence. He also gives oranges to people.\"\n\nprint(text)\n\n# set up the client\nprint('---')\nprint('starting up Java Stanford CoreNLP Server...')\n\n# set up the client\nwith CoreNLPClient(annotators=['tokenize','ssplit','pos','lemma','ner','parse','depparse','coref'], timeout=60000, memory='16G') as client:\n    # submit the request to the server\n    ann = client.annotate(text)\n\n    # get the first sentence\n    sentence = ann.sentence[0]\n\n    # get the dependency parse of the first sentence\n    print('---')\n    print('dependency parse of first sentence')\n    dependency_parse = sentence.basicDependencies\n    print(dependency_parse)\n \n    # get the constituency parse of the first sentence\n    print('---')\n    print('constituency parse of first sentence')\n    constituency_parse = sentence.parseTree\n    print(constituency_parse)\n\n    # get the first subtree of the constituency parse\n    print('---')\n    print('first subtree of constituency parse')\n    print(constituency_parse.child[0])\n\n    # get the value of the first subtree\n    print('---')\n    print('value of first subtree of constituency parse')\n    print(constituency_parse.child[0].value)\n\n    # get the first token of the first sentence\n    print('---')\n    print('first token of first sentence')\n    token = sentence.token[0]\n    print(token)\n\n    # get the part-of-speech tag\n    print('---')\n    print('part of speech tag of token')\n    token.pos\n    print(token.pos)\n\n    # get the named entity tag\n    print('---')\n    print('named entity tag of token')\n    print(token.ner)\n\n    # get an entity mention from the first sentence\n    print('---')\n    print('first entity mention in sentence')\n    print(sentence.mentions[0])\n\n    # access the coref chain\n    print('---')\n    print('coref chains for the example')\n    print(ann.corefChain)\n\n    # Use tokensregex patterns to find who wrote a sentence.\n    pattern = '([ner: PERSON]+) /wrote/ /an?/ []{0,3} /sentence|article/'\n    matches = client.tokensregex(text, pattern)\n    # sentences contains a list with matches for each sentence.\n    assert len(matches[\"sentences\"]) == 3\n    # length tells you whether or not there are any matches in this\n    assert matches[\"sentences\"][1][\"length\"] == 1\n    # You can access matches like most regex groups.\n    matches[\"sentences\"][1][\"0\"][\"text\"] == \"Chris wrote a simple sentence\"\n    matches[\"sentences\"][1][\"0\"][\"1\"][\"text\"] == \"Chris\"\n\n    # Use semgrex patterns to directly find who wrote what.\n    pattern = '{word:wrote} >nsubj {}=subject >obj {}=object'\n    matches = client.semgrex(text, pattern)\n    # sentences contains a list with matches for each sentence.\n    assert len(matches[\"sentences\"]) == 3\n    # length tells you whether or not there are any matches in this\n    assert matches[\"sentences\"][1][\"length\"] == 1\n    # You can access matches like most regex groups.\n    matches[\"sentences\"][1][\"0\"][\"text\"] == \"wrote\"\n    matches[\"sentences\"][1][\"0\"][\"$subject\"][\"text\"] == \"Chris\"\n    matches[\"sentences\"][1][\"0\"][\"$object\"][\"text\"] == \"sentence\"\n\n"
  },
  {
    "path": "demo/en_test.conllu.txt",
    "content": "# newdoc id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200\n# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0001\n# newpar id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-p0001\n# text = What if Google Morphed Into GoogleOS?\n1\tWhat\twhat\tPRON\tWP\tPronType=Int\t0\troot\t0:root\t_\n2\tif\tif\tSCONJ\tIN\t_\t4\tmark\t4:mark\t_\n3\tGoogle\tGoogle\tPROPN\tNNP\tNumber=Sing\t4\tnsubj\t4:nsubj\t_\n4\tMorphed\tmorph\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t1\tadvcl\t1:advcl:if\t_\n5\tInto\tinto\tADP\tIN\t_\t6\tcase\t6:case\t_\n6\tGoogleOS\tGoogleOS\tPROPN\tNNP\tNumber=Sing\t4\tobl\t4:obl:into\tSpaceAfter=No\n7\t?\t?\tPUNCT\t.\t_\t4\tpunct\t4:punct\t_\n\n# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0002\n# text = What if Google expanded on its search-engine (and now e-mail) wares into a full-fledged operating system?\n1\tWhat\twhat\tPRON\tWP\tPronType=Int\t0\troot\t0:root\t_\n2\tif\tif\tSCONJ\tIN\t_\t4\tmark\t4:mark\t_\n3\tGoogle\tGoogle\tPROPN\tNNP\tNumber=Sing\t4\tnsubj\t4:nsubj\t_\n4\texpanded\texpand\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t1\tadvcl\t1:advcl:if\t_\n5\ton\ton\tADP\tIN\t_\t15\tcase\t15:case\t_\n6\tits\tits\tPRON\tPRP$\tGender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t15\tnmod:poss\t15:nmod:poss\t_\n7\tsearch\tsearch\tNOUN\tNN\tNumber=Sing\t9\tcompound\t9:compound\tSpaceAfter=No\n8\t-\t-\tPUNCT\tHYPH\t_\t9\tpunct\t9:punct\tSpaceAfter=No\n9\tengine\tengine\tNOUN\tNN\tNumber=Sing\t15\tcompound\t15:compound\t_\n10\t(\t(\tPUNCT\t-LRB-\t_\t9\tpunct\t9:punct\tSpaceAfter=No\n11\tand\tand\tCCONJ\tCC\t_\t13\tcc\t13:cc\t_\n12\tnow\tnow\tADV\tRB\t_\t13\tadvmod\t13:advmod\t_\n13\te-mail\te-mail\tNOUN\tNN\tNumber=Sing\t9\tconj\t9:conj:and|15:compound\tSpaceAfter=No\n14\t)\t)\tPUNCT\t-RRB-\t_\t15\tpunct\t15:punct\t_\n15\twares\twares\tNOUN\tNNS\tNumber=Plur\t4\tobl\t4:obl:on\t_\n16\tinto\tinto\tADP\tIN\t_\t22\tcase\t22:case\t_\n17\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t22\tdet\t22:det\t_\n18\tfull\tfull\tADV\tRB\t_\t20\tadvmod\t20:advmod\tSpaceAfter=No\n19\t-\t-\tPUNCT\tHYPH\t_\t20\tpunct\t20:punct\tSpaceAfter=No\n20\tfledged\tfledged\tADJ\tJJ\tDegree=Pos\t22\tamod\t22:amod\t_\n21\toperating\toperating\tNOUN\tNN\tNumber=Sing\t22\tcompound\t22:compound\t_\n22\tsystem\tsystem\tNOUN\tNN\tNumber=Sing\t4\tobl\t4:obl:into\tSpaceAfter=No\n23\t?\t?\tPUNCT\t.\t_\t4\tpunct\t4:punct\t_\n\n# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0003\n# text = [via Microsoft Watch from Mary Jo Foley ]\n1\t[\t[\tPUNCT\t-LRB-\t_\t4\tpunct\t4:punct\tSpaceAfter=No\n2\tvia\tvia\tADP\tIN\t_\t4\tcase\t4:case\t_\n3\tMicrosoft\tMicrosoft\tPROPN\tNNP\tNumber=Sing\t4\tcompound\t4:compound\t_\n4\tWatch\tWatch\tPROPN\tNNP\tNumber=Sing\t0\troot\t0:root\t_\n5\tfrom\tfrom\tADP\tIN\t_\t6\tcase\t6:case\t_\n6\tMary\tMary\tPROPN\tNNP\tNumber=Sing\t4\tnmod\t4:nmod:from\t_\n7\tJo\tJo\tPROPN\tNNP\tNumber=Sing\t6\tflat\t6:flat\t_\n8\tFoley\tFoley\tPROPN\tNNP\tNumber=Sing\t6\tflat\t6:flat\t_\n9\t]\t]\tPUNCT\t-RRB-\t_\t4\tpunct\t4:punct\t_\n\n# newdoc id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700\n# sent_id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-0001\n# newpar id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-p0001\n# text = (And, by the way, is anybody else just a little nostalgic for the days when that was a good thing?)\n1\t(\t(\tPUNCT\t-LRB-\t_\t14\tpunct\t14:punct\tSpaceAfter=No\n2\tAnd\tand\tCCONJ\tCC\t_\t14\tcc\t14:cc\tSpaceAfter=No\n3\t,\t,\tPUNCT\t,\t_\t14\tpunct\t14:punct\t_\n4\tby\tby\tADP\tIN\t_\t6\tcase\t6:case\t_\n5\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t6\tdet\t6:det\t_\n6\tway\tway\tNOUN\tNN\tNumber=Sing\t14\tobl\t14:obl:by\tSpaceAfter=No\n7\t,\t,\tPUNCT\t,\t_\t14\tpunct\t14:punct\t_\n8\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t14\tcop\t14:cop\t_\n9\tanybody\tanybody\tPRON\tNN\tNumber=Sing\t14\tnsubj\t14:nsubj\t_\n10\telse\telse\tADJ\tJJ\tDegree=Pos\t9\tamod\t9:amod\t_\n11\tjust\tjust\tADV\tRB\t_\t13\tadvmod\t13:advmod\t_\n12\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t13\tdet\t13:det\t_\n13\tlittle\tlittle\tADJ\tJJ\tDegree=Pos\t14\tobl:npmod\t14:obl:npmod\t_\n14\tnostalgic\tnostalgic\tNOUN\tNN\tNumber=Sing\t0\troot\t0:root\t_\n15\tfor\tfor\tADP\tIN\t_\t17\tcase\t17:case\t_\n16\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t17\tdet\t17:det\t_\n17\tdays\tday\tNOUN\tNNS\tNumber=Plur\t14\tnmod\t14:nmod:for|23:obl:npmod\t_\n18\twhen\twhen\tADV\tWRB\tPronType=Rel\t23\tadvmod\t17:ref\t_\n19\tthat\tthat\tPRON\tDT\tNumber=Sing|PronType=Dem\t23\tnsubj\t23:nsubj\t_\n20\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t23\tcop\t23:cop\t_\n21\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t23\tdet\t23:det\t_\n22\tgood\tgood\tADJ\tJJ\tDegree=Pos\t23\tamod\t23:amod\t_\n23\tthing\tthing\tNOUN\tNN\tNumber=Sing\t17\tacl:relcl\t17:acl:relcl\tSpaceAfter=No\n24\t?\t?\tPUNCT\t.\t_\t14\tpunct\t14:punct\tSpaceAfter=No\n25\t)\t)\tPUNCT\t-RRB-\t_\t14\tpunct\t14:punct\t_"
  },
  {
    "path": "demo/japanese_test.conllu.txt",
    "content": "# newdoc id = test-s1\n# sent_id = test-s1\n# text = これに不快感を示す住民はいましたが,現在,表立って反対や抗議の声を挙げている住民はいないようです。\n1\tこれ\t此れ\tPRON\t代名詞\t_\t6\tobl\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=代名詞|SpaceAfter=No|UnidicInfo=,此れ,これ,これ,コレ,,,コレ,コレ,此れ\n2\tに\tに\tADP\t助詞-格助詞\t_\t1\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,に,に,に,ニ,,,ニ,ニ,に\n3\t不快\t不快\tNOUN\t名詞-普通名詞-形状詞可能\t_\t4\tcompound\t_\tBunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,不快,不快,不快,フカイ,,,フカイ,フカイカン,不快感\n4\t感\t感\tNOUN\t名詞-普通名詞-一般\t_\t6\tobj\t_\tBunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,感,感,感,カン,,,カン,フカイカン,不快感\n5\tを\tを\tADP\t助詞-格助詞\t_\t4\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,を,を,を,オ,,,ヲ,ヲ,を\n6\t示す\t示す\tVERB\t動詞-一般-五段-サ行\t_\t7\tacl\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-五段-サ行|SpaceAfter=No|UnidicInfo=,示す,示す,示す,シメス,,,シメス,シメス,示す\n7\t住民\t住民\tNOUN\t名詞-普通名詞-一般\t_\t9\tnsubj\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,住民,住民,住民,ジューミン,,,ジュウミン,ジュウミン,住民\n8\tは\tは\tADP\t助詞-係助詞\t_\t7\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は\n9\tい\t居る\tVERB\t動詞-非自立可能-上一段-ア行\t_\t29\tadvcl\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,イル,居る\n10\tまし\tます\tAUX\t助動詞-助動詞-マス\t_\t9\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-マス|SpaceAfter=No|UnidicInfo=,ます,まし,ます,マシ,,,マス,マス,ます\n11\tた\tた\tAUX\t助動詞-助動詞-タ\t_\t9\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-タ|SpaceAfter=No|UnidicInfo=,た,た,た,タ,,,タ,タ,た\n12\tが\tが\tSCONJ\t助詞-接続助詞\t_\t9\tmark\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-接続助詞|SpaceAfter=No|UnidicInfo=,が,が,が,ガ,,,ガ,ガ,が\n13\t,\t,\tPUNCT\t補助記号-読点\t_\t9\tpunct\t_\tBunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,，,,,,,,,,，\n14\t現在\t現在\tADV\t名詞-普通名詞-副詞可能\t_\t16\tadvmod\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=副詞|SpaceAfter=No|UnidicInfo=,現在,現在,現在,ゲンザイ,,,ゲンザイ,ゲンザイ,現在\n15\t,\t,\tPUNCT\t補助記号-読点\t_\t14\tpunct\t_\tBunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,，,,,,,,,,，\n16\t表立っ\t表立つ\tVERB\t動詞-一般-五段-タ行\t_\t24\tadvcl\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-五段-タ行|SpaceAfter=No|UnidicInfo=,表立つ,表立っ,表立つ,オモテダッ,,,オモテダツ,オモテダツ,表立つ\n17\tて\tて\tSCONJ\t助詞-接続助詞\t_\t16\tmark\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-接続助詞|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テ,て\n18\t反対\t反対\tNOUN\t名詞-普通名詞-サ変形状詞可能\t_\t20\tnmod\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,反対,反対,反対,ハンタイ,,,ハンタイ,ハンタイ,反対\n19\tや\tや\tADP\t助詞-副助詞\t_\t18\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-副助詞|SpaceAfter=No|UnidicInfo=,や,や,や,ヤ,,,ヤ,ヤ,や\n20\t抗議\t抗議\tNOUN\t名詞-普通名詞-サ変可能\t_\t22\tnmod\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,抗議,抗議,抗議,コーギ,,,コウギ,コウギ,抗議\n21\tの\tの\tADP\t助詞-格助詞\t_\t20\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,の,の,の,ノ,,,ノ,ノ,の\n22\t声\t声\tNOUN\t名詞-普通名詞-一般\t_\t24\tobj\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,声,声,声,コエ,,,コエ,コエ,声\n23\tを\tを\tADP\t助詞-格助詞\t_\t22\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,を,を,を,オ,,,ヲ,ヲ,を\n24\t挙げ\t上げる\tVERB\t動詞-非自立可能-下一段-ガ行\t_\t27\tacl\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-下一段-ガ行|SpaceAfter=No|UnidicInfo=,上げる,挙げ,挙げる,アゲ,,,アゲル,アゲル,上げる\n25\tて\tて\tSCONJ\t助詞-接続助詞\t_\t24\tmark\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-上一段-ア行|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テイル,ている\n26\tいる\t居る\tVERB\t動詞-非自立可能-上一段-ア行\t_\t25\tfixed\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=助動詞-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,いる,いる,イル,,,イル,テイル,ている\n27\t住民\t住民\tNOUN\t名詞-普通名詞-一般\t_\t29\tnsubj\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,住民,住民,住民,ジューミン,,,ジュウミン,ジュウミン,住民\n28\tは\tは\tADP\t助詞-係助詞\t_\t27\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は\n29\tい\t居る\tVERB\t動詞-非自立可能-上一段-ア行\t_\t0\troot\t_\tBunsetuBILabel=B|BunsetuPositionType=ROOT|LUWBILabel=B|LUWPOS=動詞-一般-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,イル,居る\n30\tない\tない\tAUX\t助動詞-助動詞-ナイ\tPolarity=Neg\t29\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-ナイ|SpaceAfter=No|UnidicInfo=,ない,ない,ない,ナイ,,,ナイ,ナイ,ない\n31\tよう\t様\tAUX\t形状詞-助動詞語幹\t_\t29\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=形状詞-助動詞語幹|PrevUDLemma=よう|SpaceAfter=No|UnidicInfo=,様,よう,よう,ヨー,,,ヨウ,ヨウ,様\n32\tです\tです\tAUX\t助動詞-助動詞-デス\t_\t29\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-デス|PrevUDLemma=だ|SpaceAfter=No|UnidicInfo=,です,です,です,デス,,,デス,デス,です\n33\t。\t。\tPUNCT\t補助記号-句点\t_\t29\tpunct\t_\tBunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。\n\n# newdoc id = test-s2\n# sent_id = test-s2\n# text = 幸福の科学側からは,特にどうしてほしいという要望はいただいていません。\n1\t幸福\t幸福\tNOUN\t名詞-普通名詞-形状詞可能\t_\t4\tnmod\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,幸福,幸福,幸福,コーフク,,,コウフク,コウフクノカガクガワ,幸福の科学側\n2\tの\tの\tADP\t助詞-格助詞\t_\t1\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,の,の,の,ノ,,,ノ,コウフクノカガクガワ,幸福の科学側\n3\t科学\t科学\tNOUN\t名詞-普通名詞-サ変可能\t_\t4\tcompound\t_\tBunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,科学,科学,科学,カガク,,,カガク,コウフクノカガクガワ,幸福の科学側\n4\t側\t側\tNOUN\t名詞-普通名詞-一般\t_\t17\tobl\t_\tBunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,側,側,側,ガワ,,,ガワ,コウフクノカガクガワ,幸福の科学側\n5\tから\tから\tADP\t助詞-格助詞\t_\t4\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,から,から,から,カラ,,,カラ,カラ,から\n6\tは\tは\tADP\t助詞-係助詞\t_\t4\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は\n7\t,\t,\tPUNCT\t補助記号-読点\t_\t4\tpunct\t_\tBunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,，,,,,,,,,，\n8\t特に\t特に\tADV\t副詞\t_\t17\tadvmod\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=副詞|SpaceAfter=No|UnidicInfo=,特に,特に,特に,トクニ,,,トクニ,トクニ,特に\n9\tどう\tどう\tADV\t副詞\t_\t15\tadvcl\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,どう,どう,どう,ドー,,,ドウ,ドウスル,どうする\n10\tし\t為る\tAUX\t動詞-非自立可能-サ行変格\t_\t9\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,し,する,シ,,,スル,ドウスル,どうする\n11\tて\tて\tSCONJ\t助詞-接続助詞\t_\t9\tmark\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-形容詞|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テホシイ,てほしい\n12\tほしい\t欲しい\tAUX\t形容詞-非自立可能-形容詞\t_\t11\tfixed\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=助動詞-形容詞|PrevUDLemma=ほしい|SpaceAfter=No|UnidicInfo=,欲しい,ほしい,ほしい,ホシー,,,ホシイ,テホシイ,てほしい\n13\tと\tと\tADP\t助詞-格助詞\t_\t9\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,と,と,と,ト,,,ト,トイウ,という\n14\tいう\t言う\tVERB\t動詞-一般-五段-ワア行\t_\t13\tfixed\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=I|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,言う,いう,いう,イウ,,,イウ,トイウ,という\n15\t要望\t要望\tNOUN\t名詞-普通名詞-サ変可能\t_\t17\tnsubj\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,要望,要望,要望,ヨーボー,,,ヨウボウ,ヨウボウ,要望\n16\tは\tは\tADP\t助詞-係助詞\t_\t15\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は\n17\tいただい\t頂く\tVERB\t動詞-非自立可能-五段-カ行\t_\t0\troot\t_\tBunsetuBILabel=B|BunsetuPositionType=ROOT|LUWBILabel=B|LUWPOS=動詞-一般-五段-カ行|PrevUDLemma=いただく|SpaceAfter=No|UnidicInfo=,頂く,いただい,いただく,イタダイ,,,イタダク,イタダク,頂く\n18\tて\tて\tSCONJ\t助詞-接続助詞\t_\t17\tmark\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-上一段-ア行|SpaceAfter=No|UnidicInfo=,て,て,て,テ,,,テ,テイル,ている\n19\tい\t居る\tVERB\t動詞-非自立可能-上一段-ア行\t_\t18\tfixed\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=助動詞-上一段-ア行|PrevUDLemma=いる|SpaceAfter=No|UnidicInfo=,居る,い,いる,イ,,,イル,テイル,ている\n20\tませ\tます\tAUX\t助動詞-助動詞-マス\t_\t17\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=B|LUWPOS=助動詞-助動詞-マス|SpaceAfter=No|UnidicInfo=,ます,ませ,ます,マセ,,,マス,マス,ます\n21\tん\tず\tAUX\t助動詞-助動詞-ヌ\tPolarity=Neg\t17\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-ヌ|PrevUDLemma=ぬ|SpaceAfter=No|UnidicInfo=,ず,ん,ぬ,ン,,,ヌ,ズ,ず\n22\t。\t。\tPUNCT\t補助記号-句点\t_\t17\tpunct\t_\tBunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。\n\n# newdoc id = test-s3\n# sent_id = test-s3\n# text = 星取り参加は当然とされ,不参加は白眼視される。\n1\t星取り\t星取り\tNOUN\t名詞-普通名詞-一般\t_\t2\tcompound\t_\tBunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,星取り,星取り,星取り,ホシトリ,,,ホシトリ,ホシトリサンカ,星取り参加\n2\t参加\t参加\tNOUN\t名詞-普通名詞-サ変可能\t_\t4\tnsubj\t_\tBunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,参加,参加,参加,サンカ,,,サンカ,ホシトリサンカ,星取り参加\n3\tは\tは\tADP\t助詞-係助詞\t_\t2\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は\n4\t当然\t当然\tADJ\t形状詞-一般\t_\t6\tadvcl\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=形状詞-一般|SpaceAfter=No|UnidicInfo=,当然,当然,当然,トーゼン,,,トウゼン,トウゼン,当然\n5\tと\tと\tADP\t助詞-格助詞\t_\t4\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-格助詞|SpaceAfter=No|UnidicInfo=,と,と,と,ト,,,ト,ト,と\n6\tさ\t為る\tVERB\t動詞-非自立可能-サ行変格\t_\t13\tacl\t_\tBunsetuBILabel=B|BunsetuPositionType=SEM_HEAD|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,さ,する,サ,,,スル,スル,する\n7\tれ\tれる\tAUX\t助動詞-助動詞-レル\t_\t6\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-レル|SpaceAfter=No|UnidicInfo=,れる,れ,れる,レ,,,レル,レル,れる\n8\t,\t,\tPUNCT\t補助記号-読点\t_\t6\tpunct\t_\tBunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-読点|SpaceAfter=No|UnidicInfo=,，,,,,,,,,，\n9\t不\t不\tNOUN\t接頭辞\tPolarity=Neg\t10\tcompound\t_\tBunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,不,不,不,フ,,,フ,フサンカ,不参加\n10\t参加\t参加\tNOUN\t名詞-普通名詞-サ変可能\t_\t13\tnsubj\t_\tBunsetuBILabel=I|BunsetuPositionType=SEM_HEAD|LUWBILabel=I|LUWPOS=名詞-普通名詞-一般|SpaceAfter=No|UnidicInfo=,参加,参加,参加,サンカ,,,サンカ,フサンカ,不参加\n11\tは\tは\tADP\t助詞-係助詞\t_\t10\tcase\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助詞-係助詞|SpaceAfter=No|UnidicInfo=,は,は,は,ワ,,,ハ,ハ,は\n12\t白眼\t白眼\tNOUN\t名詞-普通名詞-一般\t_\t13\tcompound\t_\tBunsetuBILabel=B|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,白眼,白眼,白眼,ハクガン,,,ハクガン,ハクガンシスル,白眼視する\n13\t視\t視\tNOUN\t接尾辞-名詞的-サ変可能\t_\t0\troot\t_\tBunsetuBILabel=I|BunsetuPositionType=ROOT|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|SpaceAfter=No|UnidicInfo=,視,視,視,シ,,,シ,ハクガンシスル,白眼視する\n14\tさ\t為る\tAUX\t動詞-非自立可能-サ行変格\t_\t13\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=FUNC|LUWBILabel=I|LUWPOS=動詞-一般-サ行変格|PrevUDLemma=する|SpaceAfter=No|UnidicInfo=,為る,さ,する,サ,,,スル,ハクガンシスル,白眼視する\n15\tれる\tれる\tAUX\t助動詞-助動詞-レル\t_\t13\taux\t_\tBunsetuBILabel=I|BunsetuPositionType=SYN_HEAD|LUWBILabel=B|LUWPOS=助動詞-助動詞-レル|SpaceAfter=No|UnidicInfo=,れる,れる,れる,レル,,,レル,レル,れる\n16\t。\t。\tPUNCT\t補助記号-句点\t_\t13\tpunct\t_\tBunsetuBILabel=I|BunsetuPositionType=CONT|LUWBILabel=B|LUWPOS=補助記号-句点|SpaceAfter=Yes|UnidicInfo=,。,。,。,,,,,,。"
  },
  {
    "path": "demo/pipeline_demo.py",
    "content": "\"\"\"\nA basic demo of the Stanza neural pipeline.\n\"\"\"\n\nimport sys\nimport argparse\nimport os\n\nimport stanza\nfrom stanza.resources.common import DEFAULT_MODEL_DIR\n\n\nif __name__ == '__main__':\n    # get arguments\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-d', '--models_dir', help='location of models files | default: ~/stanza_resources',\n                        default=DEFAULT_MODEL_DIR)\n    parser.add_argument('-l', '--lang', help='Demo language',\n                        default=\"en\")\n    parser.add_argument('-c', '--cpu', action='store_true', help='Use cpu as the device.')\n    args = parser.parse_args()\n\n    example_sentences = {\"en\": \"Barack Obama was born in Hawaii.  He was elected president in 2008.\",\n            \"zh\": \"中国文化经历上千年的历史演变，是各区域、各民族古代文化长期相互交流、借鉴、融合的结果。\",\n            \"fr\": \"Van Gogh grandit au sein d'une famille de l'ancienne bourgeoisie. Il tente d'abord de faire carrière comme marchand d'art chez Goupil & C.\",\n            \"vi\": \"Trận Trân Châu Cảng (hay Chiến dịch Hawaii theo cách gọi của Bộ Tổng tư lệnh Đế quốc Nhật Bản) là một đòn tấn công quân sự bất ngờ được Hải quân Nhật Bản thực hiện nhằm vào căn cứ hải quân của Hoa Kỳ tại Trân Châu Cảng thuộc tiểu bang Hawaii vào sáng Chủ Nhật, ngày 7 tháng 12 năm 1941, dẫn đến việc Hoa Kỳ sau đó quyết định tham gia vào hoạt động quân sự trong Chiến tranh thế giới thứ hai.\"}\n\n    if args.lang not in example_sentences:\n        print(f'Sorry, but we don\\'t have a demo sentence for \"{args.lang}\" for the moment. Try one of these languages: {list(example_sentences.keys())}')\n        sys.exit(1)\n\n    # download the models\n    stanza.download(args.lang, dir=args.models_dir)\n    # set up a pipeline\n    print('---')\n    print('Building pipeline...')\n    pipeline = stanza.Pipeline(lang=args.lang, dir=args.models_dir, use_gpu=(not args.cpu))\n    # process the document\n    doc = pipeline(example_sentences[args.lang])\n    # access nlp annotations\n    print('')\n    print('Input: {}'.format(example_sentences[args.lang]))\n    print(\"The tokenizer split the input into {} sentences.\".format(len(doc.sentences)))\n    print('---')\n    print('tokens of first sentence: ')\n    doc.sentences[0].print_tokens()\n    print('')\n    print('---')\n    print('dependency parse of first sentence: ')\n    doc.sentences[0].print_dependencies()\n    print('')\n\n"
  },
  {
    "path": "demo/scenegraph.py",
    "content": "\"\"\"\nVery short demo for the SceneGraph interface in the CoreNLP server\n\nRequires CoreNLP >= 4.5.5, Stanza >= 1.5.1\n\"\"\"\n\nimport json\n\nfrom stanza.server import CoreNLPClient\n\n# start_server=None if you have the server running in another process on the same host\n# you can start it with whatever normal options CoreNLPClient has\n#\n# preload=False avoids having the server unnecessarily load annotators\n# if you don't plan on using them\nwith CoreNLPClient(preload=False) as client:\n    result = client.scenegraph(\"Jennifer's antennae are on her head.\")\n    print(json.dumps(result, indent=2))\n\n\n"
  },
  {
    "path": "demo/semgrex visualization.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2787d5f5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import stanza\\n\",\n    \"from stanza.server.semgrex import Semgrex\\n\",\n    \"from stanza.models.common.constant import is_right_to_left\\n\",\n    \"import spacy\\n\",\n    \"from spacy import displacy\\n\",\n    \"from spacy.tokens import Doc\\n\",\n    \"from IPython.display import display, HTML\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"IMPORTANT: For the code in this module to run, you must have corenlp and Java installed on your machine. Additionally,\\n\",\n    \"set an environment variable CLASSPATH equal to the path of your corenlp directory.\\n\",\n    \"\\n\",\n    \"Example: CLASSPATH=C:\\\\\\\\Users\\\\\\\\Alex\\\\\\\\PycharmProjects\\\\\\\\pythonProject\\\\\\\\stanford-corenlp-4.5.0\\\\\\\\stanford-corenlp-4.5.0\\\\\\\\*\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"%env CLASSPATH=C:\\\\\\\\stanford-corenlp-4.5.2\\\\\\\\stanford-corenlp-4.5.2\\\\\\\\*\\n\",\n    \"def get_sentences_html(doc, language):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Returns a list of the HTML strings of the dependency visualizations of a given stanza doc object.\\n\",\n    \"\\n\",\n    \"    The 'language' arg is the two-letter language code for the document to be processed.\\n\",\n    \"\\n\",\n    \"    First converts the stanza doc object to a spacy doc object and uses displacy to generate an HTML\\n\",\n    \"    string for each sentence of the doc object.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    html_strings = []\\n\",\n    \"\\n\",\n    \"    # blank model - we don't use any of the model features, just the visualization\\n\",\n    \"    nlp = spacy.blank(\\\"en\\\")\\n\",\n    \"    sentences_to_visualize = []\\n\",\n    \"    for sentence in doc.sentences:\\n\",\n    \"        words, lemmas, heads, deps, tags = [], [], [], [], []\\n\",\n    \"        if is_right_to_left(language):  # order of words displayed is reversed, dependency arcs remain intact\\n\",\n    \"            sent_len = len(sentence.words)\\n\",\n    \"            for word in reversed(sentence.words):\\n\",\n    \"                words.append(word.text)\\n\",\n    \"                lemmas.append(word.lemma)\\n\",\n    \"                deps.append(word.deprel)\\n\",\n    \"                tags.append(word.upos)\\n\",\n    \"                if word.head == 0:  # spaCy head indexes are formatted differently than that of Stanza\\n\",\n    \"                    heads.append(sent_len - word.id)\\n\",\n    \"                else:\\n\",\n    \"                    heads.append(sent_len - word.head)\\n\",\n    \"        else:  # left to right rendering\\n\",\n    \"            for word in sentence.words:\\n\",\n    \"                words.append(word.text)\\n\",\n    \"                lemmas.append(word.lemma)\\n\",\n    \"                deps.append(word.deprel)\\n\",\n    \"                tags.append(word.upos)\\n\",\n    \"                if word.head == 0:\\n\",\n    \"                    heads.append(word.id - 1)\\n\",\n    \"                else:\\n\",\n    \"                    heads.append(word.head - 1)\\n\",\n    \"        document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)\\n\",\n    \"        sentences_to_visualize.append(document_result)\\n\",\n    \"\\n\",\n    \"    for line in sentences_to_visualize:  # render all sentences through displaCy\\n\",\n    \"        html_strings.append(displacy.render(line, style=\\\"dep\\\",\\n\",\n    \"                                            options={\\\"compact\\\": True, \\\"word_spacing\\\": 30, \\\"distance\\\": 100,\\n\",\n    \"                                                     \\\"arrow_spacing\\\": 20}, jupyter=False))\\n\",\n    \"    return html_strings\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def find_nth(haystack, needle, n):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Returns the starting index of the nth occurrence of the substring 'needle' in the string 'haystack'.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    start = haystack.find(needle)\\n\",\n    \"    while start >= 0 and n > 1:\\n\",\n    \"        start = haystack.find(needle, start + len(needle))\\n\",\n    \"        n -= 1\\n\",\n    \"    return start\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def round_base(num, base=10):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Rounding a number to its nearest multiple of the base. round_base(49.2, base=50) = 50.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    return base * round(num/base)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def process_sentence_html(orig_html, semgrex_sentence):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Takes a semgrex sentence object and modifies the HTML of the original sentence's deprel visualization,\\n\",\n    \"    highlighting words involved in the search queries and adding the label of the word inside of the semgrex match.\\n\",\n    \"\\n\",\n    \"    Returns the modified html string of the sentence's deprel visualization.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    tracker = {}  # keep track of which words have multiple labels\\n\",\n    \"    DEFAULT_TSPAN_COUNT = 2  # the original displacy html assigns two <tspan> objects per <text> object\\n\",\n    \"    CLOSING_TSPAN_LEN = 8  # </tspan> is 8 chars long\\n\",\n    \"    colors = ['#4477AA', '#66CCEE', '#228833', '#CCBB44', '#EE6677', '#AA3377', '#BBBBBB']\\n\",\n    \"    css_bolded_class = \\\"<style> .bolded{font-weight: bold;} </style>\\\\n\\\"\\n\",\n    \"    found_index = orig_html.find(\\\"\\\\n\\\")  # returns index where the opening <svg> ends\\n\",\n    \"    # insert the new style class into html string\\n\",\n    \"    orig_html = orig_html[: found_index + 1] + css_bolded_class + orig_html[found_index + 1:]\\n\",\n    \"\\n\",\n    \"    # Add color to words in the match, bold words in the match\\n\",\n    \"    for query in semgrex_sentence.result:\\n\",\n    \"        for i, match in enumerate(query.match):\\n\",\n    \"            color = colors[i]\\n\",\n    \"            paired_dy = 2\\n\",\n    \"            for node in match.node:\\n\",\n    \"                name, match_index = node.name, node.matchIndex\\n\",\n    \"                # edit existing <tspan> to change color and bold the text\\n\",\n    \"                start = find_nth(orig_html, \\\"<text\\\", match_index)  # finds start of svg <text> of interest\\n\",\n    \"                if match_index not in tracker:  # if we've already bolded and colored, keep the first color\\n\",\n    \"                    tspan_start = orig_html.find(\\\"<tspan\\\",\\n\",\n    \"                                                 start)  # finds start of the first svg <tspan> inside of the <text>\\n\",\n    \"                    tspan_end = orig_html.find(\\\"</tspan>\\\", start)  # finds start of the end of the above <tspan>\\n\",\n    \"                    tspan_substr = orig_html[tspan_start: tspan_end + CLOSING_TSPAN_LEN + 1] + \\\"\\\\n\\\"\\n\",\n    \"                    # color words in the hit and bold words in the hit\\n\",\n    \"                    edited_tspan = tspan_substr.replace('class=\\\"displacy-word\\\"', 'class=\\\"bolded\\\"').replace(\\n\",\n    \"                        'fill=\\\"currentColor\\\"', f'fill=\\\"{color}\\\"')\\n\",\n    \"                    # insert edited <tspan> object into html string\\n\",\n    \"                    orig_html = orig_html[: tspan_start] + edited_tspan + orig_html[tspan_end + CLOSING_TSPAN_LEN + 2:]\\n\",\n    \"                    tracker[match_index] = DEFAULT_TSPAN_COUNT\\n\",\n    \"\\n\",\n    \"                # next, we have to insert the new <tspan> object for the label\\n\",\n    \"                # Copy old <tspan> to copy formatting when creating new <tspan> later\\n\",\n    \"                prev_tspan_start = find_nth(orig_html[start:], \\\"<tspan\\\",\\n\",\n    \"                                            tracker[match_index] - 1) + start  # find the previous <tspan> start index\\n\",\n    \"                prev_tspan_end = find_nth(orig_html[start:], \\\"</tspan>\\\",\\n\",\n    \"                                          tracker[match_index] - 1) + start  # find the prev </tspan> start index\\n\",\n    \"                prev_tspan = orig_html[prev_tspan_start: prev_tspan_end + CLOSING_TSPAN_LEN + 1]\\n\",\n    \"\\n\",\n    \"                # Find spot to insert new tspan\\n\",\n    \"                closing_tspan_start = find_nth(orig_html[start:], \\\"</tspan>\\\", tracker[match_index]) + start\\n\",\n    \"                up_to_new_tspan = orig_html[: closing_tspan_start + CLOSING_TSPAN_LEN + 1]\\n\",\n    \"                rest_need_add_newline = orig_html[closing_tspan_start + CLOSING_TSPAN_LEN + 1:]\\n\",\n    \"\\n\",\n    \"                # Calculate proper x value in svg\\n\",\n    \"                x_value_start = prev_tspan.find('x=\\\"')\\n\",\n    \"                x_value_end = prev_tspan[x_value_start + 3:].find('\\\"') + 3  # 3 is the length of the 'x=\\\"' substring\\n\",\n    \"                x_value = prev_tspan[x_value_start + 3: x_value_end + x_value_start]\\n\",\n    \"\\n\",\n    \"                # Calculate proper y value in svg\\n\",\n    \"                DEFAULT_DY_VAL, dy = 2, 2\\n\",\n    \"                if paired_dy != DEFAULT_DY_VAL and node == match.node[\\n\",\n    \"                    1]:  # we're on the second node and need to adjust height to match the paired node\\n\",\n    \"                    dy = paired_dy\\n\",\n    \"                if node == match.node[0]:\\n\",\n    \"                    paired_node_level = 2\\n\",\n    \"                    if match.node[1].matchIndex in tracker:  # check if we need to adjust heights of labels\\n\",\n    \"                        paired_node_level = tracker[match.node[1].matchIndex]\\n\",\n    \"                        dif = tracker[match_index] - paired_node_level\\n\",\n    \"                        if dif > 0:  # current node has more labels\\n\",\n    \"                            paired_dy = DEFAULT_DY_VAL * dif + 1\\n\",\n    \"                            dy = DEFAULT_DY_VAL\\n\",\n    \"                        else:  # paired node has more labels, adjust this label down\\n\",\n    \"                            dy = DEFAULT_DY_VAL * (abs(dif) + 1)\\n\",\n    \"                            paired_dy = DEFAULT_DY_VAL\\n\",\n    \"\\n\",\n    \"                # Insert new <tspan> object\\n\",\n    \"                new_tspan = f'  <tspan class=\\\"displacy-word\\\" dy=\\\"{dy}em\\\" fill=\\\"{color}\\\" x={x_value}>{name[: 3].title()}.</tspan>\\\\n'  # abbreviate label names to 3 chars\\n\",\n    \"                orig_html = up_to_new_tspan + new_tspan + rest_need_add_newline\\n\",\n    \"                tracker[match_index] += 1\\n\",\n    \"    return orig_html\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def render_html_strings(edited_html_strings):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Renders the HTML to make the edits visible\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    for html_string in edited_html_strings:\\n\",\n    \"        display(HTML(html_string))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def visualize_search_doc(doc, semgrex_queries, lang_code, start_match=0, end_match=10):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Visualizes the semgrex results of running semgrex search on a stanza doc object with the given list of\\n\",\n    \"    semgrex queries. Returns a list of the edited HTML strings from the doc. Each element in the list represents\\n\",\n    \"    the HTML to render one of the sentences in the document.\\n\",\n    \"\\n\",\n    \"    'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    'start_match' and 'end_match' determine which matches to visualize. Works similar to splices, so that\\n\",\n    \"    start_match=0 and end_match=10 will display the first 10 semgrex matches.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    matches_count = 0  # Limits number of visualizations\\n\",\n    \"    with Semgrex(classpath=\\\"$CLASSPATH\\\") as sem:\\n\",\n    \"        edited_html_strings = []\\n\",\n    \"        semgrex_results = sem.process(doc, *semgrex_queries)\\n\",\n    \"        # one html string for each sentence\\n\",\n    \"        unedited_html_strings = get_sentences_html(doc, lang_code)\\n\",\n    \"        for i in range(len(unedited_html_strings)):\\n\",\n    \"\\n\",\n    \"            if matches_count >= end_match:  # we've collected enough matches, stop early\\n\",\n    \"                break\\n\",\n    \"\\n\",\n    \"            # check if sentence has matches, if not then do not visualize\\n\",\n    \"            has_none = True\\n\",\n    \"            for query in semgrex_results.result[i].result:\\n\",\n    \"                for match in query.match:\\n\",\n    \"                    if match:\\n\",\n    \"                        has_none = False\\n\",\n    \"\\n\",\n    \"            # Process HTML if queries have matches\\n\",\n    \"            if not has_none:\\n\",\n    \"                if start_match <= matches_count < end_match:\\n\",\n    \"                    edited_string = process_sentence_html(unedited_html_strings[i], semgrex_results.result[i])\\n\",\n    \"                    edited_string = adjust_dep_arrows(edited_string)\\n\",\n    \"                    edited_html_strings.append(edited_string)\\n\",\n    \"                matches_count += 1\\n\",\n    \"\\n\",\n    \"        render_html_strings(edited_html_strings)\\n\",\n    \"    return edited_html_strings\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def visualize_search_str(text, semgrex_queries, lang_code):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Visualizes the deprel of the semgrex results from running semgrex search on a string with the given list of\\n\",\n    \"    semgrex queries. Returns a list of the edited HTML strings. Each element in the list represents\\n\",\n    \"    the HTML to render one of the sentences in the document.\\n\",\n    \"\\n\",\n    \"    Internally, this function converts the string into a stanza doc object before processing the doc object.\\n\",\n    \"\\n\",\n    \"    'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    nlp = stanza.Pipeline(lang_code, processors=\\\"tokenize, pos, lemma, depparse\\\")\\n\",\n    \"    doc = nlp(text)\\n\",\n    \"    return visualize_search_doc(doc, semgrex_queries, lang_code)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def adjust_dep_arrows(raw_html):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    The default spaCy dependency visualization has misaligned arrows.\\n\",\n    \"    We fix arrows by aligning arrow ends and bodies to the word that they are directed to. If a word has an\\n\",\n    \"    arrowhead that is pointing not directly on the word's center, align the arrowhead to match the center of the word.\\n\",\n    \"\\n\",\n    \"    returns the edited html with fixed arrow placement\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    HTML_ARROW_BEGINNING = '<g class=\\\"displacy-arrow\\\">'\\n\",\n    \"    HTML_ARROW_ENDING = \\\"</g>\\\"\\n\",\n    \"    HTML_ARROW_ENDING_LEN = 6   # there are 2 newline chars after the arrow ending\\n\",\n    \"    arrows_start_idx = find_nth(haystack=raw_html, needle='<g class=\\\"displacy-arrow\\\">', n=1)\\n\",\n    \"    words_html, arrows_html = raw_html[: arrows_start_idx], raw_html[arrows_start_idx:]  # separate html for words and arrows\\n\",\n    \"    final_html = words_html  # continually concatenate to this after processing each arrow\\n\",\n    \"    arrow_number = 1  # which arrow we're editing (1-indexed)\\n\",\n    \"    start_idx, end_of_class_idx = find_nth(haystack=arrows_html, needle=HTML_ARROW_BEGINNING, n=arrow_number), find_nth(arrows_html, HTML_ARROW_ENDING, arrow_number)\\n\",\n    \"    while start_idx != -1:  # edit every arrow\\n\",\n    \"        arrow_section = arrows_html[start_idx: end_of_class_idx + HTML_ARROW_ENDING_LEN]  # slice a single svg arrow object\\n\",\n    \"        if arrow_section[-1] == \\\"<\\\":   # this is the last arrow in the HTML, don't cut the splice early\\n\",\n    \"            arrow_section = arrows_html[start_idx:]\\n\",\n    \"        edited_arrow_section = edit_dep_arrow(arrow_section)\\n\",\n    \"\\n\",\n    \"        final_html = final_html + edited_arrow_section  # continually update html with new arrow html until done\\n\",\n    \"\\n\",\n    \"        # Prepare for next iteration\\n\",\n    \"        arrow_number += 1\\n\",\n    \"        start_idx = find_nth(arrows_html, '<g class=\\\"displacy-arrow\\\">', n=arrow_number)\\n\",\n    \"        end_of_class_idx = find_nth(arrows_html, \\\"</g>\\\", arrow_number)\\n\",\n    \"    return final_html\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def edit_dep_arrow(arrow_html):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    The formatting of a displacy arrow in svg is the following:\\n\",\n    \"    <g class=\\\"displacy-arrow\\\">\\n\",\n    \"        <path class=\\\"displacy-arc\\\" id=\\\"arrow-c628889ffbf343e3848193a08606f10a-0-0\\\" stroke-width=\\\"2px\\\" d=\\\"M70,352.0 C70,177.0 390.0,177.0 390.0,352.0\\\" fill=\\\"none\\\" stroke=\\\"currentColor\\\"/>\\n\",\n    \"        <text dy=\\\"1.25em\\\" style=\\\"font-size: 0.8em; letter-spacing: 1px\\\">\\n\",\n    \"            <textPath xlink:href=\\\"#arrow-c628889ffbf343e3848193a08606f10a-0-0\\\" class=\\\"displacy-label\\\" startOffset=\\\"50%\\\" side=\\\"left\\\" fill=\\\"currentColor\\\" text-anchor=\\\"middle\\\">csubj</textPath>\\n\",\n    \"        </text>\\n\",\n    \"        <path class=\\\"displacy-arrowhead\\\" d=\\\"M70,354.0 L62,342.0 78,342.0\\\" fill=\\\"currentColor\\\"/>\\n\",\n    \"    </g>\\n\",\n    \"\\n\",\n    \"    We edit the 'd = ...' parts of the <path class ...> section to fix the arrow direction and length\\n\",\n    \"\\n\",\n    \"    returns the arrow_html with distances fixed\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    WORD_SPACING = 50   # words start at x=50 and are separated by 100s so their x values are multiples of 50\\n\",\n    \"    M_OFFSET = 4  # length of 'd=\\\"M' that we search for to extract the number from d=\\\"M70, for instance\\n\",\n    \"    ARROW_PIXEL_SIZE = 4\\n\",\n    \"    first_d_idx, second_d_idx = find_nth(arrow_html, 'd=\\\"M', 1), find_nth(arrow_html, 'd=\\\"M', 2)  # find where d=\\\"M starts\\n\",\n    \"    first_d_cutoff, second_d_cutoff = arrow_html.find(\\\",\\\", first_d_idx), arrow_html.find(\\\",\\\", second_d_idx)  # isolate the number after 'M' e.g. 'M70'\\n\",\n    \"    # gives svg x values of arrow body starting position and arrowhead position\\n\",\n    \"    arrow_position, arrowhead_position = float(arrow_html[first_d_idx + M_OFFSET: first_d_cutoff]), float(arrow_html[second_d_idx + M_OFFSET: second_d_cutoff])\\n\",\n    \"    # gives starting index of where 'fill=\\\"none\\\"' or 'fill=\\\"currentColor\\\"' begin, reference points to end the d= section\\n\",\n    \"    first_fill_start_idx, second_fill_start_idx = find_nth(arrow_html, \\\"fill\\\", n=1), find_nth(arrow_html, \\\"fill\\\", n=3)\\n\",\n    \"\\n\",\n    \"    # isolate the d= ... section to edit\\n\",\n    \"    first_d, second_d = arrow_html[first_d_idx: first_fill_start_idx], arrow_html[second_d_idx: second_fill_start_idx]\\n\",\n    \"    first_d_split, second_d_split = first_d.split(\\\",\\\"), second_d.split(\\\",\\\")\\n\",\n    \"\\n\",\n    \"    if arrow_position == arrowhead_position:  # This arrow is incoming onto the word, center the arrow/head to word center\\n\",\n    \"        corrected_arrow_pos = corrected_arrowhead_pos = round_base(arrow_position, base=WORD_SPACING)\\n\",\n    \"\\n\",\n    \"        # edit first_d  -- arrow body\\n\",\n    \"        second_term = first_d_split[1].split(\\\" \\\")[0] + \\\" \\\" + str(corrected_arrow_pos)\\n\",\n    \"        first_d = 'd=\\\"M' + str(corrected_arrow_pos) + \\\",\\\" + second_term + \\\",\\\" + \\\",\\\".join(first_d_split[2:])\\n\",\n    \"\\n\",\n    \"        # edit second_d  -- arrowhead\\n\",\n    \"        second_term = second_d_split[1].split(\\\" \\\")[0] + \\\" L\\\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\\n\",\n    \"        third_term = second_d_split[2].split(\\\" \\\")[0] + \\\" \\\" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\\n\",\n    \"        second_d = 'd=\\\"M' + str(corrected_arrowhead_pos) + \\\",\\\" + second_term + \\\",\\\" + third_term + \\\",\\\" + \\\",\\\".join(second_d_split[3:])\\n\",\n    \"    else:  # This arrow is outgoing to another word, center the arrow/head to that word's center\\n\",\n    \"        corrected_arrowhead_pos = round_base(arrowhead_position, base=WORD_SPACING)\\n\",\n    \"\\n\",\n    \"        # edit first_d -- arrow body\\n\",\n    \"        third_term = first_d_split[2].split(\\\" \\\")[0] + \\\" \\\" + str(corrected_arrowhead_pos)\\n\",\n    \"        fourth_term = first_d_split[3].split(\\\" \\\")[0] + \\\" \\\" + str(corrected_arrowhead_pos)\\n\",\n    \"        terms = [first_d_split[0], first_d_split[1], third_term, fourth_term] + first_d_split[4:]\\n\",\n    \"        first_d = \\\",\\\".join(terms)\\n\",\n    \"\\n\",\n    \"        # edit second_d -- arrow head\\n\",\n    \"        first_term = f'd=\\\"M{corrected_arrowhead_pos}'\\n\",\n    \"        second_term = second_d_split[1].split(\\\" \\\")[0] + \\\" L\\\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\\n\",\n    \"        third_term = second_d_split[2].split(\\\" \\\")[0] + \\\" \\\" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\\n\",\n    \"        terms = [first_term, second_term, third_term] + second_d_split[3:]\\n\",\n    \"        second_d = \\\",\\\".join(terms)\\n\",\n    \"    # rebuild and return html\\n\",\n    \"    return arrow_html[:first_d_idx] + first_d + \\\" \\\" + arrow_html[first_fill_start_idx:second_d_idx] + second_d + \\\" \\\" + arrow_html[second_fill_start_idx:]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def main():\\n\",\n    \"    nlp = stanza.Pipeline(\\\"en\\\", processors=\\\"tokenize,pos,lemma,depparse\\\")\\n\",\n    \"\\n\",\n    \"    # doc = nlp(\\\"This a dummy sentence. Banning opal removed all artifact decks from the meta.  I miss playing lantern. This is a dummy sentence.\\\")\\n\",\n    \"    doc = nlp(\\\"Banning opal removed artifact decks from the meta. Banning tennis resulted in players banning people.\\\")\\n\",\n    \"    # A single result .result[i].result[j] is a list of matches for sentence i on semgrex query j.\\n\",\n    \"    queries = [\\\"{pos:NN}=object <obl {}=action\\\",\\n\",\n    \"               \\\"{cpos:NOUN}=thing <obj {cpos:VERB}=action\\\"]\\n\",\n    \"    res = visualize_search_doc(doc, queries, \\\"en\\\")\\n\",\n    \"    print(res[0])  # see the first sentence's deprel visualization HTML\\n\",\n    \"    print(\\\"---------------------------------------\\\")\\n\",\n    \"    print(res[1])  # second sentence's deprel visualization HTML\\n\",\n    \"    return\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if __name__ == '__main__':\\n\",\n    \"    main()\\n\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "demo/semgrex.py",
    "content": "import stanza\nfrom stanza.server.semgrex import Semgrex\n\nnlp = stanza.Pipeline(\"en\", processors=\"tokenize,pos,lemma,depparse\")\n\ndoc = nlp(\"Banning opal removed all artifact decks from the meta.  I miss playing lantern.\")\nwith Semgrex(classpath=\"$CLASSPATH\") as sem:\n    semgrex_results = sem.process(doc,\n                                  \"{pos:NN}=object <obl {}=action\",\n                                  \"{cpos:NOUN}=thing <obj {cpos:VERB}=action\")\n    print(\"COMPLETE RESULTS\")\n    print(semgrex_results)\n\n    print(\"Number of matches in graph 0 ('Banning opal...') for semgrex query 1 (thing <obj action): %d\" % len(semgrex_results.result[0].result[1].match))\n    for match_idx, match in enumerate(semgrex_results.result[0].result[1].match):\n        print(\"Match {}:\\n-----------\\n{}\".format(match_idx, match))\n\n    print(\"graph 1 for semgrex query 0 is an empty match: len %d\" % len(semgrex_results.result[1].result[0].match))\n"
  },
  {
    "path": "demo/ssurgeon_script.txt",
    "content": "# To run this, use the stanza/server/ssurgeon.py main file.\n# For example:\n# python3 stanza/server/ssurgeon.py  --edit_file demo/ssurgeon_script.txt --no_print_input --input_file ../data/ud2_11/UD_English-Pronouns/en_pronouns-ud-test.conllu > en_pronouns.updated.conllu\n# This script updates the UD 2.11 version of UD_English-Pronouns to\n# better match punctuation attachments, MWT, and no double subjects.\n\n# This turns unwanted csubj into advcl\n{}=source >nsubj {} >csubj=bad {}\nrelabelNamedEdge -edge bad -reln advcl\n\n# This detects punctuations which are not attached to the root and reattaches them\n{word:/[.]/}=punct <punct=bad {}=parent << {$}=root : {}=parent << {}=root\nremoveNamedEdge -edge bad\naddEdge -gov root -dep punct -reln punct\n\n# This detects the specific MWT found in the 2.11 dataset\n{}=first . {word:/'s|n't|'ll/}=second\ncombineMWT -node first -node second\n"
  },
  {
    "path": "doc/CoreNLP.proto",
    "content": "syntax = \"proto2\";\n\npackage edu.stanford.nlp.pipeline;\n\noption java_package = \"edu.stanford.nlp.pipeline\";\noption java_outer_classname = \"CoreNLPProtos\";\n\n//\n// From JAVANLP_HOME, you can build me with the command:\n//\n//  protoc -I=src/edu/stanford/nlp/pipeline/ --java_out=src src/edu/stanford/nlp/pipeline/CoreNLP.proto\n//\n\n//\n// To do the python version:\n//\n//  protoc -I=./doc --python_out=./stanza/protobuf ./doc/CoreNLP.proto\n//\n\n//\n// An enumeration for the valid languages allowed in CoreNLP\n//\nenum Language {\n  Unknown  = 0;\n  Any      = 1;\n  Arabic   = 2;\n  Chinese  = 3;\n  English  = 4;\n  German   = 5;\n  French   = 6;\n  Hebrew   = 7;\n  Spanish  = 8;\n  UniversalEnglish = 9;\n  UniversalChinese = 10;\n}\n\n//\n// A document; that is, the equivalent of an Annotation.\n//\nmessage Document {\n  required string     text        = 1;\n  repeated Sentence   sentence    = 2;\n  repeated CorefChain corefChain  = 3;\n  optional string     docID       = 4;\n  optional string     docDate     = 7;\n  optional uint64     calendar    = 8;\n\n  /**\n   * A peculiar field, for the corner case when a Document is\n   * serialized without any sentences. Otherwise\n   */\n  repeated Token      sentencelessToken = 5;\n  repeated Token      character = 10;\n\n  repeated Quote      quote = 6;\n\n  /**\n   * This field is for entity mentions across the document.\n   */\n  repeated NERMention mentions = 9;\n  optional bool hasEntityMentionsAnnotation = 13; // used to differentiate between null and empty list\n\n  /**\n   * xml information\n   */\n  optional bool    xmlDoc = 11;\n  repeated Section sections = 12;\n\n  /** coref mentions for entire document **/\n  repeated Mention         mentionsForCoref                    = 14;\n  optional bool hasCorefMentionAnnotation = 15;\n  optional bool hasCorefAnnotation = 16;\n  repeated int32 corefMentionToEntityMentionMappings = 17;\n  repeated int32 entityMentionToCorefMentionMappings = 18;\n\n  extensions 100 to 255;\n}\n\n//\n// The serialized version of a CoreMap representing a sentence.\n//\nmessage Sentence {\n  repeated Token            token                               = 1;\n  required uint32           tokenOffsetBegin                    = 2;\n  required uint32           tokenOffsetEnd                      = 3;\n  optional uint32           sentenceIndex                       = 4;\n  optional uint32           characterOffsetBegin                = 5;\n  optional uint32           characterOffsetEnd                  = 6;\n  optional ParseTree        parseTree                           = 7;\n  optional ParseTree        binarizedParseTree                  = 31;\n  optional ParseTree        annotatedParseTree                  = 32;\n  optional string           sentiment                           = 33;\n  repeated ParseTree        kBestParseTrees                     = 34;\n  optional DependencyGraph  basicDependencies                   = 8;\n  optional DependencyGraph  collapsedDependencies               = 9;\n  optional DependencyGraph  collapsedCCProcessedDependencies    = 10;\n  optional DependencyGraph  alternativeDependencies             = 13;\n  repeated RelationTriple   openieTriple                        = 14;   // The OpenIE triples in the sentence\n  repeated RelationTriple   kbpTriple                           = 16;   // The KBP triples in this sentence\n  repeated SentenceFragment entailedSentence                    = 15;   // The entailed sentences, by natural logic\n  repeated SentenceFragment entailedClause                      = 35;   // The entailed clauses, by natural logic\n  optional DependencyGraph  enhancedDependencies                = 17;\n  optional DependencyGraph  enhancedPlusPlusDependencies        = 18;\n  repeated Token            character                           = 19;\n\n  optional uint32           paragraph                           = 11;\n\n  optional string           text                                = 12;   // Only needed if we're only saving the sentence.\n\n  optional uint32           lineNumber                          = 20;\n\n  // Fields set by other annotators in CoreNLP\n  optional bool            hasRelationAnnotations              = 51;\n  repeated Entity          entity                              = 52;\n  repeated Relation        relation                            = 53;\n  optional bool            hasNumerizedTokensAnnotation        = 54;\n  repeated NERMention      mentions                            = 55;\n  repeated Mention         mentionsForCoref                    = 56;\n  optional bool            hasCorefMentionsAnnotation          = 57;\n\n  optional string          sentenceID                          = 58;  // Useful when storing sentences (e.g. ForEach)\n  optional string          sectionDate                         = 59;  // date of section\n  optional uint32          sectionIndex                        = 60;  // section index for this sentence's section\n  optional string          sectionName                         = 61;  // name of section\n  optional string          sectionAuthor                       = 62;  // author of section\n  optional string          docID                               = 63;  // doc id\n  optional bool            sectionQuoted                       = 64;  // is this sentence in an xml quote in a post\n\n  optional bool            hasEntityMentionsAnnotation         = 65;  // check if there are entity mentions\n  optional bool            hasKBPTriplesAnnotation             = 68;  // check if there are KBP triples\n  optional bool            hasOpenieTriplesAnnotation          = 69;  // check if there are OpenIE triples\n\n  // quote stuff\n  optional uint32             chapterIndex                     = 66;\n  optional uint32             paragraphIndex                   = 67;\n  // the quote annotator can soometimes add merged sentences\n  optional Sentence           enhancedSentence                 = 70;\n\n  // speaker stuff\n  optional string          speaker                             = 71;  // The speaker speaking this sentence\n  optional string          speakerType                         = 72;  // The type of speaker speaking this sentence\n\n  extensions 100 to 255;\n}\n\n//\n// The serialized version of a Token (a CoreLabel).\n//\nmessage Token {\n  // Fields set by the default annotators [new CoreNLP(new Properties())]\n  optional string word              = 1;    // the word's gloss (post-tokenization)\n  optional string pos               = 2;    // The word's part of speech tag\n  optional string value             = 3;    // The word's 'value', (e.g., parse tree node)\n  optional string category          = 4;    // The word's 'category' (e.g., parse tree node)\n  optional string before            = 5;    // The whitespace/xml before the token\n  optional string after             = 6;    // The whitespace/xml after the token\n  optional string originalText      = 7;    // The original text for this token\n  optional string ner               = 8;    // The word's NER tag\n  optional string coarseNER         = 62;   // The word's coarse NER tag\n  optional string fineGrainedNER    = 63;   // The word's fine-grained NER tag\n  repeated string nerLabelProbs     = 66;   // listing of probs\n  optional string normalizedNER     = 9;    // The word's normalized NER tag\n  optional string lemma             = 10;   // The word's lemma\n  optional uint32 beginChar         = 11;   // The character offset begin, in the document\n  optional uint32 endChar           = 12;   // The character offset end, in the document\n  optional uint32 utterance         = 13;   // The utterance tag used in dcoref\n  optional string speaker           = 14;   // The speaker speaking this word\n  optional string speakerType       = 77;   // The type of speaker speaking this word\n  optional uint32 beginIndex        = 15;   // The begin index of, e.g., a span\n  optional uint32 endIndex          = 16;   // The begin index of, e.g., a span\n  optional uint32 tokenBeginIndex   = 17;   // The begin index of the token\n  optional uint32 tokenEndIndex     = 18;   // The end index of the token\n  optional Timex  timexValue        = 19;   // The time this word refers to\n  optional bool   hasXmlContext     = 21;   // Used by clean xml annotator\n  repeated string xmlContext        = 22;   // Used by clean xml annotator\n  optional uint32 corefClusterID    = 23;   // The [primary] cluster id for this token\n  optional string answer            = 24;   // A temporary annotation which is occasionally left in\n  //  optional string projectedCategory = 25;   // The syntactic category of the maximal constituent headed by the word. Not used anywhere, so deleted.\n  optional uint32    headWordIndex  = 26;   // The index of the head word of this word.\n  optional Operator  operator       = 27;   // If this is an operator, which one is it and what is its scope (as per Natural Logic)?\n  optional Polarity  polarity       = 28;   // The polarity of this word, according to Natural Logic\n  optional string    polarity_dir   = 39;   // The polarity of this word, either \"up\", \"down\", or \"flat\"\n  optional Span      span           = 29;   // The span of a leaf node of a tree\n  optional string    sentiment      = 30;   // The final sentiment of the sentence\n  optional int32     quotationIndex = 31;   // The index of the quotation this token refers to\n  optional MapStringString conllUFeatures = 32;\n  optional string coarseTag         = 33; //  The coarse POS tag (used to store the UPOS tag)\n  optional Span conllUTokenSpan     = 34;\n  optional string conllUMisc        = 35;\n  optional MapStringString conllUSecondaryDeps = 36;\n  optional string   wikipediaEntity = 37;\n  optional bool     isNewline = 38;\n\n\n  // Fields set by other annotators in CoreNLP\n  optional string gender          = 51;  // gender annotation (machine reading)\n  optional string trueCase        = 52;  // true case type of token\n  optional string trueCaseText    = 53;  // true case gloss of token\n\n  //  Chinese character info\n  optional string chineseChar     = 54;\n  optional string chineseSeg      = 55;\n  optional string chineseXMLChar  = 60;\n\n  //  Arabic character info\n  optional string arabicSeg       = 76;\n\n  // Section info\n  optional string sectionName     = 56;\n  optional string sectionAuthor   = 57;\n  optional string sectionDate     = 58;\n  optional string sectionEndLabel = 59;\n\n  // French tokens have parents\n  optional string parent          = 61;\n\n  // mention index info\n  repeated uint32 corefMentionIndex = 64;\n  optional uint32 entityMentionIndex = 65;\n\n  // mwt stuff\n  optional bool isMWT = 67;\n  optional bool isFirstMWT = 68;\n  optional string mwtText = 69;\n  // setting this to a map might be nice, but there are a couple issues\n  // for one, there can be values with no key\n  // for another, it's a pain to correctly parse, since different treebanks\n  // can have different standards for how to write out the misc field\n  optional string mwtMisc = 78;\n\n  // number info\n  optional uint64 numericValue = 70;\n  optional string numericType = 71;\n  optional uint64 numericCompositeValue = 72;\n  optional string numericCompositeType = 73;\n\n  optional uint32 codepointOffsetBegin   = 74;\n  optional uint32 codepointOffsetEnd     = 75;\n\n  // Fields in the CoreLabel java class that are moved elsewhere\n  //       string text           @see Document#text + character offsets\n  //       uint32 sentenceIndex  @see Sentence#sentenceIndex\n  //       string docID          @see Document#docID\n  //       uint32 paragraph      @see Sentence#paragraph\n\n  // Most serialized annotations will not have this\n  // Some code paths may not correctly process this if serialized,\n  // since many places will read the index off the position in a sentence\n  // In particular, deserializing a Document using ProtobufAnnotationSerializer\n  // will clobber any index value\n  // But Semgrex and Ssurgeon in particular need a way\n  // to pass around nodes where the node's index is not strictly 1, 2, 3, ...\n  // thanks to the empty nodes in UD treebanks such as\n  // English EWT or Estonian EWT (not related to each other)\n  optional uint32 index          = 79;\n  optional uint32 emptyIndex     = 80;\n\n  extensions 100 to 255;\n}\n\n//\n// An enumeration of valid sentiment values for the sentiment classifier.\n//\nenum Sentiment {\n  STRONG_NEGATIVE   = 0;\n  WEAK_NEGATIVE     = 1;\n  NEUTRAL           = 2;\n  WEAK_POSITIVE     = 3;\n  STRONG_POSITIVE   = 4;\n}\n\n//\n// A quotation marker in text\n//\nmessage Quote {\n  optional string text           = 1;\n  optional uint32 begin          = 2;\n  optional uint32 end            = 3;\n  optional uint32 sentenceBegin  = 5;\n  optional uint32 sentenceEnd    = 6;\n  optional uint32 tokenBegin     = 7;\n  optional uint32 tokenEnd       = 8;\n  optional string docid          = 9;\n  optional uint32 index          = 10;\n  optional string author         = 11;\n  optional string mention        = 12;\n  optional uint32 mentionBegin   = 13;\n  optional uint32 mentionEnd     = 14;\n  optional string mentionType    = 15;\n  optional string mentionSieve   = 16;\n  optional string speaker        = 17;\n  optional string speakerSieve   = 18;\n  optional string canonicalMention = 19;\n  optional uint32 canonicalMentionBegin = 20;\n  optional uint32 canonicalMentionEnd = 21;\n  optional DependencyGraph attributionDependencyGraph = 22;\n}\n\n//\n// A syntactic parse tree, with scores.\n//\nmessage ParseTree {\n  repeated ParseTree child           = 1;\n  optional string    value           = 2;\n  optional uint32    yieldBeginIndex = 3;\n  optional uint32    yieldEndIndex   = 4;\n  optional double    score           = 5;\n  optional Sentiment sentiment       = 6;\n}\n\n//\n// A dependency graph representation.\n//\nmessage DependencyGraph {\n  message Node {\n    required uint32 sentenceIndex  = 1;\n    required uint32 index          = 2;\n    optional uint32 copyAnnotation = 3;\n    optional uint32 emptyIndex     = 4;\n  }\n\n  message Edge {\n    required uint32 source      = 1;\n    required uint32 target      = 2;\n    optional string dep         = 3;\n    optional bool   isExtra     = 4;\n    optional uint32 sourceCopy  = 5;\n    optional uint32 targetCopy  = 6;\n    optional uint32 sourceEmpty = 8;\n    optional uint32 targetEmpty = 9;\n    optional Language language  = 7 [default=Unknown];\n  }\n\n  repeated Node     node     = 1;\n  repeated Edge     edge     = 2;\n  repeated uint32   root     = 3 [packed=true];\n  // optional: if this graph message is not part of a larger context,\n  // the tokens will help reconstruct the actual sentence\n  repeated Token    token    = 4;\n  // The values in this field will index directly into the node list\n  // This is useful so that additional information such as emptyIndex\n  // can be considered without having to pass it around a second time\n  repeated uint32   rootNode = 5 [packed=true];\n}\n\n//\n// A coreference chain.\n// These fields are not *really* optional. CoreNLP will crash without them.\n//\nmessage CorefChain {\n  message CorefMention {\n    optional int32  mentionID          = 1;\n    optional string mentionType        = 2;\n    optional string number             = 3;\n    optional string gender             = 4;\n    optional string animacy            = 5;\n    optional uint32 beginIndex         = 6;\n    optional uint32 endIndex           = 7;\n    optional uint32 headIndex          = 9;\n    optional uint32 sentenceIndex      = 10;\n    optional uint32 position           = 11;  // the second element of position\n  }\n\n  required int32        chainID        = 1;\n  repeated CorefMention mention        = 2;\n  required uint32       representative = 3;\n}\n\n//\n// a mention\n//\n\nmessage Mention {\n  optional int32 mentionID             = 1;\n  optional string mentionType          = 2;\n  optional string number               = 3;\n  optional string gender               = 4;\n  optional string animacy              = 5;\n  optional string person               = 6;\n  optional uint32 startIndex           = 7;\n  optional uint32 endIndex             = 9;\n  optional int32 headIndex             = 10;\n  optional string headString           = 11;\n  optional string nerString            = 12;\n  optional int32 originalRef           = 13;\n  optional int32 goldCorefClusterID    = 14;\n  optional int32 corefClusterID        = 15;\n  optional int32 mentionNum            = 16;\n  optional int32 sentNum               = 17;\n  optional int32 utter                 = 18;\n  optional int32 paragraph             = 19;\n  optional bool isSubject              = 20;\n  optional bool isDirectObject         = 21;\n  optional bool isIndirectObject       = 22;\n  optional bool isPrepositionObject    = 23;\n  optional bool hasTwin                = 24;\n  optional bool generic                = 25;\n  optional bool isSingleton            = 26;\n  optional bool hasBasicDependency     = 27;\n  optional bool hasEnhancedDependency  = 28;\n  optional bool hasContextParseTree    = 29;\n  optional IndexedWord headIndexedWord = 30;\n  optional IndexedWord   dependingVerb = 31;\n  optional IndexedWord       headWord  = 32;\n  optional SpeakerInfo    speakerInfo  = 33;\n\n  repeated IndexedWord sentenceWords   = 50;\n  repeated IndexedWord originalSpan    = 51;\n  repeated string dependents           = 52;\n  repeated string preprocessedTerms    = 53;\n  repeated int32 appositions           = 54;\n  repeated int32 predicateNominatives  = 55;\n  repeated int32 relativePronouns      = 56;\n  repeated int32 listMembers           = 57;\n  repeated int32 belongToLists         = 58;\n\n}\n\n//\n// store the position (sentence, token index) of a CoreLabel\n//\n\nmessage IndexedWord {\n  optional  int32 sentenceNum          = 1;\n  optional  int32 tokenIndex           = 2;\n  optional  int32 docID                = 3;\n  optional uint32 copyCount            = 4;\n}\n\n//\n// speaker info, this is used for Mentions\n//\n\nmessage SpeakerInfo {\n  optional string speakerName          = 1;\n  repeated int32 mentions              = 2;\n}\n\n//\n// A Span of text\n//\nmessage Span {\n  required uint32 begin      = 1;\n  required uint32 end        = 2;\n}\n\n//\n// A Timex object, representing a temporal expression (TIMe EXpression)\n// These fields are not *really* optional. CoreNLP will crash without them.\n//\nmessage Timex {\n  optional string value      = 1;\n  optional string altValue   = 2;\n  optional string text       = 3;\n  optional string type       = 4;\n  optional string tid        = 5;\n  optional uint32 beginPoint = 6;\n  optional uint32 endPoint   = 7;\n}\n\n//\n// A representation of an entity in a relation.\n// This corresponds to the EntityMention, and more broadly the\n// ExtractionObject classes.\n//\nmessage Entity {\n  optional uint32 headStart      = 6;\n  optional uint32 headEnd        = 7;\n  optional string mentionType    = 8;\n  optional string normalizedName = 9;\n  optional uint32 headTokenIndex = 10;\n  optional string corefID        = 11;\n  // inherited from ExtractionObject\n  optional string objectID       = 1;\n  optional uint32 extentStart    = 2;\n  optional uint32 extentEnd      = 3;\n  optional string type           = 4;\n  optional string subtype        = 5;\n  // Implicit\n  //       uint32 sentence       @see implicit in sentence\n}\n\n//\n// A representation of a relation, mirroring RelationMention\n//\nmessage Relation {\n  repeated string argName   = 6;\n  repeated Entity arg       = 7;\n  optional string signature = 8;\n  // inherited from ExtractionObject\n  optional string objectID = 1;\n  optional uint32 extentStart    = 2;\n  optional uint32 extentEnd      = 3;\n  optional string type           = 4;\n  optional string subtype        = 5;\n  // Implicit\n  //       uint32 sentence       @see implicit in sentence\n}\n\n//\n// A Natural Logic operator\n//\nmessage Operator {\n  required string name                = 1;\n  required int32  quantifierSpanBegin = 2;\n  required int32  quantifierSpanEnd   = 3;\n  required int32  subjectSpanBegin    = 4;\n  required int32  subjectSpanEnd      = 5;\n  required int32  objectSpanBegin     = 6;\n  required int32  objectSpanEnd       = 7;\n}\n\n//\n// The seven informative Natural Logic relations\n//\nenum NaturalLogicRelation {\n  EQUIVALENCE        = 0;\n  FORWARD_ENTAILMENT = 1;\n  REVERSE_ENTAILMENT = 2;\n  NEGATION           = 3;\n  ALTERNATION        = 4;\n  COVER              = 5;\n  INDEPENDENCE       = 6;\n}\n\n//\n// The polarity of a word, according to Natural Logic\n//\nmessage Polarity {\n  required NaturalLogicRelation projectEquivalence       = 1;\n  required NaturalLogicRelation projectForwardEntailment = 2;\n  required NaturalLogicRelation projectReverseEntailment = 3;\n  required NaturalLogicRelation projectNegation          = 4;\n  required NaturalLogicRelation projectAlternation       = 5;\n  required NaturalLogicRelation projectCover             = 6;\n  required NaturalLogicRelation projectIndependence      = 7;\n}\n\n//\n// An NER mention in the text\n//\nmessage NERMention {\n  optional uint32 sentenceIndex                 = 1;\n  required uint32 tokenStartInSentenceInclusive = 2;\n  required uint32 tokenEndInSentenceExclusive   = 3;\n  required string ner                           = 4;\n  optional string normalizedNER                 = 5;\n  optional string entityType                    = 6;\n  optional Timex  timex                         = 7;\n  optional string wikipediaEntity               = 8;\n  optional string gender                        = 9;\n  optional uint32 entityMentionIndex            = 10;\n  optional uint32 canonicalEntityMentionIndex   = 11;\n  optional string entityMentionText             = 12;\n}\n\n//\n// An entailed sentence fragment.\n// Created by the openie annotator.\n//\nmessage SentenceFragment {\n  repeated uint32 tokenIndex     = 1;\n  optional uint32 root           = 2;\n  optional bool   assumedTruth   = 3;\n  optional double score          = 4;\n}\n\n\n//\n// The index of a token in a document, including the sentence\n// index and the offset.\n//\nmessage TokenLocation {\n optional uint32 sentenceIndex = 1;\n optional uint32 tokenIndex    = 2;\n\n}\n\n\n//\n// An OpenIE relation triple.\n// Created by the openie annotator.\n//\nmessage RelationTriple {\n  optional string          subject        = 1;   // The surface form of the subject\n  optional string          relation       = 2;   // The surface form of the relation (required)\n  optional string          object         = 3;   // The surface form of the object\n  optional double          confidence     = 4;   // The [optional] confidence of the extraction\n  repeated TokenLocation   subjectTokens  = 13; // The tokens comprising the subject of the triple\n  repeated TokenLocation   relationTokens = 14; // The tokens comprising the relation of the triple\n  repeated TokenLocation   objectTokens   = 15; // The tokens comprising the object of the triple\n  optional DependencyGraph tree           = 8;   // The dependency graph fragment for this triple\n  optional bool            istmod         = 9;   // If true, this expresses an implicit tmod relation\n  optional bool            prefixBe       = 10;  // If true, this relation string is missing a 'be' prefix\n  optional bool            suffixBe       = 11;  // If true, this relation string is missing a 'be' suffix\n  optional bool            suffixOf       = 12;  // If true, this relation string is missing a 'of' prefix\n}\n\n\n//\n// A map from strings to strings.\n// Used, minimally, in the CoNLLU featurizer\n//\nmessage MapStringString {\n  repeated string key   = 1;\n  repeated string value = 2;\n}\n\n//\n// A map from integers to strings.\n// Used, minimally, in the CoNLLU featurizer\n//\nmessage MapIntString {\n  repeated uint32 key   = 1;\n  repeated string value = 2;\n}\n\n//\n// Store section info\n//\n\nmessage Section {\n  required uint32 charBegin         = 1;\n  required uint32 charEnd           = 2;\n  optional string author            = 3;\n  repeated uint32 sentenceIndexes   = 4;\n  optional string datetime          = 5;\n  repeated Quote quotes             = 6;\n  optional uint32 authorCharBegin   = 7;\n  optional uint32 authorCharEnd     = 8;\n  required Token xmlTag             = 9;\n}\n\n\n\n// A message for requesting a semgrex\n// Each sentence stores information about the tokens making up the\n// corresponding graph\n// An alternative would have been to use the existing Document or\n// Sentence classes, but the problem with that is it would be\n// ambiguous which dependency object to use.\nmessage SemgrexRequest {\n  message Dependencies {\n    repeated Token           token       = 1;\n    required DependencyGraph graph       = 2;\n  }\n\n  repeated string            semgrex     = 1;\n  repeated Dependencies      query       = 2;\n}\n\n// The response from running a semgrex\n// If you pass in M semgrex expressions and N dependency graphs,\n// this returns MxN nested results.  Each SemgrexResult can match\n// multiple times in one graph\n//\n// You may want to send multiple semgrexes per query because\n// translating large numbers of dependency graphs to protobufs\n// will be expensive, so doing several queries at once will save time\nmessage SemgrexResponse {\n  message NamedNode {\n    required string          name        = 1;\n    required int32           matchIndex  = 2;\n  }\n\n  message NamedRelation {\n    required string          name        = 1;\n    required string          reln        = 2;\n  }\n\n  message NamedEdge {\n    required string          name        = 1;\n    required int32           source      = 2;\n    required int32           target      = 3;\n    optional string          reln        = 4;\n    optional bool            isExtra     = 5;\n    optional uint32          sourceCopy  = 6;\n    optional uint32          targetCopy  = 7;\n  }\n\n  message VariableString {\n    required string          name        = 1;\n    required string          value       = 2;\n  }\n\n  message Match {\n    required int32           matchIndex   = 1;\n    repeated NamedNode       node         = 2;\n    repeated NamedRelation   reln         = 3;\n    repeated NamedEdge       edge         = 6;\n    repeated VariableString  varstring    = 7;\n\n    // when processing multiple sentences at once,\n    // which sentence this applies to\n    // indexed from 0\n    optional int32           sentenceIndex  = 4;\n    // index of the semgrex expression this match applies to\n    // indexed from 0\n    optional int32           semgrexIndex = 5;\n  }\n\n  message SemgrexResult {\n    repeated Match           match       = 1;\n  }\n\n  message GraphResult {\n    repeated SemgrexResult   result      = 1;\n  }\n\n  repeated GraphResult       result      = 1;\n}\n\n\n// A message for processing an Ssurgeon\n// Each sentence stores information about the tokens making up the\n// corresponding graph\n// An alternative would have been to use the existing Document or\n// Sentence classes, but the problem with that is it would be\n// ambiguous which dependency object to use.  Another problem\n// is that if the intent is to use multiple graphs from a\n// Sentence, then edits to the nodes of one graph would show up\n// in the nodes of the other graph (same backing CoreLabels)\n// and the operations themselves may not have the intended effect.\n// The Ssurgeon is composed of two pieces, the semgrex and the\n// ssurgeon operations, along with some optional documentation.\nmessage SsurgeonRequest {\n  message Ssurgeon {\n    optional string          semgrex     = 1;\n    repeated string          operation   = 2;\n    optional string          id          = 3;\n    optional string          notes       = 4;\n    optional string          language    = 5;\n  }\n\n  repeated Ssurgeon          ssurgeon    = 1;\n  repeated DependencyGraph   graph       = 2;\n}\n\nmessage SsurgeonResponse {\n  message SsurgeonResult {\n    optional DependencyGraph graph      = 1;\n    optional bool            changed    = 2;\n  }\n\n  repeated SsurgeonResult    result      = 1;\n}\n\n// It's possible to send in a whole document, but we\n// only care about the Sentences and Tokens\nmessage TokensRegexRequest {\n  required Document          doc         = 1;\n  repeated string            pattern     = 2;\n}\n\n// The result will be a nested structure:\n// repeated PatternMatch, one for each pattern\n// each PatternMatch has a repeated Match,\n// which tells you which sentence matched and where\nmessage TokensRegexResponse {\n  message MatchLocation {\n    optional string          text        = 1;\n    optional int32           begin       = 2;\n    optional int32           end         = 3;\n  }\n\n  message Match {\n    required int32           sentence    = 1;\n    required MatchLocation   match       = 2;\n    repeated MatchLocation   group       = 3;\n  }\n\n  message PatternMatch {\n    repeated Match           match       = 1;\n  }\n\n  repeated PatternMatch      match       = 1;\n}\n\n// A protobuf which allows to pass in a document with basic\n// dependencies to be converted to enhanced\nmessage DependencyEnhancerRequest {\n  required Document          document           = 1;\n\n  oneof ref {\n    Language          language           = 2;\n    // The expected value of this is a regex which matches relative pronouns\n    string            relativePronouns   = 3;\n  }\n}\n\n// A version of ParseTree with a flattened structure so that deep trees\n// don't exceed the protobuf stack depth\nmessage FlattenedParseTree {\n  message Node {\n    oneof contents {\n      bool              openNode           = 1;\n      bool              closeNode          = 2;\n      string            value              = 3;\n    }\n\n    optional double     score              = 4;\n  }\n\n  repeated Node         nodes              = 1;\n}\n\n// A protobuf for calling the java constituency parser evaluator from elsewhere\nmessage EvaluateParserRequest {\n  message ParseResult {\n    required FlattenedParseTree         gold           = 1;\n    // repeated so you can send in kbest parses, if your parser handles that\n    // note that this already includes a score field\n    repeated FlattenedParseTree         predicted      = 2;\n  }\n\n  repeated ParseResult         treebank       = 1;\n}\n\nmessage EvaluateParserResponse {\n  required double              f1             = 1;\n  optional double              kbestF1        = 2;\n  // keep track of the individual tree F1 scores\n  repeated double              treeF1         = 3;\n}\n\n\n// A protobuf for running Tsurgeon operations on constituency trees\nmessage TsurgeonRequest {\n  message Operation {\n    required string                tregex         = 1;\n    repeated string                tsurgeon       = 2;\n  }\n  repeated Operation               operations     = 1;\n  repeated FlattenedParseTree      trees          = 2;\n}\n\n// The results of the Tsurgeon operation\nmessage TsurgeonResponse {\n  repeated FlattenedParseTree      trees          = 1;\n}\n\n// Sent in Morphology requests - a stream of sentences with tagged words\nmessage MorphologyRequest {\n  message TaggedWord {\n    required string                word           = 1;\n    optional string                xpos           = 2;\n  }\n\n  repeated TaggedWord              words          = 1;\n}\n\n// Sent back from the Morphology request - the words and their tags\nmessage MorphologyResponse {\n  message WordTagLemma {\n    required string                word           = 1;\n    optional string                xpos           = 2;\n    required string                lemma          = 3;\n  }\n\n  repeated WordTagLemma            words          = 1;\n}\n\n\n// A request for converting constituency trees to dependency graphs\nmessage DependencyConverterRequest {\n  repeated FlattenedParseTree      trees          = 1;\n}\n\n// The result of using the CoreNLP dependency converter.\n// One graph per tree\nmessage DependencyConverterResponse {\n  message DependencyConversion {\n    required DependencyGraph       graph          = 1;\n    optional FlattenedParseTree    tree           = 2;\n  }\n\n  repeated DependencyConversion    conversions         = 1;\n}\n\n"
  },
  {
    "path": "scripts/config.sh",
    "content": "#!/bin/bash\n#\n# Set environment variables for the training and testing of stanza modules.\n\n# Set UDBASE to the location of UD data folder\n# The data should be CoNLL-U format\n# For details, see\n#   http://universaldependencies.org/conll18/data.html (CoNLL-18 UD data)\n#   https://universaldependencies.org/\n# When rebuilding models based on Universal Dependencies, download the\n#   UD data to some directory, set UDBASE to that directory, and\n#   uncomment this line.  Alternatively, put UDBASE in your shell\n#   config, Windows env variables, etc as relevant.\n# export UDBASE=/path/to/UD\n\n# Set NERBASE to the location of NER data folder\n# The data should be BIO format or convertable to that format\n# For details, see https://www.aclweb.org/anthology/W03-0419.pdf (CoNLL-03 NER paper)\n# There are other NER datasets, supported in\n#   stanza/utils/datasets/ner/prepare_ner_dataset.py\n# If rebuilding NER data, choose a location for the NER directory\n#   and set NERBASE to that variable.\n# export NERBASE=/path/to/NER\n\n# Set CONSTITUENCY_BASE to the location of NER data folder\n# The data will be in some dataset-specific format\n# There is a conversion script which will turn this\n#   into a PTB style format\n#   stanza/utils/datasets/constituency/prepare_con_dataset.py\n# If processing constituency data, choose a location for the CON data\n#   and set CONSTITUENCY_BASE to that variable.\n# export CONSTITUENCY_BASE=/path/to/CON\n\n# Set directories to store processed training/evaluation files\n# $DATA_ROOT is a default home for where all the outputs from the\n#   preparation scripts will go.  The training scripts will then look\n#   for the stanza formatted data in that directory.\nexport DATA_ROOT=./data\nexport TOKENIZE_DATA_DIR=$DATA_ROOT/tokenize\nexport MWT_DATA_DIR=$DATA_ROOT/mwt\nexport LEMMA_DATA_DIR=$DATA_ROOT/lemma\nexport POS_DATA_DIR=$DATA_ROOT/pos\nexport DEPPARSE_DATA_DIR=$DATA_ROOT/depparse\nexport ETE_DATA_DIR=$DATA_ROOT/ete\nexport NER_DATA_DIR=$DATA_ROOT/ner\nexport CHARLM_DATA_DIR=$DATA_ROOT/charlm\nexport CONSTITUENCY_DATA_DIR=$DATA_ROOT/constituency\nexport SENTIMENT_DATA_DIR=$DATA_ROOT/sentiment\n\n# Set directories to store external word vector data\nexport WORDVEC_DIR=./extern_data/wordvec\n"
  },
  {
    "path": "scripts/download_vectors.sh",
    "content": "#!/bin/bash\n#\n# Download word vector files for all supported languages. Run as:\n#   ./download_vectors.sh WORDVEC_DIR\n# where WORDVEC_DIR is the target directory to store the word vector data.\n\n# check arguments\n: ${1?\"Usage: $0 WORDVEC_DIR\"}\nWORDVEC_DIR=$1\n\n# constants and functions\nCONLL17_URL=\"https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1989/word-embeddings-conll17.tar\"\nCONLL17_TAR=\"word-embeddings-conll17.tar\"\n\nFASTTEXT_BASE_URL=\"https://dl.fbaipublicfiles.com/fasttext/vectors-wiki\"\n\n# TODO: some fasttext vectors are now at\n# https://fasttext.cc/docs/en/pretrained-vectors.html\n# there are also vectors for\n# Welsh, Icelandic, Thai, Sanskrit\n# https://fasttext.cc/docs/en/crawl-vectors.html\n\n# We get the Armenian word vectors from here:\n# https://github.com/ispras-texterra/word-embeddings-eval-hy\n# https://arxiv.org/ftp/arxiv/papers/1906/1906.03134.pdf\n# In particular, the glove model (dogfooding):\n# https://at.ispras.ru/owncloud/index.php/s/pUUiS1l1jGKNax3/download\n# These vectors improved F1 by about 1 on various tasks for Armenian\n# and had much better coverage of Western Armenian\n\n# For Eryza, we use word vectors available here:\n# https://github.com/mokha/semantics\n# @incollection{Alnajjar_2021,\n#   doi = {10.31885/9789515150257.24},\n#   url = {https://doi.org/10.31885%2F9789515150257.24},\n#   year = 2021,\n#   month = {mar},\n#   publisher = {University of Helsinki},\n#   pages = {275--288},\n#   author = {Khalid Alnajjar},\n#   title = {When Word Embeddings Become Endangered},\n#   booktitle = {Multilingual Facilitation}\n# }\n\ndeclare -a FASTTEXT_LANG=(\"Afrikaans\" \"Breton\" \"Buryat\" \"Chinese\" \"Faroese\" \"Gothic\" \"Kurmanji\" \"North_Sami\" \"Serbian\" \"Upper_Sorbian\")\ndeclare -a FASTTEXT_CODE=(\"af\" \"br\" \"bxr\" \"zh\" \"fo\" \"got\" \"ku\" \"se\" \"sr\" \"hsb\")\ndeclare -a LOCAL_CODE=(\"af\" \"br\" \"bxr\" \"zh\" \"fo\" \"got\" \"kmr\" \"sme\" \"sr\" \"hsb\")\n\ncolor_green='\\033[32;1m'\ncolor_clear='\\033[0m' # No Color\nfunction msg() {\n    echo -e \"${color_green}$@${color_clear}\"\n}\n\nfunction prepare_fasttext_vec() {\n    lang=$1\n    ftcode=$2\n    code=$3\n\n    cwd=$(pwd)\n    mkdir -p $lang\n    cd $lang\n    msg \"=== Downloading fasttext vector file for ${lang}...\"\n    url=\"${FASTTEXT_BASE_URL}/wiki.${ftcode}.vec\"\n    fname=\"${code}.vectors\"\n    wget $url -O $fname\n\n    msg \"=== Compressing file ${fname}...\"\n    xz $fname\n    cd $cwd\n}\n\n# do the actual work\nmkdir -p $WORDVEC_DIR\ncd $WORDVEC_DIR\n\nmsg \"Downloading CONLL17 word vectors. This may take a while...\"\nwget $CONLL17_URL -O $CONLL17_TAR\n\nmsg \"Extracting CONLL17 word vector files...\"\ntar -xvf $CONLL17_TAR\nrm $CONLL17_TAR\n\nmsg \"Preparing fasttext vectors for the rest of the languages.\"\nfor (( i=0; i<${#FASTTEXT_LANG[*]}; ++i)); do\n    prepare_fasttext_vec ${FASTTEXT_LANG[$i]} ${FASTTEXT_CODE[$i]} ${LOCAL_CODE[$i]}\ndone\n\n# handle old french\nmkdir Old_French\nln -s French/fr.vectors.xz Old_French/fro.vectors.xz\n\nmsg \"All done.\"\n"
  },
  {
    "path": "setup.py",
    "content": "# Always prefer setuptools over distutils\nimport re\n\nfrom setuptools import setup, find_packages\n# To use a consistent encoding\nfrom codecs import open\nfrom os import path\n\nhere = path.abspath(path.dirname(__file__))\n\n# read the version from stanza/_version.py\nversion_file_contents = open(path.join(here, 'stanza/_version.py'), encoding='utf-8').read()\nVERSION = re.compile('__version__ = \\\"(.*)\\\"').search(version_file_contents).group(1)\n\n# Get the long description from the README file\nwith open(path.join(here, 'README.md'), encoding='utf-8') as f:\n    long_description = f.read()\n\nsetup(\n    name='stanza',\n\n    # Versions should comply with PEP440.  For a discussion on single-sourcing\n    # the version across setup.py and the project code, see\n    # https://packaging.python.org/en/latest/single_source_version.html\n    version=VERSION,\n\n    description='A Python NLP Library for Many Human Languages, by the Stanford NLP Group',\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    # The project's main homepage.\n    url='https://github.com/stanfordnlp/stanza',\n\n    # Author details\n    author='Stanford Natural Language Processing Group',\n    author_email='jebolton@stanford.edu',\n\n    # Choose your license\n    license='Apache License 2.0',\n\n    # See https://pypi.python.org/pypi?%3Aaction=list_classifiers\n    classifiers=[\n        # How mature is this project? Common values are\n        #   3 - Alpha\n        #   4 - Beta\n        #   5 - Production/Stable\n        'Development Status :: 4 - Beta',\n\n        # Indicate who your project is intended for\n        'Intended Audience :: Developers',\n        'Intended Audience :: Education',\n        'Intended Audience :: Science/Research',\n        'Intended Audience :: Information Technology',\n        'Topic :: Scientific/Engineering',\n        'Topic :: Scientific/Engineering :: Artificial Intelligence',\n        'Topic :: Scientific/Engineering :: Information Analysis',\n        'Topic :: Text Processing',\n        'Topic :: Text Processing :: Linguistic',\n        'Topic :: Software Development',\n        'Topic :: Software Development :: Libraries',\n\n        # Specify the Python versions you support here. In particular, ensure\n        # that you indicate whether you support Python 2, Python 3 or both.\n        'Programming Language :: Python :: 3.9',\n        'Programming Language :: Python :: 3.10',\n        'Programming Language :: Python :: 3.11',\n        'Programming Language :: Python :: 3.12',\n        'Programming Language :: Python :: 3.13',\n    ],\n\n    # What does your project relate to?\n    keywords='natural-language-processing nlp natural-language-understanding stanford-nlp deep-learning',\n\n    # You can just specify the packages manually here if your project is\n    # simple. Or you can use find_packages().\n    packages=find_packages(exclude=['data', 'docs', 'extern_data', 'figures', 'saved_models']),\n\n    # List run-time dependencies here.  These will be installed by pip when\n    # your project is installed. For an analysis of \"install_requires\" vs pip's\n    # requirements files see:\n    # https://packaging.python.org/en/latest/requirements.html\n    install_requires=[\n        'emoji', \n        'numpy', \n        'platformdirs',\n        'protobuf>=3.15.0',\n        'requests', \n        'networkx',\n        'tomli;python_version<\"3.11\"',\n        'torch>=1.13.0',\n        'tqdm',\n        'udtools>=0.2.4',\n    ],\n\n    # List required Python versions\n    python_requires='>=3.9',\n\n    # List additional groups of dependencies here (e.g. development\n    # dependencies). You can install these using the following syntax,\n    # for example:\n    # $ pip install -e .[dev,test]\n    extras_require={\n        'dev': [\n            'check-manifest',\n        ],\n        'test': [\n            'coverage', \n            'pytest',\n        ],\n        'transformers': [\n            'transformers>=3.0.0',\n            'peft>=0.6.1',\n        ],\n        'datasets': [\n            'datasets',\n        ],\n        'tokenizers': [\n            'jieba',\n            'pythainlp',\n            'python-crfsuite',\n            'spacy',\n            'sudachidict_core',\n            'sudachipy',\n        ],\n        'visualization': [\n            'spacy',\n            'streamlit',\n            'ipython',\n        ],\n        'morphseg': [\n            'morphseg>=0.2.0',\n        ]\n    },\n\n    # If there are data files included in your packages that need to be\n    # installed, specify them here.  If using Python 2.6 or less, then these\n    # have to be included in MANIFEST.in as well.\n    package_data={\n        \"\": [\"pipeline/demo/*ttf\",\n             \"pipeline/demo/*css\",\n             \"pipeline/demo/*html\",\n             \"pipeline/demo/*js\",\n             \"pipeline/demo/*gif\",],\n    },\n\n    include_package_data=True,\n\n    # Although 'package_data' is the preferred approach, in some case you may\n    # need to place data files outside of your packages. See:\n    # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa\n    # In this case, 'data_file' will be installed into '<sys.prefix>/my_data'\n    data_files=[],\n\n    # To provide executable scripts, use entry points in preference to the\n    # \"scripts\" keyword. Entry points provide cross-platform support and allow\n    # pip to create the appropriate form of executable for the target platform.\n    entry_points={\n    },\n)\n"
  },
  {
    "path": "stanza/__init__.py",
    "content": "from stanza.pipeline.core import DownloadMethod, Pipeline\nfrom stanza.pipeline.multilingual import MultilingualPipeline\nfrom stanza.models.common.doc import Document\nfrom stanza.resources.common import download\nfrom stanza.resources.installation import install_corenlp, download_corenlp_models\nfrom stanza._version import __version__, __resources_version__\nfrom stanza.pipeline.morphseg_processor import MorphSegProcessor\n\nimport logging\nlogger = logging.getLogger('stanza')\n\n# if the client application hasn't set the log level, we set it\n# ourselves to INFO\nif logger.level == 0:\n    logger.setLevel(logging.INFO)\n\nlog_handler = logging.StreamHandler()\nlog_formatter = logging.Formatter(fmt=\"%(asctime)s %(levelname)s: %(message)s\",\n                              datefmt='%Y-%m-%d %H:%M:%S')\nlog_handler.setFormatter(log_formatter)\n\n# also, if the client hasn't added any handlers for this logger\n# (or a default handler), we add a handler of our own\n#\n# client can later do\n#   logger.removeHandler(stanza.log_handler)\nif not logger.hasHandlers():\n    logger.addHandler(log_handler)\n"
  },
  {
    "path": "stanza/_version.py",
    "content": "\"\"\" Single source of truth for version number \"\"\"\n\n__version__ = \"1.11.1\"\n__resources_version__ = '1.11.0'\n"
  },
  {
    "path": "stanza/models/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/_training_logging.py",
    "content": "import logging\n\nlogger = logging.getLogger('stanza')\nlogger.setLevel(logging.DEBUG)"
  },
  {
    "path": "stanza/models/charlm.py",
    "content": "\"\"\"\nEntry point for training and evaluating a character-level neural language model.\n\"\"\"\n\nimport argparse\nfrom copy import copy\nimport logging\nimport lzma\nimport math\nimport os\nimport random\nimport time\nfrom types import GeneratorType\nimport numpy as np\nimport torch\n\nfrom stanza.models.common.char_model import build_charlm_vocab, CharacterLanguageModel, CharacterLanguageModelTrainer\nfrom stanza.models.common.vocab import CharVocab\nfrom stanza.models.common import utils\nfrom stanza.models import _training_logging\n\nlogger = logging.getLogger('stanza')\n\ndef repackage_hidden(h):\n    \"\"\"Wraps hidden states in new Tensors,\n    to detach them from their history.\"\"\"\n    if isinstance(h, torch.Tensor):\n        return h.detach()\n    else:\n        return tuple(repackage_hidden(v) for v in h)\n\ndef batchify(data, bsz, device):\n    # Work out how cleanly we can divide the dataset into bsz parts.\n    nbatch = data.size(0) // bsz\n    # Trim off any extra elements that wouldn't cleanly fit (remainders).\n    data = data.narrow(0, 0, nbatch * bsz)\n    # Evenly divide the data across the bsz batches.\n    data = data.view(bsz, -1) # batch_first is True\n    data = data.to(device)\n    return data\n\ndef get_batch(source, i, seq_len):\n    seq_len = min(seq_len, source.size(1) - 1 - i)\n    data = source[:, i:i+seq_len]\n    target = source[:, i+1:i+1+seq_len].reshape(-1)\n    return data, target\n\ndef load_file(filename, vocab, direction):\n    with utils.open_read_text(filename) as fin:\n        data = fin.read()\n\n    idx = vocab['char'].map(data)\n    if direction == 'backward': idx = idx[::-1]\n    return torch.tensor(idx)\n\ndef load_data(path, vocab, direction):\n    if os.path.isdir(path):\n        filenames = sorted(os.listdir(path))\n        for filename in filenames:\n            logger.info('Loading data from {}'.format(filename))\n            data = load_file(os.path.join(path, filename), vocab, direction)\n            yield data\n    else:\n        data = load_file(path, vocab, direction)\n        yield data\n\ndef build_argparse():\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument('--train_file', type=str, help=\"Input plaintext file\")\n    parser.add_argument('--train_dir', type=str, help=\"If non-empty, load from directory with multiple training files\")\n    parser.add_argument('--eval_file', type=str, help=\"Input plaintext file for the dev/test set\")\n    parser.add_argument('--shorthand', type=str, help=\"UD treebank shorthand\")\n\n    parser.add_argument('--mode', default='train', choices=['train', 'predict'])\n    parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help=\"Forward or backward language model\")\n    parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help=\"Train a forward language model\")\n    parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help=\"Train a backward language model\")\n\n    parser.add_argument('--char_emb_dim', type=int, default=100, help=\"Dimension of unit embeddings\")\n    parser.add_argument('--char_hidden_dim', type=int, default=1024, help=\"Dimension of hidden units\")\n    parser.add_argument('--char_num_layers', type=int, default=1, help=\"Layers of RNN in the language model\")\n    parser.add_argument('--char_dropout', type=float, default=0.05, help=\"Dropout probability\")\n    parser.add_argument('--char_unit_dropout', type=float, default=1e-5, help=\"Randomly set an input char to UNK during training\")\n    parser.add_argument('--char_rec_dropout', type=float, default=0.0, help=\"Recurrent dropout probability\")\n\n    parser.add_argument('--batch_size', type=int, default=100, help=\"Batch size to use\")\n    parser.add_argument('--bptt_size', type=int, default=250, help=\"Sequence length to consider at a time\")\n    parser.add_argument('--epochs', type=int, default=50, help=\"Total epochs to train the model for\")\n    parser.add_argument('--max_grad_norm', type=float, default=0.25, help=\"Maximum gradient norm to clip to\")\n    parser.add_argument('--lr0', type=float, default=5, help=\"Initial learning rate\")\n    parser.add_argument('--anneal', type=float, default=0.25, help=\"Anneal the learning rate by this amount when dev performance deteriorate\")\n    parser.add_argument('--patience', type=int, default=1, help=\"Patience for annealing the learning rate\")\n    parser.add_argument('--weight_decay', type=float, default=0.0, help=\"Weight decay\")\n    parser.add_argument('--momentum', type=float, default=0.0, help='Momentum for SGD.')\n    parser.add_argument('--cutoff', type=int, default=1000, help=\"Frequency cutoff for char vocab. By default we assume a very large corpus.\")\n    \n    parser.add_argument('--report_steps', type=int, default=50, help=\"Update step interval to report loss\")\n    parser.add_argument('--eval_steps', type=int, default=100000, help=\"Update step interval to run eval on dev; set to -1 to eval after each epoch\")\n    parser.add_argument('--save_name', type=str, default=None, help=\"File name to save the model\")\n    parser.add_argument('--vocab_save_name', type=str, default=None, help=\"File name to save the vocab\")\n    parser.add_argument('--checkpoint_save_name', type=str, default=None, help=\"File name to save the most recent checkpoint\")\n    parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help=\"Don't save checkpoints\")\n    parser.add_argument('--save_dir', type=str, default='saved_models/charlm', help=\"Directory to save models in\")\n    parser.add_argument('--summary', action='store_true', help='Use summary writer to record progress.')\n    utils.add_device_args(parser)\n    parser.add_argument('--seed', type=int, default=1234)\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n    return parser\n\ndef build_model_filename(args):\n    if args['save_name']:\n        save_name = args['save_name']\n    else:\n        save_name = '{}_{}_charlm.pt'.format(args['shorthand'], args['direction'])\n    model_file = os.path.join(args['save_dir'], save_name)\n    return model_file\n\ndef parse_args(args=None):\n    parser = build_argparse()\n\n    args = parser.parse_args(args=args)\n\n    if args.wandb_name:\n        args.wandb = True\n\n    args = vars(args)\n    return args\n\ndef main(args=None):\n    args = parse_args(args=args)\n\n    utils.set_random_seed(args['seed'])\n\n    logger.info(\"Running {} character-level language model in {} mode\".format(args['direction'], args['mode']))\n    \n    utils.ensure_dir(args['save_dir'])\n\n    if args['mode'] == 'train':\n        train(args)\n    else:\n        evaluate(args)\n\ndef evaluate_epoch(args, vocab, data, model, criterion):\n    \"\"\"\n    Run an evaluation over entire dataset.\n    \"\"\"\n    model.eval()\n    device = next(model.parameters()).device\n    hidden = None\n    total_loss = 0\n    if isinstance(data, GeneratorType):\n        data = list(data)\n        assert len(data) == 1, 'Only support single dev/test file'\n        data = data[0]\n    batches = batchify(data, args['batch_size'], device)\n    with torch.no_grad():\n        for i in range(0, batches.size(1) - 1, args['bptt_size']):\n            data, target = get_batch(batches, i, args['bptt_size'])\n            lens = [data.size(1) for i in range(data.size(0))]\n\n            output, hidden, decoded = model.forward(data, lens, hidden)\n            loss = criterion(decoded.view(-1, len(vocab['char'])), target)\n            \n            hidden = repackage_hidden(hidden)\n            total_loss += data.size(1) * loss.data.item()\n    return total_loss / batches.size(1)\n\ndef evaluate_and_save(args, vocab, data, trainer, best_loss, model_file, checkpoint_file, writer=None):\n    \"\"\"\n    Run an evaluation over entire dataset, print progress and save the model if necessary.\n    \"\"\"\n    start_time = time.time()\n    loss = evaluate_epoch(args, vocab, data, trainer.model, trainer.criterion)\n    ppl = math.exp(loss)\n    elapsed = int(time.time() - start_time)\n    # TODO: step the scheduler less often when the eval frequency is higher\n    previous_lr = get_current_lr(trainer, args)\n    trainer.scheduler.step(loss)\n    current_lr = get_current_lr(trainer, args)\n    if previous_lr != current_lr:\n        logger.info(\"Updating learning rate to %f\", current_lr)\n    logger.info(\n        \"| eval checkpoint @ global step {:10d} | time elapsed {:6d}s | loss {:5.2f} | ppl {:8.2f}\".format(\n            trainer.global_step,\n            elapsed,\n            loss,\n            ppl,\n        )\n    )\n    if best_loss is None or loss < best_loss:\n        best_loss = loss\n        trainer.save(model_file, full=False)\n        logger.info('new best model saved at step {:10d}'.format(trainer.global_step))\n    if writer:\n        writer.add_scalar('dev_loss', loss, global_step=trainer.global_step)\n        writer.add_scalar('dev_ppl', ppl, global_step=trainer.global_step)\n    if checkpoint_file:\n        trainer.save(checkpoint_file, full=True)\n        logger.info('new checkpoint saved at step {:10d}'.format(trainer.global_step))\n\n    return loss, ppl, best_loss\n\ndef get_current_lr(trainer, args):\n    return trainer.scheduler.state_dict().get('_last_lr', [args['lr0']])[0]\n\ndef load_char_vocab(vocab_file):\n    return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage, weights_only=True))}\n\ndef train(args):\n    utils.log_training_args(args, logger)\n    model_file = build_model_filename(args)\n\n    vocab_file = args['save_dir'] + '/' + args['vocab_save_name'] if args['vocab_save_name'] is not None \\\n        else '{}/{}_vocab.pt'.format(args['save_dir'], args['shorthand'])\n\n    if args['checkpoint']:\n        checkpoint_file = utils.checkpoint_name(args['save_dir'], model_file, args['checkpoint_save_name'])\n    else:\n        checkpoint_file = None\n\n    if os.path.exists(vocab_file):\n        logger.info('Loading existing vocab file')\n        vocab = load_char_vocab(vocab_file)\n    else:\n        logger.info('Building and saving vocab')\n        vocab = {'char': build_charlm_vocab(args['train_file'] if args['train_dir'] is None else args['train_dir'], cutoff=args['cutoff'])}\n        torch.save(vocab['char'].state_dict(), vocab_file)\n    logger.info(\"Training model with vocab size: {}\".format(len(vocab['char'])))\n\n    if checkpoint_file and os.path.exists(checkpoint_file):\n        logger.info('Loading existing checkpoint: %s' % checkpoint_file)\n        trainer = CharacterLanguageModelTrainer.load(args, checkpoint_file, finetune=True)\n    else:\n        trainer = CharacterLanguageModelTrainer.from_new_model(args, vocab)\n\n    writer = None\n    if args['summary']:\n        from torch.utils.tensorboard import SummaryWriter\n        summary_dir = '{}/{}_summary'.format(args['save_dir'], args['save_name']) if args['save_name'] is not None \\\n            else '{}/{}_{}_charlm_summary'.format(args['save_dir'], args['shorthand'], args['direction'])\n        writer = SummaryWriter(log_dir=summary_dir)\n    \n    # evaluate model within epoch if eval_interval is set\n    eval_within_epoch = False\n    if args['eval_steps'] > 0:\n        eval_within_epoch = True\n\n    if args['wandb']:\n        import wandb\n        wandb_name = args['wandb_name'] if args['wandb_name'] else '%s_%s_charlm' % (args['shorthand'], args['direction'])\n        wandb.init(name=wandb_name, config=args)\n        wandb.run.define_metric('best_loss', summary='min')\n        wandb.run.define_metric('ppl', summary='min')\n\n    device = next(trainer.model.parameters()).device\n\n    best_loss = None\n    start_epoch = trainer.epoch  # will default to 1 for a new trainer\n    for trainer.epoch in range(start_epoch, args['epochs']+1):\n        # load train data from train_dir if not empty, otherwise load from file\n        if args['train_dir'] is not None:\n            train_path = args['train_dir']\n        else:\n            train_path = args['train_file']\n        train_data = load_data(train_path, vocab, args['direction'])\n        dev_data = load_file(args['eval_file'], vocab, args['direction']) # dev must be a single file\n\n        # run over entire training set\n        for data_chunk in train_data:\n            batches = batchify(data_chunk, args['batch_size'], device)\n            hidden = None\n            total_loss = 0.0\n            total_batches = math.ceil((batches.size(1) - 1) / args['bptt_size'])\n            iteration, i = 0, 0\n            # over the data chunk\n            while i < batches.size(1) - 1 - 1:\n                trainer.model.train()\n                trainer.global_step += 1\n                start_time = time.time()\n                bptt = args['bptt_size'] if np.random.random() < 0.95 else args['bptt_size']/ 2.\n                # prevent excessively small or negative sequence lengths\n                seq_len = max(5, int(np.random.normal(bptt, 5)))\n                # prevent very large sequence length, must be <= 1.2 x bptt\n                seq_len = min(seq_len, int(args['bptt_size'] * 1.2))\n                data, target = get_batch(batches, i, seq_len)\n                lens = [data.size(1) for i in range(data.size(0))]\n                \n                trainer.optimizer.zero_grad()\n                output, hidden, decoded = trainer.model.forward(data, lens, hidden)\n                loss = trainer.criterion(decoded.view(-1, len(vocab['char'])), target)\n                total_loss += loss.data.item()\n                loss.backward()\n\n                torch.nn.utils.clip_grad_norm_(trainer.params, args['max_grad_norm'])\n                trainer.optimizer.step()\n\n                hidden = repackage_hidden(hidden)\n\n                if (iteration + 1) % args['report_steps'] == 0:\n                    cur_loss = total_loss / args['report_steps']\n                    elapsed = time.time() - start_time\n                    logger.info(\n                        \"| epoch {:5d} | {:5d}/{:5d} batches | sec/batch {:.6f} | loss {:5.2f} | ppl {:8.2f}\".format(\n                            trainer.epoch,\n                            iteration + 1,\n                            total_batches,\n                            elapsed / args['report_steps'],\n                            cur_loss,\n                            math.exp(cur_loss),\n                        )\n                    )\n                    if args['wandb']:\n                        wandb.log({'train_loss': cur_loss}, step=trainer.global_step)\n                    total_loss = 0.0\n\n                iteration += 1\n                i += seq_len\n\n                # evaluate if necessary\n                if eval_within_epoch and trainer.global_step % args['eval_steps'] == 0:\n                    _, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer)\n                    if args['wandb']:\n                        wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step)\n\n        # if eval_interval isn't provided, run evaluation after each epoch\n        if not eval_within_epoch or trainer.epoch == args['epochs']:\n            _, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer)\n            if args['wandb']:\n                wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step)\n\n    if writer:\n        writer.close()\n    if args['wandb']:\n        wandb.finish()\n    return\n\ndef evaluate(args):\n    model_file = build_model_filename(args)\n\n    model = CharacterLanguageModel.load(model_file).to(args['device'])\n    vocab = model.vocab\n    data = load_data(args['eval_file'], vocab, args['direction'])\n    criterion = torch.nn.CrossEntropyLoss()\n    \n    loss = evaluate_epoch(args, vocab, data, model, criterion)\n    logger.info(\n        \"| best model | loss {:5.2f} | ppl {:8.2f}\".format(\n            loss,\n            math.exp(loss),\n        )\n    )\n    return\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/classifier.py",
    "content": "import argparse\nimport ast\nimport logging\nimport os\nimport random\nimport re\nfrom enum import Enum\n\nimport torch\nimport torch.nn as nn\n\nfrom stanza.models.common import loss\nfrom stanza.models.common import utils\nfrom stanza.models.pos.vocab import CharVocab\n\nimport stanza.models.classifiers.data as data\nfrom stanza.models.classifiers.trainer import Trainer\nfrom stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType\nfrom stanza.models.common.peft_config import add_peft_args, resolve_peft_args\n\nfrom stanza.utils.confusion import format_confusion, confusion_to_accuracy, confusion_to_macro_f1\n\n\nclass Loss(Enum):\n    CROSS = 1\n    WEIGHTED_CROSS = 2\n    LOG_CROSS = 3\n    FOCAL = 4\n\nclass DevScoring(Enum):\n    ACCURACY = 'ACC'\n    WEIGHTED_F1 = 'WF'\n\nlogger = logging.getLogger('stanza')\ntlogger = logging.getLogger('stanza.classifiers.trainer')\n\nlogging.getLogger('elmoformanylangs').setLevel(logging.WARNING)\n\nDEFAULT_TRAIN='data/sentiment/en_sstplus.train.txt'\nDEFAULT_DEV='data/sentiment/en_sst3roots.dev.txt'\nDEFAULT_TEST='data/sentiment/en_sst3roots.test.txt'\n\n\"\"\"A script for training and testing classifier models, especially on the SST.\n\nIf you run the script with no arguments, it will start trying to train\na sentiment model.\n\npython3 -m stanza.models.classifier\n\nThis requires the sentiment dataset to be in an `extern_data`\ndirectory, such as by symlinking it from somewhere else.\n\nThe default model is a CNN where the word vectors are first mapped to\nchannels with filters of a few different widths, those channels are\nmaxpooled over the entire sentence, and then the resulting pools have\nfully connected layers until they reach the number of classes in the\ntraining data.  You can see the defaults in the options below.\n\nhttps://arxiv.org/abs/1408.5882\n\n(Currently the CNN is the only sentence classifier implemented.)\n\nTo train with a more complicated CNN arch:\n\nnohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 > FC41.out 2>&1 &\n\nYou can train models with word vectors other than the default word2vec.  For example:\n\n nohup python3 -u -m stanza.models.classifier  --wordvec_type google --wordvec_dir extern_data/google --max_epochs 200 --filter_channels 1000 --fc_shapes 200,100 --base_name FC21_google > FC21_google.out 2>&1 &\n\nA model trained on the 5 class dataset can be tested on the 2 class dataset with a command line like this:\n\npython3 -u -m stanza.models.classifier  --no_train --load_name saved_models/classifier/sst_en_ewt_FS_3_4_5_C_1000_FC_400_100_classifier.E0165-ACC41.87.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels \"{0:0, 1:0, 3:1, 4:1}\"\n\npython3 -u -m stanza.models.classifier  --wordvec_type google --wordvec_dir extern_data/google --no_train --load_name saved_models/classifier/FC21_google_en_ewt_FS_3_4_5_C_1000_FC_200_100_classifier.E0189-ACC45.87.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels \"{0:0, 1:0, 3:1, 4:1}\"\n\nA model trained on the 3 class dataset can be tested on the 2 class dataset with a command line like this:\n\npython3 -u -m stanza.models.classifier  --wordvec_type google --wordvec_dir extern_data/google --no_train --load_name saved_models/classifier/FC21_3C_google_en_ewt_FS_3_4_5_C_1000_FC_200_100_classifier.E0101-ACC68.94.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels \"{0:0, 2:1}\"\n\nTo train models on combined 3 class datasets:\n\nnohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_3class  --extra_wordvec_method CONCAT --extra_wordvec_dim 200  --train_file data/sentiment/en_sstplus.train.txt --dev_file data/sentiment/en_sst3roots.dev.txt --test_file data/sentiment/en_sst3roots.test.txt > FC41_3class.out 2>&1 &\n\nThis tests that model:\n\npython3 -u -m stanza.models.classifier --no_train --load_name en_sstplus.pt --test_file data/sentiment/en_sst3roots.test.txt\n\nHere is an example for training a model in a different language:\n\nnohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_german  --train_file data/sentiment/de_sb10k.train.txt --dev_file data/sentiment/de_sb10k.dev.txt --test_file data/sentiment/de_sb10k.test.txt --shorthand de_sb10k --min_train_len 3 --extra_wordvec_method CONCAT --extra_wordvec_dim 100 > de_sb10k.out 2>&1 &\n\nThis uses more data, although that wound up being worse for the German model:\n\nnohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_german  --train_file data/sentiment/de_sb10k.train.txt,data/sentiment/de_scare.train.txt,data/sentiment/de_usage.train.txt --dev_file data/sentiment/de_sb10k.dev.txt --test_file data/sentiment/de_sb10k.test.txt --shorthand de_sb10k --min_train_len 3 --extra_wordvec_method CONCAT --extra_wordvec_dim 100 > de_sb10k.out 2>&1 &\n\nnohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_chinese --train_file data/sentiment/zh_ren.train.txt --dev_file data/sentiment/zh_ren.dev.txt --test_file data/sentiment/zh_ren.test.txt --shorthand zh_ren --wordvec_type fasttext --extra_wordvec_method SUM --wordvec_pretrain_file ../stanza_resources/zh-hans/pretrain/gsdsimp.pt > zh_ren.out 2>&1 &\n\nnohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --save_name vi_vsfc.pt  --train_file data/sentiment/vi_vsfc.train.json --dev_file data/sentiment/vi_vsfc.dev.json --test_file data/sentiment/vi_vsfc.test.json --shorthand vi_vsfc --wordvec_pretrain_file ../stanza_resources/vi/pretrain/vtb.pt --wordvec_type word2vec --extra_wordvec_method SUM --dev_eval_scoring WEIGHTED_F1 > vi_vsfc.out 2>&1 &\n\npython3 -u -m stanza.models.classifier --no_train --test_file extern_data/sentiment/vietnamese/_UIT-VSFC/test.txt --shorthand vi_vsfc --wordvec_pretrain_file ../stanza_resources/vi/pretrain/vtb.pt --wordvec_type word2vec --load_name vi_vsfc.pt\n\"\"\"\n\ndef convert_fc_shapes(arg):\n    \"\"\"\n    Returns a tuple of sizes to use in FC layers.\n\n    For examples, converts \"100\" -> (100,)\n    \"100,200\" -> (100,200)\n    \"\"\"\n    arg = arg.strip()\n    if not arg:\n        return ()\n    arg = ast.literal_eval(arg)\n    if isinstance(arg, int):\n        return (arg,)\n    if isinstance(arg, tuple):\n        return arg\n    return tuple(arg)\n\n# For the most part, these values are for the constituency parser.\n# Only the WD for adadelta is originally for sentiment\n# Also LR for adadelta and madgrad\n\n# madgrad learning rate experiment on sstplus\n# note that the hyperparameters are not cross-validated in tandem, so\n# later changes may make some earlier experiments slightly out of date\n# LR\n#   0.01         failed to converge\n#   0.004        failed to converge\n#   0.003        0.5572\n#   0.002        failed to converge\n#   0.001        0.6857\n#   0.0008       0.6799\n#   0.0005       0.6849\n#   0.00025      0.6749\n#   0.0001       0.6746\n#   0.00001      0.6536\n#   0.000001     0.6267\n# LR 0.001 produced the best model, but it does occasionally fail to\n# converge to a working model, so we set the default to 0.0005 instead\nDEFAULT_LEARNING_RATES = { \"adamw\": 0.0002, \"adadelta\": 1.0, \"sgd\": 0.001, \"adabelief\": 0.00005, \"madgrad\": 0.0005, \"sgd\": 0.001 }\nDEFAULT_LEARNING_EPS = { \"adabelief\": 1e-12, \"adadelta\": 1e-6, \"adamw\": 1e-8 }\nDEFAULT_LEARNING_RHO = 0.9\nDEFAULT_MOMENTUM = { \"madgrad\": 0.9, \"sgd\": 0.9 }\nDEFAULT_WEIGHT_DECAY = { \"adamw\": 0.05, \"adadelta\": 0.0001, \"sgd\": 0.01, \"adabelief\": 1.2e-6, \"madgrad\": 2e-6 }\n\ndef build_argparse():\n    \"\"\"\n    Build the argparse for the classifier.\n\n    Refactored so that other utility scripts can use the same parser if needed.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--train', dest='train', default=True, action='store_true', help='Train the model (default)')\n    parser.add_argument('--no_train', dest='train', action='store_false', help=\"Don't train the model\")\n\n    parser.add_argument('--shorthand', type=str, default='en_ewt', help=\"Treebank shorthand, eg 'en' for English\")\n\n    parser.add_argument('--load_name', type=str, default=None, help='Name for loading an existing model')\n    parser.add_argument('--save_dir', type=str, default='saved_models/classifier', help='Root dir for saving models.')\n    parser.add_argument('--save_name', type=str, default=\"{shorthand}_{embedding}_{bert_finetuning}_{classifier_type}_classifier.pt\", help='Name for saving the model')\n\n    parser.add_argument('--checkpoint_save_name', type=str, default=None, help=\"File name to save the most recent checkpoint\")\n    parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help=\"Don't save checkpoints\")\n\n    parser.add_argument('--save_intermediate_models', default=False, action='store_true',\n                        help='Save all intermediate models - this can be a lot!')\n\n    parser.add_argument('--train_file', type=str, default=DEFAULT_TRAIN, help='Input file(s) to train a model from.  Each line is an example.  Should go <label> <tokenized sentence>.  Comma separated list.')\n    parser.add_argument('--dev_file', type=str, default=DEFAULT_DEV, help='Input file(s) to use as the dev set.')\n    parser.add_argument('--test_file', type=str, default=DEFAULT_TEST, help='Input file(s) to use as the test set.')\n    parser.add_argument('--output_predictions', default=False, action='store_true', help='Output predictions when running the test set')\n    parser.add_argument('--max_epochs', type=int, default=100)\n    parser.add_argument('--tick', type=int, default=50)\n\n    parser.add_argument('--model_type', type=lambda x: ModelType[x.upper()], default=ModelType.CNN,\n                        help='Model type to use.  Options: %s' % \" \".join(x.name for x in ModelType))\n\n    parser.add_argument('--filter_sizes', default=(3,4,5), type=ast.literal_eval, help='Filter sizes for the layer after the word vectors')\n    parser.add_argument('--filter_channels', default=1000, type=ast.literal_eval, help='Number of channels for layers after the word vectors.  Int for same number of channels (scaled by width) for each filter, or tuple/list for exact lengths for each filter')\n    parser.add_argument('--fc_shapes', default=\"400,100\", type=convert_fc_shapes, help='Extra fully connected layers to put after the initial filters.  If set to blank, will FC directly from the max pooling to the output layer.')\n    parser.add_argument('--dropout', default=0.5, type=float, help='Dropout value to use')\n\n    parser.add_argument('--batch_size', default=50, type=int, help='Batch size when training')\n    parser.add_argument('--batch_single_item', default=200, type=int, help='Items of this size go in their own batch')\n    parser.add_argument('--dev_eval_batches', default=2000, type=int, help='Run the dev set after this many train batches.  Set to 0 to only do it once per epoch')\n    parser.add_argument('--dev_eval_scoring', type=lambda x: DevScoring[x.upper()], default=DevScoring.WEIGHTED_F1,\n                        help=('Scoring method to use for choosing the best model.  Options: %s' %\n                              \" \".join(x.name for x in DevScoring)))\n\n    parser.add_argument('--weight_decay', default=None, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer')\n    parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate to use in the optimizer')\n    parser.add_argument('--momentum', default=None, type=float, help='Momentum to use in the optimizer')\n\n    parser.add_argument('--optim', default='adadelta', choices=['adadelta', 'madgrad', 'sgd'], help='Optimizer type: SGD, Adadelta, or madgrad.  Highly recommend to install madgrad and use that')\n\n    parser.add_argument('--test_remap_labels', default=None, type=ast.literal_eval,\n                        help='Map of which label each classifier label should map to.  For example, \"{0:0, 1:0, 3:1, 4:1}\" to map a 5 class sentiment test to a 2 class.  Any labels not mapped will be considered wrong')\n    parser.add_argument('--forgive_unmapped_labels', dest='forgive_unmapped_labels', default=True, action='store_true',\n                        help='When remapping labels, such as from 5 class to 2 class, pick a different label if the first guess is not remapped.')\n    parser.add_argument('--no_forgive_unmapped_labels', dest='forgive_unmapped_labels', action='store_false',\n                        help=\"When remapping labels, such as from 5 class to 2 class, DON'T pick a different label if the first guess is not remapped.\")\n\n    parser.add_argument('--loss', type=lambda x: Loss[x.upper()], default=Loss.CROSS,\n                        help=\"Whether to use regular cross entropy or scale it by 1/log(quantity)\")\n    parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss')\n    parser.add_argument('--min_train_len', type=int, default=0,\n                        help=\"Filter sentences less than this length\")\n\n    parser.add_argument('--pretrain_max_vocab', type=int, default=-1)\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n    parser.add_argument('--wordvec_raw_file', type=str, default=None, help='Exact name of the raw wordvec file to read')\n    parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors')\n    parser.add_argument('--wordvec_type', type=lambda x: WVType[x.upper()], default='word2vec', help='Different vector types have different options, such as google 300d replacing numbers with #')\n    parser.add_argument('--extra_wordvec_dim', type=int, default=0, help=\"Extra dim of word vectors - will be trained\")\n    parser.add_argument('--extra_wordvec_method', type=lambda x: ExtraVectors[x.upper()], default='sum', help='How to train extra dimensions of word vectors, if at all')\n    parser.add_argument('--extra_wordvec_max_norm', type=float, default=None, help=\"Max norm for initializing the extra vectors\")\n\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n    parser.add_argument('--charlm_projection', type=int, default=None, help=\"Project the charlm values to this dimension\")\n    parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help=\"Use lowercased characters in character model.\")\n\n    parser.add_argument('--elmo_model', default='extern_data/manyelmo/english', help='Directory with elmo model')\n    parser.add_argument('--use_elmo', dest='use_elmo', default=False, action='store_true', help='Use an elmo model as a source of parameters')\n    parser.add_argument('--elmo_projection', type=int, default=None, help='Project elmo to this many dimensions')\n\n    parser.add_argument('--bert_model', type=str, default=None, help=\"Use an external bert model (requires the transformers package)\")\n    parser.add_argument('--no_bert_model', dest='bert_model', action=\"store_const\", const=None, help=\"Don't use bert\")\n    parser.add_argument('--bert_finetune', default=False, action='store_true', help=\"Finetune the Bert model\")\n    parser.add_argument('--bert_learning_rate', default=0.01, type=float, help='Scale the learning rate for transformer finetuning by this much')\n    parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much')\n    parser.add_argument('--bert_hidden_layers', type=int, default=4, help=\"How many layers of hidden state to use from the transformer\")\n    parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')\n\n    parser.add_argument('--bilstm', dest='bilstm', action='store_true', default=True, help=\"Use a bilstm after the inputs, before the convs.  Using bilstm is about as accurate and significantly faster (because of dim reduction) than going straight to the filters\")\n    parser.add_argument('--no_bilstm', dest='bilstm', action='store_false', help=\"Don't use a bilstm after the inputs, before the convs.\")\n    # somewhere between 200-300 seems to be the sweet spot for a couple datasets:\n    # dev set macro f1 scores on 3 class problems\n    # note that these were only run once each\n    # more trials might narrow down which ones works best\n    # es_tass2020:\n    #   150        0.5580\n    #   200        0.5629\n    #   250        0.5586\n    #   300        0.5642    <---\n    #   400        0.5525\n    #   500        0.5579\n    #   750        0.5585\n    # en_sstplus:\n    #   150        0.6816\n    #   200        0.6721\n    #   250        0.6915    <---\n    #   300        0.6824\n    #   400        0.6757\n    #   500        0.6770\n    #   750        0.6781\n    # de_sb10k\n    #   150        0.6745\n    #   200        0.6798    <---\n    #   250        0.6459\n    #   300        0.6665\n    #   400        0.6521\n    #   500        0.6584\n    #   750        0.6447\n    parser.add_argument('--bilstm_hidden_dim', type=int, default=300, help=\"Dimension of the bilstm to use\")\n\n    parser.add_argument('--maxpool_width', type=int, default=1, help=\"Width of the maxpool kernel to use\")\n\n    parser.add_argument('--no_constituency_backprop', dest='constituency_backprop', default=True, action='store_false', help=\"When using a constituency parser, backprop into the parser's weights if True\")\n    parser.add_argument('--constituency_model', type=str, default=\"/home/john/stanza_resources/it/constituency/vit_bert.pt\", help=\"Which constituency model to use.  TODO: make this more user friendly\")\n    parser.add_argument('--constituency_batch_norm', default=False, action='store_true', help='Add a LayerNorm between the output of the parser and the classifier layers')\n    parser.add_argument('--constituency_node_attn', default=False, action='store_true', help='True means to make an attn layer out of the tree, with the words as key and nodes as query')\n    parser.add_argument('--no_constituency_node_attn', dest='constituency_node_attn', action='store_false', help='True means to make an attn layer out of the tree, with the words as key and nodes as query')\n    parser.add_argument('--constituency_top_layer', dest='constituency_top_layer', default=False, action='store_true', help='True means use the top (ROOT) layer of the constituents.  Otherwise, the next layer down (S, usually) will be used')\n    parser.add_argument('--no_constituency_top_layer', dest='constituency_top_layer', action='store_false', help='True means use the top (ROOT) layer of the constituents.  Otherwise, the next layer down (S, usually) will be used')\n    parser.add_argument('--constituency_all_words', default=False, action='store_true', help='Use all word positions in the constituency classifier')\n    parser.add_argument('--no_constituency_all_words', dest='constituency_all_words', default=False, action='store_false', help='Use the start and end word embeddings as inputs to the constituency classifier')\n\n    parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training.  A very noisy option')\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n\n    parser.add_argument('--seed', default=None, type=int, help='Random seed for model')\n\n    add_peft_args(parser)\n    utils.add_device_args(parser)\n\n    return parser\n\ndef build_model_filename(args):\n    shape = \"FS_%s\" % \"_\".join([str(x) for x in args.filter_sizes])\n    shape = shape + \"_C_%d_\" % args.filter_channels\n    if args.fc_shapes:\n        shape = shape + \"_FC_%s_\" % \"_\".join([str(x) for x in args.fc_shapes])\n\n    model_save_file = utils.standard_model_file_name(vars(args), \"classifier\", shape=shape, classifier_type=args.model_type.name)\n    logger.info(\"Expanded save_name: %s\", model_save_file)\n    return model_save_file\n\ndef parse_args(args=None):\n    \"\"\"\n    Add arguments for building the classifier.\n    Parses command line args and returns the result.\n    \"\"\"\n    parser = build_argparse()\n    args = parser.parse_args(args)\n    resolve_peft_args(args, tlogger)\n\n    if args.wandb_name:\n        args.wandb = True\n\n    args.optim = args.optim.lower()\n    if args.weight_decay is None:\n        args.weight_decay = DEFAULT_WEIGHT_DECAY.get(args.optim, None)\n    if args.momentum is None:\n        args.momentum = DEFAULT_MOMENTUM.get(args.optim, None)\n    if args.learning_rate is None:\n        args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim, None)\n\n    return args\n\n\ndef dataset_predictions(model, dataset):\n    model.eval()\n    index_label_map = {x: y for (x, y) in enumerate(model.labels)}\n\n    dataset_lengths = data.sort_dataset_by_len(dataset, keep_index=True)\n\n    predictions = []\n    o_idx = []\n    for length in dataset_lengths.keys():\n        batch = dataset_lengths[length]\n        output = model([x[0] for x in batch])\n        for i in range(len(batch)):\n            predicted = torch.argmax(output[i])\n            predicted_label = index_label_map[predicted.item()]\n            predictions.append(predicted_label)\n            o_idx.append(batch[i][1])\n\n    predictions = utils.unsort(predictions, o_idx)\n    return predictions\n\ndef confusion_dataset(predictions, dataset, labels):\n    \"\"\"\n    Returns a confusion matrix\n\n    First key: gold\n    Second key: predicted\n    so: confusion_matrix[gold][predicted]\n    \"\"\"\n    confusion_matrix = {}\n    for label in labels:\n        confusion_matrix[label] = {}\n\n    for predicted_label, datum in zip(predictions, dataset):\n        expected_label = datum.sentiment\n        confusion_matrix[expected_label][predicted_label] = confusion_matrix[expected_label].get(predicted_label, 0) + 1\n\n    return confusion_matrix\n\n\ndef score_dataset(model, dataset, label_map=None,\n                  remap_labels=None, forgive_unmapped_labels=False):\n    \"\"\"\n    remap_labels: a dict from old label to new label to use when\n    testing a classifier on a dataset with a simpler label set.\n    For example, a model trained on 5 class sentiment can be tested\n    on a binary distribution with {\"0\": \"0\", \"1\": \"0\", \"3\": \"1\", \"4\": \"1\"}\n\n    forgive_unmapped_labels says the following: in the case that the\n    model predicts \"2\" in the above example for remap_labels, instead\n    treat the model's prediction as whichever label it gave the\n    highest score\n    \"\"\"\n    model.eval()\n    if label_map is None:\n        label_map = {x: y for (y, x) in enumerate(model.labels)}\n    correct = 0\n    dataset_lengths = data.sort_dataset_by_len(dataset)\n\n    for length in dataset_lengths.keys():\n        # TODO: possibly break this up into smaller batches\n        batch = dataset_lengths[length]\n        expected_labels = [label_map[x.sentiment] for x in batch]\n\n        output = model(batch)\n\n        for i in range(len(expected_labels)):\n            predicted = torch.argmax(output[i])\n            predicted_label = predicted.item()\n            if remap_labels:\n                if predicted_label in remap_labels:\n                    predicted_label = remap_labels[predicted_label]\n                else:\n                    found = False\n                    if forgive_unmapped_labels:\n                        items = []\n                        for j in range(len(output[i])):\n                            items.append((output[i][j].item(), j))\n                        items.sort(key=lambda x: -x[0])\n                        for _, item in items:\n                            if item in remap_labels:\n                                predicted_label = remap_labels[item]\n                                found = True\n                                break\n                    # if slack guesses allowed, none of the existing\n                    # labels matched, so we count it wrong.  if slack\n                    # guesses not allowed, just count it wrong\n                    if not found:\n                        continue\n\n            if predicted_label == expected_labels[i]:\n                correct = correct + 1\n    return correct\n\ndef score_dev_set(model, dev_set, dev_eval_scoring):\n    predictions = dataset_predictions(model, dev_set)\n    confusion_matrix = confusion_dataset(predictions, dev_set, model.labels)\n    logger.info(\"Dev set confusion matrix:\\n{}\".format(format_confusion(confusion_matrix, model.labels)))\n    correct, total = confusion_to_accuracy(confusion_matrix)\n    macro_f1 = confusion_to_macro_f1(confusion_matrix)\n    logger.info(\"Dev set: %d correct of %d examples.  Accuracy: %f\" %\n                (correct, len(dev_set), correct / len(dev_set)))\n    logger.info(\"Macro f1: {}\".format(macro_f1))\n\n    accuracy = correct / total\n    if dev_eval_scoring is DevScoring.ACCURACY:\n        return accuracy, accuracy, macro_f1\n    elif dev_eval_scoring is DevScoring.WEIGHTED_F1:\n        return macro_f1, accuracy, macro_f1\n    else:\n        raise ValueError(\"Unknown scoring method {}\".format(dev_eval_scoring))\n\ndef intermediate_name(filename, epoch, dev_scoring, score):\n    \"\"\"\n    Build an informative intermediate checkpoint name from a base name, epoch #, and accuracy\n    \"\"\"\n    root, ext = os.path.splitext(filename)\n    return root + \".E{epoch:04d}-{score_type}{acc:05.2f}\".format(**{\"epoch\": epoch, \"score_type\": dev_scoring.value, \"acc\": score * 100}) + ext\n\ndef log_param_sizes(model):\n    logger.debug(\"--- Model parameter sizes ---\")\n    total_size = 0\n    for name, param in model.named_parameters():\n        param_size = param.element_size() * param.nelement()\n        total_size += param_size\n        logger.debug(\"  %s %d %d %d\", name, param.element_size(), param.nelement(), param_size)\n    logger.debug(\"  Total size: %d\", total_size)\n\ndef train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, labels):\n    tlogger.setLevel(logging.DEBUG)\n\n    # TODO: use a (torch) dataloader to possibly speed up the GPU usage\n    model = trainer.model\n    optimizer = trainer.optimizer\n\n    device = next(model.parameters()).device\n    logger.info(\"Current device: %s\" % device)\n\n    label_map = {x: y for (y, x) in enumerate(labels)}\n    label_tensors = {x: torch.tensor(y, requires_grad=False, device=device)\n                     for (y, x) in enumerate(labels)}\n\n    process_outputs = lambda x: x\n    if args.loss == Loss.CROSS:\n        logger.info(\"Creating CrossEntropyLoss\")\n        loss_function = nn.CrossEntropyLoss()\n    elif args.loss == Loss.WEIGHTED_CROSS:\n        logger.info(\"Creating weighted cross entropy loss w/o log\")\n        loss_function = loss.weighted_cross_entropy_loss([label_map[x[0]] for x in train_set], log_dampened=False)\n    elif args.loss == Loss.LOG_CROSS:\n        logger.info(\"Creating weighted cross entropy loss w/ log\")\n        loss_function = loss.weighted_cross_entropy_loss([label_map[x[0]] for x in train_set], log_dampened=True)\n    elif args.loss == Loss.FOCAL:\n        try:\n            from focal_loss.focal_loss import FocalLoss\n        except ImportError:\n            raise ImportError(\"focal_loss not installed.  Must `pip install focal_loss_torch` to use the --loss=focal feature\")\n        logger.info(\"Creating FocalLoss with loss %f\", args.loss_focal_gamma)\n        process_outputs = lambda x: torch.softmax(x, dim=1)\n        loss_function = FocalLoss(gamma=args.loss_focal_gamma)\n    else:\n        raise ValueError(\"Unknown loss function {}\".format(args.loss))\n    loss_function.to(device)\n\n    train_set_by_len = data.sort_dataset_by_len(train_set)\n\n    if trainer.global_step > 0:\n        # We reloaded the model, so let's report its current dev set score\n        _ = score_dev_set(model, dev_set, args.dev_eval_scoring)\n        logger.info(\"Reloaded model for continued training.\")\n        if trainer.best_score is not None:\n            logger.info(\"Previous best score: %.5f\", trainer.best_score)\n\n    log_param_sizes(model)\n\n    # https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html\n    if args.wandb:\n        import wandb\n        wandb_name = args.wandb_name if args.wandb_name else \"%s_classifier\" % args.shorthand\n        wandb.init(name=wandb_name, config=args)\n        wandb.run.define_metric('accuracy', summary='max')\n        wandb.run.define_metric('macro_f1', summary='max')\n        wandb.run.define_metric('epoch_loss', summary='min')\n\n    for opt_name, opt in optimizer.items():\n        current_lr = opt.param_groups[0]['lr']\n        logger.info(\"optimizer %s learning rate: %s\", opt_name, current_lr)\n\n    # if this is a brand new training run, and we're saving all intermediate models, save the start model as well\n    if args.save_intermediate_models and trainer.epochs_trained == 0:\n        intermediate_file = intermediate_name(model_file, trainer.epochs_trained, args.dev_eval_scoring, 0.0)\n        trainer.save(intermediate_file, save_optimizer=False)\n    for trainer.epochs_trained in range(trainer.epochs_trained, args.max_epochs):\n        running_loss = 0.0\n        epoch_loss = 0.0\n        shuffled_batches = data.shuffle_dataset(train_set_by_len, args.batch_size, args.batch_single_item)\n\n        model.train()\n        logger.info(\"Starting epoch %d\", trainer.epochs_trained)\n        if args.log_norms:\n            model.log_norms()\n\n        for batch_num, batch in enumerate(shuffled_batches):\n            # logger.debug(\"Batch size %d max len %d\" % (len(batch), max(len(x.text) for x in batch)))\n            trainer.global_step += 1\n            logger.debug(\"Starting batch: %d step %d\", batch_num, trainer.global_step)\n\n            batch_labels = torch.stack([label_tensors[x.sentiment] for x in batch])\n\n            # zero the parameter gradients\n            for opt in optimizer.values():\n                opt.zero_grad()\n\n            outputs = model(batch)\n            outputs = process_outputs(outputs)\n            batch_loss = loss_function(outputs, batch_labels)\n            batch_loss.backward()\n            for opt in optimizer.values():\n                opt.step()\n\n            # print statistics\n            running_loss += batch_loss.item()\n            if (batch_num + 1) % args.tick == 0: # print every so many batches\n                train_loss = running_loss / args.tick\n                logger.info('[%d, %5d] Average loss: %.3f', trainer.epochs_trained + 1, batch_num + 1, train_loss)\n                if args.wandb:\n                    wandb.log({'train_loss': train_loss}, step=trainer.global_step)\n                if args.dev_eval_batches > 0 and (batch_num + 1) % args.dev_eval_batches == 0:\n                    logger.info('---- Interim analysis ----')\n                    dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring)\n                    if args.wandb:\n                        wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1}, step=trainer.global_step)\n                    if trainer.best_score is None or dev_score > trainer.best_score:\n                        trainer.best_score = dev_score\n                        trainer.save(model_file, save_optimizer=False)\n                        logger.info(\"Saved new best score model!  Accuracy %.5f   Macro F1 %.5f   Epoch %5d   Batch %d\" % (accuracy, macro_f1, trainer.epochs_trained+1, batch_num+1))\n                    model.train()\n                    if args.log_norms:\n                        trainer.model.log_norms()\n                epoch_loss += running_loss\n                running_loss = 0.0\n        # Add any leftover loss to the epoch_loss\n        epoch_loss += running_loss\n\n        logger.info(\"Finished epoch %d  Total loss %.3f\" % (trainer.epochs_trained + 1, epoch_loss))\n        dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring)\n        if args.wandb:\n            wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1, 'epoch_loss': epoch_loss}, step=trainer.global_step)\n        if checkpoint_file:\n            trainer.save(checkpoint_file, epochs_trained = trainer.epochs_trained + 1)\n        if args.save_intermediate_models:\n            intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score)\n            trainer.save(intermediate_file, save_optimizer=False)\n        if trainer.best_score is None or dev_score > trainer.best_score:\n            trainer.best_score = dev_score\n            trainer.save(model_file, save_optimizer=False)\n            logger.info(\"Saved new best score model!  Accuracy %.5f   Macro F1 %.5f   Epoch %5d\" % (accuracy, macro_f1, trainer.epochs_trained+1))\n\n    if args.wandb:\n        wandb.finish()\n\ndef main(args=None):\n    args = parse_args(args)\n    seed = utils.set_random_seed(args.seed)\n    logger.info(\"Using random seed: %d\" % seed)\n\n    utils.ensure_dir(args.save_dir)\n\n    save_name = build_model_filename(args)\n\n    # TODO: maybe the dataset needs to be in a torch data loader in order to\n    # make cuda operations faster\n    checkpoint_file = None\n    if args.train:\n        train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)\n        logger.info(\"Using training set: %s\" % args.train_file)\n        logger.info(\"Training set has %d labels\" % len(data.dataset_labels(train_set)))\n        tlogger.setLevel(logging.DEBUG)\n\n        tlogger.info(\"Saving checkpoints: %s\", args.checkpoint)\n        if args.checkpoint:\n            checkpoint_file = utils.checkpoint_name(args.save_dir, save_name, args.checkpoint_save_name)\n            tlogger.info(\"Checkpoint filename: %s\", checkpoint_file)\n    elif not args.load_name:\n        if save_name:\n            args.load_name = save_name\n        else:\n            raise ValueError(\"No model provided and not asked to train a model.  This makes no sense\")\n    else:\n        train_set = None\n\n    if args.train and checkpoint_file is not None and os.path.exists(checkpoint_file):\n        trainer = Trainer.load(checkpoint_file, args, load_optimizer=args.train)\n    elif args.load_name:\n        trainer = Trainer.load(args.load_name, args, load_optimizer=args.train)\n    else:\n        trainer = Trainer.build_new_model(args, train_set)\n\n    trainer.model.log_configuration()\n\n    if args.train:\n        utils.log_training_args(args, logger)\n\n        dev_set = data.read_dataset(args.dev_file, args.wordvec_type, min_len=None)\n        logger.info(\"Using dev set: %s\", args.dev_file)\n        logger.info(\"Training set has %d items\", len(train_set))\n        logger.info(\"Dev set has %d items\", len(dev_set))\n        data.check_labels(trainer.model.labels, dev_set)\n\n        train_model(trainer, save_name, checkpoint_file, args, train_set, dev_set, trainer.model.labels)\n\n    if args.log_norms:\n        trainer.model.log_norms()\n    test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None)\n    logger.info(\"Using test set: %s\" % args.test_file)\n    data.check_labels(trainer.model.labels, test_set)\n\n    if args.test_remap_labels is None:\n        predictions = dataset_predictions(trainer.model, test_set)\n        confusion_matrix = confusion_dataset(predictions, test_set, trainer.model.labels)\n        if args.output_predictions:\n            logger.info(\"List of predictions: %s\", predictions)\n        logger.info(\"Confusion matrix:\\n{}\".format(format_confusion(confusion_matrix, trainer.model.labels)))\n        correct, total = confusion_to_accuracy(confusion_matrix)\n        logger.info(\"Macro f1: {}\".format(confusion_to_macro_f1(confusion_matrix)))\n    else:\n        correct = score_dataset(trainer.model, test_set,\n                                remap_labels=args.test_remap_labels,\n                                forgive_unmapped_labels=args.forgive_unmapped_labels)\n        total = len(test_set)\n    logger.info(\"Test set: %d correct of %d examples.  Accuracy: %f\" %\n                (correct, total, correct / total))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/classifiers/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/classifiers/base_classifier.py",
    "content": "from abc import ABC, abstractmethod\n\nimport logging\n\nimport torch\nimport torch.nn as nn\n\nfrom stanza.models.common.utils import split_into_batches, sort_with_indices, unsort\n\n\"\"\"\nA base classifier type\n\nCurrently, has the ability to process text or other inputs in a manner\nsuitable for the particular model type.\nIn other words, the CNNClassifier processes lists of words,\nand the ConstituencyClassifier processes trees\n\"\"\"\n\nlogger = logging.getLogger('stanza')\n\nclass BaseClassifier(ABC, nn.Module):\n    @abstractmethod\n    def extract_sentences(self, doc):\n        \"\"\"\n        Extract the sentences or the relevant information in the sentences from a document\n        \"\"\"\n\n    def preprocess_sentences(self, sentences):\n        \"\"\"\n        By default, don't do anything\n        \"\"\"\n        return sentences\n\n    def label_sentences(self, sentences, batch_size=None):\n        \"\"\"\n        Given a list of sentences, return the model's results on that text.\n        \"\"\"\n        self.eval()\n\n        sentences = self.preprocess_sentences(sentences)\n\n        if batch_size is None:\n            intervals = [(0, len(sentences))]\n            orig_idx = None\n        else:\n            sentences, orig_idx = sort_with_indices(sentences, key=len, reverse=True)\n            intervals = split_into_batches(sentences, batch_size)\n        labels = []\n        for interval in intervals:\n            if interval[1] - interval[0] == 0:\n                # this can happen for empty text\n                continue\n            output = self(sentences[interval[0]:interval[1]])\n            predicted = torch.argmax(output, dim=1)\n            labels.extend(predicted.tolist())\n\n        if orig_idx:\n            sentences = unsort(sentences, orig_idx)\n            labels = unsort(labels, orig_idx)\n\n        logger.debug(\"Found labels\")\n        for (label, sentence) in zip(labels, sentences):\n            logger.debug((label, sentence))\n\n        return labels\n"
  },
  {
    "path": "stanza/models/classifiers/cnn_classifier.py",
    "content": "import dataclasses\nimport logging\nimport math\nimport os\nimport random\nimport re\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport stanza.models.classifiers.data as data\nfrom stanza.models.classifiers.base_classifier import BaseClassifier\nfrom stanza.models.classifiers.config import CNNConfig\nfrom stanza.models.classifiers.data import SentimentDatum\nfrom stanza.models.classifiers.utils import ExtraVectors, ModelType, build_output_layers\nfrom stanza.models.common.bert_embedding import extract_bert_embeddings\nfrom stanza.models.common.data import get_long_tensor, sort_all\nfrom stanza.models.common.utils import attach_bert_model\nfrom stanza.models.common.vocab import PAD_ID, UNK_ID\n\n\"\"\"\nThe CNN classifier is based on Yoon Kim's work:\n\nhttps://arxiv.org/abs/1408.5882\n\nAlso included are maxpool 2d, conv 2d, and a bilstm, as in\n\nText Classification Improved by Integrating Bidirectional LSTM\nwith Two-dimensional Max Pooling\nhttps://aclanthology.org/C16-1329.pdf\n\nThe architecture is simple:\n\n- Embedding at the bottom layer\n  - separate learnable entry for UNK, since many of the embeddings we have use 0 for UNK\n- maybe a bilstm layer, as per a command line flag\n- Some number of conv2d layers over the embedding\n- Maxpool layers over small windows, window size being a parameter\n- FC layer to the classification layer\n\nOne experiment which was run and found to be a bit of a negative was\nputting a layer on top of the pretrain.  You would think that might\nhelp, but dev performance went down for each variation of\n  - trans(emb)\n  - relu(trans(emb))\n  - dropout(trans(emb))\n  - dropout(relu(trans(emb)))\n\"\"\"\n\nlogger = logging.getLogger('stanza')\ntlogger = logging.getLogger('stanza.classifiers.trainer')\n\nclass CNNClassifier(BaseClassifier):\n    def __init__(self, pretrain, extra_vocab, labels,\n                 charmodel_forward, charmodel_backward, elmo_model, bert_model, bert_tokenizer, force_bert_saved, peft_name,\n                 args):\n        \"\"\"\n        pretrain is a pretrained word embedding.  should have .emb and .vocab\n\n        extra_vocab is a collection of words in the training data to\n        be used for the delta word embedding, if used.  can be set to\n        None if delta word embedding is not used.\n\n        labels is the list of labels we expect in the training data.\n        Used to derive the number of classes.  Saving it in the model\n        will let us check that test data has the same labels\n\n        args is either the complete arguments when training, or the\n        subset of arguments stored in the model save file\n        \"\"\"\n        super(CNNClassifier, self).__init__()\n        self.labels = labels\n        bert_finetune = args.bert_finetune\n        use_peft = args.use_peft\n        force_bert_saved = force_bert_saved or bert_finetune\n        logger.debug(\"bert_finetune %s / force_bert_saved %s\", bert_finetune, force_bert_saved)\n\n        # this may change when loaded in a new Pipeline, so it's not part of the config\n        self.peft_name = peft_name\n\n        # we build a separate config out of the args so that we can easily save it in torch\n        self.config = CNNConfig(filter_channels = args.filter_channels,\n                                filter_sizes = args.filter_sizes,\n                                fc_shapes = args.fc_shapes,\n                                dropout = args.dropout,\n                                num_classes = len(labels),\n                                wordvec_type = args.wordvec_type,\n                                extra_wordvec_method = args.extra_wordvec_method,\n                                extra_wordvec_dim = args.extra_wordvec_dim,\n                                extra_wordvec_max_norm = args.extra_wordvec_max_norm,\n                                char_lowercase = args.char_lowercase,\n                                charlm_projection = args.charlm_projection,\n                                has_charlm_forward = charmodel_forward is not None,\n                                has_charlm_backward = charmodel_backward is not None,\n                                use_elmo = args.use_elmo,\n                                elmo_projection = args.elmo_projection,\n                                bert_model = args.bert_model,\n                                bert_finetune = bert_finetune,\n                                bert_hidden_layers = args.bert_hidden_layers,\n                                force_bert_saved = force_bert_saved,\n\n                                use_peft = use_peft,\n                                lora_rank = args.lora_rank,\n                                lora_alpha = args.lora_alpha,\n                                lora_dropout = args.lora_dropout,\n                                lora_modules_to_save = args.lora_modules_to_save,\n                                lora_target_modules = args.lora_target_modules,\n\n                                bilstm = args.bilstm,\n                                bilstm_hidden_dim = args.bilstm_hidden_dim,\n                                maxpool_width = args.maxpool_width,\n                                model_type = ModelType.CNN)\n\n        self.char_lowercase = args.char_lowercase\n\n        self.unsaved_modules = []\n\n        emb_matrix = pretrain.emb\n        self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))\n        self.add_unsaved_module('elmo_model', elmo_model)\n        self.vocab_size = emb_matrix.shape[0]\n        self.embedding_dim = emb_matrix.shape[1]\n\n        self.add_unsaved_module('forward_charlm', charmodel_forward)\n        if charmodel_forward is not None:\n            tlogger.debug(\"Got forward char model of dimension {}\".format(charmodel_forward.hidden_dim()))\n            if not charmodel_forward.is_forward_lm:\n                raise ValueError(\"Got a backward charlm as a forward charlm!\")\n        self.add_unsaved_module('backward_charlm', charmodel_backward)\n        if charmodel_backward is not None:\n            tlogger.debug(\"Got backward char model of dimension {}\".format(charmodel_backward.hidden_dim()))\n            if charmodel_backward.is_forward_lm:\n                raise ValueError(\"Got a forward charlm as a backward charlm!\")\n\n        attach_bert_model(self, bert_model, bert_tokenizer, self.config.use_peft, force_bert_saved)\n\n        # The Pretrain has PAD and UNK already (indices 0 and 1), but we\n        # possibly want to train UNK while freezing the rest of the embedding\n        # note that the /10.0 operation has to be inside nn.Parameter unless\n        # you want to spend a long time debugging this\n        self.unk = nn.Parameter(torch.randn(self.embedding_dim) / np.sqrt(self.embedding_dim) / 10.0)\n\n        # replacing NBSP picks up a whole bunch of words for VI\n        self.vocab_map = { word.replace('\\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }\n\n        if self.config.extra_wordvec_method is not ExtraVectors.NONE:\n            if not extra_vocab:\n                raise ValueError(\"Should have had extra_vocab set for extra_wordvec_method {}\".format(self.config.extra_wordvec_method))\n            if not args.extra_wordvec_dim:\n                self.config.extra_wordvec_dim = self.embedding_dim\n            if self.config.extra_wordvec_method is ExtraVectors.SUM:\n                if self.config.extra_wordvec_dim != self.embedding_dim:\n                    raise ValueError(\"extra_wordvec_dim must equal embedding_dim for {}\".format(self.config.extra_wordvec_method))\n\n            self.extra_vocab = list(extra_vocab)\n            self.extra_vocab_map = { word: i for i, word in enumerate(self.extra_vocab) }\n            # TODO: possibly add regularization specifically on the extra embedding?\n            # note: it looks like a bug that this doesn't add UNK or PAD, but actually\n            # those are expected to already be the first two entries\n            self.extra_embedding = nn.Embedding(num_embeddings = len(extra_vocab),\n                                                embedding_dim = self.config.extra_wordvec_dim,\n                                                max_norm = self.config.extra_wordvec_max_norm,\n                                                padding_idx = 0)\n            tlogger.debug(\"Extra embedding size: {}\".format(self.extra_embedding.weight.shape))\n        else:\n            self.extra_vocab = None\n            self.extra_vocab_map = None\n            self.config.extra_wordvec_dim = 0\n            self.extra_embedding = None\n\n        # Pytorch is \"aware\" of the existence of the nn.Modules inside\n        # an nn.ModuleList in terms of parameters() etc\n        if self.config.extra_wordvec_method is ExtraVectors.NONE:\n            total_embedding_dim = self.embedding_dim\n        elif self.config.extra_wordvec_method is ExtraVectors.SUM:\n            total_embedding_dim = self.embedding_dim\n        elif self.config.extra_wordvec_method is ExtraVectors.CONCAT:\n            total_embedding_dim = self.embedding_dim + self.config.extra_wordvec_dim\n        else:\n            raise ValueError(\"unable to handle {}\".format(self.config.extra_wordvec_method))\n\n        if charmodel_forward is not None:\n            if args.charlm_projection:\n                self.charmodel_forward_projection = nn.Linear(charmodel_forward.hidden_dim(), args.charlm_projection)\n                total_embedding_dim += args.charlm_projection\n            else:\n                self.charmodel_forward_projection = None\n                total_embedding_dim += charmodel_forward.hidden_dim()\n\n        if charmodel_backward is not None:\n            if args.charlm_projection:\n                self.charmodel_backward_projection = nn.Linear(charmodel_backward.hidden_dim(), args.charlm_projection)\n                total_embedding_dim += args.charlm_projection\n            else:\n                self.charmodel_backward_projection = None\n                total_embedding_dim += charmodel_backward.hidden_dim()\n\n        if self.config.use_elmo:\n            if elmo_model is None:\n                raise ValueError(\"Model requires elmo, but elmo_model not passed in\")\n            elmo_dim = elmo_model.sents2elmo([[\"Test\"]])[0].shape[1]\n\n            # this mapping will combine 3 layers of elmo to 1 layer of features\n            self.elmo_combine_layers = nn.Linear(in_features=3, out_features=1, bias=False)\n            if self.config.elmo_projection:\n                self.elmo_projection = nn.Linear(in_features=elmo_dim, out_features=self.config.elmo_projection)\n                total_embedding_dim = total_embedding_dim + self.config.elmo_projection\n            else:\n                total_embedding_dim = total_embedding_dim + elmo_dim\n\n        if bert_model is not None:\n            if self.config.bert_hidden_layers:\n                # The average will be offset by 1/N so that the default zeros\n                # repressents an average of the N layers\n                if self.config.bert_hidden_layers > bert_model.config.num_hidden_layers:\n                    # limit ourselves to the number of layers actually available\n                    # note that we can +1 because of the initial embedding layer\n                    self.config.bert_hidden_layers = bert_model.config.num_hidden_layers + 1\n                self.bert_layer_mix = nn.Linear(self.config.bert_hidden_layers, 1, bias=False)\n                nn.init.zeros_(self.bert_layer_mix.weight)\n            else:\n                # an average of layers 2, 3, 4 will be used\n                # (for historic reasons)\n                self.bert_layer_mix = None\n\n            if bert_tokenizer is None:\n                raise ValueError(\"Cannot have a bert model without a tokenizer\")\n            self.bert_dim = self.bert_model.config.hidden_size\n            total_embedding_dim += self.bert_dim\n\n        if self.config.bilstm:\n            conv_input_dim = self.config.bilstm_hidden_dim * 2\n            self.bilstm = nn.LSTM(batch_first=True,\n                                  input_size=total_embedding_dim,\n                                  hidden_size=self.config.bilstm_hidden_dim,\n                                  num_layers=2,\n                                  bidirectional=True,\n                                  dropout=0.2)\n        else:\n            conv_input_dim = total_embedding_dim\n            self.bilstm = None\n\n        self.fc_input_size = 0\n        self.conv_layers = nn.ModuleList()\n        self.max_window = 0\n        for filter_idx, filter_size in enumerate(self.config.filter_sizes):\n            if isinstance(filter_size, int):\n                self.max_window = max(self.max_window, filter_size)\n                if isinstance(self.config.filter_channels, int):\n                    filter_channels = self.config.filter_channels\n                else:\n                    filter_channels = self.config.filter_channels[filter_idx]\n                fc_delta = filter_channels // self.config.maxpool_width\n                tlogger.debug(\"Adding full width filter %d.  Output channels: %d -> %d\", filter_size, filter_channels, fc_delta)\n                self.fc_input_size += fc_delta\n                self.conv_layers.append(nn.Conv2d(in_channels=1,\n                                                  out_channels=filter_channels,\n                                                  kernel_size=(filter_size, conv_input_dim)))\n            elif isinstance(filter_size, tuple) and len(filter_size) == 2:\n                filter_height, filter_width = filter_size\n                self.max_window = max(self.max_window, filter_width)\n                if isinstance(self.config.filter_channels, int):\n                    filter_channels = max(1, self.config.filter_channels // (conv_input_dim // filter_width))\n                else:\n                    filter_channels = self.config.filter_channels[filter_idx]\n                fc_delta = filter_channels * (conv_input_dim // filter_width) // self.config.maxpool_width\n                tlogger.debug(\"Adding filter %s.  Output channels: %d -> %d\", filter_size, filter_channels, fc_delta)\n                self.fc_input_size += fc_delta\n                self.conv_layers.append(nn.Conv2d(in_channels=1,\n                                                  out_channels=filter_channels,\n                                                  stride=(1, filter_width),\n                                                  kernel_size=(filter_height, filter_width)))\n            else:\n                raise ValueError(\"Expected int or 2d tuple for conv size\")\n\n        tlogger.debug(\"Input dim to FC layers: %d\", self.fc_input_size)\n        self.fc_layers = build_output_layers(self.fc_input_size, self.config.fc_shapes, self.config.num_classes)\n\n        self.dropout = nn.Dropout(self.config.dropout)\n\n    def add_unsaved_module(self, name, module):\n        self.unsaved_modules += [name]\n        setattr(self, name, module)\n\n        if module is not None and (name in ('forward_charlm', 'backward_charlm') or\n                                   (name == 'bert_model' and not self.config.use_peft)):\n            # if we are using peft, we should not save the transformer directly\n            # instead, the peft parameters only will be saved later\n            for _, parameter in module.named_parameters():\n                parameter.requires_grad = False\n\n    def is_unsaved_module(self, name):\n        return name.split('.')[0] in self.unsaved_modules\n\n    def log_configuration(self):\n        \"\"\"\n        Log some essential information about the model configuration to the training logger\n        \"\"\"\n        tlogger.info(\"Filter sizes: %s\" % str(self.config.filter_sizes))\n        tlogger.info(\"Filter channels: %s\" % str(self.config.filter_channels))\n        tlogger.info(\"Intermediate layers: %s\" % str(self.config.fc_shapes))\n\n    def log_norms(self):\n        lines = [\"NORMS FOR MODEL PARAMTERS\"]\n        for name, param in self.named_parameters():\n            if param.requires_grad and name.split(\".\")[0] not in ('forward_charlm', 'backward_charlm'):\n                lines.append(\"%s %.6g\" % (name, torch.norm(param).item()))\n        logger.info(\"\\n\".join(lines))\n\n    def build_char_reps(self, inputs, max_phrase_len, charlm, projection, begin_paddings, device):\n        char_reps = charlm.build_char_representation(inputs)\n        if projection is not None:\n            char_reps = [projection(x) for x in char_reps]\n        char_inputs = torch.zeros((len(inputs), max_phrase_len, char_reps[0].shape[-1]), device=device)\n        for idx, rep in enumerate(char_reps):\n            start = begin_paddings[idx]\n            end = start + rep.shape[0]\n            char_inputs[idx, start:end, :] = rep\n        return char_inputs\n\n    def extract_bert_embeddings(self, inputs, max_phrase_len, begin_paddings, device):\n        bert_embeddings = extract_bert_embeddings(self.config.bert_model, self.bert_tokenizer, self.bert_model, inputs, device,\n                                                  keep_endpoints=False,\n                                                  num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,\n                                                  detach=not self.config.bert_finetune,\n                                                  peft_name=self.peft_name)\n        if self.bert_layer_mix is not None:\n            # add the average so that the default behavior is to\n            # take an average of the N layers, and anything else\n            # other than that needs to be learned\n            bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]\n        bert_inputs = torch.zeros((len(inputs), max_phrase_len, bert_embeddings[0].shape[-1]), device=device)\n        for idx, rep in enumerate(bert_embeddings):\n            start = begin_paddings[idx]\n            end = start + rep.shape[0]\n            bert_inputs[idx, start:end, :] = rep\n        return bert_inputs\n\n    def forward(self, inputs):\n        # assume all pieces are on the same device\n        device = next(self.parameters()).device\n\n        vocab_map = self.vocab_map\n        def map_word(word):\n            idx = vocab_map.get(word, None)\n            if idx is not None:\n                return idx\n            if word[-1] == \"'\":\n                idx = vocab_map.get(word[:-1], None)\n                if idx is not None:\n                    return idx\n            return vocab_map.get(word.lower(), UNK_ID)\n\n        inputs = [x.text if isinstance(x, SentimentDatum) else x for x in inputs]\n        # we will pad each phrase so either it matches the longest\n        # conv or the longest phrase in the input, whichever is longer\n        max_phrase_len = max(len(x) for x in inputs)\n        if self.max_window > max_phrase_len:\n            max_phrase_len = self.max_window\n\n        batch_indices = []\n        batch_unknowns = []\n        extra_batch_indices = []\n        begin_paddings = []\n        end_paddings = []\n\n        elmo_batch_words = []\n\n        for phrase in inputs:\n            # we use random at training time to try to learn different\n            # positions of padding.  at test time, though, we want to\n            # have consistent results, so we set that to 0 begin_pad\n            if self.training:\n                begin_pad_width = random.randint(0, max_phrase_len - len(phrase))\n            else:\n                begin_pad_width = 0\n            end_pad_width = max_phrase_len - begin_pad_width - len(phrase)\n\n            begin_paddings.append(begin_pad_width)\n            end_paddings.append(end_pad_width)\n\n            # the initial lists are the length of the begin padding\n            sentence_indices = [PAD_ID] * begin_pad_width\n            sentence_indices.extend([map_word(x) for x in phrase])\n            sentence_indices.extend([PAD_ID] * end_pad_width)\n\n            # the \"unknowns\" will be the locations of the unknown words.\n            # these locations will get the specially trained unknown vector\n            # TODO: split UNK based on part of speech?  might be an interesting experiment\n            sentence_unknowns = [idx for idx, word in enumerate(sentence_indices) if word == UNK_ID]\n\n            batch_indices.append(sentence_indices)\n            batch_unknowns.append(sentence_unknowns)\n\n            if self.extra_vocab:\n                extra_sentence_indices = [PAD_ID] * begin_pad_width\n                for word in phrase:\n                    if word in self.extra_vocab_map:\n                        # the extra vocab is initialized from the\n                        # words in the training set, which means there\n                        # would be no unknown words.  to occasionally\n                        # train the extra vocab's unknown words, we\n                        # replace 1% of the words with UNK\n                        # we don't do that for the original embedding\n                        # on the assumption that there may be some\n                        # unknown words in the training set anyway\n                        # TODO: maybe train unk for the original embedding?\n                        if self.training and random.random() < 0.01:\n                            extra_sentence_indices.append(UNK_ID)\n                        else:\n                            extra_sentence_indices.append(self.extra_vocab_map[word])\n                    else:\n                        extra_sentence_indices.append(UNK_ID)\n                extra_sentence_indices.extend([PAD_ID] * end_pad_width)\n                extra_batch_indices.append(extra_sentence_indices)\n\n            if self.config.use_elmo:\n                elmo_phrase_words = [\"\"] * begin_pad_width\n                for word in phrase:\n                    elmo_phrase_words.append(word)\n                elmo_phrase_words.extend([\"\"] * end_pad_width)\n                elmo_batch_words.append(elmo_phrase_words)\n\n        # creating a single large list with all the indices lets us\n        # create a single tensor, which is much faster than creating\n        # many tiny tensors\n        # we can convert this to the input to the CNN\n        # it is padded at one or both ends so that it is now num_phrases x max_len x emb_size\n        # there are two ways in which this padding is suboptimal\n        # the first is that for short sentences, smaller windows will\n        #   be padded to the point that some windows are entirely pad\n        # the second is that a sentence S will have more or less padding\n        #   depending on what other sentences are in its batch\n        # we assume these effects are pretty minimal\n        batch_indices = torch.tensor(batch_indices, requires_grad=False, device=device)\n        input_vectors = self.embedding(batch_indices)\n        # we use the random unk so that we are not necessarily\n        # learning to match 0s for unk\n        for phrase_num, sentence_unknowns in enumerate(batch_unknowns):\n            input_vectors[phrase_num][sentence_unknowns] = self.unk\n\n        if self.extra_vocab:\n            extra_batch_indices = torch.tensor(extra_batch_indices, requires_grad=False, device=device)\n            extra_input_vectors = self.extra_embedding(extra_batch_indices)\n            if self.config.extra_wordvec_method is ExtraVectors.CONCAT:\n                all_inputs = [input_vectors, extra_input_vectors]\n            elif self.config.extra_wordvec_method is ExtraVectors.SUM:\n                all_inputs = [input_vectors + extra_input_vectors]\n            else:\n                raise ValueError(\"unable to handle {}\".format(self.config.extra_wordvec_method))\n        else:\n            all_inputs = [input_vectors]\n\n        if self.forward_charlm is not None:\n            char_reps_forward = self.build_char_reps(inputs, max_phrase_len, self.forward_charlm, self.charmodel_forward_projection, begin_paddings, device)\n            all_inputs.append(char_reps_forward)\n\n        if self.backward_charlm is not None:\n            char_reps_backward = self.build_char_reps(inputs, max_phrase_len, self.backward_charlm, self.charmodel_backward_projection, begin_paddings, device)\n            all_inputs.append(char_reps_backward)\n\n        if self.config.use_elmo:\n            # this will be N arrays of 3xMx1024 where M is the number of words\n            # and N is the number of sentences (and 1024 is actually the number of weights)\n            elmo_arrays = self.elmo_model.sents2elmo(elmo_batch_words, output_layer=-2)\n            elmo_tensors = [torch.tensor(x).to(device=device) for x in elmo_arrays]\n            # elmo_tensor will now be Nx3xMx1024\n            elmo_tensor = torch.stack(elmo_tensors)\n            # Nx1024xMx3\n            elmo_tensor = torch.transpose(elmo_tensor, 1, 3)\n            # NxMx1024x3\n            elmo_tensor = torch.transpose(elmo_tensor, 1, 2)\n            # NxMx1024x1\n            elmo_tensor = self.elmo_combine_layers(elmo_tensor)\n            # NxMx1024\n            elmo_tensor = elmo_tensor.squeeze(3)\n            if self.config.elmo_projection:\n                elmo_tensor = self.elmo_projection(elmo_tensor)\n            all_inputs.append(elmo_tensor)\n\n        if self.bert_model is not None:\n            bert_embeddings = self.extract_bert_embeddings(inputs, max_phrase_len, begin_paddings, device)\n            all_inputs.append(bert_embeddings)\n\n        # still works even if there's just one item\n        input_vectors = torch.cat(all_inputs, dim=2)\n\n        if self.config.bilstm:\n            input_vectors, _ = self.bilstm(self.dropout(input_vectors))\n\n        # reshape to fit the input tensors\n        x = input_vectors.unsqueeze(1)\n\n        conv_outs = []\n        for conv, filter_size in zip(self.conv_layers, self.config.filter_sizes):\n            if isinstance(filter_size, int):\n                conv_out = self.dropout(F.relu(conv(x).squeeze(3)))\n                conv_outs.append(conv_out)\n            else:\n                conv_out = conv(x).transpose(2, 3).flatten(1, 2)\n                conv_out = self.dropout(F.relu(conv_out))\n                conv_outs.append(conv_out)\n        pool_outs = [F.max_pool2d(out, (self.config.maxpool_width, out.shape[2])).squeeze(2) for out in conv_outs]\n        pooled = torch.cat(pool_outs, dim=1)\n\n        previous_layer = pooled\n        for fc in self.fc_layers[:-1]:\n            previous_layer = self.dropout(F.relu(fc(previous_layer)))\n        out = self.fc_layers[-1](previous_layer)\n        # note that we return the raw logits rather than use a softmax\n        # https://discuss.pytorch.org/t/multi-class-cross-entropy-loss-and-softmax-in-pytorch/24920/4\n        return out\n\n    def get_params(self, skip_modules=True):\n        model_state = self.state_dict()\n        # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file\n        if skip_modules:\n            skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]\n            for k in skipped:\n                del model_state[k]\n\n        config = dataclasses.asdict(self.config)\n        config['wordvec_type'] = config['wordvec_type'].name\n        config['extra_wordvec_method'] = config['extra_wordvec_method'].name\n        config['model_type'] = config['model_type'].name\n\n        params = {\n            'model':        model_state,\n            'config':       config,\n            'labels':       self.labels,\n            'extra_vocab':  self.extra_vocab,\n        }\n        if self.config.use_peft:\n            # Hide import so that peft dependency is optional\n            from peft import get_peft_model_state_dict\n            params[\"bert_lora\"] = get_peft_model_state_dict(self.bert_model, adapter_name=self.peft_name)\n        return params\n\n    def preprocess_data(self, sentences):\n        sentences = [data.update_text(s, self.config.wordvec_type) for s in sentences]\n        return sentences\n\n    def extract_sentences(self, doc):\n        # TODO: tokens or words better here?\n        return [[token.text for token in sentence.tokens] for sentence in doc.sentences]\n"
  },
  {
    "path": "stanza/models/classifiers/config.py",
    "content": "from dataclasses import dataclass\nfrom typing import List, Union\n\n# TODO: perhaps put the enums in this file\nfrom stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType\n\n@dataclass\nclass CNNConfig:  # pylint: disable=too-many-instance-attributes, too-few-public-methods\n        filter_channels: Union[int, tuple]\n        filter_sizes: tuple\n        fc_shapes: tuple\n        dropout: float\n        num_classes: int\n        wordvec_type: WVType\n        extra_wordvec_method: ExtraVectors\n        extra_wordvec_dim: int\n        extra_wordvec_max_norm: float\n        char_lowercase: bool\n        charlm_projection: int\n        has_charlm_forward: bool\n        has_charlm_backward: bool\n\n        use_elmo: bool\n        elmo_projection: int\n\n        bert_model: str\n        bert_finetune: bool\n        bert_hidden_layers: int\n        force_bert_saved: bool\n\n        use_peft: bool\n        lora_rank: int\n        lora_alpha: float\n        lora_dropout: float\n        lora_modules_to_save: List\n        lora_target_modules: List\n\n        bilstm: bool\n        bilstm_hidden_dim: int\n        maxpool_width: int\n        model_type: ModelType\n\n@dataclass\nclass ConstituencyConfig:  # pylint: disable=too-many-instance-attributes, too-few-public-methods\n        fc_shapes: tuple\n        dropout: float\n        num_classes: int\n\n        constituency_backprop: bool\n        constituency_batch_norm: bool\n        constituency_node_attn: bool\n        constituency_top_layer: bool\n        constituency_all_words: bool\n\n        model_type: ModelType\n"
  },
  {
    "path": "stanza/models/classifiers/constituency_classifier.py",
    "content": "\"\"\"\nA classifier that uses a constituency parser for the base embeddings\n\"\"\"\n\nimport dataclasses\nimport logging\nfrom types import SimpleNamespace\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom stanza.models.classifiers.base_classifier import BaseClassifier\nfrom stanza.models.classifiers.config import ConstituencyConfig\nfrom stanza.models.classifiers.data import SentimentDatum\nfrom stanza.models.classifiers.utils import ModelType, build_output_layers\n\nfrom stanza.models.common.utils import split_into_batches, sort_with_indices, unsort\n\nlogger = logging.getLogger('stanza')\ntlogger = logging.getLogger('stanza.classifiers.trainer')\n\nclass ConstituencyClassifier(BaseClassifier):\n    def __init__(self, tree_embedding, labels, args):\n        super(ConstituencyClassifier, self).__init__()\n        self.labels = labels\n        # we build a separate config out of the args so that we can easily save it in torch\n        self.config = ConstituencyConfig(fc_shapes = args.fc_shapes,\n                                         dropout = args.dropout,\n                                         num_classes = len(labels),\n                                         constituency_backprop = args.constituency_backprop,\n                                         constituency_batch_norm = args.constituency_batch_norm,\n                                         constituency_node_attn = args.constituency_node_attn,\n                                         constituency_top_layer = args.constituency_top_layer,\n                                         constituency_all_words = args.constituency_all_words,\n                                         model_type = ModelType.CONSTITUENCY)\n\n        self.tree_embedding = tree_embedding\n\n        self.fc_layers = build_output_layers(self.tree_embedding.output_size, self.config.fc_shapes, self.config.num_classes)\n        self.dropout = nn.Dropout(self.config.dropout)\n\n    def is_unsaved_module(self, name):\n        return False\n\n    def log_configuration(self):\n        tlogger.info(\"Backprop into parser: %s\", self.config.constituency_backprop)\n        tlogger.info(\"Batch norm: %s\", self.config.constituency_batch_norm)\n        tlogger.info(\"Word positions used: %s\", \"all words\" if self.config.constituency_all_words else \"start and end words\")\n        tlogger.info(\"Attention over nodes: %s\", self.config.constituency_node_attn)\n        tlogger.info(\"Intermediate layers: %s\", self.config.fc_shapes)\n\n    def log_norms(self):\n        lines = [\"NORMS FOR MODEL PARAMTERS\"]\n        lines.extend([\"tree_embedding.\" + x for x in self.tree_embedding.get_norms()])\n        for name, param in self.named_parameters():\n            if param.requires_grad and not name.startswith('tree_embedding.'):\n                lines.append(\"%s %.6g\" % (name, torch.norm(param).item()))\n        logger.info(\"\\n\".join(lines))\n\n\n    def forward(self, inputs):\n        inputs = [x.constituency if isinstance(x, SentimentDatum) else x for x in inputs]\n\n        embedding = self.tree_embedding.embed_trees(inputs)\n        previous_layer = torch.stack([torch.max(x, dim=0)[0] for x in embedding], dim=0)\n        previous_layer = self.dropout(previous_layer)\n        for fc in self.fc_layers[:-1]:\n            # relu cause many neuron die\n            previous_layer = self.dropout(F.gelu(fc(previous_layer)))\n        out = self.fc_layers[-1](previous_layer)\n        return out\n\n    def get_params(self, skip_modules=True):\n        model_state = self.state_dict()\n        # skip all of the constituency parameters here -\n        # we will add them by calling the model's get_params()\n        skipped = [k for k in model_state.keys() if k.startswith(\"tree_embedding.\")]\n        for k in skipped:\n            del model_state[k]\n\n        tree_embedding = self.tree_embedding.get_params(skip_modules)\n\n        config = dataclasses.asdict(self.config)\n        config['model_type'] = config['model_type'].name\n\n        params = {\n            'model':           model_state,\n            'tree_embedding':  tree_embedding,\n            'config':          config,\n            'labels':          self.labels,\n        }\n        return params\n\n    def extract_sentences(self, doc):\n        return [sentence.constituency for sentence in doc.sentences]\n"
  },
  {
    "path": "stanza/models/classifiers/data.py",
    "content": "\"\"\"Stanza models classifier data functions.\"\"\"\n\nimport collections\nfrom collections import namedtuple\nimport logging\nimport json\nimport random\nimport re\nfrom typing import List\n\nfrom stanza.models.classifiers.utils import WVType\nfrom stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID\nimport stanza.models.constituency.tree_reader as tree_reader\n\nlogger = logging.getLogger('stanza')\n\nclass SentimentDatum:\n    def __init__(self, sentiment, text, constituency=None):\n        self.sentiment = sentiment\n        self.text = text\n        self.constituency = constituency\n\n    def __eq__(self, other):\n        if self is other:\n            return True\n        if not isinstance(other, SentimentDatum):\n            return False\n        return self.sentiment == other.sentiment and self.text == other.text and self.constituency == other.constituency\n\n    def __str__(self):\n        return str(self._asdict())\n\n    def _asdict(self):\n        if self.constituency is None:\n            return {'sentiment': self.sentiment, 'text': self.text}\n        else:\n            return {'sentiment': self.sentiment, 'text': self.text, 'constituency': str(self.constituency)}\n\ndef update_text(sentence: List[str], wordvec_type: WVType) -> List[str]:\n    \"\"\"\n    Process a line of text (with tokenization provided as whitespace)\n    into a list of strings.\n    \"\"\"\n    # stanford sentiment dataset has a lot of random - and /\n    # remove those characters and flatten the newly created sublists into one list each time\n    sentence = [y for x in sentence for y in x.split(\"-\") if y]\n    sentence = [y for x in sentence for y in x.split(\"/\") if y]\n    sentence = [x.strip() for x in sentence]\n    sentence = [x for x in sentence if x]\n    if sentence == []:\n        # removed too much\n        sentence = [\"-\"]\n    # our current word vectors are all entirely lowercased\n    sentence = [word.lower() for word in sentence]\n    if wordvec_type == WVType.WORD2VEC:\n        return sentence\n    elif wordvec_type == WVType.GOOGLE:\n        new_sentence = []\n        for word in sentence:\n            if word != '0' and word != '1':\n                word = re.sub('[0-9]', '#', word)\n            new_sentence.append(word)\n        return new_sentence\n    elif wordvec_type == WVType.FASTTEXT:\n        return sentence\n    elif wordvec_type == WVType.OTHER:\n        return sentence\n    else:\n        raise ValueError(\"Unknown wordvec_type {}\".format(wordvec_type))\n\n\ndef read_dataset(dataset, wordvec_type: WVType, min_len: int) -> List[SentimentDatum]:\n    \"\"\"\n    returns a list where the values of the list are\n      label, [token...]\n    \"\"\"\n    lines = []\n    for filename in str(dataset).split(\",\"):\n        with open(filename, encoding=\"utf-8\") as fin:\n            new_lines = json.load(fin)\n        new_lines = [(str(x['sentiment']), x['text'], x.get('constituency', None)) for x in new_lines]\n        lines.extend(new_lines)\n    # TODO: maybe do this processing later, once the model is built.\n    # then move the processing into the model so we can use\n    # overloading to potentially make future model types\n    lines = [SentimentDatum(x[0], update_text(x[1], wordvec_type), tree_reader.read_trees(x[2])[0] if x[2] else None) for x in lines]\n    if min_len:\n        lines = [x for x in lines if len(x.text) >= min_len]\n    return lines\n\ndef dataset_labels(dataset):\n    \"\"\"\n    Returns a sorted list of label name\n    \"\"\"\n    labels = set([x.sentiment for x in dataset])\n    if all(re.match(\"^[0-9]+$\", label) for label in labels):\n        # if all of the labels are integers, sort numerically\n        # maybe not super important, but it would be nicer than having\n        # 10 before 2\n        labels = [str(x) for x in sorted(map(int, list(labels)))]\n    else:\n        labels = sorted(list(labels))\n    return labels\n\ndef dataset_vocab(dataset):\n    vocab = set()\n    for line in dataset:\n        for word in line.text:\n            vocab.add(word)\n    vocab = [PAD, UNK] + list(vocab)\n    if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK:\n        raise ValueError(\"Unexpected values for PAD and UNK!\")\n    return vocab\n\ndef sort_dataset_by_len(dataset, keep_index=False):\n    \"\"\"\n    returns a dict mapping length -> list of items of that length\n\n    an OrderedDict is used so that the mapping is sorted from smallest to largest\n    \"\"\"\n    sorted_dataset = collections.OrderedDict()\n    lengths = sorted(list(set(len(x.text) for x in dataset)))\n    for l in lengths:\n        sorted_dataset[l] = []\n    for item_idx, item in enumerate(dataset):\n        if keep_index:\n            sorted_dataset[len(item.text)].append((item, item_idx))\n        else:\n            sorted_dataset[len(item.text)].append(item)\n    return sorted_dataset\n\ndef shuffle_dataset(sorted_dataset, batch_size, batch_single_item):\n    \"\"\"\n    Given a dataset sorted by len, sorts within each length to make\n    chunks of roughly the same size.  Returns all items as a single list.\n    \"\"\"\n    dataset = []\n    for l in sorted_dataset.keys():\n        items = list(sorted_dataset[l])\n        random.shuffle(items)\n        dataset.extend(items)\n    batches = []\n    next_batch = []\n    for item in dataset:\n        if batch_single_item > 0 and len(item.text) >= batch_single_item:\n            batches.append([item])\n        else:\n            next_batch.append(item)\n            if len(next_batch) >= batch_size:\n                batches.append(next_batch)\n                next_batch = []\n    if len(next_batch) > 0:\n        batches.append(next_batch)\n    random.shuffle(batches)\n    return batches\n\n\ndef check_labels(labels, dataset):\n    \"\"\"\n    Check that all of the labels in the dataset are in the known labels.\n\n    Actually, unknown labels could be acceptable if we just treat the model as always wrong.\n    However, this is a good sanity check to make sure the datasets match\n    \"\"\"\n    new_labels = dataset_labels(dataset)\n    not_found = [i for i in new_labels if i not in labels]\n    if not_found:\n        raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found))\n\n"
  },
  {
    "path": "stanza/models/classifiers/iterate_test.py",
    "content": "\"\"\"Iterate test.\"\"\"\nimport argparse\nimport glob\nimport logging\n\nimport stanza.models.classifier as classifier\nimport stanza.models.classifiers.cnn_classifier as cnn_classifier\nfrom stanza.models.common import utils\n\nfrom stanza.utils.confusion import format_confusion, confusion_to_accuracy\n\n\"\"\"\nA script for running the same test file on several different classifiers.\n\nFor each one, it will output the accuracy and, if possible, the confusion matrix.\n\nIncludes the arguments for pretrain, which allows for passing in a\ndifferent directory for the pretrain file.\n\nExample command line:\n  python3 -m stanza.models.classifiers.iterate_test  --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt --glob \"saved_models/classifier/FC41_3class_en_ewt_FS*ACC66*\"\n\"\"\"\n\nlogger = logging.getLogger('stanza')\n\n\ndef parse_args():\n    \"\"\"Add and parse arguments.\"\"\"\n    parser = classifier.build_argparse()\n\n    parser.add_argument('--glob', type=str, default='saved_models/classifier/*classifier*pt', help='Model file(s) to test.')\n\n    args = parser.parse_args()\n    return args\n\nargs = parse_args()\nseed = utils.set_random_seed(args.seed)\n\nmodel_files = []\nfor glob_piece in args.glob.split():\n    model_files.extend(glob.glob(glob_piece))\nmodel_files = sorted(set(model_files))\n\ntest_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None)\nlogger.info(\"Using test set: %s\" % args.test_file)\n\ndevice = None\nfor load_name in model_files:\n    args.load_name = load_name\n    model = classifier.load_model(args)\n\n    logger.info(\"Testing %s\" % load_name)\n    model = cnn_classifier.load(load_name, pretrain)\n    if device is None:\n        device = next(model.parameters()).device\n        logger.info(\"Current device: %s\" % device)\n\n    labels = model.labels\n    classifier.check_labels(labels, test_set)\n\n    confusion = classifier.confusion_dataset(model, test_set, device=device)\n    correct, total = confusion_to_accuracy(confusion)\n    logger.info(\"  Results: %d correct of %d examples.  Accuracy: %f\" % (correct, total, correct / total))\n    logger.info(\"Confusion matrix:\\n{}\".format(format_confusion(confusion, model.labels)))\n"
  },
  {
    "path": "stanza/models/classifiers/trainer.py",
    "content": "\"\"\"\nOrganizes the model itself and its optimizer in one place\n\nSaving the optimizer allows for easy restarting of training\n\"\"\"\n\nimport logging\nimport os\nimport torch\nimport torch.optim as optim\nfrom types import SimpleNamespace\n\nimport stanza.models.classifiers.data as data\nimport stanza.models.classifiers.cnn_classifier as cnn_classifier\nimport stanza.models.classifiers.constituency_classifier as constituency_classifier\nfrom stanza.models.classifiers.config import CNNConfig, ConstituencyConfig\nfrom stanza.models.classifiers.utils import ModelType, WVType, ExtraVectors\nfrom stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain\nfrom stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper\nfrom stanza.models.common.pretrain import Pretrain\nfrom stanza.models.common.utils import get_split_optimizer\nfrom stanza.models.constituency.tree_embedding import TreeEmbedding\n\nfrom pickle import UnpicklingError\nimport warnings\n\nlogger = logging.getLogger('stanza')\n\nclass Trainer:\n    \"\"\"\n    Stores a constituency model and its optimizer\n    \"\"\"\n\n    def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None):\n        self.model = model\n        self.optimizer = optimizer\n        # we keep track of position in the learning so that we can\n        # checkpoint & restart if needed without restarting the epoch count\n        self.epochs_trained = epochs_trained\n        self.global_step = global_step\n        # save the best dev score so that when reloading a checkpoint\n        # of a model, we know how far we got\n        self.best_score = best_score\n\n    def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True):\n        \"\"\"\n        save the current model, optimizer, and other state to filename\n\n        epochs_trained can be passed as a parameter to handle saving at the end of an epoch\n        \"\"\"\n        if epochs_trained is None:\n            epochs_trained = self.epochs_trained\n        save_dir = os.path.split(filename)[0]\n        os.makedirs(save_dir, exist_ok=True)\n        model_params = self.model.get_params(skip_modules)\n        params = {\n            'params':         model_params,\n            'epochs_trained': epochs_trained,\n            'global_step':    self.global_step,\n            'best_score':     self.best_score,\n        }\n        if save_optimizer and self.optimizer is not None:\n            params['optimizer_state_dict'] = {opt_name: opt.state_dict() for opt_name, opt in self.optimizer.items()}\n        torch.save(params, filename, _use_new_zipfile_serialization=False)\n        logger.info(\"Model saved to {}\".format(filename))\n\n    @staticmethod\n    def load(filename, args, foundation_cache=None, load_optimizer=False):\n        if not os.path.exists(filename):\n            if args.save_dir is None:\n                raise FileNotFoundError(\"Cannot find model in {} and args.save_dir is None\".format(filename))\n            elif os.path.exists(os.path.join(args.save_dir, filename)):\n                filename = os.path.join(args.save_dir, filename)\n            else:\n                raise FileNotFoundError(\"Cannot find model in {} or in {}\".format(filename, os.path.join(args.save_dir, filename)))\n        try:\n            # TODO: can remove the try/except once the new version is out\n            #checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n            try:\n                checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n            except UnpicklingError as e:\n                checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)\n                warnings.warn(\"The saved classifier has an old format using SimpleNamespace and/or Enum instead of a dict to store config.  This version of Stanza can support reading both the new and the old formats.  Future versions will only allow loading with weights_only=True.  Please resave the pretrained classifier using this version ASAP.\")\n        except BaseException:\n            logger.exception(\"Cannot load model from {}\".format(filename))\n            raise\n        logger.debug(\"Loaded model {}\".format(filename))\n\n        epochs_trained = checkpoint.get('epochs_trained', 0)\n        global_step = checkpoint.get('global_step', 0)\n        best_score = checkpoint.get('best_score', None)\n\n        # TODO: can remove this block once all models are retrained\n        if 'params' not in checkpoint:\n            model_params = {\n                'model':        checkpoint['model'],\n                'config':       checkpoint['config'],\n                'labels':       checkpoint['labels'],\n                'extra_vocab':  checkpoint['extra_vocab'],\n            }\n        else:\n            model_params = checkpoint['params']\n        # TODO: this can be removed once v1.10.0 is out\n        if isinstance(model_params['config'], SimpleNamespace):\n            model_params['config'] = vars(model_params['config'])\n        # TODO: these isinstance can go away after 1.10.0\n        model_type = model_params['config']['model_type']\n        if isinstance(model_type, str):\n            model_type = ModelType[model_type]\n            model_params['config']['model_type'] = model_type\n\n        if model_type == ModelType.CNN:\n            # TODO: these updates are only necessary during the\n            # transition to the @dataclass version of the config\n            # Once those are all saved, it is no longer necessary\n            # to patch existing models (since they will all be patched)\n            if 'has_charlm_forward' not in model_params['config']:\n                model_params['config']['has_charlm_forward'] = args.charlm_forward_file is not None\n            if 'has_charlm_backward' not in model_params['config']:\n                model_params['config']['has_charlm_backward'] = args.charlm_backward_file is not None\n            for argname in ['bert_hidden_layers', 'bert_finetune', 'force_bert_saved', 'use_peft',\n                            'lora_rank', 'lora_alpha', 'lora_dropout', 'lora_modules_to_save', 'lora_target_modules']:\n                model_params['config'][argname] = model_params['config'].get(argname, None)\n            # TODO: these isinstance can go away after 1.10.0\n            if isinstance(model_params['config']['wordvec_type'], str):\n                model_params['config']['wordvec_type'] = WVType[model_params['config']['wordvec_type']]\n            if isinstance(model_params['config']['extra_wordvec_method'], str):\n                model_params['config']['extra_wordvec_method'] = ExtraVectors[model_params['config']['extra_wordvec_method']]\n            model_params['config'] = CNNConfig(**model_params['config'])\n\n            pretrain = Trainer.load_pretrain(args, foundation_cache)\n            elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None\n\n            if model_params['config'].has_charlm_forward:\n                charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache)\n            else:\n                charmodel_forward = None\n            if model_params['config'].has_charlm_backward:\n                charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache)\n            else:\n                charmodel_backward = None\n\n            bert_model = model_params['config'].bert_model\n            # TODO: can get rid of the getattr after rebuilding all models\n            use_peft = getattr(model_params['config'], 'use_peft', False)\n            force_bert_saved = getattr(model_params['config'], 'force_bert_saved', False)\n            peft_name = None\n            if use_peft:\n                # if loading a peft model, we first load the base transformer\n                # the CNNClassifier code wraps the transformer in peft\n                # after creating the CNNClassifier with the peft wrapper,\n                # we *then* load the weights\n                bert_model, bert_tokenizer, peft_name = load_bert_with_peft(bert_model, \"classifier\", foundation_cache)\n                bert_model = load_peft_wrapper(bert_model, model_params['bert_lora'], vars(model_params['config']), logger, peft_name)\n            elif force_bert_saved:\n                bert_model, bert_tokenizer = load_bert(bert_model)\n            else:\n                bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache)\n            model = cnn_classifier.CNNClassifier(pretrain=pretrain,\n                                                 extra_vocab=model_params['extra_vocab'],\n                                                 labels=model_params['labels'],\n                                                 charmodel_forward=charmodel_forward,\n                                                 charmodel_backward=charmodel_backward,\n                                                 elmo_model=elmo_model,\n                                                 bert_model=bert_model,\n                                                 bert_tokenizer=bert_tokenizer,\n                                                 force_bert_saved=force_bert_saved,\n                                                 peft_name=peft_name,\n                                                 args=model_params['config'])\n        elif model_type == ModelType.CONSTITUENCY:\n            # the constituency version doesn't have a peft feature yet\n            use_peft = False\n            pretrain_args = {\n                'wordvec_pretrain_file': args.wordvec_pretrain_file,\n                'charlm_forward_file': args.charlm_forward_file,\n                'charlm_backward_file': args.charlm_backward_file,\n            }\n            # TODO: integrate with peft for the constituency version\n            tree_embedding = TreeEmbedding.model_from_params(model_params['tree_embedding'], pretrain_args, foundation_cache)\n            model_params['config'] = ConstituencyConfig(**model_params['config'])\n            model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,\n                                                                   labels=model_params['labels'],\n                                                                   args=model_params['config'])\n        else:\n            raise ValueError(\"Unknown model type {}\".format(model_type))\n        model.load_state_dict(model_params['model'], strict=False)\n        model = model.to(args.device)\n\n        logger.debug(\"-- MODEL CONFIG --\")\n        for k in model.config.__dict__:\n            logger.debug(\"  --{}: {}\".format(k, model.config.__dict__[k]))\n\n        logger.debug(\"-- MODEL LABELS --\")\n        logger.debug(\"  {}\".format(\" \".join(model.labels)))\n\n        optimizer = None\n        if load_optimizer:\n            optimizer = Trainer.build_optimizer(model, args)\n            if checkpoint.get('optimizer_state_dict', None) is not None:\n                for opt_name, opt_state_dict in checkpoint['optimizer_state_dict'].items():\n                    optimizer[opt_name].load_state_dict(opt_state_dict)\n            else:\n                logger.info(\"Attempted to load optimizer to resume training, but optimizer not saved.  Creating new optimizer\")\n\n        trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score)\n\n        return trainer\n\n\n    def load_pretrain(args, foundation_cache):\n        if args.wordvec_pretrain_file:\n            pretrain_file = args.wordvec_pretrain_file\n        elif args.wordvec_type:\n            pretrain_file = '{}/{}.{}.pretrain.pt'.format(args.save_dir, args.shorthand, args.wordvec_type.name.lower())\n        else:\n            raise RuntimeError(\"TODO: need to get the wv type back from get_wordvec_file\")\n\n        logger.debug(\"Looking for pretrained vectors in {}\".format(pretrain_file))\n        if os.path.exists(pretrain_file):\n            return load_pretrain(pretrain_file, foundation_cache)\n        elif args.wordvec_raw_file:\n            vec_file = args.wordvec_raw_file\n            logger.debug(\"Pretrain not found.  Looking in {}\".format(vec_file))\n        else:\n            vec_file = utils.get_wordvec_file(args.wordvec_dir, args.shorthand, args.wordvec_type.name.lower())\n            logger.debug(\"Pretrain not found.  Looking in {}\".format(vec_file))\n        pretrain = Pretrain(pretrain_file, vec_file, args.pretrain_max_vocab)\n        logger.debug(\"Embedding shape: %s\" % str(pretrain.emb.shape))\n        return pretrain\n\n\n    @staticmethod\n    def build_new_model(args, train_set):\n        \"\"\"\n        Load pretrained pieces and then build a new model\n        \"\"\"\n        if train_set is None:\n            raise ValueError(\"Must have a train set to build a new model - needed for labels and delta word vectors\")\n\n        labels = data.dataset_labels(train_set)\n\n        if args.model_type == ModelType.CNN:\n            pretrain = Trainer.load_pretrain(args, foundation_cache=None)\n            elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None\n            charmodel_forward = load_charlm(args.charlm_forward_file)\n            charmodel_backward = load_charlm(args.charlm_backward_file)\n            peft_name = None\n            bert_model, bert_tokenizer = load_bert(args.bert_model)\n\n            use_peft = getattr(args, \"use_peft\", False)\n            if use_peft:\n                peft_name = \"sentiment\"\n                bert_model = build_peft_wrapper(bert_model, vars(args), logger, adapter_name=peft_name)\n\n            extra_vocab = data.dataset_vocab(train_set)\n            force_bert_saved = args.bert_finetune\n            model = cnn_classifier.CNNClassifier(pretrain=pretrain,\n                                                 extra_vocab=extra_vocab,\n                                                 labels=labels,\n                                                 charmodel_forward=charmodel_forward,\n                                                 charmodel_backward=charmodel_backward,\n                                                 elmo_model=elmo_model,\n                                                 bert_model=bert_model,\n                                                 bert_tokenizer=bert_tokenizer,\n                                                 force_bert_saved=force_bert_saved,\n                                                 peft_name=peft_name,\n                                                 args=args)\n            model = model.to(args.device)\n        elif args.model_type == ModelType.CONSTITUENCY:\n            # this passes flags such as \"constituency_backprop\" from\n            # the classifier to the TreeEmbedding as the \"backprop\" flag\n            parser_args = { x[len(\"constituency_\"):]: y for x, y in vars(args).items() if x.startswith(\"constituency_\") }\n            parser_args.update({\n                \"wordvec_pretrain_file\": args.wordvec_pretrain_file,\n                \"charlm_forward_file\": args.charlm_forward_file,\n                \"charlm_backward_file\": args.charlm_backward_file,\n                \"bert_model\": args.bert_model,\n                # we found that finetuning from the classifier output\n                # all the way to the bert layers caused the bert model\n                # to go astray\n                # could make this an option... but it is much less accurate\n                # with the Bert finetuning\n                # noting that the constituency parser itself works better\n                # after finetuning, of course\n                \"bert_finetune\": False,\n                \"stage1_bert_finetune\": False,\n            })\n            logger.info(\"Building constituency classifier using %s as the base model\" % args.constituency_model)\n            tree_embedding = TreeEmbedding.from_parser_file(parser_args)\n            model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,\n                                                                   labels=labels,\n                                                                   args=args)\n            model = model.to(args.device)\n        else:\n            raise ValueError(\"Unhandled model type {}\".format(args.model_type))\n\n        optimizer = Trainer.build_optimizer(model, args)\n\n        return Trainer(model, optimizer)\n\n\n    @staticmethod\n    def build_optimizer(model, args):\n        return get_split_optimizer(args.optim.lower(), model, args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, bert_learning_rate=args.bert_learning_rate, bert_weight_decay=args.weight_decay * args.bert_weight_decay, is_peft=args.use_peft)\n"
  },
  {
    "path": "stanza/models/classifiers/utils.py",
    "content": "from enum import Enum\n\nfrom torch import nn\n\n\"\"\"\nDefines some methods which may occur in multiple model types\n\"\"\"\n# NLP machines:\n# word2vec are in\n# /u/nlp/data/stanfordnlp/model_production/stanfordnlp/extern_data/word2vec\n# google vectors are in\n# /scr/nlp/data/wordvectors/en/google/GoogleNews-vectors-negative300.txt\n\nclass WVType(Enum):\n    WORD2VEC = 1\n    GOOGLE = 2\n    FASTTEXT = 3\n    OTHER = 4\n\nclass ExtraVectors(Enum):\n    NONE = 1\n    CONCAT = 2\n    SUM = 3\n\nclass ModelType(Enum):\n    CNN = 1\n    CONSTITUENCY = 2\n\ndef build_output_layers(fc_input_size, fc_shapes, num_classes):\n    \"\"\"\n    Build a sequence of fully connected layers to go from the final conv layer to num_classes\n\n    Returns an nn.ModuleList\n    \"\"\"\n    fc_layers = []\n    previous_layer_size = fc_input_size\n    for shape in fc_shapes:\n        fc_layers.append(nn.Linear(previous_layer_size, shape))\n        previous_layer_size = shape\n    fc_layers.append(nn.Linear(previous_layer_size, num_classes))\n    return nn.ModuleList(fc_layers)\n"
  },
  {
    "path": "stanza/models/common/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/common/beam.py",
    "content": "from __future__ import division\nimport torch\n\nimport stanza.models.common.seq2seq_constant as constant\n\nr\"\"\"\n Adapted and modified from the OpenNMT project.\n\n Class for managing the internals of the beam search process.\n\n\n         hyp1-hyp1---hyp1 -hyp1\n                 \\             /\n         hyp2 \\-hyp2 /-hyp2hyp2\n                               /      \\\n         hyp3-hyp3---hyp3 -hyp3\n         ========================\n\n Takes care of beams, back pointers, and scores.\n\"\"\"\n\n\n# TORCH COMPATIBILITY\n#\n# Here we special case trunc division\n# torch < 1.8.0 has no rounding_model='trunc' argument for torch.div\n# however, there were several versions in a row where // would loudly\n# proclaim it was buggy, and users complained about that\n# this hopefully maintains compatibility for torch\ntry:\n    a = torch.tensor([1.])\n    b = torch.tensor([2.])\n    c = torch.div(a, b, rounding_mode='trunc')\n    def trunc_division(a, b):\n        return torch.div(a, b, rounding_mode='trunc')\nexcept TypeError:\n    def trunc_division(a, b):\n        return a // b\n\nclass Beam(object):\n    def __init__(self, size, device=None):\n        self.size = size\n        self.done = False\n\n        # The score for each translation on the beam.\n        self.scores = torch.zeros(size, dtype=torch.float32, device=device)\n        self.allScores = []\n\n        # The backpointers at each time-step.\n        self.prevKs = []\n\n        # The outputs at each time-step.\n        self.nextYs = [torch.zeros(size, dtype=torch.int64, device=device).fill_(constant.PAD_ID)]\n        self.nextYs[0][0] = constant.SOS_ID\n\n        # The copy indices for each time\n        self.copy = []\n\n    def get_current_state(self):\n        \"Get the outputs for the current timestep.\"\n        return self.nextYs[-1]\n\n    def get_current_origin(self):\n        \"Get the backpointers for the current timestep.\"\n        return self.prevKs[-1]\n\n    def advance(self, wordLk, copy_indices=None):\n        \"\"\"\n        Given prob over words for every last beam `wordLk` and attention\n        `attnOut`: Compute and update the beam search.\n\n        Parameters:\n\n        * `wordLk`- probs of advancing from the last step (K x words)\n        * `copy_indices` - copy indices (K x ctx_len)\n\n        Returns: True if beam search is complete.\n        \"\"\"\n        if self.done:\n            return True\n        numWords = wordLk.size(1)\n\n        # Sum the previous scores.\n        if len(self.prevKs) > 0:\n            beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)\n        else:\n            # first step, expand from the first position\n            beamLk = wordLk[0]\n\n        flatBeamLk = beamLk.view(-1)\n\n        bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)\n        self.allScores.append(self.scores)\n        self.scores = bestScores\n\n        # bestScoresId is flattened beam x word array, so calculate which\n        # word and beam each score came from\n        # bestScoreId is the integer ids, and numWords is the integer length.\n        # Need to do integer division\n        prevK = trunc_division(bestScoresId, numWords)\n        self.prevKs.append(prevK)\n        self.nextYs.append(bestScoresId - prevK * numWords)\n        if copy_indices is not None:\n            self.copy.append(copy_indices.index_select(0, prevK))\n\n        # End condition is when top-of-beam is EOS.\n        if self.nextYs[-1][0] == constant.EOS_ID:\n            self.done = True\n            self.allScores.append(self.scores)\n\n        return self.done\n\n    def sort_best(self):\n        return torch.sort(self.scores, 0, True)\n\n    def get_best(self):\n        \"Get the score of the best in the beam.\"\n        scores, ids = self.sortBest()\n        return scores[1], ids[1]\n\n    def get_hyp(self, k):\n        \"\"\"\n        Walk back to construct the full hypothesis.\n\n        Parameters:\n\n             * `k` - the position in the beam to construct.\n\n         Returns: The hypothesis\n        \"\"\"\n        hyp = []\n        cpy = []\n        for j in range(len(self.prevKs) - 1, -1, -1):\n            hyp.append(self.nextYs[j+1][k])\n            if len(self.copy) > 0:\n                cpy.append(self.copy[j][k])\n            k = self.prevKs[j][k]\n\n        hyp = hyp[::-1]\n        cpy = cpy[::-1]\n        # postprocess: if cpy index is not -1, use cpy index instead of hyp word\n        for i,cidx in enumerate(cpy):\n            if cidx >= 0:\n                hyp[i] = -(cidx+1) # make index 1-based and flip it for token generation\n\n        return hyp\n"
  },
  {
    "path": "stanza/models/common/bert_embedding.py",
    "content": "import math\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence\n\nlogger = logging.getLogger('stanza')\n\nBERT_ARGS = {\n    \"vinai/phobert-base\": { \"use_fast\": True },\n    \"vinai/phobert-large\": { \"use_fast\": True },\n}\n\nclass TextTooLongError(ValueError):\n    \"\"\"\n    A text was too long for the underlying model (possibly BERT)\n    \"\"\"\n    def __init__(self, length, max_len, line_num, text):\n        super().__init__(\"Found a text of length %d (possibly after tokenizing).  Maximum handled length is %d  Error occurred at line %d\" % (length, max_len, line_num))\n        self.line_num = line_num\n        self.text = text\n\n\ndef update_max_length(model_name, tokenizer):\n    if model_name in ('hf-internal-testing/tiny-bert',\n                      'google/muril-base-cased',\n                      'google/muril-large-cased',\n                      'airesearch/wangchanberta-base-att-spm-uncased',\n                      'camembert/camembert-large',\n                      'hfl/chinese-electra-180g-large-discriminator',\n                      'hfl/chinese-macbert-large',\n                      'NYTK/electra-small-discriminator-hungarian'):\n        tokenizer.model_max_length = 512\n\ndef load_tokenizer(model_name, tokenizer_kwargs=None, local_files_only=False):\n    if model_name:\n        # note that use_fast is the default\n        try:\n            from transformers import AutoTokenizer\n        except ImportError:\n            raise ImportError(\"Please install transformers library for BERT support! Try `pip install transformers`.\")\n        bert_args = BERT_ARGS.get(model_name, dict())\n        if not model_name.startswith(\"vinai/phobert\"):\n            bert_args[\"add_prefix_space\"] = True\n        if tokenizer_kwargs:\n            bert_args.update(tokenizer_kwargs)\n        bert_args['local_files_only'] = local_files_only\n        bert_tokenizer = AutoTokenizer.from_pretrained(model_name, **bert_args)\n        update_max_length(model_name, bert_tokenizer)\n        if model_name == 'princeton-nlp/Sheared-LLaMA-1.3B':\n            bert_tokenizer.pad_token = bert_tokenizer.eos_token\n            logger.debug(\"Tokenizer does not have a pad_token - setting to %s (%s)\", bert_tokenizer.pad_token, bert_tokenizer.eos_token)\n        return bert_tokenizer\n    return None\n\ndef load_bert(model_name, tokenizer_kwargs=None, local_files_only=False):\n    if model_name:\n        # such as: \"vinai/phobert-base\"\n        try:\n            from transformers import AutoModel\n        except ImportError:\n            raise ImportError(\"Please install transformers library for BERT support! Try `pip install transformers`.\")\n        bert_model = AutoModel.from_pretrained(model_name, local_files_only=local_files_only)\n        bert_tokenizer = load_tokenizer(model_name, tokenizer_kwargs=tokenizer_kwargs, local_files_only=local_files_only)\n        return bert_model, bert_tokenizer\n    return None, None\n\ndef tokenize_manual(model_name, sent, tokenizer):\n    \"\"\"\n    Tokenize a sentence manually, using for checking long sentences and PHOBert.\n    \"\"\"\n    #replace \\xa0 or whatever the space character is by _ since PhoBERT expects _ between syllables\n    tokenized = [word.replace(\"\\xa0\",\"_\").replace(\" \", \"_\") for word in sent] if model_name.startswith(\"vinai/phobert\") else [word.replace(\"\\xa0\",\" \") for word in sent]\n\n    #concatenate to a sentence\n    sentence = ' '.join(tokenized)\n\n    #tokenize using AutoTokenizer PhoBERT\n    tokenized = tokenizer.tokenize(sentence)\n\n    #convert tokens to ids\n    sent_ids = tokenizer.convert_tokens_to_ids(tokenized)\n\n    #add start and end tokens to sent_ids\n    tokenized_sent = [tokenizer.bos_token_id] + sent_ids + [tokenizer.eos_token_id]\n\n    return tokenized, tokenized_sent\n\ndef filter_data(model_name, data, tokenizer = None, log_level=logging.DEBUG):\n    \"\"\"\n    Filter out the (NER, POS) data that is too long for BERT model.\n    \"\"\"\n    if tokenizer is None:\n        tokenizer = load_tokenizer(model_name) \n    filtered_data = []\n    #eliminate all the sentences that are too long for bert model\n    for sent in data:\n        sentence = [word if isinstance(word, str) else word[0] for word in sent]\n        _, tokenized_sent = tokenize_manual(model_name, sentence, tokenizer)\n        \n        if len(tokenized_sent) > tokenizer.model_max_length - 2:\n            continue\n\n        filtered_data.append(sent)\n\n    logger.log(log_level, \"Eliminated %d of %d datapoints because their length is over maximum size of BERT model.\", (len(data)-len(filtered_data)), len(data))\n    \n    return filtered_data\n\ndef needs_length_filter(model_name):\n    \"\"\"\n    TODO: we were lazy and didn't implement any form of length fudging for models other than bert/roberta/electra\n    \"\"\"\n    if 'bart' in model_name or 'xlnet' in model_name:\n        return True\n    if model_name.startswith(\"vinai/phobert\"):\n        return True\n    return False\n\ndef cloned_feature(feature, num_layers, detach=True):\n    \"\"\"\n    Clone & detach the feature, keeping the last N layers (or averaging -2,-3,-4 if not specified)\n\n    averaging 3 of the last 4 layers worked well for non-VI languages\n    \"\"\"\n    # in most cases, need to call with features.hidden_states\n    # bartpho is different - it has features.decoder_hidden_states\n    # feature[2] is the same for bert, but it didn't work for\n    # older versions of transformers for xlnet\n    if num_layers is None:\n        feature = torch.stack(feature[-4:-1], axis=3).sum(axis=3) / 4\n    else:\n        feature = torch.stack(feature[-num_layers:], axis=3)\n    if detach:\n        return feature.clone().detach()\n    else:\n        return feature\n\ndef extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):\n    \"\"\"\n    Handles vi-bart.  May need testing before using on other bart\n\n    https://github.com/VinAIResearch/BARTpho\n    \"\"\"\n    processed = [] # final product, returns the list of list of word representation\n\n    sentences = [\" \".join([word.replace(\" \", \"_\") for word in sentence]) for sentence in data]\n    tokenized = tokenizer(sentences, return_tensors='pt', padding=True, return_attention_mask=True)\n    input_ids = tokenized['input_ids'].to(device)\n    attention_mask = tokenized['attention_mask'].to(device)\n\n    for i in range(int(math.ceil(len(sentences)/128))):\n        start_sentence = i * 128\n        end_sentence = min(start_sentence + 128, len(sentences))\n        input_ids = input_ids[start_sentence:end_sentence]\n        attention_mask = attention_mask[start_sentence:end_sentence]\n\n        if detach:\n            with torch.no_grad():\n                features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)\n                features = cloned_feature(features.decoder_hidden_states, num_layers, detach)\n        else:\n            features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)\n            features = cloned_feature(features.decoder_hidden_states, num_layers, detach)\n\n        for feature, sentence in zip(features, data):\n            # +2 for the endpoints\n            feature = feature[:len(sentence)+2]\n            if not keep_endpoints:\n                feature = feature[1:-1]\n            processed.append(feature)\n\n    return processed\n\ndef extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):\n    \"\"\"\n    Extract transformer embeddings using a method specifically for phobert\n\n    Since phobert doesn't have the is_split_into_words / tokenized.word_ids(batch_index=0)\n    capability, we instead look for @@ to denote a continued token.\n    data: list of list of string (the text tokens)\n    \"\"\"\n    processed = [] # final product, returns the list of list of word representation\n    tokenized_sents = [] # list of sentences, each is a torch tensor with start and end token\n    list_tokenized = [] # list of tokenized sentences from phobert\n    for idx, sent in enumerate(data):\n\n        tokenized, tokenized_sent = tokenize_manual(model_name, sent, tokenizer)\n\n        #add tokenized to list_tokenzied for later checking\n        list_tokenized.append(tokenized)\n\n        if len(tokenized_sent) > tokenizer.model_max_length:\n            logger.error(\"Invalid size, max size: %d, got %d %s\", tokenizer.model_max_length, len(tokenized_sent), data[idx])\n            raise TextTooLongError(len(tokenized_sent), tokenizer.model_max_length, idx, \" \".join(data[idx]))\n\n        #add to tokenized_sents\n        tokenized_sents.append(torch.tensor(tokenized_sent).detach())\n\n        processed_sent = []\n        processed.append(processed_sent)\n\n        # done loading bert emb\n\n    size = len(tokenized_sents)\n\n    #padding the inputs\n    tokenized_sents_padded = torch.nn.utils.rnn.pad_sequence(tokenized_sents,batch_first=True,padding_value=tokenizer.pad_token_id)\n\n    features = []\n\n    # Feed into PhoBERT 128 at a time in a batch fashion. In testing, the loop was\n    # run only 1 time as the batch size for the outer model was less than that\n    # (30 for conparser, for example)\n    for i in range(int(math.ceil(size/128))):\n        padded_input = tokenized_sents_padded[128*i:128*i+128]\n        start_sentence = i * 128\n        end_sentence = start_sentence + padded_input.shape[0]\n        attention_mask = torch.zeros(end_sentence - start_sentence, padded_input.shape[1], device=device)\n        for sent_idx, sent in enumerate(tokenized_sents[start_sentence:end_sentence]):\n            attention_mask[sent_idx, :len(sent)] = 1\n        if detach:\n            with torch.no_grad():\n                # TODO: is the clone().detach() necessary?\n                feature = model(padded_input.clone().detach().to(device), attention_mask=attention_mask, output_hidden_states=True)\n                features += cloned_feature(feature.hidden_states, num_layers, detach)\n        else:\n            feature = model(padded_input.to(device), attention_mask=attention_mask, output_hidden_states=True)\n            features += cloned_feature(feature.hidden_states, num_layers, detach)\n\n    assert len(features)==size\n    assert len(features)==len(processed)\n\n    #process the output\n    #only take the vector of the last word piece of a word/ you can do other methods such as first word piece or averaging.\n    # idx2+1 compensates for the start token at the start of a sentence\n    offsets = [[idx2+1 for idx2, _ in enumerate(list_tokenized[idx]) if (idx2 > 0 and not list_tokenized[idx][idx2-1].endswith(\"@@\")) or (idx2==0)]\n                for idx, sent in enumerate(processed)]\n    if keep_endpoints:\n        # [0] and [-1] grab the start and end representations as well\n        offsets = [[0] + off + [-1] for off in offsets]\n    processed = [feature[offset] for feature, offset in zip(features, offsets)]\n\n    # This is a list of tensors\n    # Each tensor holds the representation of a sentence extracted from phobert\n    return processed\n\nBAD_TOKENIZERS = ('bert-base-german-cased',\n                  # the dbmdz tokenizers turn one or more types of characters into empty words\n                  # for example, from PoSTWITA:\n                  #   ewww 󾓺 — in viaggio Roma\n                  # the character which may not be rendering properly is 0xFE4FA\n                  # https://github.com/dbmdz/berts/issues/48\n                  'dbmdz/bert-base-german-cased',\n                  'dbmdz/bert-base-italian-xxl-cased',\n                  'dbmdz/bert-base-italian-cased',\n                  'dbmdz/electra-base-italian-xxl-cased-discriminator',\n                  # each of these (perhaps using similar tokenizers?)\n                  # does not digest the script-flip-mark \\u200f\n                  'avichr/heBERT',\n                  'onlplab/alephbert-base',\n                  'imvladikon/alephbertgimmel-base-512',\n                  # these indonesian models fail on a sentence in the Indonesian GSD dataset:\n                  # 'Tak', 'dapat', 'disangkal', 'jika', '\\u200e', 'kemenangan', ...\n                  # weirdly some other indonesian models (even by the same group) don't have that problem\n                  'cahya/bert-base-indonesian-1.5G',\n                  'indolem/indobert-base-uncased',\n                  'google/muril-base-cased',\n                  'l3cube-pune/marathi-roberta')\n\ndef fix_blank_tokens(tokenizer, data):\n    \"\"\"Patch bert tokenizers with missing characters\n\n    There is an issue that some tokenizers (so far the German ones identified above)\n    tokenize soft hyphens or other unknown characters into nothing\n    If an entire word is tokenized as a soft hyphen, this means the tokenizer\n    simply vaporizes that word.  The result is we're missing an embedding for\n    an entire word we wanted to use.\n\n    The solution we take here is to look for any words which get vaporized\n    in such a manner, eg `len(token) == 2`, and replace it with a regular \"-\"\n\n    Actually, recently we have found that even the Bert / Electra tokenizer\n    can do this in the case of \"words\" which are one special character long,\n    so the easiest thing to do is just always run this function\n    \"\"\"\n    new_data = []\n    for sentence in data:\n        tokenized = tokenizer(sentence, is_split_into_words=False).input_ids\n        new_sentence = [word if len(token) > 2 else \"-\" for word, token in zip(sentence, tokenized)]\n        new_data.append(new_sentence)\n    return new_data\n\ndef extract_llama_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):\n    # will calculate attention masks ourselves later\n    tokenized = tokenizer(data, is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=False)\n\n    list_offsets = []\n    for idx in range(len(data)):\n        converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))\n        list_offsets.append(converted_offsets)\n\n    if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):\n        raise ValueError(\"OOPS, hit None when preparing to use transformer at idx {}\\ndata[idx]: {}\\nlist_offsets[idx]: {}\\ntokenizer output: {}\".format(idx, data[idx], list_offsets[idx], tokenized))\n\n    features = []\n    for i in range(int(math.ceil(len(data)/128))):\n        id_rows = [id_row + [tokenizer.eos_token_id] for id_row in tokenized['input_ids'][128*i:128*i+128]]\n        max_id_len = max(len(x) for x in id_rows)\n        attention_tensor = torch.zeros((len(id_rows), max_id_len), dtype=torch.long, device=device)\n        for idx, id_row in enumerate(id_rows):\n            attention_tensor[idx, :len(id_row)] = 1\n            if len(id_row) < max_id_len:\n                # actually this value doesn't matter... autoregressive\n                id_row.extend([0] * (max_id_len - len(id_row)))\n        id_tensor = torch.tensor(id_rows, device=device)\n\n        if detach:\n            with torch.no_grad():\n                features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)\n        else:\n            features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)\n\n    processed = []\n    #process the output\n    if not keep_endpoints:\n        #remove the bos and eos tokens\n        list_offsets = [sent[1:-1] for sent in list_offsets]\n    for feature, offsets in zip(features, list_offsets):\n        new_sent = feature[offsets]\n        processed.append(new_sent)\n\n    return processed\n\n\ndef extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):\n    # using attention masks makes contextual embeddings much more useful for downstream tasks\n    tokenized = tokenizer(data, is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=False)\n    #tokenized = tokenizer(data, padding=\"longest\", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)\n\n    list_offsets = [[None] * (len(sentence)+2) for sentence in data]\n    for idx in range(len(data)):\n        offsets = tokenized.word_ids(batch_index=idx)\n        list_offsets[idx][0] = 0\n        for pos, offset in enumerate(offsets):\n            if offset is None:\n                break\n            # this uses the last token piece for any offset by overwriting the previous value\n            # this will be one token earlier\n            # we will add a <pad> to the start of each sentence for the endpoints\n            list_offsets[idx][offset+1] = pos + 1\n        list_offsets[idx][-1] = list_offsets[idx][-2] + 1\n        if any(x is None for x in list_offsets[idx]):\n            raise ValueError(\"OOPS, hit None when preparing to use Bert\\ndata[idx]: {}\\noffsets: {}\\nlist_offsets[idx]: {}\".format(data[idx], offsets, list_offsets[idx], tokenized))\n\n        if len(offsets) > tokenizer.model_max_length - 2:\n            logger.error(\"Invalid size, max size: %d, got %d %s\", tokenizer.model_max_length, len(offsets), data[idx])\n            raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, \" \".join(data[idx]))\n\n    features = []\n    for i in range(int(math.ceil(len(data)/128))):\n        # TODO: find a suitable representation for attention masks for xlnet\n        # xlnet base on WSJ:\n        # sep_token_id at beginning, cls_token_id at end:     0.9441\n        # bos_token_id at beginning, eos_token_id at end:     0.9463\n        # bos_token_id at beginning, sep_token_id at end:     0.9459\n        # bos_token_id at beginning, cls_token_id at end:     0.9457\n        # bos_token_id at beginning, sep/cls at end:          0.9454\n        # use the xlnet tokenization with words at end,\n        # begin token is last pad, end token is sep, no mask: 0.9463\n        # same, but with masks:                               0.9440\n        input_ids = [[tokenizer.bos_token_id] + x[:-2] + [tokenizer.eos_token_id] for x in tokenized['input_ids'][128*i:128*i+128]]\n        max_len = max(len(x) for x in input_ids)\n        attention_mask = torch.zeros(len(input_ids), max_len, dtype=torch.long, device=device)\n        for idx, input_row in enumerate(input_ids):\n            attention_mask[idx, :len(input_row)] = 1\n            if len(input_row) < max_len:\n                input_row.extend([tokenizer.pad_token_id] * (max_len - len(input_row)))\n        if detach:\n            with torch.no_grad():\n                id_tensor = torch.tensor(input_ids, device=device)\n                feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)\n                # feature[2] is the same for bert, but it didn't work for\n                # older versions of transformers for xlnet\n                # feature = feature[2]\n                features += cloned_feature(feature.hidden_states, num_layers, detach)\n        else:\n            id_tensor = torch.tensor(input_ids, device=device)\n            feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)\n            # feature[2] is the same for bert, but it didn't work for\n            # older versions of transformers for xlnet\n            # feature = feature[2]\n            features += cloned_feature(feature.hidden_states, num_layers, detach)\n\n    processed = []\n    #process the output\n    if not keep_endpoints:\n        #remove the bos and eos tokens\n        list_offsets = [sent[1:-1] for sent in list_offsets]\n    for feature, offsets in zip(features, list_offsets):\n        new_sent = feature[offsets]\n        processed.append(new_sent)\n\n    return processed\n\ndef build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device):\n    \"\"\"\n    Extract an embedding from the given transformer for a certain attention mask and tokens range\n\n    In the event that the tokens are longer than the max length\n    supported by the model, the range is split up into overlapping\n    sections and the overlapping pieces are connected.  No idea if\n    this is actually any good, but at least it returns something\n    instead of horribly failing\n\n    TODO: at least two upgrades are very relevant\n      1) cut off some overlap at the end as well\n      2) use this on the phobert, bart, and xln versions as well\n    \"\"\"\n    if attention_tensor.shape[1] <= tokenizer.model_max_length:\n        features = model(id_tensor, attention_mask=attention_tensor, output_hidden_states=True)\n        features = cloned_feature(features.hidden_states, num_layers, detach)\n        return features\n\n    slices = []\n    slice_len = max(tokenizer.model_max_length - 20, tokenizer.model_max_length // 2)\n    prefix_len = tokenizer.model_max_length - slice_len\n    if slice_len < 5:\n        raise RuntimeError(\"Really tiny tokenizer!\")\n    remaining_attention = attention_tensor\n    remaining_ids = id_tensor\n    while True:\n        attention_slice = remaining_attention[:, :tokenizer.model_max_length]\n        id_slice = remaining_ids[:, :tokenizer.model_max_length]\n        features = model(id_slice, attention_mask=attention_slice, output_hidden_states=True)\n        features = cloned_feature(features.hidden_states, num_layers, detach)\n        if len(slices) > 0:\n            features = features[:, prefix_len:, :]\n        slices.append(features)\n        if remaining_attention.shape[1] <= tokenizer.model_max_length:\n            break\n        remaining_attention = remaining_attention[:, slice_len:]\n        remaining_ids = remaining_ids[:, slice_len:]\n    slices = torch.cat(slices, axis=1)\n    return slices\n\n\ndef convert_to_position_list(sentence, offsets):\n    \"\"\"\n    Convert a transformers-tokenized sentence's offsets to a list of word to position\n    \"\"\"\n    # +2 for the beginning and end\n    list_offsets = [None] * (len(sentence) + 2)\n    for pos, offset in enumerate(offsets):\n        if offset is None:\n            continue\n        # this uses the last token piece for any offset by overwriting the previous value\n        list_offsets[offset+1] = pos\n    list_offsets[0] = 0\n    for offset in list_offsets[-2::-1]:\n        # count backwards in case the last position was\n        # a word or character that got erased by the tokenizer\n        # this loop should eventually find something...\n        # after all, we just set the first one to be 0\n        if offset is not None:\n            list_offsets[-1] = offset + 1\n            break\n    return list_offsets\n\ndef extract_base_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach):\n    #add add_prefix_space = True for RoBerTa-- error if not\n    # using attention masks makes contextual embeddings much more useful for downstream tasks\n    tokenized = tokenizer(data, padding=\"longest\", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)\n    list_offsets = []\n    for idx in range(len(data)):\n        converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))\n        list_offsets.append(converted_offsets)\n\n        #if list_offsets[idx][-1] > tokenizer.model_max_length - 1:\n        #    logger.error(\"Invalid size, max size: %d, got %d.\\nTokens: %s\\nTokenized: %s\", tokenizer.model_max_length, len(offsets), data[idx][:1000], offsets[:1000])\n        #    raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, \" \".join(data[idx]))\n\n    if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):\n        # at least one of the tokens in the data is composed entirely of characters the tokenizer doesn't know about\n        # one possible approach would be to retokenize only those sentences\n        # however, in that case the attention mask might be of a different length,\n        # as would the token ids, and it would be a pain to fix those\n        # easiest to just retokenize the whole thing, hopefully a rare event\n        data = fix_blank_tokens(tokenizer, data)\n\n        tokenized = tokenizer(data, padding=\"longest\", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)\n        list_offsets = []\n        for idx in range(len(data)):\n            converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))\n            list_offsets.append(converted_offsets)\n\n    if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):\n        raise ValueError(\"OOPS, hit None when preparing to use transformer at idx {}\\ndata[idx]: {}\\nlist_offsets[idx]: {}\\ntokenizer output: {}\".format(idx, data[idx], list_offsets[idx], tokenized))\n\n\n    features = []\n    for i in range(int(math.ceil(len(data)/128))):\n        attention_tensor = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)\n        id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)\n        if detach:\n            with torch.no_grad():\n                features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)\n        else:\n            features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)\n\n    processed = []\n    #process the output\n    if not keep_endpoints:\n        #remove the bos and eos tokens\n        list_offsets = [sent[1:-1] for sent in list_offsets]\n    for feature, offsets in zip(features, list_offsets):\n        new_sent = feature[offsets]\n        processed.append(new_sent)\n\n    return processed\n\ndef extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers=None, detach=True, peft_name=None):\n    \"\"\"\n    Extract transformer embeddings using a generic roberta extraction\n\n    data: list of list of string (the text tokens)\n    num_layers: how many to return.  If None, the average of -2, -3, -4 is returned\n    \"\"\"\n    # TODO: can maybe cache this value for a model and save some time\n    # TODO: too bad it isn't thread safe, but then again, who does?\n    if peft_name is None:\n        if model._hf_peft_config_loaded:\n            model.disable_adapters()\n    else:\n        model.enable_adapters()\n        model.set_adapter(peft_name)\n\n    if model_name.startswith(\"vinai/phobert\"):\n        return extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)\n\n    if 'bart' in model_name:\n        # this should work with \"vinai/bartpho-word\"\n        # not sure this works with any other Bart\n        return extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)\n\n    if isinstance(data, tuple):\n        data = list(data)\n\n    if \"xlnet\" in model_name:\n        return extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)\n\n    if \"LLaMA\" in model_name:\n        return extract_llama_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)\n\n    return extract_base_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)\n\n"
  },
  {
    "path": "stanza/models/common/biaffine.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass PairwiseBilinear(nn.Module):\n    ''' A bilinear module that deals with broadcasting for efficient memory usage.\n    Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)\n    Output: tensor of size (N x L1 x L2 x O)'''\n    def __init__(self, input1_size, input2_size, output_size, bias=True):\n        super().__init__()\n\n        self.input1_size = input1_size\n        self.input2_size = input2_size\n        self.output_size = output_size\n\n        self.weight = nn.Parameter(torch.Tensor(input1_size, input2_size, output_size))\n        self.bias = nn.Parameter(torch.Tensor(output_size)) if bias else 0\n\n    def forward(self, input1, input2):\n        input1_size = list(input1.size())\n        input2_size = list(input2.size())\n        output_size = [input1_size[0], input1_size[1], input2_size[1], self.output_size]\n\n        # ((N x L1) x D1) * (D1 x (D2 x O)) -> (N x L1) x (D2 x O)\n        intermediate = torch.mm(input1.view(-1, input1_size[-1]), self.weight.view(-1, self.input2_size * self.output_size))\n        # (N x L2 x D2) -> (N x D2 x L2)\n        input2 = input2.transpose(1, 2)\n        # (N x (L1 x O) x D2) * (N x D2 x L2) -> (N x (L1 x O) x L2)\n        output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)\n        # (N x (L1 x O) x L2) -> (N x L1 x L2 x O)\n        output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3)\n\n        return output\n\nclass BiaffineScorer(nn.Module):\n    def __init__(self, input1_size, input2_size, output_size):\n        super().__init__()\n        self.W_bilin = nn.Bilinear(input1_size + 1, input2_size + 1, output_size)\n\n        self.W_bilin.weight.data.zero_()\n        self.W_bilin.bias.data.zero_()\n\n    def forward(self, input1, input2):\n        input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size())-1)\n        input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size())-1)\n        return self.W_bilin(input1, input2)\n\nclass PairwiseBiaffineScorer(nn.Module):\n    def __init__(self, input1_size, input2_size, output_size):\n        super().__init__()\n        self.W_bilin = PairwiseBilinear(input1_size + 1, input2_size + 1, output_size)\n\n        self.W_bilin.weight.data.zero_()\n        self.W_bilin.bias.data.zero_()\n\n    def forward(self, input1, input2):\n        input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size())-1)\n        input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size())-1)\n        return self.W_bilin(input1, input2)\n\nclass DeepBiaffineScorer(nn.Module):\n    def __init__(self, input1_size, input2_size, hidden_size, output_size, hidden_func=F.relu, dropout=0, pairwise=True):\n        super().__init__()\n        self.W1 = nn.Linear(input1_size, hidden_size)\n        self.W2 = nn.Linear(input2_size, hidden_size)\n        self.hidden_func = hidden_func\n        if pairwise:\n            self.scorer = PairwiseBiaffineScorer(hidden_size, hidden_size, output_size)\n        else:\n            self.scorer = BiaffineScorer(hidden_size, hidden_size, output_size)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, input1, input2):\n        return self.scorer(self.dropout(self.hidden_func(self.W1(input1))), self.dropout(self.hidden_func(self.W2(input2))))\n\nif __name__ == \"__main__\":\n    x1 = torch.randn(3,4)\n    x2 = torch.randn(3,5)\n    scorer = DeepBiaffineScorer(4, 5, 6, 7)\n    print(scorer(x1, x2))\n"
  },
  {
    "path": "stanza/models/common/build_short_name_to_treebank.py",
    "content": "import glob\nimport os\n\nfrom stanza.models.common.constant import treebank_to_short_name, UnknownLanguageError, treebank_special_cases\nfrom stanza.utils import default_paths\n\npaths = default_paths.get_default_paths()\nudbase = paths[\"UDBASE\"]\n\ndirectories = glob.glob(udbase + \"/UD_*\")\ndirectories.sort()\n\noutput_name = os.path.join(os.path.split(__file__)[0], \"short_name_to_treebank.py\")\nud_names = [os.path.split(ud_path)[1] for ud_path in directories]\nshort_names = []\n\n# check that all languages are known in the language map\n# use that language map to come up with a shortname for these treebanks\nfor directory, ud_name in zip(directories, ud_names):\n    try:\n        short_names.append(treebank_to_short_name(ud_name))\n    except UnknownLanguageError as e:\n        raise UnknownLanguageError(\"Could not find language short name for dataset %s, path %s\" % (ud_name, directory)) from e\n\nfor directory, ud_name in zip(directories, ud_names):\n    if ud_name.startswith(\"UD_Norwegian\"):\n        if ud_name not in treebank_special_cases:\n            raise ValueError(\"Please figure out if dataset %s is NN or NB, then add to treebank_special_cases\" % ud_name)\n    if ud_name.startswith(\"UD_Chinese\"):\n        if ud_name not in treebank_special_cases:\n            raise ValueError(\"Please figure out if dataset %s is NN or NB, then add to treebank_special_cases\" % ud_name)\n\nmax_len = max(len(x) for x in short_names) + 8\nline_format = \"    %-\" + str(max_len) + \"s '%s',\\n\"\n\n\nprint(\"Writing to %s\" % output_name)\nwith open(output_name, \"w\") as fout:\n    fout.write(\"# This module is autogenerated by build_short_name_to_treebank.py\\n\")\n    fout.write(\"# Please do not edit\\n\")\n    fout.write(\"\\n\")\n    fout.write(\"SHORT_NAMES = {\\n\")\n    for short_name, ud_name in zip(short_names, ud_names):\n        fout.write(line_format % (\"'\" + short_name + \"':\", ud_name))\n\n        if short_name.startswith(\"zh_\"):\n            short_name = \"zh-hans_\" + short_name[3:]\n            fout.write(line_format % (\"'\" + short_name + \"':\", ud_name))\n        elif short_name.startswith(\"zh-hans_\") or short_name.startswith(\"zh-hant_\"):\n            short_name = \"zh_\" + short_name[8:]\n            fout.write(line_format % (\"'\" + short_name + \"':\", ud_name))\n        elif short_name == 'nb_bokmaal':\n            short_name = 'no_bokmaal'\n            fout.write(line_format % (\"'\" + short_name + \"':\", ud_name))\n\n    fout.write(\"}\\n\")\n\n    fout.write(\"\"\"\n\ndef short_name_to_treebank(short_name):\n    return SHORT_NAMES[short_name]\n\n\n\"\"\")\n\n    max_len = max(len(x) for x in ud_names) + 5\n    line_format = \"    %-\" + str(max_len) + \"s '%s',\\n\"\n    fout.write(\"CANONICAL_NAMES = {\\n\")\n    for ud_name in ud_names:\n        fout.write(line_format % (\"'\" + ud_name.lower() + \"':\", ud_name))\n    fout.write(\"}\\n\")\n    fout.write(\"\"\"\n\ndef canonical_treebank_name(ud_name):\n    if ud_name in SHORT_NAMES:\n        return SHORT_NAMES[ud_name]\n    return CANONICAL_NAMES.get(ud_name.lower(), ud_name)\n\"\"\")\n"
  },
  {
    "path": "stanza/models/common/char_model.py",
    "content": "\"\"\"\nBased on\n\n@inproceedings{akbik-etal-2018-contextual,\n    title = \"Contextual String Embeddings for Sequence Labeling\",\n    author = \"Akbik, Alan  and\n      Blythe, Duncan  and\n      Vollgraf, Roland\",\n    booktitle = \"Proceedings of the 27th International Conference on Computational Linguistics\",\n    month = aug,\n    year = \"2018\",\n    address = \"Santa Fe, New Mexico, USA\",\n    publisher = \"Association for Computational Linguistics\",\n    url = \"https://aclanthology.org/C18-1139\",\n    pages = \"1638--1649\",\n}\n\"\"\"\n\nfrom collections import Counter\nfrom operator import itemgetter\nimport os\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, pack_padded_sequence, PackedSequence\n\nfrom stanza.models.common.data import get_long_tensor\nfrom stanza.models.common.packed_lstm import PackedLSTM\nfrom stanza.models.common.utils import open_read_text, tensor_unsort, unsort\nfrom stanza.models.common.dropout import SequenceUnitDropout\nfrom stanza.models.common.vocab import UNK_ID, CharVocab\n\nclass CharacterModel(nn.Module):\n    def __init__(self, args, vocab, pad=False, bidirectional=False, attention=True):\n        super().__init__()\n        self.args = args\n        self.pad = pad\n        self.num_dir = 2 if bidirectional else 1\n        self.attn = attention\n\n        # char embeddings\n        self.char_emb = nn.Embedding(len(vocab['char']), self.args['char_emb_dim'], padding_idx=0)\n        if self.attn: \n            self.char_attn = nn.Linear(self.num_dir * self.args['char_hidden_dim'], 1, bias=False)\n            self.char_attn.weight.data.zero_()\n\n        # modules\n        self.charlstm = PackedLSTM(self.args['char_emb_dim'], self.args['char_hidden_dim'], self.args['char_num_layers'], batch_first=True, \\\n                dropout=0 if self.args['char_num_layers'] == 1 else args['dropout'], rec_dropout = self.args['char_rec_dropout'], bidirectional=bidirectional)\n        self.charlstm_h_init = nn.Parameter(torch.zeros(self.num_dir * self.args['char_num_layers'], 1, self.args['char_hidden_dim']))\n        self.charlstm_c_init = nn.Parameter(torch.zeros(self.num_dir * self.args['char_num_layers'], 1, self.args['char_hidden_dim']))\n\n        self.dropout = nn.Dropout(args['dropout'])\n\n    def forward(self, chars, chars_mask, word_orig_idx, sentlens, wordlens):\n        embs = self.dropout(self.char_emb(chars))\n        batch_size = embs.size(0)\n        embs = pack_padded_sequence(embs, wordlens, batch_first=True)\n        output = self.charlstm(embs, wordlens, hx=(\\\n                self.charlstm_h_init.expand(self.num_dir * self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous(), \\\n                self.charlstm_c_init.expand(self.num_dir * self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous()))\n         \n        # apply attention, otherwise take final states\n        if self.attn:\n            char_reps = output[0]\n            weights = torch.sigmoid(self.char_attn(self.dropout(char_reps.data)))\n            char_reps = PackedSequence(char_reps.data * weights, char_reps.batch_sizes)\n            char_reps, _ = pad_packed_sequence(char_reps, batch_first=True)\n            res = char_reps.sum(1)\n        else:\n            h, c = output[1]\n            res = h[-2:].transpose(0,1).contiguous().view(batch_size, -1)\n\n        # recover character order and word separation\n        res = tensor_unsort(res, word_orig_idx)\n        res = pack_sequence(res.split(sentlens))\n        if self.pad:\n            res = pad_packed_sequence(res, batch_first=True)[0]\n\n        return res\n\ndef build_charlm_vocab(path, cutoff=0):\n    \"\"\"\n    Build a vocab for a CharacterLanguageModel\n\n    Requires a large amount of memory, but only need to build once\n\n    here we need some trick to deal with excessively large files\n    for each file we accumulate the counter of characters, and\n    at the end we simply pass a list of chars to the vocab builder\n    \"\"\"\n    counter = Counter()\n    if os.path.isdir(path):\n        filenames = sorted(os.listdir(path))\n    else:\n        filenames = [os.path.split(path)[1]]\n        path = os.path.split(path)[0]\n\n    for filename in filenames:\n        filename = os.path.join(path, filename)\n        with open_read_text(filename) as fin:\n            for line in fin:\n                counter.update(list(line))\n\n    if len(counter) == 0:\n        raise ValueError(\"Training data was empty!\")\n    # remove infrequent characters from vocab\n    for k in list(counter.keys()):\n        if counter[k] < cutoff:\n            del counter[k]\n    # a singleton list of all characters\n    data = [sorted([x[0] for x in counter.most_common()])]\n    if len(data[0]) == 0:\n        raise ValueError(\"All characters in the training data were less frequent than --cutoff!\")\n    vocab = CharVocab(data) # skip cutoff argument because this has been dealt with\n    return vocab\n\nCHARLM_START = \"\\n\"\nCHARLM_END = \" \"\n\nclass CharacterLanguageModel(nn.Module):\n\n    def __init__(self, args, vocab, pad=False, is_forward_lm=True):\n        super().__init__()\n        self.args = args\n        self.vocab = vocab\n        self.is_forward_lm = is_forward_lm\n        self.pad = pad\n        self.finetune = True # always finetune unless otherwise specified\n\n        # char embeddings\n        self.char_emb = nn.Embedding(len(self.vocab['char']), self.args['char_emb_dim'], padding_idx=None) # we use space as padding, so padding_idx is not necessary\n        \n        # modules\n        self.charlstm = PackedLSTM(self.args['char_emb_dim'], self.args['char_hidden_dim'], self.args['char_num_layers'], batch_first=True, \\\n                dropout=0 if self.args['char_num_layers'] == 1 else args['char_dropout'], rec_dropout = self.args['char_rec_dropout'], bidirectional=False)\n        self.charlstm_h_init = nn.Parameter(torch.zeros(self.args['char_num_layers'], 1, self.args['char_hidden_dim']))\n        self.charlstm_c_init = nn.Parameter(torch.zeros(self.args['char_num_layers'], 1, self.args['char_hidden_dim']))\n\n        # decoder\n        self.decoder = nn.Linear(self.args['char_hidden_dim'], len(self.vocab['char']))\n        self.dropout = nn.Dropout(args['char_dropout'])\n        self.char_dropout = SequenceUnitDropout(args.get('char_unit_dropout', 0), UNK_ID)\n\n    def forward(self, chars, charlens, hidden=None):\n        chars = self.char_dropout(chars)\n        embs = self.dropout(self.char_emb(chars))\n        batch_size = embs.size(0)\n        embs = pack_padded_sequence(embs, charlens, batch_first=True)\n        if hidden is None: \n            hidden = (self.charlstm_h_init.expand(self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous(),\n                      self.charlstm_c_init.expand(self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous())\n        output, hidden = self.charlstm(embs, charlens, hx=hidden)\n        output = self.dropout(pad_packed_sequence(output, batch_first=True)[0])\n        decoded = self.decoder(output)\n        return output, hidden, decoded\n\n    def get_representation(self, chars, charoffsets, charlens, char_orig_idx):\n        with torch.no_grad():\n            output, _, _ = self.forward(chars, charlens)\n            res = [output[i, offsets] for i, offsets in enumerate(charoffsets)]\n            res = unsort(res, char_orig_idx)\n            res = pack_sequence(res)\n            if self.pad:\n                res = pad_packed_sequence(res, batch_first=True)[0]\n        return res\n\n    def per_char_representation(self, words):\n        device = next(self.parameters()).device\n        vocab = self.char_vocab()\n\n        all_data = [(vocab.map(word), len(word), idx) for idx, word in enumerate(words)]\n        all_data.sort(key=itemgetter(1), reverse=True)\n        chars = [x[0] for x in all_data]\n        char_lens = [x[1] for x in all_data]\n        char_tensor = get_long_tensor(chars, len(chars), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)\n        with torch.no_grad():\n            output, _, _ = self.forward(char_tensor, char_lens)\n            output = [x[:y, :] for x, y in zip(output, char_lens)]\n            output = unsort(output, [x[2] for x in all_data])\n        return output\n\n    def build_char_representation(self, sentences):\n        \"\"\"\n        Return values from this charlm for a list of list of words\n\n        input: [[str]]\n          K sentences, each of length Ki (can be different for each sentence)\n        output: [tensor(Ki x dim)]\n          list of tensors, each one with shape Ki by the dim of the character model\n\n        Values are taken from the last character in a word for each word.\n        The words are effectively treated as if they are whitespace separated\n        (which may actually be somewhat inaccurate for languages such as Chinese or for MWT)\n        \"\"\"\n        forward = self.is_forward_lm\n        vocab = self.char_vocab()\n        device = next(self.parameters()).device\n\n        all_data = []\n        for idx, words in enumerate(sentences):\n            if not forward:\n                words = [x[::-1] for x in reversed(words)]\n\n            chars = [CHARLM_START]\n            offsets = []\n            for w in words:\n                chars.extend(w)\n                chars.append(CHARLM_END)\n                offsets.append(len(chars) - 1)\n            if not forward:\n                offsets.reverse()\n            chars = vocab.map(chars)\n            all_data.append((chars, offsets, len(chars), len(all_data)))\n\n        all_data.sort(key=itemgetter(2), reverse=True)\n        chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data))\n        # TODO: can this be faster?\n        chars = get_long_tensor(chars, len(all_data), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)\n\n        with torch.no_grad():\n            output, _, _ = self.forward(chars, char_lens)\n            res = [output[i, offsets] for i, offsets in enumerate(char_offsets)]\n            res = unsort(res, orig_idx)\n\n        return res\n\n    def hidden_dim(self):\n        return self.args['char_hidden_dim']\n\n    def char_vocab(self):\n        return self.vocab['char']\n\n    def train(self, mode=True):\n        \"\"\"\n        Override the default train() function, so that when self.finetune == False, the training mode \n        won't be impacted by the parent models' status change.\n        \"\"\"\n        if not mode: # eval() is always allowed, regardless of finetune status\n            super().train(mode)\n        else:\n            if self.finetune: # only set to training mode in finetune status\n                super().train(mode)\n\n    def full_state(self):\n        state = {\n            'vocab': self.vocab['char'].state_dict(),\n            'args': self.args,\n            'state_dict': self.state_dict(),\n            'pad': self.pad,\n            'is_forward_lm': self.is_forward_lm\n        }\n        return state\n\n    def save(self, filename):\n        os.makedirs(os.path.split(filename)[0], exist_ok=True)\n        state = self.full_state()\n        torch.save(state, filename, _use_new_zipfile_serialization=False)\n\n    @classmethod\n    def from_full_state(cls, state, finetune=False):\n        vocab = {'char': CharVocab.load_state_dict(state['vocab'])}\n        model = cls(state['args'], vocab, state['pad'], state['is_forward_lm'])\n        model.load_state_dict(state['state_dict'])\n        model.eval()\n        model.finetune = finetune # set finetune status\n        return model\n\n    @classmethod\n    def load(cls, filename, finetune=False):\n        state = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        # allow saving just the Model object,\n        # and allow for old charlms to still work\n        if 'state_dict' in state:\n            return cls.from_full_state(state, finetune)\n        return cls.from_full_state(state['model'], finetune)\n\nclass CharacterLanguageModelWordAdapter(nn.Module):\n    \"\"\"\n    Adapts a character model to return embeddings for each character in a word\n    \"\"\"\n    def __init__(self, charlms):\n        super().__init__()\n        self.charlms = charlms\n\n    def forward(self, words, wrap=True):\n        if wrap:\n            words = [CHARLM_START + x + CHARLM_END for x in words]\n        padded_reps = []\n        for charlm in self.charlms:\n            rep = charlm.per_char_representation(words)\n            padded_rep = torch.zeros(len(rep), max(x.shape[0] for x in rep), rep[0].shape[1], dtype=rep[0].dtype, device=rep[0].device)\n            for idx, row in enumerate(rep):\n                padded_rep[idx, :row.shape[0], :] = row\n            padded_reps.append(padded_rep)\n        padded_rep = torch.cat(padded_reps, dim=2)\n        return padded_rep\n\n    def hidden_dim(self):\n        return sum(charlm.hidden_dim() for charlm in self.charlms)\n\nclass CharacterLanguageModelTrainer():\n    def __init__(self, model, params, optimizer, criterion, scheduler, epoch=1, global_step=0):\n        self.model = model\n        self.params = params\n        self.optimizer = optimizer\n        self.criterion = criterion\n        self.scheduler = scheduler\n        self.epoch = epoch\n        self.global_step = global_step\n\n    def save(self, filename, full=True):\n        os.makedirs(os.path.split(filename)[0], exist_ok=True)\n        state = {\n            'model': self.model.full_state(),\n            'epoch': self.epoch,\n            'global_step': self.global_step,\n        }\n        if full and self.optimizer is not None:\n            state['optimizer'] = self.optimizer.state_dict()\n        if full and self.criterion is not None:\n            state['criterion'] = self.criterion.state_dict()\n        if full and self.scheduler is not None:\n            state['scheduler'] = self.scheduler.state_dict()\n        torch.save(state, filename, _use_new_zipfile_serialization=False)\n\n    @classmethod\n    def from_new_model(cls, args, vocab):\n        model = CharacterLanguageModel(args, vocab, is_forward_lm=True if args['direction'] == 'forward' else False)\n        model = model.to(args['device'])\n        params = [param for param in model.parameters() if param.requires_grad]\n        optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])\n        criterion = torch.nn.CrossEntropyLoss()\n        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=args['anneal'], patience=args['patience'])\n        return cls(model, params, optimizer, criterion, scheduler)\n\n\n    @classmethod\n    def load(cls, args, filename, finetune=False):\n        \"\"\"\n        Load the model along with any other saved state for training\n\n        Note that you MUST set finetune=True if planning to continue training\n        Otherwise the only benefit you will get will be a warm GPU\n        \"\"\"\n        state = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        model = CharacterLanguageModel.from_full_state(state['model'], finetune)\n        model = model.to(args['device'])\n\n        params = [param for param in model.parameters() if param.requires_grad]\n        optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])\n        if 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])\n\n        criterion = torch.nn.CrossEntropyLoss()\n        if 'criterion' in state: criterion.load_state_dict(state['criterion'])\n\n        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=args['anneal'], patience=args['patience'])\n        if 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])\n\n        epoch = state.get('epoch', 1)\n        global_step = state.get('global_step', 0)\n        return cls(model, params, optimizer, criterion, scheduler, epoch, global_step)\n\n"
  },
  {
    "path": "stanza/models/common/chuliu_edmonds.py",
    "content": "# Adapted from Tim's code here: https://github.com/tdozat/Parser-v3/blob/master/scripts/chuliu_edmonds.py\n\nimport numpy as np\n\ndef tarjan(tree):\n    \"\"\"Finds the cycles in a dependency graph\n\n    The input should be a numpy array of integers,\n    where in the standard use case,\n    tree[i] is the head of node i.\n\n    tree[0] == 0 to represent the root\n\n    so for example, for the English sentence \"This is a test\",\n    the input is\n\n    [0 4 4 4 0]\n\n    \"Arthritis makes my hip hurt\"\n\n    [0 2 0 4 2 2]\n\n    The return is a list of cycles, where in cycle has True if the\n    node at that index is participating in the cycle.\n    So, for example, the previous examples both return empty lists,\n    whereas an input of\n      np.array([0, 3, 1, 2])\n    has an output of\n      [np.array([False,  True,  True,  True])]\n    \"\"\"\n    indices = -np.ones_like(tree)\n    lowlinks = -np.ones_like(tree)\n    onstack = np.zeros_like(tree, dtype=bool)\n    stack = list()\n    _index = [0]\n    cycles = []\n    #-------------------------------------------------------------\n    def maybe_pop_cycle(i):\n        if lowlinks[i] == indices[i]:\n            # There's a cycle!\n            cycle = np.zeros_like(indices, dtype=bool)\n            while stack[-1] != i:\n                j = stack.pop()\n                onstack[j] = False\n                cycle[j] = True\n            stack.pop()\n            onstack[i] = False\n            cycle[i] = True\n            if cycle.sum() > 1:\n                cycles.append(cycle)\n\n    def initialize_strong_connect(i):\n        _index[0] += 1\n        index = _index[-1]\n        indices[i] = lowlinks[i] = index - 1\n        stack.append(i)\n        onstack[i] = True\n\n    def strong_connect(i):\n        # this ridiculous atrocity is because somehow people keep\n        # coming up with graphs which overflow python's call stack\n        # so instead we make our own call stack and turn the recursion\n        # into a loop\n        # see for example\n        #   https://github.com/stanfordnlp/stanza/issues/962\n        #   https://github.com/spraakbanken/sparv-pipeline/issues/166\n        # in an ideal world this block of code would look like this\n        #    initialize_strong_connect(i)\n        #    dependents = iter(np.where(np.equal(tree, i))[0])\n        #    for j in dependents:\n        #        if indices[j] == -1:\n        #            strong_connect(j)\n        #            lowlinks[i] = min(lowlinks[i], lowlinks[j])\n        #        elif onstack[j]:\n        #            lowlinks[i] = min(lowlinks[i], indices[j])\n        #\n        #     maybe_pop_cycle(i)\n        call_stack = [(i, None, None)]\n        while len(call_stack) > 0:\n            i, dependents_iterator, j = call_stack.pop()\n            if dependents_iterator is None: # first time getting here for this i\n                initialize_strong_connect(i)\n                dependents_iterator = iter(np.where(np.equal(tree, i))[0])\n            else: # been here before.  j was the dependent we were just considering\n                lowlinks[i] = min(lowlinks[i], lowlinks[j])\n            for j in dependents_iterator:\n                if indices[j] == -1:\n                    # have to remember where we were...\n                    # put the current iterator & its state on the \"call stack\"\n                    # we will come back to it later\n                    call_stack.append((i, dependents_iterator, j))\n                    # also, this is what we do next...\n                    call_stack.append((j, None, None))\n                    # this will break this iterator for now\n                    # the next time through, we will continue progressing this iterator\n                    break\n                elif onstack[j]:\n                    lowlinks[i] = min(lowlinks[i], indices[j])\n            else:\n                # this is an intended use of for/else\n                # please stop filing git issues on obscure language features\n                # we finished iterating without a break\n                # and can finally resolve any possible cycles\n                maybe_pop_cycle(i)\n            # at this point, there are two cases:\n            #\n            # we iterated all the way through an iterator (the else in the for/else)\n            # and have resolved any possible cycles.  can then proceed to the previous\n            # iterator we were considering (or finish, if there are no others)\n            # OR\n            # we have hit a break in the iteration over the dependents\n            # for a node\n            # and we need to dig deeper into the graph and resolve the dependent's dependents\n            # before we can continue the previous node\n            #\n            # either way, we check to see if there are unfinished subtrees\n            # when that is finally done, we can return\n\n    #-------------------------------------------------------------\n    for i in range(len(tree)):\n        if indices[i] == -1:\n            strong_connect(i)\n    return cycles\n\ndef process_cycle(tree, cycle, scores):\n    \"\"\"\n    Build a subproblem with one cycle broken\n    \"\"\"\n    # indices of cycle in original tree; (c) in t\n    cycle_locs = np.where(cycle)[0]\n    # heads of cycle in original tree; (c) in t\n    cycle_subtree = tree[cycle]\n    # scores of cycle in original tree; (c) in R\n    cycle_scores = scores[cycle, cycle_subtree]\n    # total score of cycle; () in R\n    cycle_score = cycle_scores.sum()\n\n    # locations of noncycle; (t) in [0,1]\n    noncycle = np.logical_not(cycle)\n    # indices of noncycle in original tree; (n) in t\n    noncycle_locs = np.where(noncycle)[0]\n    #print(cycle_locs, noncycle_locs)\n\n    # scores of cycle's potential heads; (c x n) - (c) + () -> (n x c) in R\n    metanode_head_scores = scores[cycle][:,noncycle] - cycle_scores[:,None] + cycle_score\n    # scores of cycle's potential dependents; (n x c) in R\n    metanode_dep_scores = scores[noncycle][:,cycle]\n    # best noncycle head for each cycle dependent; (n) in c\n    metanode_heads = np.argmax(metanode_head_scores, axis=0)\n    # best cycle head for each noncycle dependent; (n) in c\n    metanode_deps = np.argmax(metanode_dep_scores, axis=1)\n\n    # scores of noncycle graph; (n x n) in R\n    subscores = scores[noncycle][:,noncycle]\n    # pad to contracted graph; (n+1 x n+1) in R\n    subscores = np.pad(subscores, ( (0,1) , (0,1) ), 'constant')\n    # set the contracted graph scores of cycle's potential heads; (c x n)[:, (n) in n] in R -> (n) in R\n    subscores[-1, :-1] = metanode_head_scores[metanode_heads, np.arange(len(noncycle_locs))]\n    # set the contracted graph scores of cycle's potential dependents; (n x c)[(n) in n] in R-> (n) in R\n    subscores[:-1,-1] = metanode_dep_scores[np.arange(len(noncycle_locs)), metanode_deps]\n    return subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps\n\n\ndef expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps):\n    \"\"\"\n    Given a partially solved tree with a cycle and a solved subproblem\n    for the cycle, build a larger solution without the cycle\n    \"\"\"\n    # head of the cycle; () in n\n    #print(contracted_tree)\n    cycle_head = contracted_tree[-1]\n    # fixed tree: (n) in n+1\n    contracted_tree = contracted_tree[:-1]\n    # initialize new tree; (t) in 0\n    new_tree = -np.ones_like(tree)\n    #print(0, new_tree)\n    # fixed tree with no heads coming from the cycle: (n) in [0,1]\n    contracted_subtree = contracted_tree < len(contracted_tree)\n    # add the nodes to the new tree (t)[(n)[(n) in [0,1]] in t] in t = (n)[(n)[(n) in [0,1]] in n] in t\n    new_tree[noncycle_locs[contracted_subtree]] = noncycle_locs[contracted_tree[contracted_subtree]]\n    #print(1, new_tree)\n    # fixed tree with heads coming from the cycle: (n) in [0,1]\n    contracted_subtree = np.logical_not(contracted_subtree)\n    # add the nodes to the tree (t)[(n)[(n) in [0,1]] in t] in t = (c)[(n)[(n) in [0,1]] in c] in t\n    new_tree[noncycle_locs[contracted_subtree]] = cycle_locs[metanode_deps[contracted_subtree]]\n    #print(2, new_tree)\n    # add the old cycle to the tree; (t)[(c) in t] in t = (t)[(c) in t] in t\n    new_tree[cycle_locs] = tree[cycle_locs]\n    #print(3, new_tree)\n    # root of the cycle; (n)[() in n] in c = () in c\n    cycle_root = metanode_heads[cycle_head]\n    # add the root of the cycle to the new tree; (t)[(c)[() in c] in t] = (c)[() in c]\n    new_tree[cycle_locs[cycle_root]] = noncycle_locs[cycle_head]\n    #print(4, new_tree)\n    return new_tree\n\ndef prepare_scores(scores):\n    \"\"\"\n    Alter the scores matrix to avoid self loops and handle the root\n    \"\"\"\n    # prevent self-loops, set up the root location\n    np.fill_diagonal(scores, -float('inf')) # prevent self-loops\n    scores[0] = -float('inf')\n    scores[0,0] = 0\n\ndef chuliu_edmonds(scores):\n    subtree_stack = []\n\n    prepare_scores(scores)\n    tree = np.argmax(scores, axis=1)\n    cycles = tarjan(tree)\n\n    #print(scores)\n    #print(cycles)\n\n    # recursive implementation:\n    #if cycles:\n    #    # t = len(tree); c = len(cycle); n = len(noncycle)\n    #    # cycles.pop(): locations of cycle; (t) in [0,1]\n    #    subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps = process_cycle(tree, cycles.pop(), scores)\n    #    # MST with contraction; (n+1) in n+1\n    #    contracted_tree = chuliu_edmonds(subscores)\n    #    tree = expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps)\n    # unfortunately, while the recursion is simpler to understand, it can get too deep for python's stack limit\n    # so instead we make our own recursion, with blackjack and (you know how it goes)\n\n    while cycles:\n        # t = len(tree); c = len(cycle); n = len(noncycle)\n        # cycles.pop(): locations of cycle; (t) in [0,1]\n        subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps = process_cycle(tree, cycles.pop(), scores)\n        subtree_stack.append((tree, cycles, scores, subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps))\n\n        scores = subscores\n        prepare_scores(scores)\n        tree = np.argmax(scores, axis=1)\n        cycles = tarjan(tree)\n\n    while len(subtree_stack) > 0:\n        contracted_tree = tree\n        (tree, cycles, scores, subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps) = subtree_stack.pop()\n        tree = expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps)\n\n    return tree\n\n#===============================================================\ndef chuliu_edmonds_one_root(scores):\n    \"\"\"\n    Return the results of the dependency tree search, but with exactly one link to root (0)\n\n    scores is a numpy array, with scores[x][y] should be the cost for assigning y to be the head of x\n\n    Here we reweight the root arcs so as to ensure that the picker only ever chooses one root.\n    See for example\n\n    https://aclanthology.org/2021.emnlp-main.823/\n    A Root of a Problem: Optimizing Single-Root Dependency Parsing\n    Miloš Stanojević, Shay B. Cohen\n    \"\"\"\n    # we fiddle the scores to prevent double root arcs\n    # we therefore copy the array so it doesn't get messed up at the source\n    scores = scores.copy()\n    scores = scores.astype(np.float64)\n    min_score = scores[np.isfinite(scores)].min()\n    scores[:, 0] = scores[:, 0] + (min_score * scores.shape[0])\n    tree = chuliu_edmonds(scores)\n    # +1 because we cut off the first column of the tree\n    roots_to_try = np.where(np.equal(tree[1:], 0))[0]+1\n    assert len(roots_to_try) == 1, \"Rescaling by the lowest score should have prevented using multiple root edges\"\n    return tree\n"
  },
  {
    "path": "stanza/models/common/constant.py",
    "content": "\"\"\"\nGlobal constants.\n\nThese language codes mirror UD language codes when possible\n\"\"\"\n\nimport re\n\nclass UnknownLanguageError(ValueError):\n    pass\n\n# tuples in a list so we can assert that the langcodes are all unique\n# When applicable, we favor the UD decision over any other possible\n# language code or language name\n# An example of this is sab -> Bokota, instead of bgd in ISO 693-3\n# ISO 639-1 is out of date, but many of the UD datasets are labeled\n# using the two letter abbreviations, so we add those for non-UD\n# languages in the hopes that we've guessed right if those languages\n# are eventually processed\nlcode2lang_raw = [\n    (\"abq\", \"Abaza\"),\n    (\"ab\",  \"Abkhazian\"),\n    (\"aa\",  \"Afar\"),\n    (\"af\",  \"Afrikaans\"),\n    (\"ak\",  \"Akan\"),\n    (\"akk\", \"Akkadian\"),\n    (\"aqz\", \"Akuntsu\"),\n    (\"sq\",  \"Albanian\"),\n    (\"am\",  \"Amharic\"),\n    (\"grc\", \"Ancient_Greek\"),\n    (\"hbo\", \"Ancient_Hebrew\"),\n    (\"apu\", \"Apurina\"),\n    (\"ar\",  \"Arabic\"),\n    (\"arz\", \"Egyptian_Arabic\"),\n    (\"an\",  \"Aragonese\"),\n    (\"hy\",  \"Armenian\"),\n    (\"as\",  \"Assamese\"),\n    (\"aii\", \"Assyrian\"),\n    (\"ast\", \"Asturian\"),\n    (\"av\",  \"Avaric\"),\n    (\"ae\",  \"Avestan\"),\n    (\"ay\",  \"Aymara\"),\n    (\"az\",  \"Azerbaijani\"),\n    (\"bm\",  \"Bambara\"),\n    (\"ba\",  \"Bashkir\"),\n    (\"eu\",  \"Basque\"),\n    (\"bar\", \"Bavarian\"),\n    (\"bej\", \"Beja\"),\n    (\"be\",  \"Belarusian\"),\n    (\"bn\",  \"Bengali\"),\n    (\"bho\", \"Bhojpuri\"),\n    (\"bpy\", \"Bishnupriya_Manipuri\"),\n    (\"bi\",  \"Bislama\"),\n    (\"bor\", \"Bororo\"),\n    (\"sab\", \"Bokota\"),\n    (\"bs\",  \"Bosnian\"),\n    (\"br\",  \"Breton\"),\n    (\"bg\",  \"Bulgarian\"),\n    (\"bxr\", \"Buryat\"),\n    (\"yue\", \"Cantonese\"),\n    (\"cpg\", \"Cappadocian\"),\n    (\"ca\",  \"Catalan\"),\n    (\"ceb\", \"Cebuano\"),\n    (\"km\",  \"Central_Khmer\"),\n    (\"ch\",  \"Chamorro\"),\n    (\"ce\",  \"Chechen\"),\n    (\"ny\",  \"Chichewa\"),\n    (\"ctn\", \"Chintang\"),\n    (\"ckt\", \"Chukchi\"),\n    (\"cv\",  \"Chuvash\"),\n    (\"xcl\", \"Classical_Armenian\"),\n    (\"lzh\", \"Classical_Chinese\"),\n    (\"cop\", \"Coptic\"),\n    (\"kw\",  \"Cornish\"),\n    (\"co\",  \"Corsican\"),\n    (\"cr\",  \"Cree\"),\n    (\"hr\",  \"Croatian\"),\n    (\"cs\",  \"Czech\"),\n    (\"da\",  \"Danish\"),\n    (\"dar\", \"Dargwa\"),\n    (\"dv\",  \"Dhivehi\"),\n    (\"nl\",  \"Dutch\"),\n    (\"dz\",  \"Dzongkha\"),\n    (\"egy\", \"Egyptian\"),\n    (\"en\",  \"English\"),\n    (\"myv\", \"Erzya\"),\n    (\"eo\",  \"Esperanto\"),\n    (\"et\",  \"Estonian\"),\n    (\"ee\",  \"Ewe\"),\n    (\"ext\", \"Extremaduran\"),\n    (\"fo\",  \"Faroese\"),\n    (\"fj\",  \"Fijian\"),\n    (\"fi\",  \"Finnish\"),\n    (\"fon\", \"Fon\"),\n    (\"fr\",  \"French\"),\n    (\"qfn\", \"Frisian_Dutch\"),\n    (\"ff\",  \"Fulah\"),\n    (\"gl\",  \"Galician\"),\n    (\"lg\",  \"Ganda\"),\n    (\"ka\",  \"Georgian\"),\n    (\"de\",  \"German\"),\n    (\"aln\", \"Gheg\"),\n    (\"bbj\", \"Ghomálá'\"),\n    (\"got\", \"Gothic\"),\n    (\"el\",  \"Greek\"),\n    (\"kl\",  \"Greenlandic\"),\n    (\"gub\", \"Guajajara\"),\n    (\"gn\",  \"Guarani\"),\n    (\"gu\",  \"Gujarati\"),\n    (\"gwi\", \"Gwichin\"),\n    (\"ht\",  \"Haitian\"),\n    (\"ha\",  \"Hausa\"),\n    (\"he\",  \"Hebrew\"),\n    (\"hz\",  \"Herero\"),\n    (\"azz\", \"Highland_Puebla_Nahuatl\"),\n    (\"hil\", \"Hiligaynon\"),\n    (\"hi\",  \"Hindi\"),\n    (\"qhe\", \"Hindi_English\"),\n    (\"ho\",  \"Hiri_Motu\"),\n    (\"hit\", \"Hittite\"),\n    (\"hu\",  \"Hungarian\"),\n    (\"is\",  \"Icelandic\"),\n    (\"io\",  \"Ido\"),\n    (\"ig\",  \"Igbo\"),\n    (\"arh\", \"Ika\"),\n    (\"ilo\", \"Ilocano\"),\n    (\"arc\", \"Imperial_Aramaic\"),\n    (\"id\",  \"Indonesian\"),\n    (\"iu\",  \"Inuktitut\"),\n    (\"ik\",  \"Inupiaq\"),\n    (\"ga\",  \"Irish\"),\n    (\"it\",  \"Italian\"),\n    (\"ja\",  \"Japanese\"),\n    (\"jv\",  \"Javanese\"),\n    (\"urb\", \"Kaapor\"),\n    (\"kab\", \"Kabyle\"),\n    (\"xnr\", \"Kangri\"),\n    (\"kn\",  \"Kannada\"),\n    (\"kr\",  \"Kanuri\"),\n    (\"pam\", \"Kapampangan\"),\n    (\"krl\", \"Karelian\"),\n    (\"arr\", \"Karo\"),\n    (\"ks\",  \"Kashmiri\"),\n    (\"kk\",  \"Kazakh\"),\n    (\"naq\", \"Khoekhoe\"),\n    (\"kfm\", \"Khunsari\"),\n    (\"quc\", \"Kiche\"),\n    (\"cgg\", \"Kiga\"),\n    (\"ki\",  \"Kikuyu\"),\n    (\"rw\",  \"Kinyarwanda\"),\n    (\"ky\",  \"Kyrgyz\"),\n    (\"kv\",  \"Komi\"),\n    (\"koi\", \"Komi_Permyak\"),\n    (\"kpv\", \"Komi_Zyrian\"),\n    (\"kg\",  \"Kongo\"),\n    (\"ko\",  \"Korean\"),\n    (\"ku\",  \"Kurdish\"),\n    (\"kmr\", \"Northern_Kurdish\"),\n    (\"kj\",  \"Kwanyama\"),\n    (\"lad\", \"Ladino\"),\n    (\"lo\",  \"Lao\"),\n    (\"ltg\", \"Latgalian\"),\n    (\"la\",  \"Latin\"),\n    (\"lv\",  \"Latvian\"),\n    (\"lij\", \"Ligurian\"),\n    (\"li\",  \"Limburgish\"),\n    (\"ln\",  \"Lingala\"),\n    (\"lt\",  \"Lithuanian\"),\n    (\"liv\", \"Livonian\"),\n    (\"olo\", \"Livvi\"),\n    (\"nds\", \"Low_Saxon\"),\n    (\"lu\",  \"Luba_Katanga\"),\n    (\"lb\",  \"Luxembourgish\"),\n    (\"mk\",  \"Macedonian\"),\n    (\"jaa\", \"Madi\"),\n    (\"mag\", \"Magahi\"),\n    (\"qaf\", \"Maghrebi_Arabic_French\"),\n    (\"mai\", \"Maithili\"),\n    (\"mpu\", \"Makurap\"),\n    (\"mg\",  \"Malagasy\"),\n    (\"ms\",  \"Malay\"),\n    (\"ml\",  \"Malayalam\"),\n    (\"mt\",  \"Maltese\"),\n    (\"mjl\", \"Mandyali\"),\n    (\"gv\",  \"Manx\"),\n    (\"mi\",  \"Maori\"),\n    (\"mr\",  \"Marathi\"),\n    (\"mh\",  \"Marshallese\"),\n    (\"mzn\", \"Mazandarani\"),\n    (\"gun\", \"Mbya_Guarani\"),\n    (\"enm\", \"Middle_English\"),\n    (\"frm\", \"Middle_French\"),\n    (\"min\", \"Minangkabau\"),\n    (\"xmf\", \"Mingrelian\"),\n    (\"mwl\", \"Mirandese\"),\n    (\"mdf\", \"Moksha\"),\n    (\"mn\",  \"Mongolian\"),\n    (\"mos\", \"Mossi\"),\n    (\"myu\", \"Munduruku\"),\n    (\"my\",  \"Myanmar\"),\n    (\"nqo\", \"N'Ko\"),\n    (\"nmf\", \"Naga\"),\n    (\"nah\", \"Nahuatl\"),\n    (\"pcm\", \"Naija\"),\n    (\"na\",  \"Nauru\"),\n    (\"nv\",  \"Navajo\"),\n    (\"nyq\", \"Nayini\"),\n    (\"ng\",  \"Ndonga\"),\n    (\"nap\", \"Neapolitan\"),\n    (\"nrk\", \"Nenets\"),\n    (\"ne\",  \"Nepali\"),\n    (\"new\", \"Newar\"),\n    (\"yrl\", \"Nheengatu\"),\n    (\"nyn\", \"Nkore\"),\n    (\"frr\", \"North_Frisian\"),\n    (\"nd\",  \"North_Ndebele\"),\n    (\"sme\", \"North_Sami\"),\n    (\"nso\", \"Northern_Sotho\"),\n    (\"gya\", \"Northwest_Gbaya\"),\n    (\"nb\",  \"Norwegian_Bokmaal\"),\n    (\"nn\",  \"Norwegian_Nynorsk\"),\n    (\"ii\",  \"Nuosu\"),\n    (\"oc\",  \"Occitan\"),\n    (\"or\",  \"Odia\"),\n    (\"oj\",  \"Ojibwa\"),\n    (\"cu\",  \"Old_Church_Slavonic\"),\n    (\"orv\", \"Old_East_Slavic\"),\n    (\"ang\", \"Old_English\"),\n    (\"fro\", \"Old_French\"),\n    (\"sga\", \"Old_Irish\"),\n    (\"ojp\", \"Old_Japanese\"),\n    (\"pro\", \"Old_Occitan\"),\n    (\"otk\", \"Old_Turkish\"),\n    (\"om\",  \"Oromo\"),\n    (\"os\",  \"Ossetian\"),\n    (\"ota\", \"Ottoman_Turkish\"),\n    (\"pi\",  \"Pali\"),\n    (\"ps\",  \"Pashto\"),\n    (\"pad\", \"Paumari\"),\n    (\"fa\",  \"Persian\"),\n    (\"pay\", \"Pesh\"),\n    (\"xpg\", \"Phrygian\"),\n    (\"pbv\", \"Pnar\"),\n    (\"pl\",  \"Polish\"),\n    (\"qpm\", \"Pomak\"),\n    (\"pnt\", \"Pontic\"),\n    (\"pt\",  \"Portuguese\"),\n    (\"pra\", \"Prakrit\"),\n    (\"pa\",  \"Punjabi\"),\n    (\"qu\",  \"Quechua\"),\n    (\"rhg\", \"Rohingya\"),\n    (\"ro\",  \"Romanian\"),\n    (\"rm\",  \"Romansh\"),\n    (\"rn\",  \"Rundi\"),\n    (\"ru\",  \"Russian\"),\n    (\"sm\",  \"Samoan\"),\n    (\"sg\",  \"Sango\"),\n    (\"sa\",  \"Sanskrit\"),\n    (\"skr\", \"Saraiki\"),\n    (\"sc\",  \"Sardinian\"),\n    (\"sco\", \"Scots\"),\n    (\"gd\",  \"Scottish_Gaelic\"),\n    (\"sr\",  \"Serbian\"),\n    (\"wuu\", \"Shanghainese\"),\n    (\"sn\",  \"Shona\"),\n    (\"zh-hans\", \"Simplified_Chinese\"),\n    (\"scn\", \"Sicilian\"),\n    (\"sd\",  \"Sindhi\"),\n    (\"si\",  \"Sinhala\"),\n    (\"sms\", \"Skolt_Sami\"),\n    (\"sk\",  \"Slovak\"),\n    (\"sl\",  \"Slovenian\"),\n    (\"soj\", \"Soi\"),\n    (\"so\",  \"Somali\"),\n    (\"ckb\", \"Sorani\"),\n    (\"ajp\", \"South_Levantine_Arabic\"),\n    (\"sdh\", \"Southern_Kurdish\"),\n    (\"nr\",  \"South_Ndebele\"),\n    (\"st\",  \"Southern_Sotho\"),\n    (\"es\",  \"Spanish\"),\n    (\"ssp\", \"Spanish_Sign_Language\"),\n    (\"su\",  \"Sundanese\"),\n    (\"sw\",  \"Swahili\"),\n    (\"ss\",  \"Swati\"),\n    (\"sv\",  \"Swedish\"),\n    (\"swl\", \"Swedish_Sign_Language\"),\n    (\"gsw\", \"Swiss_German\"),\n    (\"syr\", \"Syriac\"),\n    (\"tl\",  \"Tagalog\"),\n    (\"ty\",  \"Tahitian\"),\n    (\"tg\",  \"Tajik\"),\n    (\"ta\",  \"Tamil\"),\n    (\"tt\",  \"Tatar\"),\n    (\"eme\", \"Teko\"),\n    (\"te\",  \"Telugu\"),\n    (\"qte\", \"Telugu_English\"),\n    (\"th\",  \"Thai\"),\n    (\"bo\",  \"Tibetan\"),\n    (\"ti\",  \"Tigrinya\"),\n    (\"to\",  \"Tonga\"),\n    (\"zh-hant\", \"Traditional_Chinese\"),\n    (\"ts\",  \"Tsonga\"),\n    (\"tn\",  \"Tswana\"),\n    (\"tpn\", \"Tupinamba\"),\n    (\"tr\",  \"Turkish\"),\n    (\"qti\", \"Turkish_English\"),\n    (\"qtd\", \"Turkish_German\"),\n    (\"tk\",  \"Turkmen\"),\n    (\"tw\",  \"Twi\"),\n    (\"uk\",  \"Ukrainian\"),\n    (\"xum\", \"Umbrian\"),\n    (\"hsb\", \"Upper_Sorbian\"),\n    (\"ur\",  \"Urdu\"),\n    (\"ug\",  \"Uyghur\"),\n    (\"uz\",  \"Uzbek\"),\n    (\"ve\",  \"Venda\"),\n    (\"vep\", \"Veps\"),\n    (\"vi\",  \"Vietnamese\"),\n    (\"vo\",  \"Volapük\"),\n    (\"wa\",  \"Walloon\"),\n    (\"war\", \"Waray\"),\n    (\"wbp\", \"Warlpiri\"),\n    (\"cy\",  \"Welsh\"),\n    (\"hyw\", \"Western_Armenian\"),\n    (\"fy\",  \"Western_Frisian\"),\n    (\"nhi\", \"Western_Sierra_Puebla_Nahuatl\"),\n    (\"wo\",  \"Wolof\"),\n    (\"xav\", \"Xavante\"),\n    (\"xh\",  \"Xhosa\"),\n    (\"sjo\", \"Xibe\"),\n    (\"sah\", \"Yakut\"),\n    (\"yi\",  \"Yiddish\"),\n    (\"yo\",  \"Yoruba\"),\n    (\"ess\", \"Yupik\"),\n    (\"say\", \"Zaar\"),\n    (\"zza\", \"Zazaki\"),\n    (\"zea\", \"Zeelandic\"),\n    (\"za\",  \"Zhuang\"),\n    (\"zu\",  \"Zulu\"),\n]\n\n# build the dictionary, checking for duplicate language codes\nlcode2lang = {}\nfor code, language in lcode2lang_raw:\n    assert code not in lcode2lang\n    lcode2lang[code] = language\n\n# invert the dictionary, checking for possible duplicate language names\nlang2lcode = {}\nfor code, language in lcode2lang_raw:\n    assert language not in lang2lcode\n    lang2lcode[language] = code\n\n# check that nothing got clobbered\nassert len(lcode2lang_raw) == len(lcode2lang)\nassert len(lcode2lang_raw) == len(lang2lcode)\n\n# some of the two letter langcodes get used elsewhere as three letters\n# for example, Wolof is abbreviated \"wo\" in UD, but \"wol\" in Masakhane NER\ntwo_to_three_letters_raw = (\n    (\"bm\",  \"bam\"),\n    (\"ee\",  \"ewe\"),\n    (\"ha\",  \"hau\"),\n    (\"ig\",  \"ibo\"),\n    (\"rw\",  \"kin\"),\n    (\"lg\",  \"lug\"),\n    (\"ny\",  \"nya\"),\n    (\"sn\",  \"sna\"),\n    (\"sw\",  \"swa\"),\n    (\"tn\",  \"tsn\"),\n    (\"tw\",  \"twi\"),\n    (\"wo\",  \"wol\"),\n    (\"xh\",  \"xho\"),\n    (\"yo\",  \"yor\"),\n    (\"zu\",  \"zul\"),\n\n    # this is a weird case where a 2 letter code was available,\n    # but UD used the 3 letter code instead\n    (\"se\",  \"sme\"),\n)\n\nfor two, three in two_to_three_letters_raw:\n    if two in lcode2lang:\n        assert two in lcode2lang\n        assert three not in lcode2lang\n        assert three not in lang2lcode\n        lang2lcode[three] = two\n        lcode2lang[three] = lcode2lang[two]\n    elif three in lcode2lang:\n        assert three in lcode2lang\n        assert two not in lcode2lang\n        assert two not in lang2lcode\n        lang2lcode[two] = three\n        lcode2lang[two] = lcode2lang[three]\n    else:\n        raise AssertionError(\"Found a proposed alias %s -> %s when neither code was already known\" % (two, three))\n\ntwo_to_three_letters = {\n    two: three for two, three in two_to_three_letters_raw\n}\n\nthree_to_two_letters = {\n    three: two for two, three in two_to_three_letters_raw\n}\n\nassert len(two_to_three_letters) == len(two_to_three_letters_raw)\nassert len(three_to_two_letters) == len(two_to_three_letters_raw)\n\n# additional useful code to language mapping\n# added after dict invert to avoid conflict\nlcode2lang['bgd'] = 'Bokota'   # ISO 693-3 code, although UD used sab\nlcode2lang['nb'] = 'Norwegian' # Norwegian Bokmall mapped to default norwegian\nlcode2lang['no'] = 'Norwegian'\nlcode2lang['zh'] = 'Simplified_Chinese'\n\nextra_lang_to_lcodes = {\n    \"ab\":  \"Abkhaz\",\n    \"gsw\": \"Alemannic\",\n    \"my\":  \"Burmese\",\n    \"ckb\": \"Central_Kurdish\",\n    \"ny\":  \"Chewa\",\n    \"zh\":  \"Chinese\",\n    \"za\":  \"Chuang\",\n    \"dv\":  \"Divehi\",\n    \"eme\": \"Emerillon\",\n    \"lij\": \"Genoese\",\n    \"ga\":  \"Gaelic\",\n    \"ne\":  \"Gorkhali\",\n    \"ht\":  \"Haitian_Creole\",\n    \"ilo\": \"Ilokano\",\n    \"nr\":  \"isiNdebele\",\n    \"xh\":  \"isiXhosa\",\n    \"zu\":  \"isiZulu\",\n    \"jaa\": \"Jamamadí\",\n    \"kab\": \"Kabylian\",\n    \"kl\":  \"Kalaallisut\",\n    \"km\":  \"Khmer\",\n    \"ky\":  \"Kirghiz\",\n    \"lb\":  \"Letzeburgesch\",\n    \"lg\":  \"Luganda\",\n    \"jaa\": \"Madí\",\n    \"dv\":  \"Maldivian\",\n    \"mjl\": \"Mandeali\",\n    \"skr\": \"Multani\",\n    \"nb\":  \"Norwegian\",\n    \"kmr\": \"Kurmanji\",\n    \"ny\":  \"Nyanja\",\n    \"sga\": \"Old_Gaelic\",\n    \"or\":  \"Oriya\",\n    \"arr\": \"Ramarama\",\n    \"sah\": \"Sakha\",\n    \"nso\": \"Sepedi\",\n    \"tn\":  \"Setswana\",\n    \"ii\":  \"Sichuan_Yi\",\n    \"si\":  \"Sinhalese\",\n    \"ss\":  \"Siswati\",\n    \"soj\": \"Sohi\",\n    \"st\":  \"Sesotho\",\n    \"ve\":  \"Tshivenda\",\n    \"ts\":  \"Xitsonga\",\n    \"fy\":  \"West_Frisian\",\n    \"zza\": \"Zaza\",\n}\n\nfor code, language in extra_lang_to_lcodes.items():\n    assert language not in lang2lcode\n    assert code in lcode2lang\n    lang2lcode[language] = code\n\n# treebank names changed from Old Russian to Old East Slavic in 2.8\nlang2lcode['Old_Russian'] = 'orv'\n\n# build a lowercase map from language to langcode\nlanglower2lcode = {}\nfor k in lang2lcode:\n    langlower2lcode[k.lower()] = lang2lcode[k]\n\ntreebank_special_cases = {\n    \"UD_Chinese-Beginner\": \"zh-hans_beginner\",\n    \"UD_Chinese-GSDSimp\": \"zh-hans_gsdsimp\",\n    \"UD_Chinese-GSD\": \"zh-hant_gsd\",\n    \"UD_Chinese-HK\": \"zh-hant_hk\",\n    \"UD_Chinese-CFL\": \"zh-hans_cfl\",\n    \"UD_Chinese-PatentChar\": \"zh-hans_patentchar\",\n    \"UD_Chinese-PUD\": \"zh-hant_pud\",\n    \"UD_Norwegian-Bokmaal\": \"nb_bokmaal\",\n    \"UD_Norwegian-Nynorsk\": \"nn_nynorsk\",\n    \"UD_Norwegian-NynorskLIA\": \"nn_nynorsklia\",\n}\n\nSHORTNAME_RE = re.compile(\"^[a-z-]+_[a-z0-9-_]+$\")\n\ndef langcode_to_lang(lcode):\n    if lcode in lcode2lang:\n        return lcode2lang[lcode]\n    elif lcode.lower() in lcode2lang:\n        return lcode2lang[lcode.lower()]\n    else:\n        return lcode\n\ndef pretty_langcode_to_lang(lcode):\n    lang = langcode_to_lang(lcode)\n    lang = lang.replace(\"_\", \" \")\n    if lang == 'Simplified Chinese':\n        lang = 'Chinese (Simplified)'\n    elif lang == 'Traditional Chinese':\n        lang = 'Chinese (Traditional)'\n    return lang\n\ndef lang_to_langcode(lang):\n    if lang in lang2lcode:\n        lcode = lang2lcode[lang]\n    elif lang.lower() in langlower2lcode:\n        lcode = langlower2lcode[lang.lower()]\n    elif lang in lcode2lang:\n        lcode = lang\n    elif lang.lower() in lcode2lang:\n        lcode = lang.lower()\n    else:\n        raise UnknownLanguageError(\"Unable to find language code for %s\" % lang)\n    return lcode\n\nRIGHT_TO_LEFT = set([\"ar\", \"arc\", \"az\", \"ckb\", \"dv\", \"ff\", \"he\", \"ku\", \"mzn\", \"nqo\", \"ps\", \"fa\", \"rhg\", \"sd\", \"syr\", \"ur\"])\n\ndef is_right_to_left(lang):\n    \"\"\"\n    Covers all the RtL languages we support, as well as many we don't.\n\n    If a language is left out, please let us know!\n    \"\"\"\n    lcode = lang_to_langcode(lang)\n    return lcode in RIGHT_TO_LEFT\n\ndef treebank_to_short_name(treebank):\n    \"\"\" Convert treebank name to short code. \"\"\"\n    if treebank in treebank_special_cases:\n        return treebank_special_cases.get(treebank)\n    if SHORTNAME_RE.match(treebank):\n        lang, corpus = treebank.split(\"_\", 1)\n        lang = lang_to_langcode(lang)\n        return lang + \"_\" + corpus\n\n    if treebank.startswith('UD_'):\n        treebank = treebank[3:]\n    # special case starting with zh in case the input is an already-converted ZH treebank\n    if treebank.startswith(\"zh-hans\") or treebank.startswith(\"zh-hant\"):\n        splits = (treebank[:len(\"zh-hans\")], treebank[len(\"zh-hans\")+1:])\n    else:\n        splits = treebank.split('-')\n        if len(splits) == 1:\n            splits = treebank.split(\"_\", 1)\n    assert len(splits) == 2, \"Unable to process %s\" % treebank\n    lang, corpus = splits\n\n    lcode = lang_to_langcode(lang)\n\n    short = \"{}_{}\".format(lcode, corpus.lower())\n    return short\n\ndef treebank_to_langid(treebank):\n    \"\"\" Convert treebank name to langid \"\"\"\n    short_name = treebank_to_short_name(treebank)\n    return short_name.split(\"_\")[0]\n\n"
  },
  {
    "path": "stanza/models/common/convert_pretrain.py",
    "content": "\"\"\"\nA utility script to load a word embedding file from a text file and save it as a .pt\n\nRun it as follows:\n  python stanza/models/common/convert_pretrain.py <.pt file> <text file> <# vectors>\n\nNote that -1 for # of vectors will keep all the vectors.\nYou probably want to keep fewer than that for most publicly released\nembeddings, though, as they can get quite large.\n\nAs a concrete example, you can convert a newly downloaded Faroese WV file as follows:\n  python3 stanza/models/common/convert_pretrain.py ~/stanza/saved_models/pos/fo_farpahc.pretrain.pt ~/extern_data/wordvec/fasttext/faroese.txt -1\nor save part of an Icelandic WV file:\n  python3 stanza/models/common/convert_pretrain.py ~/stanza/saved_models/pos/is_icepahc.pretrain.pt ~/extern_data/wordvec/fasttext/icelandic.cc.is.300.vec 150000\nNote that if the pretrain already exists, nothing will be changed.  It will not overwrite an existing .pt file.\n\n\"\"\"\n\nimport argparse\nimport os\nimport sys\n\nfrom stanza.models.common import pretrain\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"output_pt\", default=None, help=\"Where to write the converted PT file\")\n    parser.add_argument(\"input_vec\", default=None, help=\"Unconverted vectors file\")\n    parser.add_argument(\"max_vocab\", type=int, default=-1, nargs=\"?\", help=\"How many vectors to convert.  -1 means convert them all\")\n    args = parser.parse_args()\n\n    if os.path.exists(args.output_pt):\n        print(\"Not overwriting existing pretrain file in %s\" % args.output_pt)\n\n    if args.input_vec.endswith(\".csv\"):\n        pt = pretrain.Pretrain(args.output_pt, max_vocab=args.max_vocab, csv_filename=args.input_vec)\n    else:\n        pt = pretrain.Pretrain(args.output_pt, args.input_vec, max_vocab=args.max_vocab)\n    print(\"Pretrain is of size {}\".format(len(pt.vocab)))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/common/count_ner_coverage.py",
    "content": "from stanza.models.common import pretrain\nimport argparse\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('ners', type=str, nargs='*', help='Which treebanks to run on')\n    parser.add_argument('--pretrain', type=str, default=\"/home/john/stanza_resources/hi/pretrain/hdtb.pt\", help='Which pretrain to use')\n    parser.set_defaults(ners=[\"/home/john/stanza/data/ner/hi_fire2013.train.csv\",\n                              \"/home/john/stanza/data/ner/hi_fire2013.dev.csv\"])\n    args = parser.parse_args()\n    return args\n\n\ndef read_ner(filename):\n    words = []\n    for line in open(filename).readlines():\n        line = line.strip()\n        if not line:\n            continue\n        if line.split(\"\\t\")[1] == 'O':\n            continue\n        words.append(line.split(\"\\t\")[0])\n    return words\n\ndef count_coverage(pretrain, words):\n    count = 0\n    for w in words:\n        if w in pretrain.vocab:\n            count = count + 1\n    return count / len(words)\n\nargs = parse_args()\npt = pretrain.Pretrain(args.pretrain)\nfor dataset in args.ners:\n    words = read_ner(dataset)\n    print(dataset)\n    print(count_coverage(pt, words))\n    print()\n"
  },
  {
    "path": "stanza/models/common/count_pretrain_coverage.py",
    "content": "\"\"\"A simple script to count the fraction of words in a UD dataset which are in a particular pretrain.\n\nFor example, this script shows that the word2vec Armenian vectors,\ntruncated at 250K words, have 75% coverage of the Western Armenian\ndataset, whereas the vectors available here have 88% coverage:\n\nhttps://github.com/ispras-texterra/word-embeddings-eval-hy\n\"\"\"\n\nfrom stanza.models.common import pretrain\nfrom stanza.utils.conll import CoNLL\n\nimport argparse\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('treebanks', type=str, nargs='*', help='Which treebanks to run on')\n    parser.add_argument('--pretrain', type=str, default=\"/home/john/extern_data/wordvec/glove/armenian.pt\", help='Which pretrain to use')\n    parser.set_defaults(treebanks=[\"/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Western_Armenian-ArmTDP/hyw_armtdp-ud-train.conllu\",\n                                   \"/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Armenian-ArmTDP/hy_armtdp-ud-train.conllu\"])\n    args = parser.parse_args()\n    return args\n\n\nargs = parse_args()\npt = pretrain.Pretrain(args.pretrain)\npt.load()\nprint(\"Pretrain stats: {} vectors, {} dim\".format(len(pt.vocab), pt.emb[0].shape[0]))\n\nfor treebank in args.treebanks:\n    print(treebank)\n    found = 0\n    total = 0\n    doc = CoNLL.conll2doc(treebank)\n    for sentence in doc.sentences:\n        for word in sentence.words:\n            total = total + 1\n            if word.text in pt.vocab:\n                found = found + 1\n\n    print (found / total)\n"
  },
  {
    "path": "stanza/models/common/crf.py",
    "content": "\"\"\"\nCRF loss and viterbi decoding.\n\"\"\"\n\nimport math\nfrom numbers import Number\nimport numpy as np\nimport torch\nfrom torch import nn\nimport torch.nn.init as init\n\nclass CRFLoss(nn.Module):\n    \"\"\"\n    Calculate log-space crf loss, given unary potentials, a transition matrix\n    and gold tag sequences.\n    \"\"\"\n    def __init__(self, num_tag, batch_average=True):\n        super().__init__()\n        self._transitions = nn.Parameter(torch.zeros(num_tag, num_tag))\n        self._batch_average = batch_average # if not batch average, average on all tokens\n\n    def forward(self, inputs, masks, tag_indices):\n        \"\"\"\n        inputs: batch_size x seq_len x num_tags\n        masks: batch_size x seq_len\n        tag_indices: batch_size x seq_len\n        \n        @return:\n            loss: CRF negative log likelihood on all instances.\n            transitions: the transition matrix\n        \"\"\"\n        # TODO: handle <start> and <end> tags\n        input_bs, input_sl, input_nc = inputs.size()\n        unary_scores = self.crf_unary_score(inputs, masks, tag_indices, input_bs, input_sl, input_nc)\n        binary_scores = self.crf_binary_score(inputs, masks, tag_indices, input_bs, input_sl, input_nc)\n        log_norm = self.crf_log_norm(inputs, masks, tag_indices)\n        log_likelihood = unary_scores + binary_scores - log_norm # batch_size\n        loss = torch.sum(-log_likelihood)\n        if self._batch_average:\n            loss = loss / input_bs\n        else:\n            total = masks.eq(0).sum()\n            loss = loss / (total + 1e-8)\n        return loss, self._transitions\n\n    def crf_unary_score(self, inputs, masks, tag_indices, input_bs, input_sl, input_nc):\n        \"\"\"\n        @return:\n            unary_scores: batch_size\n        \"\"\"\n        flat_inputs = inputs.view(input_bs, -1)\n        flat_tag_indices = tag_indices + torch.arange(input_sl, device=tag_indices.device).long().unsqueeze(0) * input_nc\n        unary_scores = torch.gather(flat_inputs, 1, flat_tag_indices).view(input_bs, -1)\n        unary_scores.masked_fill_(masks, 0)\n        return unary_scores.sum(dim=1)\n    \n    def crf_binary_score(self, inputs, masks, tag_indices, input_bs, input_sl, input_nc):\n        \"\"\"\n        @return:\n            binary_scores: batch_size\n        \"\"\"\n        # get number of transitions\n        nt = tag_indices.size(-1) - 1\n        start_indices = tag_indices[:, :nt]\n        end_indices = tag_indices[:, 1:]\n        # flat matrices\n        flat_transition_indices = start_indices * input_nc + end_indices\n        flat_transition_indices = flat_transition_indices.view(-1)\n        flat_transition_matrix = self._transitions.view(-1)\n        binary_scores = torch.gather(flat_transition_matrix, 0, flat_transition_indices)\\\n                .view(input_bs, -1)\n        score_masks = masks[:, 1:]\n        binary_scores.masked_fill_(score_masks, 0)\n        return binary_scores.sum(dim=1)\n\n    def crf_log_norm(self, inputs, masks, tag_indices):\n        \"\"\"\n        Calculate the CRF partition in log space for each instance, following:\n            http://www.cs.columbia.edu/~mcollins/fb.pdf\n        @return:\n            log_norm: batch_size\n        \"\"\"\n        start_inputs = inputs[:,0,:] # bs x nc\n        rest_inputs = inputs[:,1:,:]\n        # TODO: technically we need to pay attention to the initial\n        # value being masked.  Currently we do compensate for the\n        # entire row being masked at the end of the operation\n        rest_masks = masks[:,1:]\n        alphas = start_inputs # bs x nc\n        trans = self._transitions.unsqueeze(0) # 1 x nc x nc\n        # accumulate alphas in log space\n        for i in range(rest_inputs.size(1)):\n            transition_scores = alphas.unsqueeze(2) + trans # bs x nc x nc\n            new_alphas = rest_inputs[:,i,:] + log_sum_exp(transition_scores, dim=1)\n            m = rest_masks[:,i].unsqueeze(1).expand_as(new_alphas) # bs x nc, 1 for padding idx\n            # apply masks\n            new_alphas.masked_scatter_(m, alphas.masked_select(m))\n            alphas = new_alphas\n        log_norm = log_sum_exp(alphas, dim=1)\n\n        # if any row was entirely masked, we just turn its log denominator to 0\n        # eg, the empty summation for the denominator will be 1, and its log will be 0\n        all_masked = torch.all(masks, dim=1)\n        log_norm = log_norm * torch.logical_not(all_masked)\n        return log_norm\n\ndef viterbi_decode(scores, transition_params):\n    \"\"\"\n    Decode a tag sequence with viterbi algorithm.\n    scores: seq_len x num_tags (numpy array)\n    transition_params: num_tags x num_tags (numpy array)\n    @return:\n        viterbi: a list of tag ids with highest score\n        viterbi_score: the highest score\n    \"\"\"\n    trellis = np.zeros_like(scores)\n    backpointers = np.zeros_like(scores, dtype=np.int32)\n    trellis[0] = scores[0]\n\n    for t in range(1, scores.shape[0]):\n        v = np.expand_dims(trellis[t-1], 1) + transition_params\n        trellis[t] = scores[t] + np.max(v, 0)\n        backpointers[t] = np.argmax(v, 0)\n\n    viterbi = [np.argmax(trellis[-1])]\n    for bp in reversed(backpointers[1:]):\n        viterbi.append(bp[viterbi[-1]])\n    viterbi.reverse()\n    viterbi_score = np.max(trellis[-1])\n    return viterbi, viterbi_score\n\ndef log_sum_exp(value, dim=None, keepdim=False):\n    \"\"\"Numerically stable implementation of the operation\n    value.exp().sum(dim, keepdim).log()\n    \"\"\"\n    if dim is not None:\n        m, _ = torch.max(value, dim=dim, keepdim=True)\n        value0 = value - m\n        if keepdim is False:\n            m = m.squeeze(dim)\n        return m + torch.log(torch.sum(torch.exp(value0),\n                                       dim=dim, keepdim=keepdim))\n    else:\n        m = torch.max(value)\n        sum_exp = torch.sum(torch.exp(value - m))\n        if isinstance(sum_exp, Number):\n            return m + math.log(sum_exp)\n        else:\n            return m + torch.log(sum_exp)\n"
  },
  {
    "path": "stanza/models/common/data.py",
    "content": "\"\"\"\nUtility functions for data transformations.\n\"\"\"\n\nimport logging\nimport random\n\nimport torch\n\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.models.common.doc import HEAD, ID, UPOS\n\nlogger = logging.getLogger('stanza')\n\ndef map_to_ids(tokens, vocab):\n    ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens]\n    return ids\n\ndef get_long_tensor(tokens_list, batch_size, pad_id=constant.PAD_ID):\n    \"\"\" Convert (list of )+ tokens to a padded LongTensor. \"\"\"\n    sizes = []\n    x = tokens_list\n    while isinstance(x[0], list):\n        sizes.append(max(len(y) for y in x))\n        x = [z for y in x for z in y]\n    # TODO: pass in a device parameter and put it directly on the relevant device?\n    # that might be faster than creating it and then moving it\n    tokens = torch.LongTensor(batch_size, *sizes).fill_(pad_id)\n    for i, s in enumerate(tokens_list):\n        tokens[i, :len(s)] = torch.LongTensor(s)\n    return tokens\n\ndef get_float_tensor(features_list, batch_size):\n    if features_list is None or features_list[0] is None:\n        return None\n    seq_len = max(len(x) for x in features_list)\n    feature_len = len(features_list[0][0])\n    features = torch.FloatTensor(batch_size, seq_len, feature_len).zero_()\n    for i,f in enumerate(features_list):\n        features[i,:len(f),:] = torch.FloatTensor(f)\n    return features\n\ndef sort_all(batch, lens):\n    \"\"\" Sort all fields by descending order of lens, and return the original indices. \"\"\"\n    if batch == [[]]:\n        return [[]], []\n    unsorted_all = [lens] + [range(len(lens))] + list(batch)\n    sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]\n    return sorted_all[2:], sorted_all[1]\n\ndef get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate, desired_ratio=0.1, max_ratio=0.5):\n    \"\"\"\n    Returns X so that if you randomly select X * N sentences, you get 10%\n\n    The ratio will be chosen in the assumption that the final dataset\n    is of size N rather than N + X * N.\n\n    should_augment_predicate: returns True if the sentence has some\n      feature which we may want to change occasionally.  for example,\n      depparse sentences which end in punct\n    can_augment_predicate: in the depparse sentences example, it is\n      technically possible for the punct at the end to be the parent\n      of some other word in the sentence.  in that case, the sentence\n      should not be chosen.  should be at least as restrictive as\n      should_augment_predicate\n    \"\"\"\n    n_data = len(train_data)\n    n_should_augment = sum(should_augment_predicate(sentence) for sentence in train_data)\n    n_can_augment = sum(can_augment_predicate(sentence) for sentence in train_data)\n    n_error = sum(can_augment_predicate(sentence) and not should_augment_predicate(sentence)\n                  for sentence in train_data)\n    if n_error > 0:\n        raise AssertionError(\"can_augment_predicate allowed sentences not allowed by should_augment_predicate\")\n\n    if n_can_augment == 0:\n        logger.warning(\"Found no sentences which matched can_augment_predicate {}\".format(can_augment_predicate))\n        return 0.0\n    n_needed = n_data * desired_ratio - (n_data - n_should_augment)\n    # if we want 10%, for example, and more than 10% already matches, we can skip\n    if n_needed < 0:\n        return 0.0\n    ratio = n_needed / n_can_augment\n    if ratio > max_ratio:\n        return max_ratio\n    return ratio\n\n\ndef should_augment_nopunct_predicate(sentence):\n    last_word = sentence[-1]\n    return last_word.get(UPOS, None) == 'PUNCT'\n\ndef can_augment_nopunct_predicate(sentence):\n    \"\"\"\n    Check that the sentence ends with PUNCT and also doesn't have any words which depend on the last word\n    \"\"\"\n    last_word = sentence[-1]\n    if last_word.get(UPOS, None) != 'PUNCT':\n        return False\n    # don't cut off MWT\n    if len(last_word[ID]) > 1:\n        return False\n    if any(len(word[ID]) == 1 and word[HEAD] == last_word[ID][0] for word in sentence):\n        return False\n    return True\n\ndef augment_punct(train_data, augment_ratio,\n                  should_augment_predicate=should_augment_nopunct_predicate,\n                  can_augment_predicate=can_augment_nopunct_predicate,\n                  keep_original_sentences=True):\n\n    \"\"\"\n    Adds extra training data to compensate for some models having all sentences end with PUNCT\n\n    Some of the models (for example, UD_Hebrew-HTB) have the flaw that\n    all of the training sentences end with PUNCT.  The model therefore\n    learns to finish every sentence with punctuation, even if it is\n    given a sentence with non-punct at the end.\n\n    One simple way to fix this is to train on some fraction of training data with punct.\n\n    Params:\n    train_data: list of list of dicts, eg a conll doc\n    augment_ratio: the fraction to augment.  if None, a best guess is made to get to 10%\n\n    should_augment_predicate: a function which returns T/F if a sentence already ends with not PUNCT\n    can_augment_predicate: a function which returns T/F if it makes sense to remove the last PUNCT\n\n    TODO: do this dynamically, as part of the DataLoader or elsewhere?\n    One complication is the data comes back from the DataLoader as\n    tensors & indices, so it is much more complicated to manipulate\n    \"\"\"\n    if len(train_data) == 0:\n        return []\n\n    if augment_ratio is None:\n        augment_ratio = get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate)\n\n    if augment_ratio <= 0:\n        if keep_original_sentences:\n            return list(train_data)\n        else:\n            return []\n\n    new_data = []\n    for sentence in train_data:\n        if can_augment_predicate(sentence):\n            if random.random() < augment_ratio and len(sentence) > 1:\n                # todo: could deep copy the words\n                #       or not deep copy any of this\n                new_sentence = list(sentence[:-1])\n                new_data.append(new_sentence)\n            elif keep_original_sentences:\n                new_data.append(new_sentence)\n\n    return new_data\n"
  },
  {
    "path": "stanza/models/common/doc.py",
    "content": "\"\"\"\nBasic data structures\n\"\"\"\n\nimport io\nfrom itertools import repeat\nimport re\nimport json\nimport pickle\nimport warnings\n\nfrom enum import Enum\n\nimport networkx as nx\n\nfrom stanza.models.common.stanza_object import StanzaObject\nfrom stanza.models.common.utils import misc_to_space_after, space_after_to_misc, misc_to_space_before, space_before_to_misc\nfrom stanza.models.ner.utils import decode_from_bioes\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.coref.coref_chain import CorefMention, CorefChain, CorefAttachment\n\nclass MWTProcessingType(Enum):\n    FLATTEN = 0 # flatten the current token into one ID instead of MWT\n    PROCESS = 1 # process the current token as an MWT and expand it as such\n    SKIP = 2 # do nothing on this token, simply increment IDs\n\nmulti_word_token_id = re.compile(r\"([0-9]+)-([0-9]+)\")\nmulti_word_token_misc = re.compile(r\".*MWT=Yes.*\")\n\nMEXP = 'manual_expansion'\nID = 'id'\nTEXT = 'text'\nLEMMA = 'lemma'\nUPOS = 'upos'\nXPOS = 'xpos'\nFEATS = 'feats'\nHEAD = 'head'\nDEPREL = 'deprel'\nDEPS = 'deps'\nMISC = 'misc'\nNER = 'ner'\nMULTI_NER = 'multi_ner'     # will represent tags from multiple NER models\nSTART_CHAR = 'start_char'\nEND_CHAR = 'end_char'\nTYPE = 'type'\nSENTIMENT = 'sentiment'\nCONSTITUENCY = 'constituency'\nCOREF_CHAINS = 'coref_chains'\nLINE_NUMBER = 'line_number'\nMORPHEMES = 'morphemes'\n\n# field indices when converting the document to conll\nFIELD_TO_IDX = {ID: 0, TEXT: 1, LEMMA: 2, UPOS: 3, XPOS: 4, FEATS: 5, HEAD: 6, DEPREL: 7, DEPS: 8, MISC: 9}\nFIELD_NUM = len(FIELD_TO_IDX)\n\nDEFAULT_OUTPUT_FIELDS = [ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, NER, MULTI_NER, MEXP, COREF_CHAINS, MORPHEMES]\nNO_OFFSETS_OUTPUT_FIELDS = [ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, NER, MULTI_NER, MEXP, COREF_CHAINS, MORPHEMES]\n\nclass DocJSONEncoder(json.JSONEncoder):\n    def default(self, obj):\n        if isinstance(obj, CorefMention):\n            return obj.__dict__\n        if isinstance(obj, CorefAttachment):\n            return obj.to_json()\n        return json.JSONEncoder.default(self, obj)\n\nclass Document(StanzaObject):\n    \"\"\" A document class that stores attributes of a document and carries a list of sentences.\n    \"\"\"\n\n    def __init__(self, sentences, text=None, comments=None, empty_sentences=None):\n        \"\"\" Construct a document given a list of sentences in the form of lists of CoNLL-U dicts.\n\n        Args:\n            sentences: a list of sentences, which being a list of token entry, in the form of a CoNLL-U dict.\n            text: the raw text of the document.\n            comments: A list of list of strings to use as comments on the sentences, either None or the same length as sentences\n        \"\"\"\n        self._sentences = []\n        self._lang = None\n        self._text = text\n        self._num_tokens = 0\n        self._num_words = 0\n\n        self._process_sentences(sentences, comments, empty_sentences)\n        self._ents = []\n        self._coref = []\n        if self._text is not None:\n            self.build_ents()\n            self.mark_whitespace()\n\n    def mark_whitespace(self):\n        for sentence in self._sentences:\n            # TODO: pairwise, once we move to minimum 3.10\n            for prev_token, next_token in zip(sentence.tokens[:-1], sentence.tokens[1:]):\n                whitespace = self._text[prev_token.end_char:next_token.start_char]\n                prev_token.spaces_after = whitespace\n        for prev_sentence, next_sentence in zip(self._sentences[:-1], self._sentences[1:]):\n            prev_token = prev_sentence.tokens[-1]\n            next_token = next_sentence.tokens[0]\n            whitespace = self._text[prev_token.end_char:next_token.start_char]\n            prev_token.spaces_after = whitespace\n        if len(self._sentences) > 0 and len(self._sentences[-1].tokens) > 0:\n            final_token = self._sentences[-1].tokens[-1]\n            whitespace = self._text[final_token.end_char:]\n            final_token.spaces_after = whitespace\n        if len(self._sentences) > 0 and len(self._sentences[0].tokens) > 0:\n            first_token = self._sentences[0].tokens[0]\n            whitespace = self._text[:first_token.start_char]\n            first_token.spaces_before = whitespace\n\n\n    @property\n    def lang(self):\n        \"\"\" Access the language of this document \"\"\"\n        return self._lang\n\n    @lang.setter\n    def lang(self, value):\n        \"\"\" Set the language of this document \"\"\"\n        self._lang = value\n\n    @property\n    def text(self):\n        \"\"\" Access the raw text for this document. \"\"\"\n        return self._text\n\n    @text.setter\n    def text(self, value):\n        \"\"\" Set the raw text for this document. \"\"\"\n        self._text = value\n\n    @property\n    def sentences(self):\n        \"\"\" Access the list of sentences for this document. \"\"\"\n        return self._sentences\n\n    @sentences.setter\n    def sentences(self, value):\n        \"\"\" Set the list of tokens for this document. \"\"\"\n        self._sentences = value\n\n    @property\n    def num_tokens(self):\n        \"\"\" Access the number of tokens for this document. \"\"\"\n        return self._num_tokens\n\n    @num_tokens.setter\n    def num_tokens(self, value):\n        \"\"\" Set the number of tokens for this document. \"\"\"\n        self._num_tokens = value\n\n    @property\n    def num_words(self):\n        \"\"\" Access the number of words for this document. \"\"\"\n        return self._num_words\n\n    @num_words.setter\n    def num_words(self, value):\n        \"\"\" Set the number of words for this document. \"\"\"\n        self._num_words = value\n\n    @property\n    def ents(self):\n        \"\"\" Access the list of entities in this document. \"\"\"\n        return self._ents\n\n    @ents.setter\n    def ents(self, value):\n        \"\"\" Set the list of entities in this document. \"\"\"\n        self._ents = value\n\n    @property\n    def entities(self):\n        \"\"\" Access the list of entities. This is just an alias of `ents`. \"\"\"\n        return self._ents\n\n    @entities.setter\n    def entities(self, value):\n        \"\"\" Set the list of entities in this document. \"\"\"\n        self._ents = value\n\n    def _process_sentences(self, sentences, comments=None, empty_sentences=None):\n        self.sentences = []\n        if empty_sentences is None:\n            empty_sentences = repeat([])\n        for sent_idx, (tokens, empty_words) in enumerate(zip(sentences, empty_sentences)):\n            try:\n                sentence = Sentence(tokens, doc=self, empty_words=empty_words)\n            except IndexError as e:\n                raise IndexError(\"Could not process document at sentence %d\" % sent_idx) from e\n            except ValueError as e:\n                tokens = [\"|%s|\" % t for t in tokens]\n                tokens = \", \".join(tokens)\n                raise ValueError(\"Could not process document at sentence %d\\n  Raw tokens: %s\" % (sent_idx, tokens)) from e\n            self.sentences.append(sentence)\n            begin_idx, end_idx = sentence.tokens[0].start_char, sentence.tokens[-1].end_char\n            if all((self.text is not None, begin_idx is not None, end_idx is not None)): sentence.text = self.text[begin_idx: end_idx]\n            sentence.index = sent_idx\n\n        self._count_words()\n\n        # Add a #text comment to each sentence in a doc if it doesn't already exist\n        if not comments:\n            comments = [[] for x in self.sentences]\n        else:\n            comments = [list(x) for x in comments]\n        for sentence, sentence_comments in zip(self.sentences, comments):\n            # the space after text can occur in treebanks such as the Naija-NSC treebank,\n            # which extensively uses `# text_en =` and `# text_ortho`\n            if sentence.text and not any(comment.startswith(\"# text \") or comment.startswith(\"#text \") or comment.startswith(\"# text=\") or comment.startswith(\"#text=\") for comment in sentence_comments):\n                # split/join to handle weird whitespace, especially newlines\n                sentence_comments.append(\"# text = \" + ' '.join(sentence.text.split()))\n            elif not sentence.text:\n                for comment in sentence_comments:\n                    if comment.startswith(\"# text \") or comment.startswith(\"#text \") or comment.startswith(\"# text=\") or comment.startswith(\"#text=\"):\n                        sentence.text = comment.split(\"=\", 1)[-1].strip()\n                        break\n\n            for comment in sentence_comments:\n                sentence.add_comment(comment)\n\n            # look for sent_id in the comments\n            # if it's there, overwrite the sent_idx id from above\n            for comment in sentence_comments:\n                if comment.startswith(\"# sent_id\"):\n                    sentence.sent_id = comment.split(\"=\", 1)[-1].strip()\n                    break\n            else:\n                # no sent_id found.  add a comment with our enumerated id\n                # setting the sent_id on the sentence will automatically add the comment\n                sentence.sent_id = str(sentence.index)\n\n            # look for speaker in the comments\n            for comment in sentence_comments:\n                if comment.startswith(\"# speaker\"):\n                    sentence.speaker = comment.split(\"=\", 1)[-1].strip()\n                    break\n            else:\n                sentence.speaker = None\n\n    def _count_words(self):\n        \"\"\"\n        Count the number of tokens and words\n        \"\"\"\n        self.num_tokens = sum([len(sentence.tokens) for sentence in self.sentences])\n        self.num_words = sum([len(sentence.words) for sentence in self.sentences])\n\n    def get(self, fields, as_sentences=False, from_token=False):\n        \"\"\" Get fields from a list of field names.\n        If only one field name (string or singleton list) is provided,\n        return a list of that field; if more than one, return a list of list.\n        Note that all returned fields are after multi-word expansion.\n\n        Args:\n            fields: name of the fields as a list or a single string\n            as_sentences: if True, return the fields as a list of sentences; otherwise as a whole list\n            from_token: if True, get the fields from Token; otherwise from Word\n\n        Returns:\n            All requested fields.\n        \"\"\"\n        if isinstance(fields, str):\n            fields = [fields]\n        assert isinstance(fields, list), \"Must provide field names as a list.\"\n        assert len(fields) >= 1, \"Must have at least one field.\"\n\n        results = []\n        for sentence in self.sentences:\n            cursent = []\n            # decide word or token\n            if from_token:\n                units = sentence.tokens\n            else:\n                units = sentence.words\n            for unit in units:\n                if len(fields) == 1:\n                    cursent += [getattr(unit, fields[0])]\n                else:\n                    cursent += [[getattr(unit, field) for field in fields]]\n\n            # decide whether append the results as a sentence or a whole list\n            if as_sentences:\n                results.append(cursent)\n            else:\n                results += cursent\n        return results\n\n    def set(self, fields, contents, to_token=False, to_sentence=False):\n        \"\"\"Set fields based on contents. If only one field (string or\n        singleton list) is provided, then a list of content will be\n        expected; otherwise a list of list of contents will be expected.\n\n        Args:\n            fields: name of the fields as a list or a single string\n            contents: field values to set; total length should be equal to number of words/tokens\n            to_token: if True, set field values to tokens; otherwise to words\n\n        \"\"\"\n        if isinstance(fields, str):\n            fields = [fields]\n        assert isinstance(fields, (tuple, list)), \"Must provide field names as a list.\"\n        assert isinstance(contents, (tuple, list)), \"Must provide contents as a list (one item per line).\"\n        assert len(fields) >= 1, \"Must have at least one field.\"\n\n        assert not to_sentence or not to_token, \"Both to_token and to_sentence set to True, which is very confusing\"\n\n        if to_sentence:\n            assert len(self.sentences) == len(contents), \\\n                \"Contents must have the same length as the sentences\"\n            for sentence, content in zip(self.sentences, contents):\n                if len(fields) == 1:\n                    setattr(sentence, fields[0], content)\n                else:\n                    for field, piece in zip(fields, content):\n                        setattr(sentence, field, piece)\n        else:\n            assert (to_token and self.num_tokens == len(contents)) or self.num_words == len(contents), \\\n                \"Contents must have the same length as the original file.\"\n\n            cidx = 0\n            for sentence in self.sentences:\n                # decide word or token\n                if to_token:\n                    units = sentence.tokens\n                else:\n                    units = sentence.words\n                for unit in units:\n                    if len(fields) == 1:\n                        setattr(unit, fields[0], contents[cidx])\n                    else:\n                        for field, content in zip(fields, contents[cidx]):\n                            setattr(unit, field, content)\n                    cidx += 1\n\n    def set_mwt_expansions(self, expansions,\n                           fake_dependencies=False,\n                           process_manual_expanded=None):\n        \"\"\" Extend the multi-word tokens annotated by tokenizer. A list of list of expansions\n        will be expected for each multi-word token. Use `process_manual_expanded` to limit\n        processing for tokens marked manually expanded:\n\n        There are two types of MWT expansions: those with `misc`: `MWT=True`, and those with\n        `manual_expansion`: True. The latter of which means that it is an expansion which the\n        user manually specified through a postprocessor; the former means that it is a MWT\n        which the detector picked out, but needs to be automatically expanded.\n\n        process_manual_expanded = None - default; doesn't process manually expanded tokens\n                                = True - process only manually expanded tokens (with `manual_expansion`: True)\n                                = False - process only tokens explicitly tagged as MWT (`misc`: `MWT=True`)\n        \"\"\"\n\n        idx_e = 0\n        for sentence in self.sentences:\n            idx_w = 0\n            for token in sentence.tokens:\n                idx_w += 1\n                is_multi = (len(token.id) > 1)\n                is_mwt = (multi_word_token_misc.match(token.misc) if token.misc is not None else None)\n                is_manual_expansion = token.manual_expansion\n\n                perform_mwt_processing = MWTProcessingType.FLATTEN\n\n                if (process_manual_expanded and is_manual_expansion):\n                    perform_mwt_processing = MWTProcessingType.PROCESS\n                elif (process_manual_expanded==False and is_mwt):\n                    perform_mwt_processing = MWTProcessingType.PROCESS\n                elif (process_manual_expanded==False and is_manual_expansion):\n                    perform_mwt_processing = MWTProcessingType.SKIP\n                elif (process_manual_expanded==None and (is_mwt or is_multi)):\n                    perform_mwt_processing = MWTProcessingType.PROCESS\n\n                if perform_mwt_processing == MWTProcessingType.FLATTEN:\n                    for word in token.words:\n                        token.id = (idx_w, )\n                        # delete dependency information\n                        word.deps = None\n                        word.head, word.deprel = None, None\n                        word.id = idx_w\n                elif perform_mwt_processing == MWTProcessingType.PROCESS:\n                    expanded = [x for x in expansions[idx_e].split(' ') if len(x) > 0]\n                    # in the event the MWT annotator only split the\n                    # Token into a single Word, we preserve its text\n                    # otherwise the Token's text is different from its\n                    # only Word's text\n                    if len(expanded) == 1:\n                        expanded = [token.text]\n                    idx_e += 1\n                    idx_w_end = idx_w + len(expanded) - 1\n                    if token.misc:  # None can happen when using a prebuilt doc\n                        token.misc = None if token.misc == 'MWT=Yes' else '|'.join([x for x in token.misc.split('|') if x != 'MWT=Yes'])\n                    token.id = (idx_w, idx_w_end) if len(expanded) > 1 else (idx_w,)\n                    token.words = []\n                    for i, e_word in enumerate(expanded):\n                        token.words.append(Word(sentence, {ID: idx_w + i, TEXT: e_word}))\n                    idx_w = idx_w_end\n                elif perform_mwt_processing == MWTProcessingType.SKIP:\n                    token.id = tuple(orig_id + idx_e for orig_id in token.id)\n                    for i in token.words:\n                        i.id += idx_e\n                    idx_w = token.id[-1]\n                    token.manual_expansion = None\n\n            # reprocess the words using the new tokens\n            sentence.words = []\n            for token in sentence.tokens:\n                token.sent = sentence\n                for word in token.words:\n                    word.sent = sentence\n                    word.parent = token\n                    sentence.words.append(word)\n                if len(token.words) == 1:\n                    word.start_char = token.start_char\n                    word.end_char = token.end_char\n                elif token.start_char is not None and token.end_char is not None:\n                    search_string = \"^%s$\" % (\"\\\\s*\".join(\"(%s)\" % re.escape(word.text) for word in token.words))\n                    match = re.compile(search_string).match(token.text)\n                    if match:\n                        for word_idx, word in enumerate(token.words):\n                            word.start_char = match.start(word_idx+1) + token.start_char\n                            word.end_char = match.end(word_idx+1) + token.start_char\n\n            if fake_dependencies:\n                sentence.build_fake_dependencies()\n            else:\n                sentence.rebuild_dependencies()\n\n        self._count_words() # update number of words & tokens\n        assert idx_e == len(expansions), \"{} {}\".format(idx_e, len(expansions))\n        return\n\n    def get_mwt_expansions(self, evaluation=False):\n        \"\"\" Get the multi-word tokens. For training, return a list of\n        (multi-word token, extended multi-word token); otherwise, return a list of\n        multi-word token only. By default doesn't skip already expanded tokens, but\n        `skip_already_expanded` will return only tokens marked as MWT.\n        \"\"\"\n        expansions = []\n        for sentence in self.sentences:\n            for token in sentence.tokens:\n                is_multi = (len(token.id) > 1)\n                is_mwt = multi_word_token_misc.match(token.misc) if token.misc is not None else None\n                is_manual_expansion = token.manual_expansion\n                if (is_multi and not is_manual_expansion) or is_mwt:\n                    src = token.text\n                    dst = ' '.join([word.text for word in token.words])\n                    expansions.append([src, dst])\n        if evaluation: expansions = [e[0] for e in expansions]\n        return expansions\n\n    def build_ents(self):\n        \"\"\" Build the list of entities by iterating over all words. Return all entities as a list. \"\"\"\n        self.ents = []\n        for s in self.sentences:\n            s_ents = s.build_ents()\n            self.ents += s_ents\n        return self.ents\n\n    def sort_features(self):\n        \"\"\" Sort the features on all the words... useful for prototype treebanks, for example \"\"\"\n        for sentence in self.sentences:\n            for word in sentence.words:\n                if not word.feats:\n                    continue\n                pieces = word.feats.split(\"|\")\n                pieces = sorted(pieces, key=str.casefold)\n                word.feats = \"|\".join(pieces)\n\n    def iter_words(self):\n        \"\"\" An iterator that returns all of the words in this Document. \"\"\"\n        for s in self.sentences:\n            yield from s.words\n\n    def iter_tokens(self):\n        \"\"\" An iterator that returns all of the tokens in this Document. \"\"\"\n        for s in self.sentences:\n            yield from s.tokens\n\n    def sentence_comments(self):\n        \"\"\" Returns a list of list of comments for the sentences \"\"\"\n        return [[comment for comment in sentence.comments] for sentence in self.sentences]\n\n    @property\n    def coref(self):\n        \"\"\"\n        Access the coref lists of the document\n        \"\"\"\n        return self._coref\n\n    @coref.setter\n    def coref(self, chains):\n        \"\"\" Set the document's coref lists \"\"\"\n        self._coref = chains\n        self._attach_coref_mentions(chains)\n\n    def _attach_coref_mentions(self, chains):\n        for sentence in self.sentences:\n            for word in sentence.all_words:\n                word.coref_chains = []\n\n        for chain in chains:\n            for mention_idx, mention in enumerate(chain.mentions):\n                sentence = self.sentences[mention.sentence]\n                if isinstance(mention.start_word, tuple):\n                    attachment = CorefAttachment(chain, True, True, False)\n                    sentence._empty_words[mention.start_word[1]-1].coref_chains.append(attachment)\n                else:\n                    for word_idx in range(mention.start_word, mention.end_word):\n                            is_start = word_idx == mention.start_word\n                            is_end = word_idx == mention.end_word - 1\n                            is_representative = mention_idx == chain.representative_index\n                            attachment = CorefAttachment(chain, is_start, is_end, is_representative)\n                            sentence.words[word_idx].coref_chains.append(attachment)\n\n    def reindex_sentences(self, start_index):\n        for sent_id, sentence in zip(range(start_index, start_index + len(self.sentences)), self.sentences):\n            sentence.sent_id = str(sent_id)\n\n    def to_dict(self):\n        \"\"\" Dumps the whole document into a list of list of dictionary for each token in each sentence in the doc.\n        \"\"\"\n        return [sentence.to_dict() for sentence in self.sentences]\n\n    def __repr__(self):\n        return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)\n\n    def __format__(self, spec):\n        if spec and spec[0] in ('c', 'C'):\n            spec = \"{:%s}\" % spec\n            return \"\\n\\n\".join(spec.format(s) for s in self.sentences)\n        else:\n            return str(self)\n\n    def to_serialized(self):\n        \"\"\" Dumps the whole document including text to a byte array containing a list of list of dictionaries for each token in each sentence in the doc.\n        \"\"\"\n        return pickle.dumps((self.text, self.to_dict(), self.sentence_comments()))\n\n    @classmethod\n    def from_serialized(cls, serialized_string):\n        \"\"\" Create and initialize a new document from a serialized string generated by Document.to_serialized_string():\n        \"\"\"\n        stuff = pickle.loads(serialized_string)\n        if not isinstance(stuff, tuple):\n            raise TypeError(\"Serialized data was not a tuple when building a Document\")\n        if len(stuff) == 2:\n            text, sentences = pickle.loads(serialized_string)\n            doc = cls(sentences, text)\n        else:\n            text, sentences, comments = pickle.loads(serialized_string)\n            doc = cls(sentences, text, comments)\n        return doc\n\n\nclass Sentence(StanzaObject):\n    \"\"\" A sentence class that stores attributes of a sentence and carries a list of tokens.\n    \"\"\"\n\n    def __init__(self, tokens, doc=None, empty_words=None):\n        \"\"\" Construct a sentence given a list of tokens in the form of CoNLL-U dicts.\n        \"\"\"\n        self._tokens = []\n        self._words = []\n        self._dependencies = []\n        self._text = None\n        self._ents = []\n        self._doc = doc\n        self._constituency = None\n        self._sentiment = None\n        # comments are a list of comment lines occurring before the\n        # sentence in a CoNLL-U file.  Can be empty\n        self._comments = []\n        self._doc_id = None\n\n        # enhanced_dependencies represents the DEPS column\n        # this is a networkx MultiDiGraph\n        # with edges from the parent to the dependent\n        # however, we set it to None until needed, as it is somewhat slow\n        self._enhanced_dependencies = None\n        self._process_tokens(tokens)\n\n        if empty_words is not None:\n            self._empty_words = [Word(self, entry) for entry in empty_words]\n        else:\n            self._empty_words = []\n\n    def _process_tokens(self, tokens):\n        st, en = -1, -1\n        self.tokens, self.words = [], []\n        for i, entry in enumerate(tokens):\n            if ID not in entry: # manually set a 1-based id for word if not exist\n                entry[ID] = (i+1, )\n            if isinstance(entry[ID], int):\n                entry[ID] = (entry[ID], )\n            if len(entry.get(ID)) > 1: # if this token is a multi-word token\n                st, en = entry[ID]\n                self.tokens.append(Token(self, entry))\n            else: # else this token is a word\n                new_word = Word(self, entry)\n                if len(self.words) > 0 and self.words[-1].id == new_word.id:\n                    # this can happen in the following context:\n                    # a document was created with MWT=Yes to mark that a token should be split\n                    # and then there was an MWT \"expansion\" with a single word after that token\n                    # we replace the Word in the Token assuming that the expansion token might\n                    # have more information than the Token dict did\n                    # note that a single word MWT like that can be detected with something like\n                    #   multi_word_token_misc.match(entry.get(MISC)) if entry.get(MISC, None)\n                    self.words[-1] = new_word\n                    self.tokens[-1].words[-1] = new_word\n                    continue\n                self.words.append(new_word)\n                idx = entry.get(ID)[0]\n                if idx <= en:\n                    self.tokens[-1].words.append(new_word)\n                else:\n                    self.tokens.append(Token(self, entry, words=[new_word]))\n                new_word.parent = self.tokens[-1]\n\n        # put all of the whitespace annotations (if any) on the Tokens instead of the Words\n        for token in self.tokens:\n            token.consolidate_whitespace()\n        self.rebuild_dependencies()\n\n    def has_enhanced_dependencies(self):\n        \"\"\"\n        Whether or not the enhanced dependencies are part of this sentence\n        \"\"\"\n        return self._enhanced_dependencies is not None and len(self._enhanced_dependencies) > 0\n\n    @property\n    def enhanced_dependencies(self):\n        \"\"\"\n        Returns the enhanced_dependencies graph.\n\n        Creates an empty one if one currently does not exist.\n        \"\"\"\n        graph = self._enhanced_dependencies\n        if graph is None:\n            graph = nx.MultiDiGraph()\n            self._enhanced_dependencies = graph\n        return graph\n\n    @property\n    def index(self):\n        \"\"\"\n        Access the index of this sentence within the doc.\n\n        If multiple docs were processed together,\n        the sentence index will continue counting across docs.\n        \"\"\"\n        return self._index\n\n    @index.setter\n    def index(self, value):\n        \"\"\" Set the sentence's index value. \"\"\"\n        self._index = value\n\n    @property\n    def id(self):\n        \"\"\"\n        Access the index of this sentence within the doc.\n\n        If multiple docs were processed together,\n        the sentence index will continue counting across docs.\n        \"\"\"\n        warnings.warn(\"Use of sentence.id is deprecated.  Please use sentence.index instead\", stacklevel=2)\n        return self._index\n\n    @id.setter\n    def id(self, value):\n        \"\"\" Set the sentence's index value. \"\"\"\n        warnings.warn(\"Use of sentence.id is deprecated.  Please use sentence.index instead\", stacklevel=2)\n        self._index = value\n\n    @property\n    def sent_id(self):\n        \"\"\" conll-style sent_id  Will be set from index if unknown \"\"\"\n        return self._sent_id\n\n    @sent_id.setter\n    def sent_id(self, value):\n        \"\"\" Set the sentence's sent_id value. \"\"\"\n        self._sent_id = value\n        sent_id_comment = \"# sent_id = \" + str(value)\n        for comment_idx, comment in enumerate(self._comments):\n            if comment.startswith(\"# sent_id = \"):\n                self._comments[comment_idx] = sent_id_comment\n                break\n        else: # this is intended to be a for/else loop\n            self._comments.append(sent_id_comment)\n\n    @property\n    def speaker(self):\n        \"\"\" conll-style speaker - adopt the EN GUM formatting \"\"\"\n        return self._speaker\n\n    @speaker.setter\n    def speaker(self, value):\n        \"\"\" Set the sentence's speaker value. \"\"\"\n        self._speaker = value\n        speaker_comment = \"# speaker = \" + str(value)\n        if not value:\n            for comment_idx, comment in enumerate(self._comments):\n                if comment.startswith(\"# speaker = \"):\n                    self._comments.pop(comment_idx)\n                    break\n        else:\n            for comment_idx, comment in enumerate(self._comments):\n                if comment.startswith(\"# speaker = \"):\n                    self._comments[comment_idx] = speaker_comment\n                    break\n            else: # this is intended to be a for/else loop\n                self._comments.append(speaker_comment)\n\n    @property\n    def doc_id(self):\n        \"\"\" conll-style doc_id  Can be left blank if unknown \"\"\"\n        return self._doc_id\n\n    @doc_id.setter\n    def doc_id(self, value):\n        \"\"\" Set the sentence's doc_id value. \"\"\"\n        self._doc_id = value\n        doc_id_comment = \"# doc_id = \" + str(value)\n        for comment_idx, comment in enumerate(self._comments):\n            if comment.startswith(\"# doc_id = \"):\n                self._comments[comment_idx] = doc_id_comment\n                break\n        else: # this is intended to be a for/else loop\n            self._comments.append(doc_id_comment)\n\n    @property\n    def doc(self):\n        \"\"\" Access the parent doc of this span. \"\"\"\n        return self._doc\n\n    @doc.setter\n    def doc(self, value):\n        \"\"\" Set the parent doc of this span. \"\"\"\n        self._doc = value\n\n    @property\n    def text(self):\n        \"\"\" Access the raw text for this sentence. \"\"\"\n        return self._text\n\n    @text.setter\n    def text(self, value):\n        \"\"\" Set the raw text for this sentence. \"\"\"\n        self._text = value\n\n    @property\n    def dependencies(self):\n        \"\"\" Access list of dependencies for this sentence. \"\"\"\n        return self._dependencies\n\n    @dependencies.setter\n    def dependencies(self, value):\n        \"\"\" Set the list of dependencies for this sentence. \"\"\"\n        self._dependencies = value\n\n    @property\n    def tokens(self):\n        \"\"\" Access the list of tokens for this sentence. \"\"\"\n        return self._tokens\n\n    @tokens.setter\n    def tokens(self, value):\n        \"\"\" Set the list of tokens for this sentence. \"\"\"\n        self._tokens = value\n\n    @property\n    def words(self):\n        \"\"\" Access the list of words for this sentence. \"\"\"\n        return self._words\n\n    @words.setter\n    def words(self, value):\n        \"\"\" Set the list of words for this sentence. \"\"\"\n        self._words = value\n\n    @property\n    def empty_words(self):\n        \"\"\" Access the list of words for this sentence. \"\"\"\n        return self._empty_words\n\n    @empty_words.setter\n    def empty_words(self, value):\n        \"\"\" Set the list of words for this sentence. \"\"\"\n        self._empty_words = value\n\n    @property\n    def all_words(self):\n        \"\"\" Access the list of words + empty words for this sentence. \"\"\"\n        words = self._words\n        empty_words = self._empty_words\n\n        all_words = sorted(words + empty_words,\n                           key=lambda x:(x.id,) if isinstance(x.id, int) else x.id)\n\n        return all_words\n\n    @property\n    def ents(self):\n        \"\"\" Access the list of entities in this sentence. \"\"\"\n        return self._ents\n\n    @ents.setter\n    def ents(self, value):\n        \"\"\" Set the list of entities in this sentence. \"\"\"\n        self._ents = value\n\n    @property\n    def entities(self):\n        \"\"\" Access the list of entities. This is just an alias of `ents`. \"\"\"\n        return self._ents\n\n    @entities.setter\n    def entities(self, value):\n        \"\"\" Set the list of entities in this sentence. \"\"\"\n        self._ents = value\n\n    def build_ents(self):\n        \"\"\" Build the list of entities by iterating over all tokens. Return all entities as a list.\n\n        Note that unlike other attributes, since NER requires raw text, the actual tagging are always\n        performed at and attached to the `Token`s, instead of `Word`s.\n        \"\"\"\n        self.ents = []\n        tags = [w.ner for w in self.tokens]\n        decoded = decode_from_bioes(tags)\n        for e in decoded:\n            ent_tokens = self.tokens[e['start']:e['end']+1]\n            self.ents.append(Span(tokens=ent_tokens, type=e['type'], doc=self.doc, sent=self))\n        return self.ents\n\n    @property\n    def sentiment(self):\n        \"\"\" Returns the sentiment value for this sentence \"\"\"\n        return self._sentiment\n\n    @sentiment.setter\n    def sentiment(self, value):\n        \"\"\" Set the sentiment value \"\"\"\n        self._sentiment = value\n        sentiment_comment = \"# sentiment = \" + str(value)\n        for comment_idx, comment in enumerate(self._comments):\n            if comment.startswith(\"# sentiment = \"):\n                self._comments[comment_idx] = sentiment_comment\n                break\n        else: # this is intended to be a for/else loop\n            self._comments.append(sentiment_comment)\n\n    @property\n    def constituency(self):\n        \"\"\" Returns the constituency tree for this sentence \"\"\"\n        return self._constituency\n\n    @constituency.setter\n    def constituency(self, value):\n        \"\"\"\n        Set the constituency tree\n\n        This incidentally updates the #constituency comment if it already exists,\n        or otherwise creates a new comment # constituency = ...\n        \"\"\"\n        self._constituency = value\n        constituency_comment = \"# constituency = \" + str(value)\n        constituency_comment = constituency_comment.replace(\"\\n\", \"*NL*\").replace(\"\\r\", \"\")\n        for comment_idx, comment in enumerate(self._comments):\n            if comment.startswith(\"# constituency = \"):\n                self._comments[comment_idx] = constituency_comment\n                break\n        else: # this is intended to be a for/else loop\n            self._comments.append(constituency_comment)\n\n\n    @property\n    def comments(self):\n        \"\"\" Returns CoNLL-style comments for this sentence \"\"\"\n        return self._comments\n\n    def add_comment(self, comment):\n        \"\"\" Adds a single comment to this sentence.\n\n        If the comment does not already have # at the start, it will be added.\n        \"\"\"\n        if not comment.startswith(\"#\"):\n            comment = \"# \" + comment\n        if comment.startswith(\"# constituency =\"):\n            _, tree_text = comment.split(\"=\", 1)\n            tree = tree_reader.read_trees(tree_text)\n            if len(tree) > 1:\n                raise ValueError(\"Multiple constituency trees for one sentence: %s\" % tree_text)\n            self._constituency = tree[0]\n            self._comments = [x for x in self._comments if not x.startswith(\"# constituency =\")]\n        elif comment.startswith(\"# sentiment =\"):\n            _, sentiment = comment.split(\"=\", 1)\n            sentiment = int(sentiment.strip())\n            self._sentiment = sentiment\n            self._comments = [x for x in self._comments if not x.startswith(\"# sentiment =\")]\n        elif comment.startswith(\"# sent_id =\"):\n            _, sent_id = comment.split(\"=\", 1)\n            sent_id = sent_id.strip()\n            self._sent_id = sent_id\n            self._comments = [x for x in self._comments if not x.startswith(\"# sent_id =\")]\n        elif comment.startswith(\"# doc_id =\"):\n            _, doc_id = comment.split(\"=\", 1)\n            doc_id = doc_id.strip()\n            self._doc_id = doc_id\n            self._comments = [x for x in self._comments if not x.startswith(\"# doc_id =\")]\n        self._comments.append(comment)\n\n    def rebuild_dependencies(self):\n        # rebuild dependencies if there is dependency info\n        is_complete_dependencies = all(word.head is not None and word.deprel is not None for word in self.words)\n        is_complete_words = (len(self.words) >= len(self.tokens)) and (len(self.words) == self.words[-1].id)\n        if is_complete_dependencies and is_complete_words: self.build_dependencies()\n\n    def build_dependencies(self):\n        \"\"\" Build the dependency graph for this sentence. Each dependency graph entry is\n        a list of (head, deprel, word).\n        \"\"\"\n        self.dependencies = []\n        for word in self.words:\n            if word.head == 0:\n                # make a word for the ROOT\n                word_entry = {ID: 0, TEXT: \"ROOT\"}\n                head = Word(self, word_entry)\n            else:\n                # id is index in words list + 1\n                try:\n                    head = self.words[word.head - 1]\n                except IndexError as e:\n                    raise IndexError(\"Word head {} is not a valid word index for word {}\".format(word.head, word.id)) from e\n                if word.head != head.id:\n                    raise ValueError(\"Dependency tree is incorrectly constructed\")\n            self.dependencies.append((head, word.deprel, word))\n\n    def build_fake_dependencies(self):\n        self.dependencies = []\n        for word_idx, word in enumerate(self.words):\n            word.head = word_idx   # note that this goes one previous to the index\n            word.deprel = \"root\" if word_idx == 0 else \"dep\"\n            word.deps = \"%d:%s\" % (word.head, word.deprel)\n            self.dependencies.append((word_idx, word.deprel, word))\n\n    def print_dependencies(self, file=None):\n        \"\"\" Print the dependencies for this sentence. \"\"\"\n        for dep_edge in self.dependencies:\n            print((dep_edge[2].text, dep_edge[0].id, dep_edge[1]), file=file)\n\n    def dependencies_string(self):\n        \"\"\" Dump the dependencies for this sentence into string. \"\"\"\n        dep_string = io.StringIO()\n        self.print_dependencies(file=dep_string)\n        return dep_string.getvalue().strip()\n\n    def get_roots(self):\n        \"\"\" Return a list of root(s) from a sentence \"\"\"\n        roots = []\n        for word in self.words:\n            if word.head == 0:\n                roots.append(word)\n        return roots\n\n    def print_tokens(self, file=None):\n        \"\"\" Print the tokens for this sentence. \"\"\"\n        for tok in self.tokens:\n            print(tok.pretty_print(), file=file)\n\n    def tokens_string(self):\n        \"\"\" Dump the tokens for this sentence into string. \"\"\"\n        toks_string = io.StringIO()\n        self.print_tokens(file=toks_string)\n        return toks_string.getvalue().strip()\n\n    def print_words(self, file=None):\n        \"\"\" Print the words for this sentence. \"\"\"\n        for word in self.words:\n            print(word.pretty_print(), file=file)\n\n    def words_string(self):\n        \"\"\" Dump the words for this sentence into string. \"\"\"\n        wrds_string = io.StringIO()\n        self.print_words(file=wrds_string)\n        return wrds_string.getvalue().strip()\n\n    def to_dict(self):\n        \"\"\" Dumps the sentence into a list of dictionary for each token in the sentence.\n        \"\"\"\n        ret = []\n        empty_idx = 0\n        for token_idx, token in enumerate(self.tokens):\n            while empty_idx < len(self._empty_words) and self._empty_words[empty_idx].id[0] < token.id[0]:\n                ret.append(self._empty_words[empty_idx].to_dict())\n                empty_idx += 1\n            ret += token.to_dict()\n        for empty_word in self._empty_words[empty_idx:]:\n            ret.append(empty_word.to_dict())\n        return ret\n\n    def __repr__(self):\n        return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)\n\n    def __format__(self, spec):\n        if not spec:\n            return str(self)\n        if not spec[0] == 'c' and not spec[0] == 'C':\n            return str(self)\n        if \"-o\" in spec:\n            fields = NO_OFFSETS_OUTPUT_FIELDS\n        else:\n            fields = DEFAULT_OUTPUT_FIELDS\n\n        pieces = []\n        empty_idx = 0\n        for token_idx, token in enumerate(self.tokens):\n            while empty_idx < len(self._empty_words) and self._empty_words[empty_idx].id[0] < token.id[0]:\n                pieces.append(self._empty_words[empty_idx].to_conll_text(fields))\n                empty_idx += 1\n            pieces.append(token.to_conll_text(fields))\n        for empty_word in self._empty_words[empty_idx:]:\n            pieces.append(empty_word.to_conll_text(fields))\n\n        if spec[0] == 'c':\n            return \"\\n\".join(pieces)\n        elif spec[0] == 'C':\n            tokens = \"\\n\".join(pieces)\n            if len(self.comments) > 0:\n                text = \"\\n\".join(self.comments)\n                return text + \"\\n\" + tokens\n            return tokens\n\ndef init_from_misc(unit):\n    \"\"\"Create attributes by parsing from the `misc` field.\n\n    Also, remove start_char, end_char, and any other values we can set\n    from the misc field if applicable, so that we don't repeat ourselves\n    \"\"\"\n    remaining_values = []\n    for item in unit._misc.split('|'):\n        key_value = item.split('=', 1)\n        if len(key_value) == 2:\n            # some key_value can not be split\n            key, value = key_value\n            # start & end char are kept as ints\n            if key in (START_CHAR, END_CHAR, LINE_NUMBER):\n                value = int(value)\n            # set attribute\n            attr = f'_{key}'\n            if hasattr(unit, attr):\n                setattr(unit, attr, value)\n                continue\n            elif key == NER:\n                # special case skipping NER for Words, since there is no Word NER field\n                continue\n        remaining_values.append(item)\n    unit._misc = \"|\".join(remaining_values)\n\n\ndef dict_to_conll_text(token_dict, id_connector=\"-\"):\n    token_conll = ['_' for i in range(FIELD_NUM)]\n\n    misc = []\n    if token_dict.get(MISC):\n        # avoid appending a blank misc entry.\n        # otherwise the resulting misc field in the conll doc will wind up being blank text\n        # TODO: potentially need to escape =|\\ in the MISC as well\n        misc.append(token_dict[MISC])\n\n    # for other items meant to be in the MISC field,\n    # we try to operate on those columns in a deterministic order\n    # so that the output doesn't change based on the order of keys\n    # in the token_dict\n    for key in [START_CHAR, END_CHAR, NER]:\n        if key in token_dict:\n            misc.append(\"{}={}\".format(key, token_dict[key]))\n\n    if COREF_CHAINS in token_dict:\n        chains = token_dict[COREF_CHAINS]\n        if len(chains) > 0:\n            misc_chains = []\n            for chain in chains:\n                if chain.is_start and chain.is_end:\n                    coref_position = \"unit-\"\n                elif chain.is_start:\n                    coref_position = \"start-\"\n                elif chain.is_end:\n                    coref_position = \"end-\"\n                else:\n                    coref_position = \"middle-\"\n                is_representative = \"repr-\" if chain.is_representative else \"\"\n                misc_chains.append(\"%s%sid%d\" % (coref_position, is_representative, chain.chain.index))\n            misc.append(\"{}={}\".format(key, \",\".join(misc_chains)))\n\n    for key in token_dict.keys():\n        if key == ID:\n            token_conll[FIELD_TO_IDX[key]] = id_connector.join([str(x) for x in token_dict[key]]) if isinstance(token_dict[key], tuple) else str(token_dict[key])\n        elif key == FEATS:\n            feats = token_dict[key]\n            if feats:\n                pieces = feats.split(\"|\")\n                pieces = sorted(pieces, key=str.casefold)\n                feats = \"|\".join(pieces)\n            token_conll[FIELD_TO_IDX[key]] = str(feats)\n        elif key in FIELD_TO_IDX:\n            token_conll[FIELD_TO_IDX[key]] = str(token_dict[key])\n        elif key == LINE_NUMBER:\n            # skip this when converting back for now\n            pass\n    if misc:\n        token_conll[FIELD_TO_IDX[MISC]] = \"|\".join(misc)\n    else:\n        token_conll[FIELD_TO_IDX[MISC]] = '_'\n    # when a word (not mwt token) without head is found, we insert dummy head as required by the UD eval script\n    if '-' not in token_conll[FIELD_TO_IDX[ID]] and '.' not in token_conll[FIELD_TO_IDX[ID]] and HEAD not in token_dict:\n        token_conll[FIELD_TO_IDX[HEAD]] = str(int(token_dict[ID] if isinstance(token_dict[ID], int) else token_dict[ID][0]) - 1) # evaluation script requires head: int\n    return \"\\t\".join(token_conll)\n\n\nclass Token(StanzaObject):\n    \"\"\" A token class that stores attributes of a token and carries a list of words. A token corresponds to a unit in the raw\n    text. In some languages such as English, a token has a one-to-one mapping to a word, while in other languages such as French,\n    a (multi-word) token might be expanded into multiple words that carry syntactic annotations.\n    \"\"\"\n\n    def __init__(self, sentence, token_entry, words=None):\n        \"\"\"\n        Construct a token given a dictionary format token entry. Optionally link itself to the corresponding words.\n        The owning sentence must be passed in.\n        \"\"\"\n        self._id = token_entry.get(ID)\n        self._text = token_entry.get(TEXT)\n        if not self._id:\n            raise ValueError('id not included for the token')\n        if not self._text:\n            raise ValueError('text not included for the token')\n        self._misc = token_entry.get(MISC, None)\n        self._ner = token_entry.get(NER, None)\n        self._multi_ner = token_entry.get(MULTI_NER, None)\n        self._words = words if words is not None else []\n        self._start_char = token_entry.get(START_CHAR, None)\n        self._end_char = token_entry.get(END_CHAR, None)\n        self._sent = sentence\n        self._mexp = token_entry.get(MEXP, None)\n        self._spaces_before = \"\"\n        self._spaces_after = \" \"\n        self._line_number = None\n\n        if self._misc is not None:\n            init_from_misc(self)\n\n    @property\n    def id(self):\n        \"\"\" Access the index of this token. \"\"\"\n        return self._id\n\n    @id.setter\n    def id(self, value):\n        \"\"\" Set the token's id value. \"\"\"\n        self._id = value\n\n    @property\n    def manual_expansion(self):\n        \"\"\" Access the whether this token was manually expanded. \"\"\"\n        return self._mexp\n\n    @manual_expansion.setter\n    def manual_expansion(self, value):\n        \"\"\" Set the whether this token was manually expanded. \"\"\"\n        self._mexp = value\n\n    @property\n    def text(self):\n        \"\"\" Access the text of this token. Example: 'The' \"\"\"\n        return self._text\n\n    @text.setter\n    def text(self, value):\n        \"\"\" Set the token's text value. Example: 'The' \"\"\"\n        self._text = value\n\n    @property\n    def misc(self):\n        \"\"\" Access the miscellaneousness of this token. \"\"\"\n        return self._misc\n\n    @misc.setter\n    def misc(self, value):\n        \"\"\" Set the token's miscellaneousness value. \"\"\"\n        self._misc = value if self._is_null(value) == False else None\n\n    def consolidate_whitespace(self):\n        \"\"\"\n        Remove whitespace misc annotations from the Words and mark the whitespace on the Tokens\n        \"\"\"\n        found_after = False\n        found_before = False\n        num_words = len(self.words)\n        for word_idx, word in enumerate(self.words):\n            misc = word.misc\n            if not misc:\n                continue\n            pieces = misc.split(\"|\")\n            if word_idx == 0:\n                if any(piece.startswith(\"SpacesBefore=\") for piece in pieces):\n                    self.spaces_before = misc_to_space_before(misc)\n                    found_before = True\n            else:\n                if any(piece.startswith(\"SpacesBefore=\") for piece in pieces):\n                    warnings.warn(\"Found a SpacesBefore MISC annotation on a Word that was not the first Word in a Token\")\n            if word_idx == num_words - 1:\n                if any(piece.startswith(\"SpaceAfter=\") or piece.startswith(\"SpacesAfter=\") for piece in pieces):\n                    self.spaces_after = misc_to_space_after(misc)\n                    found_after = True\n            else:\n                if any(piece.startswith(\"SpaceAfter=\") or piece.startswith(\"SpacesAfter=\") for piece in pieces):\n                    unexpected_space_after = misc_to_space_after(misc)\n                    if unexpected_space_after == \"\":\n                        warnings.warn(\"Unexpected SpaceAfter=No annotation on a word in the middle of an MWT\")\n                    else:\n                        warnings.warn(\"Unexpected SpacesAfter on a word in the middle on an MWT\")\n            pieces = [x for x in pieces if not x.startswith(\"SpacesAfter=\") and not x.startswith(\"SpaceAfter=\") and not x.startswith(\"SpacesBefore=\")]\n            word.misc = \"|\".join(pieces)\n\n        misc = self.misc\n        if misc:\n            pieces = misc.split(\"|\")\n            if any(piece.startswith(\"SpacesBefore=\") for piece in pieces):\n                spaces_before = misc_to_space_before(misc)\n                if found_before:\n                    if spaces_before != self.spaces_before:\n                        warnings.warn(\"Found conflicting SpacesBefore on a token and its word!\")\n                else:\n                    self.spaces_before = spaces_before\n            if any(piece.startswith(\"SpaceAfter=\") or piece.startswith(\"SpacesAfter=\") for piece in pieces):\n                spaces_after = misc_to_space_after(misc)\n                if found_after:\n                    if spaces_after != self.spaces_after:\n                        warnings.warn(\"Found conflicting SpaceAfter / SpacesAfter on a token and its word!\")\n                else:\n                    self.spaces_after = spaces_after\n            pieces = [x for x in pieces if not x.startswith(\"SpacesAfter=\") and not x.startswith(\"SpaceAfter=\") and not x.startswith(\"SpacesBefore=\")]\n            self.misc = \"|\".join(pieces)\n\n    @property\n    def spaces_before(self):\n        \"\"\" SpacesBefore for the token. Translated from the MISC fields \"\"\"\n        return self._spaces_before\n\n    @spaces_before.setter\n    def spaces_before(self, value):\n        self._spaces_before = value\n\n    @property\n    def spaces_after(self):\n        \"\"\" SpaceAfter or SpacesAfter for the token.  Translated from the MISC field \"\"\"\n        return self._spaces_after\n\n    @spaces_after.setter\n    def spaces_after(self, value):\n        self._spaces_after = value\n\n    @property\n    def words(self):\n        \"\"\" Access the list of syntactic words underlying this token. \"\"\"\n        return self._words\n\n    @words.setter\n    def words(self, value):\n        \"\"\" Set this token's list of underlying syntactic words. \"\"\"\n        self._words = value\n        for w in self._words:\n            w.parent = self\n\n    @property\n    def line_number(self):\n        \"\"\" Access the line number from the original document, if set \"\"\"\n        return self._line_number\n\n    @property\n    def start_char(self):\n        \"\"\" Access the start character index for this token in the raw text. \"\"\"\n        return self._start_char\n\n    @property\n    def end_char(self):\n        \"\"\" Access the end character index for this token in the raw text. \"\"\"\n        return self._end_char\n\n    @property\n    def ner(self):\n        \"\"\" Access the NER tag of this token. Example: 'B-ORG'\"\"\"\n        return self._ner\n\n    @ner.setter\n    def ner(self, value):\n        \"\"\" Set the token's NER tag. Example: 'B-ORG'\"\"\"\n        self._ner = value if self._is_null(value) == False else None\n\n    @property\n    def multi_ner(self):\n        \"\"\" Access the MULTI_NER tag of this token. Example: '(B-ORG, B-DISEASE)'\"\"\"\n        return self._multi_ner\n\n    @multi_ner.setter\n    def multi_ner(self, value):\n        \"\"\" Set the token's MULTI_NER tag. Example: '(B-ORG, B-DISEASE)'\"\"\"\n        self._multi_ner = value if self._is_null(value) == False else None\n\n    @property\n    def sent(self):\n        \"\"\" Access the pointer to the sentence that this token belongs to. \"\"\"\n        return self._sent\n\n    @sent.setter\n    def sent(self, value):\n        \"\"\" Set the pointer to the sentence that this token belongs to. \"\"\"\n        self._sent = value\n\n    def __repr__(self):\n        return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)\n\n    def __format__(self, spec):\n        if spec == 'C':\n            return \"\\n\".join(self.to_conll_text(DEFAULT_OUTPUT_FIELDS))\n        elif spec == 'P':\n            return self.pretty_print()\n        else:\n            return str(self)\n\n    def to_conll_text(self, fields=DEFAULT_OUTPUT_FIELDS):\n        return \"\\n\".join(dict_to_conll_text(x) for x in self.to_dict(fields))\n\n    def to_dict(self, fields=DEFAULT_OUTPUT_FIELDS):\n        \"\"\" Dumps the token into a list of dictionary for this token with its extended words\n        if the token is a multi-word token.\n        \"\"\"\n        ret = []\n        if len(self.id) > 1:\n            token_dict = {}\n            for field in fields:\n                if getattr(self, field, None) is not None:\n                    token_dict[field] = getattr(self, field)\n            if MISC in fields:\n                spaces_after = self.spaces_after\n                if spaces_after is not None and spaces_after != ' ':\n                    space_misc = space_after_to_misc(spaces_after)\n                    if token_dict.get(MISC):\n                        token_dict[MISC] = token_dict[MISC] + \"|\" + space_misc\n                    else:\n                        token_dict[MISC] = space_misc\n\n                spaces_before = self.spaces_before\n                if spaces_before is not None and spaces_before != '':\n                    space_misc = space_before_to_misc(spaces_before)\n                    if token_dict.get(MISC):\n                        token_dict[MISC] = token_dict[MISC] + \"|\" + space_misc\n                    else:\n                        token_dict[MISC] = space_misc\n\n            ret.append(token_dict)\n        for word in self.words:\n            word_dict = word.to_dict(fields)\n            if len(self.id) == 1 and NER in fields and getattr(self, NER) is not None: # propagate NER label to Word if it is a single-word token\n                word_dict[NER] = getattr(self, NER)\n            if len(self.id) == 1 and MULTI_NER in fields and getattr(self, MULTI_NER) is not None: # propagate MULTI_NER label to Word if it is a single-word token\n                word_dict[MULTI_NER] = getattr(self, MULTI_NER)\n            if len(self.id) == 1 and MISC in fields:\n                spaces_after = self.spaces_after\n                if spaces_after is not None and spaces_after != ' ':\n                    space_misc = space_after_to_misc(spaces_after)\n                    if word_dict.get(MISC):\n                        word_dict[MISC] = word_dict[MISC] + \"|\" + space_misc\n                    else:\n                        word_dict[MISC] = space_misc\n\n                spaces_before = self.spaces_before\n                if spaces_before is not None and spaces_before != '':\n                    space_misc = space_before_to_misc(spaces_before)\n                    if word_dict.get(MISC):\n                        word_dict[MISC] = word_dict[MISC] + \"|\" + space_misc\n                    else:\n                        word_dict[MISC] = space_misc\n            ret.append(word_dict)\n        return ret\n\n    def pretty_print(self):\n        \"\"\" Print this token with its extended words in one line. \"\"\"\n        return f\"<{self.__class__.__name__} id={'-'.join([str(x) for x in self.id])};words=[{', '.join([word.pretty_print() for word in self.words])}]>\"\n\n    def _is_null(self, value):\n        return (value is None) or (value == '_')\n\n    def is_mwt(self):\n        return len(self.words) > 1\n\nclass Word(StanzaObject):\n    \"\"\" A word class that stores attributes of a word.\n    \"\"\"\n\n    def __init__(self, sentence, word_entry):\n        \"\"\" Construct a word given a dictionary format word entry.\n        \"\"\"\n        self._id = word_entry.get(ID, None)\n        if isinstance(self._id, tuple):\n            if len(self._id) == 1:\n                self._id = self._id[0]\n        self._text = word_entry.get(TEXT, None)\n\n        assert self._id is not None and self._text is not None, 'id and text should be included for the word. {}'.format(word_entry)\n\n        self._lemma = word_entry.get(LEMMA, None)\n        self._upos = word_entry.get(UPOS, None)\n        self._xpos = word_entry.get(XPOS, None)\n        self._feats = word_entry.get(FEATS, None)\n        self._head = word_entry.get(HEAD, None)\n        self._deprel = word_entry.get(DEPREL, None)\n        self._misc = word_entry.get(MISC, None)\n        self._start_char = word_entry.get(START_CHAR, None)\n        self._end_char = word_entry.get(END_CHAR, None)\n        self._parent = None\n        self._sent = sentence\n        self._mexp = word_entry.get(MEXP, None)\n        self._coref_chains = None\n        self._line_number = None\n\n        if self._misc is not None:\n            init_from_misc(self)\n\n        # use the setter, which will go up to the sentence and set the\n        # dependencies on that graph\n        self.deps = word_entry.get(DEPS, None)\n\n    @property\n    def manual_expansion(self):\n        \"\"\" Access the whether this token was manually expanded. \"\"\"\n        return self._mexp\n\n    @manual_expansion.setter\n    def manual_expansion(self, value):\n        \"\"\" Set the whether this token was manually expanded. \"\"\"\n        self._mexp = value\n\n    @property\n    def id(self):\n        \"\"\" Access the index of this word. \"\"\"\n        return self._id\n\n    @id.setter\n    def id(self, value):\n        \"\"\" Set the word's index value. \"\"\"\n        self._id = value\n\n    @property\n    def text(self):\n        \"\"\" Access the text of this word. Example: 'The'\"\"\"\n        return self._text\n\n    @text.setter\n    def text(self, value):\n        \"\"\" Set the word's text value. Example: 'The'\"\"\"\n        self._text = value\n\n    @property\n    def lemma(self):\n        \"\"\" Access the lemma of this word. \"\"\"\n        return self._lemma\n\n    @lemma.setter\n    def lemma(self, value):\n        \"\"\" Set the word's lemma value. \"\"\"\n        self._lemma = value if self._is_null(value) == False or self._text == '_' else None\n\n    @property\n    def upos(self):\n        \"\"\" Access the universal part-of-speech of this word. Example: 'NOUN'\"\"\"\n        return self._upos\n\n    @upos.setter\n    def upos(self, value):\n        \"\"\" Set the word's universal part-of-speech value. Example: 'NOUN'\"\"\"\n        self._upos = value if self._is_null(value) == False else None\n\n    @property\n    def xpos(self):\n        \"\"\" Access the treebank-specific part-of-speech of this word. Example: 'NNP'\"\"\"\n        return self._xpos\n\n    @xpos.setter\n    def xpos(self, value):\n        \"\"\" Set the word's treebank-specific part-of-speech value. Example: 'NNP'\"\"\"\n        self._xpos = value if self._is_null(value) == False else None\n\n    @property\n    def feats(self):\n        \"\"\" Access the morphological features of this word. Example: 'Gender=Fem'\"\"\"\n        return self._feats\n\n    @feats.setter\n    def feats(self, value):\n        \"\"\" Set this word's morphological features. Example: 'Gender=Fem'\"\"\"\n        self._feats = value if self._is_null(value) == False else None\n\n    @property\n    def head(self):\n        \"\"\" Access the id of the governor of this word. \"\"\"\n        return self._head\n\n    @head.setter\n    def head(self, value):\n        \"\"\" Set the word's governor id value. \"\"\"\n        self._head = int(value) if self._is_null(value) == False else None\n\n    @property\n    def deprel(self):\n        \"\"\" Access the dependency relation of this word. Example: 'nmod'\"\"\"\n        return self._deprel\n\n    @deprel.setter\n    def deprel(self, value):\n        \"\"\" Set the word's dependency relation value. Example: 'nmod'\"\"\"\n        self._deprel = value if self._is_null(value) == False else None\n\n    @property\n    def deps(self):\n        \"\"\" Access the dependencies of this word. \"\"\"\n        graph = self._sent._enhanced_dependencies\n        if graph is None or not graph.has_node(self.id):\n            return None\n\n        data = []\n        predecessors = sorted(list(graph.predecessors(self.id)), key=lambda x: x if isinstance(x, tuple) else (x,))\n        for parent in predecessors:\n            deps = sorted(list(graph.get_edge_data(parent, self.id)))\n            for dep in deps:\n                if isinstance(parent, int):\n                    data.append(\"%d:%s\" % (parent, dep))\n                else:\n                    data.append(\"%d.%d:%s\" % (parent[0], parent[1], dep))\n        if not data:\n            return None\n\n        return \"|\".join(data)\n\n    @deps.setter\n    def deps(self, value):\n        \"\"\" Set the word's dependencies value. \"\"\"\n        graph = self._sent._enhanced_dependencies\n        # if we don't have a graph, and we aren't trying to set any actual\n        # dependencies, we can save the time of doing anything else\n        if graph is None and value is None:\n            return\n\n        if graph is None:\n            graph = nx.MultiDiGraph()\n            self._sent._enhanced_dependencies = graph\n        # need to make a new list: cannot iterate and delete at the same time\n        if graph.has_node(self.id):\n            in_edges = list(graph.in_edges(self.id))\n            graph.remove_edges_from(in_edges)\n\n        if value is None:\n            return\n\n        if isinstance(value, str):\n            value = value.split(\"|\")\n        if all(isinstance(x, str) for x in value):\n            value = [x.split(\":\", maxsplit=1) for x in value]\n        for parent, dep in value:\n            # we have to match the format of the IDs.  since the IDs\n            # of the words are int if they aren't empty words, we need\n            # to convert single int IDs into int instead of tuple\n            parent = tuple(map(int, parent.split(\".\", maxsplit=1)))\n            if len(parent) == 1:\n                parent = parent[0]\n            graph.add_edge(parent, self.id, dep)\n\n    @property\n    def misc(self):\n        \"\"\" Access the miscellaneousness of this word. \"\"\"\n        return self._misc\n\n    @misc.setter\n    def misc(self, value):\n        \"\"\" Set the word's miscellaneousness value. \"\"\"\n        self._misc = value if self._is_null(value) == False else None\n\n    @property\n    def line_number(self):\n        \"\"\" Access the line number from the original document, if set \"\"\"\n        return self._line_number\n\n    @property\n    def start_char(self):\n        \"\"\" Access the start character index for this token in the raw text. \"\"\"\n        return self._start_char\n\n    @start_char.setter\n    def start_char(self, value):\n        self._start_char = value\n\n    @property\n    def end_char(self):\n        \"\"\" Access the end character index for this token in the raw text. \"\"\"\n        return self._end_char\n\n    @end_char.setter\n    def end_char(self, value):\n        self._end_char = value\n\n    @property\n    def parent(self):\n        \"\"\" Access the parent token of this word. In the case of a multi-word token, a token can be the parent of\n        multiple words. Note that this should return a reference to the parent token object.\n        \"\"\"\n        return self._parent\n\n    @parent.setter\n    def parent(self, value):\n        \"\"\" Set this word's parent token. In the case of a multi-word token, a token can be the parent of\n        multiple words. Note that value here should be a reference to the parent token object.\n        \"\"\"\n        self._parent = value\n\n    @property\n    def pos(self):\n        \"\"\" Access the universal part-of-speech of this word. Example: 'NOUN'\"\"\"\n        return self._upos\n\n    @pos.setter\n    def pos(self, value):\n        \"\"\" Set the word's universal part-of-speech value. Example: 'NOUN'\"\"\"\n        self._upos = value if self._is_null(value) == False else None\n\n    @property\n    def coref_chains(self):\n        \"\"\"\n        coref_chains points to a list of CorefChain namedtuple, which has a list of mentions and a representative mention.\n\n        Useful for disambiguating words such as \"him\" (in languages where coref is available)\n\n        Theoretically it is possible for multiple corefs to occur at the same word.  For example,\n          \"Chris Manning's NLP Group\"\n        could have \"Chris Manning\" and \"Chris Manning's NLP Group\" as overlapping entities\n        \"\"\"\n        return self._coref_chains\n\n    @coref_chains.setter\n    def coref_chains(self, chain):\n        \"\"\" Set the backref for the coref chains \"\"\"\n        self._coref_chains = chain\n\n    @property\n    def sent(self):\n        \"\"\" Access the pointer to the sentence that this word belongs to. \"\"\"\n        return self._sent\n\n    @sent.setter\n    def sent(self, value):\n        \"\"\" Set the pointer to the sentence that this word belongs to. \"\"\"\n        self._sent = value\n\n    def __repr__(self):\n        return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)\n\n    def __format__(self, spec):\n        if spec == 'C':\n            return self.to_conll_text(DEFAULT_OUTPUT_FIELDS)\n        elif spec == 'P':\n            return self.pretty_print()\n        else:\n            return str(self)\n\n    def to_conll_text(self, fields=DEFAULT_OUTPUT_FIELDS):\n        \"\"\"\n        Turn a word into a conll representation (10 column tab separated)\n        \"\"\"\n        token_dict = self.to_dict(fields)\n        return dict_to_conll_text(token_dict, '.')\n\n    def to_dict(self, fields=DEFAULT_OUTPUT_FIELDS):\n        \"\"\" Dumps the word into a dictionary.\n        \"\"\"\n        word_dict = {}\n        for field in fields:\n            if getattr(self, field, None) is not None:\n                word_dict[field] = getattr(self, field)\n        return word_dict\n\n    def pretty_print(self):\n        \"\"\" Print the word in one line. \"\"\"\n        features = [ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL]\n        feature_str = \";\".join([\"{}={}\".format(k, getattr(self, k)) for k in features if getattr(self, k) is not None])\n        return f\"<{self.__class__.__name__} {feature_str}>\"\n\n    def _is_null(self, value):\n        return (value is None) or (value == '_')\n\n\nclass Span(StanzaObject):\n    \"\"\" A span class that stores attributes of a textual span. A span can be typed.\n    A range of objects (e.g., entity mentions) can be represented as spans.\n    \"\"\"\n\n    def __init__(self, span_entry=None, tokens=None, type=None, doc=None, sent=None):\n        \"\"\" Construct a span given a span entry or a list of tokens. A valid reference to a doc\n        must be provided to construct a span (otherwise the text of the span cannot be initialized).\n        \"\"\"\n        assert span_entry is not None or (tokens is not None and type is not None), \\\n                'Either a span_entry or a token list needs to be provided to construct a span.'\n        assert doc is not None, 'A parent doc must be provided to construct a span.'\n        self._text, self._type, self._start_char, self._end_char = [None] * 4\n        self._tokens = []\n        self._words = []\n        self._doc = doc\n        self._sent = sent\n\n        if span_entry is not None:\n            self.init_from_entry(span_entry)\n\n        if tokens is not None:\n            self.init_from_tokens(tokens, type)\n\n    def init_from_entry(self, span_entry):\n        self.text = span_entry.get(TEXT, None)\n        self.type = span_entry.get(TYPE, None)\n        self.start_char = span_entry.get(START_CHAR, None)\n        self.end_char = span_entry.get(END_CHAR, None)\n\n    def init_from_tokens(self, tokens, type):\n        assert isinstance(tokens, list), 'Tokens must be provided as a list to construct a span.'\n        assert len(tokens) > 0, \"Tokens of a span cannot be an empty list.\"\n        self.tokens = tokens\n        self.type = type\n        # load start and end char offsets from tokens\n        self.start_char = self.tokens[0].start_char\n        self.end_char = self.tokens[-1].end_char\n        if self.doc is not None and self.doc.text is not None:\n            self.text = self.doc.text[self.start_char:self.end_char]\n        elif tokens[0].sent is tokens[-1].sent:\n            sentence = tokens[0].sent\n            if tokens[-1].end_char is not None and tokens[0].start_char is not None and sentence.tokens[0].start_char is not None:\n                text_start = tokens[0].start_char - sentence.tokens[0].start_char\n                text_end = tokens[-1].end_char - sentence.tokens[0].start_char\n                self.text = sentence.text[text_start:text_end]\n            else:\n                text = []\n                for token in tokens:\n                    text.append(token.text)\n                    text.append(token.spaces_after)\n                self.text = \"\".join(text[:-1])\n        else:\n            # TODO: do any spans ever cross sentences?\n            raise RuntimeError(\"Document text does not exist, and the span tested crosses two sentences, so it is impossible to extract the entity text!\")\n        # collect the words of the span following tokens\n        self.words = [w for t in tokens for w in t.words]\n        # set the sentence back-pointer to point to the sentence of the first token\n        self.sent = tokens[0].sent\n\n    @property\n    def doc(self):\n        \"\"\" Access the parent doc of this span. \"\"\"\n        return self._doc\n\n    @doc.setter\n    def doc(self, value):\n        \"\"\" Set the parent doc of this span. \"\"\"\n        self._doc = value\n\n    @property\n    def text(self):\n        \"\"\" Access the text of this span. Example: 'Stanford University'\"\"\"\n        return self._text\n\n    @text.setter\n    def text(self, value):\n        \"\"\" Set the span's text value. Example: 'Stanford University'\"\"\"\n        self._text = value\n\n    @property\n    def tokens(self):\n        \"\"\" Access reference to a list of tokens that correspond to this span. \"\"\"\n        return self._tokens\n\n    @tokens.setter\n    def tokens(self, value):\n        \"\"\" Set the span's list of tokens. \"\"\"\n        self._tokens = value\n\n    @property\n    def words(self):\n        \"\"\" Access reference to a list of words that correspond to this span. \"\"\"\n        return self._words\n\n    @words.setter\n    def words(self, value):\n        \"\"\" Set the span's list of words. \"\"\"\n        self._words = value\n\n    @property\n    def type(self):\n        \"\"\" Access the type of this span. Example: 'PERSON'\"\"\"\n        return self._type\n\n    @type.setter\n    def type(self, value):\n        \"\"\" Set the type of this span. \"\"\"\n        self._type = value\n\n    @property\n    def start_char(self):\n        \"\"\" Access the start character offset of this span. \"\"\"\n        return self._start_char\n\n    @start_char.setter\n    def start_char(self, value):\n        \"\"\" Set the start character offset of this span. \"\"\"\n        self._start_char = value\n\n    @property\n    def end_char(self):\n        \"\"\" Access the end character offset of this span. \"\"\"\n        return self._end_char\n\n    @end_char.setter\n    def end_char(self, value):\n        \"\"\" Set the end character offset of this span. \"\"\"\n        self._end_char = value\n\n    @property\n    def sent(self):\n        \"\"\" Access the pointer to the sentence that this span belongs to. \"\"\"\n        return self._sent\n\n    @sent.setter\n    def sent(self, value):\n        \"\"\" Set the pointer to the sentence that this span belongs to. \"\"\"\n        self._sent = value\n\n    def to_dict(self):\n        \"\"\" Dumps the span into a dictionary. \"\"\"\n        attrs = ['text', 'type', 'start_char', 'end_char']\n        span_dict = dict([(attr_name, getattr(self, attr_name)) for attr_name in attrs])\n        return span_dict\n\n    def __repr__(self):\n        return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)\n\n    def pretty_print(self):\n        \"\"\" Print the span in one line. \"\"\"\n        span_dict = self.to_dict()\n        feature_str = \";\".join([\"{}={}\".format(k,v) for k,v in span_dict.items()])\n        return f\"<{self.__class__.__name__} {feature_str}>\"\n"
  },
  {
    "path": "stanza/models/common/dropout.py",
    "content": "import torch\nimport torch.nn as nn\n\nclass WordDropout(nn.Module):\n    \"\"\" A word dropout layer that's designed for embedded inputs (e.g., any inputs to an LSTM layer).\n    Given a batch of embedded inputs, this layer randomly set some of them to be a replacement state.\n    Note that this layer assumes the last dimension of the input to be the hidden dimension of a unit.\n    \"\"\"\n    def __init__(self, dropprob):\n        super().__init__()\n        self.dropprob = dropprob\n\n    def forward(self, x, replacement=None):\n        if not self.training or self.dropprob == 0:\n            return x\n\n        masksize = [y for y in x.size()]\n        masksize[-1] = 1\n        dropmask = torch.rand(*masksize, device=x.device) < self.dropprob\n\n        res = x.masked_fill(dropmask, 0)\n        if replacement is not None:\n            res = res + dropmask.float() * replacement\n\n        return res\n    \n    def extra_repr(self):\n        return 'p={}'.format(self.dropprob)\n\nclass LockedDropout(nn.Module):\n    \"\"\"\n    A variant of dropout layer that consistently drops out the same parameters over time. Also known as the variational dropout. \n    This implementation was modified from the LockedDropout implementation in the flair library (https://github.com/zalandoresearch/flair).\n    \"\"\"\n    def __init__(self, dropprob, batch_first=True):\n        super().__init__()\n        self.dropprob = dropprob\n        self.batch_first = batch_first\n\n    def forward(self, x):\n        if not self.training or self.dropprob == 0:\n            return x\n\n        if not self.batch_first:\n            m = x.new_empty(1, x.size(1), x.size(2), requires_grad=False).bernoulli_(1 - self.dropprob)\n        else:\n            m = x.new_empty(x.size(0), 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropprob)\n\n        mask = m.div(1 - self.dropprob).expand_as(x)\n        return mask * x\n    \n    def extra_repr(self):\n        return 'p={}'.format(self.dropprob)\n\nclass SequenceUnitDropout(nn.Module):\n    \"\"\" A unit dropout layer that's designed for input of sequence units (e.g., word sequence, char sequence, etc.).\n    Given a sequence of unit indices, this layer randomly set some of them to be a replacement id (usually set to be <UNK>).\n    \"\"\"\n    def __init__(self, dropprob, replacement_id):\n        super().__init__()\n        self.dropprob = dropprob\n        self.replacement_id = replacement_id\n\n    def forward(self, x):\n        \"\"\" :param: x must be a LongTensor of unit indices. \"\"\"\n        if not self.training or self.dropprob == 0:\n            return x\n        masksize = [y for y in x.size()]\n        dropmask = torch.rand(*masksize, device=x.device) < self.dropprob\n        res = x.masked_fill(dropmask, self.replacement_id)\n        return res\n    \n    def extra_repr(self):\n        return 'p={}, replacement_id={}'.format(self.dropprob, self.replacement_id)\n\n"
  },
  {
    "path": "stanza/models/common/exceptions.py",
    "content": "\"\"\"\nA couple more specific FileNotFoundError exceptions\n\nThe idea being, the caller can catch it and report a more useful error resolution\n\"\"\"\n\nimport errno\n\nclass ForwardCharlmNotFoundError(FileNotFoundError):\n    def __init__(self, msg, filename):\n        super().__init__(errno.ENOENT, msg, filename)\n\nclass BackwardCharlmNotFoundError(FileNotFoundError):\n    def __init__(self, msg, filename):\n        super().__init__(errno.ENOENT, msg, filename)\n"
  },
  {
    "path": "stanza/models/common/foundation_cache.py",
    "content": "\"\"\"\nKeeps BERT, charlm, word embedings in a cache to save memory\n\"\"\"\n\nfrom collections import namedtuple\nfrom copy import deepcopy\nimport logging\nimport threading\n\nfrom stanza.models.common import bert_embedding\nfrom stanza.models.common.char_model import CharacterLanguageModel\nfrom stanza.models.common.pretrain import Pretrain\n\nlogger = logging.getLogger('stanza')\n\nBertRecord = namedtuple('BertRecord', ['model', 'tokenizer', 'peft_ids'])\n\nclass FoundationCache:\n    def __init__(self, other=None, local_files_only=False):\n        if other is None:\n            self.bert = {}\n            self.charlms = {}\n            self.pretrains = {}\n            # future proof the module by using a lock for the glorious day\n            # when the GIL is finally gone\n            self.lock = threading.Lock()\n        else:\n            self.bert = other.bert\n            self.charlms = other.charlms\n            self.pretrains = other.pretrains\n            self.lock = other.lock\n        self.local_files_only=local_files_only\n\n    def load_bert(self, transformer_name, local_files_only=None):\n        m, t, _ = self.load_bert_with_peft(transformer_name, None, local_files_only=local_files_only)\n        return m, t\n\n    def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):\n        \"\"\"\n        Load a transformer only once\n\n        Uses a lock for thread safety\n        \"\"\"\n        if transformer_name is None:\n            return None, None, None\n        with self.lock:\n            if transformer_name not in self.bert:\n                if local_files_only is None:\n                    local_files_only = self.local_files_only\n                model, tokenizer = bert_embedding.load_bert(transformer_name, local_files_only=local_files_only)\n                self.bert[transformer_name] = BertRecord(model, tokenizer, {})\n            else:\n                logger.debug(\"Reusing bert %s\", transformer_name)\n\n            bert_record = self.bert[transformer_name]\n            if not peft_name:\n                return bert_record.model, bert_record.tokenizer, None\n            if peft_name not in bert_record.peft_ids:\n                bert_record.peft_ids[peft_name] = 0\n            else:\n                bert_record.peft_ids[peft_name] = bert_record.peft_ids[peft_name] + 1\n            peft_name = \"%s_%d\" % (peft_name, bert_record.peft_ids[peft_name])\n            return bert_record.model, bert_record.tokenizer, peft_name\n\n    def load_charlm(self, filename):\n        if not filename:\n            return None\n\n        with self.lock:\n            if filename not in self.charlms:\n                logger.debug(\"Loading charlm from %s\", filename)\n                self.charlms[filename] = CharacterLanguageModel.load(filename, finetune=False)\n            else:\n                logger.debug(\"Reusing charlm from %s\", filename)\n\n            return self.charlms[filename]\n\n    def load_pretrain(self, filename):\n        \"\"\"\n        Load a pretrained word embedding only once\n\n        Uses a lock for thread safety\n        \"\"\"\n        if filename is None:\n            return None\n        with self.lock:\n            if filename not in self.pretrains:\n                logger.debug(\"Loading pretrain %s\", filename)\n                self.pretrains[filename] = Pretrain(filename)\n            else:\n                logger.debug(\"Reusing pretrain %s\", filename)\n\n            return self.pretrains[filename]\n\nclass NoTransformerFoundationCache(FoundationCache):\n    \"\"\"\n    Uses the underlying FoundationCache, but hiding the transformer.\n\n    Useful for when loading a downstream model such as POS which has a\n    finetuned transformer, and we don't want the transformer reused\n    since it will then have the finetuned weights for other models\n    which don't want them\n    \"\"\"\n    def load_bert(self, transformer_name, local_files_only=None):\n        return load_bert(transformer_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)\n\n    def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):\n        return load_bert_with_peft(transformer_name, peft_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)\n\ndef load_bert(model_name, foundation_cache=None, local_files_only=None):\n    \"\"\"\n    Load a bert, possibly using a foundation cache, ignoring the cache if None\n    \"\"\"\n    if foundation_cache is None:\n        return bert_embedding.load_bert(model_name, local_files_only=local_files_only)\n    else:\n        return foundation_cache.load_bert(model_name, local_files_only=local_files_only)\n\ndef load_bert_with_peft(model_name, peft_name, foundation_cache=None, local_files_only=None):\n    if foundation_cache is None:\n        m, t = bert_embedding.load_bert(model_name, local_files_only=local_files_only)\n        return m, t, peft_name\n    return foundation_cache.load_bert_with_peft(model_name, peft_name, local_files_only=local_files_only)\n\ndef load_charlm(charlm_file, foundation_cache=None, finetune=False):\n    if not charlm_file:\n        return None\n\n    if finetune:\n        # can't use the cache in the case of a model which will be finetuned\n        # and the numbers will be different for other users of the model\n        return CharacterLanguageModel.load(charlm_file, finetune=True)\n\n    if foundation_cache is not None:\n        return foundation_cache.load_charlm(charlm_file)\n\n    logger.debug(\"Loading charlm from %s\", charlm_file)\n    return CharacterLanguageModel.load(charlm_file, finetune=False)\n\ndef load_pretrain(filename, foundation_cache=None):\n    if not filename:\n        return None\n\n    if foundation_cache is not None:\n        return foundation_cache.load_pretrain(filename)\n\n    logger.debug(\"Loading pretrain from %s\", filename)\n    return Pretrain(filename)\n"
  },
  {
    "path": "stanza/models/common/hlstm.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence\n\nfrom stanza.models.common.packed_lstm import PackedLSTM\n\nclass HLSTMCell(nn.modules.rnn.RNNCellBase):\n    \"\"\"\n    A Highway LSTM Cell as proposed in Zhang et al. (2018) Highway Long Short-Term Memory RNNs for \n    Distant Speech Recognition.\n    \"\"\"\n    def __init__(self, input_size, hidden_size, bias=True):\n        super(HLSTMCell, self).__init__()\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n\n        # LSTM parameters\n        self.Wi = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)\n        self.Wf = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)\n        self.Wo = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)\n        self.Wg = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)\n\n        # highway gate parameters\n        self.gate = nn.Linear(input_size + 2 * hidden_size, hidden_size, bias=bias)\n\n    def forward(self, input, c_l_minus_one=None, hx=None):\n        self.check_forward_input(input)\n        if hx is None:\n            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)\n            hx = (hx, hx)\n        if c_l_minus_one is None:\n            c_l_minus_one = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)\n\n        self.check_forward_hidden(input, hx[0], '[0]')\n        self.check_forward_hidden(input, hx[1], '[1]')\n        self.check_forward_hidden(input, c_l_minus_one, 'c_l_minus_one')\n\n        # vanilla LSTM computation\n        rec_input = torch.cat([input, hx[0]], 1)\n        i = F.sigmoid(self.Wi(rec_input))\n        f = F.sigmoid(self.Wf(rec_input))\n        o = F.sigmoid(self.Wo(rec_input))\n        g = F.tanh(self.Wg(rec_input))\n\n        # highway gates\n        gate = F.sigmoid(self.gate(torch.cat([c_l_minus_one, hx[1], input], 1)))\n\n        c = gate * c_l_minus_one + f * hx[1] + i * g\n        h = o * F.tanh(c)\n\n        return h, c\n\n# Highway LSTM network, does NOT use the HLSTMCell above\nclass HighwayLSTM(nn.Module):\n    \"\"\"\n    A Highway LSTM network, as used in the original Tensorflow version of the Dozat parser. Note that this\n    is independent from the HLSTMCell above.\n    \"\"\"\n    def __init__(self, input_size, hidden_size,\n                 num_layers=1, bias=True, batch_first=False,\n                 dropout=0, bidirectional=False, rec_dropout=0, highway_func=None, pad=False):\n        super(HighwayLSTM, self).__init__()\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.bias = bias\n        self.batch_first = batch_first\n        self.dropout = dropout\n        self.dropout_state = {}\n        self.bidirectional = bidirectional\n        self.num_directions = 2 if bidirectional else 1\n        self.highway_func = highway_func\n        self.pad = pad\n\n        self.lstm = nn.ModuleList()\n        self.highway = nn.ModuleList()\n        self.gate = nn.ModuleList()\n        self.drop = nn.Dropout(dropout, inplace=True)\n\n        in_size = input_size\n        for l in range(num_layers):\n            self.lstm.append(PackedLSTM(in_size, hidden_size, num_layers=1, bias=bias,\n                batch_first=batch_first, dropout=0, bidirectional=bidirectional, rec_dropout=rec_dropout))\n            self.highway.append(nn.Linear(in_size, hidden_size * self.num_directions))\n            self.gate.append(nn.Linear(in_size, hidden_size * self.num_directions))\n            self.highway[-1].bias.data.zero_()\n            self.gate[-1].bias.data.zero_()\n            in_size = hidden_size * self.num_directions\n\n    def forward(self, input, seqlens, hx=None):\n        highway_func = (lambda x: x) if self.highway_func is None else self.highway_func\n\n        hs = []\n        cs = []\n\n        if not isinstance(input, PackedSequence):\n            input = pack_padded_sequence(input, seqlens, batch_first=self.batch_first)\n\n        for l in range(self.num_layers):\n            if l > 0:\n                input = PackedSequence(self.drop(input.data), input.batch_sizes, input.sorted_indices, input.unsorted_indices)\n            layer_hx = (hx[0][l * self.num_directions:(l+1)*self.num_directions], hx[1][l * self.num_directions:(l+1)*self.num_directions]) if hx is not None else None\n            h, (ht, ct) = self.lstm[l](input, seqlens, layer_hx)\n\n            hs.append(ht)\n            cs.append(ct)\n\n            input = PackedSequence(h.data + torch.sigmoid(self.gate[l](input.data)) * highway_func(self.highway[l](input.data)), input.batch_sizes, input.sorted_indices, input.unsorted_indices)\n\n        if self.pad:\n            input = pad_packed_sequence(input, batch_first=self.batch_first)[0]\n        return input, (torch.cat(hs, 0), torch.cat(cs, 0))\n\nif __name__ == \"__main__\":\n    T = 10\n    bidir = True\n    num_dir = 2 if bidir else 1\n    rnn = HighwayLSTM(10, 20, num_layers=2, bidirectional=True)\n    input = torch.randn(T, 3, 10)\n    hx = torch.randn(2 * num_dir, 3, 20)\n    cx = torch.randn(2 * num_dir, 3, 20)\n    output = rnn(input, (hx, cx))\n    print(output)\n"
  },
  {
    "path": "stanza/models/common/large_margin_loss.py",
    "content": "\"\"\"\nLargeMarginInSoftmax, from the article\n\n@inproceedings{kobayashi2019bmvc,\n  title={Large Margin In Softmax Cross-Entropy Loss},\n  author={Takumi Kobayashi},\n  booktitle={Proceedings of the British Machine Vision Conference (BMVC)},\n  year={2019}\n}\n\nimplementation from\n\nhttps://github.com/tk1980/LargeMarginInSoftmax\n\nThere is no license specifically chosen; they just ask people to cite the paper if the work is useful.\n\"\"\"\n\n\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\n\n\nclass LargeMarginInSoftmaxLoss(nn.CrossEntropyLoss):\n    r\"\"\"\n    This combines the Softmax Cross-Entropy Loss (nn.CrossEntropyLoss) and the large-margin inducing\n    regularization proposed in\n       T. Kobayashi, \"Large-Margin In Softmax Cross-Entropy Loss.\" In BMVC2019.\n\n    This loss function inherits the parameters from nn.CrossEntropyLoss except for `reg_lambda` and `deg_logit`.\n    Args:\n         reg_lambda (float, optional): a regularization parameter. (default: 0.3)\n         deg_logit (bool, optional): underestimate (degrade) the target logit by -1 or not. (default: False)\n                                     If True, it realizes the method that incorporates the modified loss into ours\n                                     as described in the above paper (Table 4).\n    \"\"\"\n    def __init__(self, reg_lambda=0.3, deg_logit=None,\n                weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'):\n        super(LargeMarginInSoftmaxLoss, self).__init__(weight=weight, size_average=size_average,\n                                ignore_index=ignore_index, reduce=reduce, reduction=reduction)\n        self.reg_lambda = reg_lambda\n        self.deg_logit = deg_logit\n\n    def forward(self, input, target):\n        N = input.size(0) # number of samples\n        C = input.size(1) # number of classes\n        Mask = torch.zeros_like(input, requires_grad=False)\n        Mask[range(N),target] = 1\n\n        if self.deg_logit is not None:\n            input = input - self.deg_logit * Mask\n\n        loss = F.cross_entropy(input, target, weight=self.weight,\n                               ignore_index=self.ignore_index, reduction=self.reduction)\n\n        X = input - 1.e6 * Mask # [N x C], excluding the target class\n        reg = 0.5 * ((F.softmax(X, dim=1) - 1.0/(C-1)) * F.log_softmax(X, dim=1) * (1.0-Mask)).sum(dim=1)\n        if self.reduction == 'sum':\n            reg = reg.sum()\n        elif self.reduction == 'mean':\n            reg = reg.mean()\n        elif self.reduction == 'none':\n            reg = reg\n\n        return loss + self.reg_lambda * reg\n"
  },
  {
    "path": "stanza/models/common/loss.py",
    "content": "\"\"\"\nDifferent loss functions.\n\"\"\"\n\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nimport stanza.models.common.seq2seq_constant as constant\n\nlogger = logging.getLogger('stanza')\n\ndef SequenceLoss(vocab_size):\n    weight = torch.ones(vocab_size)\n    weight[constant.PAD_ID] = 0\n    crit = nn.NLLLoss(weight)\n    return crit\n\ndef weighted_cross_entropy_loss(labels, log_dampened=False):\n    \"\"\"\n    Either return a loss function which reweights all examples so the\n    classes have the same effective weight, or dampened reweighting\n    using log() so that the biggest class has some priority\n    \"\"\"\n    if isinstance(labels, list):\n        all_labels = np.array(labels)\n    _, weights = np.unique(labels, return_counts=True)\n    weights = weights / float(np.sum(weights))\n    weights = np.sum(weights) / weights\n    if log_dampened:\n        weights = 1 + np.log(weights)\n    logger.debug(\"Reweighting cross entropy by {}\".format(weights))\n    loss = nn.CrossEntropyLoss(\n        weight=torch.from_numpy(weights).type('torch.FloatTensor')\n    )\n    return loss\n\nclass FocalLoss(nn.Module):\n    \"\"\"\n    Uses the model's assessment of how likely the correct answer is\n    to weight the loss for a each error\n\n    multi-category focal loss, in other words\n\n    from \"Focal Loss for Dense Object Detection\"\n\n    https://arxiv.org/abs/1708.02002\n    \"\"\"\n    def __init__(self, reduction='mean', gamma=2.0):\n        super().__init__()\n        if reduction not in ('sum', 'none', 'mean'):\n            raise ValueError(\"Unknown reduction: %s\" % reduction)\n\n        self.reduction = reduction\n        self.ce_loss = nn.CrossEntropyLoss(reduction='none')\n        self.gamma = gamma\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Weight the loss using the models assessment of the correct answer\n\n        inputs: [N, C]\n        targets: [N]\n        \"\"\"\n        if len(inputs.shape) == 2 and len(targets.shape) == 1:\n            if inputs.shape[0] != targets.shape[0]:\n                raise ValueError(\"Expected inputs N,C and targets N, but got {} and {}\".format(inputs.shape, targets.shape))\n        elif len(inputs.shape) == 1 and len(targets.shape) == 0:\n            raise NotImplementedError(\"This would be a reasonable thing to implement, but we haven't done it yet\")\n        else:\n            raise ValueError(\"Expected inputs N,C and targets N, but got {} and {}\".format(inputs.shape, targets.shape))\n\n        raw_loss = self.ce_loss(inputs, targets)\n        assert len(raw_loss.shape) == 1 and raw_loss.shape[0] == inputs.shape[0]\n\n        # https://www.tutorialexample.com/implement-focal-loss-for-multi-label-classification-in-pytorch-pytorch-tutorial/\n        final_loss = raw_loss * ((1 - torch.exp(-raw_loss)) ** self.gamma)\n        assert len(final_loss.shape) == 1 and final_loss.shape[0] == inputs.shape[0]\n        if self.reduction == 'sum':\n            return final_loss.sum()\n        elif self.reduction == 'mean':\n            return final_loss.mean()\n        elif self.reduction == 'none':\n            return final_loss\n        raise AssertionError(\"unknown reduction!  how did this happen??\")\n\nclass MixLoss(nn.Module):\n    \"\"\"\n    A mixture of SequenceLoss and CrossEntropyLoss.\n    Loss = SequenceLoss + alpha * CELoss\n    \"\"\"\n    def __init__(self, vocab_size, alpha):\n        super().__init__()\n        self.seq_loss = SequenceLoss(vocab_size)\n        self.ce_loss = nn.CrossEntropyLoss()\n        assert alpha >= 0\n        self.alpha = alpha\n\n    def forward(self, seq_inputs, seq_targets, class_inputs, class_targets):\n        sl = self.seq_loss(seq_inputs, seq_targets)\n        cel = self.ce_loss(class_inputs, class_targets)\n        loss = sl + self.alpha * cel\n        return loss\n\nclass MaxEntropySequenceLoss(nn.Module):\n    \"\"\"\n    A max entropy loss that encourage the model to have large entropy,\n    therefore giving more diverse outputs.\n\n    Loss = NLLLoss + alpha * EntropyLoss\n    \"\"\"\n    def __init__(self, vocab_size, alpha):\n        super().__init__()\n        weight = torch.ones(vocab_size)\n        weight[constant.PAD_ID] = 0\n        self.nll = nn.NLLLoss(weight)\n        self.alpha = alpha\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        inputs: [N, C]\n        targets: [N]\n        \"\"\"\n        assert inputs.size(0) == targets.size(0)\n        nll_loss = self.nll(inputs, targets)\n        # entropy loss\n        mask = targets.eq(constant.PAD_ID).unsqueeze(1).expand_as(inputs)\n        masked_inputs = inputs.clone().masked_fill_(mask, 0.0)\n        p = torch.exp(masked_inputs)\n        ent_loss = p.mul(masked_inputs).sum() / inputs.size(0) # average over minibatch\n        loss = nll_loss + self.alpha * ent_loss\n        return loss\n\n"
  },
  {
    "path": "stanza/models/common/maxout_linear.py",
    "content": "\"\"\"\nA layer which implements maxout from the \"Maxout Networks\" paper\n\nhttps://arxiv.org/pdf/1302.4389v4.pdf\nGoodfellow, Warde-Farley, Mirza, Courville, Bengio\n\nor a simpler explanation here:\n\nhttps://stats.stackexchange.com/questions/129698/what-is-maxout-in-neural-network/298705#298705\n\nThe implementation here:\nfor k layers of maxout, in -> out channels, we make a single linear\n  map of size in -> out*k\nthen we reshape the end to be (..., k, out)\nand return the max over the k layers\n\"\"\"\n\n\nimport torch\nimport torch.nn as nn\n\nclass MaxoutLinear(nn.Module):\n    def __init__(self, in_channels, out_channels, maxout_k):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.maxout_k = maxout_k\n\n        self.linear = nn.Linear(in_channels, out_channels * maxout_k)\n\n    def forward(self, inputs):\n        \"\"\"\n        Use the oversized linear as the repeated linear, then take the max\n\n        One large linear map makes the implementation simpler and easier for pytorch to make parallel\n        \"\"\"\n        outputs = self.linear(inputs)\n        outputs = outputs.view(*outputs.shape[:-1], self.maxout_k, self.out_channels)\n        outputs = torch.max(outputs, dim=-2)[0]\n        return outputs\n\n"
  },
  {
    "path": "stanza/models/common/packed_lstm.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence\n\nclass PackedLSTM(nn.Module):\n    def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):\n        super().__init__()\n\n        self.batch_first = batch_first\n        self.pad = pad\n        if rec_dropout == 0:\n            # use the fast, native LSTM implementation\n            self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)\n        else:\n            self.lstm = LSTMwRecDropout(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, rec_dropout=rec_dropout)\n\n    def forward(self, input, lengths, hx=None):\n        if not isinstance(input, PackedSequence):\n            input = pack_padded_sequence(input, lengths, batch_first=self.batch_first)\n\n        res = self.lstm(input, hx)\n        if self.pad:\n            res = (pad_packed_sequence(res[0], batch_first=self.batch_first)[0], res[1])\n        return res\n\nclass LSTMwRecDropout(nn.Module):\n    \"\"\" An LSTM implementation that supports recurrent dropout \"\"\"\n    def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):\n        super().__init__()\n        self.batch_first = batch_first\n        self.pad = pad\n        self.num_layers = num_layers\n        self.hidden_size = hidden_size\n\n        self.dropout = dropout\n        self.drop = nn.Dropout(dropout, inplace=True)\n        self.rec_drop = nn.Dropout(rec_dropout, inplace=True)\n\n        self.num_directions = 2 if bidirectional else 1\n\n        self.cells = nn.ModuleList()\n        for l in range(num_layers):\n            in_size = input_size if l == 0 else self.num_directions * hidden_size\n            for d in range(self.num_directions):\n                self.cells.append(nn.LSTMCell(in_size, hidden_size, bias=bias))\n\n    def forward(self, input, hx=None):\n        def rnn_loop(x, batch_sizes, cell, inits, reverse=False):\n            # RNN loop for one layer in one direction with recurrent dropout\n            # Assumes input is PackedSequence, returns PackedSequence as well\n            batch_size = batch_sizes[0].item()\n            states = [list(init.split([1] * batch_size)) for init in inits]\n            h_drop_mask = x.new_ones(batch_size, self.hidden_size)\n            h_drop_mask = self.rec_drop(h_drop_mask)\n            resh = []\n\n            if not reverse:\n                st = 0\n                for bs in batch_sizes:\n                    s1 = cell(x[st:st+bs], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))\n                    resh.append(s1[0])\n                    for j in range(bs):\n                        states[0][j] = s1[0][j].unsqueeze(0)\n                        states[1][j] = s1[1][j].unsqueeze(0)\n                    st += bs\n            else:\n                en = x.size(0)\n                for i in range(batch_sizes.size(0)-1, -1, -1):\n                    bs = batch_sizes[i]\n                    s1 = cell(x[en-bs:en], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))\n                    resh.append(s1[0])\n                    for j in range(bs):\n                        states[0][j] = s1[0][j].unsqueeze(0)\n                        states[1][j] = s1[1][j].unsqueeze(0)\n                    en -= bs\n                resh = list(reversed(resh))\n\n            return torch.cat(resh, 0), tuple(torch.cat(s, 0) for s in states)\n\n        all_states = [[], []]\n        inputdata, batch_sizes = input.data, input.batch_sizes\n        for l in range(self.num_layers):\n            new_input = []\n\n            if self.dropout > 0 and l > 0:\n                inputdata = self.drop(inputdata)\n            for d in range(self.num_directions):\n                idx = l * self.num_directions + d\n                cell = self.cells[idx]\n                out, states = rnn_loop(inputdata, batch_sizes, cell, (hx[i][idx] for i in range(2)) if hx is not None else (input.data.new_zeros(input.batch_sizes[0].item(), self.hidden_size, requires_grad=False) for _ in range(2)), reverse=(d == 1))\n\n                new_input.append(out)\n                all_states[0].append(states[0].unsqueeze(0))\n                all_states[1].append(states[1].unsqueeze(0))\n\n            if self.num_directions > 1:\n                # concatenate both directions\n                inputdata = torch.cat(new_input, 1)\n            else:\n                inputdata = new_input[0]\n\n        input = PackedSequence(inputdata, batch_sizes)\n\n        return input, tuple(torch.cat(x, 0) for x in all_states)\n"
  },
  {
    "path": "stanza/models/common/peft_config.py",
    "content": "\"\"\"\nSet a few common flags for peft uage\n\"\"\"\n\n\nTRANSFORMER_LORA_RANK = {}\nDEFAULT_LORA_RANK = 64\n\nTRANSFORMER_LORA_ALPHA = {}\nDEFAULT_LORA_ALPHA = 128\n\nTRANSFORMER_LORA_DROPOUT = {}\nDEFAULT_LORA_DROPOUT = 0.1\n\n\nTRANSFORMER_LORA_TARGETS = {\n    \"princeton-nlp/Sheared-LLaMA-1.3B\": \"self_attn.k_proj,self_attn.v_proj,self_attn.o_proj,mlp.gate_proj,mlp.up_proj,mlp.down_proj\"\n}\nDEFAULT_LORA_TARGETS = \"query,value,output.dense,intermediate.dense\"\n\nTRANSFORMER_LORA_SAVE = {}\nDEFAULT_LORA_SAVE = \"\"\n\ndef add_peft_args(parser):\n    \"\"\"\n    Add common default flags to an argparse\n    \"\"\"\n    parser.add_argument('--lora_rank', type=int, default=None, help=\"Rank of a LoRA approximation.  Default will be %d or a model-specific parameter\" % DEFAULT_LORA_RANK)\n    parser.add_argument('--lora_alpha', type=int, default=None, help=\"Alpha of a LoRA approximation.  Default will be %d or a model-specific parameter\" % DEFAULT_LORA_ALPHA)\n    parser.add_argument('--lora_dropout', type=float, default=None, help=\"Dropout for the LoRA approximation.  Default will be %s or a model-specific parameter\" % DEFAULT_LORA_DROPOUT)\n    parser.add_argument('--lora_target_modules', type=str, default=None, help=\"Comma separated list of LoRA targets.  Default will be '%s' or a model-specific parameter\" % DEFAULT_LORA_TARGETS)\n    parser.add_argument('--lora_modules_to_save', type=str, default=None, help=\"Comma separated list of modules to save (eg, fully tune) when using LoRA.  Default will be '%s' or a model-specific parameter\" % DEFAULT_LORA_SAVE)\n\n    parser.add_argument('--use_peft', default=False, action='store_true', help=\"Finetune Bert using peft\")\n\ndef pop_peft_args(args):\n    \"\"\"\n    Pop all of the peft-related arguments from a given dict\n\n    Useful for making sure a model loaded from disk is recreated with\n    the right shapes, for example\n    \"\"\"\n    args.pop(\"lora_rank\", None)\n    args.pop(\"lora_alpha\", None)\n    args.pop(\"lora_dropout\", None)\n    args.pop(\"lora_target_modules\", None)\n    args.pop(\"lora_modules_to_save\", None)\n\n    args.pop(\"use_peft\", None)\n\n\ndef resolve_peft_args(args, logger, check_bert_finetune=True):\n    if not hasattr(args, 'bert_model'):\n        return\n\n    if args.lora_rank is None:\n        args.lora_rank = TRANSFORMER_LORA_RANK.get(args.bert_model, DEFAULT_LORA_RANK)\n\n    if args.lora_alpha is None:\n        args.lora_alpha = TRANSFORMER_LORA_ALPHA.get(args.bert_model, DEFAULT_LORA_ALPHA)\n\n    if args.lora_dropout is None:\n        args.lora_dropout = TRANSFORMER_LORA_DROPOUT.get(args.bert_model, DEFAULT_LORA_DROPOUT)\n\n    if args.lora_target_modules is None:\n        args.lora_target_modules = TRANSFORMER_LORA_TARGETS.get(args.bert_model, DEFAULT_LORA_TARGETS)\n    if not args.lora_target_modules.strip():\n        args.lora_target_modules = []\n    else:\n        args.lora_target_modules = args.lora_target_modules.split(\",\")\n\n    if args.lora_modules_to_save is None:\n        args.lora_modules_to_save = TRANSFORMER_LORA_SAVE.get(args.bert_model, DEFAULT_LORA_SAVE)\n    if not args.lora_modules_to_save.strip():\n        args.lora_modules_to_save = []\n    else:\n        args.lora_modules_to_save = args.lora_modules_to_save.split(\",\")\n\n    if check_bert_finetune and hasattr(args, 'bert_finetune'):\n        if args.use_peft and not args.bert_finetune:\n            logger.info(\"--use_peft set.  setting --bert_finetune as well\")\n            args.bert_finetune = True\n\ndef build_peft_config(args, logger):\n    # Hide import so that the peft dependency is optional\n    from peft import LoraConfig\n    logger.debug(\"Creating lora adapter with rank %d and alpha %d\", args['lora_rank'], args['lora_alpha'])\n    peft_config = LoraConfig(inference_mode=False,\n                             r=args['lora_rank'],\n                             target_modules=args['lora_target_modules'],\n                             lora_alpha=args['lora_alpha'],\n                             lora_dropout=args['lora_dropout'],\n                             modules_to_save=args['lora_modules_to_save'],\n                             bias=\"none\")\n    return peft_config\n\ndef build_peft_wrapper(bert_model, args, logger, adapter_name=\"default\"):\n    # Hide import so that the peft dependency is optional\n    from peft import get_peft_model\n    peft_config = build_peft_config(args, logger)\n\n    pefted = get_peft_model(bert_model, peft_config, adapter_name=adapter_name)\n    # apparently get_peft_model doesn't actually mark that\n    # peft configs are loaded, making it impossible to turn off (or on)\n    # the peft adapter later\n    bert_model._hf_peft_config_loaded = True\n    pefted._hf_peft_config_loaded = True\n    pefted.set_adapter(adapter_name)\n    return pefted\n\ndef load_peft_wrapper(bert_model, lora_params, args, logger, adapter_name):\n    peft_config = build_peft_config(args, logger)\n\n    try:\n        bert_model.load_adapter(adapter_name=adapter_name, peft_config=peft_config, adapter_state_dict=lora_params)\n    except (ValueError, TypeError) as _:\n        from peft import set_peft_model_state_dict\n        # this can happen if the adapter already exists...\n        # in that case, try setting the adapter weights?\n        set_peft_model_state_dict(bert_model, lora_params, adapter_name=adapter_name)\n    bert_model.set_adapter(adapter_name)\n    return bert_model\n"
  },
  {
    "path": "stanza/models/common/pretrain.py",
    "content": "\"\"\"\nSupports for pretrained data.\n\"\"\"\nimport csv\nimport os\nimport re\n\nimport lzma\nimport logging\nimport numpy as np\nimport torch\n\nfrom .vocab import BaseVocab, VOCAB_PREFIX, UNK_ID\n\nfrom stanza.models.common.utils import open_read_binary, open_read_text\nfrom stanza.resources.common import DEFAULT_MODEL_DIR\n\nfrom pickle import UnpicklingError\nimport warnings\n\nlogger = logging.getLogger('stanza')\n\nclass PretrainedWordVocab(BaseVocab):\n    def build_vocab(self):\n        self._id2unit = VOCAB_PREFIX + self.data\n        self._unit2id = {w:i for i, w in enumerate(self._id2unit)}\n\n    def normalize_unit(self, unit):\n        unit = super().normalize_unit(unit)\n        if unit:\n            unit = unit.replace(\" \",\"\\xa0\")\n        return unit\n\nclass Pretrain:\n    \"\"\" A loader and saver for pretrained embeddings. \"\"\"\n\n    def __init__(self, filename=None, vec_filename=None, max_vocab=-1, save_to_file=True, csv_filename=None):\n        self.filename = filename\n        self._vec_filename = vec_filename\n        self._csv_filename = csv_filename\n        self._max_vocab = max_vocab\n        self._save_to_file = save_to_file\n\n    def __len__(self):\n        return len(self.vocab)\n\n    @property\n    def vocab(self):\n        if not hasattr(self, '_vocab'):\n            self.load()\n        return self._vocab\n\n    @property\n    def emb(self):\n        if not hasattr(self, '_emb'):\n            self.load()\n        return self._emb\n\n    def load(self):\n        if self.filename is not None and os.path.exists(self.filename):\n            try:\n                # TODO: after making the next release, remove the weights_only=False version\n                try:\n                    data = torch.load(self.filename, lambda storage, loc: storage, weights_only=True)\n                except UnpicklingError:\n                    data = torch.load(self.filename, lambda storage, loc: storage, weights_only=False)\n                    warnings.warn(\"The saved pretrain has an old format using numpy.ndarray instead of torch to store weights.  This version of Stanza can support reading both the new and the old formats.  Future versions will only allow loading with weights_only=True.  Please resave the pretrained embedding using this version ASAP.\")\n                logger.debug(\"Loaded pretrain from {}\".format(self.filename))\n                if not isinstance(data, dict):\n                    raise RuntimeError(\"File {} exists but is not a stanza pretrain file.  It is not a dict, whereas a Stanza pretrain should have a dict with 'emb' and 'vocab'\".format(self.filename))\n                if 'emb' not in data or 'vocab' not in data:\n                    raise RuntimeError(\"File {} exists but is not a stanza pretrain file.  A Stanza pretrain file should have 'emb' and 'vocab' fields in its state dict\".format(self.filename))\n                self._vocab = PretrainedWordVocab.load_state_dict(data['vocab'])\n                self._emb = data['emb']\n                if isinstance(self._emb, np.ndarray):\n                    self._emb = torch.from_numpy(self._emb)\n                return\n            except (KeyboardInterrupt, SystemExit):\n                raise\n            except BaseException as e:\n                if not self._vec_filename and not self._csv_filename:\n                    raise\n                logger.warning(\"Pretrained file exists but cannot be loaded from {}, due to the following exception:\\n\\t{}\".format(self.filename, e))\n                vocab, emb = self.read_pretrain()\n        else:\n            if not self._vec_filename and not self._csv_filename:\n                raise FileNotFoundError(\"Pretrained file {} does not exist, and no text/xz file was provided\".format(self.filename))\n            if self.filename is not None:\n                logger.info(\"Pretrained filename %s specified, but file does not exist.  Attempting to load from text file\" % self.filename)\n            vocab, emb = self.read_pretrain()\n\n        self._vocab = vocab\n        self._emb = emb\n\n        if self._save_to_file:\n            # save to file\n            assert self.filename is not None, \"Filename must be provided to save pretrained vector to file.\"\n            self.save(self.filename)\n\n    def save(self, filename):\n        directory, _ = os.path.split(filename)\n        if directory:\n            os.makedirs(directory, exist_ok=True)\n        # should not infinite loop since the load function sets _vocab and _emb before trying to save\n        data = {'vocab': self.vocab.state_dict(), 'emb': self.emb}\n        try:\n            torch.save(data, filename, _use_new_zipfile_serialization=False)\n            logger.info(\"Saved pretrained vocab and vectors to {}\".format(filename))\n        except (KeyboardInterrupt, SystemExit):\n            raise\n        except BaseException as e:\n            logger.warning(\"Saving pretrained data failed due to the following exception... continuing anyway.\\n\\t{}\".format(e))\n\n\n    def write_text(self, filename, header=False):\n        \"\"\"\n        Write the vocab & values to a text file\n        \"\"\"\n        with open(filename, \"w\") as fout:\n            if header:\n                word_dim = self.emb[0].shape[0]\n                fout.write(\"%d %d\\n\" % (len(self.vocab), word_dim))\n            for word_idx, word in enumerate(self.vocab):\n                row = self.emb[word_idx].to(\"cpu\")\n                fout.write(word)\n                fout.write(\" \")\n                fout.write(\" \".join([\"%.6f\" % x.item() for x in row]))\n                fout.write(\"\\n\")\n\n\n    def read_pretrain(self):\n        # load from pretrained filename\n        if self._vec_filename is not None:\n            words, emb, failed = self.read_from_file(self._vec_filename, self._max_vocab)\n        elif self._csv_filename is not None:\n            words, emb = self.read_from_csv(self._csv_filename)\n        else:\n            raise RuntimeError(\"Vector file is not provided.\")\n\n        if len(emb) - len(VOCAB_PREFIX) != len(words):\n            raise RuntimeError(\"Loaded number of vectors does not match number of words.\")\n\n        # Use a fixed vocab size\n        if self._max_vocab > len(VOCAB_PREFIX) and self._max_vocab < len(words) + len(VOCAB_PREFIX):\n            words = words[:self._max_vocab - len(VOCAB_PREFIX)]\n            emb = emb[:self._max_vocab]\n\n        vocab = PretrainedWordVocab(words)\n        \n        return vocab, emb\n\n    @staticmethod\n    def read_from_csv(filename):\n        \"\"\"\n        Read vectors from CSV\n\n        Skips the first row\n        \"\"\"\n        logger.info(\"Reading pretrained vectors from csv file %s ...\", filename)\n        with open_read_text(filename) as fin:\n            csv_reader = csv.reader(fin)\n            # the header of the thai csv vector file we have is just the number of columns\n            # so we read past the first line\n            for line in csv_reader:\n                break\n            lines = [line for line in csv_reader]\n\n        rows = len(lines)\n        cols = len(lines[0]) - 1\n\n        emb = torch.zeros((rows + len(VOCAB_PREFIX), cols), dtype=torch.float32)\n        for i, line in enumerate(lines):\n            emb[i+len(VOCAB_PREFIX)] = torch.tensor([float(x) for x in line[-cols:]], dtype=torch.float32)\n        words = [line[0].replace(' ', '\\xa0') for line in lines]\n        return words, emb\n\n    @staticmethod\n    def read_from_file(filename, max_vocab=None):\n        \"\"\"\n        Open a vector file using the provided function and read from it.\n        \"\"\"\n        logger.info(\"Reading pretrained vectors from %s ...\", filename)\n\n        # some vector files, such as Google News, use tabs\n        tab_space_pattern = re.compile(r\"[ \\t]+\")\n        first = True\n        cols = None\n        lines = []\n        failed = 0\n        unk_line = None\n        with open_read_binary(filename) as f:\n            for i, line in enumerate(f):\n                try:\n                    line = line.decode()\n                except UnicodeDecodeError:\n                    failed += 1\n                    continue\n                line = line.rstrip()\n                if not line:\n                    continue\n                pieces = tab_space_pattern.split(line)\n                if first:\n                    # the first line contains the number of word vectors and the dimensionality\n                    # note that a 1d embedding with a number as the first entry\n                    # will fail to read properly.  we ignore that case\n                    first = False\n                    if len(pieces) == 2:\n                        cols = int(pieces[1])\n                        continue\n\n                if pieces[0] == '<unk>':\n                    if unk_line is not None:\n                        logger.error(\"More than one <unk> line in the pretrain!  Keeping the most recent one\")\n                    else:\n                        logger.debug(\"Found an unk line while reading the pretrain\")\n                    unk_line = pieces\n                else:\n                    if not max_vocab or max_vocab < 0 or len(lines) < max_vocab:\n                        lines.append(pieces)\n\n        if cols is None:\n            # another failure case: all words have spaces in them\n            cols = min(len(x) for x in lines) - 1\n        rows = len(lines)\n        emb = torch.zeros((rows + len(VOCAB_PREFIX), cols), dtype=torch.float32)\n        if unk_line is not None:\n            emb[UNK_ID] = torch.tensor([float(x) for x in unk_line[-cols:]], dtype=torch.float32)\n        for i, line in enumerate(lines):\n            emb[i+len(VOCAB_PREFIX)] = torch.tensor([float(x) for x in line[-cols:]], dtype=torch.float32)\n\n        # if there were word pieces separated with spaces, rejoin them with nbsp instead\n        # this way, the normalize_unit method in vocab.py can find the word at test time\n        words = ['\\xa0'.join(line[:-cols]) for line in lines]\n        if failed > 0:\n            logger.info(\"Failed to read %d lines from embedding\", failed)\n        return words, emb, failed\n\n\ndef find_pretrain_file(wordvec_pretrain_file, save_dir, shorthand, lang):\n    \"\"\"\n    When training a model, look in a few different places for a .pt file\n\n    If a specific argument was passsed in, prefer that location\n    Otherwise, check in a few places:\n      saved_models/{model}/{shorthand}.pretrain.pt\n      saved_models/{model}/{shorthand}_pretrain.pt\n      ~/stanza_resources/{language}/pretrain/{shorthand}_pretrain.pt\n    \"\"\"\n    if wordvec_pretrain_file:\n        return wordvec_pretrain_file\n\n    default_pretrain_file = os.path.join(save_dir, '{}.pretrain.pt'.format(shorthand))\n    if os.path.exists(default_pretrain_file):\n        logger.debug(\"Found existing .pt file in %s\" % default_pretrain_file)\n        return default_pretrain_file\n    else:\n        logger.debug(\"Cannot find pretrained vectors in %s\" % default_pretrain_file)\n\n    pretrain_file = os.path.join(save_dir, '{}_pretrain.pt'.format(shorthand))\n    if os.path.exists(pretrain_file):\n        logger.debug(\"Found existing .pt file in %s\" % pretrain_file)\n        return pretrain_file\n    else:\n        logger.debug(\"Cannot find pretrained vectors in %s\" % pretrain_file)\n\n    if shorthand.find(\"_\") >= 0:\n        # try to assemble /home/user/stanza_resources/vi/pretrain/vtb.pt for example\n        pretrain_file = os.path.join(DEFAULT_MODEL_DIR, lang, 'pretrain', '{}.pt'.format(shorthand.split('_', 1)[1]))\n        if os.path.exists(pretrain_file):\n            logger.debug(\"Found existing .pt file in %s\" % pretrain_file)\n            return pretrain_file\n        else:\n            logger.debug(\"Cannot find pretrained vectors in %s\" % pretrain_file)\n\n    # if we can't find it anywhere, just return the first location searched...\n    # maybe we'll get lucky and the original .txt file can be found\n    return default_pretrain_file\n\n\nif __name__ == '__main__':\n    with open('test.txt', 'w') as fout:\n        fout.write('3 2\\na 1 1\\nb -1 -1\\nc 0 0\\n')\n    # 1st load: save to pt file\n    pretrain = Pretrain('test.pt', 'test.txt')\n    print(pretrain.emb)\n    # verify pt file\n    x = torch.load('test.pt', weights_only=True)\n    print(x)\n    # 2nd load: load saved pt file\n    pretrain = Pretrain('test.pt', 'test.txt')\n    print(pretrain.emb)\n\n"
  },
  {
    "path": "stanza/models/common/relative_attn.py",
    "content": "import logging\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nlogger = logging.getLogger('stanza')\n\nclass RelativeAttention(nn.Module):\n    def __init__(self, d_model, num_heads, window=8, dropout=0.2, reverse=False, d_output=None, fudge_output=False, num_sinks=0):\n        super().__init__()\n        if d_output is None:\n            d_output = d_model\n\n        d_head, remainder = divmod(d_output, num_heads)\n        if remainder:\n            if fudge_output:\n                d_head = d_head + 1\n                logger.debug(\"Relative attn: %d %% %d != 0, updating d_output to %d\", d_output, num_heads, num_heads * d_head)\n                d_output = num_heads * d_head\n            else:\n                raise ValueError(\"incompatible `d_model` and `num_heads`\")\n        self.window = window\n        self.num_sinks = num_sinks\n        self.d_model = d_model\n        self.d_head = d_head\n        self.num_heads = num_heads\n        self.d_output = d_output\n        self.key = nn.Linear(d_model, d_output)\n        # the bias for query all gets trained to 0 anyway\n        self.query = nn.Linear(d_model, d_output, bias=False)\n        self.value = nn.Linear(d_model, d_output, bias=False)\n        # initializing value with eye seems to hurt!\n        #nn.init.eye_(self.value.weight)\n\n        self.dropout = nn.Dropout(dropout)\n        self.position = nn.Parameter(torch.randn(1, 1, d_head, window + num_sinks, 1))\n\n        self.register_buffer(\n            \"mask\", \n            torch.tril(torch.ones(window, window), diagonal=-1).unsqueeze(0).unsqueeze(0).unsqueeze(0)\n        )\n        self.register_buffer(\n            \"flipped_mask\",\n            torch.flip(self.mask, (-1,))\n        )\n\n        self.reverse = reverse\n\n    def forward(self, x, sink=None):\n        # x.shape == (batch_size, seq_len, d_model)\n        batch_size, seq_len, d_model = x.shape\n        if d_model != self.d_model:\n            raise ValueError(\"Incompatible input\")\n\n        if self.reverse:\n            x = torch.flip(x, (1,))\n\n        orig_seq_len = seq_len\n        if seq_len < self.window:\n            zeros = torch.zeros((x.shape[0], self.window - seq_len, x.shape[2]), dtype=x.dtype, device=x.device)\n            x = torch.cat((x, zeros), axis=1)\n            seq_len = self.window\n\n        if self.num_sinks > 0:\n            # could keep a parameter to train sinks, but as it turns out,\n            # the position vectors just overlap that parameter space anyway\n            # generally the model trains the sinks to zero if we do that\n            if sink is None:\n                sink = torch.zeros((batch_size, self.num_sinks, d_model), dtype=x.dtype, device=x.device)\n            else:\n                sink = sink.expand(batch_size, self.num_sinks, d_model)\n            x = torch.cat((sink, x), axis=1)\n\n        # k.shape = (batch_size, num_heads, d_head, seq_len + num_sinks)\n        k = self.key(x).reshape(batch_size, seq_len + self.num_sinks, self.num_heads, -1).permute(0, 2, 3, 1)[:, :, :, self.num_sinks:]\n\n        # v.shape = (batch_size, num_heads, d_head, seq_len + num_sinks)\n        v = self.value(x).reshape(batch_size, seq_len + self.num_sinks, self.num_heads, -1).permute(0, 2, 3, 1)\n\n        # q.shape = (batch_size, num_heads, d_head, seq_len + num_sinks)\n        q = self.query(x).reshape(batch_size, seq_len + self.num_sinks, self.num_heads, -1).permute(0, 2, 3, 1)\n        # q.shape = (batch_size, num_heads, d_head, window + num_sinks, seq_len)\n        q = self.skew_repeat(q)\n        q = q + self.position\n\n        # qk.shape = (batch_size, num_heads, d_head, window + num_sinks, seq_len)\n        qk = torch.einsum('bndws,bnds->bndws', q, k)\n\n        # TODO: fix mask\n        # mask out the padding spaces at the end\n        # can only attend to spots that aren't padded\n        if orig_seq_len < seq_len:\n            # mask out the part of the sentence which is empty\n            shorter_mask = self.flipped_mask[:, :, :, :orig_seq_len, -orig_seq_len:]\n            qk = qk[:, :, :, :(orig_seq_len + self.num_sinks), :orig_seq_len]\n            qk[:, :, :, -orig_seq_len:, :] = qk[:, :, :, -orig_seq_len:, :].masked_fill(shorter_mask == 1, float(\"-inf\"))\n        else:\n            qk[:, :, :, -self.window:, -self.window:] = qk[:, :, :, -self.window:, -self.window:].masked_fill(self.flipped_mask == 1, float(\"-inf\"))\n        qk = F.softmax(qk, dim=3)\n\n        # v.shape = (batch_size, num_heads, d_head, window, seq_len)\n        v = self.skew_repeat(v)\n        if orig_seq_len < seq_len:\n            v = v[:, :, :, :(orig_seq_len + self.num_sinks), :orig_seq_len]\n        # result.shape = (batch_size, num_heads, d_head, orig_seq_len)\n        result = torch.einsum('bndws,bndws->bnds', qk, v)\n        # batch_size, orig_seq_len, d_output\n        result = result.reshape(batch_size, self.d_output, orig_seq_len).transpose(1, 2)\n\n        if self.reverse:\n            result = torch.flip(result, (1,))\n\n        return self.dropout(result)\n\n    def skew_repeat(self, q):\n        \"\"\"\n        q (currently, at least) is num_sinks + seq_len long\n        and the num_sinks are there to be chopped off the front\n        then the seq_len remainder is skewed\n        \"\"\"\n        if self.num_sinks > 0:\n            q_sink = q[:, :, :, :self.num_sinks]\n            q_sink = q_sink.unsqueeze(4)\n            q_sink = q_sink.repeat(1, 1, 1, 1, q.shape[-1] - self.num_sinks)\n            q = q[:, :, :, self.num_sinks:]\n        # make stripes that look like this\n        # (seq_len 5, window 3)\n        #   1 2 3 4 5\n        #   1 2 3 4 5\n        #   1 2 3 4 5\n        q = q.unsqueeze(4).repeat(1, 1, 1, 1, self.window).transpose(3, 4)\n        # now the stripes look like\n        #   1 2 3 4 5\n        #   0 2 3 4 5\n        #   0 0 3 4 5\n        q[:, :, :, :, :self.window] = q[:, :, :, :, :self.window].masked_fill(self.mask == 1, 0)\n        q_shape = list(q.shape)\n        q_new_shape = list(q.shape)[:-2] + [-1]\n        q = q.reshape(q_new_shape)\n        zeros = torch.zeros_like(q[:, :, :, :1])\n        zeros = zeros.repeat(1, 1, 1, self.window)\n        q = torch.cat((q, zeros), axis=-1)\n        q_new_shape = q_new_shape[:-1] + [self.window, -1]\n        # now the stripes look like\n        #   1 2 3 4 5\n        #   2 3 4 5 0\n        #   3 4 5 0 0\n        # q.shape = (batch_size, num_heads, d_head, window, seq_len)\n        q = q.reshape(q_new_shape)[:, :, :, :, :-1]\n        if self.num_sinks > 0:\n            q = torch.cat([q_sink, q], dim=3)\n        return q\n"
  },
  {
    "path": "stanza/models/common/seq2seq_constant.py",
    "content": "\"\"\"\nConstants for seq2seq models.\n\"\"\"\n\nPAD = '<PAD>'\nPAD_ID = 0\nUNK = '<UNK>'\nUNK_ID = 1\nSOS = '<SOS>'\nSOS_ID = 2\nEOS = '<EOS>'\nEOS_ID = 3\n\nVOCAB_PREFIX = [PAD, UNK, SOS, EOS]\n\nEMB_INIT_RANGE = 1.0\nINFINITY_NUMBER = 1e12\n"
  },
  {
    "path": "stanza/models/common/seq2seq_model.py",
    "content": "\"\"\"\nThe full encoder-decoder model, built on top of the base seq2seq modules.\n\"\"\"\n\nimport logging\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.models.common import utils\nfrom stanza.models.common.seq2seq_modules import LSTMAttention\nfrom stanza.models.common.beam import Beam\nfrom stanza.models.common.seq2seq_constant import UNK_ID\n\nlogger = logging.getLogger('stanza')\n\nclass Seq2SeqModel(nn.Module):\n    \"\"\"\n    A complete encoder-decoder model, with optional attention.\n\n    A parent class which makes use of the contextual_embedding (such as a charlm)\n    can make use of unsaved_modules when saving.\n    \"\"\"\n    def __init__(self, args, emb_matrix=None, contextual_embedding=None):\n        super().__init__()\n\n        self.unsaved_modules = []\n\n        self.vocab_size = args['vocab_size']\n        self.emb_dim = args['emb_dim']\n        self.hidden_dim = args['hidden_dim']\n        self.nlayers = args['num_layers'] # encoder layers, decoder layers = 1\n        self.emb_dropout = args.get('emb_dropout', 0.0)\n        self.dropout = args['dropout']\n        self.pad_token = constant.PAD_ID\n        self.max_dec_len = args['max_dec_len']\n        self.top = args.get('top', 1e10)\n        self.args = args\n        self.emb_matrix = emb_matrix\n        self.add_unsaved_module(\"contextual_embedding\", contextual_embedding)\n\n        logger.debug(\"Building an attentional Seq2Seq model...\")\n        logger.debug(\"Using a Bi-LSTM encoder\")\n        self.num_directions = 2\n        self.enc_hidden_dim = self.hidden_dim // 2\n        self.dec_hidden_dim = self.hidden_dim\n\n        self.use_pos = args.get('pos', False)\n        self.pos_dim = args.get('pos_dim', 0)\n        self.pos_vocab_size = args.get('pos_vocab_size', 0)\n        self.pos_dropout = args.get('pos_dropout', 0)\n        self.edit = args.get('edit', False)\n        self.num_edit = args.get('num_edit', 0)\n        self.copy = args.get('copy', False)\n\n        self.emb_drop = nn.Dropout(self.emb_dropout)\n        self.drop = nn.Dropout(self.dropout)\n        self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)\n        self.input_dim = self.emb_dim\n        if self.contextual_embedding is not None:\n            self.input_dim += self.contextual_embedding.hidden_dim()\n        self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \\\n                bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)\n        self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \\\n                batch_first=True, attn_type=self.args['attn_type'])\n        self.dec2vocab = nn.Linear(self.dec_hidden_dim, self.vocab_size)\n        if self.use_pos and self.pos_dim > 0:\n            logger.debug(\"Using POS in encoder\")\n            self.pos_embedding = nn.Embedding(self.pos_vocab_size, self.pos_dim, self.pad_token)\n            self.pos_drop = nn.Dropout(self.pos_dropout)\n        if self.edit:\n            edit_hidden = self.hidden_dim//2\n            self.edit_clf = nn.Sequential(\n                    nn.Linear(self.hidden_dim, edit_hidden),\n                    nn.ReLU(),\n                    nn.Linear(edit_hidden, self.num_edit))\n\n        if self.copy:\n            self.copy_gate = nn.Linear(self.dec_hidden_dim, 1)\n\n        SOS_tensor = torch.LongTensor([constant.SOS_ID])\n        self.register_buffer('SOS_tensor', SOS_tensor)\n\n        self.init_weights()\n\n    def add_unsaved_module(self, name, module):\n        self.unsaved_modules += [name]\n        setattr(self, name, module)\n\n    def init_weights(self):\n        # initialize embeddings\n        init_range = constant.EMB_INIT_RANGE\n        if self.emb_matrix is not None:\n            if isinstance(self.emb_matrix, np.ndarray):\n                self.emb_matrix = torch.from_numpy(self.emb_matrix)\n            assert self.emb_matrix.size() == (self.vocab_size, self.emb_dim), \\\n                    \"Input embedding matrix must match size: {} x {}\".format(self.vocab_size, self.emb_dim)\n            self.embedding.weight.data.copy_(self.emb_matrix)\n        else:\n            self.embedding.weight.data.uniform_(-init_range, init_range)\n        # decide finetuning\n        if self.top <= 0:\n            logger.debug(\"Do not finetune embedding layer.\")\n            self.embedding.weight.requires_grad = False\n        elif self.top < self.vocab_size:\n            logger.debug(\"Finetune top {} embeddings.\".format(self.top))\n            self.embedding.weight.register_hook(lambda x: utils.keep_partial_grad(x, self.top))\n        else:\n            logger.debug(\"Finetune all embeddings.\")\n        # initialize pos embeddings\n        if self.use_pos:\n            self.pos_embedding.weight.data.uniform_(-init_range, init_range)\n\n    def zero_state(self, inputs):\n        batch_size = inputs.size(0)\n        device = self.SOS_tensor.device\n        h0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)\n        c0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)\n        return h0, c0\n\n    def encode(self, enc_inputs, lens):\n        \"\"\" Encode source sequence. \"\"\"\n        h0, c0 = self.zero_state(enc_inputs)\n\n        packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)\n        packed_h_in, (hn, cn) = self.encoder(packed_inputs, (h0, c0))\n        h_in, _ = nn.utils.rnn.pad_packed_sequence(packed_h_in, batch_first=True)\n        hn = torch.cat((hn[-1], hn[-2]), 1)\n        cn = torch.cat((cn[-1], cn[-2]), 1)\n        return h_in, (hn, cn)\n\n    def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None, never_decode_unk=False):\n        \"\"\" Decode a step, based on context encoding and source context states.\"\"\"\n        dec_hidden = (hn, cn)\n        decoder_output = self.decoder(dec_inputs, dec_hidden, ctx, ctx_mask, return_logattn=self.copy)\n        if self.copy:\n            h_out, dec_hidden, log_attn = decoder_output\n        else:\n            h_out, dec_hidden = decoder_output\n\n        h_out_reshape = h_out.contiguous().view(h_out.size(0) * h_out.size(1), -1)\n        decoder_logits = self.dec2vocab(h_out_reshape)\n        decoder_logits = decoder_logits.view(h_out.size(0), h_out.size(1), -1)\n        log_probs = self.get_log_prob(decoder_logits)\n\n        if self.copy:\n            copy_logit = self.copy_gate(h_out)\n            if self.use_pos:\n                # can't copy the UPOS\n                log_attn = log_attn[:, :, 1:]\n\n            # renormalize\n            log_attn = torch.log_softmax(log_attn, -1)\n            # calculate copy probability for each word in the vocab\n            log_copy_prob = torch.nn.functional.logsigmoid(copy_logit) + log_attn\n            # scatter logsumexp\n            mx = log_copy_prob.max(-1, keepdim=True)[0]\n            log_copy_prob = log_copy_prob - mx\n            # here we make space in the log probs for vocab items\n            # which might be copied from the encoder side, but which\n            # were not known at training time\n            # note that such an item cannot possibly be predicted by\n            # the model as a raw output token\n            # however, the copy gate might score high on copying a\n            # previously unknown vocab item\n            copy_prob = torch.exp(log_copy_prob)\n            copied_vocab_shape = list(log_probs.size())\n            if torch.max(src) >= copied_vocab_shape[-1]:\n                copied_vocab_shape[-1] = torch.max(src) + 1\n            copied_vocab_prob = log_probs.new_zeros(copied_vocab_shape)\n            scattered_copy = src.unsqueeze(1).expand(src.size(0), copy_prob.size(1), src.size(1))\n            # fill in the copy tensor with the copy probs of each character\n            # the rest of the copy tensor will be filled with -largenumber\n            copied_vocab_prob = copied_vocab_prob.scatter_add(-1, scattered_copy, copy_prob)\n            zero_mask = (copied_vocab_prob == 0)\n            log_copied_vocab_prob = torch.log(copied_vocab_prob.masked_fill(zero_mask, 1e-12)) + mx\n            log_copied_vocab_prob = log_copied_vocab_prob.masked_fill(zero_mask, -1e12)\n\n            # combine with normal vocab probability\n            log_nocopy_prob = -torch.log(1 + torch.exp(copy_logit))\n            if log_probs.shape[-1] < copied_vocab_shape[-1]:\n                # for previously unknown vocab items which are in the encoder,\n                # we reuse the UNK_ID prediction\n                # this gives a baseline number which we can combine with\n                # the copy gate prediction\n                # technically this makes log_probs no longer represent\n                # a probability distribution when looking at unknown vocab\n                # this is probably not a serious problem\n                # an example of this usage is in the Lemmatizer, such as a\n                # plural word in English with the character \"ã\" in it instead of \"a\"\n                # if \"ã\" is not known in the training data, the lemmatizer would\n                # ordinarily be unable to output it, and thus the seq2seq model\n                # would have no chance to depluralize \"ãntennae\" -> \"ãntenna\"\n                # however, if we temporarily add \"ã\" to the encoder vocab,\n                # then let the copy gate accept that letter, we find the Lemmatizer\n                # seq2seq model will want to copy that particular vocab item\n                # this allows the Lemmatizer to produce \"ã\" instead of requiring\n                # that it produces UNK, then going back to the input text to\n                # figure out which UNK it intended to produce\n                new_log_probs = log_probs.new_zeros(copied_vocab_shape)\n                new_log_probs[:, :, :log_probs.shape[-1]] = log_probs\n                new_log_probs[:, :, log_probs.shape[-1]:] = new_log_probs[:, :, UNK_ID].unsqueeze(2)\n                log_probs = new_log_probs\n            log_probs = log_probs + log_nocopy_prob\n            log_probs = torch.logsumexp(torch.stack([log_copied_vocab_prob, log_probs]), 0)\n\n        if never_decode_unk:\n            log_probs[:, :, UNK_ID] = float(\"-inf\")\n        return log_probs, dec_hidden\n\n    def embed(self, src, src_mask, pos, raw):\n        embed_src = src.clone()\n        embed_src[embed_src >= self.vocab_size] = UNK_ID\n        enc_inputs = self.emb_drop(self.embedding(embed_src))\n        batch_size = enc_inputs.size(0)\n        if self.use_pos:\n            assert pos is not None, \"Missing POS input for seq2seq lemmatizer.\"\n            pos_inputs = self.pos_drop(self.pos_embedding(pos))\n            enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)\n            pos_src_mask = src_mask.new_zeros([batch_size, 1])\n            src_mask = torch.cat([pos_src_mask, src_mask], dim=1)\n        if raw is not None and self.contextual_embedding is not None:\n            raw_inputs = self.contextual_embedding(raw)\n            if self.use_pos:\n                raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2]))\n                raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1)\n            enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2)\n        src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))\n        return enc_inputs, batch_size, src_lens, src_mask\n\n    def forward(self, src, src_mask, tgt_in, pos=None, raw=None):\n        # prepare for encoder/decoder\n        enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)\n\n        # encode source\n        h_in, (hn, cn) = self.encode(enc_inputs, src_lens)\n\n        if self.edit:\n            edit_logits = self.edit_clf(hn)\n        else:\n            edit_logits = None\n\n        dec_inputs = self.emb_drop(self.embedding(tgt_in))\n\n        log_probs, _ = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src)\n        return log_probs, edit_logits\n\n    def get_log_prob(self, logits):\n        logits_reshape = logits.view(-1, self.vocab_size)\n        log_probs = F.log_softmax(logits_reshape, dim=1)\n        if logits.dim() == 2:\n            return log_probs\n        return log_probs.view(logits.size(0), logits.size(1), logits.size(2))\n\n    def predict_greedy(self, src, src_mask, pos=None, raw=None, never_decode_unk=False):\n        \"\"\" Predict with greedy decoding. \"\"\"\n        enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)\n\n        # encode source\n        h_in, (hn, cn) = self.encode(enc_inputs, src_lens)\n\n        if self.edit:\n            edit_logits = self.edit_clf(hn)\n        else:\n            edit_logits = None\n\n        # greedy decode by step\n        dec_inputs = self.embedding(self.SOS_tensor)\n        dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1))\n\n        done = [False for _ in range(batch_size)]\n        total_done = 0\n        max_len = 0\n        output_seqs = [[] for _ in range(batch_size)]\n\n        while total_done < batch_size and max_len < self.max_dec_len:\n            log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)\n            assert log_probs.size(1) == 1, \"Output must have 1-step of output.\"\n            _, preds = log_probs.squeeze(1).max(1, keepdim=True)\n            # if a unlearned character is predicted via the copy mechanism,\n            # use the UNK embedding for it\n            dec_inputs = preds.clone()\n            dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID\n            dec_inputs = self.embedding(dec_inputs) # update decoder inputs\n            max_len += 1\n            for i in range(batch_size):\n                if not done[i]:\n                    token = preds.data[i][0].item()\n                    if token == constant.EOS_ID:\n                        done[i] = True\n                        total_done += 1\n                    else:\n                        output_seqs[i].append(token)\n        return output_seqs, edit_logits\n\n    def predict(self, src, src_mask, pos=None, beam_size=5, raw=None, never_decode_unk=False):\n        \"\"\" Predict with beam search. \"\"\"\n        if beam_size == 1:\n            return self.predict_greedy(src, src_mask, pos, raw, never_decode_unk=never_decode_unk)\n\n        enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)\n\n        # (1) encode source\n        h_in, (hn, cn) = self.encode(enc_inputs, src_lens)\n\n        if self.edit:\n            edit_logits = self.edit_clf(hn)\n        else:\n            edit_logits = None\n\n        # (2) set up beam\n        with torch.no_grad():\n            h_in = h_in.data.repeat(beam_size, 1, 1) # repeat data for beam search\n            src_mask = src_mask.repeat(beam_size, 1)\n            # repeat decoder hidden states\n            hn = hn.data.repeat(beam_size, 1)\n            cn = cn.data.repeat(beam_size, 1)\n        device = self.SOS_tensor.device\n        beam = [Beam(beam_size, device) for _ in range(batch_size)]\n\n        def update_state(states, idx, positions, beam_size):\n            \"\"\" Select the states according to back pointers. \"\"\"\n            for e in states:\n                br, d = e.size()\n                s = e.contiguous().view(beam_size, br // beam_size, d)[:,idx]\n                s.data.copy_(s.data.index_select(0, positions))\n\n        # (3) main loop\n        for i in range(self.max_dec_len):\n            dec_inputs = torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(-1, 1)\n            # if a unlearned character is predicted via the copy mechanism,\n            # use the UNK embedding for it\n            dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID\n            dec_inputs = self.embedding(dec_inputs)\n            log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)\n            log_probs = log_probs.view(beam_size, batch_size, -1).transpose(0,1).contiguous() # [batch, beam, V]\n\n            # advance each beam\n            done = []\n            for b in range(batch_size):\n                is_done = beam[b].advance(log_probs.data[b])\n                if is_done:\n                    done += [b]\n                # update beam state\n                update_state((hn, cn), b, beam[b].get_current_origin(), beam_size)\n\n            if len(done) == batch_size:\n                break\n\n        # back trace and find hypothesis\n        all_hyp, all_scores = [], []\n        for b in range(batch_size):\n            scores, ks = beam[b].sort_best()\n            all_scores += [scores[0]]\n            k = ks[0]\n            hyp = beam[b].get_hyp(k)\n            hyp = utils.prune_hyp(hyp)\n            hyp = [i.item() for i in hyp]\n            all_hyp += [hyp]\n\n        return all_hyp, edit_logits\n\n"
  },
  {
    "path": "stanza/models/common/seq2seq_modules.py",
    "content": "\"\"\"\nPytorch implementation of basic sequence to Sequence modules.\n\"\"\"\n\nimport logging\nimport torch\nimport torch.nn as nn\nimport math\nimport numpy as np\n\nimport stanza.models.common.seq2seq_constant as constant\n\nlogger = logging.getLogger('stanza')\n\nclass BasicAttention(nn.Module):\n    \"\"\"\n    A basic MLP attention layer.\n    \"\"\"\n    def __init__(self, dim):\n        super(BasicAttention, self).__init__()\n        self.linear_in = nn.Linear(dim, dim, bias=False)\n        self.linear_c = nn.Linear(dim, dim)\n        self.linear_v = nn.Linear(dim, 1, bias=False)\n        self.linear_out = nn.Linear(dim * 2, dim, bias=False)\n        self.tanh = nn.Tanh()\n        self.sm = nn.Softmax(dim=1)\n\n    def forward(self, input, context, mask=None, attn_only=False):\n        \"\"\"\n        input: batch x dim\n        context: batch x sourceL x dim\n        \"\"\"\n        batch_size = context.size(0)\n        source_len = context.size(1)\n        dim = context.size(2)\n        target = self.linear_in(input) # batch x dim\n        source = self.linear_c(context.contiguous().view(-1, dim)).view(batch_size, source_len, dim)\n        attn = target.unsqueeze(1).expand_as(context) + source\n        attn = self.tanh(attn) # batch x sourceL x dim\n        attn = self.linear_v(attn.view(-1, dim)).view(batch_size, source_len)\n\n        if mask is not None:\n            attn.masked_fill_(mask, -constant.INFINITY_NUMBER)\n\n        attn = self.sm(attn)\n        if attn_only:\n            return attn\n\n        weighted_context = torch.bmm(attn.unsqueeze(1), context).squeeze(1)\n        h_tilde = torch.cat((weighted_context, input), 1)\n        h_tilde = self.tanh(self.linear_out(h_tilde))\n\n        return h_tilde, attn\n\nclass SoftDotAttention(nn.Module):\n    \"\"\"Soft Dot Attention.\n\n    Ref: http://www.aclweb.org/anthology/D15-1166\n    Adapted from PyTorch OPEN NMT.\n    \"\"\"\n\n    def __init__(self, dim):\n        \"\"\"Initialize layer.\"\"\"\n        super(SoftDotAttention, self).__init__()\n        self.linear_in = nn.Linear(dim, dim, bias=False)\n        self.sm = nn.Softmax(dim=1)\n        self.linear_out = nn.Linear(dim * 2, dim, bias=False)\n        self.tanh = nn.Tanh()\n        self.mask = None\n\n    def forward(self, input, context, mask=None, attn_only=False, return_logattn=False):\n        \"\"\"Propagate input through the network.\n\n        input: batch x dim\n        context: batch x sourceL x dim\n        \"\"\"\n        target = self.linear_in(input).unsqueeze(2)  # batch x dim x 1\n\n        # Get attention\n        attn = torch.bmm(context, target).squeeze(2)  # batch x sourceL\n\n        if mask is not None:\n            # sett the padding attention logits to -inf\n            assert mask.size() == attn.size(), \"Mask size must match the attention size!\"\n            attn.masked_fill_(mask, -constant.INFINITY_NUMBER)\n\n        if return_logattn:\n            attn = torch.log_softmax(attn, 1)\n            attn_w = torch.exp(attn)\n        else:\n            attn = self.sm(attn)\n            attn_w = attn\n        if attn_only:\n            return attn\n\n        attn3 = attn_w.view(attn_w.size(0), 1, attn_w.size(1))  # batch x 1 x sourceL\n\n        weighted_context = torch.bmm(attn3, context).squeeze(1)  # batch x dim\n        h_tilde = torch.cat((weighted_context, input), 1)\n\n        h_tilde = self.tanh(self.linear_out(h_tilde))\n\n        return h_tilde, attn\n\n\nclass LinearAttention(nn.Module):\n    \"\"\" A linear attention form, inspired by BiDAF:\n        a = W (u; v; u o v)\n    \"\"\"\n\n    def __init__(self, dim):\n        super(LinearAttention, self).__init__()\n        self.linear = nn.Linear(dim*3, 1, bias=False)\n        self.linear_out = nn.Linear(dim * 2, dim, bias=False)\n        self.sm = nn.Softmax(dim=1)\n        self.tanh = nn.Tanh()\n        self.mask = None\n\n    def forward(self, input, context, mask=None, attn_only=False):\n        \"\"\"\n        input: batch x dim\n        context: batch x sourceL x dim\n        \"\"\"\n        batch_size = context.size(0)\n        source_len = context.size(1)\n        dim = context.size(2)\n        u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim)  # batch*sourceL x dim\n        v = context.contiguous().view(-1, dim)\n        attn_in = torch.cat((u, v, u.mul(v)), 1)\n        attn = self.linear(attn_in).view(batch_size, source_len)\n\n        if mask is not None:\n            # sett the padding attention logits to -inf\n            assert mask.size() == attn.size(), \"Mask size must match the attention size!\"\n            attn.masked_fill_(mask, -constant.INFINITY_NUMBER)\n\n        attn = self.sm(attn)\n        if attn_only:\n            return attn\n\n        attn3 = attn.view(batch_size, 1, source_len)  # batch x 1 x sourceL\n\n        weighted_context = torch.bmm(attn3, context).squeeze(1)  # batch x dim\n        h_tilde = torch.cat((weighted_context, input), 1)\n        h_tilde = self.tanh(self.linear_out(h_tilde))\n        return h_tilde, attn\n\nclass DeepAttention(nn.Module):\n    \"\"\" A deep attention form, invented by Robert:\n        u = ReLU(Wx)\n        v = ReLU(Wy)\n        a = V.(u o v)\n    \"\"\"\n\n    def __init__(self, dim):\n        super(DeepAttention, self).__init__()\n        self.linear_in = nn.Linear(dim, dim, bias=False)\n        self.linear_v = nn.Linear(dim, 1, bias=False)\n        self.linear_out = nn.Linear(dim * 2, dim, bias=False)\n        self.relu = nn.ReLU()\n        self.sm = nn.Softmax(dim=1)\n        self.tanh = nn.Tanh()\n        self.mask = None\n\n    def forward(self, input, context, mask=None, attn_only=False):\n        \"\"\"\n        input: batch x dim\n        context: batch x sourceL x dim\n        \"\"\"\n        batch_size = context.size(0)\n        source_len = context.size(1)\n        dim = context.size(2)\n        u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim)  # batch*sourceL x dim\n        u = self.relu(self.linear_in(u))\n        v = self.relu(self.linear_in(context.contiguous().view(-1, dim)))\n        attn = self.linear_v(u.mul(v)).view(batch_size, source_len)\n\n        if mask is not None:\n            # sett the padding attention logits to -inf\n            assert mask.size() == attn.size(), \"Mask size must match the attention size!\"\n            attn.masked_fill_(mask, -constant.INFINITY_NUMBER)\n\n        attn = self.sm(attn)\n        if attn_only:\n            return attn\n\n        attn3 = attn.view(batch_size, 1, source_len)  # batch x 1 x sourceL\n\n        weighted_context = torch.bmm(attn3, context).squeeze(1)  # batch x dim\n        h_tilde = torch.cat((weighted_context, input), 1)\n        h_tilde = self.tanh(self.linear_out(h_tilde))\n        return h_tilde, attn\n\nclass LSTMAttention(nn.Module):\n    r\"\"\"A long short-term memory (LSTM) cell with attention.\"\"\"\n\n    def __init__(self, input_size, hidden_size, batch_first=True, attn_type='soft'):\n        \"\"\"Initialize params.\"\"\"\n        super(LSTMAttention, self).__init__()\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.batch_first = batch_first\n\n        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)\n\n        if attn_type == 'soft':\n            self.attention_layer = SoftDotAttention(hidden_size)\n        elif attn_type == 'mlp':\n            self.attention_layer = BasicAttention(hidden_size)\n        elif attn_type == 'linear':\n            self.attention_layer = LinearAttention(hidden_size)\n        elif attn_type == 'deep':\n            self.attention_layer = DeepAttention(hidden_size)\n        else:\n            raise Exception(\"Unsupported LSTM attention type: {}\".format(attn_type))\n        logger.debug(\"Using {} attention for LSTM.\".format(attn_type))\n\n    def forward(self, input, hidden, ctx, ctx_mask=None, return_logattn=False):\n        \"\"\"Propagate input through the network.\"\"\"\n        if self.batch_first:\n            input = input.transpose(0,1)\n\n        output = []\n        attn = []\n        steps = range(input.size(0))\n        for i in steps:\n            hidden = self.lstm_cell(input[i], hidden)\n            hy, cy = hidden\n            h_tilde, alpha = self.attention_layer(hy, ctx, mask=ctx_mask, return_logattn=return_logattn)\n            output.append(h_tilde)\n            attn.append(alpha)\n        output = torch.cat(output, 0).view(input.size(0), *output[0].size())\n\n        if self.batch_first:\n            output = output.transpose(0,1)\n\n        if return_logattn:\n            attn = torch.stack(attn, 0)\n            if self.batch_first:\n                attn = attn.transpose(0, 1)\n            return output, hidden, attn\n\n        return output, hidden\n\n"
  },
  {
    "path": "stanza/models/common/seq2seq_utils.py",
    "content": "\"\"\"\nUtils for seq2seq models.\n\"\"\"\nfrom collections import Counter\nimport random\nimport json\nimport torch\n\nimport stanza.models.common.seq2seq_constant as constant\n\n# torch utils\ndef get_optimizer(name, parameters, lr):\n    if name == 'sgd':\n        return torch.optim.SGD(parameters, lr=lr)\n    elif name == 'adagrad':\n        return torch.optim.Adagrad(parameters, lr=lr)\n    elif name == 'adam':\n        return torch.optim.Adam(parameters) # use default lr\n    elif name == 'adamax':\n        return torch.optim.Adamax(parameters) # use default lr\n    else:\n        raise Exception(\"Unsupported optimizer: {}\".format(name))\n\ndef change_lr(optimizer, new_lr):\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = new_lr\n\ndef flatten_indices(seq_lens, width):\n    flat = []\n    for i, l in enumerate(seq_lens):\n        for j in range(l):\n            flat.append(i * width + j)\n    return flat\n\ndef keep_partial_grad(grad, topk):\n    \"\"\"\n    Keep only the topk rows of grads.\n    \"\"\"\n    assert topk < grad.size(0)\n    grad.data[topk:].zero_()\n    return grad\n\n# other utils\ndef save_config(config, path, verbose=True):\n    with open(path, 'w') as outfile:\n        json.dump(config, outfile, indent=2)\n    if verbose:\n        print(\"Config saved to file {}\".format(path))\n    return config\n\ndef load_config(path, verbose=True):\n    with open(path) as f:\n        config = json.load(f)\n    if verbose:\n        print(\"Config loaded from file {}\".format(path))\n    return config\n\ndef unmap_with_copy(indices, src_tokens, vocab):\n    \"\"\"\n    Unmap a list of list of indices, by optionally copying from src_tokens.\n    \"\"\"\n    result = []\n    for ind, tokens in zip(indices, src_tokens):\n        words = []\n        for idx in ind:\n            if idx >= 0:\n                words.append(vocab.id2word[idx])\n            else:\n                idx = -idx - 1 # flip and minus 1\n                words.append(tokens[idx])\n        result += [words]\n    return result\n\ndef prune_decoded_seqs(seqs):\n    \"\"\"\n    Prune decoded sequences after EOS token.\n    \"\"\"\n    out = []\n    for s in seqs:\n        if constant.EOS in s:\n            idx = s.index(constant.EOS_TOKEN)\n            out += [s[:idx]]\n        else:\n            out += [s]\n    return out\n\ndef prune_hyp(hyp):\n    \"\"\"\n    Prune a decoded hypothesis\n    \"\"\"\n    if constant.EOS_ID in hyp:\n        idx = hyp.index(constant.EOS_ID)\n        return hyp[:idx]\n    else:\n        return hyp\n\ndef prune(data_list, lens):\n    assert len(data_list) == len(lens)\n    nl = []\n    for d, l in zip(data_list, lens):\n        nl.append(d[:l])\n    return nl\n\ndef sort(packed, ref, reverse=True):\n    \"\"\"\n    Sort a series of packed list, according to a ref list.\n    Also return the original index before the sort.\n    \"\"\"\n    assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)\n    packed = [ref] + [range(len(ref))] + list(packed)\n    sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]\n    return tuple(sorted_packed[1:])\n\ndef unsort(sorted_list, oidx):\n    \"\"\"\n    Unsort a sorted list, based on the original idx.\n    \"\"\"\n    assert len(sorted_list) == len(oidx), \"Number of list elements must match with original indices.\"\n    _, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]\n    return unsorted\n\n"
  },
  {
    "path": "stanza/models/common/short_name_to_treebank.py",
    "content": "# This module is autogenerated by build_short_name_to_treebank.py\n# Please do not edit\n\nSHORT_NAMES = {\n    'abq_atb':                   'UD_Abaza-ATB',\n    'ab_abnc':                   'UD_Abkhaz-AbNC',\n    'af_afribooms':              'UD_Afrikaans-AfriBooms',\n    'akk_pisandub':              'UD_Akkadian-PISANDUB',\n    'akk_riao':                  'UD_Akkadian-RIAO',\n    'aqz_tudet':                 'UD_Akuntsu-TuDeT',\n    'sq_staf':                   'UD_Albanian-STAF',\n    'sq_tsa':                    'UD_Albanian-TSA',\n    'gsw_divital':               'UD_Alemannic-DIVITAL',\n    'gsw_uzh':                   'UD_Alemannic-UZH',\n    'am_att':                    'UD_Amharic-ATT',\n    'grc_proiel':                'UD_Ancient_Greek-PROIEL',\n    'grc_ptnk':                  'UD_Ancient_Greek-PTNK',\n    'grc_perseus':               'UD_Ancient_Greek-Perseus',\n    'hbo_ptnk':                  'UD_Ancient_Hebrew-PTNK',\n    'apu_ufpa':                  'UD_Apurina-UFPA',\n    'ar_nyuad':                  'UD_Arabic-NYUAD',\n    'ar_padt':                   'UD_Arabic-PADT',\n    'ar_pud':                    'UD_Arabic-PUD',\n    'hy_armtdp':                 'UD_Armenian-ArmTDP',\n    'hy_bsut':                   'UD_Armenian-BSUT',\n    'aii_as':                    'UD_Assyrian-AS',\n    'az_tuecl':                  'UD_Azerbaijani-TueCL',\n    'bm_crb':                    'UD_Bambara-CRB',\n    'eu_bdt':                    'UD_Basque-BDT',\n    'bar_maibaam':               'UD_Bavarian-MaiBaam',\n    'bej_autogramm':             'UD_Beja-Autogramm',\n    'be_hse':                    'UD_Belarusian-HSE',\n    'bn_bru':                    'UD_Bengali-BRU',\n    'bho_bhtb':                  'UD_Bhojpuri-BHTB',\n    'sab_chibergis':             'UD_Bokota-ChibErgIS',\n    'bor_bdt':                   'UD_Bororo-BDT',\n    'br_keb':                    'UD_Breton-KEB',\n    'bg_btb':                    'UD_Bulgarian-BTB',\n    'bxr_bdt':                   'UD_Buryat-BDT',\n    'yue_hk':                    'UD_Cantonese-HK',\n    'cpg_amgic':                 'UD_Cappadocian-AMGiC',\n    'cpg_tuecl':                 'UD_Cappadocian-TueCL',\n    'ca_ancora':                 'UD_Catalan-AnCora',\n    'ceb_gja':                   'UD_Cebuano-GJA',\n    'ckb_mukri':                 'UD_Central_Kurdish-Mukri',\n    'zh-hans_beginner':          'UD_Chinese-Beginner',\n    'zh_beginner':               'UD_Chinese-Beginner',\n    'zh-hans_cfl':               'UD_Chinese-CFL',\n    'zh_cfl':                    'UD_Chinese-CFL',\n    'zh-hant_gsd':               'UD_Chinese-GSD',\n    'zh_gsd':                    'UD_Chinese-GSD',\n    'zh-hans_gsdsimp':           'UD_Chinese-GSDSimp',\n    'zh_gsdsimp':                'UD_Chinese-GSDSimp',\n    'zh-hant_hk':                'UD_Chinese-HK',\n    'zh_hk':                     'UD_Chinese-HK',\n    'zh-hant_pud':               'UD_Chinese-PUD',\n    'zh_pud':                    'UD_Chinese-PUD',\n    'zh-hans_patentchar':        'UD_Chinese-PatentChar',\n    'zh_patentchar':             'UD_Chinese-PatentChar',\n    'ctn_ctntb':                 'UD_Chintang-CTNTB',\n    'ckt_hse':                   'UD_Chukchi-HSE',\n    'xcl_caval':                 'UD_Classical_Armenian-CAVaL',\n    'lzh_kyoto':                 'UD_Classical_Chinese-Kyoto',\n    'lzh_tuecl':                 'UD_Classical_Chinese-TueCL',\n    'cop_bohairic':              'UD_Coptic-Bohairic',\n    'cop_scriptorium':           'UD_Coptic-Scriptorium',\n    'hr_set':                    'UD_Croatian-SET',\n    'cs_cac':                    'UD_Czech-CAC',\n    'cs_cltt':                   'UD_Czech-CLTT',\n    'cs_fictree':                'UD_Czech-FicTree',\n    'cs_pdtc':                   'UD_Czech-PDTC',\n    'cs_pud':                    'UD_Czech-PUD',\n    'cs_poetry':                 'UD_Czech-Poetry',\n    'da_ddt':                    'UD_Danish-DDT',\n    'nl_alpino':                 'UD_Dutch-Alpino',\n    'nl_lassysmall':             'UD_Dutch-LassySmall',\n    'egy_ujaen':                 'UD_Egyptian-UJaen',\n    'en_atis':                   'UD_English-Atis',\n    'en_childes':                'UD_English-CHILDES',\n    'en_ctetex':                 'UD_English-CTeTex',\n    'en_eslspok':                'UD_English-ESLSpok',\n    'en_ewt':                    'UD_English-EWT',\n    'en_gentle':                 'UD_English-GENTLE',\n    'en_gum':                    'UD_English-GUM',\n    'en_gumreddit':              'UD_English-GUMReddit',\n    'en_lines':                  'UD_English-LinES',\n    'en_littleprince':           'UD_English-LittlePrince',\n    'en_pud':                    'UD_English-PUD',\n    'en_partut':                 'UD_English-ParTUT',\n    'en_pronouns':               'UD_English-Pronouns',\n    'myv_jr':                    'UD_Erzya-JR',\n    'eo_cairo':                  'UD_Esperanto-Cairo',\n    'eo_prago':                  'UD_Esperanto-Prago',\n    'et_edt':                    'UD_Estonian-EDT',\n    'et_ewt':                    'UD_Estonian-EWT',\n    'fo_farpahc':                'UD_Faroese-FarPaHC',\n    'fo_oft':                    'UD_Faroese-OFT',\n    'fi_ftb':                    'UD_Finnish-FTB',\n    'fi_ood':                    'UD_Finnish-OOD',\n    'fi_pud':                    'UD_Finnish-PUD',\n    'fi_tdt':                    'UD_Finnish-TDT',\n    'fr_alts':                   'UD_French-ALTS',\n    'fr_fqb':                    'UD_French-FQB',\n    'fr_gsd':                    'UD_French-GSD',\n    'fr_pud':                    'UD_French-PUD',\n    'fr_partut':                 'UD_French-ParTUT',\n    'fr_parisstories':           'UD_French-ParisStories',\n    'fr_poitevindivital':        'UD_French-PoitevinDIVITAL',\n    'fr_rhapsodie':              'UD_French-Rhapsodie',\n    'fr_sequoia':                'UD_French-Sequoia',\n    'qfn_fame':                  'UD_Frisian_Dutch-Fame',\n    'gl_ctg':                    'UD_Galician-CTG',\n    'gl_pud':                    'UD_Galician-PUD',\n    'gl_treegal':                'UD_Galician-TreeGal',\n    'ka_glc':                    'UD_Georgian-GLC',\n    'ka_gnc':                    'UD_Georgian-GNC',\n    'de_gsd':                    'UD_German-GSD',\n    'de_hdt':                    'UD_German-HDT',\n    'de_lit':                    'UD_German-LIT',\n    'de_pud':                    'UD_German-PUD',\n    'aln_gps':                   'UD_Gheg-GPS',\n    'got_proiel':                'UD_Gothic-PROIEL',\n    'el_cretan':                 'UD_Greek-Cretan',\n    'el_gdt':                    'UD_Greek-GDT',\n    'el_gud':                    'UD_Greek-GUD',\n    'el_lesbian':                'UD_Greek-Lesbian',\n    'el_messinian':              'UD_Greek-Messinian',\n    'gub_tudet':                 'UD_Guajajara-TuDeT',\n    'gn_oldtudet':               'UD_Guarani-OldTuDeT',\n    'gu_gujtb':                  'UD_Gujarati-GujTB',\n    'gwi_tuecl':                 'UD_Gwichin-TueCL',\n    'ht_adolphe':                'UD_Haitian_Creole-Adolphe',\n    'ht_autogramm':              'UD_Haitian_Creole-Autogramm',\n    'ha_northernautogramm':      'UD_Hausa-NorthernAutogramm',\n    'ha_southernautogramm':      'UD_Hausa-SouthernAutogramm',\n    'ha_westernautogramm':       'UD_Hausa-WesternAutogramm',\n    'he_htb':                    'UD_Hebrew-HTB',\n    'he_iahltknesset':           'UD_Hebrew-IAHLTknesset',\n    'he_iahltwiki':              'UD_Hebrew-IAHLTwiki',\n    'azz_itml':                  'UD_Highland_Puebla_Nahuatl-ITML',\n    'hi_hdtb':                   'UD_Hindi-HDTB',\n    'hi_pud':                    'UD_Hindi-PUD',\n    'hit_hittb':                 'UD_Hittite-HitTB',\n    'hu_szeged':                 'UD_Hungarian-Szeged',\n    'is_gc':                     'UD_Icelandic-GC',\n    'is_icepahc':                'UD_Icelandic-IcePaHC',\n    'is_modern':                 'UD_Icelandic-Modern',\n    'is_pud':                    'UD_Icelandic-PUD',\n    'arh_chibergis':             'UD_Ika-ChibErgIS',\n    'id_csui':                   'UD_Indonesian-CSUI',\n    'id_gsd':                    'UD_Indonesian-GSD',\n    'id_pud':                    'UD_Indonesian-PUD',\n    'ga_cadhan':                 'UD_Irish-Cadhan',\n    'ga_idt':                    'UD_Irish-IDT',\n    'ga_twittirish':             'UD_Irish-TwittIrish',\n    'it_isdt':                   'UD_Italian-ISDT',\n    'it_kiparlaforest':          'UD_Italian-KIParlaForest',\n    'it_markit':                 'UD_Italian-MarkIT',\n    'it_old':                    'UD_Italian-Old',\n    'it_pud':                    'UD_Italian-PUD',\n    'it_partut':                 'UD_Italian-ParTUT',\n    'it_parlamint':              'UD_Italian-ParlaMint',\n    'it_postwita':               'UD_Italian-PoSTWITA',\n    'it_twittiro':               'UD_Italian-TWITTIRO',\n    'it_vit':                    'UD_Italian-VIT',\n    'it_valico':                 'UD_Italian-Valico',\n    'ja_bccwj':                  'UD_Japanese-BCCWJ',\n    'ja_bccwjluw':               'UD_Japanese-BCCWJLUW',\n    'ja_gsd':                    'UD_Japanese-GSD',\n    'ja_gsdluw':                 'UD_Japanese-GSDLUW',\n    'ja_pud':                    'UD_Japanese-PUD',\n    'ja_pudluw':                 'UD_Japanese-PUDLUW',\n    'jv_csui':                   'UD_Javanese-CSUI',\n    'urb_tudet':                 'UD_Kaapor-TuDeT',\n    'xnr_kdtb':                  'UD_Kangri-KDTB',\n    'krl_kkpp':                  'UD_Karelian-KKPP',\n    'arr_tudet':                 'UD_Karo-TuDeT',\n    'kk_ktb':                    'UD_Kazakh-KTB',\n    'naq_kdt':                   'UD_Khoekhoe-KDT',\n    'kfm_aha':                   'UD_Khunsari-AHA',\n    'quc_iu':                    'UD_Kiche-IU',\n    'koi_uh':                    'UD_Komi_Permyak-UH',\n    'kpv_ikdp':                  'UD_Komi_Zyrian-IKDP',\n    'kpv_lattice':               'UD_Komi_Zyrian-Lattice',\n    'ko_gsd':                    'UD_Korean-GSD',\n    'ko_ksl':                    'UD_Korean-KSL',\n    'ko_kaist':                  'UD_Korean-Kaist',\n    'ko_littleprince':           'UD_Korean-LittlePrince',\n    'ko_pud':                    'UD_Korean-PUD',\n    'ky_ktmu':                   'UD_Kyrgyz-KTMU',\n    'ky_tuecl':                  'UD_Kyrgyz-TueCL',\n    'ltg_cairo':                 'UD_Latgalian-Cairo',\n    'la_circse':                 'UD_Latin-CIRCSE',\n    'la_ittb':                   'UD_Latin-ITTB',\n    'la_llct':                   'UD_Latin-LLCT',\n    'la_proiel':                 'UD_Latin-PROIEL',\n    'la_perseus':                'UD_Latin-Perseus',\n    'la_udante':                 'UD_Latin-UDante',\n    'lv_cairo':                  'UD_Latvian-Cairo',\n    'lv_lvtb':                   'UD_Latvian-LVTB',\n    'lij_glt':                   'UD_Ligurian-GLT',\n    'lt_alksnis':                'UD_Lithuanian-ALKSNIS',\n    'lt_hse':                    'UD_Lithuanian-HSE',\n    'olo_kkpp':                  'UD_Livvi-KKPP',\n    'nds_lsdc':                  'UD_Low_Saxon-LSDC',\n    'lb_luxbank':                'UD_Luxembourgish-LuxBank',\n    'mk_mtb':                    'UD_Macedonian-MTB',\n    'jaa_jarawara':              'UD_Madi-Jarawara',\n    'qaf_arabizi':               'UD_Maghrebi_Arabic_French-Arabizi',\n    'mpu_tudet':                 'UD_Makurap-TuDeT',\n    'ml_ufal':                   'UD_Malayalam-UFAL',\n    'mt_mudt':                   'UD_Maltese-MUDT',\n    'gv_cadhan':                 'UD_Manx-Cadhan',\n    'mr_ufal':                   'UD_Marathi-UFAL',\n    'gun_dooley':                'UD_Mbya_Guarani-Dooley',\n    'gun_thomas':                'UD_Mbya_Guarani-Thomas',\n    'frm_altm':                  'UD_Middle_French-ALTM',\n    'frm_profiterole':           'UD_Middle_French-PROFITEROLE',\n    'mdf_jr':                    'UD_Moksha-JR',\n    'myu_tudet':                 'UD_Munduruku-TuDeT',\n    'nmf_suansu':                'UD_Naga-Suansu',\n    'pcm_nsc':                   'UD_Naija-NSC',\n    'nyq_aha':                   'UD_Nayini-AHA',\n    'nap_rb':                    'UD_Neapolitan-RB',\n    'nrk_tundra':                'UD_Nenets-Tundra',\n    'yrl_complin':               'UD_Nheengatu-CompLin',\n    'sme_giella':                'UD_North_Sami-Giella',\n    'kmr_kurmanji':              'UD_Northern_Kurdish-Kurmanji',\n    'gya_autogramm':             'UD_Northwest_Gbaya-Autogramm',\n    'nb_bokmaal':                'UD_Norwegian-Bokmaal',\n    'no_bokmaal':                'UD_Norwegian-Bokmaal',\n    'nn_nynorsk':                'UD_Norwegian-Nynorsk',\n    'oc_ttb':                    'UD_Occitan-TTB',\n    'or_odtb':                   'UD_Odia-ODTB',\n    'cu_proiel':                 'UD_Old_Church_Slavonic-PROIEL',\n    'orv_birchbark':             'UD_Old_East_Slavic-Birchbark',\n    'orv_rnc':                   'UD_Old_East_Slavic-RNC',\n    'orv_ruthenian':             'UD_Old_East_Slavic-Ruthenian',\n    'orv_torot':                 'UD_Old_East_Slavic-TOROT',\n    'ang_cairo':                 'UD_Old_English-Cairo',\n    'fro_altm':                  'UD_Old_French-ALTM',\n    'fro_profiterole':           'UD_Old_French-PROFITEROLE',\n    'sga_dipsgg':                'UD_Old_Irish-DipSGG',\n    'sga_dipwbg':                'UD_Old_Irish-DipWBG',\n    'pro_corag':                 'UD_Old_Occitan-CorAG',\n    'otk_clausal':               'UD_Old_Turkish-Clausal',\n    'ota_boun':                  'UD_Ottoman_Turkish-BOUN',\n    'ota_dudu':                  'UD_Ottoman_Turkish-DUDU',\n    'ps_sikaram':                'UD_Pashto-Sikaram',\n    'pad_tuecl':                 'UD_Paumari-TueCL',\n    'fa_perdt':                  'UD_Persian-PerDT',\n    'fa_seraji':                 'UD_Persian-Seraji',\n    'pay_chibergis':             'UD_Pesh-ChibErgIS',\n    'xpg_kul':                   'UD_Phrygian-KUL',\n    'pl_lfg':                    'UD_Polish-LFG',\n    'pl_mpdt':                   'UD_Polish-MPDT',\n    'pl_pdb':                    'UD_Polish-PDB',\n    'pl_pud':                    'UD_Polish-PUD',\n    'qpm_philotis':              'UD_Pomak-Philotis',\n    'pt_bosque':                 'UD_Portuguese-Bosque',\n    'pt_cintil':                 'UD_Portuguese-CINTIL',\n    'pt_dantestocks':            'UD_Portuguese-DANTEStocks',\n    'pt_gsd':                    'UD_Portuguese-GSD',\n    'pt_pud':                    'UD_Portuguese-PUD',\n    'pt_petrogold':              'UD_Portuguese-PetroGold',\n    'pt_porttinari':             'UD_Portuguese-Porttinari',\n    'ro_art':                    'UD_Romanian-ArT',\n    'ro_moldoro':                'UD_Romanian-MolDoRo',\n    'ro_nonstandard':            'UD_Romanian-Nonstandard',\n    'ro_rrt':                    'UD_Romanian-RRT',\n    'ro_simonero':               'UD_Romanian-SiMoNERo',\n    'ro_tuecl':                  'UD_Romanian-TueCL',\n    'ru_gsd':                    'UD_Russian-GSD',\n    'ru_pud':                    'UD_Russian-PUD',\n    'ru_poetry':                 'UD_Russian-Poetry',\n    'ru_syntagrus':              'UD_Russian-SynTagRus',\n    'ru_taiga':                  'UD_Russian-Taiga',\n    'sa_ufal':                   'UD_Sanskrit-UFAL',\n    'sa_vedic':                  'UD_Sanskrit-Vedic',\n    'gd_arcosg':                 'UD_Scottish_Gaelic-ARCOSG',\n    'sr_set':                    'UD_Serbian-SET',\n    'wuu_shud':                  'UD_Shanghainese-ShUD',\n    'scn_stb':                   'UD_Sicilian-STB',\n    'sd_isra':                   'UD_Sindhi-Isra',\n    'si_stb':                    'UD_Sinhala-STB',\n    'sms_giellagas':             'UD_Skolt_Sami-Giellagas',\n    'sk_snk':                    'UD_Slovak-SNK',\n    'sl_ssj':                    'UD_Slovenian-SSJ',\n    'sl_sst':                    'UD_Slovenian-SST',\n    'soj_aha':                   'UD_Soi-AHA',\n    'ajp_madar':                 'UD_South_Levantine_Arabic-MADAR',\n    'sdh_garrusi':               'UD_Southern_Kurdish-Garrusi',\n    'es_ancora':                 'UD_Spanish-AnCora',\n    'es_coser':                  'UD_Spanish-COSER',\n    'es_gsd':                    'UD_Spanish-GSD',\n    'es_pud':                    'UD_Spanish-PUD',\n    'ssp_lse':                   'UD_Spanish_Sign_Language-LSE',\n    'sv_lines':                  'UD_Swedish-LinES',\n    'sv_old':                    'UD_Swedish-Old',\n    'sv_pud':                    'UD_Swedish-PUD',\n    'sv_swell':                  'UD_Swedish-SweLL',\n    'sv_talbanken':              'UD_Swedish-Talbanken',\n    'swl_sslc':                  'UD_Swedish_Sign_Language-SSLC',\n    'tl_trg':                    'UD_Tagalog-TRG',\n    'tl_ugnayan':                'UD_Tagalog-Ugnayan',\n    'ta_mwtt':                   'UD_Tamil-MWTT',\n    'ta_ttb':                    'UD_Tamil-TTB',\n    'tt_nmctt':                  'UD_Tatar-NMCTT',\n    'eme_tudet':                 'UD_Teko-TuDeT',\n    'te_mtg':                    'UD_Telugu-MTG',\n    'qte_tect':                  'UD_Telugu_English-TECT',\n    'th_pud':                    'UD_Thai-PUD',\n    'th_tud':                    'UD_Thai-TUD',\n    'tn_popapolelo':             'UD_Tswana-Popapolelo',\n    'tpn_tudet':                 'UD_Tupinamba-TuDeT',\n    'tr_atis':                   'UD_Turkish-Atis',\n    'tr_boun':                   'UD_Turkish-BOUN',\n    'tr_framenet':               'UD_Turkish-FrameNet',\n    'tr_gb':                     'UD_Turkish-GB',\n    'tr_imst':                   'UD_Turkish-IMST',\n    'tr_kenet':                  'UD_Turkish-Kenet',\n    'tr_pud':                    'UD_Turkish-PUD',\n    'tr_penn':                   'UD_Turkish-Penn',\n    'tr_tourism':                'UD_Turkish-Tourism',\n    'tr_tuecl':                  'UD_Turkish-TueCL',\n    'qti_butr':                  'UD_Turkish_English-BUTR',\n    'qtd_sagt':                  'UD_Turkish_German-SAGT',\n    'uk_iu':                     'UD_Ukrainian-IU',\n    'uk_parlamint':              'UD_Ukrainian-ParlaMint',\n    'xum_ikuvina':               'UD_Umbrian-IKUVINA',\n    'hsb_ufal':                  'UD_Upper_Sorbian-UFAL',\n    'ur_udtb':                   'UD_Urdu-UDTB',\n    'ug_udt':                    'UD_Uyghur-UDT',\n    'uz_tuecl':                  'UD_Uzbek-TueCL',\n    'uz_ut':                     'UD_Uzbek-UT',\n    'uz_uzudt':                  'UD_Uzbek-UzUDT',\n    'vep_vwt':                   'UD_Veps-VWT',\n    'vi_tuecl':                  'UD_Vietnamese-TueCL',\n    'vi_vtb':                    'UD_Vietnamese-VTB',\n    'wbp_ufal':                  'UD_Warlpiri-UFAL',\n    'cy_ccg':                    'UD_Welsh-CCG',\n    'hyw_armtdp':                'UD_Western_Armenian-ArmTDP',\n    'nhi_itml':                  'UD_Western_Sierra_Puebla_Nahuatl-ITML',\n    'wo_wtb':                    'UD_Wolof-WTB',\n    'xav_xdt':                   'UD_Xavante-XDT',\n    'sjo_xdt':                   'UD_Xibe-XDT',\n    'sah_yktdt':                 'UD_Yakut-YKTDT',\n    'yi_yitb':                   'UD_Yiddish-YiTB',\n    'yo_ytb':                    'UD_Yoruba-YTB',\n    'ess_sli':                   'UD_Yupik-SLI',\n    'say_autogramm':             'UD_Zaar-Autogramm',\n}\n\n\ndef short_name_to_treebank(short_name):\n    return SHORT_NAMES[short_name]\n\n\nCANONICAL_NAMES = {\n    'ud_abaza-atb':                            'UD_Abaza-ATB',\n    'ud_abkhaz-abnc':                          'UD_Abkhaz-AbNC',\n    'ud_afrikaans-afribooms':                  'UD_Afrikaans-AfriBooms',\n    'ud_akkadian-pisandub':                    'UD_Akkadian-PISANDUB',\n    'ud_akkadian-riao':                        'UD_Akkadian-RIAO',\n    'ud_akuntsu-tudet':                        'UD_Akuntsu-TuDeT',\n    'ud_albanian-staf':                        'UD_Albanian-STAF',\n    'ud_albanian-tsa':                         'UD_Albanian-TSA',\n    'ud_alemannic-divital':                    'UD_Alemannic-DIVITAL',\n    'ud_alemannic-uzh':                        'UD_Alemannic-UZH',\n    'ud_amharic-att':                          'UD_Amharic-ATT',\n    'ud_ancient_greek-proiel':                 'UD_Ancient_Greek-PROIEL',\n    'ud_ancient_greek-ptnk':                   'UD_Ancient_Greek-PTNK',\n    'ud_ancient_greek-perseus':                'UD_Ancient_Greek-Perseus',\n    'ud_ancient_hebrew-ptnk':                  'UD_Ancient_Hebrew-PTNK',\n    'ud_apurina-ufpa':                         'UD_Apurina-UFPA',\n    'ud_arabic-nyuad':                         'UD_Arabic-NYUAD',\n    'ud_arabic-padt':                          'UD_Arabic-PADT',\n    'ud_arabic-pud':                           'UD_Arabic-PUD',\n    'ud_armenian-armtdp':                      'UD_Armenian-ArmTDP',\n    'ud_armenian-bsut':                        'UD_Armenian-BSUT',\n    'ud_assyrian-as':                          'UD_Assyrian-AS',\n    'ud_azerbaijani-tuecl':                    'UD_Azerbaijani-TueCL',\n    'ud_bambara-crb':                          'UD_Bambara-CRB',\n    'ud_basque-bdt':                           'UD_Basque-BDT',\n    'ud_bavarian-maibaam':                     'UD_Bavarian-MaiBaam',\n    'ud_beja-autogramm':                       'UD_Beja-Autogramm',\n    'ud_belarusian-hse':                       'UD_Belarusian-HSE',\n    'ud_bengali-bru':                          'UD_Bengali-BRU',\n    'ud_bhojpuri-bhtb':                        'UD_Bhojpuri-BHTB',\n    'ud_bokota-chibergis':                     'UD_Bokota-ChibErgIS',\n    'ud_bororo-bdt':                           'UD_Bororo-BDT',\n    'ud_breton-keb':                           'UD_Breton-KEB',\n    'ud_bulgarian-btb':                        'UD_Bulgarian-BTB',\n    'ud_buryat-bdt':                           'UD_Buryat-BDT',\n    'ud_cantonese-hk':                         'UD_Cantonese-HK',\n    'ud_cappadocian-amgic':                    'UD_Cappadocian-AMGiC',\n    'ud_cappadocian-tuecl':                    'UD_Cappadocian-TueCL',\n    'ud_catalan-ancora':                       'UD_Catalan-AnCora',\n    'ud_cebuano-gja':                          'UD_Cebuano-GJA',\n    'ud_central_kurdish-mukri':                'UD_Central_Kurdish-Mukri',\n    'ud_chinese-beginner':                     'UD_Chinese-Beginner',\n    'ud_chinese-cfl':                          'UD_Chinese-CFL',\n    'ud_chinese-gsd':                          'UD_Chinese-GSD',\n    'ud_chinese-gsdsimp':                      'UD_Chinese-GSDSimp',\n    'ud_chinese-hk':                           'UD_Chinese-HK',\n    'ud_chinese-pud':                          'UD_Chinese-PUD',\n    'ud_chinese-patentchar':                   'UD_Chinese-PatentChar',\n    'ud_chintang-ctntb':                       'UD_Chintang-CTNTB',\n    'ud_chukchi-hse':                          'UD_Chukchi-HSE',\n    'ud_classical_armenian-caval':             'UD_Classical_Armenian-CAVaL',\n    'ud_classical_chinese-kyoto':              'UD_Classical_Chinese-Kyoto',\n    'ud_classical_chinese-tuecl':              'UD_Classical_Chinese-TueCL',\n    'ud_coptic-bohairic':                      'UD_Coptic-Bohairic',\n    'ud_coptic-scriptorium':                   'UD_Coptic-Scriptorium',\n    'ud_croatian-set':                         'UD_Croatian-SET',\n    'ud_czech-cac':                            'UD_Czech-CAC',\n    'ud_czech-cltt':                           'UD_Czech-CLTT',\n    'ud_czech-fictree':                        'UD_Czech-FicTree',\n    'ud_czech-pdtc':                           'UD_Czech-PDTC',\n    'ud_czech-pud':                            'UD_Czech-PUD',\n    'ud_czech-poetry':                         'UD_Czech-Poetry',\n    'ud_danish-ddt':                           'UD_Danish-DDT',\n    'ud_dutch-alpino':                         'UD_Dutch-Alpino',\n    'ud_dutch-lassysmall':                     'UD_Dutch-LassySmall',\n    'ud_egyptian-ujaen':                       'UD_Egyptian-UJaen',\n    'ud_english-atis':                         'UD_English-Atis',\n    'ud_english-childes':                      'UD_English-CHILDES',\n    'ud_english-ctetex':                       'UD_English-CTeTex',\n    'ud_english-eslspok':                      'UD_English-ESLSpok',\n    'ud_english-ewt':                          'UD_English-EWT',\n    'ud_english-gentle':                       'UD_English-GENTLE',\n    'ud_english-gum':                          'UD_English-GUM',\n    'ud_english-gumreddit':                    'UD_English-GUMReddit',\n    'ud_english-lines':                        'UD_English-LinES',\n    'ud_english-littleprince':                 'UD_English-LittlePrince',\n    'ud_english-pud':                          'UD_English-PUD',\n    'ud_english-partut':                       'UD_English-ParTUT',\n    'ud_english-pronouns':                     'UD_English-Pronouns',\n    'ud_erzya-jr':                             'UD_Erzya-JR',\n    'ud_esperanto-cairo':                      'UD_Esperanto-Cairo',\n    'ud_esperanto-prago':                      'UD_Esperanto-Prago',\n    'ud_estonian-edt':                         'UD_Estonian-EDT',\n    'ud_estonian-ewt':                         'UD_Estonian-EWT',\n    'ud_faroese-farpahc':                      'UD_Faroese-FarPaHC',\n    'ud_faroese-oft':                          'UD_Faroese-OFT',\n    'ud_finnish-ftb':                          'UD_Finnish-FTB',\n    'ud_finnish-ood':                          'UD_Finnish-OOD',\n    'ud_finnish-pud':                          'UD_Finnish-PUD',\n    'ud_finnish-tdt':                          'UD_Finnish-TDT',\n    'ud_french-alts':                          'UD_French-ALTS',\n    'ud_french-fqb':                           'UD_French-FQB',\n    'ud_french-gsd':                           'UD_French-GSD',\n    'ud_french-pud':                           'UD_French-PUD',\n    'ud_french-partut':                        'UD_French-ParTUT',\n    'ud_french-parisstories':                  'UD_French-ParisStories',\n    'ud_french-poitevindivital':               'UD_French-PoitevinDIVITAL',\n    'ud_french-rhapsodie':                     'UD_French-Rhapsodie',\n    'ud_french-sequoia':                       'UD_French-Sequoia',\n    'ud_frisian_dutch-fame':                   'UD_Frisian_Dutch-Fame',\n    'ud_galician-ctg':                         'UD_Galician-CTG',\n    'ud_galician-pud':                         'UD_Galician-PUD',\n    'ud_galician-treegal':                     'UD_Galician-TreeGal',\n    'ud_georgian-glc':                         'UD_Georgian-GLC',\n    'ud_georgian-gnc':                         'UD_Georgian-GNC',\n    'ud_german-gsd':                           'UD_German-GSD',\n    'ud_german-hdt':                           'UD_German-HDT',\n    'ud_german-lit':                           'UD_German-LIT',\n    'ud_german-pud':                           'UD_German-PUD',\n    'ud_gheg-gps':                             'UD_Gheg-GPS',\n    'ud_gothic-proiel':                        'UD_Gothic-PROIEL',\n    'ud_greek-cretan':                         'UD_Greek-Cretan',\n    'ud_greek-gdt':                            'UD_Greek-GDT',\n    'ud_greek-gud':                            'UD_Greek-GUD',\n    'ud_greek-lesbian':                        'UD_Greek-Lesbian',\n    'ud_greek-messinian':                      'UD_Greek-Messinian',\n    'ud_guajajara-tudet':                      'UD_Guajajara-TuDeT',\n    'ud_guarani-oldtudet':                     'UD_Guarani-OldTuDeT',\n    'ud_gujarati-gujtb':                       'UD_Gujarati-GujTB',\n    'ud_gwichin-tuecl':                        'UD_Gwichin-TueCL',\n    'ud_haitian_creole-adolphe':               'UD_Haitian_Creole-Adolphe',\n    'ud_haitian_creole-autogramm':             'UD_Haitian_Creole-Autogramm',\n    'ud_hausa-northernautogramm':              'UD_Hausa-NorthernAutogramm',\n    'ud_hausa-southernautogramm':              'UD_Hausa-SouthernAutogramm',\n    'ud_hausa-westernautogramm':               'UD_Hausa-WesternAutogramm',\n    'ud_hebrew-htb':                           'UD_Hebrew-HTB',\n    'ud_hebrew-iahltknesset':                  'UD_Hebrew-IAHLTknesset',\n    'ud_hebrew-iahltwiki':                     'UD_Hebrew-IAHLTwiki',\n    'ud_highland_puebla_nahuatl-itml':         'UD_Highland_Puebla_Nahuatl-ITML',\n    'ud_hindi-hdtb':                           'UD_Hindi-HDTB',\n    'ud_hindi-pud':                            'UD_Hindi-PUD',\n    'ud_hittite-hittb':                        'UD_Hittite-HitTB',\n    'ud_hungarian-szeged':                     'UD_Hungarian-Szeged',\n    'ud_icelandic-gc':                         'UD_Icelandic-GC',\n    'ud_icelandic-icepahc':                    'UD_Icelandic-IcePaHC',\n    'ud_icelandic-modern':                     'UD_Icelandic-Modern',\n    'ud_icelandic-pud':                        'UD_Icelandic-PUD',\n    'ud_ika-chibergis':                        'UD_Ika-ChibErgIS',\n    'ud_indonesian-csui':                      'UD_Indonesian-CSUI',\n    'ud_indonesian-gsd':                       'UD_Indonesian-GSD',\n    'ud_indonesian-pud':                       'UD_Indonesian-PUD',\n    'ud_irish-cadhan':                         'UD_Irish-Cadhan',\n    'ud_irish-idt':                            'UD_Irish-IDT',\n    'ud_irish-twittirish':                     'UD_Irish-TwittIrish',\n    'ud_italian-isdt':                         'UD_Italian-ISDT',\n    'ud_italian-kiparlaforest':                'UD_Italian-KIParlaForest',\n    'ud_italian-markit':                       'UD_Italian-MarkIT',\n    'ud_italian-old':                          'UD_Italian-Old',\n    'ud_italian-pud':                          'UD_Italian-PUD',\n    'ud_italian-partut':                       'UD_Italian-ParTUT',\n    'ud_italian-parlamint':                    'UD_Italian-ParlaMint',\n    'ud_italian-postwita':                     'UD_Italian-PoSTWITA',\n    'ud_italian-twittiro':                     'UD_Italian-TWITTIRO',\n    'ud_italian-vit':                          'UD_Italian-VIT',\n    'ud_italian-valico':                       'UD_Italian-Valico',\n    'ud_japanese-bccwj':                       'UD_Japanese-BCCWJ',\n    'ud_japanese-bccwjluw':                    'UD_Japanese-BCCWJLUW',\n    'ud_japanese-gsd':                         'UD_Japanese-GSD',\n    'ud_japanese-gsdluw':                      'UD_Japanese-GSDLUW',\n    'ud_japanese-pud':                         'UD_Japanese-PUD',\n    'ud_japanese-pudluw':                      'UD_Japanese-PUDLUW',\n    'ud_javanese-csui':                        'UD_Javanese-CSUI',\n    'ud_kaapor-tudet':                         'UD_Kaapor-TuDeT',\n    'ud_kangri-kdtb':                          'UD_Kangri-KDTB',\n    'ud_karelian-kkpp':                        'UD_Karelian-KKPP',\n    'ud_karo-tudet':                           'UD_Karo-TuDeT',\n    'ud_kazakh-ktb':                           'UD_Kazakh-KTB',\n    'ud_khoekhoe-kdt':                         'UD_Khoekhoe-KDT',\n    'ud_khunsari-aha':                         'UD_Khunsari-AHA',\n    'ud_kiche-iu':                             'UD_Kiche-IU',\n    'ud_komi_permyak-uh':                      'UD_Komi_Permyak-UH',\n    'ud_komi_zyrian-ikdp':                     'UD_Komi_Zyrian-IKDP',\n    'ud_komi_zyrian-lattice':                  'UD_Komi_Zyrian-Lattice',\n    'ud_korean-gsd':                           'UD_Korean-GSD',\n    'ud_korean-ksl':                           'UD_Korean-KSL',\n    'ud_korean-kaist':                         'UD_Korean-Kaist',\n    'ud_korean-littleprince':                  'UD_Korean-LittlePrince',\n    'ud_korean-pud':                           'UD_Korean-PUD',\n    'ud_kyrgyz-ktmu':                          'UD_Kyrgyz-KTMU',\n    'ud_kyrgyz-tuecl':                         'UD_Kyrgyz-TueCL',\n    'ud_latgalian-cairo':                      'UD_Latgalian-Cairo',\n    'ud_latin-circse':                         'UD_Latin-CIRCSE',\n    'ud_latin-ittb':                           'UD_Latin-ITTB',\n    'ud_latin-llct':                           'UD_Latin-LLCT',\n    'ud_latin-proiel':                         'UD_Latin-PROIEL',\n    'ud_latin-perseus':                        'UD_Latin-Perseus',\n    'ud_latin-udante':                         'UD_Latin-UDante',\n    'ud_latvian-cairo':                        'UD_Latvian-Cairo',\n    'ud_latvian-lvtb':                         'UD_Latvian-LVTB',\n    'ud_ligurian-glt':                         'UD_Ligurian-GLT',\n    'ud_lithuanian-alksnis':                   'UD_Lithuanian-ALKSNIS',\n    'ud_lithuanian-hse':                       'UD_Lithuanian-HSE',\n    'ud_livvi-kkpp':                           'UD_Livvi-KKPP',\n    'ud_low_saxon-lsdc':                       'UD_Low_Saxon-LSDC',\n    'ud_luxembourgish-luxbank':                'UD_Luxembourgish-LuxBank',\n    'ud_macedonian-mtb':                       'UD_Macedonian-MTB',\n    'ud_madi-jarawara':                        'UD_Madi-Jarawara',\n    'ud_maghrebi_arabic_french-arabizi':       'UD_Maghrebi_Arabic_French-Arabizi',\n    'ud_makurap-tudet':                        'UD_Makurap-TuDeT',\n    'ud_malayalam-ufal':                       'UD_Malayalam-UFAL',\n    'ud_maltese-mudt':                         'UD_Maltese-MUDT',\n    'ud_manx-cadhan':                          'UD_Manx-Cadhan',\n    'ud_marathi-ufal':                         'UD_Marathi-UFAL',\n    'ud_mbya_guarani-dooley':                  'UD_Mbya_Guarani-Dooley',\n    'ud_mbya_guarani-thomas':                  'UD_Mbya_Guarani-Thomas',\n    'ud_middle_french-altm':                   'UD_Middle_French-ALTM',\n    'ud_middle_french-profiterole':            'UD_Middle_French-PROFITEROLE',\n    'ud_moksha-jr':                            'UD_Moksha-JR',\n    'ud_munduruku-tudet':                      'UD_Munduruku-TuDeT',\n    'ud_naga-suansu':                          'UD_Naga-Suansu',\n    'ud_naija-nsc':                            'UD_Naija-NSC',\n    'ud_nayini-aha':                           'UD_Nayini-AHA',\n    'ud_neapolitan-rb':                        'UD_Neapolitan-RB',\n    'ud_nenets-tundra':                        'UD_Nenets-Tundra',\n    'ud_nheengatu-complin':                    'UD_Nheengatu-CompLin',\n    'ud_north_sami-giella':                    'UD_North_Sami-Giella',\n    'ud_northern_kurdish-kurmanji':            'UD_Northern_Kurdish-Kurmanji',\n    'ud_northwest_gbaya-autogramm':            'UD_Northwest_Gbaya-Autogramm',\n    'ud_norwegian-bokmaal':                    'UD_Norwegian-Bokmaal',\n    'ud_norwegian-nynorsk':                    'UD_Norwegian-Nynorsk',\n    'ud_occitan-ttb':                          'UD_Occitan-TTB',\n    'ud_odia-odtb':                            'UD_Odia-ODTB',\n    'ud_old_church_slavonic-proiel':           'UD_Old_Church_Slavonic-PROIEL',\n    'ud_old_east_slavic-birchbark':            'UD_Old_East_Slavic-Birchbark',\n    'ud_old_east_slavic-rnc':                  'UD_Old_East_Slavic-RNC',\n    'ud_old_east_slavic-ruthenian':            'UD_Old_East_Slavic-Ruthenian',\n    'ud_old_east_slavic-torot':                'UD_Old_East_Slavic-TOROT',\n    'ud_old_english-cairo':                    'UD_Old_English-Cairo',\n    'ud_old_french-altm':                      'UD_Old_French-ALTM',\n    'ud_old_french-profiterole':               'UD_Old_French-PROFITEROLE',\n    'ud_old_irish-dipsgg':                     'UD_Old_Irish-DipSGG',\n    'ud_old_irish-dipwbg':                     'UD_Old_Irish-DipWBG',\n    'ud_old_occitan-corag':                    'UD_Old_Occitan-CorAG',\n    'ud_old_turkish-clausal':                  'UD_Old_Turkish-Clausal',\n    'ud_ottoman_turkish-boun':                 'UD_Ottoman_Turkish-BOUN',\n    'ud_ottoman_turkish-dudu':                 'UD_Ottoman_Turkish-DUDU',\n    'ud_pashto-sikaram':                       'UD_Pashto-Sikaram',\n    'ud_paumari-tuecl':                        'UD_Paumari-TueCL',\n    'ud_persian-perdt':                        'UD_Persian-PerDT',\n    'ud_persian-seraji':                       'UD_Persian-Seraji',\n    'ud_pesh-chibergis':                       'UD_Pesh-ChibErgIS',\n    'ud_phrygian-kul':                         'UD_Phrygian-KUL',\n    'ud_polish-lfg':                           'UD_Polish-LFG',\n    'ud_polish-mpdt':                          'UD_Polish-MPDT',\n    'ud_polish-pdb':                           'UD_Polish-PDB',\n    'ud_polish-pud':                           'UD_Polish-PUD',\n    'ud_pomak-philotis':                       'UD_Pomak-Philotis',\n    'ud_portuguese-bosque':                    'UD_Portuguese-Bosque',\n    'ud_portuguese-cintil':                    'UD_Portuguese-CINTIL',\n    'ud_portuguese-dantestocks':               'UD_Portuguese-DANTEStocks',\n    'ud_portuguese-gsd':                       'UD_Portuguese-GSD',\n    'ud_portuguese-pud':                       'UD_Portuguese-PUD',\n    'ud_portuguese-petrogold':                 'UD_Portuguese-PetroGold',\n    'ud_portuguese-porttinari':                'UD_Portuguese-Porttinari',\n    'ud_romanian-art':                         'UD_Romanian-ArT',\n    'ud_romanian-moldoro':                     'UD_Romanian-MolDoRo',\n    'ud_romanian-nonstandard':                 'UD_Romanian-Nonstandard',\n    'ud_romanian-rrt':                         'UD_Romanian-RRT',\n    'ud_romanian-simonero':                    'UD_Romanian-SiMoNERo',\n    'ud_romanian-tuecl':                       'UD_Romanian-TueCL',\n    'ud_russian-gsd':                          'UD_Russian-GSD',\n    'ud_russian-pud':                          'UD_Russian-PUD',\n    'ud_russian-poetry':                       'UD_Russian-Poetry',\n    'ud_russian-syntagrus':                    'UD_Russian-SynTagRus',\n    'ud_russian-taiga':                        'UD_Russian-Taiga',\n    'ud_sanskrit-ufal':                        'UD_Sanskrit-UFAL',\n    'ud_sanskrit-vedic':                       'UD_Sanskrit-Vedic',\n    'ud_scottish_gaelic-arcosg':               'UD_Scottish_Gaelic-ARCOSG',\n    'ud_serbian-set':                          'UD_Serbian-SET',\n    'ud_shanghainese-shud':                    'UD_Shanghainese-ShUD',\n    'ud_sicilian-stb':                         'UD_Sicilian-STB',\n    'ud_sindhi-isra':                          'UD_Sindhi-Isra',\n    'ud_sinhala-stb':                          'UD_Sinhala-STB',\n    'ud_skolt_sami-giellagas':                 'UD_Skolt_Sami-Giellagas',\n    'ud_slovak-snk':                           'UD_Slovak-SNK',\n    'ud_slovenian-ssj':                        'UD_Slovenian-SSJ',\n    'ud_slovenian-sst':                        'UD_Slovenian-SST',\n    'ud_soi-aha':                              'UD_Soi-AHA',\n    'ud_south_levantine_arabic-madar':         'UD_South_Levantine_Arabic-MADAR',\n    'ud_southern_kurdish-garrusi':             'UD_Southern_Kurdish-Garrusi',\n    'ud_spanish-ancora':                       'UD_Spanish-AnCora',\n    'ud_spanish-coser':                        'UD_Spanish-COSER',\n    'ud_spanish-gsd':                          'UD_Spanish-GSD',\n    'ud_spanish-pud':                          'UD_Spanish-PUD',\n    'ud_spanish_sign_language-lse':            'UD_Spanish_Sign_Language-LSE',\n    'ud_swedish-lines':                        'UD_Swedish-LinES',\n    'ud_swedish-old':                          'UD_Swedish-Old',\n    'ud_swedish-pud':                          'UD_Swedish-PUD',\n    'ud_swedish-swell':                        'UD_Swedish-SweLL',\n    'ud_swedish-talbanken':                    'UD_Swedish-Talbanken',\n    'ud_swedish_sign_language-sslc':           'UD_Swedish_Sign_Language-SSLC',\n    'ud_tagalog-trg':                          'UD_Tagalog-TRG',\n    'ud_tagalog-ugnayan':                      'UD_Tagalog-Ugnayan',\n    'ud_tamil-mwtt':                           'UD_Tamil-MWTT',\n    'ud_tamil-ttb':                            'UD_Tamil-TTB',\n    'ud_tatar-nmctt':                          'UD_Tatar-NMCTT',\n    'ud_teko-tudet':                           'UD_Teko-TuDeT',\n    'ud_telugu-mtg':                           'UD_Telugu-MTG',\n    'ud_telugu_english-tect':                  'UD_Telugu_English-TECT',\n    'ud_thai-pud':                             'UD_Thai-PUD',\n    'ud_thai-tud':                             'UD_Thai-TUD',\n    'ud_tswana-popapolelo':                    'UD_Tswana-Popapolelo',\n    'ud_tupinamba-tudet':                      'UD_Tupinamba-TuDeT',\n    'ud_turkish-atis':                         'UD_Turkish-Atis',\n    'ud_turkish-boun':                         'UD_Turkish-BOUN',\n    'ud_turkish-framenet':                     'UD_Turkish-FrameNet',\n    'ud_turkish-gb':                           'UD_Turkish-GB',\n    'ud_turkish-imst':                         'UD_Turkish-IMST',\n    'ud_turkish-kenet':                        'UD_Turkish-Kenet',\n    'ud_turkish-pud':                          'UD_Turkish-PUD',\n    'ud_turkish-penn':                         'UD_Turkish-Penn',\n    'ud_turkish-tourism':                      'UD_Turkish-Tourism',\n    'ud_turkish-tuecl':                        'UD_Turkish-TueCL',\n    'ud_turkish_english-butr':                 'UD_Turkish_English-BUTR',\n    'ud_turkish_german-sagt':                  'UD_Turkish_German-SAGT',\n    'ud_ukrainian-iu':                         'UD_Ukrainian-IU',\n    'ud_ukrainian-parlamint':                  'UD_Ukrainian-ParlaMint',\n    'ud_umbrian-ikuvina':                      'UD_Umbrian-IKUVINA',\n    'ud_upper_sorbian-ufal':                   'UD_Upper_Sorbian-UFAL',\n    'ud_urdu-udtb':                            'UD_Urdu-UDTB',\n    'ud_uyghur-udt':                           'UD_Uyghur-UDT',\n    'ud_uzbek-tuecl':                          'UD_Uzbek-TueCL',\n    'ud_uzbek-ut':                             'UD_Uzbek-UT',\n    'ud_uzbek-uzudt':                          'UD_Uzbek-UzUDT',\n    'ud_veps-vwt':                             'UD_Veps-VWT',\n    'ud_vietnamese-tuecl':                     'UD_Vietnamese-TueCL',\n    'ud_vietnamese-vtb':                       'UD_Vietnamese-VTB',\n    'ud_warlpiri-ufal':                        'UD_Warlpiri-UFAL',\n    'ud_welsh-ccg':                            'UD_Welsh-CCG',\n    'ud_western_armenian-armtdp':              'UD_Western_Armenian-ArmTDP',\n    'ud_western_sierra_puebla_nahuatl-itml':   'UD_Western_Sierra_Puebla_Nahuatl-ITML',\n    'ud_wolof-wtb':                            'UD_Wolof-WTB',\n    'ud_xavante-xdt':                          'UD_Xavante-XDT',\n    'ud_xibe-xdt':                             'UD_Xibe-XDT',\n    'ud_yakut-yktdt':                          'UD_Yakut-YKTDT',\n    'ud_yiddish-yitb':                         'UD_Yiddish-YiTB',\n    'ud_yoruba-ytb':                           'UD_Yoruba-YTB',\n    'ud_yupik-sli':                            'UD_Yupik-SLI',\n    'ud_zaar-autogramm':                       'UD_Zaar-Autogramm',\n}\n\n\ndef canonical_treebank_name(ud_name):\n    if ud_name in SHORT_NAMES:\n        return SHORT_NAMES[ud_name]\n    return CANONICAL_NAMES.get(ud_name.lower(), ud_name)\n"
  },
  {
    "path": "stanza/models/common/stanza_object.py",
    "content": "def _readonly_setter(self, name):\n    full_classname = self.__class__.__module__\n    if full_classname is None:\n        full_classname = self.__class__.__qualname__\n    else:\n        full_classname += '.' + self.__class__.__qualname__\n    raise ValueError(f'Property \"{name}\" of \"{full_classname}\" is read-only.')\n\nclass StanzaObject(object):\n    \"\"\"\n    Base class for all Stanza data objects that allows for some flexibility handling annotations\n    \"\"\"\n\n    @classmethod\n    def add_property(cls, name, default=None, getter=None, setter=None):\n        \"\"\"\n        Add a property accessible through self.{name} with underlying variable self._{name}.\n        Optionally setup a setter as well.\n        \"\"\"\n\n        if hasattr(cls, name):\n            raise ValueError(f'Property by the name of {name} already exists in {cls}. Maybe you want to find another name?')\n\n        setattr(cls, f'_{name}', default)\n        if getter is None:\n            getter = lambda self: getattr(self, f'_{name}')\n        if setter is None:\n            setter = lambda self, value: _readonly_setter(self, name)\n\n        setattr(cls, name, property(getter, setter))\n\n"
  },
  {
    "path": "stanza/models/common/trainer.py",
    "content": "import torch\n\nclass Trainer:\n    def change_lr(self, new_lr):\n        for param_group in self.optimizer.param_groups:\n            param_group['lr'] = new_lr\n\n    def save(self, filename):\n        savedict = {\n                   'model': self.model.state_dict(),\n                   'optimizer': self.optimizer.state_dict()\n                   }\n        torch.save(savedict, filename)\n\n    def load(self, filename):\n        savedict = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n\n        self.model.load_state_dict(savedict['model'])\n        if self.args['mode'] == 'train':\n            self.optimizer.load_state_dict(savedict['optimizer'])\n"
  },
  {
    "path": "stanza/models/common/utils.py",
    "content": "\"\"\"\nUtility functions.\n\"\"\"\n\nimport argparse\nfrom collections import Counter\nfrom contextlib import contextmanager\nimport gzip\nimport json\nimport logging\nimport lzma\nimport os\nimport random\nimport re\nimport sys\nimport unicodedata\nimport zipfile\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\ntry:\n    from udtools import udeval\nexcept ImportError:\n    from udtools.src.udtools import udeval\n\ntry:\n    from udtools.udeval import UDError\nexcept ImportError:\n    from udtools.src.udtools.udeval import UDError\n\nfrom stanza.models.common.constant import lcode2lang\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.resources.default_packages import TRANSFORMER_NICKNAMES\n\nlogger = logging.getLogger('stanza')\n\n# filenames\ndef get_wordvec_file(wordvec_dir, shorthand, wordvec_type=None):\n    \"\"\" Lookup the name of the word vectors file, given a directory and the language shorthand.\n    \"\"\"\n    lcode, tcode = shorthand.split('_', 1)\n    lang = lcode2lang[lcode]\n    # locate language folder\n    word2vec_dir = os.path.join(wordvec_dir, 'word2vec', lang)\n    fasttext_dir = os.path.join(wordvec_dir, 'fasttext', lang)\n    lang_dir = None\n    if wordvec_type is not None:\n        lang_dir = os.path.join(wordvec_dir, wordvec_type, lang)\n        if not os.path.exists(lang_dir):\n            raise FileNotFoundError(\"Word vector type {} was specified, but directory {} does not exist\".format(wordvec_type, lang_dir))\n    elif os.path.exists(word2vec_dir): # first try word2vec\n        lang_dir = word2vec_dir\n    elif os.path.exists(fasttext_dir): # otherwise try fasttext\n        lang_dir = fasttext_dir\n    else:\n        raise FileNotFoundError(\"Cannot locate word vector directory for language: {}  Looked in {} and {}\".format(lang, word2vec_dir, fasttext_dir))\n    # look for wordvec filename in {lang_dir}\n    filename = os.path.join(lang_dir, '{}.vectors'.format(lcode))\n    if os.path.exists(filename + \".xz\"):\n        filename = filename + \".xz\"\n    elif os.path.exists(filename + \".txt\"):\n        filename = filename + \".txt\"\n    return filename\n\n@contextmanager\ndef output_stream(filename=None):\n    \"\"\"\n    Yields the given file if a file is given, or returns sys.stdout if filename is None\n\n    Opens the file in a context manager so it closes nicely\n    \"\"\"\n    if filename is None:\n        yield sys.stdout\n    else:\n        with open(filename, \"w\", encoding=\"utf-8\") as fout:\n            yield fout\n\n\n@contextmanager\ndef open_read_text(filename, encoding=\"utf-8\"):\n    \"\"\"\n    Opens a file as an .xz file or .gz if it ends with .xz or .gz, or regular text otherwise.\n\n    Use as a context\n\n    eg:\n    with open_read_text(filename) as fin:\n        do stuff\n\n    File will be closed once the context exits\n    \"\"\"\n    if filename.endswith(\".xz\"):\n        with lzma.open(filename, mode='rt', encoding=encoding) as fin:\n            yield fin\n    elif filename.endswith(\".gz\"):\n        with gzip.open(filename, mode='rt', encoding=encoding) as fin:\n            yield fin\n    else:\n        with open(filename, encoding=encoding) as fin:\n            yield fin\n\n@contextmanager\ndef open_read_binary(filename):\n    \"\"\"\n    Opens a file as an .xz file or .gz if it ends with .xz or .gz, or regular binary file otherwise.\n\n    If a .zip file is given, it can be read if there is a single file in there\n\n    Use as a context\n\n    eg:\n    with open_read_binary(filename) as fin:\n        do stuff\n\n    File will be closed once the context exits\n    \"\"\"\n    if filename.endswith(\".xz\"):\n        with lzma.open(filename, mode='rb') as fin:\n            yield fin\n    elif filename.endswith(\".gz\"):\n        with gzip.open(filename, mode='rb') as fin:\n            yield fin\n    elif filename.endswith(\".zip\"):\n        with zipfile.ZipFile(filename) as zin:\n            input_names = zin.namelist()\n            if len(input_names) == 0:\n                raise ValueError(\"Empty zip archive\")\n            if len(input_names) > 1:\n                raise ValueError(\"zip file %s has more than one file in it\")\n            with zin.open(input_names[0]) as fin:\n                yield fin\n    else:\n        with open(filename, mode='rb') as fin:\n            yield fin\n\n# training schedule\ndef get_adaptive_eval_interval(cur_dev_size, thres_dev_size, base_interval):\n    \"\"\" Adjust the evaluation interval adaptively.\n    If cur_dev_size <= thres_dev_size, return base_interval;\n    else, linearly increase the interval (round to integer times of base interval).\n    \"\"\"\n    if cur_dev_size <= thres_dev_size:\n        return base_interval\n    else:\n        alpha = round(cur_dev_size / thres_dev_size)\n        return base_interval * alpha\n\n# ud utils\ndef ud_scores(gold_conllu_file, system_conllu_file):\n    def has_readline(f):\n        return hasattr(f, 'readline') and callable(f.readline)\n\n    if has_readline(gold_conllu_file):\n        try:\n            gold_ud = udeval.load_conllu(gold_conllu_file, '', {})\n        except UDError as e:\n            raise UDError(\"Could not process gold UD file\") from e\n    else:\n        try:\n            gold_ud = udeval.load_conllu_file(gold_conllu_file)\n        except UDError as e:\n            raise UDError(\"Could not read %s\" % gold_conllu_file) from e\n\n    if has_readline(system_conllu_file):\n        try:\n            system_ud = udeval.load_conllu(system_conllu_file, '', {})\n        except UDError as e:\n            raise UDError(\"Could not process system UD file\") from e\n    else:\n        try:\n            system_ud = udeval.load_conllu_file(system_conllu_file)\n        except UDError as e:\n            raise UDError(\"Could not read %s\" % system_conllu_file) from e\n\n    evaluation = udeval.evaluate(gold_ud, system_ud)\n\n    return evaluation\n\ndef harmonic_mean(a, weights=None):\n    if any([x == 0 for x in a]):\n        return 0\n    else:\n        assert weights is None or len(weights) == len(a), 'Weights has length {} which is different from that of the array ({}).'.format(len(weights), len(a))\n        if weights is None:\n            return len(a) / sum([1/x for x in a])\n        else:\n            return sum(weights) / sum(w/x for x, w in zip(a, weights))\n\n# torch utils\ndef dispatch_optimizer(name, parameters, opt_logger, lr=None, betas=None, eps=None, momentum=None, **extra_args):\n    extra_logging = \"\"\n    if len(extra_args) > 0:\n        extra_logging = \", \" + \", \".join(\"%s=%s\" % (x, y) for x, y in extra_args.items())\n\n    if name == 'amsgrad':\n        opt_logger.debug(\"Building Adam w/ amsgrad with lr=%f, betas=%s, eps=%f%s\", lr, betas, eps, extra_logging)\n        return torch.optim.Adam(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)\n    elif name == 'amsgradw':\n        opt_logger.debug(\"Building AdamW w/ amsgrad with lr=%f, betas=%s, eps=%f%s\", lr, betas, eps, extra_logging)\n        return torch.optim.AdamW(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)\n    elif name == 'sgd':\n        opt_logger.debug(\"Building SGD with lr=%f, momentum=%f%s\", lr, momentum, extra_logging)\n        return torch.optim.SGD(parameters, lr=lr, momentum=momentum, **extra_args)\n    elif name == 'adagrad':\n        opt_logger.debug(\"Building Adagrad with lr=%f%s\", lr, extra_logging)\n        return torch.optim.Adagrad(parameters, lr=lr, **extra_args)\n    elif name == 'adam':\n        opt_logger.debug(\"Building Adam with lr=%f, betas=%s, eps=%f%s\", lr, betas, eps, extra_logging)\n        return torch.optim.Adam(parameters, lr=lr, betas=betas, eps=eps, **extra_args)\n    elif name == 'adamw':\n        opt_logger.debug(\"Building AdamW with lr=%f, betas=%s, eps=%f%s\", lr, betas, eps, extra_logging)\n        return torch.optim.AdamW(parameters, lr=lr, betas=betas, eps=eps, **extra_args)\n    elif name == 'adamax':\n        opt_logger.debug(\"Building Adamax%s\", extra_logging)\n        return torch.optim.Adamax(parameters, **extra_args) # use default lr\n    elif name == 'adadelta':\n        opt_logger.debug(\"Building Adadelta with lr=%f%s\", lr, extra_logging)\n        return torch.optim.Adadelta(parameters, lr=lr, **extra_args)\n    elif name == 'adabelief':\n        try:\n            from adabelief_pytorch import AdaBelief\n        except ModuleNotFoundError as e:\n            raise ModuleNotFoundError(\"Could not create adabelief optimizer.  Perhaps the adabelief-pytorch package is not installed\") from e\n        opt_logger.debug(\"Building AdaBelief with lr=%f, eps=%f%s\", lr, eps, extra_logging)\n        # TODO: add weight_decouple and rectify as extra args?\n        return AdaBelief(parameters, lr=lr, eps=eps, weight_decouple=True, rectify=True, **extra_args)\n    elif name == 'madgrad':\n        try:\n            import madgrad\n        except ModuleNotFoundError as e:\n            raise ModuleNotFoundError(\"Could not create madgrad optimizer.  Perhaps the madgrad package is not installed\") from e\n        opt_logger.debug(\"Building MADGRAD with lr=%f, momentum=%f%s\", lr, momentum, extra_logging)\n        return madgrad.MADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)\n    elif name == 'mirror_madgrad':\n        try:\n            import madgrad\n        except ModuleNotFoundError as e:\n            raise ModuleNotFoundError(\"Could not create mirror_madgrad optimizer.  Perhaps the madgrad package is not installed\") from e\n        opt_logger.debug(\"Building MirrorMADGRAD with lr=%f, momentum=%f%s\", lr, momentum, extra_logging)\n        return madgrad.MirrorMADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)\n    elif name == 'rmsprop':\n        opt_logger.debug(\"Building RMSprop with lr=%f%s\", lr, extra_logging)\n        return torch.optim.RMSprop(parameters, lr=lr, **extra_args)\n    else:\n        raise ValueError(\"Unsupported optimizer: {}\".format(name))\n\n\ndef get_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0, weight_decay=None, bert_learning_rate=0.0, bert_weight_decay=None, charlm_learning_rate=0.0, is_peft=False, bert_finetune_layers=None, opt_logger=None):\n    opt_logger = opt_logger if opt_logger is not None else logger\n    base_parameters = [p for n, p in model.named_parameters()\n                       if p.requires_grad and not n.startswith(\"bert_model.\")\n                       and not n.startswith(\"charmodel_forward.\") and not n.startswith(\"charmodel_backward.\")]\n    parameters = [{'param_group_name': 'base', 'params': base_parameters}]\n\n    charlm_parameters = [p for n, p in model.named_parameters()\n                         if p.requires_grad and (n.startswith(\"charmodel_forward.\") or n.startswith(\"charmodel_backward.\"))]\n    if len(charlm_parameters) > 0 and charlm_learning_rate > 0:\n        parameters.append({'param_group_name': 'charlm', 'params': charlm_parameters, 'lr': lr * charlm_learning_rate})\n\n    if not is_peft:\n        bert_parameters = [p for n, p in model.named_parameters() if p.requires_grad and n.startswith(\"bert_model.\")]\n\n        # bert_finetune_layers limits the bert finetuning to the *last* N layers of the model\n        if len(bert_parameters) > 0 and bert_finetune_layers is not None:\n            num_layers = model.bert_model.config.num_hidden_layers\n            start_layer = num_layers - bert_finetune_layers\n            bert_parameters = []\n            for layer_num in range(start_layer, num_layers):\n                bert_parameters.extend([param for name, param in model.named_parameters()\n                                        if param.requires_grad and name.startswith(\"bert_model.\") and \"layer.%d.\" % layer_num in name])\n\n        if len(bert_parameters) > 0 and bert_learning_rate > 0:\n            opt_logger.debug(\"Finetuning %d bert parameters with LR %s and WD %s\", len(bert_parameters), lr * bert_learning_rate, bert_weight_decay)\n            parameters.append({'param_group_name': 'bert', 'params': bert_parameters, 'lr': lr * bert_learning_rate})\n            if bert_weight_decay is not None:\n                parameters[-1]['weight_decay'] = bert_weight_decay\n    else:\n        # some optimizers seem to train some even with a learning rate of 0...\n        if bert_learning_rate > 0:\n            # because PEFT handles what to hand to an optimizer, we don't want to touch that\n            parameters.append({'param_group_name': 'bert', 'params': model.bert_model.parameters(), 'lr': lr * bert_learning_rate})\n            if bert_weight_decay is not None:\n                parameters[-1]['weight_decay'] = bert_weight_decay\n\n    extra_args = {}\n    if weight_decay is not None:\n        extra_args[\"weight_decay\"] = weight_decay\n\n    return dispatch_optimizer(name, parameters, opt_logger=opt_logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)\n\ndef get_split_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0, weight_decay=None, bert_learning_rate=0.0, bert_weight_decay=None, charlm_learning_rate=0.0, is_peft=False, bert_finetune_layers=None):\n    \"\"\"Same as `get_optimizer`, but splits the optimizer for Bert into a separate optimizer\"\"\"\n    base_parameters = [p for n, p in model.named_parameters()\n                       if p.requires_grad and not n.startswith(\"bert_model.\")\n                       and not n.startswith(\"charmodel_forward.\") and not n.startswith(\"charmodel_backward.\")]\n    parameters = [{'param_group_name': 'base', 'params': base_parameters}]\n\n    charlm_parameters = [p for n, p in model.named_parameters()\n                         if p.requires_grad and (n.startswith(\"charmodel_forward.\") or n.startswith(\"charmodel_backward.\"))]\n    if len(charlm_parameters) > 0 and charlm_learning_rate > 0:\n        parameters.append({'param_group_name': 'charlm', 'params': charlm_parameters, 'lr': lr * charlm_learning_rate})\n\n    bert_parameters = None\n    if not is_peft:\n        trainable_parameters = [p for n, p in model.named_parameters() if p.requires_grad and n.startswith(\"bert_model.\")]\n\n        # bert_finetune_layers limits the bert finetuning to the *last* N layers of the model\n        if len(trainable_parameters) > 0 and bert_finetune_layers is not None:\n            num_layers = model.bert_model.config.num_hidden_layers\n            start_layer = num_layers - bert_finetune_layers\n            trainable_parameters = []\n            for layer_num in range(start_layer, num_layers):\n                trainable_parameters.extend([param for name, param in model.named_parameters()\n                                             if param.requires_grad and name.startswith(\"bert_model.\") and \"layer.%d.\" % layer_num in name])\n\n        if len(trainable_parameters) > 0:\n            bert_parameters = [{'param_group_name': 'bert', 'params': trainable_parameters, 'lr': lr * bert_learning_rate}]\n    else:\n        # because PEFT handles what to hand to an optimizer, we don't want to touch that\n        bert_parameters = [{'param_group_name': 'bert', 'params': model.bert_model.parameters(), 'lr': lr * bert_learning_rate}]\n\n    extra_args = {}\n    if weight_decay is not None:\n        extra_args[\"weight_decay\"] = weight_decay\n\n    optimizers = {\n        \"general_optimizer\": dispatch_optimizer(name, parameters, opt_logger=logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)\n    }\n    if bert_parameters is not None and bert_learning_rate > 0.0:\n        if bert_weight_decay is not None:\n            extra_args['weight_decay'] = bert_weight_decay\n        optimizers[\"bert_optimizer\"] = dispatch_optimizer(name, bert_parameters, opt_logger=logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)\n    return optimizers\n\n\ndef change_lr(optimizer, new_lr):\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = new_lr\n\ndef flatten_indices(seq_lens, width):\n    flat = []\n    for i, l in enumerate(seq_lens):\n        for j in range(l):\n            flat.append(i * width + j)\n    return flat\n\ndef keep_partial_grad(grad, topk):\n    \"\"\"\n    Keep only the topk rows of grads.\n    \"\"\"\n    assert topk < grad.size(0)\n    grad.data[topk:].zero_()\n    return grad\n\n# other utils\ndef ensure_dir(d, verbose=True):\n    if not os.path.exists(d):\n        if verbose:\n            logger.info(\"Directory {} does not exist; creating...\".format(d))\n        # exist_ok: guard against race conditions\n        os.makedirs(d, exist_ok=True)\n\ndef save_config(config, path, verbose=True):\n    with open(path, 'w') as outfile:\n        json.dump(config, outfile, indent=2)\n    if verbose:\n        print(\"Config saved to file {}\".format(path))\n    return config\n\ndef load_config(path, verbose=True):\n    with open(path) as f:\n        config = json.load(f)\n    if verbose:\n        print(\"Config loaded from file {}\".format(path))\n    return config\n\ndef print_config(config):\n    info = \"Running with the following configs:\\n\"\n    for k,v in config.items():\n        info += \"\\t{} : {}\\n\".format(k, str(v))\n    logger.info(\"\\n\" + info + \"\\n\")\n\ndef normalize_text(text):\n    return unicodedata.normalize('NFD', text)\n\ndef unmap_with_copy(indices, src_tokens, vocab):\n    \"\"\"\n    Unmap a list of list of indices, by optionally copying from src_tokens.\n    \"\"\"\n    result = []\n    for ind, tokens in zip(indices, src_tokens):\n        words = []\n        for idx in ind:\n            if idx >= 0:\n                words.append(vocab.id2word[idx])\n            else:\n                idx = -idx - 1 # flip and minus 1\n                words.append(tokens[idx])\n        result += [words]\n    return result\n\ndef prune_decoded_seqs(seqs):\n    \"\"\"\n    Prune decoded sequences after EOS token.\n    \"\"\"\n    out = []\n    for s in seqs:\n        if constant.EOS in s:\n            idx = s.index(constant.EOS_TOKEN)\n            out += [s[:idx]]\n        else:\n            out += [s]\n    return out\n\ndef prune_hyp(hyp):\n    \"\"\"\n    Prune a decoded hypothesis\n    \"\"\"\n    if constant.EOS_ID in hyp:\n        idx = hyp.index(constant.EOS_ID)\n        return hyp[:idx]\n    else:\n        return hyp\n\ndef prune(data_list, lens):\n    assert len(data_list) == len(lens)\n    nl = []\n    for d, l in zip(data_list, lens):\n        nl.append(d[:l])\n    return nl\n\ndef sort(packed, ref, reverse=True):\n    \"\"\"\n    Sort a series of packed list, according to a ref list.\n    Also return the original index before the sort.\n    \"\"\"\n    assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)\n    packed = [ref] + [range(len(ref))] + list(packed)\n    sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]\n    return tuple(sorted_packed[1:])\n\ndef unsort(sorted_list, oidx):\n    \"\"\"\n    Unsort a sorted list, based on the original idx.\n    \"\"\"\n    assert len(sorted_list) == len(oidx), \"Number of list elements must match with original indices.\"\n    if len(sorted_list) == 0:\n        return []\n    _, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]\n    return unsorted\n\ndef sort_with_indices(data, key=None, reverse=False):\n    \"\"\"\n    Sort data and return both the data and the original indices.\n\n    One useful application is to sort by length, which can be done with key=len\n    Returns the data as a sorted list, then the indices of the original list.\n    \"\"\"\n    if not data:\n        return [], []\n    if key:\n        ordered = sorted(enumerate(data), key=lambda x: key(x[1]), reverse=reverse)\n    else:\n        ordered = sorted(enumerate(data), key=lambda x: x[1], reverse=reverse)\n\n    result = tuple(zip(*ordered))\n    return result[1], result[0]\n\ndef split_into_batches(data, batch_size):\n    \"\"\"\n    Returns a list of intervals so that each interval is either <= batch_size or one element long.\n\n    Long elements are not dropped from the intervals.\n    data is a list of lists\n    batch_size is how long to make each batch\n    return value is a list of pairs, start_idx end_idx\n    \"\"\"\n    intervals = []\n    interval_start = 0\n    interval_size = 0\n    for idx, line in enumerate(data):\n        if len(line) > batch_size:\n            # guess we'll just hope the model can handle a batch of this size after all\n            if interval_size > 0:\n                intervals.append((interval_start, idx))\n            intervals.append((idx, idx+1))\n            interval_start = idx+1\n            interval_size = 0\n        elif len(line) + interval_size > batch_size:\n            # this line puts us over batch_size\n            intervals.append((interval_start, idx))\n            interval_start = idx\n            interval_size = len(line)\n        else:\n            interval_size = interval_size + len(line)\n    if interval_size > 0:\n        # there's some leftover\n        intervals.append((interval_start, len(data)))\n    return intervals\n\ndef tensor_unsort(sorted_tensor, oidx):\n    \"\"\"\n    Unsort a sorted tensor on its 0-th dimension, based on the original idx.\n    \"\"\"\n    assert sorted_tensor.size(0) == len(oidx), \"Number of list elements must match with original indices.\"\n    backidx = [x[0] for x in sorted(enumerate(oidx), key=lambda x: x[1])]\n    return sorted_tensor[backidx]\n\n\ndef set_random_seed(seed):\n    \"\"\"\n    Set a random seed on all of the things which might need it.\n    torch, np, python random, and torch.cuda\n    \"\"\"\n    if seed is None:\n        seed = random.randint(0, 1000000000)\n\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    # some of these calls are probably redundant\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n    return seed\n\ndef find_missing_tags(known_tags, test_tags):\n    if isinstance(known_tags, list) and isinstance(known_tags[0], list):\n        known_tags = set(x for y in known_tags for x in y)\n    if isinstance(test_tags, list) and isinstance(test_tags[0], list):\n        test_tags = sorted(set(x for y in test_tags for x in y))\n    missing_tags = sorted(x for x in test_tags if x not in known_tags)\n    return missing_tags\n\ndef warn_missing_tags(known_tags, test_tags, test_set_name):\n    \"\"\"\n    Print a warning if any tags present in the second list are not in the first list.\n\n    Can also handle a list of lists.\n    \"\"\"\n    missing_tags = find_missing_tags(known_tags, test_tags)\n    if len(missing_tags) > 0:\n        logger.warning(\"Found tags in {} missing from the expected tag set: {}\".format(test_set_name, missing_tags))\n        return True\n    return False\n\ndef checkpoint_name(save_dir, save_name, checkpoint_name):\n    \"\"\"\n    Will return a recommended checkpoint name for the given dir, save_name, optional checkpoint_name\n\n    For example, can pass in args['save_dir'], args['save_name'], args['checkpoint_save_name']\n    \"\"\"\n    if checkpoint_name:\n        model_dir = os.path.split(checkpoint_name)[0]\n        if model_dir == save_dir:\n            return checkpoint_name\n        return os.path.join(save_dir, checkpoint_name)\n\n    model_dir = os.path.split(save_name)[0]\n    if model_dir != save_dir:\n        save_name = os.path.join(save_dir, save_name)\n    if save_name.endswith(\".pt\"):\n        return save_name[:-3] + \"_checkpoint.pt\"\n\n    return save_name + \"_checkpoint\"\n\ndef default_device():\n    \"\"\"\n    Pick a default device based on what's available on this system\n    \"\"\"\n    if torch.cuda.is_available():\n        return 'cuda'\n    return 'cpu'\n\ndef add_device_args(parser):\n    \"\"\"\n    Add args which specify cpu, cuda, or arbitrary device\n    \"\"\"\n    parser.add_argument('--device', type=str, default=default_device(), help='Which device to run on - use a torch device string name')\n    parser.add_argument('--cuda', dest='device', action='store_const', const='cuda', help='Run on CUDA')\n    parser.add_argument('--cpu', dest='device', action='store_const', const='cpu', help='Ignore CUDA and run on CPU')\n\ndef load_elmo(elmo_model):\n    # This import is here so that Elmo integration can be treated\n    # as an optional feature\n    import elmoformanylangs\n\n    logger.info(\"Loading elmo: %s\" % elmo_model)\n    elmo_model = elmoformanylangs.Embedder(elmo_model)\n    return elmo_model\n\ndef log_training_args(args, args_logger, name=\"training\"):\n    \"\"\"\n    For record keeping purposes, log the arguments when training\n    \"\"\"\n    if isinstance(args, argparse.Namespace):\n        args = vars(args)\n    keys = sorted(args.keys())\n    log_lines = ['%s: %s' % (k, args[k]) for k in keys]\n    args_logger.info('ARGS USED AT %s TIME:\\n%s\\n', name.upper(), '\\n'.join(log_lines))\n\ndef embedding_name(args):\n    \"\"\"\n    Return the generic name of the biggest embedding used by a model.\n\n    Used by POS and depparse, for example.\n\n    TODO: Probably will make the transformer names a bit more informative,\n    such as electra, roberta, etc.  Maybe even phobert for VI, for example\n    \"\"\"\n    embedding = \"nocharlm\"\n    if args['wordvec_pretrain_file'] is None and args['wordvec_file'] is None:\n        embedding = \"nopretrain\"\n    if args.get('charlm', True) and (args['charlm_forward_file'] or args['charlm_backward_file']):\n        embedding = \"charlm\"\n    if args['bert_model']:\n        if args['bert_model'] in TRANSFORMER_NICKNAMES:\n            embedding = TRANSFORMER_NICKNAMES[args['bert_model']]\n        else:\n            embedding = \"transformer\"\n\n    return embedding\n\ndef standard_model_file_name(args, model_type, **kwargs):\n    \"\"\"\n    Returns a model file name based on some common args found in the various models.\n\n    The expectation is that the args will have something like\n\n      parser.add_argument('--save_name', type=str, default=\"{shorthand}_{embedding}_parser.pt\", help=\"File name to save the model\")\n\n    Then the model shorthand, embedding type, and other args will be\n    turned into arguments in a format string\n    \"\"\"\n    embedding = embedding_name(args)\n\n    finetune = \"\"\n    transformer_lr = \"\"\n    if args.get(\"bert_finetune\", False):\n        finetune = \"finetuned\"\n        if \"bert_learning_rate\" in args:\n            transformer_lr = \"{}\".format(args[\"bert_learning_rate\"])\n\n    use_peft = \"nopeft\"\n    if args.get(\"bert_finetune\", False) and args.get(\"use_peft\", False):\n        use_peft = \"peft\"\n\n    bert_finetuning = \"\"\n    if args.get(\"bert_finetune\", False):\n        if args.get(\"use_peft\", False):\n            bert_finetuning = \"peft\"\n        else:\n            bert_finetuning = \"ft\"\n\n    seed = args.get('seed', None)\n    if seed is None:\n        seed = \"\"\n    else:\n        seed = str(seed)\n\n    format_args = {\n        \"batch_size\":      args['batch_size'],\n        \"bert_finetuning\": bert_finetuning,\n        \"embedding\":       embedding,\n        \"finetune\":        finetune,\n        \"peft\":            use_peft,\n        \"seed\":            seed,\n        \"shorthand\":       args['shorthand'],\n        \"transformer_lr\":  transformer_lr,\n    }\n    format_args.update(**kwargs)\n    model_file = args['save_name'].format(**format_args)\n    model_file = re.sub(\"_+\", \"_\", model_file)\n\n    model_dir = os.path.split(model_file)[0]\n\n    if not os.path.exists(os.path.join(args['save_dir'], model_file)) and os.path.exists(model_file):\n        return model_file\n    if model_dir.startswith(args['save_dir']):\n        return model_file\n    return os.path.join(args['save_dir'], model_file)\n\ndef escape_misc_space(space):\n    spaces = []\n    for char in space:\n        if char == ' ':\n            spaces.append('\\\\s')\n        elif char == '\\t':\n            spaces.append('\\\\t')\n        elif char == '\\r':\n            spaces.append('\\\\r')\n        elif char == '\\n':\n            spaces.append('\\\\n')\n        elif char == '|':\n            spaces.append('\\\\p')\n        elif char == '\\\\':\n            spaces.append('\\\\\\\\')\n        elif char == ' ':\n            spaces.append('\\\\u00A0')\n        else:\n            spaces.append(char)\n    escaped_space = \"\".join(spaces)\n    return escaped_space\n\ndef unescape_misc_space(misc_space):\n    spaces = []\n    pos = 0\n    while pos < len(misc_space):\n        if misc_space[pos:pos+2] == '\\\\s':\n            spaces.append(' ')\n            pos += 2\n        elif misc_space[pos:pos+2] == '\\\\t':\n            spaces.append('\\t')\n            pos += 2\n        elif misc_space[pos:pos+2] == '\\\\r':\n            spaces.append('\\r')\n            pos += 2\n        elif misc_space[pos:pos+2] == '\\\\n':\n            spaces.append('\\n')\n            pos += 2\n        elif misc_space[pos:pos+2] == '\\\\p':\n            spaces.append('|')\n            pos += 2\n        elif misc_space[pos:pos+2] == '\\\\\\\\':\n            spaces.append('\\\\')\n            pos += 2\n        elif misc_space[pos:pos+6] == '\\\\u00A0':\n            spaces.append(' ')\n            pos += 6\n        else:\n            spaces.append(misc_space[pos])\n            pos += 1\n    unescaped_space = \"\".join(spaces)\n    return unescaped_space\n\ndef space_before_to_misc(space):\n    \"\"\"\n    Convert whitespace to SpacesBefore specifically for the start of a document.\n\n    In general, UD datasets do not have both SpacesAfter on a token and SpacesBefore on the next token.\n\n    The space(s) are only marked on one of the tokens.\n\n    Only at the very beginning of a document is it necessary to mark what spaces occurred before the actual text,\n    and the default assumption is that there is no space if there is no SpacesBefore annotation.\n    \"\"\"\n    if not space:\n        return \"\"\n    escaped_space = escape_misc_space(space)\n    return \"SpacesBefore=%s\" % escaped_space\n\ndef space_after_to_misc(space):\n    \"\"\"\n    Convert whitespace back to the escaped format - either SpaceAfter=No or SpacesAfter=...\n    \"\"\"\n    if not space:\n        return \"SpaceAfter=No\"\n    if space == \" \":\n        return \"\"\n    escaped_space = escape_misc_space(space)\n    return \"SpacesAfter=%s\" % escaped_space\n\ndef misc_to_space_before(misc):\n    \"\"\"\n    Find any SpacesBefore annotation in the MISC column and turn it into a space value\n    \"\"\"\n    if not misc:\n        return \"\"\n    pieces = misc.split(\"|\")\n    for piece in pieces:\n        if not piece.lower().startswith(\"spacesbefore=\"):\n            continue\n        misc_space = piece.split(\"=\", maxsplit=1)[1]\n        return unescape_misc_space(misc_space)\n    return \"\"\n\ndef misc_to_space_after(misc):\n    \"\"\"\n    Convert either SpaceAfter=No or the SpacesAfter annotation\n\n    see https://universaldependencies.org/misc.html#spacesafter\n\n    We compensate for some treebanks using SpaceAfter=\\n instead of SpacesAfter=\\n\n    On the way back, though, those annotations will be turned into SpacesAfter\n    \"\"\"\n    if not misc:\n        return \" \"\n    pieces = misc.split(\"|\")\n    if any(piece.lower() == \"spaceafter=no\" for piece in pieces):\n        return \"\"\n    if \"SpaceAfter=Yes\" in pieces:\n        # as of UD 2.11, the Cantonese treebank had this as a misc feature\n        return \" \"\n    if \"SpaceAfter=No~\" in pieces:\n        # as of UD 2.11, a weird typo in the Russian Taiga dataset\n        return \"\"\n    for piece in pieces:\n        if piece.startswith(\"SpaceAfter=\") or piece.startswith(\"SpacesAfter=\"):\n            misc_space = piece.split(\"=\", maxsplit=1)[1]\n            return unescape_misc_space(misc_space)\n    return \" \"\n\ndef log_norms(model):\n    lines = [\"NORMS FOR MODEL PARAMTERS\"]\n    pieces = []\n    for name, param in model.named_parameters():\n        if param.requires_grad:\n            pieces.append((name, \"%.6g\" % torch.norm(param).item(), \"%d\" % param.numel()))\n    name_len = max(len(x[0]) for x in pieces)\n    norm_len = max(len(x[1]) for x in pieces)\n    line_format = \"  %-\" + str(name_len) + \"s   %\" + str(norm_len) + \"s     %s\"\n    for line in pieces:\n        lines.append(line_format % line)\n    logger.info(\"\\n\".join(lines))\n\ndef attach_bert_model(model, bert_model, bert_tokenizer, use_peft, force_bert_saved):\n    if use_peft:\n        # we use a peft-specific pathway for saving peft weights\n        model.add_unsaved_module('bert_model', bert_model)\n        model.bert_model.train()\n    elif force_bert_saved:\n        model.bert_model = bert_model\n    elif bert_model is not None:\n        model.add_unsaved_module('bert_model', bert_model)\n        for _, parameter in bert_model.named_parameters():\n            parameter.requires_grad = False\n    else:\n        model.bert_model = None\n    model.add_unsaved_module('bert_tokenizer', bert_tokenizer)\n\ndef build_save_each_filename(base_filename):\n    \"\"\"\n    If the given name doesn't have %d in it, add %4d at the end of the filename\n\n    This way, there's something to count how many models have been saved\n    \"\"\"\n    try:\n        base_filename % 1\n    except TypeError:\n        # so models.pt -> models_0001.pt, etc\n        pieces = os.path.splitext(model_save_each_file)\n        base_filename = pieces[0] + \"_%04d\" + pieces[1]\n    return base_filename\n\n# the constituency parser went through a large suite of experiments to\n# optimize which nonlinearity to use\n#\n# this is on a VI dataset, VLSP_22, using 1/10th of the data as a dev set\n# (no released test set at the time of the experiment)\n# original non-Bert tagger, with 1 iteration each instead of averaged over 5\n# considering the number of experiments and the length of time they would take\n#\n# Gelu had the highest score, which tracks with other experiments run.\n# Note that publicly released models have typically used Relu\n# on account of the runtime speed improvement\n#\n# Anyway, a larger experiment of 5x models on gelu or relu, using the\n# Roberta POS tagger and a corpus of silver trees, resulted in 0.8270\n# for relu and 0.8248 for gelu.  So it is not even clear that\n# switching to gelu would be an accuracy improvement.\n#\n# Gelu: 82.32\n# Relu: 82.14\n# Mish: 81.95\n# Relu6: 81.91\n# Silu: 81.90\n# ELU: 81.73\n# Hardswish: 81.67\n# Softsign: 81.63\n# Hardtanh: 81.44\n# Celu: 81.43\n# Selu: 81.17\n#   TODO: need to redo the prelu experiment with\n#         possibly different numbers of parameters\n#         and proper weight decay\n# Prelu: 80.95 (terminated early)\n# Softplus: 80.94\n# Logsigmoid: 80.91\n# Hardsigmoid: 79.03\n# RReLU: 77.00\n# Hardshrink: failed\n# Softshrink: failed\nNONLINEARITY = {\n    'none':       nn.Identity,\n    'celu':       nn.CELU,\n    'elu':        nn.ELU,\n    'gelu':       nn.GELU,\n    'glu':        nn.GLU,\n    'hardsigmoid':nn.Hardsigmoid,\n    'hardshrink': nn.Hardshrink,\n    'hardswish':  nn.Hardswish,\n    'hardtanh':   nn.Hardtanh,\n    'leaky_relu': nn.LeakyReLU,\n    'logsigmoid': nn.LogSigmoid,\n    'mish':       nn.Mish,\n    'prelu':      nn.PReLU,\n    'relu':       nn.ReLU,\n    'relu6':      nn.ReLU6,\n    'rrelu':      nn.RReLU,\n    'selu':       nn.SELU,\n    'silu':       nn.SiLU,\n    'softplus':   nn.Softplus,\n    'softshrink': nn.Softshrink,\n    'softsign':   nn.Softsign,\n    'tanhshrink': nn.Tanhshrink,\n    'tanh':       nn.Tanh,\n}\n\ndef build_nonlinearity(nonlinearity):\n    \"\"\"\n    Look up \"nonlinearity\" in a map from function name to function, build the appropriate layer.\n    \"\"\"\n    if nonlinearity is None:\n        return nn.Identity()\n    if nonlinearity in NONLINEARITY:\n        return NONLINEARITY[nonlinearity]()\n    raise ValueError('Chosen value of nonlinearity, \"%s\", not handled' % nonlinearity)\n\nDEFAULT_WORD_CUTOFF = 7\n\ndef update_word_cutoff(pt, word_cutoff):\n    \"\"\"\n    If a word cutoff option wasn't set, pick a word cutoff based on the size of the pretrain\n\n    Using a lower word cutoff for the smaller pretrains helps quite a bit on the Abkhaz tagger,\n    where all we have is a very small PT.\n\n    no WV:\n    ab_abnc dev\n      UPOS    XPOS  UFeats AllTags\n    89.06   62.53   75.21   61.53\n    ab_abnc test\n      UPOS    XPOS  UFeats AllTags\n    88.96   61.37   74.85   60.29\n\n    WV, cutoff 7\n    ab_abnc dev\n      UPOS    XPOS  UFeats AllTags\n    89.15   62.76   75.43   61.62\n    ab_abnc test\n      UPOS    XPOS  UFeats AllTags\n    89.64   61.56   75.31   60.88\n\n    WV, cutoff 0\n    ab_abnc\n      UPOS    XPOS  UFeats AllTags\n    90.02   64.81   76.75   64.13\n    ab_abnc\n      UPOS    XPOS  UFeats AllTags\n    90.19   63.95   76.62   63.59\n\n    The results are less compelling for depparse, though:\n\n    no WV\n    ab_abnc dev\n      UAS   LAS  CLAS  MLAS  BLEX\n    78.85 65.27 57.31 56.27 57.31\n    ab_abnc test\n      UAS   LAS  CLAS  MLAS  BLEX\n    78.11 64.22 57.45 56.90 57.45\n\n    WV with cutoff 7\n    ab_abnc dev\n      UAS   LAS  CLAS  MLAS  BLEX\n    79.49 65.41 57.15 56.38 57.15\n    ab_abnc test\n      UAS   LAS  CLAS  MLAS  BLEX\n    77.30 64.41 57.13 56.65 57.13\n\n    WV with cutoff 0\n    ab_abnc dev\n      UAS   LAS  CLAS  MLAS  BLEX\n    80.04 65.68 56.81 56.04 56.81\n    ab_abnc test\n      UAS   LAS  CLAS  MLAS  BLEX\n    77.66 64.86 57.28 57.00 57.28\n    \"\"\"\n    if word_cutoff is not None:\n        return word_cutoff\n\n    if pt is None:\n        logger.info('Using 0 as the word cutoff (no pretrain available)')\n        return 0\n\n    if len(pt) < 5000:\n        word_cutoff = 0\n    else:\n        word_cutoff = DEFAULT_WORD_CUTOFF\n    logger.info('Using %d as the word cutoff based on the size of the pretrain (%d)', word_cutoff, len(pt))\n    return word_cutoff\n\n\nQUESTION_RE = re.compile(\"^[?？︖﹖⁇][?？︖﹖⁇!！︕﹗‼]+$\")\nEXCLAM_RE = re.compile(\"^[!！︕﹗‼][?？︖﹖⁇!！︕﹗‼]+$\")\n\ndef simplify_punct(data):\n    \"\"\"\n    For the data formats used in the POS and depparse, replace long punct words with simpler forms\n\n    replace ?[?!]+ -> ?\n    replace ![?!]+ -> !\n    also, include other non-ascii ?!\n    \"\"\"\n    for sent_idx in range(len(data)):\n        for tok_idx in range(len(data[sent_idx])):\n            data[sent_idx][tok_idx][0] = QUESTION_RE.sub(\"?\", data[sent_idx][tok_idx][0])\n            data[sent_idx][tok_idx][0] = EXCLAM_RE.sub(\"!\", data[sent_idx][tok_idx][0])\n    return data\n\n"
  },
  {
    "path": "stanza/models/common/vocab.py",
    "content": "from copy import copy\nfrom collections import Counter, OrderedDict\nfrom collections.abc import Iterable\nimport os\nimport pickle\n\nPAD = '<PAD>'\nPAD_ID = 0\nUNK = '<UNK>'\nUNK_ID = 1\nEMPTY = '<EMPTY>'\nEMPTY_ID = 2\nROOT = '<ROOT>'\nROOT_ID = 3\nVOCAB_PREFIX = [PAD, UNK, EMPTY, ROOT]\nVOCAB_PREFIX_SIZE = len(VOCAB_PREFIX)\n\nclass BaseVocab:\n    \"\"\" A base class for common vocabulary operations. Each subclass should at least \n    implement its own build_vocab() function.\"\"\"\n    def __init__(self, data=None, lang=\"\", idx=0, cutoff=0, lower=False):\n        self.data = data\n        self.lang = lang\n        self.idx = idx\n        self.cutoff = cutoff\n        self.lower = lower\n        if data is not None:\n            self.build_vocab()\n        self.state_attrs = ['lang', 'idx', 'cutoff', 'lower', '_unit2id', '_id2unit']\n\n    def build_vocab(self):\n        raise NotImplementedError(\"This BaseVocab does not have build_vocab implemented.  This method should create _id2unit and _unit2id\")\n\n    def state_dict(self):\n        \"\"\" Returns a dictionary containing all states that are necessary to recover\n        this vocab. Useful for serialization.\"\"\"\n        state = OrderedDict()\n        for attr in self.state_attrs:\n            if hasattr(self, attr):\n                state[attr] = getattr(self, attr)\n        return state\n\n    @classmethod\n    def load_state_dict(cls, state_dict):\n        \"\"\" Returns a new Vocab instance constructed from a state dict. \"\"\"\n        new = cls()\n        for attr, value in state_dict.items():\n            setattr(new, attr, value)\n        return new\n\n    def normalize_unit(self, unit):\n        # be sure to look in subclasses for other normalization being done\n        # especially PretrainWordVocab\n        if unit is None:\n            return unit\n        if self.lower:\n            return unit.lower()\n        return unit\n\n    def unit2id(self, unit):\n        unit = self.normalize_unit(unit)\n        if unit in self._unit2id:\n            return self._unit2id[unit]\n        else:\n            return self._unit2id[UNK]\n\n    def id2unit(self, id):\n        return self._id2unit[id]\n\n    def map(self, units):\n        return [self.unit2id(x) for x in units]\n\n    def unmap(self, ids):\n        return [self.id2unit(x) for x in ids]\n\n    def __str__(self):\n        lang_str = \"(%s)\" % self.lang if self.lang else \"\"\n        name = str(type(self)) + lang_str\n        return \"<%s: %s>\" % (name, self._id2unit)\n\n    def __len__(self):\n        return len(self._id2unit)\n\n    def __getitem__(self, key):\n        if isinstance(key, str):\n            return self.unit2id(key)\n        elif isinstance(key, int) or isinstance(key, list):\n            return self.id2unit(key)\n        else:\n            raise TypeError(\"Vocab key must be one of str, list, or int\")\n\n    def __contains__(self, key):\n        return self.normalize_unit(key) in self._unit2id\n\n    @property\n    def size(self):\n        return len(self)\n\nclass DeltaVocab(BaseVocab):\n    \"\"\"\n    A vocab that starts off with a BaseVocab, then possibly adds more tokens based on the text in the given data\n\n    Currently meant only for characters, such as built by MWT or Lemma\n\n    Expected data format is either a list of strings, or a list of list of strings\n    \"\"\"\n    def __init__(self, data, orig_vocab):\n        self.orig_vocab = orig_vocab\n        super().__init__(data=data, lang=orig_vocab.lang, idx=orig_vocab.idx, cutoff=orig_vocab.cutoff, lower=orig_vocab.lower)\n\n    def build_vocab(self):\n        if all(isinstance(word, str) for word in self.data):\n            allchars = \"\".join(self.data)\n        else:\n            allchars = \"\".join([word for sentence in self.data for word in sentence])\n\n        unk = [c for c in allchars if c not in self.orig_vocab._unit2id]\n        if len(unk) > 0:\n            unk = sorted(set(unk))\n            self._id2unit = self.orig_vocab._id2unit + unk\n            self._unit2id = dict(self.orig_vocab._unit2id)\n            for c in unk:\n                self._unit2id[c] = len(self._unit2id)\n        else:\n            self._id2unit = self.orig_vocab._id2unit\n            self._unit2id = self.orig_vocab._unit2id\n\nclass CompositeVocab(BaseVocab):\n    ''' Vocabulary class that handles parsing and printing composite values such as\n    compositional XPOS and universal morphological features (UFeats).\n\n    Two key options are `keyed` and `sep`. `sep` specifies the separator used between\n    different parts of the composite values, which is `|` for UFeats, for example.\n    If `keyed` is `True`, then the incoming value is treated similarly to UFeats, where\n    each part is a key/value pair separated by an equal sign (`=`). There are no inherit\n    order to the keys, and we sort them alphabetically for serialization and deserialization.\n    Whenever a part is absent, its internal value is a special `<EMPTY>` symbol that will\n    be treated accordingly when generating the output. If `keyed` is `False`, then the parts\n    are treated as positioned values, and `<EMPTY>` is used to pad parts at the end when the\n    incoming value is not long enough.'''\n\n    def __init__(self, data=None, lang=\"\", idx=0, sep=\"\", keyed=False):\n        self.sep = sep\n        self.keyed = keyed\n        super().__init__(data, lang, idx=idx)\n        self.state_attrs += ['sep', 'keyed']\n\n    def unit2parts(self, unit):\n        # unpack parts of a unit\n        if not self.sep:\n            parts = [x for x in unit]\n        else:\n            parts = unit.split(self.sep)\n        if self.keyed:\n            if len(parts) == 1 and parts[0] == '_':\n                return dict()\n            parts = [x.split('=') for x in parts]\n            if any(len(x) != 2 for x in parts):\n                raise ValueError('Received \"%s\" for a dictionary which is supposed to be keyed, eg the entries should all be of the form key=value and separated by %s' % (unit, self.sep))\n\n            # Just treat multi-valued properties values as one possible value\n            parts = dict(parts)\n        elif unit == '_':\n            parts = []\n        return parts\n\n    def unit2id(self, unit):\n        parts = self.unit2parts(unit)\n        if self.keyed:\n            # treat multi-valued properties as singletons\n            return [self._unit2id[k].get(parts[k], UNK_ID) if k in parts else EMPTY_ID for k in self._unit2id]\n        else:\n            return [self._unit2id[i].get(parts[i], UNK_ID) if i < len(parts) else EMPTY_ID for i in range(len(self._unit2id))]\n\n    def id2unit(self, id):\n        # special case: allow single ids for vocabs with length 1\n        if len(self._id2unit) == 1 and not isinstance(id, Iterable):\n            id = (id,)\n        items = []\n        for v, k in zip(id, self._id2unit.keys()):\n            if v == EMPTY_ID: continue\n            if self.keyed:\n                items.append(\"{}={}\".format(k, self._id2unit[k][v]))\n            else:\n                items.append(self._id2unit[k][v])\n        if self.sep is not None:\n            res = self.sep.join(items)\n            if res == \"\":\n                res = \"_\"\n            return res\n        else:\n            return items\n\n    def build_vocab(self):\n        allunits = [w[self.idx] for sent in self.data for w in sent]\n        if self.keyed:\n            self._id2unit = dict()\n\n            for u in allunits:\n                parts = self.unit2parts(u)\n                for key in parts:\n                    if key not in self._id2unit:\n                        self._id2unit[key] = copy(VOCAB_PREFIX)\n\n                    # treat multi-valued properties as singletons\n                    if parts[key] not in self._id2unit[key]:\n                        self._id2unit[key].append(parts[key])\n\n            # special handle for the case where upos/xpos/ufeats are always empty\n            if len(self._id2unit) == 0:\n                self._id2unit['_'] = copy(VOCAB_PREFIX) # use an arbitrary key\n\n        else:\n            self._id2unit = dict()\n\n            allparts = [self.unit2parts(u) for u in allunits]\n            maxlen = max([len(p) for p in allparts])\n\n            for parts in allparts:\n                for i, p in enumerate(parts):\n                    if i not in self._id2unit:\n                        self._id2unit[i] = copy(VOCAB_PREFIX)\n                    if i < len(parts) and p not in self._id2unit[i]:\n                        self._id2unit[i].append(p)\n\n            # special handle for the case where upos/xpos/ufeats are always empty\n            if len(self._id2unit) == 0:\n                self._id2unit[0] = copy(VOCAB_PREFIX) # use an arbitrary key\n\n        self._id2unit = OrderedDict([(k, self._id2unit[k]) for k in sorted(self._id2unit.keys())])\n        self._unit2id = {k: {w:i for i, w in enumerate(self._id2unit[k])} for k in self._id2unit}\n\n    def lens(self):\n        return [len(self._unit2id[k]) for k in self._unit2id]\n\n    def items(self, idx):\n        return self._id2unit[idx]\n\n    def __str__(self):\n        pieces = [\"[\" + \",\".join(x) + \"]\" for _, x in self._id2unit.items()]\n        rep = \"<{}:\\n {}>\".format(type(self), \"\\n \".join(pieces))\n        return rep\n\nclass BaseMultiVocab:\n    \"\"\" A convenient vocab container that can store multiple BaseVocab instances, and support \n    safe serialization of all instances via state dicts. Each subclass of this base class \n    should implement the load_state_dict() function to specify how a saved state dict \n    should be loaded back.\"\"\"\n    def __init__(self, vocab_dict=None):\n        self._vocabs = OrderedDict()\n        if vocab_dict is None:\n            return\n        # check all values provided must be a subclass of the Vocab base class\n        assert all([isinstance(v, BaseVocab) for v in vocab_dict.values()])\n        for k, v in vocab_dict.items():\n            self._vocabs[k] = v\n\n    def __setitem__(self, key, item):\n        self._vocabs[key] = item\n\n    def __getitem__(self, key):\n        return self._vocabs[key]\n\n    def __str__(self):\n        return \"<{}: [{}]>\".format(type(self), \", \".join(self._vocabs.keys()))\n\n    def __contains__(self, key):\n        return key in self._vocabs\n\n    def keys(self):\n        return self._vocabs.keys()\n\n    def state_dict(self):\n        \"\"\" Build a state dict by iteratively calling state_dict() of all vocabs. \"\"\"\n        state = OrderedDict()\n        for k, v in self._vocabs.items():\n            state[k] = v.state_dict()\n        return state\n\n    @classmethod\n    def load_state_dict(cls, state_dict):\n        \"\"\" Construct a MultiVocab by reading from a state dict.\"\"\"\n        raise NotImplementedError\n\n\n\nclass CharVocab(BaseVocab):\n    def build_vocab(self):\n        if isinstance(self.data[0][0], (list, tuple)): # general data from DataLoader\n            counter = Counter([c for sent in self.data for w in sent for c in w[self.idx]])\n            for k in list(counter.keys()):\n                if counter[k] < self.cutoff:\n                    del counter[k]\n        else: # special data from Char LM\n            counter = Counter([c for sent in self.data for c in sent])\n        self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: (counter[k], k), reverse=True))\n        self._unit2id = {w:i for i, w in enumerate(self._id2unit)}\n\n"
  },
  {
    "path": "stanza/models/constituency/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/constituency/base_model.py",
    "content": "\"\"\"\nThe BaseModel is passed to the transitions so that the transitions\ncan operate on a parsing state without knowing the exact\nrepresentation used in the model.\n\nFor example, a SimpleModel simply looks at the top of the various stacks in the state.\n\nA model with LSTM representations for the different transitions may\nattach the hidden and output states of the LSTM to the word /\nconstituent / transition stacks.\n\nReminder: the parsing state is a list of words to parse, the\ntransitions used to build a (possibly incomplete) parse, and the\nconstituent(s) built so far by those transitions.  Each of these\ncomponents are represented using stacks to improve the efficiency\nof operations such as \"combine the most recent 4 constituents\"\nor \"turn the next input word into a constituent\"\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom collections import defaultdict\nimport logging\n\nimport torch\n\nfrom stanza.models.common import utils\nfrom stanza.models.constituency import transition_sequence\nfrom stanza.models.constituency.parse_transitions import TransitionScheme, CloseConstituent\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.models.constituency.state import State\nfrom stanza.models.constituency.tree_stack import TreeStack\nfrom stanza.server.parser_eval import ParseResult, ScoredTree\n\n# default unary limit.  some treebanks may have longer chains (CTB, for example)\nUNARY_LIMIT = 4\n\nlogger = logging.getLogger('stanza.constituency.trainer')\n\nclass BaseModel(ABC):\n    \"\"\"\n    This base class defines abstract methods for manipulating a State.\n\n    Applying transitions may change important metadata about a State\n    such as the vectors associated with LSTM hidden states, for example.\n\n    The constructor forwards all unused arguments to other classes in the\n    constructor sequence, so put this before other classes such as nn.Module\n    \"\"\"\n    def __init__(self, transition_scheme, unary_limit, reverse_sentence, root_labels, *args, **kwargs):\n        super().__init__(*args, **kwargs)  # forwards all unused arguments\n\n        self._transition_scheme = transition_scheme\n        self._unary_limit = unary_limit\n        self._reverse_sentence = reverse_sentence\n        self._root_labels = sorted(list(root_labels))\n\n        self._is_top_down = (self._transition_scheme is TransitionScheme.TOP_DOWN or\n                             self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY or\n                             self._transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND)\n\n    @abstractmethod\n    def initial_word_queues(self, tagged_word_lists):\n        \"\"\"\n        For each list of tagged words, builds a TreeStack of word nodes\n\n        The word lists should be backwards so that the first word is the last word put on the stack (LIFO)\n        \"\"\"\n\n    @abstractmethod\n    def initial_transitions(self):\n        \"\"\"\n        Builds an initial transition stack with whatever values need to go into first position\n        \"\"\"\n\n    @abstractmethod\n    def initial_constituents(self):\n        \"\"\"\n        Builds an initial constituent stack with whatever values need to go into first position\n        \"\"\"\n\n    @abstractmethod\n    def get_word(self, word_node):\n        \"\"\"\n        Get the word corresponding to this position in the word queue\n        \"\"\"\n\n    @abstractmethod\n    def transform_word_to_constituent(self, state):\n        \"\"\"\n        Transform the top node of word_queue to something that can push on the constituent stack\n        \"\"\"\n\n    @abstractmethod\n    def dummy_constituent(self, dummy):\n        \"\"\"\n        When using a dummy node as a sentinel, transform it to something usable by this model\n        \"\"\"\n\n    @abstractmethod\n    def build_constituents(self, labels, children_lists):\n        \"\"\"\n        Build multiple constituents at once.  This gives the opportunity for batching operations\n        \"\"\"\n\n    @abstractmethod\n    def push_constituents(self, constituent_stacks, constituents):\n        \"\"\"\n        Add a multiple constituents to multiple constituent_stacks\n\n        Useful to factor this out in case batching will help\n        \"\"\"\n\n    @abstractmethod\n    def get_top_constituent(self, constituents):\n        \"\"\"\n        Get the first constituent from the constituent stack\n\n        For example, a model might want to remove embeddings and LSTM state vectors\n        \"\"\"\n\n    @abstractmethod\n    def push_transitions(self, transition_stacks, transitions):\n        \"\"\"\n        Add a multiple transitions to multiple transition_stacks\n\n        Useful to factor this out in case batching will help\n        \"\"\"\n\n    @abstractmethod\n    def get_top_transition(self, transitions):\n        \"\"\"\n        Get the first transition from the transition stack\n\n        For example, a model might want to remove transition embeddings before returning the transition\n        \"\"\"\n\n    @property\n    def root_labels(self):\n        \"\"\"\n        Return ROOT labels for this model.  Probably ROOT, TOP, or both\n\n        (Danish uses 's', though)\n        \"\"\"\n        return self._root_labels\n\n    def unary_limit(self):\n        \"\"\"\n        Limit on the number of consecutive unary transitions\n        \"\"\"\n        return self._unary_limit\n\n\n    def transition_scheme(self):\n        \"\"\"\n        Transition scheme used - see parse_transitions\n        \"\"\"\n        return self._transition_scheme\n\n    def has_unary_transitions(self):\n        \"\"\"\n        Whether or not this model uses unary transitions, based on transition_scheme\n        \"\"\"\n        return self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY\n\n    @property\n    def is_top_down(self):\n        \"\"\"\n        Whether or not this model is TOP_DOWN\n        \"\"\"\n        return self._is_top_down\n\n    @property\n    def reverse_sentence(self):\n        \"\"\"\n        Whether or not this model is built to parse backwards\n        \"\"\"\n        return self._reverse_sentence\n\n    def predict(self, states, is_legal=True):\n        raise NotImplementedError(\"LSTMModel can predict, but SimpleModel cannot\")\n\n    def weighted_choice(self, states):\n        raise NotImplementedError(\"LSTMModel can weighted_choice, but SimpleModel cannot\")\n\n    def predict_gold(self, states, is_legal=True):\n        \"\"\"\n        For each State, return the next item in the gold_sequence\n        \"\"\"\n        transitions = [y.gold_sequence[y.num_transitions] for y in states]\n        if is_legal:\n            for trans, state in zip(transitions, states):\n                if not trans.is_legal(state, self):\n                    raise RuntimeError(\"Transition {}:{} was not legal in a transition sequence:\\nOriginal tree: {}\\nTransitions: {}\".format(state.num_transitions, trans, state.gold_tree, state.gold_sequence))\n        return None, transitions, None\n\n    def initial_state_from_preterminals(self, preterminal_lists, gold_trees, gold_sequences):\n        \"\"\"\n        what is passed in should be a list of list of preterminals\n        \"\"\"\n        word_queues = self.initial_word_queues(preterminal_lists)\n        # this is the bottom of the TreeStack and will be the same for each State\n        transitions = self.initial_transitions()\n        constituents = self.initial_constituents()\n        states = [State(sentence_length=len(wq)-2,   # -2 because it starts and ends with a sentinel\n                        num_opens=0,\n                        word_queue=wq,\n                        gold_tree=None,\n                        gold_sequence=None,\n                        transitions=transitions,\n                        constituents=constituents,\n                        word_position=0,\n                        score=0.0,\n                        broken=False)\n                  for idx, wq in enumerate(word_queues)]\n        if gold_trees:\n            states = [state._replace(gold_tree=gold_tree) for gold_tree, state in zip(gold_trees, states)]\n        if gold_sequences:\n            states = [state._replace(gold_sequence=gold_sequence) for gold_sequence, state in zip(gold_sequences, states)]\n        return states\n\n    def initial_state_from_words(self, word_lists):\n        preterminal_lists = [[Tree(tag, Tree(word)) for word, tag in words]\n                             for words in word_lists]\n        return self.initial_state_from_preterminals(preterminal_lists, gold_trees=None, gold_sequences=None)\n\n    def initial_state_from_gold_trees(self, trees, gold_sequences=None):\n        preterminal_lists = [[Tree(pt.label, Tree(pt.children[0].label))\n                              for pt in tree.yield_preterminals()]\n                             for tree in trees]\n        return self.initial_state_from_preterminals(preterminal_lists, gold_trees=trees, gold_sequences=gold_sequences)\n\n    def build_batch_from_trees(self, batch_size, data_iterator):\n        \"\"\"\n        Read from the data_iterator batch_size trees and turn them into new parsing states\n        \"\"\"\n        state_batch = []\n        for _ in range(batch_size):\n            gold_tree = next(data_iterator, None)\n            if gold_tree is None:\n                break\n            state_batch.append(gold_tree)\n\n        if len(state_batch) > 0:\n            state_batch = self.initial_state_from_gold_trees(state_batch)\n        return state_batch\n\n    def build_batch_from_trees_with_gold_sequence(self, batch_size, data_iterator):\n        \"\"\"\n        Same as build_batch_from_trees, but use the model parameters to turn the trees into gold sequences and include the sequence\n        \"\"\"\n        state_batch = self.build_batch_from_trees(batch_size, data_iterator)\n        if len(state_batch) == 0:\n            return state_batch\n\n        gold_sequences = transition_sequence.build_treebank([state.gold_tree for state in state_batch], self.transition_scheme(), self.reverse_sentence)\n        state_batch = [state._replace(gold_sequence=sequence) for state, sequence in zip(state_batch, gold_sequences)]\n        return state_batch\n\n    def build_batch_from_tagged_words(self, batch_size, data_iterator):\n        \"\"\"\n        Read from the data_iterator batch_size tagged sentences and turn them into new parsing states\n\n        Expects a list of list of (word, tag)\n        \"\"\"\n        state_batch = []\n        for _ in range(batch_size):\n            sentence = next(data_iterator, None)\n            if sentence is None:\n                break\n            state_batch.append(sentence)\n\n        if len(state_batch) > 0:\n            state_batch = self.initial_state_from_words(state_batch)\n        return state_batch\n\n\n    def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):\n        \"\"\"\n        Repeat transitions to build a list of trees from the input batches.\n\n        The data_iterator should be anything which returns the data for a parse task via next()\n        build_batch_fn is a function that turns that data into State objects\n        This will be called to generate batches of size batch_size until the data is exhausted\n\n        The return is a list of tuples: (gold_tree, [(predicted, score) ...])\n        gold_tree will be left blank if the data did not include gold trees\n        if keep_scores is true, the score will be the sum of the values\n          returned by the model for each transition\n\n        transition_choice: which method of the model to use for choosing the next transition\n          predict for predicting the transition based on the model\n          predict_gold to just extract the gold transition from the sequence\n        \"\"\"\n        treebank = []\n        treebank_indices = []\n        state_batch = build_batch_fn(batch_size, data_iterator)\n        # used to track which indices we are currently parsing\n        # since the parses get finished at different times, this will let us unsort after\n        batch_indices = list(range(len(state_batch)))\n        horizon_iterator = iter([])\n\n        if keep_constituents:\n            constituents = defaultdict(list)\n\n        while len(state_batch) > 0:\n            pred_scores, transitions, scores = transition_choice(state_batch)\n            if keep_scores and scores is not None:\n                state_batch = [state._replace(score=state.score + score) for state, score in zip(state_batch, scores)]\n            state_batch = self.bulk_apply(state_batch, transitions)\n\n            if keep_constituents:\n                for t_idx, transition in enumerate(transitions):\n                    if isinstance(transition, CloseConstituent):\n                        # constituents is a TreeStack with information on how to build the next state of the LSTM or attn\n                        # constituents.value is the TreeStack node\n                        # constituents.value.value is the Constituent itself (with the tree and the embedding)\n                        constituents[batch_indices[t_idx]].append(state_batch[t_idx].constituents.value.value)\n\n            remove = set()\n            for idx, state in enumerate(state_batch):\n                if state.broken:\n                    # TODO: make a fake tree with the appropriate words at least?\n                    # something like the X-tree CoreNLP does\n                    #gold_tree = state.gold_tree\n                    #treebank.append(ParseResult(gold_tree, [], state if keep_state else None, constituents[batch_indices[idx]] if keep_constituents else None))\n                    #treebank_indices.append(batch_indices[idx])\n                    remove.add(idx)\n                elif state.finished(self):\n                    predicted_tree = state.get_tree(self)\n                    if self.reverse_sentence:\n                        predicted_tree = predicted_tree.reverse()\n                    gold_tree = state.gold_tree\n                    treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, state.score)], state if keep_state else None, constituents[batch_indices[idx]] if keep_constituents else None))\n                    treebank_indices.append(batch_indices[idx])\n                    remove.add(idx)\n\n            if len(remove) > 0:\n                state_batch = [state for idx, state in enumerate(state_batch) if idx not in remove]\n                batch_indices = [batch_idx for idx, batch_idx in enumerate(batch_indices) if idx not in remove]\n\n            for _ in range(batch_size - len(state_batch)):\n                horizon_state = next(horizon_iterator, None)\n                if not horizon_state:\n                    horizon_batch = build_batch_fn(batch_size, data_iterator)\n                    if len(horizon_batch) == 0:\n                        break\n                    horizon_iterator = iter(horizon_batch)\n                    horizon_state = next(horizon_iterator, None)\n\n                state_batch.append(horizon_state)\n                batch_indices.append(len(treebank) + len(state_batch))\n\n        treebank = utils.unsort(treebank, treebank_indices)\n        return treebank\n\n    def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):\n        \"\"\"\n        Given an iterator over the data and a method for building batches, returns a list of parse trees.\n\n        no_grad() is so that gradients aren't kept, which makes the model\n        run faster and use less memory at inference time\n        \"\"\"\n        with torch.no_grad():\n            return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)\n\n    def analyze_trees(self, trees, batch_size=None, keep_state=True, keep_constituents=True, keep_scores=True):\n        \"\"\"\n        Return a ParseResult for each tree in the trees list\n\n        The transitions run will be the transitions represented by the tree\n        The output layers will be available in result.state for each result\n\n        keep_state=True as a default here as a method which keeps the grad\n        is likely to want to keep the resulting state as well\n        \"\"\"\n        if batch_size is None:\n            # TODO: refactor?\n            batch_size = self.args['eval_batch_size']\n        tree_iterator = iter(trees)\n        treebank = self.parse_sentences(tree_iterator, self.build_batch_from_trees_with_gold_sequence, batch_size, self.predict_gold, keep_state, keep_constituents, keep_scores=keep_scores)\n        return treebank\n\n    def parse_tagged_words(self, words, batch_size):\n        \"\"\"\n        This parses tagged words and returns a list of trees.\n\n        `parse_tagged_words` is useful at Pipeline time -\n          it takes words & tags and processes that into trees.\n\n        The tagged words should be represented:\n          one list per sentence\n            each sentence is a list of (word, tag)\n        The return value is a list of ParseTree objects\n        \"\"\"\n        logger.debug(\"Processing %d sentences\", len(words))\n        self.eval()\n\n        sentence_iterator = iter(words)\n        treebank = self.parse_sentences_no_grad(sentence_iterator, self.build_batch_from_tagged_words, batch_size, self.predict, keep_state=False, keep_constituents=False)\n\n        results = [t.predictions[0].tree for t in treebank]\n        return results\n\n    def bulk_apply(self, state_batch, transitions, fail=False):\n        \"\"\"\n        Apply the given list of Transitions to the given list of States, using the model as a reference\n\n        model: SimpleModel, LSTMModel, or any other form of model\n        state_batch: list of States\n        transitions: list of transitions, one per state\n        fail: throw an exception on a failed transition, as opposed to skipping the tree\n        \"\"\"\n        word_positions = []\n        constituents = []\n        new_constituents = []\n        valid_state_indices = []\n        callbacks = defaultdict(list)\n\n        state_batch = list(state_batch)\n        for idx, (state, transition) in enumerate(zip(state_batch, transitions)):\n            if not transition:\n                error = \"Got stuck and couldn't find a legal transition on the following gold tree:\\n{}\\n\\nFinal state:\\n{}\".format(state.gold_tree, state.to_string(self))\n                if fail:\n                    raise ValueError(error)\n                else:\n                    logger.error(error)\n                    state_batch[idx] = state._replace(broken=True)\n                    continue\n\n            if state.num_transitions >= len(state.word_queue) * 20:\n                # too many transitions\n                # x20 is somewhat empirically chosen based on certain\n                # treebanks having deep unary structures, especially early\n                # on when the model is fumbling around\n                if state.gold_tree:\n                    error = \"Went infinite on the following gold tree:\\n{}\\n\\nFinal state:\\n{}\".format(state.gold_tree, state.to_string(self))\n                else:\n                    error = \"Went infinite!:\\nFinal state:\\n{}\".format(state.to_string(self))\n                if fail:\n                    raise ValueError(error)\n                else:\n                    logger.error(error)\n                    state_batch[idx] = state._replace(broken=True)\n                    continue\n\n            wq, c, nc, callback = transition.update_state(state, self)\n\n            word_positions.append(wq)\n            constituents.append(c)\n            new_constituents.append(nc)\n            valid_state_indices.append(idx)\n            if callback:\n                # not `idx` in case something was broken\n                callbacks[callback].append(len(new_constituents)-1)\n\n        for key, idxs in callbacks.items():\n            data = [new_constituents[x] for x in idxs]\n            callback_constituents = key.build_constituents(self, data)\n            for idx, constituent in zip(idxs, callback_constituents):\n                new_constituents[idx] = constituent\n\n        transitions = [trans for state, trans in zip(state_batch, transitions) if not state.broken]\n        if len(transitions) > 0:\n            new_transitions = self.push_transitions([state.transitions for state in state_batch if not state.broken], transitions)\n            new_constituents = self.push_constituents(constituents, new_constituents)\n        else:\n            new_transitions = []\n            new_constituents = []\n\n        for state, transition, word_position, transition_stack, constituents, state_idx in zip(state_batch, transitions, word_positions, new_transitions, new_constituents, valid_state_indices):\n            state_batch[state_idx] = state._replace(num_opens=state.num_opens + transition.delta_opens(),\n                                                    word_position=word_position,\n                                                    transitions=transition_stack,\n                                                    constituents=constituents)\n\n        return state_batch\n\nclass SimpleModel(BaseModel):\n    \"\"\"\n    This model allows pushing and popping with no extra data\n\n    This class is primarily used for testing various operations which\n    don't need the NN's weights\n\n    Also, for rebuilding trees from transitions when verifying the\n    transitions in situations where the NN state is not relevant,\n    as this class will be faster than using the NN\n    \"\"\"\n    def __init__(self, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, unary_limit=UNARY_LIMIT, reverse_sentence=False, root_labels=(\"ROOT\",)):\n        super().__init__(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse_sentence, root_labels=root_labels)\n\n    def initial_word_queues(self, tagged_word_lists):\n        word_queues = []\n        for tagged_words in tagged_word_lists:\n            word_queue =  [None]\n            word_queue += [tag_node for tag_node in tagged_words]\n            word_queue.append(None)\n            if self.reverse_sentence:\n                word_queue.reverse()\n            word_queues.append(word_queue)\n        return word_queues\n\n    def initial_transitions(self):\n        return TreeStack(value=None, parent=None, length=1)\n\n    def initial_constituents(self):\n        return TreeStack(value=None, parent=None, length=1)\n\n    def get_word(self, word_node):\n        return word_node\n\n    def transform_word_to_constituent(self, state):\n        return state.get_word(state.word_position)\n\n    def dummy_constituent(self, dummy):\n        return dummy\n\n    def build_constituents(self, labels, children_lists):\n        constituents = []\n        for label, children in zip(labels, children_lists):\n            if isinstance(label, str):\n                label = (label,)\n            for value in reversed(label):\n                children = Tree(label=value, children=children)\n            constituents.append(children)\n        return constituents\n\n    def push_constituents(self, constituent_stacks, constituents):\n        return [stack.push(constituent) for stack, constituent in zip(constituent_stacks, constituents)]\n\n    def get_top_constituent(self, constituents):\n        return constituents.value\n\n    def push_transitions(self, transition_stacks, transitions):\n        return [stack.push(transition) for stack, transition in zip(transition_stacks, transitions)]\n\n    def get_top_transition(self, transitions):\n        return transitions.value\n"
  },
  {
    "path": "stanza/models/constituency/base_trainer.py",
    "content": "from enum import Enum\nimport logging\nimport os\n\nimport torch\n\nfrom pickle import UnpicklingError\nimport warnings\n\nlogger = logging.getLogger('stanza')\n\nclass ModelType(Enum):\n    LSTM               = 1\n    ENSEMBLE           = 2\n\nclass BaseTrainer:\n    def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):\n        self.model = model\n        self.optimizer = optimizer\n        self.scheduler = scheduler\n        # keeping track of the epochs trained will be useful\n        # for adjusting the learning scheme\n        self.epochs_trained = epochs_trained\n        self.batches_trained = batches_trained\n        self.best_f1 = best_f1\n        self.best_epoch = best_epoch\n        self.first_optimizer = first_optimizer\n\n    def save(self, filename, save_optimizer=True):\n        params = self.model.get_params()\n        checkpoint = {\n            'params': params,\n            'epochs_trained': self.epochs_trained,\n            'batches_trained': self.batches_trained,\n            'best_f1': self.best_f1,\n            'best_epoch': self.best_epoch,\n            'model_type': self.model_type.name,\n            'first_optimizer': self.first_optimizer,\n        }\n        checkpoint[\"bert_lora\"] = self.get_peft_params()\n        if save_optimizer and self.optimizer is not None:\n            checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()\n            checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()\n        torch.save(checkpoint, filename, _use_new_zipfile_serialization=False)\n        logger.info(\"Model saved to %s\", filename)\n\n    def log_norms(self):\n        self.model.log_norms()\n\n    def log_shapes(self):\n        self.model.log_shapes()\n\n    @property\n    def transitions(self):\n        return self.model.transitions\n\n    @property\n    def root_labels(self):\n        return self.model.root_labels\n\n    @property\n    def device(self):\n        return next(self.model.parameters()).device\n\n    def train(self):\n        return self.model.train()\n\n    def eval(self):\n        return self.model.eval()\n\n    # TODO: make ABC with methods such as model_from_params?\n    # TODO: if we save the type in the checkpoint, use that here to figure out which to load\n    @staticmethod\n    def load(filename, args=None, load_optimizer=False, foundation_cache=None, peft_name=None):\n        \"\"\"\n        Load back a model and possibly its optimizer.\n        \"\"\"\n        # hide the import here to avoid circular imports\n        from stanza.models.constituency.ensemble import EnsembleTrainer\n        from stanza.models.constituency.trainer import Trainer\n\n        if not os.path.exists(filename):\n            if args.get('save_dir', None) is None:\n                raise FileNotFoundError(\"Cannot find model in {} and args['save_dir'] is None\".format(filename))\n            elif os.path.exists(os.path.join(args['save_dir'], filename)):\n                filename = os.path.join(args['save_dir'], filename)\n            else:\n                raise FileNotFoundError(\"Cannot find model in {} or in {}\".format(filename, os.path.join(args['save_dir'], filename)))\n        try:\n            # TODO: currently cannot switch this to weights_only=True\n            # without in some way changing the model to save enums in\n            # a safe manner, probably by converting to int\n            try:\n                checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n            except UnpicklingError as e:\n                checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)\n                warnings.warn(\"The saved constituency parser has an old format using Enum, set, unsanitized Transitions, etc.  This version of Stanza can support reading both the new and the old formats.  Future versions will only allow loading with weights_only=True.  Please resave the constituency parser using this version ASAP.\")\n        except BaseException:\n            logger.exception(\"Cannot load model from %s\", filename)\n            raise\n        logger.debug(\"Loaded model from %s\", filename)\n\n        params = checkpoint['params']\n\n        if 'model_type' not in checkpoint:\n            # old models will have this trait\n            # TODO: can remove this after 1.10\n            checkpoint['model_type'] = ModelType.LSTM\n        if isinstance(checkpoint['model_type'], str):\n            checkpoint['model_type'] = ModelType[checkpoint['model_type']]\n        if checkpoint['model_type'] == ModelType.LSTM:\n            clazz = Trainer\n        elif checkpoint['model_type'] == ModelType.ENSEMBLE:\n            clazz = EnsembleTrainer\n        else:\n            raise ValueError(\"Unexpected model type: %s\" % checkpoint['model_type'])\n        model = clazz.model_from_params(params, checkpoint.get('bert_lora', None), args, foundation_cache, peft_name)\n\n        epochs_trained = checkpoint['epochs_trained']\n        batches_trained = checkpoint.get('batches_trained', 0)\n        best_f1 = checkpoint['best_f1']\n        best_epoch = checkpoint['best_epoch']\n\n        if 'first_optimizer' not in checkpoint:\n            # this will only apply to old (LSTM) Trainers\n            # EnsembleTrainers will always have this value saved\n            # so here we can compensate by looking at the old training statistics...\n            # we use params['config'] here instead of model.args\n            # because the args might have a different training\n            # mechanism, but in order to reload the optimizer, we need\n            # to match the optimizer we build with the one that was\n            # used at training time\n            build_simple_adadelta = params['config']['multistage'] and epochs_trained < params['config']['epochs'] // 2\n            checkpoint['first_optimizer'] = build_simple_adadelta\n        first_optimizer = checkpoint['first_optimizer']\n\n        if load_optimizer:\n            optimizer = clazz.load_optimizer(model, checkpoint, first_optimizer, filename)\n            scheduler = clazz.load_scheduler(model, optimizer, checkpoint, first_optimizer)\n        else:\n            optimizer = None\n            scheduler = None\n\n        if checkpoint['model_type'] == ModelType.LSTM:\n            logger.debug(\"-- MODEL CONFIG --\")\n            for k in model.args.keys():\n                logger.debug(\"  --%s: %s\", k, model.args[k])\n            return Trainer(model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch, first_optimizer=first_optimizer)\n        elif checkpoint['model_type'] == ModelType.ENSEMBLE:\n            return EnsembleTrainer(ensemble=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch, first_optimizer=first_optimizer)\n        else:\n            raise ValueError(\"Unexpected model type: %s\" % checkpoint['model_type'])\n\n"
  },
  {
    "path": "stanza/models/constituency/dynamic_oracle.py",
    "content": "from collections import namedtuple\n\nimport numpy as np\n\nfrom stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent\n\nRepairEnum = namedtuple(\"RepairEnum\", \"name value is_correct\")\n\ndef score_candidates_single_block(model, state, candidates, candidate_idx):\n    \"\"\"\n    score candidate fixed sequences by summing up the transition scores of the most important block\n\n    the candidate with the best summed score is chosen, and the candidate sequence is reconstructed from the blocks\n    \"\"\"\n    scores = []\n    # could bulkify this if we wanted\n    for candidate in candidates:\n        current_state = [state]\n        for block in candidate[1:candidate_idx]:\n            for transition in block:\n                current_state = model.bulk_apply(current_state, [transition])\n        score = 0.0\n        for transition in candidate[candidate_idx]:\n            predictions = model.forward(current_state)\n            t_idx = model.transition_map[transition]\n            score += predictions[0, t_idx].cpu().item()\n            current_state = model.bulk_apply(current_state, [transition])\n        scores.append(score)\n    best_idx = np.argmax(scores)\n    best_candidate = [x for block in candidates[best_idx] for x in block]\n    return scores, best_idx, best_candidate\n\ndef score_candidates(model, state, candidates):\n    \"\"\"\n    score candidate fixed sequences by summing up the transition scores of the most important block\n\n    the candidate with the best summed score is chosen, and the candidate sequence is reconstructed from the blocks\n\n    actually, using either this or just scoring a single block doesn't really help\n      eg, score_candidates_single_block(candidate_idx=2)\n    it still winds up being slightly better for accuracy to simply\n    revert to teacher forcing for ambiguous transition errors\n    \"\"\"\n    scores = []\n    # could bulkify this if we wanted\n    for candidate in candidates:\n        current_state = [state]\n        score = 0.0\n        for block in candidate[1:]:\n            for transition in block:\n                predictions = model.forward(current_state)\n                t_idx = model.transition_map[transition]\n                score += predictions[0, t_idx].cpu().item()\n                current_state = model.bulk_apply(current_state, [transition])\n        scores.append(score)\n    best_idx = np.argmax(scores)\n    best_candidate = [x for block in candidates[best_idx] for x in block]\n    return scores, best_idx, best_candidate\n\ndef advance_past_constituents(gold_sequence, cur_index):\n    \"\"\"\n    Advance cur_index through gold_sequence until we have seen 1 more Close than Open\n\n    The index returned is the index of the Close which occurred after all the stuff\n    \"\"\"\n    count = 0\n    while cur_index < len(gold_sequence):\n        if isinstance(gold_sequence[cur_index], OpenConstituent):\n            count = count + 1\n        elif isinstance(gold_sequence[cur_index], CloseConstituent):\n            count = count - 1\n            if count == -1: return cur_index\n        cur_index = cur_index + 1\n    return None\n\ndef find_previous_open(gold_sequence, cur_index):\n    \"\"\"\n    Go backwards from cur_index to find the open which opens the previous block of stuff.\n\n    Return None if it can't be found.\n    \"\"\"\n    count = 0\n    cur_index = cur_index - 1\n    while cur_index >= 0:\n        if isinstance(gold_sequence[cur_index], OpenConstituent):\n            count = count + 1\n            if count > 0:\n                return cur_index\n        elif isinstance(gold_sequence[cur_index], CloseConstituent):\n            count = count - 1\n        cur_index = cur_index - 1\n    return None\n\ndef find_in_order_constituent_end(gold_sequence, cur_index):\n    \"\"\"\n    Advance cur_index through gold_sequence until the next block has ended\n\n    This is different from advance_past_constituents in that it will\n    also return when there is a Shift when count == 0.  That way, we\n    return the first block of things we know attach to the left\n    \"\"\"\n    count = 0\n    saw_shift = False\n    while cur_index < len(gold_sequence):\n        if isinstance(gold_sequence[cur_index], OpenConstituent):\n            count = count + 1\n        elif isinstance(gold_sequence[cur_index], CloseConstituent):\n            count = count - 1\n            if count == -1: return cur_index\n        elif isinstance(gold_sequence[cur_index], Shift):\n            if saw_shift and count == 0:\n                return cur_index\n            else:\n                saw_shift = True\n        cur_index = cur_index + 1\n    return None\n\nclass DynamicOracle():\n    def __init__(self, root_labels, oracle_level, repair_types, additional_levels, deactivated_levels):\n        self.root_labels = root_labels\n        # default oracle_level will be the UNKNOWN repair type (which each oracle should have)\n        # transitions after that as experimental or ambiguous, not to be used by default\n        self.oracle_level = oracle_level if oracle_level is not None else repair_types.UNKNOWN.value\n        self.repair_types = repair_types\n        self.additional_levels = set()\n        if additional_levels:\n            self.additional_levels = set([repair_types[x.upper()] for x in additional_levels.split(\",\")])\n        self.deactivated_levels = set()\n        if deactivated_levels:\n            self.deactivated_levels = set([repair_types[x.upper()] for x in deactivated_levels.split(\",\")])\n\n    def fix_error(self, pred_transition, model, state):\n        \"\"\"\n        Return which error has been made, if any, along with an updated transition list\n\n        We assume the transition sequence builds a correct tree, meaning\n        that there will always be a CloseConstituent sometime after an\n        OpenConstituent, for example\n        \"\"\"\n        gold_transition = state.gold_sequence[state.num_transitions]\n        if gold_transition == pred_transition:\n            return self.repair_types.CORRECT, None\n\n        for repair_type in self.repair_types:\n            if repair_type.fn is None:\n                continue\n            if self.oracle_level is not None and repair_type.value > self.oracle_level and repair_type not in self.additional_levels and not repair_type.debug:\n                continue\n            if repair_type in self.deactivated_levels:\n                continue\n            repair = repair_type.fn(gold_transition, pred_transition, state.gold_sequence, state.num_transitions, self.root_labels, model, state)\n            if repair is None:\n                continue\n\n            if isinstance(repair, tuple) and len(repair) == 2:\n                return repair\n\n            # TODO: could update all of the returns to be tuples of length 2\n            if repair is not None:\n                return repair_type, repair\n\n        return self.repair_types.UNKNOWN, None\n"
  },
  {
    "path": "stanza/models/constituency/ensemble.py",
    "content": "\"\"\"\nPrototype of ensembling N models together on the same dataset\n\nThe main inference method is to run the normal transition sequence,\nbut sum the scores for the N models and use that to choose the highest\nscoring transition\n\nExample of how to run it to build a silver dataset\n(or just parse a text file in general):\n\n# first, use this tool to build a saved ensemble\npython3 stanza/models/constituency/ensemble.py\n   saved_models/constituency/wsj_inorder_?.pt\n   --save_name saved_models/constituency/en_ensemble.pt\n\n# then use the ensemble directly as a model in constituency_parser.py\npython3 stanza/models/constituency_parser.py\n   --save_name saved_models/constituency/en_ensemble.pt\n   --mode parse_text\n   --tokenized_file /nlp/scr/horatio/en_silver/en_split_100\n   --predict_file /nlp/scr/horatio/en_silver/en_split_100.inorder.mrg\n   --retag_package en_combined_bert\n   --lang en\n\nthen, ideally, run a second time with a set of topdown models,\nthen take the trees which match from the files\n\"\"\"\n\n\nimport argparse\nimport copy\nimport logging\nimport os\n\nimport torch\nimport torch.nn as nn\n\nfrom stanza.models.common import utils\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.models.constituency.base_trainer import BaseTrainer, ModelType\nfrom stanza.models.constituency.state import MultiState\nfrom stanza.models.constituency.trainer import Trainer\nfrom stanza.models.constituency.utils import build_optimizer, build_scheduler\nfrom stanza.server.parser_eval import ParseResult, ScoredTree\n\nlogger = logging.getLogger('stanza.constituency.trainer')\n\nclass Ensemble(nn.Module):\n    def __init__(self, args, filenames=None, models=None, foundation_cache=None):\n        \"\"\"\n        Loads each model in filenames\n\n        If foundation_cache is None, we build one on our own,\n        as the expectation is the models will reuse modules\n        such as pretrain, charlm, bert\n        \"\"\"\n        super().__init__()\n\n        self.args = args\n        if filenames:\n            if models:\n                raise ValueError(\"both filenames and models set when making the Ensemble\")\n\n            if foundation_cache is None:\n                foundation_cache = FoundationCache()\n\n            if isinstance(filenames, str):\n                filenames = [filenames]\n            logger.info(\"Models used for ensemble:\\n  %s\", \"\\n  \".join(filenames))\n            models = [Trainer.load(filename, args, load_optimizer=False, foundation_cache=foundation_cache).model for filename in filenames]\n        elif not models:\n            raise ValueError(\"filenames and models both not set!\")\n\n        self.models = nn.ModuleList(models)\n\n        for model_idx, model in enumerate(self.models):\n            if self.models[0].transition_scheme() != model.transition_scheme():\n                raise ValueError(\"Models {} and {} are incompatible.  {} vs {}\".format(filenames[0], filenames[model_idx], self.models[0].transition_scheme(), model.transition_scheme()))\n            if self.models[0].transitions != model.transitions:\n                raise ValueError(f\"Models {filenames[0]} and {filenames[model_idx]} are incompatible: different transitions\\n{filenames[0]}:\\n{self.models[0].transitions}\\n{filenames[model_idx]}:\\n{model.transitions}\")\n            if self.models[0].constituents != model.constituents:\n                raise ValueError(\"Models %s and %s are incompatible: different constituents\" % (filenames[0], filenames[model_idx]))\n            if self.models[0].root_labels != model.root_labels:\n                raise ValueError(\"Models %s and %s are incompatible: different root_labels\" % (filenames[0], filenames[model_idx]))\n            if self.models[0].uses_xpos() != model.uses_xpos():\n                raise ValueError(\"Models %s and %s are incompatible: different uses_xpos\" % (filenames[0], filenames[model_idx]))\n            if self.models[0].reverse_sentence != model.reverse_sentence:\n                raise ValueError(\"Models %s and %s are incompatible: different reverse_sentence\" % (filenames[0], filenames[model_idx]))\n\n        self._reverse_sentence = self.models[0].reverse_sentence\n\n        # submodels are not trained (so far)\n        self.detach_submodels()\n\n        logger.debug(\"Number of models in the Ensemble: %d\", len(self.models))\n        self.register_parameter('weighted_sum', torch.nn.Parameter(torch.zeros(len(self.models), len(self.transitions), requires_grad=True)))\n\n    def detach_submodels(self):\n        # submodels are not trained (so far)\n        for model in self.models:\n            for _, parameter in model.named_parameters():\n                parameter.requires_grad = False\n\n    def train(self, mode=True):\n        super().train(mode)\n        if mode:\n            # peft has a weird interaction where it turns requires_grad back on\n            # even if it was previously off\n            self.detach_submodels()\n\n    @property\n    def transitions(self):\n        return self.models[0].transitions\n\n    @property\n    def root_labels(self):\n        return self.models[0].root_labels\n\n    @property\n    def device(self):\n        return next(self.parameters()).device\n\n    def unary_limit(self):\n        \"\"\"\n        Limit on the number of consecutive unary transitions\n        \"\"\"\n        return min(m.unary_limit() for m in self.models)\n\n    def transition_scheme(self):\n        return self.models[0].transition_scheme()\n\n    def has_unary_transitions(self):\n        return self.models[0].has_unary_transitions()\n\n    @property\n    def is_top_down(self):\n        return self.models[0].is_top_down\n\n    @property\n    def reverse_sentence(self):\n        return self._reverse_sentence\n\n    @property\n    def retag_method(self):\n        # TODO: make the method an enum\n        return self.models[0].args['retag_method']\n\n    def uses_xpos(self):\n        return self.models[0].uses_xpos()\n\n    def get_top_constituent(self, constituents):\n        return self.models[0].get_top_constituent(constituents)\n\n    def get_top_transition(self, transitions):\n        return self.models[0].get_top_transition(transitions)\n\n    def log_norms(self):\n        lines = [\"NORMS FOR MODEL PARAMETERS\"]\n        for name, param in self.named_parameters():\n            if param.requires_grad and not name.startswith(\"models.\"):\n                zeros = torch.sum(param.abs() < 0.000001).item()\n                norm = \"%.6g\" % torch.norm(param).item()\n                lines.append(\"%s %s %d %d\" % (name, norm, zeros, param.nelement()))\n        for model_idx, model in enumerate(self.models):\n            sublines = model.get_norms()\n            if len(sublines) > 0:\n                lines.append(\"  ---- MODEL %d ----\" % model_idx)\n                lines.extend(sublines)\n        logger.info(\"\\n\".join(lines))\n\n    def log_shapes(self):\n        lines = [\"NORMS FOR MODEL PARAMETERS\"]\n        for name, param in self.named_parameters():\n            if param.requires_grad:\n                lines.append(\"{} {}\".format(name, param.shape))\n        logger.info(\"\\n\".join(lines))\n\n    def get_params(self):\n        model_state = self.state_dict()\n        # don't save the children in the base params\n        model_state = {k: v for k, v in model_state.items() if not k.startswith(\"models.\")}\n        return {\n            \"base_params\": model_state,\n            \"children_params\": [x.get_params() for x in self.models]\n        }\n\n    def initial_state_from_preterminals(self, preterminal_lists, gold_trees, gold_sequences):\n        state_batch = [model.initial_state_from_preterminals(preterminal_lists, gold_trees, gold_sequences) for model in self.models]\n        state_batch = list(zip(*state_batch))\n        state_batch = [MultiState(states, gold_tree, gold_sequence, 0.0)\n                       for states, gold_tree, gold_sequence in zip(state_batch, gold_trees, gold_sequences)]\n        return state_batch\n\n    def build_batch_from_tagged_words(self, batch_size, data_iterator):\n        \"\"\"\n        Read from the data_iterator batch_size tagged sentences and turn them into new parsing states\n\n        Expects a list of list of (word, tag)\n        \"\"\"\n        state_batch = []\n        for _ in range(batch_size):\n            sentence = next(data_iterator, None)\n            if sentence is None:\n                break\n            state_batch.append(sentence)\n\n        if len(state_batch) > 0:\n            state_batch = [model.initial_state_from_words(state_batch) for model in self.models]\n            state_batch = list(zip(*state_batch))\n            state_batch = [MultiState(states, None, None, 0.0) for states in state_batch]\n        return state_batch\n\n    def build_batch_from_trees(self, batch_size, data_iterator):\n        \"\"\"\n        Read from the data_iterator batch_size trees and turn them into N lists of parsing states\n        \"\"\"\n        state_batch = []\n        for _ in range(batch_size):\n            gold_tree = next(data_iterator, None)\n            if gold_tree is None:\n                break\n            state_batch.append(gold_tree)\n\n        if len(state_batch) > 0:\n            state_batch = [model.initial_state_from_gold_trees(state_batch) for model in self.models]\n            state_batch = list(zip(*state_batch))\n            state_batch = [MultiState(states, None, None, 0.0) for states in state_batch]\n        return state_batch\n\n    def predict(self, states, is_legal=True):\n        states = list(zip(*[x.states for x in states]))\n        predictions = [model.forward(state_batch) for model, state_batch in zip(self.models, states)]\n\n        # batch X num transitions X num models\n        predictions = torch.stack(predictions, dim=2)\n\n        flat_predictions = torch.einsum(\"BTM,MT->BT\", predictions, self.weighted_sum)\n        predictions = torch.sum(predictions, dim=2) + flat_predictions\n\n        model = self.models[0]\n\n        # TODO: possibly refactor with lstm_model.predict\n        pred_max = torch.argmax(predictions, dim=1)\n        scores = torch.take_along_dim(predictions, pred_max.unsqueeze(1), dim=1)\n        pred_max = pred_max.detach().cpu()\n\n        pred_trans = [model.transitions[pred_max[idx]] for idx in range(len(states[0]))]\n        if is_legal:\n            for idx, (state, trans) in enumerate(zip(states[0], pred_trans)):\n                if not trans.is_legal(state, model):\n                    _, indices = predictions[idx, :].sort(descending=True)\n                    for index in indices:\n                        if model.transitions[index].is_legal(state, model):\n                            pred_trans[idx] = model.transitions[index]\n                            scores[idx] = predictions[idx, index]\n                            break\n                    else: # yeah, else on a for loop, deal with it\n                        pred_trans[idx] = None\n                        scores[idx] = None\n\n        return predictions, pred_trans, scores.squeeze(1)\n\n    def bulk_apply(self, state_batch, transitions, fail=False):\n        new_states = []\n\n        states = list(zip(*[x.states for x in state_batch]))\n        states = [x.bulk_apply(y, transitions, fail=fail) for x, y in zip(self.models, states)]\n        states = list(zip(*states))\n        state_batch = [x._replace(states=y) for x, y in zip(state_batch, states)]\n        return state_batch\n\n    def parse_tagged_words(self, words, batch_size):\n        \"\"\"\n        This parses tagged words and returns a list of trees.\n\n        `parse_tagged_words` is useful at Pipeline time -\n          it takes words & tags and processes that into trees.\n\n        The tagged words should be represented:\n          one list per sentence\n            each sentence is a list of (word, tag)\n        The return value is a list of ParseTree objects\n\n        TODO: this really ought to be refactored with base_model\n        \"\"\"\n        logger.debug(\"Processing %d sentences\", len(words))\n        self.eval()\n\n        sentence_iterator = iter(words)\n        treebank = self.parse_sentences_no_grad(sentence_iterator, self.build_batch_from_tagged_words, batch_size, self.predict, keep_state=False, keep_constituents=False)\n\n        results = [t.predictions[0].tree for t in treebank]\n        return results\n\n    def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):\n        \"\"\"\n        Repeat transitions to build a list of trees from the input batches.\n\n        The data_iterator should be anything which returns the data for a parse task via next()\n        build_batch_fn is a function that turns that data into State objects\n        This will be called to generate batches of size batch_size until the data is exhausted\n\n        The return is a list of tuples: (gold_tree, [(predicted, score) ...])\n        gold_tree will be left blank if the data did not include gold trees\n        currently score is always 1.0, but the interface may be expanded\n        to get a score from the result of the parsing\n\n        transition_choice: which method of the model to use for\n        choosing the next transition\n\n        TODO: refactor with base_model\n        \"\"\"\n        treebank = []\n        treebank_indices = []\n        # this will produce tuples of states\n        # batch size lists of num models tuples\n        state_batch = build_batch_fn(batch_size, data_iterator)\n        batch_indices = list(range(len(state_batch)))\n        horizon_iterator = iter([])\n\n        if keep_constituents:\n            constituents = defaultdict(list)\n\n        while len(state_batch) > 0:\n            pred_scores, transitions, scores = transition_choice(state_batch)\n            # num models lists of batch size states\n            state_batch = self.bulk_apply(state_batch, transitions)\n\n            remove = set()\n            for idx, states in enumerate(state_batch):\n                if states.finished(self):\n                    predicted_tree = states.get_tree(self)\n                    if self.reverse_sentence:\n                        predicted_tree = predicted_tree.reverse()\n                    gold_tree = states.gold_tree\n                    # TODO: could easily store the score here\n                    # not sure what it means to store the state,\n                    # since each model is tracking its own state\n                    treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, None)], None, None))\n                    treebank_indices.append(batch_indices[idx])\n                    remove.add(idx)\n\n            if len(remove) > 0:\n                state_batch = [state for idx, state in enumerate(state_batch) if idx not in remove]\n                batch_indices = [batch_idx for idx, batch_idx in enumerate(batch_indices) if idx not in remove]\n\n            for _ in range(batch_size - len(state_batch)):\n                horizon_state = next(horizon_iterator, None)\n                if not horizon_state:\n                    horizon_batch = build_batch_fn(batch_size, data_iterator)\n                    if len(horizon_batch) == 0:\n                        break\n                    horizon_iterator = iter(horizon_batch)\n                    horizon_state = next(horizon_iterator, None)\n\n                state_batch.append(horizon_state)\n                batch_indices.append(len(treebank) + len(state_batch))\n\n        treebank = utils.unsort(treebank, treebank_indices)\n        return treebank\n\n    def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):\n        with torch.no_grad():\n            return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)\n\nclass EnsembleTrainer(BaseTrainer):\n    \"\"\"\n    Stores a list of constituency models, useful for combining their results into one stronger model\n    \"\"\"\n    def __init__(self, ensemble, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):\n        super().__init__(ensemble, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer)\n\n    @staticmethod\n    def from_files(args, filenames, foundation_cache=None):\n        ensemble = Ensemble(args, filenames, foundation_cache=foundation_cache)\n        ensemble = ensemble.to(args.get('device', None))\n        return EnsembleTrainer(ensemble)\n\n    def get_peft_params(self):\n        params = []\n        for model in self.model.models:\n            if model.args.get('use_peft', False):\n                from peft import get_peft_model_state_dict\n                params.append(get_peft_model_state_dict(model.bert_model, adapter_name=model.peft_name))\n            else:\n                params.append(None)\n\n        return params\n\n    @property\n    def model_type(self):\n        return ModelType.ENSEMBLE\n\n    def log_num_words_known(self, words):\n        nwk = [m.num_words_known(words) for m in self.model.models]\n        if all(x == nwk[0] for x in nwk):\n            logger.info(\"Number of words in the training set known to each sub-model: %d out of %d\", nwk[0], len(words))\n        else:\n            logger.info(\"Number of words in the training set known to the sub-models:\\n  %s\" % \"\\n  \".join([\"%d/%d\" % (x, len(words)) for x in nwk]))\n\n    @staticmethod\n    def build_optimizer(args, model, first_optimizer):\n        def fake_named_parameters():\n            for n, p in model.named_parameters():\n                if not n.startswith(\"models.\"):\n                    yield n, p\n\n        # TODO: there has to be a cleaner way to do this, like maybe a \"keep\" callback\n        # TODO: if we finetune the underlying models, we will want a series of optimizers\n        # so that they can have a different learning rate from the ensemble's fields\n        fake_model = copy.copy(model)\n        fake_model.named_parameters = fake_named_parameters\n        optimizer = build_optimizer(args, fake_model, first_optimizer)\n        return optimizer\n\n    @staticmethod\n    def load_optimizer(model, checkpoint, first_optimizer, filename):\n        optimizer = EnsembleTrainer.build_optimizer(model.models[0].args, model, first_optimizer)\n        if checkpoint.get('optimizer_state_dict', None) is not None:\n            try:\n                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n            except ValueError as e:\n                raise ValueError(\"Failed to load optimizer from %s\" % filename) from e\n        else:\n            logger.info(\"Attempted to load optimizer to resume training, but optimizer not saved.  Creating new optimizer\")\n        return optimizer\n\n    @staticmethod\n    def load_scheduler(model, optimizer, checkpoint, first_optimizer):\n        scheduler = build_scheduler(model.models[0].args, optimizer, first_optimizer=first_optimizer)\n        if 'scheduler_state_dict' in checkpoint:\n            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n        return scheduler\n\n    @staticmethod\n    def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None):\n        # TODO: no need for the if/else once the models are rebuilt\n        children_params = params[\"children_params\"] if isinstance(params, dict) else params\n        base_params = params[\"base_params\"] if isinstance(params, dict) else {}\n\n        # TODO: fill in peft_name\n        if peft_params is None:\n            peft_params = [None] * len(children_params)\n        if peft_name is None:\n            peft_name = [None] * len(children_params)\n\n        if len(children_params) != len(peft_params):\n            raise ValueError(\"Model file had params length %d and peft params length %d\" % (len(params), len(peft_params)))\n        if len(children_params) != len(peft_name):\n            raise ValueError(\"Model file had params length %d and peft name length %d\" % (len(params), len(peft_name)))\n\n        models = [Trainer.model_from_params(model_param, peft_param, args, foundation_cache, peft_name=pname)\n                  for model_param, peft_param, pname in zip(children_params, peft_params, peft_name)]\n        ensemble = Ensemble(args, models=models)\n        ensemble.load_state_dict(base_params, strict=False)\n        ensemble = ensemble.to(args.get('device', None))\n        return ensemble\n\ndef parse_args(args=None):\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n\n    utils.add_device_args(parser)\n\n    parser.add_argument('--lang', default='en', help='Language to use')\n\n    parser.add_argument('models', type=str, nargs='+', default=None, help=\"Which model(s) to load\")\n\n    parser.add_argument('--save_name', type=str, default=None, required=True, help='Where to save the combined ensemble')\n\n    args = vars(parser.parse_args())\n\n    return args\n\ndef main(args=None):\n    args = parse_args(args)\n    foundation_cache = FoundationCache()\n\n    ensemble = EnsembleTrainer.from_files(args, args['models'], foundation_cache)\n    ensemble.save(args['save_name'], save_optimizer=False)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/models/constituency/error_analysis_in_order.py",
    "content": "\"\"\"\nA tool with an initial set of error analysis for in-order parsing.\n\nAnalyzes the first error created in the parser\n\nTODO: there are more errors to analyze, and see below for a case where attachment is misidentified as bracket\n\"\"\"\n\nfrom enum import Enum\n\nfrom stanza.models.constituency.dynamic_oracle import advance_past_constituents\nfrom stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme, Finalize\nfrom stanza.models.constituency.transition_sequence import build_sequence\n\nclass FirstError(Enum):\n    NONE                        = 1\n    UNKNOWN                     = 2\n    WRONG_OPEN_LABEL_NO_CASCADE = 3\n    WRONG_OPEN_LABEL_CASCADE    = 4\n    WRONG_SUBTREE_NO_CASCADE    = 5\n    WRONG_SUBTREE_CASCADE       = 6\n    EXTRA_ATTACHMENT            = 7\n    MISSING_ATTACHMENT          = 8\n    EXTRA_BRACKET_NO_CASCADE    = 9\n    EXTRA_BRACKET_CASCADE       = 10\n    MISSING_BRACKET_NO_CASCADE  = 11\n    MISSING_BRACKET_CASCADE     = 12\n\ndef advance_past_unaries(sequence, idx):\n    while idx + 2 < len(sequence) and isinstance(sequence[idx+1], OpenConstituent) and isinstance(sequence[idx+2], CloseConstituent):\n        idx += 2\n    return idx\n\ndef check_attachment_error(gold_sequence, pred_sequence, idx, error_type):\n    # this will find the Close that closes the constituent that\n    # was just closed in the gold sequence\n    # hopefully we will have built the same constituent(s)\n    # that were built after the gold sequence closed\n    pred_close_idx = advance_past_constituents(pred_sequence, idx)\n    gold_close_idx = pred_close_idx + 1\n    #gold_close_idx = find_in_order_constituent_end(gold_sequence, idx+1) # +1 represents, start counting from the Shift\n    #pred_close_idx = find_in_order_constituent_end(pred_sequence, idx)\n    if gold_sequence[idx+1:gold_close_idx] != pred_sequence[idx:pred_close_idx]:\n        return FirstError.UNKNOWN\n    if (isinstance(gold_sequence[gold_close_idx], CloseConstituent) and\n        isinstance(pred_sequence[pred_close_idx], CloseConstituent) and\n        isinstance(pred_sequence[pred_close_idx+1], CloseConstituent)):\n        #print(gold_sequence)\n        #print(gold_close_idx)\n        #print(pred_sequence)\n        #print(pred_close_idx)\n        #print(\"{:P}\".format(gold_tree))\n        #print(\"{:P}\".format(pred_tree))\n        #print(\"=================\")\n        return error_type\n\n    return None\n\ndef analyze_tree(gold_tree, pred_tree):\n    if gold_tree == pred_tree:\n        return FirstError.NONE\n\n    gold_sequence = build_sequence(gold_tree, TransitionScheme.IN_ORDER)\n    pred_sequence = build_sequence(pred_tree, TransitionScheme.IN_ORDER)\n\n    for idx, (gold_trans, pred_trans) in enumerate(zip(gold_sequence, pred_sequence)):\n        if gold_trans != pred_trans:\n            break\n    else:\n        # guess only the tags were different?\n        return FirstError.NONE\n\n    if isinstance(gold_trans, CloseConstituent) and isinstance(pred_trans, Shift) and isinstance(gold_sequence[idx + 1], Shift):\n        # perhaps this is an attachment error\n        # we can see if the exact same sequence of moved constituent was built\n        error = check_attachment_error(gold_sequence, pred_sequence, idx, FirstError.EXTRA_ATTACHMENT)\n        if error is not None:\n            return error\n\n    if isinstance(pred_trans, CloseConstituent) and isinstance(gold_trans, Shift) and isinstance(pred_sequence[idx + 1], Shift):\n        # perhaps this is an attachment error\n        # we can see if the exact same sequence of moved constituent was built\n        error = check_attachment_error(pred_sequence, gold_sequence, idx, FirstError.MISSING_ATTACHMENT)\n        if error is not None:\n            return error\n\n    if isinstance(gold_trans, OpenConstituent) and isinstance(pred_trans, OpenConstituent):\n        gold_close_idx = advance_past_constituents(gold_sequence, idx+1)\n        gold_unary_idx = advance_past_unaries(gold_sequence, gold_close_idx)\n\n        pred_close_idx = advance_past_constituents(pred_sequence, idx+1)\n        pred_unary_idx = advance_past_unaries(pred_sequence, pred_close_idx)\n        if gold_sequence[idx+1:gold_close_idx] != pred_sequence[idx+1:pred_close_idx]:\n            # maybe the internal structure is the same?\n            # actually, if the number of shifts inside is the same,\n            # then the words shifted were the same,\n            # so the internal structure is different but the parser\n            # is getting back on track after closing\n            if (sum(isinstance(gt, Shift) for gt in gold_sequence[idx+1:gold_close_idx]) ==\n                sum(isinstance(pt, Shift) for pt in pred_sequence[idx+1:pred_close_idx])):\n                if gold_sequence[gold_unary_idx:] == pred_sequence[pred_unary_idx:]:\n                    return FirstError.WRONG_SUBTREE_NO_CASCADE\n                else:\n                    return FirstError.WRONG_SUBTREE_CASCADE\n            return FirstError.UNKNOWN\n        # at this point, everything is the same aside from the open being a different label\n        if gold_sequence[gold_unary_idx:] == pred_sequence[pred_unary_idx:]:\n            return FirstError.WRONG_OPEN_LABEL_NO_CASCADE\n        else:\n            return FirstError.WRONG_OPEN_LABEL_CASCADE\n\n    if isinstance(gold_trans, Shift) and isinstance(pred_trans, OpenConstituent):\n        # This could be a case of an extra bracket inserted into the tree\n        # We will search for the end of the new bracket, then check if\n        # all the children were properly constructed the way the gold sequence wanted to,\n        # aside from the extra bracket\n\n        # TODO: this is also capturing what are effectively attachment\n        # errors in the case of nested nodes (S over S) where a node\n        # at the start should have been connected to the below node\n        #   gold:\n        #  (ROOT\n        #    (S\n        #      (S\n        #        (`` ``)\n        #        (NP (PRP$ Our) (NN balance) (NNS sheets))\n        #        (VP\n        #          (VBP look)\n        #          (SBAR\n        #            (IN like)\n        #            (S\n        #              (NP (PRP they))\n        #              (VP\n        #                (VBD came)\n        #                (PP\n        #                  (IN from)\n        #                  (NP\n        #                    (NP (NNP Alice) (POS 's))\n        #                    (NN wonderland)))))))\n        #        (, ,)\n        #        ('' ''))\n        #      (NP (NNP Mr.) (NNP Fromstein))\n        #      (VP (VBD said))\n        #      (. .)))\n        #\n        #  pred:\n        #  (ROOT\n        #    (S\n        #      (`` ``)\n        #      (S\n        #        (NP (PRP$ Our) (NN balance) (NNS sheets))\n        #        (VP\n        #          (VBP look)\n        #          (SBAR\n        #            (IN like)\n        #            (S\n        #              (NP (PRP they))\n        #              (VP\n        #                (VBD came)\n        #                (PP\n        #                  (IN from)\n        #                  (NP\n        #                    (NP (NNP Alice) (POS 's))\n        #                    (NN wonderland))))))))\n        #      (, ,)\n        #      ('' '')\n        #      (NP (NNP Mr.) (NNP Fromstein))\n        #      (VP (VBD said))\n        #      (. .)))\n\n        pred_close_idx = advance_past_constituents(pred_sequence, idx+1)\n        pred_unary_idx = advance_past_unaries(pred_sequence, pred_close_idx + 1)\n        if gold_sequence[idx:pred_close_idx-1] == pred_sequence[idx+1:pred_close_idx]:\n            #print(gold_sequence)\n            #print(pred_sequence)\n            #print(idx, pred_close_idx)\n            #print(\"{:P}\".format(gold_tree))\n            #print(\"{:P}\".format(pred_tree))\n            #print(\"=================\")\n            gold_unary_idx = advance_past_unaries(gold_sequence, pred_close_idx - 1)\n            if pred_sequence[pred_unary_idx:] == gold_sequence[gold_unary_idx:]:\n                return FirstError.EXTRA_BRACKET_NO_CASCADE\n            else:\n                return FirstError.EXTRA_BRACKET_CASCADE\n\n    if isinstance(pred_trans, Shift) and isinstance(gold_trans, OpenConstituent):\n        # presumably this has attachment errors as well, similarly to EXTRA_BRACKET\n        gold_close_idx = advance_past_constituents(gold_sequence, idx+1)\n        gold_unary_idx = advance_past_unaries(gold_sequence, gold_close_idx + 1)\n        if pred_sequence[idx:gold_close_idx-1] == gold_sequence[idx+1:gold_close_idx]:\n            #print(gold_sequence)\n            #print(pred_sequence)\n            #print(idx, gold_close_idx)\n            #print(\"{:P}\".format(gold_tree))\n            #print(\"{:P}\".format(pred_tree))\n            #print(\"=================\")\n            pred_unary_idx = advance_past_unaries(pred_sequence, gold_close_idx - 1)\n            if pred_sequence[pred_unary_idx:] == gold_sequence[gold_unary_idx:]:\n                return FirstError.MISSING_BRACKET_NO_CASCADE\n            else:\n                return FirstError.MISSING_BRACKET_CASCADE\n\n    return FirstError.UNKNOWN\n"
  },
  {
    "path": "stanza/models/constituency/evaluate_treebanks.py",
    "content": "\"\"\"\nRead multiple treebanks, score the results.\n\nReports the k-best score if multiple predicted treebanks are given.\n\"\"\"\n\nimport argparse\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.server.parser_eval import EvaluateParser, ParseResult\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Get scores for one or more treebanks against the gold')\n    parser.add_argument('gold', type=str, help='Which file to load as the gold trees')\n    parser.add_argument('pred', type=str, nargs='+', help='Which file(s) are the predictions.  If more than one is given, the evaluation will be \"k-best\" with the first prediction treated as the canonical')\n    args = parser.parse_args()\n\n    print(\"Loading gold treebank: \" + args.gold)\n    gold = tree_reader.read_treebank(args.gold)\n    print(\"Loading predicted treebanks: \" + args.pred)\n    pred = [tree_reader.read_treebank(x) for x in args.pred]\n\n    full_results = [ParseResult(parses[0], [*parses[1:]])\n                    for parses in zip(gold, *pred)]\n\n    if len(pred) <= 1:\n        kbest = None\n    else:\n        kbest = len(pred)\n\n    with EvaluateParser(kbest=kbest) as evaluator:\n        response = evaluator.process(full_results)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/constituency/in_order_compound_oracle.py",
    "content": "from enum import Enum\n\nfrom stanza.models.constituency.dynamic_oracle import advance_past_constituents, find_in_order_constituent_end, find_previous_open, DynamicOracle\nfrom stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, CompoundUnary, Finalize\n\ndef fix_missing_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    A CompoundUnary transition was missed after a Shift, but the sequence was continued correctly otherwise\n    \"\"\"\n    if not isinstance(gold_transition, CompoundUnary):\n        return None\n\n    if pred_transition != gold_sequence[gold_index + 1]:\n        return None\n    if isinstance(pred_transition, Finalize):\n        # this can happen if the entire tree is a single word\n        # but it can't be fixed if it means the parser missed the ROOT transition\n        return None\n\n    return gold_sequence[:gold_index] + gold_sequence[gold_index+1:]\n\ndef fix_wrong_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, CompoundUnary):\n        return None\n\n    if not isinstance(pred_transition, CompoundUnary):\n        return None\n\n    assert gold_transition != pred_transition\n\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:]\n\ndef fix_spurious_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if isinstance(gold_transition, CompoundUnary):\n        return None\n\n    if not isinstance(pred_transition, CompoundUnary):\n        return None\n\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:]\n\ndef fix_open_shift_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Fix a missed Open constituent where we predicted a Shift and the next transition was a Shift\n\n    In fact, the subsequent transition MUST be a Shift with this transition scheme\n    \"\"\"\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    #if not isinstance(gold_sequence[gold_index+1], Shift):\n    #    return None\n    assert isinstance(gold_sequence[gold_index+1], Shift)\n\n    # close_index represents the Close for the missing Open\n    close_index = advance_past_constituents(gold_sequence, gold_index+1)\n    assert close_index is not None\n    return gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:]\n\ndef fix_open_open_two_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if gold_transition == pred_transition:\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)\n    if isinstance(gold_sequence[block_end], Shift):\n        # this is a multiple subtrees version of this error\n        # we are only skipping the two subtrees errors for now\n        return None\n\n    # no fix is possible, so we just return here\n    return RepairType.OPEN_OPEN_TWO_SUBTREES_ERROR, None\n\ndef fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three):\n    if gold_transition == pred_transition:\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)\n    if not isinstance(gold_sequence[block_end], Shift):\n        # this is a multiple subtrees version of this error\n        # we are only skipping the two subtrees errors for now\n        return None\n\n    next_block_end = find_in_order_constituent_end(gold_sequence, block_end+1)\n    if exactly_three and isinstance(gold_sequence[next_block_end], Shift):\n        # for exactly three subtrees,\n        # we can put back the missing open transition\n        # and now we have no recall error, only precision error\n        # for more than three, we separate that out as an ambiguous choice\n        return None\n    elif not exactly_three and isinstance(gold_sequence[next_block_end], CloseConstituent):\n        # this is ambiguous, but we can still try this fix\n        return None\n\n    # at this point, we build a new sequence with the origin constituent inserted\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:block_end] + [CloseConstituent(), gold_transition] + gold_sequence[block_end:]\n\n\ndef fix_open_open_three_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three=True)\n\ndef fix_open_open_many_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three=False)\n\ndef fix_open_close_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Find the closed bracket, reopen it\n\n    The Open we just missed must be forgotten - it cannot be reopened\n    \"\"\"\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    if not isinstance(pred_transition, CloseConstituent):\n        return None\n\n    # find the appropriate Open so we can reopen it\n    open_idx = find_previous_open(gold_sequence, gold_index)\n    # actually, if the Close is legal, this can't happen\n    # but it might happen in a unit test which doesn't check legality\n    if open_idx is None:\n        return None\n\n    # also, since we are punting on the missed Open, we need to skip\n    # the Close which would have closed it\n    close_idx = advance_past_constituents(gold_sequence, gold_index+1)\n\n    return gold_sequence[:gold_index] + [pred_transition, gold_sequence[open_idx]] + gold_sequence[gold_index+1:close_idx] + gold_sequence[close_idx+1:]\n\ndef fix_shift_close_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Find the closed bracket, reopen it\n    \"\"\"\n    if not isinstance(gold_transition, Shift):\n        return None\n\n    if not isinstance(pred_transition, CloseConstituent):\n        return None\n\n    # don't do this at the start or immediately after opening\n    if gold_index == 0 or isinstance(gold_sequence[gold_index - 1], OpenConstituent):\n        return None\n\n    open_idx = find_previous_open(gold_sequence, gold_index)\n    assert open_idx is not None\n\n    return gold_sequence[:gold_index] + [pred_transition, gold_sequence[open_idx]] + gold_sequence[gold_index:]\n\ndef fix_shift_open_unambiguous_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, Shift):\n        return None\n\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    bracket_end = find_in_order_constituent_end(gold_sequence, gold_index)\n    assert bracket_end is not None\n    if isinstance(gold_sequence[bracket_end], Shift):\n        # this is an ambiguous error\n        # multiple possible places to end the wrong constituent\n        return None\n    assert isinstance(gold_sequence[bracket_end], CloseConstituent)\n\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:bracket_end] + [CloseConstituent()] + gold_sequence[bracket_end:]\n\ndef fix_close_shift_unambiguous_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n\n    if not isinstance(pred_transition, Shift):\n        return None\n    if not isinstance(gold_sequence[gold_index+1], Shift):\n        return None\n\n    bracket_end = find_in_order_constituent_end(gold_sequence, gold_index+1)\n    assert bracket_end is not None\n    if isinstance(gold_sequence[bracket_end], Shift):\n        # this is an ambiguous error\n        # multiple possible places to end the wrong constituent\n        return None\n    assert isinstance(gold_sequence[bracket_end], CloseConstituent)\n\n    return gold_sequence[:gold_index] + gold_sequence[gold_index+1:bracket_end] + [CloseConstituent()] + gold_sequence[bracket_end:]\n\nclass RepairType(Enum):\n    \"\"\"\n    Keep track of which repair is used, if any, on an incorrect transition\n\n    Effects of different repair types:\n      no oracle:                0.9251  0.9226\n     +missing_unary:            0.9246  0.9214\n     +wrong_unary:              0.9236  0.9213\n     +spurious_unary:           0.9247  0.9229\n     +open_shift_error:         0.9258  0.9226\n     +open_open_two_subtrees:   0.9256  0.9215    # nothing changes with this one...\n     +open_open_three_subtrees: 0.9256  0.9226\n     +open_open_many_subtrees:  0.9257  0.9234\n     +shift_close:              0.9267  0.9250\n     +shift_open:               0.9273  0.9247\n     +close_shift:              0.9266  0.9229\n     +open_close:               0.9267  0.9256\n    \"\"\"\n    def __new__(cls, fn, correct=False, debug=False):\n        \"\"\"\n        Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error\n        \"\"\"\n        value = len(cls.__members__)\n        obj = object.__new__(cls)\n        obj._value_ = value + 1\n        obj.fn = fn\n        obj.correct = correct\n        obj.debug = debug\n        return obj\n\n    @property\n    def is_correct(self):\n        return self.correct\n\n    # The correct sequence went Shift - Unary - Stuff\n    # but the CompoundUnary was missed and Stuff predicted\n    # so now we just proceed as if nothing happened\n    # note that CompoundUnary happens immediately after a Shift\n    # complicated nodes are created with single Open transitions\n    MISSING_UNARY_ERROR                    = (fix_missing_unary_error,)\n\n    # Predicted a wrong CompoundUnary.  No way to fix this, so just keep going\n    WRONG_UNARY_ERROR                      = (fix_wrong_unary_error,)\n\n    # The correct sequence went Shift - Stuff\n    # but instead we predicted a CompoundUnary\n    # again, we just keep going\n    SPURIOUS_UNARY_ERROR                   = (fix_spurious_unary_error,)\n\n    # Were supposed to open a new constituent,\n    # but instead shifted an item onto the stack\n    #\n    # The missed Open cannot be recovered\n    #\n    # One could ask, is it possible to open a bigger constituent later,\n    # but if the constituent patterns go\n    #   X (good open) Y (missed open) Z\n    # when we eventually close Y and Z, because of the missed Open,\n    # it is guaranteed to capture X as well\n    # since it will grab constituents until one left of the previous Open before Y\n    #\n    # Therefore, in this case, we must simply forget about this Open (recall error)\n    OPEN_SHIFT_ERROR                       = (fix_open_shift_error,)\n\n    # With this transition scheme, it is not possible to fix the following pattern:\n    #   T1 O_x T2 C -> T1 O_y T2 C\n    # seeing as how there are no unary transitions\n    # so whatever precision & recall errors are caused by substituting O_x -> O_y\n    # (which could include multiple transitions)\n    # those errors are unfixable in any way\n    OPEN_OPEN_TWO_SUBTREES_ERROR           = (fix_open_open_two_subtrees_error,)\n\n    # With this transition scheme, a three subtree branch with a wrong Open\n    # has a non-ambiguous fix\n    #   T1 O_x T2 T3 C -> T1 O_y T2 T3 C\n    # this can become\n    #   T1 O_y T2 C O_x T3 C\n    # now there are precision errors from the incorrectly added transition(s),\n    # but the correctly replaced transitions are unambiguous\n    OPEN_OPEN_THREE_SUBTREES_ERROR         = (fix_open_open_three_subtrees_error,)\n\n    # We were supposed to shift a new item onto the stack,\n    # but instead we closed the previous constituent\n    # This causes a precision error, but we can avoid the recall error\n    # by immediately reopening the closed constituent.\n    SHIFT_CLOSE_ERROR                      = (fix_shift_close_error,)\n\n    # We opened a new constituent instead of shifting\n    # In the event that the next constituent ends with a close,\n    # rather than building another new constituent,\n    # then there is no ambiguity\n    SHIFT_OPEN_UNAMBIGUOUS_ERROR           = (fix_shift_open_unambiguous_error,)\n\n    # Suppose we were supposed to Close, then Shift\n    # but instead we just did a Shift\n    # Similar to shift_open_unambiguous, we now have an opened\n    # constituent which shouldn't be there\n    # We can scroll past the next constituent created to see\n    # if the outer constituents close at that point\n    # If so, we can close this constituent as well in an unambiguous manner\n    # TODO: analyze the case where we were supposed to Close, Open\n    # but instead did a Shift\n    CLOSE_SHIFT_UNAMBIGUOUS_ERROR          = (fix_close_shift_unambiguous_error,)\n\n    # Supposed to open a new constituent,\n    # instead closed an existing constituent\n    #\n    #  X (good open) Y (open -> close) Z\n    #\n    # the constituent that should contain Y, Z is unfortunately lost\n    # since now the stack has\n    #\n    #  XY ...\n    #\n    # furthermore, there is now a precision error for the extra XY\n    # constituent that should not exist\n    # however, what we can do to minimize further errors is\n    # to at least reopen the label between X and Y\n    OPEN_CLOSE_ERROR                       = (fix_open_close_error,)\n\n    # this is ambiguous, but we can still try the same fix as three_subtrees (see above)\n    OPEN_OPEN_MANY_SUBTREES_ERROR          = (fix_open_open_many_subtrees_error,)\n\n    CORRECT                                = (None, True)\n\n    UNKNOWN                                = None\n\n\nclass InOrderCompoundOracle(DynamicOracle):\n    def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):\n        super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)\n"
  },
  {
    "path": "stanza/models/constituency/in_order_oracle.py",
    "content": "from enum import Enum\n\nfrom stanza.models.constituency.dynamic_oracle import advance_past_constituents, find_in_order_constituent_end, find_previous_open, score_candidates, DynamicOracle, RepairEnum\nfrom stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent\n\ndef fix_wrong_open_root_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    If there is an open/open error specifically at the ROOT, close the wrong open and try again\n    \"\"\"\n    if gold_transition == pred_transition:\n        return None\n\n    if isinstance(gold_transition, OpenConstituent) and isinstance(pred_transition, OpenConstituent) and gold_transition.top_label in root_labels:\n        return gold_sequence[:gold_index] + [pred_transition, CloseConstituent()] + gold_sequence[gold_index:]\n\n    return None\n\ndef fix_wrong_open_unary_chain(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Fix a wrong open/open in a unary chain by removing the skipped unary transitions\n\n    Only applies is the wrong pred transition is a transition found higher up in the unary chain\n    \"\"\"\n    # useful to have this check here in case the call is made independently in a unit test\n    if gold_transition == pred_transition:\n        return None\n\n    if isinstance(gold_transition, OpenConstituent) and isinstance(pred_transition, OpenConstituent):\n        cur_index = gold_index + 1  # This is now a Close if we are in this particular context\n        while cur_index + 1 < len(gold_sequence) and isinstance(gold_sequence[cur_index], CloseConstituent) and isinstance(gold_sequence[cur_index+1], OpenConstituent):\n            cur_index = cur_index + 1  # advance to the next Open\n            if gold_sequence[cur_index] == pred_transition:\n                return gold_sequence[:gold_index] + gold_sequence[cur_index:]\n            cur_index = cur_index + 1  # advance to the next Close\n\n    return None\n\ndef fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two):\n    if gold_transition == pred_transition:\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if isinstance(gold_sequence[gold_index+1], CloseConstituent):\n        # if Close, the gold was a unary\n        return None\n    assert not isinstance(gold_sequence[gold_index+1], OpenConstituent)\n    assert isinstance(gold_sequence[gold_index+1], Shift)\n\n    block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)\n    assert block_end is not None\n\n    if more_than_two and isinstance(gold_sequence[block_end], CloseConstituent):\n        return None\n    if not more_than_two and isinstance(gold_sequence[block_end], Shift):\n        return None\n\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:block_end] + [CloseConstituent(), gold_transition] + gold_sequence[block_end:]\n\ndef fix_wrong_open_two_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=False)\n\ndef fix_wrong_open_multiple_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=True)\n\ndef advance_past_unaries(gold_sequence, cur_index):\n    while cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index], OpenConstituent) and isinstance(gold_sequence[cur_index+1], CloseConstituent):\n        cur_index += 2\n    return cur_index\n\ndef fix_wrong_open_stuff_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Fix a wrong open/open when there is an intervening constituent and then the guessed NT\n\n    This happens when the correct pattern is\n      stuff_1 NT_X stuff_2 close NT_Y ...\n    and instead of guessing the gold transition NT_X,\n    the prediction was NT_Y\n    \"\"\"\n    if gold_transition == pred_transition:\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n    # TODO: Here we could advance past unary transitions while\n    # watching for hitting pred_transition.  However, that is an open\n    # question... is it better to try to keep such an Open as part of\n    # the sequence, or is it better to skip them and attach the inner\n    # nodes to the upper level\n    stuff_start = gold_index + 1\n    if not isinstance(gold_sequence[stuff_start], Shift):\n        return None\n    stuff_end = advance_past_constituents(gold_sequence, stuff_start)\n    if stuff_end is None:\n        return None\n    # at this point, stuff_end points to the Close which occurred after stuff_2\n    # also, stuff_start points to the first transition which makes stuff_2, the Shift\n    cur_index = stuff_end + 1\n    while isinstance(gold_sequence[cur_index], OpenConstituent):\n        if gold_sequence[cur_index] == pred_transition:\n            return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index+1:]\n        # this was an OpenConstituent, but not the OpenConstituent we guessed\n        # maybe there's a unary transition which lets us try again\n        if cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index + 1], CloseConstituent):\n            cur_index = cur_index + 2\n        else:\n            break\n\n    # oh well, none of this worked\n    return None\n\ndef fix_wrong_open_general(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Fix a general wrong open/open transition by accepting the open and continuing\n\n    A couple other open/open patterns have already been carved out\n\n    TODO: negative checks for the previous patterns, in case we turn those off\n    \"\"\"\n    if gold_transition == pred_transition:\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n    # If the top is a ROOT, then replacing it with a non-ROOT creates an illegal\n    # transition sequence.  The ROOT case was already handled elsewhere anyway\n    if gold_transition.top_label in root_labels:\n        return None\n\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:]\n\ndef fix_missed_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Fix a missed unary which is followed by an otherwise correct transition\n\n    (also handles multiple missed unary transitions)\n    \"\"\"\n    if gold_transition == pred_transition:\n        return None\n\n    cur_index = gold_index\n    cur_index = advance_past_unaries(gold_sequence, cur_index)\n    if gold_sequence[cur_index] == pred_transition:\n        return gold_sequence[:gold_index] + gold_sequence[cur_index:]\n    return None\n\ndef fix_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Fix an Open replaced with a Shift\n\n    Suppose we were supposed to guess NT_X and instead did S\n\n    We derive the repair as follows.\n\n    For simplicity, assume the open is not a unary for now\n\n    Since we know an Open was legal, there must be stuff\n      stuff NT_X\n    Shift is also legal, so there must be other stuff and a previous Open\n      stuff_1 NT_Y stuff_2 NT_X\n    After the NT_X which we missed, there was a bunch of stuff and a close for NT_X\n      stuff_1 NT_Y stuff_2 NT_X stuff_3 C\n    There could be more stuff here which can be saved...\n      stuff_1 NT_Y stuff_2 NT_X stuff_3 C stuff_4 C\n      stuff_1 NT_Y stuff_2 NT_X stuff_3 C C\n    \"\"\"\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    cur_index = gold_index\n    cur_index = advance_past_unaries(gold_sequence, cur_index)\n    if not isinstance(gold_sequence[cur_index], OpenConstituent):\n        return None\n    if gold_sequence[cur_index].top_label in root_labels:\n        return None\n    # cur_index now points to the NT_X we missed (not counting unaries)\n\n    stuff_start = cur_index + 1\n    # can't be a Close, since we just went past an Open and checked for unaries\n    # can't be an Open, since two Open in a row is illegal\n    assert isinstance(gold_sequence[stuff_start], Shift)\n    stuff_end = advance_past_constituents(gold_sequence, stuff_start)\n    # stuff_end is now the Close which ends NT_X\n    cur_index = stuff_end + 1\n    if cur_index >= len(gold_sequence):\n        return None\n    if isinstance(gold_sequence[cur_index], OpenConstituent):\n        cur_index = advance_past_unaries(gold_sequence, cur_index)\n        if cur_index >= len(gold_sequence):\n            return None\n    if isinstance(gold_sequence[cur_index], OpenConstituent):\n        # an Open here signifies that there was a bracket containing X underneath Y\n        # TODO: perhaps try to salvage something out of that situation?\n        return None\n    # the repair starts with the sequence up through the error,\n    # then stuff_3, which includes the error\n    # skip the Close for the missed NT_X\n    # then finish the sequence with any potential stuff_4, the next Close, and everything else\n    repair = gold_sequence[:gold_index] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index:]\n    return repair\n\ndef fix_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Fix an Open replaced with a Close\n\n    Call the Open NT_X\n    Open legal, so there must be stuff:\n      stuff NT_X\n    Close legal, so there must be something to close:\n      stuff_1 NT_Y stuff_2 NT_X\n\n    The incorrect close makes the following brackets:\n      (Y stuff_1 stuff_2)\n    We were supposed to build\n      (Y stuff_1 (X stuff_2 ...) (possibly more stuff))\n    The simplest fix here is to reopen Y at this point.\n\n    One issue might be if there is another bracket which encloses X underneath Y\n    So, for example, the tree was supposed to be\n      (Y stuff_1 (Z (X stuff_2 stuff_3) stuff_4))\n    The pattern for this case is\n      stuff_1 NT_Y stuff_2 NY_X stuff_3 close NT_Z stuff_4 close close\n    \"\"\"\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, CloseConstituent):\n        return None\n\n    cur_index = advance_past_unaries(gold_sequence, gold_index)\n    if cur_index >= len(gold_sequence):\n        return None\n    if not isinstance(gold_sequence[cur_index], OpenConstituent):\n        return None\n    if gold_sequence[cur_index].top_label in root_labels:\n        return None\n\n    prev_open_index = find_previous_open(gold_sequence, gold_index)\n    if prev_open_index is None:\n        return None\n    prev_open = gold_sequence[prev_open_index]\n    # prev_open is now NT_Y from above\n\n    stuff_start = cur_index + 1\n    assert isinstance(gold_sequence[stuff_start], Shift)\n    stuff_end = advance_past_constituents(gold_sequence, stuff_start)\n    # stuff_end is now the Close which ends NT_X\n    # stuff_start:stuff_end is the stuff_3 block above\n    cur_index = stuff_end + 1\n    if cur_index >= len(gold_sequence):\n        return None\n    # if there are unary transitions here, we want to skip those.\n    # those are unary transitions on X and cannot be recovered, since X is gone\n    cur_index = advance_past_unaries(gold_sequence, cur_index)\n    # now there is a certain failure case which has to be accounted for.\n\n    # specifically, if there is a new non-terminal which opens\n    # immediately after X closes, it is encompassing X in a way that\n    # cannot be recovered now that part of X is stuck under Y.\n    # The two choices at this point would be to eliminate the new\n    # transition or just reject the tree from the repair\n    # For now, we reject the tree\n    if isinstance(gold_sequence[cur_index], OpenConstituent):\n        return None\n\n    repair = gold_sequence[:gold_index] + [pred_transition, prev_open] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index:]\n    return repair\n\ndef fix_shift_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    This fixes Shift replaced with a Close transition.\n\n    This error occurs in the following pattern:\n      stuff_1 NT_X stuff... shift\n    Instead of shift, you close the NT_X\n    The easiest fix here is to just restore the NT_X.\n    \"\"\"\n\n    if not isinstance(pred_transition, CloseConstituent):\n        return None\n\n    # this fix can also be applied if there were unaries on the\n    # previous constituent.  we just skip those until the Shift\n    cur_index = gold_index\n    if isinstance(gold_transition, OpenConstituent):\n        cur_index = advance_past_unaries(gold_sequence, cur_index)\n    if not isinstance(gold_sequence[cur_index], Shift):\n        return None\n\n    prev_open_index = find_previous_open(gold_sequence, gold_index)\n    if prev_open_index is None:\n        return None\n    prev_open = gold_sequence[prev_open_index]\n    # prev_open is now NT_X from above\n\n    return gold_sequence[:gold_index] + [pred_transition, prev_open] + gold_sequence[cur_index:]\n\ndef fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous, late):\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    if len(gold_sequence) < gold_index + 3:\n        return None\n    if not isinstance(gold_sequence[gold_index+1], OpenConstituent):\n        return None\n\n    open_index = advance_past_unaries(gold_sequence, gold_index+1)\n    if not isinstance(gold_sequence[open_index], OpenConstituent):\n        return None\n    if not isinstance(gold_sequence[open_index+1], Shift):\n        return None\n\n    # check that the next operation was to open a *different* constituent\n    # from the one we just closed\n    prev_open_index = find_previous_open(gold_sequence, gold_index)\n    if prev_open_index is None:\n        return None\n    prev_open = gold_sequence[prev_open_index]\n    if gold_sequence[open_index] == prev_open:\n        return None\n\n    # check that the following stuff is a single bracket, not multiple brackets\n    end_index = find_in_order_constituent_end(gold_sequence, open_index+1)\n    if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):\n        return None\n    elif not ambiguous and isinstance(gold_sequence[end_index], Shift):\n        return None\n\n    # if closing at the end of the next blocks,\n    # instead of closing after the first block ends,\n    # we go to the end of the last block\n    if late:\n        end_index = advance_past_constituents(gold_sequence, open_index+1)\n\n    return gold_sequence[:gold_index] + gold_sequence[open_index+1:end_index] + gold_sequence[gold_index:open_index+1] + gold_sequence[end_index:]\n\ndef fix_close_open_shift_unambiguous_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=False, late=False)\n\ndef fix_close_open_shift_ambiguous_bracket_early(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=False)\n\ndef fix_close_open_shift_ambiguous_bracket_late(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=True)\n\ndef fix_close_open_shift_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    if len(gold_sequence) < gold_index + 3:\n        return None\n    if not isinstance(gold_sequence[gold_index+1], OpenConstituent):\n        return None\n\n    open_index = advance_past_unaries(gold_sequence, gold_index+1)\n    if not isinstance(gold_sequence[open_index], OpenConstituent):\n        return None\n    if not isinstance(gold_sequence[open_index+1], Shift):\n        return None\n\n    # check that the next operation was to open a *different* constituent\n    # from the one we just closed\n    prev_open_index = find_previous_open(gold_sequence, gold_index)\n    if prev_open_index is None:\n        return None\n    prev_open = gold_sequence[prev_open_index]\n    if gold_sequence[open_index] == prev_open:\n        return None\n\n    # alright, at long last we have:\n    #   a close that was missed\n    #   a non-nested open that was missed\n    end_index = find_in_order_constituent_end(gold_sequence, open_index+1)\n\n    candidates = []\n    candidates.append((gold_sequence[:gold_index], gold_sequence[open_index+1:end_index], gold_sequence[gold_index:open_index+1], gold_sequence[end_index:]))\n    while isinstance(gold_sequence[end_index], Shift):\n        end_index = find_in_order_constituent_end(gold_sequence, end_index+1)\n        candidates.append((gold_sequence[:gold_index], gold_sequence[open_index+1:end_index], gold_sequence[gold_index:open_index+1], gold_sequence[end_index:]))\n\n    scores, best_idx, best_candidate = score_candidates(model, state, candidates)\n    if len(candidates) == 1:\n        return RepairType.CLOSE_OPEN_SHIFT_UNAMBIGUOUS_BRACKET, best_candidate\n\n    if best_idx == len(candidates) - 1:\n        best_idx = -1\n    repair_type = RepairEnum(name=RepairType.CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED.name,\n                             value=\"%d.%d\" % (RepairType.CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED.value, best_idx),\n                             is_correct=False)\n    return repair_type, best_candidate\n\ndef fix_close_open_shift_nested(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Fix a Close X..Open X..Shift pattern where both the Close and Open were skipped.\n\n    Here the pattern we are trying to fix is\n      stuff_A open_X stuff_B *close* open_X shift...\n    replaced with\n      stuff_A open_X stuff_B shift...\n    the missed close & open means a missed recall error for (X A B)\n    whereas the previous open_X can still get the outer bracket\n    \"\"\"\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    if len(gold_sequence) < gold_index + 3:\n        return None\n    if not isinstance(gold_sequence[gold_index+1], OpenConstituent):\n        return None\n\n    # handle the sequence:\n    #   stuff_A open_X stuff_B close open_Y close open_X shift\n    open_index = advance_past_unaries(gold_sequence, gold_index+1)\n    if not isinstance(gold_sequence[open_index], OpenConstituent):\n        return None\n    if not isinstance(gold_sequence[open_index+1], Shift):\n        return None\n\n    # check that the next operation was to open the same constituent\n    # we just closed\n    prev_open_index = find_previous_open(gold_sequence, gold_index)\n    if prev_open_index is None:\n        return None\n    prev_open = gold_sequence[prev_open_index]\n    if gold_sequence[open_index] != prev_open:\n        return None\n\n    return gold_sequence[:gold_index] + gold_sequence[open_index+1:]\n\ndef fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous, late):\n    \"\"\"\n    Repair Close/Shift -> Shift by moving the Close to after the next block is created\n    \"\"\"\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n    if len(gold_sequence) < gold_index + 2:\n        return None\n    start_index = gold_index + 1\n    start_index = advance_past_unaries(gold_sequence, start_index)\n    if len(gold_sequence) < start_index + 2:\n        return None\n    if not isinstance(gold_sequence[start_index], Shift):\n        return None\n\n    end_index = find_in_order_constituent_end(gold_sequence, start_index)\n    if end_index is None:\n        return None\n    # if this *isn't* a close, we don't allow it in the unambiguous case\n    # that case seems to be ambiguous...\n    #   stuff_1 close stuff_2 stuff_3\n    # if you would normally start building stuff_3,\n    # it is not clear if you want to close at the end of\n    # stuff_2 or build stuff_3 instead.\n    if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):\n        return None\n    elif not ambiguous and isinstance(gold_sequence[end_index], Shift):\n        return None\n\n    # close at the end of the brackets, rather than once the first bracket is finished\n    if late:\n        end_index = advance_past_constituents(gold_sequence, start_index)\n\n    return gold_sequence[:gold_index] + gold_sequence[start_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]\n\ndef fix_close_shift_shift_unambiguous(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=False, late=False)\n\ndef fix_close_shift_shift_ambiguous_early(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=False)\n\ndef fix_close_shift_shift_ambiguous_late(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=True)\n\ndef fix_close_shift_shift_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n    if len(gold_sequence) < gold_index + 2:\n        return None\n    start_index = gold_index + 1\n    start_index = advance_past_unaries(gold_sequence, start_index)\n    if len(gold_sequence) < start_index + 2:\n        return None\n    if not isinstance(gold_sequence[start_index], Shift):\n        return None\n\n    # now we know that the gold pattern was\n    #   Close (unaries) Shift\n    # and instead the model predicted Shift\n    candidates = []\n    current_index = start_index\n    while isinstance(gold_sequence[current_index], Shift):\n        current_index = find_in_order_constituent_end(gold_sequence, current_index)\n        assert current_index is not None\n        candidates.append((gold_sequence[:gold_index], gold_sequence[start_index:current_index], [CloseConstituent()], gold_sequence[current_index:]))\n    scores, best_idx, best_candidate = score_candidates(model, state, candidates)\n    if len(candidates) == 1:\n        return RepairType.CLOSE_SHIFT_SHIFT, best_candidate\n    if best_idx == len(candidates) - 1:\n        best_idx = -1\n    repair_type = RepairEnum(name=RepairType.CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED.name,\n                             value=\"%d.%d\" % (RepairType.CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED.value, best_idx),\n                             is_correct=False)\n    #print(best_idx, len(candidates), repair_type)\n    return repair_type, best_candidate\n\ndef ambiguous_shift_open_unary_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, Shift):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    return gold_sequence[:gold_index] + [pred_transition, CloseConstituent()] + gold_sequence[gold_index:]\n\ndef ambiguous_shift_open_early_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, Shift):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    # Find when the current block ends,\n    # either via a Shift or a Close\n    end_index = find_in_order_constituent_end(gold_sequence, gold_index)\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]\n\ndef ambiguous_shift_open_late_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, Shift):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    end_index = advance_past_constituents(gold_sequence, gold_index)\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]\n\ndef ambiguous_shift_open_predicted_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, Shift):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    unary_candidate = (gold_sequence[:gold_index], [pred_transition], [CloseConstituent()], gold_sequence[gold_index:])\n\n    early_index = find_in_order_constituent_end(gold_sequence, gold_index)\n    early_candidate = (gold_sequence[:gold_index], [pred_transition] + gold_sequence[gold_index:early_index], [CloseConstituent()], gold_sequence[early_index:])\n\n    late_index = advance_past_constituents(gold_sequence, gold_index)\n    if early_index == late_index:\n        candidates = [unary_candidate, early_candidate]\n        scores, best_idx, best_candidate = score_candidates(model, state, candidates)\n        if best_idx == 0:\n            return_label = \"U\"\n        else:\n            return_label = \"S\"\n    else:\n        late_candidate = (gold_sequence[:gold_index], [pred_transition] + gold_sequence[gold_index:late_index], [CloseConstituent()], gold_sequence[late_index:])\n        candidates = [unary_candidate, early_candidate, late_candidate]\n        scores, best_idx, best_candidate = score_candidates(model, state, candidates)\n        if best_idx == 0:\n            return_label = \"U\"\n        elif best_idx == 1:\n            return_label = \"E\"\n        else:\n            return_label = \"L\"\n    repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_PREDICTED_CLOSE.name,\n                             value=\"%d.%s\" % (RepairType.SHIFT_OPEN_PREDICTED_CLOSE.value, return_label),\n                             is_correct=False)\n    return repair_type, best_candidate\n\n\ndef report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    return RepairType.OTHER_CLOSE_SHIFT, None\n\ndef report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    return RepairType.OTHER_CLOSE_OPEN, None\n\ndef report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    return RepairType.OTHER_OPEN_OPEN, None\n\ndef report_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    return RepairType.OTHER_OPEN_SHIFT, None\n\ndef report_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, CloseConstituent):\n        return None\n\n    return RepairType.OTHER_OPEN_CLOSE, None\n\ndef report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, Shift):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    return RepairType.OTHER_SHIFT_OPEN, None\n\nclass RepairType(Enum):\n    \"\"\"\n    Keep track of which repair is used, if any, on an incorrect transition\n\n    Statistics on English w/ no charlm, no transformer,\n      eg word vectors only, best model as of January 2024\n\n    unambiguous transitions only:\n        oracle scheme          dev      test\n         no oracle            0.9245   0.9226\n          +wrong_open_root    0.9244   0.9224\n          +wrong_unary_chain  0.9243   0.9237\n          +wrong_open_unary   0.9249   0.9223\n          +wrong_open_general 0.9251   0.9215\n          +missed_unary       0.9248   0.9215\n          +open_shift         0.9243   0.9216\n          +open_close         0.9254   0.9217\n          +shift_close        0.9261   0.9238\n          +close_shift_nested 0.9253   0.9250\n\n    Redoing the wrong_open_general, which seemed to hurt test scores:\n          wrong_open_two_subtrees - L4             0.9244   0.9220\n          every else w/o ambiguous open/open fix   0.9259   0.9241\n          everything w/ open_two_subtrees          0.9261   0.9246\n          w/ ambiguous open_three_subtrees         0.9264   0.9243\n\n    Testing three different possible repairs for shift-open:\n          w/ ambiguous open_three_subtrees 0.9264   0.9243\n          immediate close (unary)          0.9267   0.9246\n          close after first bracket        0.9265   0.9256\n          close after last bracket         0.9264   0.9240\n\n    Testing three possible repairs for close-open-shift/shift\n          w/ ambiguous open_three_subtrees   0.9264   0.9243\n          unambiguous c-o-s/shift            0.9265   0.9246\n          ambiguous c-o-s/shift closed early 0.9262   0.9246\n          ambiguous c-o-s/shift closed late  0.9259   0.9245\n\n    Testing three possible repairs for close-shift/shift\n          w/ ambiguous open_three_subtrees   0.9264   0.9243\n          unambiguous c-s/shift              0.9253   0.9239\n          ambiguous c-s/shift closed early   0.9259   0.9235\n          ambiguous c-s/shift closed late    0.9252   0.9241\n          ambiguous c-s/shift predicted      0.9264   0.9243\n\n    --------------------------------------------------------\n\n    Running ID experiments to verify some of the above findings\n    no charlm or bert, only 200 epochs\n\n    Comparing wrong_open fixes\n          w/ ambiguous open_two_subtrees     0.8448   0.8335\n          w/ ambiguous open_three_subtrees   0.8424   0.8336\n\n    Testing three possible repairs for close-shift/shift\n          unambiguous c-s/shift              0.8448   0.8360\n          ambiguous c-s/shift closed early   0.8425   0.8352\n          ambiguous c-s/shift closed late    0.8452   0.8334\n\n    --------------------------------------------------------\n\n    Running ID experiments to verify some of the above findings\n    bert + peft, only 200 epochs\n\n    Comparing wrong_open fixes\n          w/o ambiguous open/open fix        0.8923   0.8834\n          w/ ambiguous open_two_subtrees     0.8908   0.8828\n          w/ ambiguous open_three_subtrees   0.8901   0.8801\n\n    Testing three possible repairs for close-shift/shift\n          unambiguous c-s/shift              0.8921   0.8825\n          ambiguous c-s/shift closed early   0.8924   0.8841\n          ambiguous c-s/shift closed late    0.8921   0.8806\n          ambiguous c-s/shift predicted      0.8923   0.8835\n\n    --------------------------------------------------------\n\n    Running DE experiments to verify some of the above findings\n    bert + peft, only 200 epochs\n\n    Comparing wrong_open fixes\n          w/o ambiguous open/open fix        0.9576   0.9402\n          w/ ambiguous open_two_subtrees     0.9570   0.9410\n          w/ ambiguous open_three_subtrees   0.9569   0.9412\n\n    Testing three possible repairs for close-shift/shift\n          unambiguous c-s/shift              0.9566   0.9408\n          ambiguous c-s/shift closed early   0.9564   0.9394\n          ambiguous c-s/shift closed late    0.9572   0.9408\n          ambiguous c-s/shift predicted      0.9571   0.9404\n\n    --------------------------------------------------------\n\n    Running IT experiments to verify some of the above findings\n    bert + peft, only 200 epochs\n\n    Comparing wrong_open fixes\n          w/o ambiguous open/open fix        0.8380   0.8361\n          w/ ambiguous open_two_subtrees     0.8377   0.8351\n          w/ ambiguous open_three_subtrees   0.8381   0.8368\n\n    Testing three possible repairs for close-shift/shift\n          unambiguous c-s/shift              0.8376   0.8392\n          ambiguous c-s/shift closed early   0.8363   0.8359\n          ambiguous c-s/shift closed late    0.8365   0.8383\n          ambiguous c-s/shift predicted      0.8379   0.8371\n\n    --------------------------------------------------------\n\n    Running ZH experiments to verify some of the above findings\n    bert + peft, only 200 epochs\n\n    Comparing wrong_open fixes\n          w/o ambiguous open/open fix        0.9160   0.9143\n          w/ ambiguous open_two_subtrees     0.9145   0.9144\n          w/ ambiguous open_three_subtrees   0.9146   0.9142\n\n    Testing three possible repairs for close-shift/shift\n          unambiguous c-s/shift              0.9155   0.9146\n          ambiguous c-s/shift closed early   0.9145   0.9153\n          ambiguous c-s/shift closed late    0.9138   0.9140\n          ambiguous c-s/shift predicted      0.9154   0.9144\n\n    --------------------------------------------------------\n\n    Running VI experiments to verify some of the above findings\n    bert + peft, only 200 epochs\n\n    Comparing wrong_open fixes\n          w/o ambiguous open/open fix        0.8282   0.7668\n          w/ ambiguous open_two_subtrees     0.8272   0.7670\n          w/ ambiguous open_three_subtrees   0.8282   0.7668\n\n    Testing three possible repairs for close-shift/shift\n          unambiguous c-s/shift              0.8285   0.7683\n          ambiguous c-s/shift closed early   0.8276   0.7678\n          ambiguous c-s/shift closed late    0.8278   0.7668\n          ambiguous c-s/shift predicted      0.8270   0.7668\n\n    --------------------------------------------------------\n\n    Testing a combination of ambiguous vs predicted transitions\n\n      ambiguous\n    EN: (no CSS_U)                           0.9258   0.9252\n    ZH: (no CSS_U)                           0.9153   0.9145\n\n      predicted\n    EN: (no CSS_U)                           0.9264   0.9241\n    ZH: (no CSS_U)                           0.9145   0.9141\n    \"\"\"\n    def __new__(cls, fn, correct=False, debug=False):\n        \"\"\"\n        Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error\n\n        correct: this represents a correct transition\n\n        debug: always run this, as it just counts statistics\n        \"\"\"\n        value = len(cls.__members__)\n        obj = object.__new__(cls)\n        obj._value_ = value + 1\n        obj.fn = fn\n        obj.correct = correct\n        obj.debug = debug\n        return obj\n\n    @property\n    def is_correct(self):\n        return self.correct\n\n    # The first section is a sequence of repairs when the parser\n    # should have chosen NTx but instead chose NTy\n\n    # Blocks of transitions which can be abstracted away to be\n    # anything will be represented as S1, S2, etc... S for stuff\n\n    # We carve out an exception for a wrong open at the root\n    # The only possble transtions at this point are to close\n    # the error and try again with the root\n    WRONG_OPEN_ROOT_ERROR  = (fix_wrong_open_root_error,)\n\n    # The simplest form of such an error is when there is a sequence\n    # of unary transitions and the parser chose a wrong parent.\n    # Remember that a unary transition is represented by a pair\n    # of transitions, NTx, Close.\n    # In this case, the correct sequence was\n    #   S1 NTx Close NTy Close NTz ...\n    # but the parser chose NTy, NTz, etc\n    # The repair in this case is to simply discard the unchosen\n    # unary transitions and continue\n    WRONG_OPEN_UNARY_CHAIN = (fix_wrong_open_unary_chain,)\n\n    # Similar to the UNARY_CHAIN error, but in this case there is a\n    # bunch of stuff (one or more constituents built) between the\n    # missed open transition and the close transition\n    WRONG_OPEN_STUFF_UNARY = (fix_wrong_open_stuff_unary,)\n\n    # If the correct sequence is\n    #   T1 O_x T2 C\n    # and instead we predicted\n    #   T1 O_y ...\n    # this can be fixed with a unary transition after\n    #   T1 O_y T2 C O_x C\n    # note that this is technically ambiguous\n    # could have done\n    #   T1 O_x C O_y T2 C\n    # but doing this should be easier for the parser to detect (untested)\n    # also this way the same code paths can be used for two subtrees\n    # and for multiple subtrees\n    WRONG_OPEN_TWO_SUBTREES = (fix_wrong_open_two_subtrees,)\n\n    # If the gold transition is an Open because it is part of\n    # a unary transition, and the following transition is a\n    # correct Shift or Close, we can just skip past the unary.\n    MISSED_UNARY           = (fix_missed_unary,)\n\n    # Open -> Shift errors which don't just represent a unary\n    # generally represent a missing bracket which cannot be\n    # recovered using the in-order mechanism.  Dropping the\n    # missing transition is generally the only fix.\n    # (This means removing the corresponding Close)\n    # One could theoretically create a new transition which\n    # grabs two constituents, though\n    OPEN_SHIFT             = (fix_open_shift,)\n\n    # Open -> Close is a rather drastic break in the\n    # potential structure of the tree.  We can no longer\n    # recover the missed Open, and we might not be able\n    # to recover other following missed Opens as well.\n    # In most cases, the only thing to do is reopen the\n    # incorrectly closed outer bracket and keep going.\n    OPEN_CLOSE             = (fix_open_close,)\n\n    # Similar to the Open -> Close error, but at least\n    # in this case we are just introducing one wrong bracket\n    # rather than also breaking some existing brackets.\n    # The fix here is to reopen the closed bracket.\n    SHIFT_CLOSE            = (fix_shift_close,)\n\n    # Specifically fixes an error where bracket X is\n    # closed and then immediately opened to build a\n    # new X bracket.  In this case, the simplest fix\n    # will be to skip both the close and the new open\n    # and continue from there.\n    CLOSE_OPEN_SHIFT_NESTED = (fix_close_open_shift_nested,)\n\n    # Fix an error where the correct sequence was to Close X, Open Y,\n    # then continue building,\n    # but instead the model did a Shift in place of C_X O_Y\n    # The damage here is a recall error for the missed X and\n    # a precision error for the incorrectly opened X\n    # However, the Y can actually be recovered - whenever we finally\n    # close X, we can then open Y\n    # One form of that is unambiguous, that of\n    #   T_A O_X T_B C O_Y T_C C\n    # with only one subtree after the O_Y\n    # In that case, the Close that would have closed Y\n    # is the only place for the missing close of X\n    # So we can produce the following:\n    #   T_A O_X T_B T_C C O_Y C\n    CLOSE_OPEN_SHIFT_UNAMBIGUOUS_BRACKET = (fix_close_open_shift_unambiguous_bracket,)\n\n    # Similarly to WRONG_OPEN_TWO_SUBTREES, if the correct sequence is\n    #   T1 O_x T2 T3 C\n    # and instead we predicted\n    #   T1 O_y ...\n    # this can be fixed by closing O_y in any number of places\n    #   T1 O_y T2 C O_x T3 C\n    #   T1 O_y T2 C T3 O_x C\n    # Either solution is a single precision error,\n    # but keeps the O_x subtree correct\n    # This is an ambiguous transition - we can experiment with different fixes\n    WRONG_OPEN_MULTIPLE_SUBTREES = (fix_wrong_open_multiple_subtrees,)\n\n    CORRECT                = (None, True)\n\n    UNKNOWN                = None\n\n    # If the model is supposed to build a block after a Close\n    # operation, attach that block to the piece to the left\n    # a couple different variations on this were tried\n    # we tried attaching all constituents to the\n    #   bracket which should have been closed\n    # we tried attaching exactly one constituent\n    # and we tried attaching only if there was\n    #   exactly one following constituent\n    # none of these improved f1.  for example, on the VI dataset, we\n    # lost 0.15 F1 with the exactly one following constituent version\n    # it might be worthwhile double checking some of the other\n    # versions to make sure those also fail, though\n    CLOSE_SHIFT_SHIFT                   = (fix_close_shift_shift_unambiguous,)\n\n    # In the ambiguous close-shift/shift case, this closes the surrounding bracket\n    # (which should have already been closed)\n    # as soon as the next constituent is built\n    # this turns\n    #   (A (B s1 s2) s3 s4)\n    # into\n    #   (A (B s1 s2 s3) s4)\n    CLOSE_SHIFT_SHIFT_AMBIGUOUS_EARLY   = (fix_close_shift_shift_ambiguous_early,)\n\n    # In the ambiguous close-shift/shift case, this closes the surrounding bracket\n    # (which should have already been closed)\n    # when the rest of the constituents in this bracket are built\n    # this turns\n    #   (A (B s1 s2) s3 s4)\n    # into\n    #   (A (B s1 s2 s3 s4))\n    CLOSE_SHIFT_SHIFT_AMBIGUOUS_LATE    = (fix_close_shift_shift_ambiguous_late,)\n\n    # For the close-shift/shift errors which are ambiguous,\n    # this uses the model's predictions to guess which block\n    # to put the close after\n    CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED = (fix_close_shift_shift_ambiguous_predicted,)\n\n    # If a sequence should have gone Close - Open - Shift,\n    # and instead we went Shift,\n    # we need to close the previous bracket\n    # If it is ambiguous\n    # such as Close - Open - Shift - Shift\n    # close the bracket ASAP\n    # eg, Shift - Close - Open - Shift\n    CLOSE_OPEN_SHIFT_AMBIGUOUS_BRACKET_EARLY = (fix_close_open_shift_ambiguous_bracket_early,)\n\n    # for Close - Open - Shift - Shift\n    # close the bracket as late as possible\n    # eg, Shift - Shift - Close - Open\n    CLOSE_OPEN_SHIFT_AMBIGUOUS_BRACKET_LATE = (fix_close_open_shift_ambiguous_bracket_late,)\n\n    # If the sequence should have gone\n    #   Close - Open - Shift\n    # and instead we predicted a Shift\n    # in a context where closing the bracket would be ambiguous\n    # we use the model to predict where the close should actually happen\n    CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED = (fix_close_open_shift_ambiguous_predicted,)\n\n    # This particular repair effectively turns the shift -> ambiguous open\n    # into a unary transition\n    SHIFT_OPEN_UNARY_CLOSE       = (ambiguous_shift_open_unary_close,)\n\n    # Fix the shift -> ambiguous open by closing after the first constituent\n    # This is an ambiguous solution because it could also be closed either\n    # as a unary transition or with a close at the end of the outer bracket\n    SHIFT_OPEN_EARLY_CLOSE       = (ambiguous_shift_open_early_close,)\n\n    # Fix the shift -> ambiguous open by closing after all constituents\n    # This is an ambiguous solution because it could also be closed either\n    # as a unary transition or with a close at the end of the first constituent\n    SHIFT_OPEN_LATE_CLOSE        = (ambiguous_shift_open_late_close,)\n\n    # Use the model to predict when to close!\n    # The different options for where to put the Close are put into the model,\n    # and the highest scoring close is used\n    SHIFT_OPEN_PREDICTED_CLOSE   = (ambiguous_shift_open_predicted_close,)\n\n    OTHER_CLOSE_SHIFT            = (report_close_shift, False, True)\n\n    OTHER_CLOSE_OPEN             = (report_close_open, False, True)\n\n    OTHER_OPEN_OPEN              = (report_open_open, False, True)\n\n    OTHER_OPEN_CLOSE             = (report_open_close, False, True)\n\n    OTHER_OPEN_SHIFT             = (report_open_shift, False, True)\n\n    OTHER_SHIFT_OPEN             = (report_shift_open, False, True)\n\n    # any other open transition we get wrong, which hasn't already\n    # been carved out as an exception above, we just accept the\n    # incorrect Open and keep going\n    #\n    # TODO: check if there is a way to improve this\n    # it appears to hurt scores simply by existing\n    # explanation: this is wrong logic\n    # Suppose the correct sequence had been\n    #   T1 open(NP) T2 T3 close\n    # Instead we had done\n    #   T1 open(VP) T2 T3 close\n    # We can recover the missing NP!\n    #   T1 open(VP) T2 close open(NP) T3 close\n    # Can also recover it as\n    #   T1 open(VP) T2 T3 close open(NP) close\n    # So this is actually an ambiguous transition\n    # except in the case of\n    #   T1 open(...) close\n    # In this case, a unary transition can fix make it so we only have\n    # a precision error, not also a recall error\n    # Currently, the approach is to put this after the default fixes\n    # and use the two & more-than-two versions of the fix above\n    WRONG_OPEN_GENERAL     = (fix_wrong_open_general,)\n\nclass InOrderOracle(DynamicOracle):\n    def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):\n        super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)\n"
  },
  {
    "path": "stanza/models/constituency/label_attention.py",
    "content": "import numpy as np\nimport functools\nimport sys\nimport torch\nfrom torch.autograd import Variable\nimport torch.nn as nn\nimport torch.nn.init as init\n\n# publicly available versions alternate between torch.uint8 and torch.bool,\n# but that is for older versions of torch anyway\nDTYPE = torch.bool\n\nclass BatchIndices:\n    \"\"\"\n    Batch indices container class (used to implement packed batches)\n    \"\"\"\n    def __init__(self, batch_idxs_np, device):\n        self.batch_idxs_np = batch_idxs_np\n        self.batch_idxs_torch = torch.as_tensor(batch_idxs_np, dtype=torch.long, device=device)\n\n        self.batch_size = int(1 + np.max(batch_idxs_np))\n\n        batch_idxs_np_extra = np.concatenate([[-1], batch_idxs_np, [-1]])\n        self.boundaries_np = np.nonzero(batch_idxs_np_extra[1:] != batch_idxs_np_extra[:-1])[0]\n\n        #print(f\"boundaries_np: {self.boundaries_np}\")\n        #print(f\"boundaries_np[1:]: {self.boundaries_np[1:]}\")\n        #print(f\"boundaries_np[:-1]: {self.boundaries_np[:-1]}\")\n        self.seq_lens_np = self.boundaries_np[1:] - self.boundaries_np[:-1]\n        #print(f\"seq_lens_np: {self.seq_lens_np}\")\n        #print(f\"batch_size: {self.batch_size}\")\n        assert len(self.seq_lens_np) == self.batch_size\n        self.max_len = int(np.max(self.boundaries_np[1:] - self.boundaries_np[:-1]))\n\n\nclass FeatureDropoutFunction(torch.autograd.function.InplaceFunction):\n    @classmethod\n    def forward(cls, ctx, input, batch_idxs, p=0.5, train=False, inplace=False):\n        if p < 0 or p > 1:\n            raise ValueError(\"dropout probability has to be between 0 and 1, \"\n                             \"but got {}\".format(p))\n\n        ctx.p = p\n        ctx.train = train\n        ctx.inplace = inplace\n\n        if ctx.inplace:\n            ctx.mark_dirty(input)\n            output = input\n        else:\n            output = input.clone()\n\n        if ctx.p > 0 and ctx.train:\n            ctx.noise = input.new().resize_(batch_idxs.batch_size, input.size(1))\n            if ctx.p == 1:\n                ctx.noise.fill_(0)\n            else:\n                ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)\n            ctx.noise = ctx.noise[batch_idxs.batch_idxs_torch, :]\n            output.mul_(ctx.noise)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        if ctx.p > 0 and ctx.train:\n            return grad_output.mul(ctx.noise), None, None, None, None\n        else:\n            return grad_output, None, None, None, None\n\n#\nclass FeatureDropout(nn.Module):\n    \"\"\"\n    Feature-level dropout: takes an input of size len x num_features and drops\n    each feature with probabibility p. A feature is dropped across the full\n    portion of the input that corresponds to a single batch element.\n    \"\"\"\n    def __init__(self, p=0.5, inplace=False):\n        super().__init__()\n        if p < 0 or p > 1:\n            raise ValueError(\"dropout probability has to be between 0 and 1, \"\n                             \"but got {}\".format(p))\n        self.p = p\n        self.inplace = inplace\n\n    def forward(self, input, batch_idxs):\n        return FeatureDropoutFunction.apply(input, batch_idxs, self.p, self.training, self.inplace)\n\n\n\nclass LayerNormalization(nn.Module):\n    def __init__(self, d_hid, eps=1e-3, affine=True):\n        super(LayerNormalization, self).__init__()\n\n        self.eps = eps\n        self.affine = affine\n        if self.affine:\n            self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)\n            self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)\n\n    def forward(self, z):\n        if z.size(-1) == 1:\n            return z\n\n        mu = torch.mean(z, keepdim=True, dim=-1)\n        sigma = torch.std(z, keepdim=True, dim=-1)\n        ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)\n        if self.affine:\n            ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)\n\n        return ln_out\n\n\n\nclass ScaledDotProductAttention(nn.Module):\n    def __init__(self, d_model, attention_dropout=0.1):\n        super(ScaledDotProductAttention, self).__init__()\n        self.temper = d_model ** 0.5\n        self.dropout = nn.Dropout(attention_dropout)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, q, k, v, attn_mask=None):\n        # q: [batch, slot, feat] or (batch * d_l) x max_len x d_k\n        # k: [batch, slot, feat] or (batch * d_l) x max_len x d_k\n        # v: [batch, slot, feat] or (batch * d_l) x max_len x d_v\n        # q in LAL is (batch * d_l) x 1 x d_k\n\n        attn = torch.bmm(q, k.transpose(1, 2)) / self.temper # (batch * d_l) x max_len x max_len\n        # in LAL, gives: (batch * d_l) x 1 x max_len\n        # attention weights from each word to each word, for each label\n        # in best model (repeated q): attention weights from label (as vector weights) to each word\n\n        if attn_mask is not None:\n            assert attn_mask.size() == attn.size(), \\\n                    'Attention mask shape {} mismatch ' \\\n                    'with Attention logit tensor shape ' \\\n                    '{}.'.format(attn_mask.size(), attn.size())\n\n            attn.data.masked_fill_(attn_mask, -float('inf'))\n\n        attn = self.softmax(attn)\n        # Note that this makes the distribution not sum to 1. At some point it\n        # may be worth researching whether this is the right way to apply\n        # dropout to the attention.\n        # Note that the t2t code also applies dropout in this manner\n        attn = self.dropout(attn)\n        output = torch.bmm(attn, v) # (batch * d_l) x max_len x d_v\n        # in LAL, gives: (batch * d_l) x 1 x d_v\n\n        return output, attn\n\n\nclass MultiHeadAttention(nn.Module):\n    \"\"\"\n    Multi-head attention module\n    \"\"\"\n\n    def __init__(self, n_head, d_model, d_k, d_v, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):\n        super(MultiHeadAttention, self).__init__()\n\n        self.n_head = n_head\n        self.d_k = d_k\n        self.d_v = d_v\n\n        if not d_positional:\n            self.partitioned = False\n        else:\n            self.partitioned = True\n\n        if self.partitioned:\n            self.d_content = d_model - d_positional\n            self.d_positional = d_positional\n\n            self.w_qs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))\n            self.w_ks1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))\n            self.w_vs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_v // 2))\n\n            self.w_qs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))\n            self.w_ks2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))\n            self.w_vs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_v // 2))\n\n            init.xavier_normal_(self.w_qs1)\n            init.xavier_normal_(self.w_ks1)\n            init.xavier_normal_(self.w_vs1)\n\n            init.xavier_normal_(self.w_qs2)\n            init.xavier_normal_(self.w_ks2)\n            init.xavier_normal_(self.w_vs2)\n        else:\n            self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))\n            self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))\n            self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))\n\n            init.xavier_normal_(self.w_qs)\n            init.xavier_normal_(self.w_ks)\n            init.xavier_normal_(self.w_vs)\n\n        self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)\n        self.layer_norm = LayerNormalization(d_model)\n\n        if not self.partitioned:\n            # The lack of a bias term here is consistent with the t2t code, though\n            # in my experiments I have never observed this making a difference.\n            self.proj = nn.Linear(n_head*d_v, d_model, bias=False)\n        else:\n            self.proj1 = nn.Linear(n_head*(d_v//2), self.d_content, bias=False)\n            self.proj2 = nn.Linear(n_head*(d_v//2), self.d_positional, bias=False)\n\n        self.residual_dropout = FeatureDropout(residual_dropout)\n\n    def split_qkv_packed(self, inp, qk_inp=None):\n        v_inp_repeated = inp.repeat(self.n_head, 1).view(self.n_head, -1, inp.size(-1)) # n_head x len_inp x d_model\n        if qk_inp is None:\n            qk_inp_repeated = v_inp_repeated\n        else:\n            qk_inp_repeated = qk_inp.repeat(self.n_head, 1).view(self.n_head, -1, qk_inp.size(-1))\n\n        if not self.partitioned:\n            q_s = torch.bmm(qk_inp_repeated, self.w_qs) # n_head x len_inp x d_k\n            k_s = torch.bmm(qk_inp_repeated, self.w_ks) # n_head x len_inp x d_k\n            v_s = torch.bmm(v_inp_repeated, self.w_vs) # n_head x len_inp x d_v\n        else:\n            q_s = torch.cat([\n                torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_qs1),\n                torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_qs2),\n                ], -1)\n            k_s = torch.cat([\n                torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_ks1),\n                torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_ks2),\n                ], -1)\n            v_s = torch.cat([\n                torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),\n                torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),\n                ], -1)\n        return q_s, k_s, v_s\n\n    def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):\n        # Input is padded representation: n_head x len_inp x d\n        # Output is packed representation: (n_head * mb_size) x len_padded x d\n        # (along with masks for the attention and output)\n        n_head = self.n_head\n        d_k, d_v = self.d_k, self.d_v\n\n        len_padded = batch_idxs.max_len\n        mb_size = batch_idxs.batch_size\n        q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))\n        k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))\n        v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))\n        invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)\n\n        for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):\n            q_padded[:,i,:end-start,:] = q_s[:,start:end,:]\n            k_padded[:,i,:end-start,:] = k_s[:,start:end,:]\n            v_padded[:,i,:end-start,:] = v_s[:,start:end,:]\n            invalid_mask[i, :end-start].fill_(False)\n\n        return(\n            q_padded.view(-1, len_padded, d_k),\n            k_padded.view(-1, len_padded, d_k),\n            v_padded.view(-1, len_padded, d_v),\n            invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1),\n            (~invalid_mask).repeat(n_head, 1),\n            )\n\n    def combine_v(self, outputs):\n        # Combine attention information from the different heads\n        n_head = self.n_head\n        outputs = outputs.view(n_head, -1, self.d_v) # n_head x len_inp x d_kv\n\n        if not self.partitioned:\n            # Switch from n_head x len_inp x d_v to len_inp x (n_head * d_v)\n            outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, n_head * self.d_v)\n\n            # Project back to residual size\n            outputs = self.proj(outputs)\n        else:\n            d_v1 = self.d_v // 2\n            outputs1 = outputs[:,:,:d_v1]\n            outputs2 = outputs[:,:,d_v1:]\n            outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, n_head * d_v1)\n            outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, n_head * d_v1)\n            outputs = torch.cat([\n                self.proj1(outputs1),\n                self.proj2(outputs2),\n                ], -1)\n\n        return outputs\n\n    def forward(self, inp, batch_idxs, qk_inp=None):\n        residual = inp\n\n        # While still using a packed representation, project to obtain the\n        # query/key/value for each head\n        q_s, k_s, v_s = self.split_qkv_packed(inp, qk_inp=qk_inp)\n        # n_head x len_inp x d_kv\n\n        # Switch to padded representation, perform attention, then switch back\n        q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)\n        # (n_head * batch) x len_padded x d_kv\n\n        outputs_padded, attns_padded = self.attention(\n            q_padded, k_padded, v_padded,\n            attn_mask=attn_mask,\n            )\n        outputs = outputs_padded[output_mask]\n        # (n_head * len_inp) x d_kv\n        outputs = self.combine_v(outputs)\n        # len_inp x d_model\n\n        outputs = self.residual_dropout(outputs, batch_idxs)\n\n        return self.layer_norm(outputs + residual), attns_padded\n\n#\nclass PositionwiseFeedForward(nn.Module):\n    \"\"\"\n    A position-wise feed forward module.\n\n    Projects to a higher-dimensional space before applying ReLU, then projects\n    back.\n    \"\"\"\n\n    def __init__(self, d_hid, d_ff, relu_dropout=0.1, residual_dropout=0.1):\n        super(PositionwiseFeedForward, self).__init__()\n        self.w_1 = nn.Linear(d_hid, d_ff)\n        self.w_2 = nn.Linear(d_ff, d_hid)\n\n        self.layer_norm = LayerNormalization(d_hid)\n        self.relu_dropout = FeatureDropout(relu_dropout)\n        self.residual_dropout = FeatureDropout(residual_dropout)\n        self.relu = nn.ReLU()\n\n\n    def forward(self, x, batch_idxs):\n        residual = x\n\n        output = self.w_1(x)\n        output = self.relu_dropout(self.relu(output), batch_idxs)\n        output = self.w_2(output)\n\n        output = self.residual_dropout(output, batch_idxs)\n        return self.layer_norm(output + residual)\n\n#\nclass PartitionedPositionwiseFeedForward(nn.Module):\n    def __init__(self, d_hid, d_ff, d_positional, relu_dropout=0.1, residual_dropout=0.1):\n        super().__init__()\n        self.d_content = d_hid - d_positional\n        self.w_1c = nn.Linear(self.d_content, d_ff//2)\n        self.w_1p = nn.Linear(d_positional, d_ff//2)\n        self.w_2c = nn.Linear(d_ff//2, self.d_content)\n        self.w_2p = nn.Linear(d_ff//2, d_positional)\n        self.layer_norm = LayerNormalization(d_hid)\n        self.relu_dropout = FeatureDropout(relu_dropout)\n        self.residual_dropout = FeatureDropout(residual_dropout)\n        self.relu = nn.ReLU()\n\n    def forward(self, x, batch_idxs):\n        residual = x\n        xc = x[:, :self.d_content]\n        xp = x[:, self.d_content:]\n\n        outputc = self.w_1c(xc)\n        outputc = self.relu_dropout(self.relu(outputc), batch_idxs)\n        outputc = self.w_2c(outputc)\n\n        outputp = self.w_1p(xp)\n        outputp = self.relu_dropout(self.relu(outputp), batch_idxs)\n        outputp = self.w_2p(outputp)\n\n        output = torch.cat([outputc, outputp], -1)\n\n        output = self.residual_dropout(output, batch_idxs)\n        return self.layer_norm(output + residual)\n\nclass LabelAttention(nn.Module):\n    \"\"\"\n    Single-head Attention layer for label-specific representations\n    \"\"\"\n\n    def __init__(self, d_model, d_k, d_v, d_l, d_proj, combine_as_self, use_resdrop=True, q_as_matrix=False, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):\n        super(LabelAttention, self).__init__()\n        self.d_k = d_k\n        self.d_v = d_v\n        self.d_l = d_l # Number of Labels\n        self.d_model = d_model # Model Dimensionality\n        self.d_proj = d_proj # Projection dimension of each label output\n        self.use_resdrop = use_resdrop # Using Residual Dropout?\n        self.q_as_matrix = q_as_matrix # Using a Matrix of Q to be multiplied with input instead of learned q vectors\n        self.combine_as_self = combine_as_self # Using the Combination Method of Self-Attention\n\n        if not d_positional:\n            self.partitioned = False\n        else:\n            self.partitioned = True\n\n        if self.partitioned:\n            if d_model <= d_positional:\n                raise ValueError(\"Unable to build LabelAttention.  d_model %d <= d_positional %d\" % (d_model, d_positional))\n            self.d_content = d_model - d_positional\n            self.d_positional = d_positional\n\n            if self.q_as_matrix:\n                self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)\n            else:\n                self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)\n            self.w_ks1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)\n            self.w_vs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_v // 2), requires_grad=True)\n\n            if self.q_as_matrix:\n                self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)\n            else:\n                self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)\n            self.w_ks2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)\n            self.w_vs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_v // 2), requires_grad=True)\n\n            init.xavier_normal_(self.w_qs1)\n            init.xavier_normal_(self.w_ks1)\n            init.xavier_normal_(self.w_vs1)\n\n            init.xavier_normal_(self.w_qs2)\n            init.xavier_normal_(self.w_ks2)\n            init.xavier_normal_(self.w_vs2)\n        else:\n            if self.q_as_matrix:\n                self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)\n            else:\n                self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_k), requires_grad=True)\n            self.w_ks = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)\n            self.w_vs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_v), requires_grad=True)\n\n            init.xavier_normal_(self.w_qs)\n            init.xavier_normal_(self.w_ks)\n            init.xavier_normal_(self.w_vs)\n\n        self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)\n        if self.combine_as_self:\n            self.layer_norm = LayerNormalization(d_model)\n        else:\n            self.layer_norm = LayerNormalization(self.d_proj)\n\n        if not self.partitioned:\n            # The lack of a bias term here is consistent with the t2t code, though\n            # in my experiments I have never observed this making a difference.\n            if self.combine_as_self:\n                self.proj = nn.Linear(self.d_l * d_v, d_model, bias=False)\n            else:\n                self.proj = nn.Linear(d_v, d_model, bias=False) # input dimension does not match, should be d_l * d_v\n        else:\n            if self.combine_as_self:\n                self.proj1 = nn.Linear(self.d_l*(d_v//2), self.d_content, bias=False)\n                self.proj2 = nn.Linear(self.d_l*(d_v//2), self.d_positional, bias=False)\n            else:\n                self.proj1 = nn.Linear(d_v//2, self.d_content, bias=False)\n                self.proj2 = nn.Linear(d_v//2, self.d_positional, bias=False)\n        if not self.combine_as_self:\n            self.reduce_proj = nn.Linear(d_model, self.d_proj, bias=False)\n\n        self.residual_dropout = FeatureDropout(residual_dropout)\n\n    def split_qkv_packed(self, inp, k_inp=None):\n        len_inp = inp.size(0)\n        v_inp_repeated = inp.repeat(self.d_l, 1).view(self.d_l, -1, inp.size(-1)) # d_l x len_inp x d_model\n        if k_inp is None:\n            k_inp_repeated = v_inp_repeated\n        else:\n            k_inp_repeated = k_inp.repeat(self.d_l, 1).view(self.d_l, -1, k_inp.size(-1)) # d_l x len_inp x d_model\n\n        if not self.partitioned:\n            if self.q_as_matrix:\n                q_s = torch.bmm(k_inp_repeated, self.w_qs) # d_l x len_inp x d_k\n            else:\n                q_s = self.w_qs.unsqueeze(1) # d_l x 1 x d_k\n            k_s = torch.bmm(k_inp_repeated, self.w_ks) # d_l x len_inp x d_k\n            v_s = torch.bmm(v_inp_repeated, self.w_vs) # d_l x len_inp x d_v\n        else:\n            if self.q_as_matrix:\n                q_s = torch.cat([\n                    torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_qs1),\n                    torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_qs2),\n                    ], -1)\n            else:\n                q_s = torch.cat([\n                    self.w_qs1.unsqueeze(1),\n                    self.w_qs2.unsqueeze(1),\n                    ], -1)\n            k_s = torch.cat([\n                torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_ks1),\n                torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_ks2),\n                ], -1)\n            v_s = torch.cat([\n                torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),\n                torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),\n                ], -1)\n        return q_s, k_s, v_s\n\n    def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):\n        # Input is padded representation: n_head x len_inp x d\n        # Output is packed representation: (n_head * mb_size) x len_padded x d\n        # (along with masks for the attention and output)\n        n_head = self.d_l\n        d_k, d_v = self.d_k, self.d_v\n\n        len_padded = batch_idxs.max_len\n        mb_size = batch_idxs.batch_size\n        if self.q_as_matrix:\n            q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))\n        else:\n            q_padded = q_s.repeat(mb_size, 1, 1) # (d_l * mb_size) x 1 x d_k\n        k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))\n        v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))\n        invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)\n\n        for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):\n            if self.q_as_matrix:\n                q_padded[:,i,:end-start,:] = q_s[:,start:end,:]\n            k_padded[:,i,:end-start,:] = k_s[:,start:end,:]\n            v_padded[:,i,:end-start,:] = v_s[:,start:end,:]\n            invalid_mask[i, :end-start].fill_(False)\n\n        if self.q_as_matrix:\n            q_padded = q_padded.view(-1, len_padded, d_k)\n            attn_mask = invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1)\n        else:\n            attn_mask = invalid_mask.unsqueeze(1).repeat(n_head, 1, 1)\n\n        output_mask = (~invalid_mask).repeat(n_head, 1)\n\n        return(\n            q_padded,\n            k_padded.view(-1, len_padded, d_k),\n            v_padded.view(-1, len_padded, d_v),\n            attn_mask,\n            output_mask,\n            )\n\n    def combine_v(self, outputs):\n        # Combine attention information from the different labels\n        d_l = self.d_l\n        outputs = outputs.view(d_l, -1, self.d_v) # d_l x len_inp x d_v\n\n        if not self.partitioned:\n            # Switch from d_l x len_inp x d_v to len_inp x d_l x d_v\n            if self.combine_as_self:\n                outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, d_l * self.d_v)\n            else:\n                outputs = torch.transpose(outputs, 0, 1)#.contiguous() #.view(-1, d_l * self.d_v)\n            # Project back to residual size\n            outputs = self.proj(outputs) # Becomes len_inp x d_l x d_model\n        else:\n            d_v1 = self.d_v // 2\n            outputs1 = outputs[:,:,:d_v1]\n            outputs2 = outputs[:,:,d_v1:]\n            if self.combine_as_self:\n                outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, d_l * d_v1)\n                outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, d_l * d_v1)\n            else:\n                outputs1 = torch.transpose(outputs1, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)\n                outputs2 = torch.transpose(outputs2, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)\n            outputs = torch.cat([\n                self.proj1(outputs1),\n                self.proj2(outputs2),\n                ], -1)#.contiguous()\n\n        return outputs\n\n    def forward(self, inp, batch_idxs, k_inp=None):\n        residual = inp # len_inp x d_model\n        #print()\n        #print(f\"inp.shape: {inp.shape}\")\n        len_inp = inp.size(0)\n        #print(f\"len_inp: {len_inp}\")\n\n        # While still using a packed representation, project to obtain the\n        # query/key/value for each head\n        q_s, k_s, v_s = self.split_qkv_packed(inp, k_inp=k_inp)\n        # d_l x len_inp x d_k\n        # q_s is d_l x 1 x d_k\n\n        # Switch to padded representation, perform attention, then switch back\n        q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)\n        # q_padded, k_padded, v_padded: (d_l * batch_size) x max_len x d_kv\n        # q_s is (d_l * batch_size) x 1 x d_kv\n\n        outputs_padded, attns_padded = self.attention(\n            q_padded, k_padded, v_padded,\n            attn_mask=attn_mask,\n            )\n        # outputs_padded: (d_l * batch_size) x max_len x d_kv\n        # in LAL: (d_l * batch_size) x 1 x d_kv\n        # on the best model, this is one value vector per label that is repeated max_len times\n        if not self.q_as_matrix:\n            outputs_padded = outputs_padded.repeat(1,output_mask.size(-1),1)\n        outputs = outputs_padded[output_mask]\n        # outputs: (d_l * len_inp) x d_kv or LAL: (d_l * len_inp) x d_kv\n        # output_mask: (d_l * batch_size) x max_len\n        outputs = self.combine_v(outputs)\n        #print(f\"outputs shape: {outputs.shape}\")\n        # outputs: len_inp x d_l x d_model, whereas a normal self-attention layer gets len_inp x d_model\n        if self.use_resdrop:\n            if self.combine_as_self:\n                outputs = self.residual_dropout(outputs, batch_idxs)\n            else:\n                outputs = torch.cat([self.residual_dropout(outputs[:,i,:], batch_idxs).unsqueeze(1) for i in range(self.d_l)], 1)\n        if self.combine_as_self:\n            outputs = self.layer_norm(outputs + inp)\n        else:\n            for l in range(self.d_l):\n                outputs[:, l, :] = outputs[:, l, :] + inp\n\n            outputs = self.reduce_proj(outputs) # len_inp x d_l x d_proj\n            outputs = self.layer_norm(outputs) # len_inp x d_l x d_proj\n            outputs = outputs.view(len_inp, -1).contiguous() # len_inp x (d_l * d_proj)\n\n        return outputs, attns_padded\n\n\n#\nclass LabelAttentionModule(nn.Module):\n    \"\"\"\n    Label Attention Module for label-specific representations\n    The module can be used right after the Partitioned Attention, or it can be experimented with for the transition stack\n    \"\"\"\n    #\n    def __init__(self,\n                 d_model,\n                 d_input_proj,\n                 d_k,\n                 d_v,\n                 d_l,\n                 d_proj,\n                 combine_as_self,\n                 use_resdrop=True,\n                 q_as_matrix=False,\n                 residual_dropout=0.1,\n                 attention_dropout=0.1,\n                 d_positional=None,\n                 d_ff=2048,\n                 relu_dropout=0.2,\n                 lattn_partitioned=True):\n        super().__init__()\n        self.ff_dim = d_proj * d_l\n\n        if not lattn_partitioned:\n            self.d_positional = 0\n        else:\n            self.d_positional = d_positional if d_positional else 0\n\n        if d_input_proj:\n            if d_input_proj <= self.d_positional:\n                raise ValueError(\"Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d\" % (d_input_proj, self.d_positional))\n            self.input_projection = nn.Linear(d_model - self.d_positional, d_input_proj - self.d_positional, bias=False)\n            d_input = d_input_proj\n        else:\n            self.input_projection = None\n            d_input = d_model\n\n        self.label_attention = LabelAttention(d_input,\n                                              d_k,\n                                              d_v,\n                                              d_l,\n                                              d_proj,\n                                              combine_as_self,\n                                              use_resdrop,\n                                              q_as_matrix,\n                                              residual_dropout,\n                                              attention_dropout,\n                                              self.d_positional)\n\n        if not lattn_partitioned:\n            self.lal_ff = PositionwiseFeedForward(self.ff_dim,\n                                                  d_ff,\n                                                  relu_dropout,\n                                                  residual_dropout)\n        else:\n            self.lal_ff = PartitionedPositionwiseFeedForward(self.ff_dim,\n                                                             d_ff,\n                                                             self.d_positional,\n                                                             relu_dropout,\n                                                             residual_dropout)\n\n    def forward(self, word_embeddings, tagged_word_lists):\n        if self.input_projection:\n            if self.d_positional > 0:\n                word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]),\n                                              sentence[:, -self.d_positional:]), dim=1)\n                                   for sentence in word_embeddings]\n            else:\n                word_embeddings = [self.input_projection(sentence) for sentence in word_embeddings]\n        # Extract Labeled Representation\n        packed_len = sum(sentence.shape[0] for sentence in word_embeddings)\n        batch_idxs = np.zeros(packed_len, dtype=int)\n\n        batch_size = len(word_embeddings)\n        i = 0\n\n        sentence_lengths = [0] * batch_size\n        for sentence_idx, sentence in enumerate(word_embeddings):\n            sentence_lengths[sentence_idx] = len(sentence)\n            for word in sentence:\n                batch_idxs[i] = sentence_idx\n                i += 1\n\n        batch_indices = batch_idxs\n        batch_idxs = BatchIndices(batch_idxs, word_embeddings[0].device)\n\n        new_embeds = []\n        for sentence_idx, batch in enumerate(word_embeddings):\n            for word_idx, embed in enumerate(batch):\n                if word_idx < sentence_lengths[sentence_idx]:\n                    new_embeds.append(embed)\n\n        new_word_embeddings = torch.stack(new_embeds)\n\n        labeled_representations, _ = self.label_attention(new_word_embeddings, batch_idxs)\n        labeled_representations = self.lal_ff(labeled_representations, batch_idxs)\n        final_labeled_representations = [[] for i in range(batch_size)]\n\n        for idx, embed in enumerate(labeled_representations):\n            final_labeled_representations[batch_indices[idx]].append(embed)\n\n        for idx, representation in enumerate(final_labeled_representations):\n            final_labeled_representations[idx]  = torch.stack(representation)\n\n        return final_labeled_representations\n\n"
  },
  {
    "path": "stanza/models/constituency/lstm_model.py",
    "content": "\"\"\"\nA version of the BaseModel which uses LSTMs to predict the correct next transition\nbased on the current known state.\n\nThe primary purpose of this class is to implement the prediction of the next\ntransition, which is done by concatenating the output of an LSTM operated over\nprevious transitions, the words, and the partially built constituents.\n\nA complete processing of a sentence is as follows:\n  1) Run the input words through an encoder.\n     The encoder includes some or all of the following:\n       pretrained word embedding\n       finetuned word embedding for training set words - \"delta_embedding\"\n       POS tag embedding\n       pretrained charlm representation\n       BERT or similar large language model representation\n       attention transformer over the previous inputs\n       labeled attention transformer over the first attention layer\n     The encoded input is then put through a bi-lstm, giving a word representation\n  2) Transitions are put in an embedding, and transitions already used are tracked\n     in an LSTM\n  3) Constituents already built are also processed in an LSTM\n  4) Every transition is chosen by taking the output of the current word position,\n     the transition LSTM, and the constituent LSTM, and classifying the next\n     transition\n  5) Transitions are repeated (with constraints) until the sentence is completed\n\"\"\"\n\nfrom collections import namedtuple\nimport copy\nfrom enum import Enum\nimport logging\nimport math\nimport random\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.utils.rnn import pack_padded_sequence\n\nfrom stanza.models.common.bert_embedding import extract_bert_embeddings\nfrom stanza.models.common.maxout_linear import MaxoutLinear\nfrom stanza.models.common.relative_attn import RelativeAttention\nfrom stanza.models.common.utils import attach_bert_model, build_nonlinearity, unsort\nfrom stanza.models.common.vocab import PAD_ID, UNK_ID\nfrom stanza.models.constituency.base_model import BaseModel\nfrom stanza.models.constituency.label_attention import LabelAttentionModule\nfrom stanza.models.constituency.lstm_tree_stack import LSTMTreeStack\nfrom stanza.models.constituency.parse_transitions import TransitionScheme\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.models.constituency.partitioned_transformer import PartitionedTransformerModule\nfrom stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding\nfrom stanza.models.constituency.transformer_tree_stack import TransformerTreeStack\nfrom stanza.models.constituency.tree_stack import TreeStack\nfrom stanza.models.constituency.utils import initialize_linear\n\nlogger = logging.getLogger('stanza')\ntlogger = logging.getLogger('stanza.constituency.trainer')\n\nWordNode = namedtuple(\"WordNode\", ['value', 'hx'])\n\n# lstm_hx & lstm_cx are the hidden & cell states of the LSTM going across constituents\n# tree_hx and tree_cx are the states of the lstm going up the constituents in the case of the tree_lstm combination method\nConstituent = namedtuple(\"Constituent\", ['value', 'tree_hx', 'tree_cx'])\n\n# The sentence boundary vectors are marginally useful at best.\n# However, they make it much easier to use non-bert layers as input to\n# attention layers, as the attention layers work better when they have\n# an index 0 to attend to.\nclass SentenceBoundary(Enum):\n    NONE               = 1\n    WORDS              = 2\n    EVERYTHING         = 3\n\nclass StackHistory(Enum):\n    LSTM               = 1\n    ATTN               = 2\n\n# How to compose constituent children into new constituents\n# MAX is simply take the max value of the children\n# this is surprisingly effective\n# for example, a Turkish dataset went from 81-81.5 dev, 75->75.5 test\n# BILSTM is the method described in the papers of making an lstm\n# out of the constituents\n# BILSTM_MAX is the same as BILSTM, but instead of using a Linear\n# to reduce the outputs of the lstm, we first take the max\n# and then use a linear to reduce the max\n# BIGRAM combines pairs of children and then takes the max over those\n# ATTN means to put an attention layer over the children nodes\n# we then take the max of the children with their attention\n#\n# Experiments show that MAX is noticeably better than the other options\n# On ja_alt, here are a few results after 200 iterations,\n# averaged over 5 iterations:\n#   MAX:         0.8985\n#   BILSTM:      0.8964\n#   BILSTM_MAX:  0.8973\n#   BIGRAM:      0.8982\n#\n# The MAX method has a linear transform after the max.\n#   Removing that transform makes the score go down to 0.8982\n#\n# We tried a few varieties of BILSTM_MAX\n# In particular:\n# max over LSTM, combining forward & backward using the max: 0.8970\n# max over forward & backward separately, then reduce:       0.8970\n# max over forward & backward only over 1:-1\n#   (eg, leave out the node embedding):                      0.8969\n# same as previous, but split the reduce into 2 pieces:      0.8973\n# max over forward & backward separately, then reduce as\n#   1/2(F + B) + W(F,B)\n#   the idea being that this way F and B are guaranteed\n#   to be represented:                                       0.8971\n#\n# BIGRAM is an attempt to mix information from nodes\n#   when building constituents, but it didn't help\n#   The first example, just taking pairs and learning\n#   a transform, went to NaN.  Likely the transform\n#   expanded the embedding too much.  Switching it to\n#   scale the matrix by 0.5 didn't go to Nan, but only\n#   resulted in 0.8982\n#\n# A couple varieties of ATTN:\n# first an input linear, then attn, then an output linear\n#   the upside of this would be making the dimension of the attn\n#   independent from the rest of the model\n#   however, this caused an expansion in the magnitude of the vectors,\n#   resulting in NaN for deep enough trees\n# adding layernorm or tanh to balance this out resulted in\n#   disappointing performance\n#   tanh: 0.8972\n# another alternative not tested yet: lower initialization weights\n#   and enforce that the norms of the matrices are low enough that\n#   exponential explosion up the layers of the tree doesn't happen\n# just an attention layer means hidden_size % reduce_heads == 0\n#   that is simple enough to enforce by slightly changing hidden_size\n#   if needed\n# appending the embedding for the open state to the start of the\n#   sequence of children and taking only the content nodes\n#   was very disappointing: 0.8967\n# taking the entire sequence of children including the open state\n#   embedding resulted in 0.8973\n# long story short, this looks like an idea that should work, but it\n#   doesn't help.  suggestions welcome for improving these results\n#\n# The current TREE_LSTM_CX mechanism uses a word's embedding\n#   as the hx and a trained embedding over tags as the cx    0.8996\n# This worked slightly better than 0s for cx (TREE_LSTM)     0.8992\n# A variant of TREE_LSTM which didn't work out:\n#   nodes are combined with an LSTM\n#   hx & cx are embeddings of the node type (eg S, NP, etc)\n#   input is the max over children:                          0.8977\n# Another variant which didn't work: use the word embedding\n#   as input to the same LSTM to get hx & cx                 0.8985\n# Note that although the scores for TREE_LSTM_CX are slightly higher\n# than MAX for the JA dataset, the benefit was not as clear for EN,\n# so we left the default at MAX.\n# For example, on English WSJ, before switching to Bert POS and\n# a learned Bert mixing layer, a comparison of 5x models trained\n# for 400 iterations got dev scores of:\n#   TREE_LSTM_CX        0.9589\n#   MAX                 0.9593\n#\n# UNTIED_MAX has a different reduce_linear for each type of\n#   constituent in the model.  Similar to the different linear\n#   maps used in the CVG paper from Socher, Bauer, Manning, Ng\n# This is implemented as a large CxHxH parameter,\n#   with num_constituent layers of hidden-hidden transform,\n#   along with a CxH bias parameter.\n#   Essentially C Linears stacked on top of each other,\n#   but in a parameter so that indexing can be done quickly.\n# Unfortunately this does not beat out MAX with one combined linear.\n#   On an experiment on WSJ with all the best settings as of early\n#   October 2022, such as a Bert model POS tagger:\n#   MAX                 0.9597\n#   UNTIED_MAX          0.9592\n# Furthermore, starting from a finished MAX model and restarting\n#   by splitting the MAX layer into multiple pieces did not improve.\n#\n# KEY has a single Key which is used for a facsimile of ATTN\n#   each incoming subtree has its values weighted by a Query\n#   then the Key is used to calculate a softmax\n#   finally, a Value is used to scale the subtrees\n#   reduce_heads is used to determine the number of heads\n# There is an option to use or not use position information\n#   using a sinusoidal position embedding\n# UNTIED_KEY is the same, but has a different key\n#   for each possible constituent\n# On a VI dataset:\n#   MAX                    0.82064\n#   KEY (pos, 8)           0.81739\n#   UNTIED_KEY (pos, 8)    0.82046\n#   UNTIED_KEY (pos, 4)    0.81742\n# Attempted to add a linear to mix the attn heads together,\n#   but that was awful:    0.81567\n# Adding two position vectors, one in each direction, did not help:\n#   UNTIED_KEY (2x pos, 8) 0.8188\n# To redo that experiment, double the width of reduce_query and\n#   reduce_value, then call reduce_position on nhx, flip it,\n#   and call reduce_position again\n# Evidently the experiments to try should be:\n#   no pos at all\n#   more heads\nclass ConstituencyComposition(Enum):\n    BILSTM                = 1\n    MAX                   = 2\n    TREE_LSTM             = 3\n    BILSTM_MAX            = 4\n    BIGRAM                = 5\n    ATTN                  = 6\n    TREE_LSTM_CX          = 7\n    UNTIED_MAX            = 8\n    KEY                   = 9\n    UNTIED_KEY            = 10\n\nclass LSTMModel(BaseModel, nn.Module):\n    def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, force_bert_saved, peft_name, transitions, constituents, tags, words, rare_words, root_labels, constituent_opens, unary_limit, args):\n        \"\"\"\n        pretrain: a Pretrain object\n        transitions: a list of all possible transitions which will be\n          used to build trees\n        constituents: a list of all possible constituents in the treebank\n        tags: a list of all possible tags in the treebank\n        words: a list of all known words, used for a delta word embedding.\n          note that there will be an attempt made to learn UNK words as well,\n          and tags by themselves may help UNK words\n        rare_words: a list of rare words, used to occasionally replace with UNK\n        root_labels: probably ROOT, although apparently some treebanks like TOP or even s\n        constituent_opens: a list of all possible open nodes which will go on the stack\n          - this might be different from constituents if there are nodes\n            which represent multiple constituents at once\n        args: hidden_size, transition_hidden_size, etc as gotten from\n          constituency_parser.py\n\n        Note that it might look like a hassle to pass all of this in\n        when it can be collected directly from the trees themselves.\n        However, that would only work at train time.  At eval or\n        pipeline time we will load the lists from the saved model.\n        \"\"\"\n        super().__init__(transition_scheme=args['transition_scheme'], unary_limit=unary_limit, reverse_sentence=args.get('reversed', False), root_labels=root_labels)\n\n        self.args = args\n        self.unsaved_modules = []\n\n        emb_matrix = pretrain.emb\n        self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))\n\n        # replacing NBSP picks up a whole bunch of words for VI\n        self.vocab_map = { word.replace('\\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }\n        # precompute tensors for the word indices\n        # the tensors should be put on the GPU if needed by calling to(device)\n        self.register_buffer('vocab_tensors', torch.tensor(range(len(pretrain.vocab)), requires_grad=False))\n        self.vocab_size = emb_matrix.shape[0]\n        self.embedding_dim = emb_matrix.shape[1]\n\n        self.constituents = sorted(list(constituents))\n\n        self.hidden_size = self.args['hidden_size']\n        self.constituency_composition = self.args.get(\"constituency_composition\", ConstituencyComposition.BILSTM)\n        if self.constituency_composition in (ConstituencyComposition.ATTN, ConstituencyComposition.KEY, ConstituencyComposition.UNTIED_KEY):\n            self.reduce_heads = self.args['reduce_heads']\n            if self.hidden_size % self.reduce_heads != 0:\n                self.hidden_size = self.hidden_size + self.reduce_heads - (self.hidden_size % self.reduce_heads)\n\n        if args['constituent_stack'] == StackHistory.ATTN:\n            self.reduce_heads = self.args['reduce_heads']\n            if self.hidden_size % args['constituent_heads'] != 0:\n                # TODO: technically we should either use the LCM of this and reduce_heads, or just have two separate fields\n                self.hidden_size = self.hidden_size + args['constituent_heads'] - (hidden_size % args['constituent_heads'])\n                if self.constituency_composition == ConstituencyComposition.ATTN and self.hidden_size % self.reduce_heads != 0:\n                    raise ValueError(\"--reduce_heads and --constituent_heads not compatible!\")\n\n        self.transition_hidden_size = self.args['transition_hidden_size']\n        if args['transition_stack'] == StackHistory.ATTN:\n            if self.transition_hidden_size % args['transition_heads'] > 0:\n                logger.warning(\"transition_hidden_size %d %% transition_heads %d != 0.  reconfiguring\", transition_hidden_size, args['transition_heads'])\n                self.transition_hidden_size = self.transition_hidden_size + args['transition_heads'] - (self.transition_hidden_size % args['transition_heads'])\n\n        self.tag_embedding_dim = self.args['tag_embedding_dim']\n        self.transition_embedding_dim = self.args['transition_embedding_dim']\n        self.delta_embedding_dim = self.args['delta_embedding_dim']\n\n        self.word_input_size = self.embedding_dim + self.tag_embedding_dim + self.delta_embedding_dim\n\n        if forward_charlm is not None:\n            self.add_unsaved_module('forward_charlm', forward_charlm)\n            self.word_input_size += self.forward_charlm.hidden_dim()\n            if not forward_charlm.is_forward_lm:\n                raise ValueError(\"Got a backward charlm as a forward charlm!\")\n        else:\n            self.forward_charlm = None\n        if backward_charlm is not None:\n            self.add_unsaved_module('backward_charlm', backward_charlm)\n            self.word_input_size += self.backward_charlm.hidden_dim()\n            if backward_charlm.is_forward_lm:\n                raise ValueError(\"Got a forward charlm as a backward charlm!\")\n        else:\n            self.backward_charlm = None\n\n        self.delta_words = sorted(set(words))\n        self.delta_word_map = { word: i+2 for i, word in enumerate(self.delta_words) }\n        assert PAD_ID == 0\n        assert UNK_ID == 1\n        # initialization is chosen based on the observed values of the norms\n        # after several long training cycles\n        # (this is true for other embeddings and embedding-like vectors as well)\n        # the experiments show this slightly helps were done with\n        # Adadelta and the correct initialization may be slightly\n        # different for a different optimizer.\n        # in fact, it is likely a scheme other than normal_ would\n        # be better - the optimizer tends to learn the weights\n        # rather close to 0 before learning in the direction it\n        # actually wants to go\n        self.delta_embedding = nn.Embedding(num_embeddings = len(self.delta_words)+2,\n                                            embedding_dim = self.delta_embedding_dim,\n                                            padding_idx = 0)\n        nn.init.normal_(self.delta_embedding.weight, std=0.05)\n        self.register_buffer('delta_tensors', torch.tensor(range(len(self.delta_words) + 2), requires_grad=False))\n\n        self.rare_words = set(rare_words)\n\n        self.tags = sorted(list(tags))\n        if self.tag_embedding_dim > 0:\n            self.tag_map = { t: i+2 for i, t in enumerate(self.tags) }\n            self.tag_embedding = nn.Embedding(num_embeddings = len(tags)+2,\n                                              embedding_dim = self.tag_embedding_dim,\n                                              padding_idx = 0)\n            nn.init.normal_(self.tag_embedding.weight, std=0.25)\n            self.register_buffer('tag_tensors', torch.tensor(range(len(self.tags) + 2), requires_grad=False))\n\n        self.num_lstm_layers = self.args['num_lstm_layers']\n        self.num_tree_lstm_layers = self.args['num_tree_lstm_layers']\n        self.lstm_layer_dropout = self.args['lstm_layer_dropout']\n\n        self.word_dropout = nn.Dropout(self.args['word_dropout'])\n        self.predict_dropout = nn.Dropout(self.args['predict_dropout'])\n        self.lstm_input_dropout = nn.Dropout(self.args['lstm_input_dropout'])\n\n        # also register a buffer of zeros so that we can always get zeros on the appropriate device\n        self.register_buffer('word_zeros', torch.zeros(self.hidden_size * self.num_tree_lstm_layers))\n        self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size))\n\n        # possibly add a couple vectors for bookends of the sentence\n        # We put the word_start and word_end here, AFTER counting the\n        # charlm dimension, but BEFORE counting the bert dimension,\n        # as we want word_start and word_end to not have dimensions\n        # for the bert embedding.  The bert model will add its own\n        # start and end representation.\n        self.sentence_boundary_vectors = self.args['sentence_boundary_vectors']\n        if self.sentence_boundary_vectors is not SentenceBoundary.NONE:\n            self.register_parameter('word_start_embedding', torch.nn.Parameter(0.2 * torch.randn(self.word_input_size, requires_grad=True)))\n            self.register_parameter('word_end_embedding', torch.nn.Parameter(0.2 * torch.randn(self.word_input_size, requires_grad=True)))\n\n        # we set up the bert AFTER building word_start and word_end\n        # so that we can use the charlm endpoint values rather than\n        # try to train our own\n        self.force_bert_saved = force_bert_saved or self.args['bert_finetune'] or self.args['stage1_bert_finetune']\n        attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), self.force_bert_saved)\n        self.peft_name = peft_name\n\n        if bert_model is not None:\n            if bert_tokenizer is None:\n                raise ValueError(\"Cannot have a bert model without a tokenizer\")\n            self.bert_dim = self.bert_model.config.hidden_size\n            if args['bert_hidden_layers']:\n                # The average will be offset by 1/N so that the default zeros\n                # represents an average of the N layers\n                if args['bert_hidden_layers'] > bert_model.config.num_hidden_layers:\n                    # limit ourselves to the number of layers actually available\n                    # note that we can +1 because of the initial embedding layer\n                    args['bert_hidden_layers'] = bert_model.config.num_hidden_layers + 1\n                self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)\n                nn.init.zeros_(self.bert_layer_mix.weight)\n            else:\n                # an average of layers 2, 3, 4 will be used\n                # (for historic reasons)\n                self.bert_layer_mix = None\n            self.word_input_size = self.word_input_size + self.bert_dim\n\n        self.partitioned_transformer_module = None\n        self.pattn_d_model = 0\n        if LSTMModel.uses_pattn(self.args):\n            # Initializations of parameters for the Partitioned Attention\n            # round off the size of the model so that it divides in half evenly\n            self.pattn_d_model = self.args['pattn_d_model'] // 2 * 2\n\n            # Initializations for the Partitioned Attention\n            # experiments suggest having a bias does not help here\n            self.partitioned_transformer_module = PartitionedTransformerModule(\n                self.args['pattn_num_layers'],\n                d_model=self.pattn_d_model,\n                n_head=self.args['pattn_num_heads'],\n                d_qkv=self.args['pattn_d_kv'],\n                d_ff=self.args['pattn_d_ff'],\n                ff_dropout=self.args['pattn_relu_dropout'],\n                residual_dropout=self.args['pattn_residual_dropout'],\n                attention_dropout=self.args['pattn_attention_dropout'],\n                word_input_size=self.word_input_size,\n                bias=self.args['pattn_bias'],\n                morpho_emb_dropout=self.args['pattn_morpho_emb_dropout'],\n                timing=self.args['pattn_timing'],\n                encoder_max_len=self.args['pattn_encoder_max_len']\n            )\n            self.word_input_size += self.pattn_d_model\n\n        self.label_attention_module = None\n        if LSTMModel.uses_lattn(self.args):\n            if self.partitioned_transformer_module is None:\n                logger.error(\"Not using Labeled Attention, as the Partitioned Attention module is not used\")\n            else:\n                # TODO: think of a couple ways to use alternate inputs\n                # for example, could pass in the word inputs with a positional embedding\n                # that would also allow it to work in the case of no partitioned module\n                if self.args['lattn_combined_input']:\n                    self.lattn_d_input = self.word_input_size\n                else:\n                    self.lattn_d_input = self.pattn_d_model\n                self.label_attention_module = LabelAttentionModule(self.lattn_d_input,\n                                                                   self.args['lattn_d_input_proj'],\n                                                                   self.args['lattn_d_kv'],\n                                                                   self.args['lattn_d_kv'],\n                                                                   self.args['lattn_d_l'],\n                                                                   self.args['lattn_d_proj'],\n                                                                   self.args['lattn_combine_as_self'],\n                                                                   self.args['lattn_resdrop'],\n                                                                   self.args['lattn_q_as_matrix'],\n                                                                   self.args['lattn_residual_dropout'],\n                                                                   self.args['lattn_attention_dropout'],\n                                                                   self.pattn_d_model // 2,\n                                                                   self.args['lattn_d_ff'],\n                                                                   self.args['lattn_relu_dropout'],\n                                                                   self.args['lattn_partitioned'])\n                self.word_input_size = self.word_input_size + self.args['lattn_d_proj']*self.args['lattn_d_l']\n\n        self.rel_attn_forward = None\n        self.rel_attn_reverse = None\n        if self.args.get('use_rattn', False):\n            if not self.args['rattn_cat'] and self.word_input_size % self.args['rattn_heads'] != 0:\n                for rattn_heads in range(self.args['rattn_heads'] // 2):\n                    if self.word_input_size % (self.args['rattn_heads'] + rattn_heads) == 0:\n                        new_rattn_heads = self.args['rattn_heads'] + rattn_heads\n                        break\n                    if self.word_input_size % (self.args['rattn_heads'] - rattn_heads) == 0:\n                        new_rattn_heads = self.args['rattn_heads'] - rattn_heads\n                        break\n                else:\n                    raise ValueError(\"Number of heads %d does not divide evenly into input size %d\" % (self.args['rattn_heads'], self.word_input_size))\n                logger.warning(\"rattn_heads of %d does not work, but found a similar value of %d which does work\", self.args['rattn_heads'], new_rattn_heads)\n                self.args['rattn_heads'] = new_rattn_heads\n\n            if self.args['rattn_forward']:\n                if self.args['rattn_cat']:\n                    self.rel_attn_forward = RelativeAttention(self.word_input_size, self.args['rattn_heads'], window=self.args['rattn_window'], d_output=self.args['rattn_dim'], fudge_output=True, num_sinks=self.args['rattn_sinks'])\n                else:\n                    self.rel_attn_forward = RelativeAttention(self.word_input_size, self.args['rattn_heads'], window=self.args['rattn_window'], num_sinks=self.args['rattn_sinks'])\n\n            if self.args['rattn_reverse']:\n                if self.args['rattn_cat']:\n                    self.rel_attn_reverse = RelativeAttention(self.word_input_size, self.args['rattn_heads'], window=self.args['rattn_window'], reverse=True, d_output=self.args['rattn_dim'], fudge_output=True, num_sinks=self.args['rattn_sinks'])\n                else:\n                    self.rel_attn_reverse = RelativeAttention(self.word_input_size, self.args['rattn_heads'], window=self.args['rattn_window'], reverse=True, num_sinks=self.args['rattn_sinks'])\n\n            if self.args['rattn_forward'] and self.args['rattn_cat']:\n                self.word_input_size += self.rel_attn_forward.d_output\n\n            if self.args['rattn_reverse'] and self.args['rattn_cat']:\n                self.word_input_size += self.rel_attn_reverse.d_output\n\n        self.word_lstm = nn.LSTM(input_size=self.word_input_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)\n\n        # after putting the word_delta_tag input through the word_lstm, we get back\n        # hidden_size * 2 output with the front and back lstms concatenated.\n        # this transforms it into hidden_size with the values mixed together\n        self.word_to_constituent = nn.Linear(self.hidden_size * 2, self.hidden_size * self.num_tree_lstm_layers)\n        initialize_linear(self.word_to_constituent, self.args['nonlinearity'], self.hidden_size * 2)\n\n        self.transitions = sorted(list(transitions))\n        self.transition_map = { t: i for i, t in enumerate(self.transitions) }\n        # precompute tensors for the transitions\n        self.register_buffer('transition_tensors', torch.tensor(range(len(transitions)), requires_grad=False))\n        self.transition_embedding = nn.Embedding(num_embeddings = len(transitions),\n                                                 embedding_dim = self.transition_embedding_dim)\n        nn.init.normal_(self.transition_embedding.weight, std=0.25)\n        if args['transition_stack'] == StackHistory.LSTM:\n            self.transition_stack = LSTMTreeStack(input_size=self.transition_embedding_dim,\n                                                  hidden_size=self.transition_hidden_size,\n                                                  num_lstm_layers=self.num_lstm_layers,\n                                                  dropout=self.lstm_layer_dropout,\n                                                  uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,\n                                                  input_dropout=self.lstm_input_dropout)\n        elif args['transition_stack'] == StackHistory.ATTN:\n            self.transition_stack = TransformerTreeStack(input_size=self.transition_embedding_dim,\n                                                         output_size=self.transition_hidden_size,\n                                                         input_dropout=self.lstm_input_dropout,\n                                                         use_position=True,\n                                                         num_heads=args['transition_heads'])\n        else:\n            raise ValueError(\"Unhandled transition_stack StackHistory: {}\".format(args['transition_stack']))\n\n        self.constituent_opens = sorted(list(constituent_opens))\n        # an embedding for the spot on the constituent LSTM taken up by the Open transitions\n        # the pattern when condensing constituents is embedding - con1 - con2 - con3 - embedding\n        # TODO: try the two ends have different embeddings?\n        self.constituent_open_map = { x: i for (i, x) in enumerate(self.constituent_opens) }\n        self.constituent_open_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),\n                                                       embedding_dim = self.hidden_size)\n        nn.init.normal_(self.constituent_open_embedding.weight, std=0.2)\n\n        # input_size is hidden_size - could introduce a new constituent_size instead if we liked\n        if args['constituent_stack'] == StackHistory.LSTM:\n            self.constituent_stack = LSTMTreeStack(input_size=self.hidden_size,\n                                                   hidden_size=self.hidden_size,\n                                                   num_lstm_layers=self.num_lstm_layers,\n                                                   dropout=self.lstm_layer_dropout,\n                                                   uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,\n                                                   input_dropout=self.lstm_input_dropout)\n        elif args['constituent_stack'] == StackHistory.ATTN:\n            self.constituent_stack = TransformerTreeStack(input_size=self.hidden_size,\n                                                          output_size=self.hidden_size,\n                                                          input_dropout=self.lstm_input_dropout,\n                                                          use_position=True,\n                                                          num_heads=args['constituent_heads'])\n        else:\n            raise ValueError(\"Unhandled constituent_stack StackHistory: {}\".format(args['transition_stack']))\n\n\n        if args['combined_dummy_embedding']:\n            self.dummy_embedding = self.constituent_open_embedding\n        else:\n            self.dummy_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),\n                                                embedding_dim = self.hidden_size)\n            nn.init.normal_(self.dummy_embedding.weight, std=0.2)\n        self.register_buffer('constituent_open_tensors', torch.tensor(range(len(constituent_opens)), requires_grad=False))\n\n        # TODO: refactor\n        if (self.constituency_composition == ConstituencyComposition.BILSTM or\n            self.constituency_composition == ConstituencyComposition.BILSTM_MAX):\n            # forward and backward pieces for crunching several\n            # constituents into one, combined into a bi-lstm\n            # TODO: make the hidden size here an option?\n            self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)\n            # affine transformation from bi-lstm reduce to a new hidden layer\n            if self.constituency_composition == ConstituencyComposition.BILSTM:\n                self.reduce_linear = nn.Linear(self.hidden_size * 2, self.hidden_size)\n                initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size * 2)\n            else:\n                self.reduce_forward = nn.Linear(self.hidden_size, self.hidden_size)\n                self.reduce_backward = nn.Linear(self.hidden_size, self.hidden_size)\n                initialize_linear(self.reduce_forward, self.args['nonlinearity'], self.hidden_size)\n                initialize_linear(self.reduce_backward, self.args['nonlinearity'], self.hidden_size)\n        elif self.constituency_composition == ConstituencyComposition.MAX:\n            # transformation to turn several constituents into one new constituent\n            self.reduce_linear = nn.Linear(self.hidden_size, self.hidden_size)\n            initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)\n        elif self.constituency_composition == ConstituencyComposition.UNTIED_MAX:\n            # transformation to turn several constituents into one new constituent\n            self.register_parameter('reduce_linear_weight', torch.nn.Parameter(torch.randn(len(constituent_opens), self.hidden_size, self.hidden_size, requires_grad=True)))\n            self.register_parameter('reduce_linear_bias', torch.nn.Parameter(torch.randn(len(constituent_opens), self.hidden_size, requires_grad=True)))\n            for layer_idx in range(len(constituent_opens)):\n                nn.init.kaiming_normal_(self.reduce_linear_weight[layer_idx], nonlinearity=self.args['nonlinearity'])\n            nn.init.uniform_(self.reduce_linear_bias, 0, 1 / (self.hidden_size * 2) ** 0.5)\n        elif self.constituency_composition == ConstituencyComposition.BIGRAM:\n            self.reduce_linear = nn.Linear(self.hidden_size, self.hidden_size)\n            self.reduce_bigram = nn.Linear(self.hidden_size * 2, self.hidden_size)\n            initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)\n            initialize_linear(self.reduce_bigram, self.args['nonlinearity'], self.hidden_size)\n        elif self.constituency_composition == ConstituencyComposition.ATTN:\n            self.reduce_attn = nn.MultiheadAttention(self.hidden_size, self.reduce_heads)\n        elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:\n            if self.args['reduce_position']:\n                # unsaved module so that if it grows, we don't save\n                # the larger version unnecessarily\n                # under any normal circumstances, the growth will\n                # happen early in training when the model is not\n                # behaving well, then will not be needed once the\n                # model learns not to make super degenerate\n                # constituents\n                self.add_unsaved_module(\"reduce_position\", ConcatSinusoidalEncoding(self.args['reduce_position'], 50))\n            else:\n                self.add_unsaved_module(\"reduce_position\", nn.Identity())\n            self.reduce_query = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size, bias=False)\n            self.reduce_value = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size)\n            if self.constituency_composition == ConstituencyComposition.KEY:\n                self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))\n            else:\n                self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(len(constituent_opens), self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))\n        elif self.constituency_composition == ConstituencyComposition.TREE_LSTM:\n            self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)\n        elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:\n            self.constituent_reduce_embedding = nn.Embedding(num_embeddings = len(tags)+2,\n                                                             embedding_dim = self.num_tree_lstm_layers * self.hidden_size)\n            self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)\n        else:\n            raise ValueError(\"Unhandled ConstituencyComposition: {}\".format(self.constituency_composition))\n\n        self.nonlinearity = build_nonlinearity(self.args['nonlinearity'])\n\n        # matrix for predicting the next transition using word/constituent/transition queues\n        # word size + constituency size + transition size\n        # TODO: .get() is only necessary until all models rebuilt with this param\n        self.maxout_k = self.args.get('maxout_k', 0)\n        self.output_layers = self.build_output_layers(self.args['num_output_layers'], len(transitions), self.maxout_k)\n\n    @staticmethod\n    def uses_lattn(args):\n        return args.get('use_lattn', True) and args.get('lattn_d_proj', 0) > 0 and args.get('lattn_d_l', 0) > 0\n\n    @staticmethod\n    def uses_pattn(args):\n        return args['pattn_num_heads'] > 0 and args['pattn_num_layers'] > 0\n\n    def copy_with_new_structure(self, other):\n        \"\"\"\n        Copy parameters from the other model to this model\n\n        word_lstm can change size if the other model didn't use pattn / lattn and this one does.\n        In that case, the new values are initialized to 0.\n        This will rebuild the model in such a way that the outputs will be\n        exactly the same as the previous model.\n        \"\"\"\n        if self.constituency_composition != other.constituency_composition and self.constituency_composition != ConstituencyComposition.UNTIED_MAX:\n            raise ValueError(\"Models are incompatible: self.constituency_composition == {}, other.constituency_composition == {}\".format(self.constituency_composition, other.constituency_composition))\n        for name, other_parameter in other.named_parameters():\n            # this allows other.constituency_composition == UNTIED_MAX to fall through\n            if name.startswith('reduce_linear.') and self.constituency_composition == ConstituencyComposition.UNTIED_MAX:\n                if name == 'reduce_linear.weight':\n                    my_parameter = self.reduce_linear_weight\n                elif name == 'reduce_linear.bias':\n                    my_parameter = self.reduce_linear_bias\n                else:\n                    raise ValueError(\"Unexpected other parameter name {}\".format(name))\n                for idx in range(len(self.constituent_opens)):\n                    my_parameter[idx].data.copy_(other_parameter.data)\n            elif name.startswith('word_lstm.weight_ih_l0'):\n                # bottom layer shape may have changed from adding a new pattn / lattn block\n                my_parameter = self.get_parameter(name)\n                # -1 so that it can be converted easier to a different parameter\n                copy_size = min(other_parameter.data.shape[-1], my_parameter.data.shape[-1])\n                #new_values = my_parameter.data.clone().detach()\n                new_values = torch.zeros_like(my_parameter.data)\n                new_values[..., :copy_size] = other_parameter.data[..., :copy_size]\n                my_parameter.data.copy_(new_values)\n            else:\n                try:\n                    self.get_parameter(name).data.copy_(other_parameter.data)\n                except AttributeError as e:\n                    raise AttributeError(\"Could not process %s\" % name) from e\n\n    def build_output_layers(self, num_output_layers, final_layer_size, maxout_k):\n        \"\"\"\n        Build a ModuleList of Linear transformations for the given num_output_layers\n\n        The final layer size can be specified.\n        Initial layer size is the combination of word, constituent, and transition vectors\n        Middle layer sizes are self.hidden_size\n        \"\"\"\n        middle_layers = num_output_layers - 1\n        # word_lstm:         hidden_size * num_tree_lstm_layers\n        # transition_stack:  transition_hidden_size\n        # constituent_stack: hidden_size\n        predict_input_size = [self.hidden_size + self.hidden_size * self.num_tree_lstm_layers + self.transition_hidden_size] + [self.hidden_size] * middle_layers\n        predict_output_size = [self.hidden_size] * middle_layers + [final_layer_size]\n        if not maxout_k:\n            output_layers = nn.ModuleList([nn.Linear(input_size, output_size)\n                                           for input_size, output_size in zip(predict_input_size, predict_output_size)])\n            for output_layer, input_size in zip(output_layers, predict_input_size):\n                initialize_linear(output_layer, self.args['nonlinearity'], input_size)\n        else:\n            output_layers = nn.ModuleList([MaxoutLinear(input_size, output_size, maxout_k)\n                                           for input_size, output_size in zip(predict_input_size, predict_output_size)])\n        return output_layers\n\n    def num_words_known(self, words):\n        return sum(word in self.vocab_map or word.lower() in self.vocab_map for word in words)\n\n    @property\n    def retag_method(self):\n        # TODO: make the method an enum\n        return self.args['retag_method']\n\n    def uses_xpos(self):\n        return self.args['retag_package'] is not None and self.args['retag_method'] == 'xpos'\n\n    def add_unsaved_module(self, name, module):\n        \"\"\"\n        Adds a module which will not be saved to disk\n\n        Best used for large models such as pretrained word embeddings\n        \"\"\"\n        self.unsaved_modules += [name]\n        setattr(self, name, module)\n        if module is not None and name in ('forward_charlm', 'backward_charlm'):\n            for _, parameter in module.named_parameters():\n                parameter.requires_grad = False\n\n    def is_unsaved_module(self, name):\n        return name.split('.')[0] in self.unsaved_modules\n\n    def get_norms(self):\n        lines = []\n        skip = set()\n        if self.constituency_composition == ConstituencyComposition.UNTIED_MAX:\n            skip = {'reduce_linear_weight', 'reduce_linear_bias'}\n            lines.append(\"reduce_linear:\")\n            for c_idx, c_open in enumerate(self.constituent_opens):\n                lines.append(\"  %s weight %.6g bias %.6g\" % (c_open, torch.norm(self.reduce_linear_weight[c_idx]).item(), torch.norm(self.reduce_linear_bias[c_idx]).item()))\n        active_params = [(name, param) for name, param in self.named_parameters() if param.requires_grad and name not in skip]\n        if len(active_params) == 0:\n            return lines\n        print(len(active_params))\n\n        max_name_len = max(len(name) for name, param in active_params)\n        max_norm_len = max(len(\"%.6g\" % torch.norm(param).item()) for name, param in active_params)\n        format_string = \"%-\" + str(max_name_len) + \"s   norm %\" + str(max_norm_len) + \"s  zeros %d / %d\"\n        for name, param in active_params:\n            zeros = torch.sum(param.abs() < 0.000001).item()\n            norm = \"%.6g\" % torch.norm(param).item()\n            lines.append(format_string % (name, norm, zeros, param.nelement()))\n        return lines\n\n    def log_norms(self):\n        lines = [\"NORMS FOR MODEL PARAMETERS\"]\n        lines.extend(self.get_norms())\n        logger.info(\"\\n\".join(lines))\n\n    def log_shapes(self):\n        lines = [\"NORMS FOR MODEL PARAMETERS\"]\n        for name, param in self.named_parameters():\n            if param.requires_grad:\n                lines.append(\"{} {}\".format(name, param.shape))\n        logger.info(\"\\n\".join(lines))\n\n    def initial_word_queues(self, tagged_word_lists):\n        \"\"\"\n        Produce initial word queues out of the model's LSTMs for use in the tagged word lists.\n\n        Operates in a batched fashion to reduce the runtime for the LSTM operations\n        \"\"\"\n        device = next(self.parameters()).device\n\n        vocab_map = self.vocab_map\n        def map_word(word):\n            idx = vocab_map.get(word, None)\n            if idx is not None:\n                return idx\n            return vocab_map.get(word.lower(), UNK_ID)\n\n        all_word_inputs = []\n        all_word_labels = [[word.children[0].label for word in tagged_words]\n                           for tagged_words in tagged_word_lists]\n\n        for sentence_idx, tagged_words in enumerate(tagged_word_lists):\n            word_labels = all_word_labels[sentence_idx]\n            word_idx = torch.stack([self.vocab_tensors[map_word(word.children[0].label)] for word in tagged_words])\n            word_input = self.embedding(word_idx)\n\n            # this occasionally learns UNK at train time\n            if self.training:\n                delta_labels = [None if word in self.rare_words and random.random() < self.args['rare_word_unknown_frequency'] else word\n                                for word in word_labels]\n            else:\n                delta_labels = word_labels\n            delta_idx = torch.stack([self.delta_tensors[self.delta_word_map.get(word, UNK_ID)] for word in delta_labels])\n\n            delta_input = self.delta_embedding(delta_idx)\n            word_inputs = [word_input, delta_input]\n\n            if self.tag_embedding_dim > 0:\n                if self.training:\n                    tag_labels = [None if random.random() < self.args['tag_unknown_frequency'] else word.label for word in tagged_words]\n                else:\n                    tag_labels = [word.label for word in tagged_words]\n                tag_idx = torch.stack([self.tag_tensors[self.tag_map.get(tag, UNK_ID)] for tag in tag_labels])\n                tag_input = self.tag_embedding(tag_idx)\n                word_inputs.append(tag_input)\n\n            all_word_inputs.append(word_inputs)\n\n        if self.forward_charlm is not None:\n            all_forward_chars = self.forward_charlm.build_char_representation(all_word_labels)\n            for word_inputs, forward_chars in zip(all_word_inputs, all_forward_chars):\n                word_inputs.append(forward_chars)\n        if self.backward_charlm is not None:\n            all_backward_chars = self.backward_charlm.build_char_representation(all_word_labels)\n            for word_inputs, backward_chars in zip(all_word_inputs, all_backward_chars):\n                word_inputs.append(backward_chars)\n\n        all_word_inputs = [torch.cat(word_inputs, dim=1) for word_inputs in all_word_inputs]\n        if self.sentence_boundary_vectors is not SentenceBoundary.NONE:\n            word_start = self.word_start_embedding.unsqueeze(0)\n            word_end = self.word_end_embedding.unsqueeze(0)\n            all_word_inputs = [torch.cat([word_start, word_inputs, word_end], dim=0) for word_inputs in all_word_inputs]\n\n        if self.bert_model is not None:\n            # BERT embedding extraction\n            # result will be len+2 for each sentence\n            # we will take 1:-1 if we don't care about the endpoints\n            bert_embeddings = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, all_word_labels, device,\n                                                      keep_endpoints=self.sentence_boundary_vectors is not SentenceBoundary.NONE,\n                                                      num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,\n                                                      detach=not self.args['bert_finetune'] and not self.args['stage1_bert_finetune'],\n                                                      peft_name=self.peft_name)\n            if self.bert_layer_mix is not None:\n                # add the average so that the default behavior is to\n                # take an average of the N layers, and anything else\n                # other than that needs to be learned\n                bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]\n\n            all_word_inputs = [torch.cat((x, y), axis=1) for x, y in zip(all_word_inputs, bert_embeddings)]\n\n        # Extract partitioned representation\n        if self.partitioned_transformer_module is not None:\n            partitioned_embeddings = self.partitioned_transformer_module(None, all_word_inputs)\n            all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, partitioned_embeddings)]\n\n        # Extract Labeled Representation\n        if self.label_attention_module is not None:\n            if self.args['lattn_combined_input']:\n                labeled_representations = self.label_attention_module(all_word_inputs, tagged_word_lists)\n            else:\n                labeled_representations = self.label_attention_module(partitioned_embeddings, tagged_word_lists)\n            all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, labeled_representations)]\n\n        if self.rel_attn_forward is not None or self.rel_attn_reverse is not None:\n            rattn_inputs = [[x] for x in all_word_inputs]\n\n            if self.rel_attn_forward is not None:\n                if self.args['rattn_use_endpoint_sinks']:\n                    rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0), x[0][0]).squeeze(0)] for x in rattn_inputs]\n                else:\n                    rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]\n            if self.rel_attn_reverse is not None:\n                if self.args['rattn_use_endpoint_sinks']:\n                    rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0), x[0][-1]).squeeze(0)] for x in rattn_inputs]\n                else:\n                    rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]\n\n            if self.args['rattn_cat']:\n                all_word_inputs = [torch.cat(x, axis=1) for x in rattn_inputs]\n            else:\n                rattn_inputs = [torch.stack(x, axis=2) for x in rattn_inputs]\n                all_word_inputs = [torch.sum(x, axis=2) for x in rattn_inputs]\n\n        all_word_inputs = [self.word_dropout(word_inputs) for word_inputs in all_word_inputs]\n        packed_word_input = torch.nn.utils.rnn.pack_sequence(all_word_inputs, enforce_sorted=False)\n        word_output, _ = self.word_lstm(packed_word_input)\n        # would like to do word_to_constituent here, but it seems PackedSequence doesn't support Linear\n        # word_output will now be sentence x batch x 2*hidden_size\n        word_output, word_output_lens = torch.nn.utils.rnn.pad_packed_sequence(word_output)\n        # now sentence x batch x hidden_size\n\n        word_queues = []\n        for sentence_idx, tagged_words in enumerate(tagged_word_lists):\n            if self.sentence_boundary_vectors is not SentenceBoundary.NONE:\n                sentence_output = word_output[:len(tagged_words)+2, sentence_idx, :]\n            else:\n                sentence_output = word_output[:len(tagged_words), sentence_idx, :]\n            sentence_output = self.word_to_constituent(sentence_output)\n            sentence_output = self.nonlinearity(sentence_output)\n            # TODO: this makes it so constituents downstream are\n            # build with the outputs of the LSTM, not the word\n            # embeddings themselves.  It is possible we want to\n            # transform the word_input to hidden_size in some way\n            # and use that instead\n            if self.sentence_boundary_vectors is not SentenceBoundary.NONE:\n                word_queue =  [WordNode(None, sentence_output[0, :])]\n                word_queue += [WordNode(tag_node, sentence_output[idx+1, :])\n                               for idx, tag_node in enumerate(tagged_words)]\n                word_queue.append(WordNode(None, sentence_output[len(tagged_words)+1, :]))\n            else:\n                word_queue =  [WordNode(None, self.word_zeros)]\n                word_queue += [WordNode(tag_node, sentence_output[idx, :])\n                                   for idx, tag_node in enumerate(tagged_words)]\n                word_queue.append(WordNode(None, self.word_zeros))\n\n            if self.reverse_sentence:\n                word_queue = list(reversed(word_queue))\n            word_queues.append(word_queue)\n\n        return word_queues\n\n    def initial_transitions(self):\n        \"\"\"\n        Return an initial TreeStack with no transitions\n        \"\"\"\n        return self.transition_stack.initial_state()\n\n    def initial_constituents(self):\n        \"\"\"\n        Return an initial TreeStack with no constituents\n        \"\"\"\n        return self.constituent_stack.initial_state(Constituent(None, self.constituent_zeros, self.constituent_zeros))\n\n    def get_word(self, word_node):\n        return word_node.value\n\n    def transform_word_to_constituent(self, state):\n        word_node = state.get_word(state.word_position)\n        word = word_node.value\n        if self.constituency_composition == ConstituencyComposition.TREE_LSTM:\n            return Constituent(word, word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size), self.word_zeros.view(self.num_tree_lstm_layers, self.hidden_size))\n        elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:\n            # the UNK tag will be trained thanks to occasionally dropping out tags\n            tag = word.label\n            tree_hx = word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size)\n            tag_tensor = self.tag_tensors[self.tag_map.get(tag, UNK_ID)]\n            tree_cx = self.constituent_reduce_embedding(tag_tensor)\n            tree_cx = tree_cx.view(self.num_tree_lstm_layers, self.hidden_size)\n            return Constituent(word, tree_hx, tree_cx * tree_hx)\n        else:\n            return Constituent(word, word_node.hx[:self.hidden_size].unsqueeze(0), None)\n\n    def dummy_constituent(self, dummy):\n        label = dummy.label\n        open_index = self.constituent_open_tensors[self.constituent_open_map[label]]\n        hx = self.dummy_embedding(open_index)\n        # the cx doesn't matter: the dummy will be discarded when building a new constituent\n        return Constituent(dummy, hx.unsqueeze(0), None)\n\n    def build_constituents(self, labels, children_lists):\n        \"\"\"\n        Build new constituents with the given label from the list of children\n\n        labels is a list of labels for each of the new nodes to construct\n        children_lists is a list of children that go under each of the new nodes\n        lists of each are used so that we can stack operations\n        \"\"\"\n        # at the end of each of these operations, we expect lstm_hx.shape\n        # is (L, N, hidden_size) for N lists of children\n        if (self.constituency_composition == ConstituencyComposition.BILSTM or\n            self.constituency_composition == ConstituencyComposition.BILSTM_MAX):\n            node_hx = [[child.value.tree_hx.squeeze(0) for child in children] for children in children_lists]\n            label_hx = [self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]]) for label in labels]\n\n            max_length = max(len(children) for children in children_lists)\n            zeros = torch.zeros(self.hidden_size, device=label_hx[0].device)\n            # weirdly, this is faster than using pack_sequence\n            unpacked_hx = [[lhx] + nhx + [lhx] + [zeros] * (max_length - len(nhx)) for lhx, nhx in zip(label_hx, node_hx)]\n            unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in unpacked_hx]\n            packed_hx = torch.stack(unpacked_hx, axis=1)\n            packed_hx = torch.nn.utils.rnn.pack_padded_sequence(packed_hx, [len(x)+2 for x in children_lists], enforce_sorted=False)\n            lstm_output = self.constituent_reduce_lstm(packed_hx)\n            # take just the output of the final layer\n            #   result of lstm is ouput, (hx, cx)\n            #   so [1][0] gets hx\n            #      [1][0][-1] is the final output\n            # will be shape len(children_lists) * 2, hidden_size for bidirectional\n            # where forward outputs are -2 and backwards are -1\n            if self.constituency_composition == ConstituencyComposition.BILSTM:\n                lstm_output = lstm_output[1][0]\n                forward_hx = lstm_output[-2, :, :]\n                backward_hx = lstm_output[-1, :, :]\n                hx = self.reduce_linear(torch.cat((forward_hx, backward_hx), axis=1))\n            else:\n                lstm_output, lstm_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_output[0])\n                lstm_output = [lstm_output[1:length-1, x, :] for x, length in zip(range(len(lstm_lengths)), lstm_lengths)]\n                lstm_output = torch.stack([torch.max(x, 0).values for x in lstm_output], axis=0)\n                hx = self.reduce_forward(lstm_output[:, :self.hidden_size]) + self.reduce_backward(lstm_output[:, self.hidden_size:])\n            lstm_hx = self.nonlinearity(hx).unsqueeze(0)\n            lstm_cx = None\n        elif self.constituency_composition == ConstituencyComposition.MAX:\n            node_hx = [[child.value.tree_hx for child in children] for children in children_lists]\n            unpacked_hx = [self.lstm_input_dropout(torch.max(torch.stack(nhx), 0).values) for nhx in node_hx]\n            packed_hx = torch.stack(unpacked_hx, axis=1)\n            hx = self.reduce_linear(packed_hx)\n            lstm_hx = self.nonlinearity(hx)\n            lstm_cx = None\n        elif self.constituency_composition == ConstituencyComposition.UNTIED_MAX:\n            node_hx = [[child.value.tree_hx for child in children] for children in children_lists]\n            unpacked_hx = [self.lstm_input_dropout(torch.max(torch.stack(nhx), 0).values) for nhx in node_hx]\n            # shape == len(labels),1,hidden_size after the stack\n            #packed_hx = torch.stack(unpacked_hx, axis=0)\n            label_indices = [self.constituent_open_map[label] for label in labels]\n            # we would like to stack the reduce_linear_weight calculations as follows:\n            #reduce_weight = self.reduce_linear_weight[label_indices]\n            #reduce_bias = self.reduce_linear_bias[label_indices]\n            # this would allow for faster vectorized operations.\n            # however, this runs out of memory on larger training examples,\n            # presumably because there are too many stacks in a row and each one\n            # has its own gradient kept for the entire calculation\n            # fortunately, this operation is not a huge part of the expense\n            hx = [torch.matmul(self.reduce_linear_weight[label_idx], hx_layer.squeeze(0)) + self.reduce_linear_bias[label_idx]\n                  for label_idx, hx_layer in zip(label_indices, unpacked_hx)]\n            hx = torch.stack(hx, axis=0)\n            hx = hx.unsqueeze(0)\n            lstm_hx = self.nonlinearity(hx)\n            lstm_cx = None\n        elif self.constituency_composition == ConstituencyComposition.BIGRAM:\n            node_hx = [[child.value.tree_hx for child in children] for children in children_lists]\n            unpacked_hx = []\n            for nhx in node_hx:\n                # tanh or otherwise limit the size of the output?\n                stacked_nhx = self.lstm_input_dropout(torch.cat(nhx, axis=0))\n                if stacked_nhx.shape[0] > 1:\n                    bigram_hx = torch.cat((stacked_nhx[:-1, :], stacked_nhx[1:, :]), axis=1)\n                    bigram_hx = self.reduce_bigram(bigram_hx) / 2\n                    stacked_nhx = torch.cat((stacked_nhx, bigram_hx), axis=0)\n                unpacked_hx.append(torch.max(stacked_nhx, 0).values)\n            packed_hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)\n            hx = self.reduce_linear(packed_hx)\n            lstm_hx = self.nonlinearity(hx)\n            lstm_cx = None\n        elif self.constituency_composition == ConstituencyComposition.ATTN:\n            node_hx = [[child.value.tree_hx for child in children] for children in children_lists]\n            label_hx = [self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]]) for label in labels]\n            unpacked_hx = [torch.stack(nhx) for nhx in node_hx]\n            unpacked_hx = [torch.cat((lhx.unsqueeze(0).unsqueeze(0), nhx), axis=0) for lhx, nhx in zip(label_hx, unpacked_hx)]\n            unpacked_hx = [self.reduce_attn(nhx, nhx, nhx)[0].squeeze(1) for nhx in unpacked_hx]\n            unpacked_hx = [self.lstm_input_dropout(torch.max(nhx, 0).values) for nhx in unpacked_hx]\n            hx = torch.stack(unpacked_hx, axis=0)\n            lstm_hx = self.nonlinearity(hx).unsqueeze(0)\n            lstm_cx = None\n        elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:\n            node_hx = [torch.stack([child.value.tree_hx for child in children]) for children in children_lists]\n            # add a position vector to each node_hx\n            node_hx = [self.reduce_position(x.reshape(x.shape[0], -1)) for x in node_hx]\n            query_hx = [self.reduce_query(nhx) for nhx in node_hx]\n            # reshape query for MHA\n            query_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in query_hx]\n            if self.constituency_composition == ConstituencyComposition.KEY:\n                queries = [torch.matmul(nhx, self.reduce_key) for nhx in query_hx]\n            else:\n                label_indices = [self.constituent_open_map[label] for label in labels]\n                queries = [torch.matmul(nhx, self.reduce_key[label_idx]) for nhx, label_idx in zip(query_hx, label_indices)]\n            # softmax each head\n            weights = [torch.nn.functional.softmax(nhx, dim=1).transpose(1, 2) for nhx in queries]\n            value_hx = [self.reduce_value(nhx) for nhx in node_hx]\n            value_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in value_hx]\n            # use the softmaxes to add up the heads\n            unpacked_hx = [torch.matmul(weight, nhx).squeeze(1) for weight, nhx in zip(weights, value_hx)]\n            unpacked_hx = [nhx.reshape(-1) for nhx in unpacked_hx]\n            hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)\n            lstm_hx = self.nonlinearity(hx)\n            lstm_cx = None\n        elif self.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):\n            label_hx = [self.lstm_input_dropout(self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]])) for label in labels]\n            label_hx = torch.stack(label_hx).unsqueeze(0)\n\n            max_length = max(len(children) for children in children_lists)\n\n            # stacking will let us do elementwise multiplication faster, hopefully\n            node_hx = [[child.value.tree_hx for child in children] for children in children_lists]\n            unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in node_hx]\n            unpacked_hx = [nhx.max(dim=0) for nhx in unpacked_hx]\n            packed_hx = torch.stack([nhx.values for nhx in unpacked_hx], axis=1)\n            #packed_hx = packed_hx.max(dim=0).values\n\n            node_cx = [torch.stack([child.value.tree_cx for child in children]) for children in children_lists]\n            node_cx_indices = [uhx.indices.unsqueeze(0) for uhx in unpacked_hx]\n            unpacked_cx = [ncx.gather(0, nci).squeeze(0) for ncx, nci in zip(node_cx, node_cx_indices)]\n            packed_cx = torch.stack(unpacked_cx, axis=1)\n\n            _, (lstm_hx, lstm_cx) = self.constituent_reduce_lstm(label_hx, (packed_hx, packed_cx))\n        else:\n            raise ValueError(\"Unhandled ConstituencyComposition: {}\".format(self.constituency_composition))\n\n        constituents = []\n        for idx, (label, children) in enumerate(zip(labels, children_lists)):\n            children = [child.value.value for child in children]\n            if isinstance(label, str):\n                node = Tree(label=label, children=children)\n            else:\n                for value in reversed(label):\n                    node = Tree(label=value, children=children)\n                    children = node\n            constituents.append(Constituent(node, lstm_hx[:, idx, :], lstm_cx[:, idx, :] if lstm_cx is not None else None))\n        return constituents\n\n    def push_constituents(self, constituent_stacks, constituents):\n        # Another possibility here would be to use output[0, i, :]\n        # from the constituency lstm for the value of the new node.\n        # This might theoretically make the new constituent include\n        # information from neighboring constituents.  However, this\n        # lowers the scores of various models.\n        # For example, an experiment on ja_alt built this way,\n        # averaged over 5 trials, had the following loss in accuracy:\n        # 150 epochs: 0.8971 to 0.8953\n        # 200 epochs: 0.8985 to 0.8964\n        current_nodes = [stack.value for stack in constituent_stacks]\n\n        constituent_input = torch.stack([x.tree_hx[-1:] for x in constituents], axis=1)\n        #constituent_input = constituent_input.unsqueeze(0)\n        # the constituents are already Constituent(tree, tree_hx, tree_cx)\n        return self.constituent_stack.push_states(constituent_stacks, constituents, constituent_input)\n\n    def get_top_constituent(self, constituents):\n        \"\"\"\n        Extract only the top constituent from a state's constituent\n        sequence, even though it has multiple addition pieces of\n        information\n        \"\"\"\n        # TreeStack value -> LSTMTreeStack value -> Constituent value -> constituent\n        return constituents.value.value.value\n\n    def push_transitions(self, transition_stacks, transitions):\n        \"\"\"\n        Push all of the given transitions on to the stack as a batch operations.\n\n        Significantly faster than doing one transition at a time.\n        \"\"\"\n        transition_idx = torch.stack([self.transition_tensors[self.transition_map[transition]] for transition in transitions])\n        transition_input = self.transition_embedding(transition_idx).unsqueeze(0)\n        return self.transition_stack.push_states(transition_stacks, transitions, transition_input)\n\n    def get_top_transition(self, transitions):\n        \"\"\"\n        Extract only the top transition from a state's transition\n        sequence, even though it has multiple addition pieces of\n        information\n        \"\"\"\n        # TreeStack value -> LSTMTreeStack value -> transition\n        return transitions.value.value\n\n    def forward(self, states):\n        \"\"\"\n        Return logits for a prediction of what transition to make next\n\n        We've basically done all the work analyzing the state as\n        part of applying the transitions, so this method is very simple\n\n        return shape: (num_states, num_transitions)\n        \"\"\"\n        word_hx = torch.stack([state.get_word(state.word_position).hx for state in states])\n        transition_hx = torch.stack([self.transition_stack.output(state.transitions) for state in states])\n        # this .output() is the output of the constituent stack, not the\n        # constituent itself\n        # this way, we can, as an option, NOT include the constituents to the left\n        # when building the current vector for a constituent\n        # and the vector used for inference will still incorporate the entire LSTM\n        constituent_hx = torch.stack([self.constituent_stack.output(state.constituents) for state in states])\n\n        hx = torch.cat((word_hx, transition_hx, constituent_hx), axis=1)\n        for idx, output_layer in enumerate(self.output_layers):\n            hx = self.predict_dropout(hx)\n            # TODO: why self.output_layers - 1?\n            if not self.maxout_k and idx < len(self.output_layers) - 1:\n                hx = self.nonlinearity(hx)\n            hx = output_layer(hx)\n        return hx\n\n    def predict(self, states, is_legal=True):\n        \"\"\"\n        Generate and return predictions, along with the transitions those predictions represent\n\n        If is_legal is set to True, will only return legal transitions.\n        This means returning None if there are no legal transitions.\n        Hopefully the constraints prevent that from happening\n\n        Returns:\n          tensor(batch_size, num_transitions) - final output layer\n          list(Transition) - predicted transitions\n          tensor(batch_size) - the final output specifically for the chosen transition\n        \"\"\"\n        predictions = self.forward(states)\n        pred_max = torch.argmax(predictions, dim=1)\n        scores = torch.take_along_dim(predictions, pred_max.unsqueeze(1), dim=1)\n        pred_max = pred_max.detach().cpu()\n\n        pred_trans = [self.transitions[pred_max[idx]] for idx in range(len(states))]\n        if is_legal:\n            for idx, (state, trans) in enumerate(zip(states, pred_trans)):\n                if not trans.is_legal(state, self):\n                    _, indices = predictions[idx, :].sort(descending=True)\n                    for index in indices:\n                        if self.transitions[index].is_legal(state, self):\n                            pred_trans[idx] = self.transitions[index]\n                            scores[idx] = predictions[idx, index]\n                            break\n                    else: # yeah, else on a for loop, deal with it\n                        pred_trans[idx] = None\n                        scores[idx] = None\n\n        return predictions, pred_trans, scores.squeeze(1)\n\n    def weighted_choice(self, states):\n        \"\"\"\n        Generate and return predictions, and randomly choose a prediction weighted by the scores\n\n        TODO: pass in a temperature\n        \"\"\"\n        predictions = self.forward(states)\n        pred_trans = []\n        all_scores = []\n        for state, prediction in zip(states, predictions):\n            legal_idx = [idx for idx in range(prediction.shape[0]) if self.transitions[idx].is_legal(state, self)]\n            if len(legal_idx) == 0:\n                pred_trans.append(None)\n                continue\n            scores = prediction[legal_idx]\n            scores = torch.softmax(scores, dim=0)\n            idx = torch.multinomial(scores, 1)\n            idx = legal_idx[idx]\n            pred_trans.append(self.transitions[idx])\n            all_scores.append(prediction[idx])\n        all_scores = torch.stack(all_scores)\n        return predictions, pred_trans, all_scores\n\n    def predict_gold(self, states):\n        \"\"\"\n        For each State, return the next item in the gold_sequence\n        \"\"\"\n        predictions = self.forward(states)\n        transitions = [y.gold_sequence[y.num_transitions] for y in states]\n        indices = torch.tensor([self.transition_map[t] for t in transitions], device=predictions.device)\n        scores = torch.take_along_dim(predictions, indices.unsqueeze(1), dim=1)\n        return predictions, transitions, scores.squeeze(1)\n\n    def get_params(self, skip_modules=True):\n        \"\"\"\n        Get a dictionary for saving the model\n        \"\"\"\n        model_state = self.state_dict()\n        # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file\n        if skip_modules:\n            skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]\n            for k in skipped:\n                del model_state[k]\n        config = copy.deepcopy(self.args)\n        config['sentence_boundary_vectors'] = config['sentence_boundary_vectors'].name\n        config['constituency_composition'] = config['constituency_composition'].name\n        config['transition_stack'] = config['transition_stack'].name\n        config['constituent_stack'] = config['constituent_stack'].name\n        config['transition_scheme'] = config['transition_scheme'].name\n        assert isinstance(self.rare_words, set)\n        params = {\n            'model': model_state,\n            'model_type': \"LSTM\",\n            'config': config,\n            'transitions': [repr(x) for x in self.transitions],\n            'constituents': self.constituents,\n            'tags': self.tags,\n            'words': self.delta_words,\n            'rare_words': list(self.rare_words),\n            'root_labels': self.root_labels,\n            'constituent_opens': self.constituent_opens,\n            'unary_limit': self.unary_limit(),\n        }\n\n        return params\n\n"
  },
  {
    "path": "stanza/models/constituency/lstm_tree_stack.py",
    "content": "\"\"\"\nKeeps an LSTM in TreeStack form.\n\nThe TreeStack nodes keep the hx and cx for the LSTM, along with a\n\"value\" which represents whatever the user needs to store.\n\nThe TreeStacks can be ppped to get back to the previous LSTM state.\n\nThe module itself implements three methods: initial_state, push_states, output\n\"\"\"\n\nfrom collections import namedtuple\n\nimport torch\nimport torch.nn as nn\n\nfrom stanza.models.constituency.tree_stack import TreeStack\n\nNode = namedtuple(\"Node\", ['value', 'lstm_hx', 'lstm_cx'])\n\nclass LSTMTreeStack(nn.Module):\n    def __init__(self, input_size, hidden_size, num_lstm_layers, dropout, uses_boundary_vector, input_dropout):\n        \"\"\"\n        Prepare LSTM and parameters\n\n        input_size: dimension of the inputs to the LSTM\n        hidden_size: LSTM internal & output dimension\n        num_lstm_layers: how many layers of LSTM to use\n        dropout: value of the LSTM dropout\n        uses_boundary_vector: if set, learn a start_embedding parameter.  otherwise, use zeros\n        input_dropout: an nn.Module to dropout inputs.  TODO: allow a float parameter as well\n        \"\"\"\n        super().__init__()\n\n        self.uses_boundary_vector = uses_boundary_vector\n\n        # The start embedding needs to be input_size as we put it through the LSTM\n        if uses_boundary_vector:\n            self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))\n        else:\n            self.register_buffer('input_zeros',  torch.zeros(num_lstm_layers, 1, input_size))\n            self.register_buffer('hidden_zeros', torch.zeros(num_lstm_layers, 1, hidden_size))\n\n        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_lstm_layers, dropout=dropout)\n        self.input_dropout = input_dropout\n\n\n    def initial_state(self, initial_value=None):\n        \"\"\"\n        Return an initial state, either based on zeros or based on the initial embedding and LSTM\n\n        Note that LSTM start operation is already batched, in a sense\n        The subsequent batch built this way will be used for batch_size trees\n\n        Returns a stack with None value, hx & cx either based on the\n        start_embedding or zeros, and no parent.\n        \"\"\"\n        if self.uses_boundary_vector:\n            start = self.start_embedding.unsqueeze(0).unsqueeze(0)\n            output, (hx, cx) = self.lstm(start)\n            start = output[0, 0, :]\n        else:\n            start = self.input_zeros\n            hx = self.hidden_zeros\n            cx = self.hidden_zeros\n        return TreeStack(value=Node(initial_value, hx, cx), parent=None, length=1)\n\n    def push_states(self, stacks, values, inputs):\n        \"\"\"\n        Starting from a list of current stacks, put the inputs through the LSTM and build new stack nodes.\n\n        B = stacks.len() = values.len()\n\n        inputs must be of shape 1 x B x input_size\n        \"\"\"\n        inputs = self.input_dropout(inputs)\n\n        hx = torch.cat([t.value.lstm_hx for t in stacks], axis=1)\n        cx = torch.cat([t.value.lstm_cx for t in stacks], axis=1)\n        output, (hx, cx) = self.lstm(inputs, (hx, cx))\n        new_stacks = [stack.push(Node(transition, hx[:, i:i+1, :], cx[:, i:i+1, :]))\n                      for i, (stack, transition) in enumerate(zip(stacks, values))]\n        return new_stacks\n\n    def output(self, stack):\n        \"\"\"\n        Return the last layer of the lstm_hx as the output from a stack\n\n        Refactored so that alternate structures have an easy way of getting the output\n        \"\"\"\n        return stack.value.lstm_hx[-1, 0, :]\n"
  },
  {
    "path": "stanza/models/constituency/parse_transitions.py",
    "content": "\"\"\"\nDefines a series of transitions (open a constituent, close a constituent, etc)\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nimport ast\nfrom collections import defaultdict\nfrom enum import Enum\nimport functools\nimport logging\n\nfrom stanza.models.constituency.parse_tree import Tree\n\nlogger = logging.getLogger('stanza')\n\nclass TransitionScheme(Enum):\n    def __new__(cls, value, short_name):\n        obj = object.__new__(cls)\n        obj._value_ = value\n        obj.short_name = short_name\n        return obj\n\n\n    # top down, so the open transition comes before any constituents\n    # score on vi_vlsp22 with 5 different sizes of bert layers,\n    # bert tagger, no silver dataset:\n    #   0.8171\n    TOP_DOWN           = 1, \"top\"\n    # unary transitions are modeled as one entire transition\n    # version that uses one transform per item,\n    # score on experiment described above:\n    #   0.8157\n    # score using one combination step for an entire transition:\n    #   0.8178\n    TOP_DOWN_COMPOUND  = 2, \"topc\"\n    # unary is a separate transition.  doesn't help\n    # score on experiment described above:\n    #   0.8128\n    TOP_DOWN_UNARY     = 3, \"topu\"\n\n    # open transition comes after the first constituent it cares about\n    # score on experiment described above:\n    #   0.8205\n    # note that this is with an oracle, whereas IN_ORDER_COMPOUND does\n    # not have a dynamic oracle, so there may be room for improvement\n    IN_ORDER           = 4, \"in\"\n\n    # in order, with unaries after preterminals represented as a single\n    # transition after the preterminal\n    # and unaries elsewhere tied to the rest of the constituent\n    # score: 0.8186\n    IN_ORDER_COMPOUND  = 5, \"inc\"\n\n    # in order, with CompoundUnary on both preterminals and internal nodes\n    # score: 0.8166\n    IN_ORDER_UNARY     = 6, \"inu\"\n\n@functools.total_ordering\nclass Transition(ABC):\n    \"\"\"\n    model is passed in as a dependency injection\n    for example, an LSTM model can update hidden & output vectors when transitioning\n    \"\"\"\n    @abstractmethod\n    def update_state(self, state, model):\n        \"\"\"\n        update the word queue position, possibly remove old pieces from the constituents state, and return the new constituent\n\n        the return value should be a tuple:\n          updated word_position\n          updated constituents\n          new constituent to put on the queue and None\n            - note that the constituent shouldn't be on the queue yet\n              that allows putting it on as a batch operation, which\n              saves a significant amount of time in an LSTM, for example\n          OR\n          data used to make a new constituent and the method used\n            - for example, CloseConstituent can return the children needed\n              and itself.  this allows a batch operation to build\n              the constituent\n        \"\"\"\n\n    def delta_opens(self):\n        return 0\n\n    def apply(self, state, model):\n        \"\"\"\n        return a new State transformed via this transition\n\n        convenience method to call bulk_apply, which is significantly\n        faster than single operations for an NN based model\n        \"\"\"\n        update = model.bulk_apply([state], [self])\n        return update[0]\n\n    @abstractmethod\n    def is_legal(self, state, model):\n        \"\"\"\n        assess whether or not this transition is legal in this state\n\n        at parse time, the parser might choose a transition which cannot be made\n        \"\"\"\n\n    def components(self):\n        \"\"\"\n        Return a list of transitions which could theoretically make up this transition\n\n        For example, an Open transition with multiple labels would\n        return a list of Opens with those labels\n        \"\"\"\n        return [self]\n\n    @abstractmethod\n    def short_name(self):\n        \"\"\"\n        A short name to identify this transition\n        \"\"\"\n\n    def short_label(self):\n        if not hasattr(self, \"label\"):\n            return self.short_name()\n\n        if isinstance(self.label, str):\n            label = self.label\n        elif len(self.label) == 1:\n            label = self.label[0]\n        else:\n            label = self.label\n        return \"{}({})\".format(self.short_name(), label)\n\n    def __lt__(self, other):\n        # put the Shift at the front of a list, and otherwise sort alphabetically\n        if self == other:\n            return False\n        if isinstance(self, Shift):\n            return True\n        if isinstance(other, Shift):\n            return False\n        return str(self) < str(other)\n\n\n    @staticmethod\n    def from_repr(desc):\n        \"\"\"\n        This method is to avoid using eval() or otherwise trying to\n        deserialize strings in a possibly untrusted manner when\n        loading from a checkpoint\n        \"\"\"\n        if desc == 'Shift':\n            return Shift()\n        if desc == 'CloseConstituent':\n            return CloseConstituent()\n        labels = desc.split(\"(\", maxsplit=1)\n        if labels[0] not in ('CompoundUnary', 'OpenConstituent', 'Finalize'):\n            raise ValueError(\"Unknown Transition %s\" % desc)\n        if len(labels) == 1:\n            raise ValueError(\"Unexpected Transition repr, %s needs labels\" % labels[0])\n        if labels[1][-1] != ')':\n            raise ValueError(\"Expected Transition repr for %s: %s(labels)\" % (labels[0], labels[0]))\n        trans_type = labels[0]\n        labels = labels[1][:-1]\n        labels = ast.literal_eval(labels)\n        if trans_type == 'CompoundUnary':\n            return CompoundUnary(*labels)\n        if trans_type == 'OpenConstituent':\n            return OpenConstituent(*labels)\n        if trans_type == 'Finalize':\n            return Finalize(*labels)\n        raise ValueError(\"Unexpected Transition %s\" % desc)\n\nclass Shift(Transition):\n    def update_state(self, state, model):\n        \"\"\"\n        This will handle all aspects of a shift transition\n\n        - push the top element of the word queue onto constituents\n        - pop the top element of the word queue\n        \"\"\"\n        new_constituent = model.transform_word_to_constituent(state)\n        return state.word_position+1, state.constituents, new_constituent, None\n\n    def is_legal(self, state, model):\n        \"\"\"\n        Disallow shifting when the word queue is empty or there are no opens to eventually eat this word\n        \"\"\"\n        if state.empty_word_queue():\n            return False\n        if model.is_top_down:\n            # top down transition sequences cannot shift if there are currently no\n            # Open transitions on the stack.  in such a case, the new constituent\n            # will never be reduced\n            if state.num_opens == 0:\n                return False\n            if state.num_opens == 1:\n                # there must be at least one transition, since there is an open\n                assert state.transitions.parent is not None\n                if state.transitions.parent.parent is None:\n                    # only one transition\n                    trans = model.get_top_transition(state.transitions)\n                    # must be an Open, since there is one open and one transitions\n                    # note that an S, FRAG, etc could happen if we're using unary\n                    # and ROOT-S is possible in the case of compound Open\n                    # in both cases, Shift is legal\n                    # Note that the corresponding problem of shifting after the ROOT-S\n                    # has been closed to just ROOT is handled in CloseConstituent\n                    if len(trans.label) == 1 and trans.top_label in model.root_labels:\n                        # don't shift a word at the very start of a parse\n                        # we want there to be an extra layer below ROOT\n                        return False\n        else:\n            # in-order k==1 (the only other option currently)\n            # can shift ONCE, but note that there is no way to consume\n            # two items in a row if there is no Open on the stack.\n            # As long as there is one or more open transitions,\n            # everything can be eaten\n            if state.num_opens == 0:\n                if not state.empty_constituents:\n                    return False\n        return True\n\n    def short_name(self):\n        return \"Shift\"\n\n    def __repr__(self):\n        return \"Shift\"\n\n    def __eq__(self, other):\n        if self is other:\n            return True\n        if isinstance(other, Shift):\n            return True\n        return False\n\n    def __hash__(self):\n        return hash(37)\n\nclass CompoundUnary(Transition):\n    def __init__(self, *label):\n        # the FIRST label will be the top of the tree\n        # so CompoundUnary that results in root will have root as labels[0], for example\n        self.label = tuple(label)\n\n    def update_state(self, state, model):\n        \"\"\"\n        Apply potentially multiple unary transitions to the same preterminal\n\n        It reuses the CloseConstituent machinery\n        \"\"\"\n        # only the top constituent is meaningful here\n        constituents = state.constituents\n        children = [constituents.value]\n        constituents = constituents.pop()\n        # unlike with CloseConstituent, our label is not on the stack.\n        # it is just our label\n        # ... but we do reuse CloseConstituent's update mechanism\n        return state.word_position, constituents, (self.label, children), CloseConstituent\n\n    def is_legal(self, state, model):\n        \"\"\"\n        Disallow consecutive CompoundUnary transitions, force final transition to go to ROOT\n        \"\"\"\n        # can't unary transition nothing\n        tree = model.get_top_constituent(state.constituents)\n        if tree is None:\n            return False\n        # don't unary transition a dummy, dummy\n        # and don't stack CompoundUnary transitions\n        if isinstance(model.get_top_transition(state.transitions), (CompoundUnary, OpenConstituent)):\n            return False\n        # if we are doing IN_ORDER_COMPOUND, then we are only using these\n        # transitions to model changes from a tag node to a sequence of\n        # unary nodes.  can only occur at preterminals\n        if model.transition_scheme() is TransitionScheme.IN_ORDER_COMPOUND:\n            return tree.is_preterminal()\n        if model.transition_scheme() is not TransitionScheme.TOP_DOWN_UNARY:\n            return True\n\n        is_root = self.label[0] in model.root_labels\n        if not state.empty_word_queue() or not state.has_one_constituent():\n            return not is_root\n        else:\n            return is_root\n\n    def components(self):\n        return [CompoundUnary(label) for label in self.label]\n\n    def short_name(self):\n        return \"Unary\"\n\n    def __repr__(self):\n        return \"CompoundUnary(%s)\" % \",\".join(self.label)\n\n    def __eq__(self, other):\n        if self is other:\n            return True\n        if not isinstance(other, CompoundUnary):\n            return False\n        if self.label == other.label:\n            return True\n        return False\n\n    def __hash__(self):\n        return hash(self.label)\n\nclass Dummy():\n    \"\"\"\n    Takes a space on the constituent stack to represent where an Open transition occurred\n    \"\"\"\n    def __init__(self, label):\n        self.label = label\n\n    def is_preterminal(self):\n        return False\n\n    def __format__(self, spec):\n        if spec is None or spec == '' or spec == 'O':\n            return \"(%s ...)\" % self.label\n        if spec == 'T':\n            return r\"\\\\Tree [.%s ? ]\" % self.label\n        raise ValueError(\"Unhandled spec: %s\" % spec)\n\n    def __str__(self):\n        return \"Dummy({})\".format(self.label)\n\n    def __eq__(self, other):\n        if self is other:\n            return True\n        if not isinstance(other, Dummy):\n            return False\n        if self.label == other.label:\n            return True\n        return False\n\n    def __hash__(self):\n        return hash(self.label)\n\ndef too_many_unary_nodes(tree, unary_limit):\n    \"\"\"\n    Return True iff there are UNARY_LIMIT unary nodes in a tree in a row\n\n    helps prevent infinite open/close patterns\n    otherwise, the model can get stuck in essentially an infinite loop\n    \"\"\"\n    if tree is None:\n        return False\n    for _ in range(unary_limit + 1):\n        if len(tree.children) != 1:\n            return False\n        tree = tree.children[0]\n    return True\n\nclass OpenConstituent(Transition):\n    def __init__(self, *label):\n        self.label = tuple(label)\n        self.top_label = self.label[0]\n\n    def delta_opens(self):\n        return 1\n\n    def update_state(self, state, model):\n        # open a new constituent which can later be closed\n        # puts a DUMMY constituent on the stack to mark where the constituents end\n        return state.word_position, state.constituents, model.dummy_constituent(Dummy(self.label)), None\n\n    def is_legal(self, state, model):\n        \"\"\"\n        disallow based on the length of the sentence\n        \"\"\"\n        if state.num_opens > state.sentence_length + 10:\n            # fudge a bit so we don't miss root nodes etc in very small trees\n            # also there's one really deep tree in CTB 9.0\n            return False\n        if model.is_top_down:\n            # If the model is top down, you can't Open if there are\n            # no words to eventually eat\n            if state.empty_word_queue():\n                return False\n            # Also, you can only Open a ROOT iff it is at the root position\n            # The assumption in the unary scheme is there will be no\n            # root open transitions\n            if not model.has_unary_transitions():\n                # TODO: maybe cache this value if this is an expensive operation\n                is_root = self.top_label in model.root_labels\n                if is_root:\n                    return state.empty_transitions()\n                else:\n                    return not state.empty_transitions()\n        else:\n            # in-order nodes can Open as long as there is at least one thing\n            # on the constituency stack\n            # since closing the in-order involves removing one more\n            # item before the open, and it can close at any time\n            # (a close immediately after the open represents a unary)\n            if state.empty_constituents:\n                return False\n            if isinstance(model.get_top_transition(state.transitions), OpenConstituent):\n                # consecutive Opens don't make sense in the context of in-order\n                return False\n            if not model.transition_scheme() is TransitionScheme.IN_ORDER:\n                # eg, IN_ORDER_UNARY or IN_ORDER_COMPOUND\n                # if compound unary opens are used\n                # or the unary transitions are via CompoundUnary\n                # can always open as long as the word queue isn't empty\n                # if the word queue is empty, only close is allowed\n                return not state.empty_word_queue()\n            # one other restriction - we assume all parse trees\n            # start with (ROOT (first_real_con ...))\n            # therefore ROOT can only occur via Open after everything\n            # else has been pushed and processed\n            # there are no further restrictions\n            is_root = self.top_label in model.root_labels\n            if is_root:\n                # can't make a root node if it will be in the middle of the parse\n                # can't make a root node if there's still words to eat\n                # note that the second assumption wouldn't work,\n                # except we are assuming there will never be multiple\n                # nodes under one root\n                return state.num_opens == 0 and state.empty_word_queue()\n            else:\n                if (state.num_opens > 0 or state.empty_word_queue()) and too_many_unary_nodes(model.get_top_constituent(state.constituents), model.unary_limit()):\n                    # looks like we've been in a loop of lots of unary transitions\n                    # note that we check `num_opens > 0` because otherwise we might wind up stuck\n                    # in a state where the only legal transition is open, such as if the\n                    # constituent stack is otherwise empty, but the open is illegal because\n                    # it causes too many unaries\n                    # in such a case we can forbid the corresponding close instead...\n                    # if empty_word_queue, that means it is trying to make infinitiely many\n                    # non-ROOT Open transitions instead of just transitioning ROOT\n                    return False\n                return True\n        return True\n\n    def components(self):\n        return [OpenConstituent(label) for label in self.label]\n\n    def short_name(self):\n        return \"Open\"\n\n    def __repr__(self):\n        return \"OpenConstituent({})\".format(self.label)\n\n    def __eq__(self, other):\n        if self is other:\n            return True\n        if not isinstance(other, OpenConstituent):\n            return False\n        if self.label == other.label:\n            return True\n        return False\n\n    def __hash__(self):\n        return hash(self.label)\n\nclass Finalize(Transition):\n    \"\"\"\n    Specifically applies at the end of a parse sequence to add a ROOT\n\n    Seemed like the simplest way to remove ROOT from the\n    in_order_compound transitions while still using the mechanism of\n    the transitions to build the parse tree\n    \"\"\"\n    def __init__(self, *label):\n        self.label = tuple(label)\n\n    def update_state(self, state, model):\n        \"\"\"\n        Apply potentially multiple unary transitions to the same preterminal\n\n        Only applies to preterminals\n        It reuses the CloseConstituent machinery\n        \"\"\"\n        # only the top constituent is meaningful here\n        constituents = state.constituents\n        children = [constituents.value]\n        constituents = constituents.pop()\n        # unlike with CloseConstituent, our label is not on the stack.\n        # it is just our label\n        label = self.label\n\n        # ... but we do reuse CloseConstituent's update\n        return state.word_position, constituents, (label, children), CloseConstituent\n\n    def is_legal(self, state, model):\n        \"\"\"\n        Legal if & only if there is one tree, no more words, and no ROOT yet\n        \"\"\"\n        return state.empty_word_queue() and state.has_one_constituent() and not state.finished(model)\n\n    def short_name(self):\n        return \"Finalize\"\n\n    def __repr__(self):\n        return \"Finalize(%s)\" % \",\".join(self.label)\n\n    def __eq__(self, other):\n        if self is other:\n            return True\n        if not isinstance(other, Finalize):\n            return False\n        return other.label == self.label\n\n    def __hash__(self):\n        return hash((53, self.label))\n\nclass CloseConstituent(Transition):\n    def delta_opens(self):\n        return -1\n\n    def update_state(self, state, model):\n        # pop constituents until we are done\n        children = []\n        constituents = state.constituents\n        while not isinstance(model.get_top_constituent(constituents), Dummy):\n            # keep the entire value from the stack - the model may need\n            # the whole thing to transform the children into a new node\n            children.append(constituents.value)\n            constituents = constituents.pop()\n        # the Dummy has the label on it\n        label = model.get_top_constituent(constituents).label\n        # pop past the Dummy as well\n        constituents = constituents.pop()\n        if not model.is_top_down:\n            # the alternative to TOP_DOWN_... is IN_ORDER\n            # in which case we want to pop one more constituent\n            children.append(constituents.value)\n            constituents = constituents.pop()\n        # the children are in the opposite order of what we expect\n        children.reverse()\n\n        return state.word_position, constituents, (label, children), CloseConstituent\n\n    @staticmethod\n    def build_constituents(model, data):\n        \"\"\"\n        builds new constituents out of the incoming data\n\n        data is a list of tuples: (label, children)\n        the model will batch the build operation\n        again, the purpose of this batching is to do multiple deep learning operations at once\n        \"\"\"\n        labels, children_lists = map(list, zip(*data))\n        new_constituents = model.build_constituents(labels, children_lists)\n        return new_constituents\n\n\n    def is_legal(self, state, model):\n        \"\"\"\n        Disallow if there is no Open on the stack yet\n\n        in TOP_DOWN, if the previous transition was the Open (nothing built yet)\n        in IN_ORDER, previous transition does not matter, except for one small corner case\n        \"\"\"\n        if state.num_opens <= 0:\n            return False\n        if model.is_top_down:\n            if isinstance(model.get_top_transition(state.transitions), OpenConstituent):\n                return False\n            if state.num_opens <= 1 and not state.empty_word_queue():\n                # don't close the last open until all words have been used\n                return False\n            if model.transition_scheme() == TransitionScheme.TOP_DOWN_COMPOUND:\n                # when doing TOP_DOWN_COMPOUND, we assume all transitions\n                # at the ROOT level have an S, SQ, FRAG, etc underneath\n                # this is checked when the model is first trained\n                if state.num_opens == 1 and not state.empty_word_queue():\n                    return False\n            elif not model.has_unary_transitions():\n                # in fact, we have to leave the top level constituent\n                # under the ROOT open if unary transitions are not possible\n                if state.num_opens == 2 and not state.empty_word_queue():\n                    return False\n        elif model.transition_scheme() is TransitionScheme.IN_ORDER:\n            if not isinstance(model.get_top_transition(state.transitions), OpenConstituent):\n                # we're not stuck in a loop of unaries\n                return True\n            if state.num_opens > 1 or state.empty_word_queue():\n                # in either of these cases, the corresponding Open should be eliminated\n                # if we're stuck in a loop of unaries\n                return True\n            node = model.get_top_constituent(state.constituents.pop())\n            if too_many_unary_nodes(node, model.unary_limit()):\n                # at this point, we are in a situation where\n                # - multiple unaries have happened in a row\n                # - there is stuff on the word_queue, so a ROOT open isn't legal\n                # - there's only one constituent on the stack, so the only legal\n                #   option once there are no opens left will be an open\n                # this means we'll be stuck having to open again if we do close\n                # this node, so instead we make the Close illegal\n                return False\n        else:\n            # model.transition_scheme() == TransitionScheme.IN_ORDER_COMPOUND or\n            # model.transition_scheme() == TransitionScheme.IN_ORDER_UNARY:\n            # in both of these cases, we cannot do open/close\n            #   IN_ORDER_COMPOUND will use compound opens and preterminal unaries\n            #   IN_ORDER_UNARY will use compound unaries\n            # the only restriction here is that we can't close immediately after an open\n            #   internal unaries are handled by the opens being compound\n            #   preterminal unaries are handled with CompoundUnary\n            if isinstance(model.get_top_transition(state.transitions), OpenConstituent):\n                return False\n        return True\n\n    def short_name(self):\n        return \"Close\"\n\n    def __repr__(self):\n        return \"CloseConstituent\"\n\n    def __eq__(self, other):\n        if self is other:\n            return True\n        if isinstance(other, CloseConstituent):\n            return True\n        return False\n\n    def __hash__(self):\n        return hash(93)\n\ndef check_transitions(train_transitions, other_transitions, treebank_name):\n    \"\"\"\n    Check that all the transitions in the other dataset are known in the train set\n\n    Weird nested unaries are warned rather than failed as long as the\n    components are all known\n\n    There is a tree in VLSP, for example, with three (!) nested NP nodes\n    If this is an unknown compound transition, we won't possibly get it\n    right when parsing, but at least we don't need to fail\n    \"\"\"\n    unknown_transitions = set()\n    for trans in other_transitions:\n        if trans not in train_transitions:\n            for component in trans.components():\n                if component not in train_transitions:\n                    raise RuntimeError(\"Found transition {} in the {} set which don't exist in the train set\".format(trans, treebank_name))\n            unknown_transitions.add(trans)\n    if len(unknown_transitions) > 0:\n        logger.warning(\"Found transitions where the components are all valid transitions, but the complete transition is unknown: %s\", sorted(unknown_transitions))\n"
  },
  {
    "path": "stanza/models/constituency/parse_tree.py",
    "content": "\"\"\"\nTree datastructure\n\"\"\"\n\nfrom collections import deque, Counter\nimport copy\nfrom enum import Enum\nfrom io import StringIO\nimport itertools\nimport re\nimport warnings\n\nfrom stanza.models.common.stanza_object import StanzaObject\n\n# useful more for the \"is\" functionality than the time savings\nCLOSE_PAREN = ')'\nSPACE_SEPARATOR = ' '\nOPEN_PAREN = '('\n\nEMPTY_CHILDREN = ()\n\n# used to split off the functional tags from various treebanks\n# for example, the Icelandic treebank (which we don't currently\n# incorporate) uses * to distinguish 'ADJP', 'ADJP*OC' but we treat\n# those as the same\nCONSTITUENT_SPLIT = re.compile(\"[-=#*]\")\n\n# These words occur in the VLSP dataset.\n# The documentation claims there might be *O*, although those don't\n# seem to exist in practice\nWORDS_TO_PRUNE = ('*E*', '*T*', '*O*')\n\nclass TreePrintMethod(Enum):\n    \"\"\"\n    Describes a few options for printing trees.\n\n    This probably doesn't need to be used directly.  See __format__\n    \"\"\"\n    ONE_LINE          = 1  # (ROOT (S ...  ))\n    LABELED_PARENS    = 2  # (_ROOT (_S ... )_S )_ROOT\n    PRETTY            = 3  # multiple lines\n    VLSP              = 4  # <s> (S ... ) </s>\n    LATEX_TREE        = 5  # \\Tree [.S [.NP ... ] ]\n\n\nclass Tree(StanzaObject):\n    \"\"\"\n    A data structure to represent a parse tree\n    \"\"\"\n    def __init__(self, label=None, children=None):\n        if children is None:\n            self.children = EMPTY_CHILDREN\n        elif isinstance(children, Tree):\n            self.children = (children,)\n        else:\n            self.children = tuple(children)\n\n        self.label = label\n\n    def is_leaf(self):\n        return len(self.children) == 0\n\n    def is_preterminal(self):\n        return len(self.children) == 1 and len(self.children[0].children) == 0\n\n    def yield_preterminals(self):\n        \"\"\"\n        Yield the preterminals one at a time in order\n        \"\"\"\n        if self.is_preterminal():\n            yield self\n            return\n\n        if self.is_leaf():\n            raise ValueError(\"Attempted to iterate preterminals on non-internal node\")\n\n        iterator = iter(self.children)\n        node = next(iterator, None)\n        while node is not None:\n            if node.is_preterminal():\n                yield node\n            else:\n                iterator = itertools.chain(node.children, iterator)\n            node = next(iterator, None)\n\n    def leaf_labels(self):\n        \"\"\"\n        Get the labels of the leaves\n        \"\"\"\n        if self.is_leaf():\n            return [self.label]\n\n        words = [x.children[0].label for x in self.yield_preterminals()]\n        return words\n\n    def __len__(self):\n        return len(self.leaf_labels())\n\n    def all_leaves_are_preterminals(self):\n        \"\"\"\n        Returns True if all leaves are under preterminals, False otherwise\n        \"\"\"\n        if self.is_leaf():\n            return False\n\n        if self.is_preterminal():\n            return True\n\n        return all(t.all_leaves_are_preterminals() for t in self.children)\n\n    def pretty_print(self, normalize=None):\n        \"\"\"\n        Print with newlines & indentation on each line\n\n        Preterminals and nodes with all preterminal children go on their own line\n\n        You can pass in your own normalize() function.  If you do,\n        make sure the function updates the parens to be something\n        other than () or the brackets will be broken\n        \"\"\"\n        if normalize is None:\n            normalize = lambda x: x.replace(\"(\", \"-LRB-\").replace(\")\", \"-RRB-\")\n\n        indent = 0\n        with StringIO() as buf:\n            stack = deque()\n            stack.append(self)\n            while len(stack) > 0:\n                node = stack.pop()\n\n                if node is CLOSE_PAREN:\n                    # if we're trying to pretty print trees, pop all off close parens\n                    # then write a newline\n                    while node is CLOSE_PAREN:\n                        indent -= 1\n                        buf.write(CLOSE_PAREN)\n                        if len(stack) == 0:\n                            node = None\n                            break\n                        node = stack.pop()\n                    buf.write(\"\\n\")\n                    if node is None:\n                        break\n                    stack.append(node)\n                elif node.is_preterminal():\n                    buf.write(\"  \" * indent)\n                    buf.write(\"%s%s %s%s\" % (OPEN_PAREN, normalize(node.label), normalize(node.children[0].label), CLOSE_PAREN))\n                    if len(stack) == 0 or stack[-1] is not CLOSE_PAREN:\n                        buf.write(\"\\n\")\n                elif all(x.is_preterminal() for x in node.children):\n                    buf.write(\"  \" * indent)\n                    buf.write(\"%s%s\" % (OPEN_PAREN, normalize(node.label)))\n                    for child in node.children:\n                        buf.write(\" %s%s %s%s\" % (OPEN_PAREN, normalize(child.label), normalize(child.children[0].label), CLOSE_PAREN))\n                    buf.write(CLOSE_PAREN)\n                    if len(stack) == 0 or stack[-1] is not CLOSE_PAREN:\n                        buf.write(\"\\n\")\n                else:\n                    buf.write(\"  \" * indent)\n                    buf.write(\"%s%s\\n\" % (OPEN_PAREN, normalize(node.label)))\n                    stack.append(CLOSE_PAREN)\n                    for child in reversed(node.children):\n                        stack.append(child)\n                    indent += 1\n\n            buf.seek(0)\n            return buf.read()\n\n    def __format__(self, spec):\n        \"\"\"\n        Turn the tree into a string representing the tree\n\n        Note that this is not a recursive traversal\n        Otherwise, a tree too deep might blow up the call stack\n\n        There is a type specific format:\n          O       -> one line PTB format, which is the default anyway\n          L       -> open and close brackets are labeled, spaces in the tokens are replaced with _\n          P       -> pretty print over multiple lines\n          V       -> surround lines with <s>...</s>, don't print ROOT, and turn () into L/RBKT\n          ?       -> spaces in the tokens are replaced with ? for any value of ? other than OLP\n                     warning: this may be removed in the future\n          ?{OLPV} -> specific format AND a custom space replacement\n          Vi      -> add an ID to the <s> in the V format.  Also works with ?Vi\n        \"\"\"\n        space_replacement = \" \"\n        print_format = TreePrintMethod.ONE_LINE\n        if spec == 'L':\n            print_format = TreePrintMethod.LABELED_PARENS\n            space_replacement = \"_\"\n        elif spec and spec[-1] == 'L':\n            print_format = TreePrintMethod.LABELED_PARENS\n            space_replacement = spec[0]\n        elif spec == 'O':\n            print_format = TreePrintMethod.ONE_LINE\n        elif spec and spec[-1] == 'O':\n            print_format = TreePrintMethod.ONE_LINE\n            space_replacement = spec[0]\n        elif spec == 'P':\n            print_format = TreePrintMethod.PRETTY\n        elif spec and spec[-1] == 'P':\n            print_format = TreePrintMethod.PRETTY\n            space_replacement = spec[0]\n        elif spec and spec[0] == 'V':\n            print_format = TreePrintMethod.VLSP\n            use_tree_id = spec[-1] == 'i'\n        elif spec and len(spec) > 1 and spec[1] == 'V':\n            print_format = TreePrintMethod.VLSP\n            space_replacement = spec[0]\n            use_tree_id = spec[-1] == 'i'\n        elif spec == 'T':\n            print_format = TreePrintMethod.LATEX_TREE\n        elif spec and len(spec) > 1 and spec[1] == 'T':\n            print_format = TreePrintMethod.LATEX_TREE\n            space_replacement = spec[0]\n        elif spec:\n            space_replacement = spec[0]\n            warnings.warn(\"Use of a custom replacement without a format specifier is deprecated.  Please use {}O instead\".format(space_replacement), stacklevel=2)\n\n        LRB = \"LBKT\" if print_format == TreePrintMethod.VLSP else \"-LRB-\"\n        RRB = \"RBKT\" if print_format == TreePrintMethod.VLSP else \"-RRB-\"\n        def normalize(text):\n            return text.replace(\" \", space_replacement).replace(\"(\", LRB).replace(\")\", RRB)\n\n        if print_format is TreePrintMethod.PRETTY:\n            return self.pretty_print(normalize)\n\n        with StringIO() as buf:\n            stack = deque()\n            if print_format == TreePrintMethod.VLSP:\n                if use_tree_id:\n                    buf.write(\"<s id={}>\\n\".format(self.tree_id))\n                else:\n                    buf.write(\"<s>\\n\")\n                if len(self.children) == 0:\n                    raise ValueError(\"Cannot print an empty tree with V format\")\n                elif len(self.children) > 1:\n                    raise ValueError(\"Cannot print a tree with %d branches with V format\" % len(self.children))\n                stack.append(self.children[0])\n            elif print_format == TreePrintMethod.LATEX_TREE:\n                buf.write(\"\\\\Tree \")\n                if len(self.children) == 0:\n                    raise ValueError(\"Cannot print an empty tree with T format\")\n                elif len(self.children) == 1 and len(self.children[0].children) == 0:\n                    buf.write(\"[.? \")\n                    buf.write(normalize(self.children[0].label))\n                    buf.write(\" ]\")\n                elif self.label == 'ROOT':\n                    stack.append(self.children[0])\n                else:\n                    stack.append(self)\n            else:\n                stack.append(self)\n            while len(stack) > 0:\n                node = stack.pop()\n\n                if isinstance(node, str):\n                    buf.write(node)\n                    continue\n                if len(node.children) == 0:\n                    if node.label is not None:\n                        buf.write(normalize(node.label))\n                    continue\n\n                if print_format is TreePrintMethod.LATEX_TREE:\n                    if node.is_preterminal():\n                        buf.write(normalize(node.children[0].label))\n                        continue\n                    buf.write(\"[.%s\" % normalize(node.label))\n                    stack.append(\" ]\")\n                elif print_format is TreePrintMethod.ONE_LINE or print_format is TreePrintMethod.VLSP:\n                    buf.write(OPEN_PAREN)\n                    if node.label is not None:\n                        buf.write(normalize(node.label))\n                    stack.append(CLOSE_PAREN)\n                elif print_format is TreePrintMethod.LABELED_PARENS:\n                    buf.write(\"%s_%s\" % (OPEN_PAREN, normalize(node.label)))\n                    stack.append(CLOSE_PAREN + \"_\" + normalize(node.label))\n                    stack.append(SPACE_SEPARATOR)\n\n                for child in reversed(node.children):\n                    stack.append(child)\n                    stack.append(SPACE_SEPARATOR)\n            if print_format == TreePrintMethod.VLSP:\n                buf.write(\"\\n</s>\")\n            buf.seek(0)\n            return buf.read()\n\n    def __repr__(self):\n        return \"{}\".format(self)\n\n    def __eq__(self, other):\n        if self is other:\n            return True\n        if not isinstance(other, Tree):\n            return False\n        if self.label != other.label:\n            return False\n        if len(self.children) != len(other.children):\n            return False\n        if any(c1 != c2 for c1, c2 in zip(self.children, other.children)):\n            return False\n        return True\n\n    def depth(self):\n        if not self.children:\n            return 0\n        return 1 + max(x.depth() for x in self.children)\n\n    def visit_preorder(self, internal=None, preterminal=None, leaf=None):\n        \"\"\"\n        Visit the tree in a preorder order\n\n        Applies the given functions to each node.\n        internal: if not None, applies this function to each non-leaf, non-preterminal node\n        preterminal: if not None, applies this functiion to each preterminal\n        leaf: if not None, applies this function to each leaf\n\n        The functions should *not* destructively alter the trees.\n        There is no attempt to interpret the results of calling these functions.\n        Rather, you can use visit_preorder to collect stats on trees, etc.\n        \"\"\"\n        if self.is_leaf():\n            if leaf:\n                leaf(self)\n        elif self.is_preterminal():\n            if preterminal:\n                preterminal(self)\n        else:\n            if internal:\n                internal(self)\n        for child in self.children:\n            child.visit_preorder(internal, preterminal, leaf)\n\n    @staticmethod\n    def get_unique_constituent_labels(trees):\n        \"\"\"\n        Walks over all of the trees and gets all of the unique constituent names from the trees\n        \"\"\"\n        if isinstance(trees, Tree):\n            trees = [trees]\n        constituents = Tree.get_constituent_counts(trees)\n        return sorted(set(constituents.keys()))\n\n    @staticmethod\n    def get_constituent_counts(trees):\n        \"\"\"\n        Walks over all of the trees and gets the count of the unique constituent names from the trees\n        \"\"\"\n        if isinstance(trees, Tree):\n            trees = [trees]\n\n        constituents = Counter()\n        for tree in trees:\n            tree.visit_preorder(internal = lambda x: constituents.update([x.label]))\n        return constituents\n\n    @staticmethod\n    def get_unique_tags(trees):\n        \"\"\"\n        Walks over all of the trees and gets all of the unique tags from the trees\n        \"\"\"\n        if isinstance(trees, Tree):\n            trees = [trees]\n\n        tags = set()\n        for tree in trees:\n            tree.visit_preorder(preterminal = lambda x: tags.add(x.label))\n        return sorted(tags)\n\n    @staticmethod\n    def get_unique_words(trees):\n        \"\"\"\n        Walks over all of the trees and gets all of the unique words from the trees\n        \"\"\"\n        if isinstance(trees, Tree):\n            trees = [trees]\n\n        words = set()\n        for tree in trees:\n            tree.visit_preorder(leaf = lambda x: words.add(x.label))\n        return sorted(words)\n\n    @staticmethod\n    def get_common_words(trees, num_words):\n        \"\"\"\n        Walks over all of the trees and gets the most frequently occurring words.\n        \"\"\"\n        if num_words == 0:\n            return set()\n\n        if isinstance(trees, Tree):\n            trees = [trees]\n\n        words = Counter()\n        for tree in trees:\n            tree.visit_preorder(leaf = lambda x: words.update([x.label]))\n        return sorted(x[0] for x in words.most_common()[:num_words])\n\n    @staticmethod\n    def get_rare_words(trees, threshold=0.05):\n        \"\"\"\n        Walks over all of the trees and gets the least frequently occurring words.\n\n        threshold: choose the bottom X percent\n        \"\"\"\n        if isinstance(trees, Tree):\n            trees = [trees]\n\n        words = Counter()\n        for tree in trees:\n            tree.visit_preorder(leaf = lambda x: words.update([x.label]))\n        threshold = max(int(len(words) * threshold), 1)\n        return sorted(x[0] for x in words.most_common()[:-threshold-1:-1])\n\n    @staticmethod\n    def get_root_labels(trees):\n        return sorted(set(x.label for x in trees))\n\n    @staticmethod\n    def get_compound_constituents(trees, separate_root=False):\n        constituents = set()\n        stack = deque()\n        for tree in trees:\n            if separate_root:\n                constituents.add((tree.label,))\n                for child in tree.children:\n                    stack.append(child)\n            else:\n                stack.append(tree)\n            while len(stack) > 0:\n                node = stack.pop()\n                if node.is_leaf() or node.is_preterminal():\n                    continue\n                labels = [node.label]\n                while len(node.children) == 1 and not node.children[0].is_preterminal():\n                    node = node.children[0]\n                    labels.append(node.label)\n                constituents.add(tuple(labels))\n                for child in node.children:\n                    stack.append(child)\n        return sorted(constituents)\n\n    # TODO: test different pattern\n    def simplify_labels(self, pattern=CONSTITUENT_SPLIT):\n        \"\"\"\n        Return a copy of the tree with the -=# removed\n\n        Leaves the text of the leaves alone.\n        \"\"\"\n        new_label = self.label\n        # check len(new_label) just in case it's a tag of - or =\n        if new_label and not self.is_leaf() and len(new_label) > 1 and new_label not in ('-LRB-', '-RRB-'):\n            new_label = pattern.split(new_label)[0]\n        new_children = [child.simplify_labels(pattern) for child in self.children]\n        return Tree(new_label, new_children)\n\n    def reverse(self):\n        \"\"\"\n        Flip a tree backwards\n\n        The intent is to train a parser backwards to see if the\n        forward and backwards parsers can augment each other\n        \"\"\"\n        if self.is_leaf():\n            return Tree(self.label)\n\n        new_children = [child.reverse() for child in reversed(self.children)]\n        return Tree(self.label, new_children)\n\n    def remap_constituent_labels(self, label_map):\n        \"\"\"\n        Copies the tree with some labels replaced.\n\n        Labels in the map are replaced with the mapped value.\n        Labels not in the map are unchanged.\n        \"\"\"\n        if self.is_leaf():\n            return Tree(self.label)\n        if self.is_preterminal():\n            return Tree(self.label, Tree(self.children[0].label))\n        new_label = label_map.get(self.label, self.label)\n        return Tree(new_label, [child.remap_constituent_labels(label_map) for child in self.children])\n\n    def remap_words(self, word_map):\n        \"\"\"\n        Copies the tree with some labels replaced.\n\n        Labels in the map are replaced with the mapped value.\n        Labels not in the map are unchanged.\n        \"\"\"\n        if self.is_leaf():\n            new_label = word_map.get(self.label, self.label)\n            return Tree(new_label)\n        if self.is_preterminal():\n            return Tree(self.label, self.children[0].remap_words(word_map))\n        return Tree(self.label, [child.remap_words(word_map) for child in self.children])\n\n    def replace_words(self, words):\n        \"\"\"\n        Replace all leaf words with the words in the given list (or iterable)\n\n        Returns a new tree\n        \"\"\"\n        word_iterator = iter(words)\n        def recursive_replace_words(subtree):\n            if subtree.is_leaf():\n                word = next(word_iterator, None)\n                if word is None:\n                    raise ValueError(\"Not enough words to replace all leaves\")\n                return Tree(word)\n            return Tree(subtree.label, [recursive_replace_words(x) for x in subtree.children])\n\n        new_tree = recursive_replace_words(self)\n        if any(True for _ in word_iterator):\n            raise ValueError(\"Too many words for the given tree\")\n        return new_tree\n\n\n    def replace_tags(self, tags):\n        if self.is_leaf():\n            raise ValueError(\"Must call replace_tags with non-leaf\")\n\n        if isinstance(tags, Tree):\n            tag_iterator = (x.label for x in tags.yield_preterminals())\n        else:\n            tag_iterator = iter(tags)\n\n        new_tree = copy.deepcopy(self)\n        queue = deque()\n        queue.append(new_tree)\n        while len(queue) > 0:\n            next_node = queue.pop()\n            if next_node.is_preterminal():\n                try:\n                    label = next(tag_iterator)\n                except StopIteration:\n                    raise ValueError(\"Not enough tags in sentence for given tree\")\n                next_node.label = label\n            elif next_node.is_leaf():\n                raise ValueError(\"Got a badly structured tree: {}\".format(self))\n            else:\n                queue.extend(reversed(next_node.children))\n\n        if any(True for _ in tag_iterator):\n            raise ValueError(\"Too many tags for the given tree\")\n\n        return new_tree\n\n\n    def prune_none(self):\n        \"\"\"\n        Return a copy of the tree, eliminating all nodes which are in one of two categories:\n            they are a preterminal -NONE-, such as appears in PTB\n              *E* shows up in a VLSP dataset\n            they have been pruned to 0 children by the recursive call\n        \"\"\"\n        if self.is_leaf():\n            return Tree(self.label)\n        if self.is_preterminal():\n            if self.label == '-NONE-' or self.children[0].label in WORDS_TO_PRUNE:\n                return None\n            return Tree(self.label, Tree(self.children[0].label))\n        # must be internal node\n        new_children = [child.prune_none() for child in self.children]\n        new_children = [child for child in new_children if child is not None]\n        if len(new_children) == 0:\n            return None\n        return Tree(self.label, new_children)\n\n    def count_unary_depth(self):\n        if self.is_preterminal() or self.is_leaf():\n            return 0\n        if len(self.children) == 1:\n            t = self\n            score = 0\n            while not t.is_preterminal() and not t.is_leaf() and len(t.children) == 1:\n                score = score + 1\n                t = t.children[0]\n            child_score = max(tc.count_unary_depth() for tc in t.children)\n            score = max(score, child_score)\n            return score\n        score = max(t.count_unary_depth() for t in self.children)\n        return score\n\n    @staticmethod\n    def write_treebank(trees, out_file, fmt=\"{}\"):\n        with open(out_file, \"w\", encoding=\"utf-8\") as fout:\n            for tree in trees:\n                fout.write(fmt.format(tree))\n                fout.write(\"\\n\")\n"
  },
  {
    "path": "stanza/models/constituency/parser_training.py",
    "content": "from collections import Counter, namedtuple\nimport copy\nimport logging\nimport os\nimport random\nimport re\n\nimport torch\nfrom torch import nn\n\n#from stanza.models.common import pretrain\n\nfrom stanza.models.common import utils\nfrom stanza.models.common.foundation_cache import FoundationCache, NoTransformerFoundationCache\nfrom stanza.models.common.large_margin_loss import LargeMarginInSoftmaxLoss\nfrom stanza.models.common.utils import sort_with_indices, unsort\nfrom stanza.models.constituency import error_analysis_in_order\nfrom stanza.models.constituency import parse_transitions\nfrom stanza.models.constituency import transition_sequence\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.in_order_compound_oracle import InOrderCompoundOracle\nfrom stanza.models.constituency.in_order_oracle import InOrderOracle\nfrom stanza.models.constituency.lstm_model import LSTMModel\nfrom stanza.models.constituency.parse_transitions import TransitionScheme\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.models.constituency.top_down_oracle import TopDownOracle\nfrom stanza.models.constituency.trainer import Trainer\nfrom stanza.models.constituency.utils import retag_trees, build_optimizer, build_scheduler, verify_transitions, get_open_nodes, check_constituents, check_root_labels, remove_duplicate_trees, remove_singleton_trees\nfrom stanza.server.parser_eval import EvaluateParser, ParseResult\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\ntlogger = logging.getLogger('stanza.constituency.trainer')\n\nTrainItem = namedtuple(\"TrainItem\", ['tree', 'gold_sequence', 'preterminals'])\n\nclass EpochStats(namedtuple(\"EpochStats\", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):\n    def __add__(self, other):\n        transitions_correct = self.transitions_correct + other.transitions_correct\n        transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect\n        repairs_used = self.repairs_used + other.repairs_used\n        fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used\n        epoch_loss = self.epoch_loss + other.epoch_loss\n        nans = self.nans + other.nans\n        return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)\n\ndef evaluate(args, model_file, retag_pipeline):\n    \"\"\"\n    Loads the given model file and tests the eval_file treebank.\n\n    May retag the trees using retag_pipeline\n    Uses a subprocess to run the Java EvalB code\n    \"\"\"\n    # we create the Evaluator here because otherwise the transformers\n    # library constantly complains about forking the process\n    # note that this won't help in the event of training multiple\n    # models in the same run, although since that would take hours\n    # or days, that's not a very common problem\n    if args['num_generate'] > 0:\n        kbest = args['num_generate'] + 1\n    else:\n        kbest = None\n\n    with EvaluateParser(kbest=kbest) as evaluator:\n        foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()\n        load_args = {\n            'wordvec_pretrain_file': args['wordvec_pretrain_file'],\n            'charlm_forward_file': args['charlm_forward_file'],\n            'charlm_backward_file': args['charlm_backward_file'],\n            'device': args['device'],\n        }\n        trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)\n\n        if args['log_shapes']:\n            trainer.log_shapes()\n\n        treebank = tree_reader.read_treebank(args['eval_file'])\n        tlogger.info(\"Read %d trees for evaluation\", len(treebank))\n\n        retagged_treebank = treebank\n        if retag_pipeline is not None:\n            retag_method = trainer.model.retag_method\n            retag_xpos = retag_method == 'xpos'\n            tlogger.info(\"Retagging trees using the %s tags from the %s package...\", retag_method, args['retag_package'])\n            retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos)\n            tlogger.info(\"Retagging finished\")\n\n        if args['log_norms']:\n            trainer.log_norms()\n        f1, kbestF1, _ = run_dev_set(trainer.model, retagged_treebank, treebank, args, evaluator, analyze_first_errors=True)\n        tlogger.info(\"F1 score on %s: %f\", args['eval_file'], f1)\n        if kbestF1 is not None:\n            tlogger.info(\"KBest F1 score on %s: %f\", args['eval_file'], kbestF1)\n\ndef remove_optimizer(args, model_save_file, model_load_file):\n    \"\"\"\n    A utility method to remove the optimizer from a save file\n\n    Will make the save file a lot smaller\n    \"\"\"\n    # TODO: kind of overkill to load in the pretrain rather than\n    # change the load/save to work without it, but probably this\n    # functionality isn't used that often anyway\n    load_args = {\n        'wordvec_pretrain_file': args['wordvec_pretrain_file'],\n        'charlm_forward_file': args['charlm_forward_file'],\n        'charlm_backward_file': args['charlm_backward_file'],\n        'device': args['device'],\n    }\n    trainer = Trainer.load(model_load_file, args=load_args, load_optimizer=False)\n    trainer.save(model_save_file)\n\ndef add_grad_clipping(trainer, grad_clipping):\n    \"\"\"\n    Adds a torch.clamp hook on each parameter if grad_clipping is not None\n    \"\"\"\n    if grad_clipping is not None:\n        for p in trainer.model.parameters():\n            if p.requires_grad:\n                p.register_hook(lambda grad: torch.clamp(grad, -grad_clipping, grad_clipping))\n\ndef build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file):\n    \"\"\"\n    Builds a Trainer (with model) and the train_sequences and transitions for the given trees.\n    \"\"\"\n    train_constituents = Tree.get_unique_constituent_labels(train_trees)\n    tlogger.info(\"Unique constituents in training set: %s\", train_constituents)\n    if args['check_valid_states']:\n        check_constituents(train_constituents, dev_trees, \"dev\", fail=args['strict_check_constituents'])\n        check_constituents(train_constituents, silver_trees, \"silver\", fail=args['strict_check_constituents'])\n    constituent_counts = Tree.get_constituent_counts(train_trees)\n    tlogger.info(\"Constituent node counts: %s\", constituent_counts)\n\n    tags = Tree.get_unique_tags(train_trees)\n    if None in tags:\n        raise RuntimeError(\"Fatal problem: the tagger put None on some of the nodes!\")\n    tlogger.info(\"Unique tags in training set: %s\", tags)\n    # no need to fail for missing tags between train/dev set\n    # the model has an unknown tag embedding\n    for tag in Tree.get_unique_tags(dev_trees):\n        if tag not in tags:\n            tlogger.info(\"Found tag in dev set which does not exist in train set: %s  Continuing...\", tag)\n\n    unary_limit = max(max(t.count_unary_depth() for t in train_trees),\n                      max(t.count_unary_depth() for t in dev_trees)) + 1\n    if silver_trees:\n        unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees))\n    tlogger.info(\"Unary limit: %d\", unary_limit)\n    train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, \"training\", args['transition_scheme'], args['reversed'])\n    dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, \"dev\", args['transition_scheme'], args['reversed'])\n    silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, \"silver\", args['transition_scheme'], args['reversed'])\n\n    tlogger.info(\"Total unique transitions in train set: %d\", len(train_transitions))\n    tlogger.info(\"Unique transitions in training set:\\n  %s\", \"\\n  \".join(map(str, train_transitions)))\n    expanded_train_transitions = set(train_transitions + [x for trans in train_transitions for x in trans.components()])\n    if args['check_valid_states']:\n        parse_transitions.check_transitions(expanded_train_transitions, dev_transitions, \"dev\")\n        # theoretically could just train based on the items in the silver dataset\n        parse_transitions.check_transitions(expanded_train_transitions, silver_transitions, \"silver\")\n\n    root_labels = Tree.get_root_labels(train_trees)\n    check_root_labels(root_labels, dev_trees, \"dev\")\n    check_root_labels(root_labels, silver_trees, \"silver\")\n    tlogger.info(\"Root labels in treebank: %s\", root_labels)\n\n    verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], \"train\", root_labels)\n    verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit, args['reversed'], \"dev\", root_labels)\n\n    # we don't check against the words in the dev set as it is\n    # expected there will be some UNK words\n    words = Tree.get_unique_words(train_trees)\n    rare_words = Tree.get_rare_words(train_trees, args['rare_word_threshold'])\n    # rare/unknown silver words will just get UNK if they are not already known\n    if silver_trees and args['use_silver_words']:\n        tlogger.info(\"Getting silver words to add to the delta embedding\")\n        silver_words = Tree.get_common_words(tqdm(silver_trees, postfix='Silver words'), len(words))\n        words = sorted(set(words + silver_words))\n\n    # also, it's not actually an error if there is a pattern of\n    # compound unary or compound open nodes which doesn't exist in the\n    # train set.  it just means we probably won't ever get that right\n    open_nodes = get_open_nodes(train_trees, args['transition_scheme'])\n    tlogger.info(\"Using the following open nodes:\\n  %s\", \"\\n  \".join(map(str, open_nodes)))\n\n    # at this point we have:\n    # pretrain\n    # train_trees, dev_trees\n    # lists of transitions, internal nodes, and root states the parser needs to be aware of\n\n    trainer = Trainer.build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file)\n\n    trainer.log_num_words_known(words)\n    # grad clipping is not saved with the rest of the model,\n    # so even in the case of a model we saved,\n    # we now have to add the grad clipping\n    add_grad_clipping(trainer, args['grad_clipping'])\n\n    return trainer, train_sequences, silver_sequences, train_transitions\n\ndef train(args, model_load_file, retag_pipeline):\n    \"\"\"\n    Build a model, train it using the requested train & dev files\n    \"\"\"\n    utils.log_training_args(args, tlogger)\n\n    # we create the Evaluator here because otherwise the transformers\n    # library constantly complains about forking the process\n    # note that this won't help in the event of training multiple\n    # models in the same run, although since that would take hours\n    # or days, that's not a very common problem\n    if args['num_generate'] > 0:\n        kbest = args['num_generate'] + 1\n    else:\n        kbest = None\n\n    if args['wandb']:\n        global wandb\n        import wandb\n        wandb_name = args['wandb_name'] if args['wandb_name'] else \"%s_constituency\" % args['shorthand']\n        wandb.init(name=wandb_name, config=args)\n        wandb.run.define_metric('dev_score', summary='max')\n\n    with EvaluateParser(kbest=kbest) as evaluator:\n        utils.ensure_dir(args['save_dir'])\n\n        train_trees = tree_reader.read_treebank(args['train_file'])\n        tlogger.info(\"Read %d trees for the training set\", len(train_trees))\n        if args['train_remove_duplicates']:\n            train_trees = remove_duplicate_trees(train_trees, \"train\")\n        train_trees = remove_singleton_trees(train_trees)\n\n        dev_trees = tree_reader.read_treebank(args['eval_file'])\n        tlogger.info(\"Read %d trees for the dev set\", len(dev_trees))\n        dev_trees = remove_duplicate_trees(dev_trees, \"dev\")\n\n        silver_trees = []\n        if args['silver_file']:\n            silver_trees = tree_reader.read_treebank(args['silver_file'])\n            tlogger.info(\"Read %d trees for the silver training set\", len(silver_trees))\n            if args['silver_remove_duplicates']:\n                silver_trees = remove_duplicate_trees(silver_trees, \"silver\")\n\n        if retag_pipeline is not None:\n            tlogger.info(\"Retagging trees using the %s tags from the %s package...\", args['retag_method'], args['retag_package'])\n            train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos'])\n            dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos'])\n            silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos'])\n            tlogger.info(\"Retagging finished\")\n\n        foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()\n        trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file)\n\n        if args['log_shapes']:\n            trainer.log_shapes()\n        trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator)\n\n    if args['wandb']:\n        wandb.finish()\n\n    return trainer\n\ndef compose_train_data(trees, sequences):\n    preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label))\n                          for preterminal in tree.yield_preterminals()]\n                         for tree in trees]\n    data = [TrainItem(*x) for x in zip(trees, sequences, preterminal_lists)]\n    return data\n\ndef next_epoch_data(leftover_training_data, train_data, epoch_size):\n    \"\"\"\n    Return the next epoch_size trees from the training data, starting\n    with leftover data from the previous epoch if there is any\n\n    The training loop generally operates on a fixed number of trees,\n    rather than going through all the trees in the training set\n    exactly once, and keeping the leftover training data via this\n    function ensures that each tree in the training set is touched\n    once before beginning to iterate again.\n    \"\"\"\n    if not train_data:\n        return [], []\n\n    epoch_data = leftover_training_data\n    while len(epoch_data) < epoch_size:\n        random.shuffle(train_data)\n        epoch_data.extend(train_data)\n    leftover_training_data = epoch_data[epoch_size:]\n    epoch_data = epoch_data[:epoch_size]\n\n    return leftover_training_data, epoch_data\n\ndef update_bert_learning_rate(args, optimizer, epochs_trained):\n    \"\"\"\n    Update the learning rate for the bert finetuning, if applicable\n    \"\"\"\n    # would be nice to have a parameter group specific scheduler\n    # however, there is an issue with the optimizer we had the most success with, madgrad\n    # when the learning rate is 0 for a group, it still learns by some\n    # small amount because of the eps parameter\n    # in fact, that is enough to make the learning for the bert in the\n    # second half broken\n    for base_param_group in optimizer.param_groups:\n        if base_param_group['param_group_name'] == 'base':\n            break\n    else:\n        raise AssertionError(\"There should always be a base parameter group\")\n    for param_group in optimizer.param_groups:\n        if param_group['param_group_name'] == 'bert':\n            # Occasionally a model goes haywire and forgets how to use the transformer\n            # So far we have only seen this happen with Electra on the non-NML version of PTB\n            # We tried fixing that with an increasing transformer learning rate, but that\n            # didn't fully resolve the problem\n            # Switching to starting the finetuning after a few epochs seems to help a lot, though\n            old_lr = param_group['lr']\n            if args['bert_finetune_begin_epoch'] is not None and epochs_trained < args['bert_finetune_begin_epoch']:\n                param_group['lr'] = 0.0\n            elif args['bert_finetune_end_epoch'] is not None and epochs_trained >= args['bert_finetune_end_epoch']:\n                param_group['lr'] = 0.0\n            elif args['multistage'] and epochs_trained < args['epochs'] // 2:\n                param_group['lr'] = base_param_group['lr'] * args['stage1_bert_learning_rate']\n            else:\n                param_group['lr'] = base_param_group['lr'] * args['bert_learning_rate']\n            if param_group['lr'] != old_lr:\n                tlogger.info(\"Setting %s finetuning rate from %f to %f\", param_group['param_group_name'], old_lr, param_group['lr'])\n\ndef iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator):\n    \"\"\"\n    Given an initialized model, a processed dataset, and a secondary dev dataset, train the model\n\n    The training is iterated in the following loop:\n      extract a batch of trees of the same length from the training set\n      convert those trees into initial parsing states\n      repeat until trees are done:\n        batch predict the model's interpretation of the current states\n        add the errors to the list of things to backprop\n        advance the parsing state for each of the trees\n    \"\"\"\n    # Somewhat unusual, but possibly related to the extreme variability in length of trees\n    # Various experiments generally show about 0.5 F1 loss on various\n    # datasets when using 'mean' instead of 'sum' for reduction\n    # (Remember to adjust the weight decay when rerunning that experiment)\n    if args['loss'] == 'cross':\n        tlogger.info(\"Building CrossEntropyLoss(sum)\")\n        process_outputs = lambda x: x\n        model_loss_function = nn.CrossEntropyLoss(reduction='sum')\n    elif args['loss'] == 'focal':\n        try:\n            from focal_loss.focal_loss import FocalLoss\n        except ImportError:\n            raise ImportError(\"focal_loss not installed.  Must `pip install focal_loss_torch` to use the --loss=focal feature\")\n        tlogger.info(\"Building FocalLoss, gamma=%f\", args['loss_focal_gamma'])\n        process_outputs = lambda x: torch.softmax(x, dim=1)\n        model_loss_function = FocalLoss(reduction='sum', gamma=args['loss_focal_gamma'])\n    elif args['loss'] == 'large_margin':\n        tlogger.info(\"Building LargeMarginInSoftmaxLoss(sum)\")\n        process_outputs = lambda x: x\n        model_loss_function = LargeMarginInSoftmaxLoss(reduction='sum')\n    else:\n        raise ValueError(\"Unexpected loss term: %s\" % args['loss'])\n\n    device = trainer.device\n    model_loss_function.to(device)\n    transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0)\n                          for (y, x) in enumerate(trainer.transitions)}\n    trainer.train()\n\n    train_data = compose_train_data(train_trees, train_sequences)\n    silver_data = compose_train_data(silver_trees, silver_sequences)\n\n    if not args['epoch_size']:\n        args['epoch_size'] = len(train_data)\n    if silver_data and not args['silver_epoch_size']:\n        args['silver_epoch_size'] = args['epoch_size']\n\n    if args['multistage']:\n        multistage_splits = {}\n        # if we're halfway, only do pattn.  save lattn for next time\n        multistage_splits[args['epochs'] // 2] = (args['pattn_num_layers'], False)\n        if LSTMModel.uses_lattn(args):\n            multistage_splits[args['epochs'] * 3 // 4] = (args['pattn_num_layers'], True)\n\n    # TODO: refactor the oracle choice into the transition scheme?\n    oracle = None\n    if args['transition_scheme'] is TransitionScheme.IN_ORDER:\n        oracle = InOrderOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])\n    elif args['transition_scheme'] is TransitionScheme.IN_ORDER_COMPOUND:\n        oracle = InOrderCompoundOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])\n    elif args['transition_scheme'] is TransitionScheme.TOP_DOWN:\n        oracle = TopDownOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])\n\n    leftover_training_data = []\n    leftover_silver_data = []\n    if trainer.best_epoch > 0:\n        tlogger.info(\"Restarting trainer with a model trained for %d epochs.  Best epoch %d, f1 %f\", trainer.epochs_trained, trainer.best_epoch, trainer.best_f1)\n\n    # if we're training a new model, save the initial state so it can be inspected\n    if args['save_each_start'] == 0 and trainer.epochs_trained == 0:\n        trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=True)\n\n    # trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1\n    for trainer.epochs_trained in range(trainer.epochs_trained+1, args['epochs']+1):\n        trainer.train()\n        tlogger.info(\"Starting epoch %d\", trainer.epochs_trained)\n        update_bert_learning_rate(args, trainer.optimizer, trainer.epochs_trained)\n\n        if args['log_norms']:\n            trainer.log_norms()\n        leftover_training_data, epoch_data = next_epoch_data(leftover_training_data, train_data, args['epoch_size'])\n        leftover_silver_data, epoch_silver_data = next_epoch_data(leftover_silver_data, silver_data, args['silver_epoch_size'])\n        epoch_data = epoch_data + epoch_silver_data\n        epoch_data.sort(key=lambda x: len(x[1]))\n\n        epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args)\n\n        # print statistics\n        # by now we've forgotten about the original tags on the trees,\n        # but it doesn't matter for hill climbing\n        f1, _, _ = run_dev_set(trainer.model, dev_trees, dev_trees, args, evaluator)\n        if f1 > trainer.best_f1 or (trainer.best_epoch == 0 and trainer.best_f1 == 0.0):\n            # best_epoch == 0 to force a save of an initial model\n            # useful for tests which expect something, even when a\n            # very simple model didn't learn anything\n            tlogger.info(\"New best dev score: %.5f > %.5f\", f1, trainer.best_f1)\n            trainer.best_f1 = f1\n            trainer.best_epoch = trainer.epochs_trained\n            trainer.save(args['save_name'], save_optimizer=False)\n        if epoch_stats.nans > 0:\n            tlogger.warning(\"Had to ignore %d batches with NaN\", epoch_stats.nans)\n        # TODO: refactor the logging?\n        total_correct = sum(v for _, v in epoch_stats.transitions_correct.items())\n        correct_transitions_str = \"\\n  \".join([\"%s: %d\" % (x, epoch_stats.transitions_correct[x]) for x in epoch_stats.transitions_correct])\n        tlogger.info(\"Transitions correct: %d\\n  %s\", total_correct, correct_transitions_str)\n        total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items())\n        incorrect_transitions_str = \"\\n  \".join([\"%s: %d\" % (x, epoch_stats.transitions_incorrect[x]) for x in epoch_stats.transitions_incorrect])\n        tlogger.info(\"Transitions incorrect: %d\\n  %s\", total_incorrect, incorrect_transitions_str)\n        if len(epoch_stats.repairs_used) > 0:\n            tlogger.info(\"Oracle repairs:\\n  %s\", \"\\n  \".join(\"%s (%s): %d\" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common()))\n        if epoch_stats.fake_transitions_used > 0:\n            tlogger.info(\"Fake transitions used: %d\", epoch_stats.fake_transitions_used)\n\n        stats_log_lines = [\n            \"Epoch %d finished\" % trainer.epochs_trained,\n            \"Transitions correct:   %d\" % total_correct,\n            \"Transitions incorrect: %d\" % total_incorrect,\n            \"Total loss for epoch: %.5f\" % epoch_stats.epoch_loss,\n            \"Dev score      (%5d): %8f\" % (trainer.epochs_trained, f1),\n            \"Best dev score (%5d): %8f\" % (trainer.best_epoch, trainer.best_f1)\n        ]\n        tlogger.info(\"\\n  \".join(stats_log_lines))\n\n        old_lr = trainer.optimizer.param_groups[0]['lr']\n        trainer.scheduler.step(f1)\n        new_lr = trainer.optimizer.param_groups[0]['lr']\n        if old_lr != new_lr:\n            tlogger.info(\"Updating learning rate from %f to %f\", old_lr, new_lr)\n\n        if args['wandb']:\n            wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained)\n            if args['wandb_norm_regex']:\n                watch_regex = re.compile(args['wandb_norm_regex'])\n                for n, p in trainer.model.named_parameters():\n                    if watch_regex.search(n):\n                        wandb.log({n: torch.linalg.norm(p)})\n\n        if args['early_dropout'] > 0 and trainer.epochs_trained >= args['early_dropout']:\n            if any(x > 0.0 for x in (trainer.model.word_dropout.p, trainer.model.predict_dropout.p, trainer.model.lstm_input_dropout.p)):\n                tlogger.info(\"Setting dropout to 0.0 at epoch %d\", trainer.epochs_trained)\n            trainer.model.word_dropout.p = 0\n            trainer.model.predict_dropout.p = 0\n            trainer.model.lstm_input_dropout.p = 0\n\n        # recreate the optimizer and alter the model as needed if we hit a new multistage split\n        if args['multistage'] and trainer.epochs_trained in multistage_splits:\n            # we may be loading a save model from an earlier epoch if the scores stopped increasing\n            epochs_trained = trainer.epochs_trained\n            batches_trained = trainer.batches_trained\n\n            stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained]\n\n            # when loading the model, let the saved model determine whether it has pattn or lattn\n            temp_args = copy.deepcopy(trainer.model.args)\n            temp_args.pop('pattn_num_layers', None)\n            temp_args.pop('lattn_d_proj', None)\n            # overwriting the old trainer & model will hopefully free memory\n            # load a new bert, even in PEFT mode, mostly so that the bert model\n            # doesn't collect a whole bunch of PEFTs\n            # for one thing, two PEFTs would mean 2x the optimizer parameters,\n            # messing up saving and loading the optimizer without jumping\n            # through more hoops\n            # loading the trainer w/o the foundation_cache should create\n            # the necessary bert_model and bert_tokenizer, and then we\n            # can reuse those values when building out new LSTMModel\n            trainer = Trainer.load(args['save_name'], temp_args, load_optimizer=False)\n            model = trainer.model\n            tlogger.info(\"Finished stage at epoch %d.  Restarting optimizer\", epochs_trained)\n            tlogger.info(\"Previous best model was at epoch %d\", trainer.epochs_trained)\n\n            temp_args = dict(args)\n            tlogger.info(\"Switching to a model with %d pattn layers and %slattn\", stage_pattn_layers, \"\" if stage_uses_lattn else \"NO \")\n            temp_args['pattn_num_layers'] = stage_pattn_layers\n            if not stage_uses_lattn:\n                temp_args['lattn_d_proj'] = 0\n            pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file'])\n            forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file'])\n            backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])\n            new_model = LSTMModel(pt,\n                                  forward_charlm,\n                                  backward_charlm,\n                                  model.bert_model,\n                                  model.bert_tokenizer,\n                                  model.force_bert_saved,\n                                  model.peft_name,\n                                  model.transitions,\n                                  model.constituents,\n                                  model.tags,\n                                  model.delta_words,\n                                  model.rare_words,\n                                  model.root_labels,\n                                  model.constituent_opens,\n                                  model.unary_limit(),\n                                  temp_args)\n            new_model.to(device)\n            new_model.copy_with_new_structure(model)\n\n            optimizer = build_optimizer(temp_args, new_model, False)\n            scheduler = build_scheduler(temp_args, optimizer)\n            trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, batches_trained, trainer.best_f1, trainer.best_epoch)\n            add_grad_clipping(trainer, args['grad_clipping'])\n\n        # checkpoint needs to be saved AFTER rebuilding the optimizer\n        # so that assumptions about the optimizer in the checkpoint\n        # can be made based on the end of the epoch\n        if args['checkpoint'] and args['checkpoint_save_name']:\n            trainer.save(args['checkpoint_save_name'], save_optimizer=True)\n        # same with the \"each filename\", actually, in case those are\n        # brought back for more training or even just for testing\n        if args['save_each_start'] is not None and args['save_each_start'] <= trainer.epochs_trained and trainer.epochs_trained % args['save_each_frequency'] == 0:\n            trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=args['save_each_optimizer'])\n\n    return trainer\n\ndef train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args):\n    interval_starts = list(range(0, len(epoch_data), args['train_batch_size']))\n    random.shuffle(interval_starts)\n\n    optimizer = trainer.optimizer\n\n    epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0)\n\n    for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix=\"Epoch %d\" % epoch)):\n        batch = epoch_data[interval_start:interval_start+args['train_batch_size']]\n        batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args)\n        trainer.batches_trained += 1\n\n        # Early in the training, some trees will be degenerate in a\n        # way that results in layers going up the tree amplifying the\n        # weights until they overflow.  Generally that problem\n        # resolves itself in a few iterations, so for now we just\n        # ignore those batches, but report how often it happens\n        if batch_stats.nans == 0:\n            optimizer.step()\n        optimizer.zero_grad()\n        epoch_stats = epoch_stats + batch_stats\n\n    return epoch_stats\n\ndef train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args):\n    \"\"\"\n    Train the model for one batch\n\n    The model itself will be updated, and a bunch of stats are returned\n    It is unclear if this refactoring is useful in any way.  Might not be\n\n    ... although the indentation does get pretty ridiculous if this is\n    merged into train_model_one_epoch and then iterate_training\n    \"\"\"\n    # now we add the state to the trees in the batch\n    # the state is built as a bulk operation\n    current_batch = model.initial_state_from_preterminals([x.preterminals for x in training_batch],\n                                                          [x.tree for x in training_batch],\n                                                          [x.gold_sequence for x in training_batch])\n\n    transitions_correct = Counter()\n    transitions_incorrect = Counter()\n    repairs_used = Counter()\n    fake_transitions_used = 0\n\n    all_errors = []\n    all_answers = []\n\n    # we iterate through the batch in the following sequence:\n    # predict the logits and the applied transition for each tree in the batch\n    # collect errors\n    #  - we always train to the desired one-hot vector\n    #    this was a noticeable improvement over training just the\n    #    incorrect transitions\n    # determine whether the training can continue using the \"student\" transition\n    #   or if we need to use teacher forcing\n    # update all states using either the gold or predicted transition\n    # any trees which are now finished are removed from the training cycle\n    while len(current_batch) > 0:\n        outputs, pred_transitions, _ = model.predict(current_batch, is_legal=False)\n        gold_transitions = [x.gold_sequence[x.num_transitions] for x in current_batch]\n        trans_tensor = [transition_tensors[gold_transition] for gold_transition in gold_transitions]\n        all_errors.append(outputs)\n        all_answers.extend(trans_tensor)\n\n        new_batch = []\n        update_transitions = []\n        for pred_transition, gold_transition, state in zip(pred_transitions, gold_transitions, current_batch):\n            # forget teacher forcing vs scheduled sampling\n            # we're going with idiot forcing\n            if pred_transition == gold_transition:\n                transitions_correct[gold_transition.short_name()] += 1\n                if state.num_transitions + 1 < len(state.gold_sequence):\n                    if oracle is not None and epoch >= args['oracle_initial_epoch'] and random.random() < args['oracle_forced_errors']:\n                        # TODO: could randomly choose from the legal transitions\n                        # perhaps the second best scored transition\n                        fake_transition = random.choice(model.transitions)\n                        if fake_transition.is_legal(state, model):\n                            _, new_sequence = oracle.fix_error(fake_transition, model, state)\n                            if new_sequence is not None:\n                                new_batch.append(state._replace(gold_sequence=new_sequence))\n                                update_transitions.append(fake_transition)\n                                fake_transitions_used = fake_transitions_used + 1\n                                continue\n                    new_batch.append(state)\n                    update_transitions.append(gold_transition)\n                continue\n\n            transitions_incorrect[gold_transition.short_name(), pred_transition.short_name()] += 1\n            # if we are on the final operation, there are two choices:\n            #   - the parsing mode is IN_ORDER, and the final transition\n            #     is the close to end the sequence, which has no alternatives\n            #   - the parsing mode is something else, in which case\n            #     we have no oracle anyway\n            if state.num_transitions + 1 >= len(state.gold_sequence):\n                continue\n\n            if oracle is None or epoch < args['oracle_initial_epoch'] or not pred_transition.is_legal(state, model):\n                new_batch.append(state)\n                update_transitions.append(gold_transition)\n                continue\n\n            repair_type, new_sequence = oracle.fix_error(pred_transition, model, state)\n            # we can only reach here on an error\n            assert not repair_type.is_correct\n            repairs_used[repair_type] += 1\n            if new_sequence is not None and random.random() < args['oracle_frequency']:\n                new_batch.append(state._replace(gold_sequence=new_sequence))\n                update_transitions.append(pred_transition)\n            else:\n                new_batch.append(state)\n                update_transitions.append(gold_transition)\n\n        if len(current_batch) > 0:\n            # bulk update states - significantly faster\n            current_batch = model.bulk_apply(new_batch, update_transitions, fail=True)\n\n    errors = torch.cat(all_errors)\n    answers = torch.cat(all_answers)\n\n    errors = process_outputs(errors)\n    tree_loss = model_loss_function(errors, answers)\n    tree_loss.backward()\n    if args['watch_regex']:\n        matched = False\n        tlogger.info(\"Watching %s   ... epoch %d batch %d\", args['watch_regex'], epoch, batch_idx)\n        watch_regex = re.compile(args['watch_regex'])\n        for n, p in trainer.model.named_parameters():\n            if watch_regex.search(n):\n                matched = True\n                if p.requires_grad and p.grad is not None:\n                    tlogger.info(\"  %s norm: %f grad: %f\", n, torch.linalg.norm(p), torch.linalg.norm(p.grad))\n                elif p.requires_grad:\n                    tlogger.info(\"  %s norm: %f grad required, but is None!\", n, torch.linalg.norm(p))\n                else:\n                    tlogger.info(\"  %s norm: %f grad not required\", n, torch.linalg.norm(p))\n        if not matched:\n            tlogger.info(\"  (none found!)\")\n    if torch.any(torch.isnan(tree_loss)):\n        batch_loss = 0.0\n        nans = 1\n    else:\n        batch_loss = tree_loss.item()\n        nans = 0\n\n    return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)\n\ndef run_dev_set(model, retagged_trees, original_trees, args, evaluator=None, analyze_first_errors=False):\n    \"\"\"\n    This reparses a treebank and executes the CoreNLP Java EvalB code.\n\n    It only works if CoreNLP 4.3.0 or higher is in the classpath.\n    \"\"\"\n    tlogger.info(\"Processing %d trees from %s\", len(retagged_trees), args['eval_file'])\n    model.eval()\n\n    num_generate = args.get('num_generate', 0)\n    keep_scores = num_generate > 0\n\n    sorted_trees, original_indices = sort_with_indices(retagged_trees, key=len, reverse=True)\n    tree_iterator = iter(tqdm(sorted_trees))\n    treebank = model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.predict, keep_scores=keep_scores)\n    treebank = unsort(treebank, original_indices)\n    full_results = treebank\n\n    if num_generate > 0:\n        tlogger.info(\"Generating %d random analyses\", args['num_generate'])\n        generated_treebanks = [treebank]\n        for i in tqdm(range(num_generate)):\n            tree_iterator = iter(tqdm(retagged_trees, leave=False, postfix=\"tb%03d\" % i))\n            generated_treebanks.append(model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.weighted_choice, keep_scores=keep_scores))\n\n        #best_treebank = [ParseResult(parses[0].gold, [max([p.predictions[0] for p in parses], key=itemgetter(1))], None, None)\n        #                 for parses in zip(*generated_treebanks)]\n        #generated_treebanks = [best_treebank] + generated_treebanks\n\n        # TODO: if the model is dropping trees, this will not work\n        full_results = [ParseResult(parses[0].gold, [p.predictions[0] for p in parses], None, None)\n                        for parses in zip(*generated_treebanks)]\n\n    if len(full_results) < len(retagged_trees):\n        tlogger.warning(\"Only evaluating %d trees instead of %d\", len(full_results), len(retagged_trees))\n    else:\n        full_results = [x._replace(gold=gold) for x, gold in zip(full_results, original_trees)]\n\n    if args.get('mode', None) == 'predict' and args['predict_file']:\n        utils.ensure_dir(args['predict_dir'], verbose=False)\n        pred_file = os.path.join(args['predict_dir'], args['predict_file'] + \".pred.mrg\")\n        orig_file = os.path.join(args['predict_dir'], args['predict_file'] + \".orig.mrg\")\n        if os.path.exists(pred_file):\n            tlogger.warning(\"Cowardly refusing to overwrite {}\".format(pred_file))\n        elif os.path.exists(orig_file):\n            tlogger.warning(\"Cowardly refusing to overwrite {}\".format(orig_file))\n        else:\n            with open(pred_file, 'w') as fout:\n                for tree in full_results:\n                    output_tree = tree.predictions[0].tree\n                    if args['predict_output_gold_tags']:\n                        output_tree = output_tree.replace_tags(tree.gold)\n                    fout.write(args['predict_format'].format(output_tree))\n                    fout.write(\"\\n\")\n\n            for i in range(num_generate):\n                pred_file = os.path.join(args['predict_dir'], args['predict_file'] + \".%03d.pred.mrg\" % i)\n                with open(pred_file, 'w') as fout:\n                    for tree in generated_treebanks[-(i+1)]:\n                        output_tree = tree.predictions[0].tree\n                        if args['predict_output_gold_tags']:\n                            output_tree = output_tree.replace_tags(tree.gold)\n                        fout.write(args['predict_format'].format(output_tree))\n                        fout.write(\"\\n\")\n\n            with open(orig_file, 'w') as fout:\n                for tree in full_results:\n                    fout.write(args['predict_format'].format(tree.gold))\n                    fout.write(\"\\n\")\n\n    if len(full_results) == 0:\n        return 0.0, 0.0\n    if evaluator is None:\n        if num_generate > 0:\n            kbest = max(len(fr.predictions) for fr in full_results)\n        else:\n            kbest = None\n        with EvaluateParser(kbest=kbest) as evaluator:\n            response = evaluator.process(full_results)\n    else:\n        response = evaluator.process(full_results)\n\n    if analyze_first_errors and args['transition_scheme'] is TransitionScheme.IN_ORDER:\n        errors = Counter()\n        for result in full_results:\n            first_error = error_analysis_in_order.analyze_tree(result.gold, result.predictions[0].tree)\n            errors[first_error] += 1\n        log_lines = [\"%30s: %d\" % (key.name, errors[key]) for key in error_analysis_in_order.FirstError]\n        tlogger.info(\"First error frequency:\\n  %s\", \"\\n  \".join(log_lines))\n\n    kbestF1 = response.kbestF1 if response.HasField(\"kbestF1\") else None\n    return response.f1, kbestF1, response.treeF1\n"
  },
  {
    "path": "stanza/models/constituency/partitioned_transformer.py",
    "content": "\"\"\"\nTransformer with partitioned content and position features.\n\nSee section 3 of https://arxiv.org/pdf/1805.01052.pdf\n\"\"\"\n\nimport copy\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding\n\nclass FeatureDropoutFunction(torch.autograd.function.InplaceFunction):\n    @staticmethod\n    def forward(ctx, input, p=0.5, train=False, inplace=False):\n        if p < 0 or p > 1:\n            raise ValueError(\n                \"dropout probability has to be between 0 and 1, but got {}\".format(p)\n            )\n\n        ctx.p = p\n        ctx.train = train\n        ctx.inplace = inplace\n\n        if ctx.inplace:\n            ctx.mark_dirty(input)\n            output = input\n        else:\n            output = input.clone()\n\n        if ctx.p > 0 and ctx.train:\n            ctx.noise = torch.empty(\n                (input.size(0), input.size(-1)),\n                dtype=input.dtype,\n                layout=input.layout,\n                device=input.device,\n            )\n            if ctx.p == 1:\n                ctx.noise.fill_(0)\n            else:\n                ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)\n            ctx.noise = ctx.noise[:, None, :]\n            output.mul_(ctx.noise)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        if ctx.p > 0 and ctx.train:\n            return grad_output.mul(ctx.noise), None, None, None\n        else:\n            return grad_output, None, None, None\n\n\nclass FeatureDropout(nn.Dropout):\n    \"\"\"\n    Feature-level dropout: takes an input of size len x num_features and drops\n    each feature with probabibility p. A feature is dropped across the full\n    portion of the input that corresponds to a single batch element.\n    \"\"\"\n\n    def forward(self, x):\n        if isinstance(x, tuple):\n            x_c, x_p = x\n            x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace)\n            x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace)\n            return x_c, x_p\n        else:\n            return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace)\n\n\n# TODO: this module apparently is not treated the same the built-in\n# nonlinearity modules, as multiple uses of the same relu on different\n# tensors winds up mixing the gradients See if there is a way to\n# resolve that other than creating a new nonlinearity for each layer\nclass PartitionedReLU(nn.ReLU):\n    def forward(self, x):\n        if isinstance(x, tuple):\n            x_c, x_p = x\n        else:\n            x_c, x_p = torch.chunk(x, 2, dim=-1)\n        return super().forward(x_c), super().forward(x_p)\n\n\nclass PartitionedLinear(nn.Module):\n    def __init__(self, in_features, out_features, bias=True):\n        super().__init__()\n        self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias)\n        self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias)\n\n    def forward(self, x):\n        if isinstance(x, tuple):\n            x_c, x_p = x\n        else:\n            x_c, x_p = torch.chunk(x, 2, dim=-1)\n\n        out_c = self.linear_c(x_c)\n        out_p = self.linear_p(x_p)\n        return out_c, out_p\n\n\nclass PartitionedMultiHeadAttention(nn.Module):\n    def __init__(\n        self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02\n    ):\n        super().__init__()\n\n        self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))\n        self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))\n        self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))\n        self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))\n\n        bound = math.sqrt(3.0) * initializer_range\n        for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]:\n            nn.init.uniform_(param, -bound, bound)\n        self.scaling_factor = 1 / d_qkv ** 0.5\n\n        self.dropout = nn.Dropout(attention_dropout)\n\n    def forward(self, x, mask=None):\n        if isinstance(x, tuple):\n            x_c, x_p = x\n        else:\n            x_c, x_p = torch.chunk(x, 2, dim=-1)\n        qkv_c = torch.einsum(\"btf,hfca->bhtca\", x_c, self.w_qkv_c)\n        qkv_p = torch.einsum(\"btf,hfca->bhtca\", x_p, self.w_qkv_p)\n        q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)]\n        q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)]\n        q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor\n        k = torch.cat([k_c, k_p], dim=-1)\n        v = torch.cat([v_c, v_p], dim=-1)\n        dots = torch.einsum(\"bhqa,bhka->bhqk\", q, k)\n        if mask is not None:\n            dots.data.masked_fill_(~mask[:, None, None, :], -float(\"inf\"))\n        probs = F.softmax(dots, dim=-1)\n        probs = self.dropout(probs)\n        o = torch.einsum(\"bhqk,bhka->bhqa\", probs, v)\n        o_c, o_p = torch.chunk(o, 2, dim=-1)\n        out_c = torch.einsum(\"bhta,haf->btf\", o_c, self.w_o_c)\n        out_p = torch.einsum(\"bhta,haf->btf\", o_p, self.w_o_p)\n        return out_c, out_p\n\n\nclass PartitionedTransformerEncoderLayer(nn.Module):\n    def __init__(self,\n                 d_model,\n                 n_head,\n                 d_qkv,\n                 d_ff,\n                 ff_dropout,\n                 residual_dropout,\n                 attention_dropout,\n                 activation=PartitionedReLU(),\n    ):\n        super().__init__()\n        self.self_attn = PartitionedMultiHeadAttention(\n            d_model, n_head, d_qkv, attention_dropout=attention_dropout\n        )\n        self.linear1 = PartitionedLinear(d_model, d_ff)\n        self.ff_dropout = FeatureDropout(ff_dropout)\n        self.linear2 = PartitionedLinear(d_ff, d_model)\n\n        self.norm_attn = nn.LayerNorm(d_model)\n        self.norm_ff = nn.LayerNorm(d_model)\n        self.residual_dropout_attn = FeatureDropout(residual_dropout)\n        self.residual_dropout_ff = FeatureDropout(residual_dropout)\n\n        self.activation = activation\n\n    def forward(self, x, mask=None):\n        residual = self.self_attn(x, mask=mask)\n        residual = torch.cat(residual, dim=-1)\n        residual = self.residual_dropout_attn(residual)\n        x = self.norm_attn(x + residual)\n        residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x))))\n        residual = torch.cat(residual, dim=-1)\n        residual = self.residual_dropout_ff(residual)\n        x = self.norm_ff(x + residual)\n        return x\n\n\nclass PartitionedTransformerEncoder(nn.Module):\n    def __init__(self,\n                 n_layers,\n                 d_model,\n                 n_head,\n                 d_qkv,\n                 d_ff,\n                 ff_dropout,\n                 residual_dropout,\n                 attention_dropout,\n                 activation=PartitionedReLU,\n    ):\n        super().__init__()\n        self.layers = nn.ModuleList([PartitionedTransformerEncoderLayer(d_model=d_model,\n                                                                        n_head=n_head,\n                                                                        d_qkv=d_qkv,\n                                                                        d_ff=d_ff,\n                                                                        ff_dropout=ff_dropout,\n                                                                        residual_dropout=residual_dropout,\n                                                                        attention_dropout=attention_dropout,\n                                                                        activation=activation())\n                                     for i in range(n_layers)])\n\n    def forward(self, x, mask=None):\n        for layer in self.layers:\n            x = layer(x, mask=mask)\n        return x\n\n\nclass ConcatPositionalEncoding(nn.Module):\n    \"\"\"\n    Learns a position embedding\n    \"\"\"\n    def __init__(self, d_model=256, max_len=512):\n        super().__init__()\n        self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model))\n        nn.init.normal_(self.timing_table)\n\n    def forward(self, x):\n        timing = self.timing_table[:x.shape[1], :]\n        timing = timing.expand(x.shape[0], -1, -1)\n        out = torch.cat([x, timing], dim=-1)\n        return out\n\n#\nclass PartitionedTransformerModule(nn.Module):\n    def __init__(self,\n                 n_layers,\n                 d_model,\n                 n_head,\n                 d_qkv,\n                 d_ff,\n                 ff_dropout,\n                 residual_dropout,\n                 attention_dropout,\n                 word_input_size,\n                 bias,\n                 morpho_emb_dropout,\n                 timing,\n                 encoder_max_len,\n                 activation=PartitionedReLU()\n    ):\n        super().__init__()\n        self.project_pretrained = nn.Linear(\n            word_input_size, d_model // 2, bias=bias\n        )\n\n        self.pattention_morpho_emb_dropout = FeatureDropout(morpho_emb_dropout)\n        if timing == 'sin':\n            self.add_timing = ConcatSinusoidalEncoding(d_model=d_model // 2, max_len=encoder_max_len)\n        elif timing == 'learned':\n            self.add_timing = ConcatPositionalEncoding(d_model=d_model // 2, max_len=encoder_max_len)\n        else:\n            raise ValueError(\"Unhandled timing type: %s\" % timing)\n        self.transformer_input_norm = nn.LayerNorm(d_model)\n        self.pattn_encoder = PartitionedTransformerEncoder(\n            n_layers,\n            d_model=d_model,\n            n_head=n_head,\n            d_qkv=d_qkv,\n            d_ff=d_ff,\n            ff_dropout=ff_dropout,\n            residual_dropout=residual_dropout,\n            attention_dropout=attention_dropout,\n        )\n\n\n    #\n    def forward(self, attention_mask, bert_embeddings):\n        # Prepares attention mask for feeding into the self-attention\n        device = bert_embeddings[0].device\n        if attention_mask:\n            valid_token_mask = attention_mask\n        else:\n            valids = []\n            for sent in bert_embeddings:\n                valids.append(torch.ones(len(sent), device=device))\n\n            padded_data = torch.nn.utils.rnn.pad_sequence(\n                valids,\n                batch_first=True,\n                padding_value=-100\n            )\n\n            valid_token_mask = padded_data != -100\n\n        valid_token_mask = valid_token_mask.to(device=device)\n        padded_embeddings = torch.nn.utils.rnn.pad_sequence(\n            bert_embeddings,\n            batch_first=True,\n            padding_value=0\n        )\n\n        # Project the pretrained embedding onto the desired dimension\n        extra_content_annotations = self.project_pretrained(padded_embeddings)\n\n        # Add positional information through the table\n        encoder_in = self.add_timing(self.pattention_morpho_emb_dropout(extra_content_annotations))\n        encoder_in = self.transformer_input_norm(encoder_in)\n        # Put the partitioned input through the partitioned attention\n        annotations = self.pattn_encoder(encoder_in, valid_token_mask)\n\n        return annotations\n\n"
  },
  {
    "path": "stanza/models/constituency/positional_encoding.py",
    "content": "\"\"\"\nBased on\nhttps://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model\n\"\"\"\n\nimport math\n\nimport torch\nfrom torch import nn\n\nclass SinusoidalEncoding(nn.Module):\n    \"\"\"\n    Uses sine & cosine to represent position\n    \"\"\"\n    def __init__(self, model_dim, max_len):\n        super().__init__()\n        self.register_buffer('pe', self.build_position(model_dim, max_len))\n\n    @staticmethod\n    def build_position(model_dim, max_len, device=None):\n        position = torch.arange(max_len).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))\n        pe = torch.zeros(max_len, model_dim)\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        if device is not None:\n            pe = pe.to(device=device)\n        return pe\n\n    def forward(self, x):\n        if max(x) >= self.pe.shape[0]:\n            # try to drop the reference first before creating a new encoding\n            # the goal being to save memory if we are close to the memory limit\n            device = self.pe.device\n            shape = self.pe.shape[1]\n            self.register_buffer('pe', None)\n            # TODO: this may result in very poor performance\n            # in the event of a model that increases size one at a time\n            self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device))\n        return self.pe[x]\n\n    def max_len(self):\n        return self.pe.shape[0]\n\n\nclass AddSinusoidalEncoding(nn.Module):\n    \"\"\"\n    Uses sine & cosine to represent position.  Adds the position to the given matrix\n\n    Default behavior is batch_first\n    \"\"\"\n    def __init__(self, d_model=256, max_len=512):\n        super().__init__()\n        self.encoding = SinusoidalEncoding(d_model, max_len)\n\n    def forward(self, x, scale=1.0):\n        \"\"\"\n        Adds the positional encoding to the input tensor\n\n        The tensor is expected to be of the shape B, N, D\n        Properly masking the output tensor is up to the caller\n        \"\"\"\n        if len(x.shape) == 3:\n            timing = self.encoding(torch.arange(x.shape[1], device=x.device))\n            timing = timing.expand(x.shape[0], -1, -1)\n        elif len(x.shape) == 2:\n            timing = self.encoding(torch.arange(x.shape[0], device=x.device))\n        return x + timing * scale\n\n\nclass ConcatSinusoidalEncoding(nn.Module):\n    \"\"\"\n    Uses sine & cosine to represent position.  Concats the position and returns a larger object\n\n    Default behavior is batch_first\n    \"\"\"\n    def __init__(self, d_model=256, max_len=512):\n        super().__init__()\n        self.encoding = SinusoidalEncoding(d_model, max_len)\n\n    def forward(self, x):\n        if len(x.shape) == 3:\n            timing = self.encoding(torch.arange(x.shape[1], device=x.device))\n            timing = timing.expand(x.shape[0], -1, -1)\n        else:\n            timing = self.encoding(torch.arange(x.shape[0], device=x.device))\n\n        out = torch.cat((x, timing), dim=-1)\n        return out\n"
  },
  {
    "path": "stanza/models/constituency/retagging.py",
    "content": "\"\"\"\nRefactor a few functions specifically for retagging trees\n\nRetagging is important because the gold tags will not be available at runtime\n\nNote that the method which does the actual retagging is in utils.py\nso as to avoid unnecessary circular imports\n(eg, Pipeline imports constituency/trainer which imports this which imports Pipeline)\n\"\"\"\n\nimport copy\nimport logging\n\nfrom stanza import Pipeline\n\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.models.common.vocab import VOCAB_PREFIX\nfrom stanza.resources.common import download_resources_json, load_resources_json, get_language_resources\n\ntlogger = logging.getLogger('stanza.constituency.trainer')\n\n# xpos tagger doesn't produce PP tag on the turin treebank,\n# so instead we use upos to avoid unknown tag errors\nRETAG_METHOD = {\n    \"da\": \"upos\",   # the DDT has no xpos tags anyway\n    \"de\": \"upos\",   # DE GSD is also missing a few punctuation tags\n    \"es\": \"upos\",   # AnCora has half-finished xpos tags\n    \"id\": \"upos\",   # GSD is missing a few punctuation tags - fixed in 2.12, though\n    \"it\": \"upos\",\n    \"pt\": \"upos\",   # default PT model has no xpos either\n    \"vi\": \"xpos\",   # the new version of UD can be merged with xpos from VLSP22\n}\n\ndef add_retag_args(parser):\n    \"\"\"\n    Arguments specifically for retagging treebanks\n    \"\"\"\n    parser.add_argument('--retag_package', default=\"default\", help='Which tagger shortname to use when retagging trees.  None for no retagging.  Retagging is recommended, as gold tags will not be available at pipeline time')\n    parser.add_argument('--retag_method', default=None, choices=['xpos', 'upos'], help='Which tags to use when retagging.  Default depends on the language')\n    parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use.  Will use a downloaded Stanza model by default.  Can specify multiple taggers with ; in which case the majority vote wins')\n    parser.add_argument('--retag_pretrain_path', default=None, help='Use this for a pretrain path for the retagging pipeline.  Generally not needed unless using a custom POS model with a custom pretrain')\n    parser.add_argument('--retag_charlm_forward_file', default=None, help='Use this for a forward charlm path for the retagging pipeline.  Generally not needed unless using a custom POS model with a custom charlm')\n    parser.add_argument('--retag_charlm_backward_file', default=None, help='Use this for a backward charlm  path for the retagging pipeline.  Generally not needed unless using a custom POS model with a custom charlm')\n    parser.add_argument('--no_retag', dest='retag_package', action=\"store_const\", const=None, help=\"Don't retag the trees\")\n\ndef postprocess_args(args):\n    \"\"\"\n    After parsing args, unify some settings\n    \"\"\"\n    # use a language specific default for retag_method if we know the language\n    # otherwise, use xpos\n    if args['retag_method'] is None and 'lang' in args and args['lang'] in RETAG_METHOD:\n        args['retag_method'] = RETAG_METHOD[args['lang']]\n    if args['retag_method'] is None:\n        args['retag_method'] = 'xpos'\n\n    if args['retag_method'] == 'xpos':\n        args['retag_xpos'] = True\n    elif args['retag_method'] == 'upos':\n        args['retag_xpos'] = False\n    else:\n        raise ValueError(\"Unknown retag method {}\".format(xpos))\n\ndef build_retag_pipeline(args):\n    \"\"\"\n    Builds retag pipelines based on the arguments\n\n    May alter the arguments if the pipeline is incompatible, such as\n    taggers with no xpos\n\n    Will return a list of one or more retag pipelines.\n    Multiple tagger models can be specified by having them\n    semi-colon separated in retag_model_path.\n    \"\"\"\n    # some argument sets might not use 'mode'\n    if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer':\n        download_resources_json()\n        resources = load_resources_json()\n\n        if '_' in args['retag_package']:\n            lang, package = args['retag_package'].split('_', 1)\n            lang_resources = get_language_resources(resources, lang)\n            if lang_resources is None and 'lang' in args:\n                lang_resources = get_language_resources(resources, args['lang'])\n                if lang_resources is not None and 'pos' in lang_resources and args['retag_package'] in lang_resources['pos']:\n                    lang = args['lang']\n                    package = args['retag_package']\n        else:\n            if 'lang' not in args:\n                raise ValueError(\"Retag package %s does not specify the language, and it is not clear from the arguments\" % args['retag_package'])\n            lang = args.get('lang', None)\n            package = args['retag_package']\n        foundation_cache = FoundationCache()\n        retag_args = {\"lang\": lang,\n                      \"processors\": \"tokenize, pos\",\n                      \"tokenize_pretokenized\": True,\n                      \"package\": {\"pos\": package}}\n        if args['retag_pretrain_path'] is not None:\n            retag_args['pos_pretrain_path'] = args['retag_pretrain_path']\n        if args['retag_charlm_forward_file'] is not None:\n            retag_args['pos_forward_charlm_path'] = args['retag_charlm_forward_file']\n        if args['retag_charlm_backward_file'] is not None:\n            retag_args['pos_backward_charlm_path'] = args['retag_charlm_backward_file']\n\n        def build(retag_args, path):\n            retag_args = copy.deepcopy(retag_args)\n            # we just downloaded the resources a moment ago\n            # no need to repeatedly download\n            retag_args['download_method'] = 'reuse_resources'\n            if path is not None:\n                retag_args['allow_unknown_language'] = True\n                retag_args['pos_model_path'] = path\n                tlogger.debug('Creating retag pipeline using %s', path)\n            else:\n                tlogger.debug('Creating retag pipeline for %s package', package)\n\n            retag_pipeline = Pipeline(foundation_cache=foundation_cache, **retag_args)\n            if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX):\n                tlogger.warning(\"XPOS for the %s tagger is empty.  Switching to UPOS\", package)\n                args['retag_xpos'] = False\n                args['retag_method'] = 'upos'\n            return retag_pipeline\n\n        if args['retag_model_path'] is None:\n            return [build(retag_args, None)]\n        paths = args['retag_model_path'].split(\";\")\n        # can be length 1 if only one tagger to work with\n        return [build(retag_args, path) for path in paths]\n\n    return None\n"
  },
  {
    "path": "stanza/models/constituency/score_converted_dependencies.py",
    "content": "\"\"\"\nScript which processes a dependency file by using the constituency parser, then converting with the CoreNLP converter\n\nCurrently this does not have the constituency parser as an option,\nalthough that is easy to add.\n\nOnly English is supported, as only English is available in the CoreNLP converter\n\"\"\"\n\nimport argparse\nimport os\nimport tempfile\n\nimport stanza\nfrom stanza.models.constituency import retagging\nfrom stanza.models.depparse import scorer\nfrom stanza.utils.conll import CoNLL\n\ndef score_converted_dependencies(args):\n    if args['lang'] != 'en':\n        raise ValueError(\"Converting and scoring dependencies is currently only supported for English\")\n\n    constituency_package = args['constituency_package']\n    pipeline_args = {'lang': args['lang'],\n                     'tokenize_pretokenized': True,\n                     'package': {'pos': args['retag_package'], 'depparse': 'converter', 'constituency': constituency_package},\n                     'processors': 'tokenize, pos, constituency, depparse'}\n    pipeline = stanza.Pipeline(**pipeline_args)\n\n    input_doc = CoNLL.conll2doc(args['eval_file'])\n    output_doc = pipeline(input_doc)\n    print(\"Processed %d sentences\" % len(output_doc.sentences))\n    # reload - the pipeline clobbered the gold values\n    input_doc = CoNLL.conll2doc(args['eval_file'])\n\n    scorer.score_named_dependencies(output_doc, input_doc)\n    with tempfile.TemporaryDirectory() as tempdir:\n        output_path = os.path.join(tempdir, \"converted.conll\")\n\n        CoNLL.write_doc2conll(output_doc, output_path)\n\n        _, _, score = scorer.score(output_path, args['eval_file'])\n\n        print(\"Parser score:\")\n        print(\"{} {:.2f}\".format(constituency_package, score*100))\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--lang', default='en', type=str, help='Language')\n    parser.add_argument('--eval_file', default=\"extern_data/ud2/ud-treebanks-v2.13/UD_English-EWT/en_ewt-ud-test.conllu\", help='Input file for data loader.')\n    parser.add_argument('--constituency_package', default=\"ptb3-revised_electra-large\", help='Which constituency parser to use for converting')\n\n    retagging.add_retag_args(parser)\n    args = parser.parse_args()\n\n    args = vars(args)\n    retagging.postprocess_args(args)\n\n    score_converted_dependencies(args)\n\nif __name__ == '__main__':\n    main()\n    \n"
  },
  {
    "path": "stanza/models/constituency/state.py",
    "content": "from collections import namedtuple\n\nclass State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'gold_tree', 'gold_sequence',\n                                 'sentence_length', 'num_opens', 'word_position', 'score', 'broken'])):\n    \"\"\"\n    Represents a partially completed transition parse\n\n    Includes stack/buffers for unused words, already executed transitions, and partially build constituents\n    At training time, also keeps track of the gold data we are reparsing\n\n    num_opens is useful for tracking\n       1) if the parser is in a stuck state where it is making infinite opens\n       2) if a close transition is impossible because there are no previous opens\n\n    sentence_length tracks how long the sentence is so we abort if we go infinite\n\n    non-stack information such as sentence_length and num_opens\n    will be copied from the original_state if possible, with the\n    exact arguments overriding the values in the original_state\n\n    gold_tree: the original tree, if made from a gold tree.  might be None\n    gold_sequence: the original transition sequence, if available\n    Note that at runtime, gold values will not be available\n\n    word_position tracks where in the word queue we are.  cheaper than\n      manipulating the list itself.  this can be handled differently\n      from transitions and constituents as it is processed once\n      at the start of parsing\n\n    The word_queue should have both a start and an end word.\n    Those can be None in the case of the endpoints if they are unused.\n    \"\"\"\n    def empty_word_queue(self):\n        # the first element of each stack is a sentinel with no value\n        # and no parent\n        return self.word_position == self.sentence_length\n\n    def empty_transitions(self):\n        # the first element of each stack is a sentinel with no value\n        # and no parent\n        return self.transitions.parent is None\n\n    def has_one_constituent(self):\n        # a length of 1 represents no constituents\n        return self.constituents.length == 2\n\n    @property\n    def empty_constituents(self):\n        return self.constituents.parent is None\n\n    def num_constituents(self):\n        return self.constituents.length - 1\n\n    @property\n    def num_transitions(self):\n        # -1 for the sentinel value\n        return self.transitions.length - 1\n\n    def get_word(self, pos):\n        # +1 to handle the initial sentinel value\n        # (which you can actually get with pos=-1)\n        return self.word_queue[pos+1]\n\n    def finished(self, model):\n        return self.empty_word_queue() and self.has_one_constituent() and model.get_top_constituent(self.constituents).label in model.root_labels\n\n    def get_tree(self, model):\n        return model.get_top_constituent(self.constituents)\n\n    def all_transitions(self, model):\n        # TODO: rewrite this to be nicer / faster?  or just refactor?\n        all_transitions = []\n        transitions = self.transitions\n        while transitions.parent is not None:\n            all_transitions.append(model.get_top_transition(transitions))\n            transitions = transitions.parent\n        return list(reversed(all_transitions))\n\n    def all_constituents(self, model):\n        # TODO: rewrite this to be nicer / faster?\n        all_constituents = []\n        constituents = self.constituents\n        while constituents.parent is not None:\n            all_constituents.append(model.get_top_constituent(constituents))\n            constituents = constituents.parent\n        return list(reversed(all_constituents))\n\n    def all_words(self, model):\n        return [model.get_word(x) for x in self.word_queue]\n\n    def to_string(self, model):\n        return \"State(\\n  buffer:%s\\n  transitions:%s\\n  constituents:%s\\n  word_position:%d num_opens:%d)\" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)), self.word_position, self.num_opens)\n\n    def __str__(self):\n        return \"State(\\n  buffer:%s\\n  transitions:%s\\n  constituents:%s)\" % (str(self.word_queue), str(self.transitions), str(self.constituents))\n\nclass MultiState(namedtuple('MultiState', ['states', 'gold_tree', 'gold_sequence', 'score'])):\n    def finished(self, ensemble):\n        return self.states[0].finished(ensemble.models[0])\n\n    def get_tree(self, ensemble):\n        return self.states[0].get_tree(ensemble.models[0])\n\n    @property\n    def empty_constituents(self):\n        return self.states[0].empty_constituents\n\n    def num_constituents(self):\n        return len(self.states[0].constituents) - 1\n\n    @property\n    def num_transitions(self):\n        # -1 for the sentinel value\n        return len(self.states[0].transitions) - 1\n\n    @property\n    def num_opens(self):\n        return self.states[0].num_opens\n\n    @property\n    def sentence_length(self):\n        return self.states[0].sentence_length\n\n    def empty_word_queue(self):\n        return self.states[0].empty_word_queue()\n\n    def empty_transitions(self):\n        return self.states[0].empty_transitions()\n\n    @property\n    def constituents(self):\n        # warning! if there is information in the constituents such as\n        # the embedding of the constituent, this will only contain the\n        # first such embedding\n        # the other models' constituent states won't be returned\n        return self.states[0].constituents\n\n    @property\n    def transitions(self):\n        # warning! if there is information in the transitions such as\n        # the embedding of the transition, this will only contain the\n        # first such embedding\n        # the other models' transition states won't be returned\n        return self.states[0].transitions\n"
  },
  {
    "path": "stanza/models/constituency/text_processing.py",
    "content": "import os\n\nimport logging\n\nfrom stanza.models.common import utils\nfrom stanza.models.constituency.utils import retag_tags\nfrom stanza.models.constituency.trainer import Trainer\nfrom stanza.models.constituency.tree_reader import read_trees\nfrom stanza.utils.get_tqdm import get_tqdm\n\nlogger = logging.getLogger('stanza')\ntqdm = get_tqdm()\n\ndef read_tokenized_file(tokenized_file):\n    \"\"\"\n    Read sentences from a tokenized file, potentially replacing _ with space for languages such as VI\n    \"\"\"\n    with open(tokenized_file, encoding='utf-8') as fin:\n        lines = fin.readlines()\n    lines = [x.strip() for x in lines]\n    lines = [x for x in lines if x]\n    docs = [[word if all(x == '_' for x in word) else word.replace(\"_\", \" \") for word in sentence.split()] for sentence in lines]\n    ids = [None] * len(docs)\n    return docs, ids\n\ndef read_xml_tree_file(tree_file):\n    \"\"\"\n    Read sentences from a file of the format unique to VLSP test sets\n\n    in particular, it should be multiple blocks of\n\n    <s id=1>\n      (tree ...)\n    </s>\n    \"\"\"\n    with open(tree_file, encoding='utf-8') as fin:\n        lines = fin.readlines()\n    lines = [x.strip() for x in lines]\n    lines = [x for x in lines if x]\n    docs = []\n    ids = []\n    tree_id = None\n    tree_text = []\n    for line in lines:\n        if line.startswith(\"<s\"):\n            tree_id = line.split(\"=\")\n            if len(tree_id) > 1:\n                tree_id = tree_id[1]\n                if tree_id.endswith(\">\"):\n                    tree_id = tree_id[:-1]\n                tree_id = int(tree_id)\n            else:\n                tree_id = None\n        elif line.startswith(\"</s\"):\n            if len(tree_text) == 0:\n                raise ValueError(\"Found a blank tree in %s\" % tree_file)\n            ids.append(tree_id)\n            tree_text = \"\\n\".join(tree_text)\n            trees = read_trees(tree_text)\n            # TODO: perhaps the processing can be put into read_trees instead\n            trees = [t.prune_none().simplify_labels() for t in trees]\n            if len(trees) != 1:\n                raise ValueError(\"Found a tree with %d trees in %s\" % (len(trees), tree_file))\n            tree = trees[0]\n            text = tree.leaf_labels()\n            text = [word if all(x == '_' for x in word) else word.replace(\"_\", \" \") for word in text]\n            docs.append(text)\n            tree_text = []\n            tree_id = None\n        else:\n            tree_text.append(line)\n\n    return docs, ids\n\n\ndef parse_tokenized_sentences(args, model, retag_pipeline, sentences):\n    \"\"\"\n    Parse the given sentences, return a list of ParseResult objects\n    \"\"\"\n    tags = retag_tags(sentences, retag_pipeline, model.uses_xpos())\n    words = [[(word, tag) for word, tag in zip(s_words, s_tags)] for s_words, s_tags in zip(sentences, tags)]\n    logger.info(\"Retagging finished.  Parsing tagged text\")\n\n    assert len(words) == len(sentences)\n    treebank = model.parse_sentences_no_grad(iter(tqdm(words)), model.build_batch_from_tagged_words, args['eval_batch_size'], model.predict, keep_scores=False)\n    return treebank\n\ndef parse_text(args, model, retag_pipeline, tokenized_file=None, predict_file=None):\n    \"\"\"\n    Use the given model to parse text and write it\n\n    refactored so it can be used elsewhere, such as Ensemble\n    \"\"\"\n    model.eval()\n\n    if predict_file is None:\n        if args['predict_file']:\n            predict_file = args['predict_file']\n            if args['predict_dir']:\n                predict_file = os.path.join(args['predict_dir'], predict_file)\n\n    if tokenized_file is None:\n        tokenized_file = args['tokenized_file']\n\n    docs, ids = None, None\n    if tokenized_file is not None:\n        docs, ids = read_tokenized_file(tokenized_file)\n    elif args['xml_tree_file']:\n        logger.info(\"Reading trees from %s\" % args['xml_tree_file'])\n        docs, ids = read_xml_tree_file(args['xml_tree_file'])\n\n    if not docs:\n        logger.error(\"No sentences to process!\")\n        return\n\n    logger.info(\"Processing %d sentences\", len(docs))\n\n    with utils.output_stream(predict_file) as fout:\n        chunk_size = 10000\n        for chunk_start in range(0, len(docs), chunk_size):\n            chunk = docs[chunk_start:chunk_start+chunk_size]\n            ids_chunk = ids[chunk_start:chunk_start+chunk_size]\n            logger.info(\"Processing trees %d to %d\", chunk_start, chunk_start+len(chunk))\n            treebank = parse_tokenized_sentences(args, model, retag_pipeline, chunk)\n\n            for result, tree_id in zip(treebank, ids_chunk):\n                tree = result.predictions[0].tree\n                if tree_id is not None:\n                    tree.tree_id = tree_id\n                fout.write(args['predict_format'].format(tree))\n                fout.write(\"\\n\")\n\ndef parse_dir(args, model, retag_pipeline, tokenized_dir, predict_dir):\n    os.makedirs(predict_dir, exist_ok=True)\n    for filename in os.listdir(tokenized_dir):\n        input_path = os.path.join(tokenized_dir, filename)\n        output_path = os.path.join(predict_dir, os.path.splitext(filename)[0] + \".mrg\")\n        logger.info(\"Processing %s to %s\", input_path, output_path)\n        parse_text(args, model, retag_pipeline, tokenized_file=input_path, predict_file=output_path)\n\n\ndef load_model_parse_text(args, model_file, retag_pipeline):\n    \"\"\"\n    Load a model, then parse text and write it to stdout or args['predict_file']\n\n    retag_pipeline: a list of Pipeline meant to use for retagging\n    \"\"\"\n    foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()\n    load_args = {\n        'wordvec_pretrain_file': args['wordvec_pretrain_file'],\n        'charlm_forward_file': args['charlm_forward_file'],\n        'charlm_backward_file': args['charlm_backward_file'],\n        'device': args['device'],\n    }\n    trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)\n    model = trainer.model\n    model.eval()\n    logger.info(\"Loaded model from %s\", model_file)\n\n    if args['tokenized_dir']:\n        if not args['predict_dir']:\n            raise ValueError(\"Must specific --predict_dir to go with --tokenized_dir\")\n        parse_dir(args, model, retag_pipeline, args['tokenized_dir'], args['predict_dir'])\n    else:\n        parse_text(args, model, retag_pipeline)\n\n"
  },
  {
    "path": "stanza/models/constituency/top_down_oracle.py",
    "content": "from enum import Enum\nimport random\n\nfrom stanza.models.constituency.dynamic_oracle import advance_past_constituents, score_candidates_single_block, DynamicOracle, RepairEnum\nfrom stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent\n\ndef find_constituent_end(gold_sequence, cur_index):\n    \"\"\"\n    Find the Close which ends the next constituent opened at or after cur_index\n    \"\"\"\n    count = 0\n    while cur_index < len(gold_sequence):\n        if isinstance(gold_sequence[cur_index], OpenConstituent):\n            count = count + 1\n        elif isinstance(gold_sequence[cur_index], CloseConstituent):\n            count = count - 1\n            if count == 0:\n                return cur_index\n        cur_index += 1\n    raise AssertionError(\"Open constituent not closed starting from index %d in sequence %s\" % (cur_index, gold_sequence))\n\ndef fix_shift_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Predicted a close when we should have shifted\n\n    The fix here is to remove the corresponding close from later in\n    the transition sequence.  The rest of the tree building is the same,\n    including doing the missing Shift immediately after\n\n    Anything else would make the situation of one precision, one\n    recall error worse\n    \"\"\"\n    if not isinstance(pred_transition, CloseConstituent):\n        return None\n\n    if not isinstance(gold_transition, Shift):\n        return None\n\n    close_index = advance_past_constituents(gold_sequence, gold_index)\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + gold_sequence[close_index+1:]\n\ndef fix_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Predicted a close when we should have opened a constituent\n\n    In this case, the previous constituent is now a precision and\n    recall error, BUT we can salvage the constituent we were about to\n    open by proceeding as if everything else is still the same.\n\n    The next thing the model should do is open the transition it forgot about\n    \"\"\"\n    if not isinstance(pred_transition, CloseConstituent):\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    close_index = advance_past_constituents(gold_sequence, gold_index)\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + gold_sequence[close_index+1:]\n\ndef fix_one_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Predicted a shift when we should have opened a constituent\n\n    This causes a single recall error if we just pretend that\n    constituent didn't exist\n\n    Keep the shift where it was, remove the next shift\n    Also, scroll ahead, find the corresponding close, cut it out\n\n    For the corresponding multiple opens, shift error, see fix_multiple_open_shift\n    \"\"\"\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_sequence[gold_index + 1], Shift):\n        return None\n\n    shift_index = gold_index + 1\n    close_index = advance_past_constituents(gold_sequence, gold_index + 1)\n    if close_index is None:\n        return None\n    # gold_index is the skipped open constituent\n    # close_index was the corresponding close\n    # shift_index is the shift to remove\n    updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:shift_index] + gold_sequence[shift_index+1:close_index] + gold_sequence[close_index+1:]\n    #print(\"Input sequence: %s\\nIndex %d\\nGold %s Pred %s\\nUpdated sequence %s\" % (gold_sequence, gold_index, gold_transition, pred_transition, updated_sequence))\n    return updated_sequence\n\ndef fix_multiple_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Predicted a shift when we should have opened multiple constituents instead\n\n    This causes a single recall error per constituent if we just\n    pretend those constituents don't exist\n\n    For each open constituent, we find the corresponding close,\n    then remove both the open & close\n    \"\"\"\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    shift_index = gold_index\n    while shift_index < len(gold_sequence) and isinstance(gold_sequence[shift_index], OpenConstituent):\n        shift_index += 1\n    if shift_index >= len(gold_sequence):\n        raise AssertionError(\"Found a sequence of OpenConstituent at the end of a TOP_DOWN sequence!\")\n    if not isinstance(gold_sequence[shift_index], Shift):\n        raise AssertionError(\"Expected to find a Shift after a sequence of OpenConstituent.  There should not be a %s\" % gold_sequence[shift_index])\n\n    #print(\"Input sequence: %s\\nIndex %d\\nGold %s Pred %s\" % (gold_sequence, gold_index, gold_transition, pred_transition))\n    updated_sequence = gold_sequence\n    while shift_index > gold_index:\n        close_index = advance_past_constituents(updated_sequence, shift_index)\n        if close_index is None:\n            raise AssertionError(\"Did not find a corresponding Close for this Open\")\n        # cut out the corresponding open and close\n        updated_sequence = updated_sequence[:shift_index-1] + updated_sequence[shift_index:close_index] + updated_sequence[close_index+1:]\n        shift_index -= 1\n        #print(\"  %s\" % updated_sequence)\n\n    #print(\"Final updated sequence: %s\" % updated_sequence)\n    return updated_sequence\n\ndef fix_nested_open_constituent(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    We were supposed to predict Open(X), then Open(Y), but predicted Open(Y) instead\n\n    We treat this as a single recall error.\n\n    We could even go crazy and turn it into a Unary,\n    such as Open(Y), Open(X), Open(Y)...\n    presumably that would be very confusing to the parser\n    not to mention ambiguous as to where to close the new constituent\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    assert len(gold_sequence) > gold_index + 1\n\n    if not isinstance(gold_sequence[gold_index+1], OpenConstituent):\n        return None\n\n    # This replacement works if we skipped exactly one level\n    if gold_sequence[gold_index+1].label != pred_transition.label:\n        return None\n\n    close_index = advance_past_constituents(gold_sequence, gold_index+1)\n    assert close_index is not None\n    updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:]\n    return updated_sequence\n\ndef fix_shift_open_immediate_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    We were supposed to Shift, but instead we Opened\n\n    The biggest problem with this type of error is that the Close of\n    the Open is ambiguous.  We could put it immediately before the\n    next Close, immediately after the Shift, or anywhere in between.\n\n    One unambiguous case would be if the proper sequence was Shift - Close.\n    Then it is unambiguous that the only possible repair is Open - Shift - Close - Close.\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, Shift):\n        return None\n\n    assert len(gold_sequence) > gold_index + 1\n    if not isinstance(gold_sequence[gold_index+1], CloseConstituent):\n        # this is the ambiguous case\n        return None\n\n    return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:]\n\ndef fix_shift_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    We were supposed to Shift, but instead we Opened\n\n    The biggest problem with this type of error is that the Close of\n    the Open is ambiguous.  We could put it immediately before the\n    next Close, immediately after the Shift, or anywhere in between.\n\n    In this fix, we are testing what happens if we treat this Open as a Unary transition.\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, Shift):\n        return None\n\n    assert len(gold_sequence) > gold_index + 1\n    if isinstance(gold_sequence[gold_index+1], CloseConstituent):\n        # this is the unambiguous case, which should already be handled\n        return None\n\n    return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:]\n\ndef fix_shift_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    We were supposed to Shift, but instead we Opened\n\n    The biggest problem with this type of error is that the Close of\n    the Open is ambiguous.  We could put it immediately before the\n    next Close, immediately after the Shift, or anywhere in between.\n\n    In this fix, we put the corresponding Close for this Open at the end of the enclosing bracket.\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, Shift):\n        return None\n\n    assert len(gold_sequence) > gold_index + 1\n    if isinstance(gold_sequence[gold_index+1], CloseConstituent):\n        # this is the unambiguous case, which should already be handled\n        return None\n\n    outer_close_index = advance_past_constituents(gold_sequence, gold_index)\n\n    return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:outer_close_index] + [CloseConstituent()] + gold_sequence[outer_close_index:]\n\ndef fix_shift_open_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, Shift):\n        return None\n\n    assert len(gold_sequence) > gold_index + 1\n    if isinstance(gold_sequence[gold_index+1], CloseConstituent):\n        # this is the unambiguous case, which should already be handled\n        return None\n\n    # at this point: have Opened a constituent which we don't want\n    # need to figure out where to Close it\n    # could close it after the shift or after any given block\n    candidates = []\n    current_index = gold_index\n    while not isinstance(gold_sequence[current_index], CloseConstituent):\n        if isinstance(gold_sequence[current_index], Shift):\n            end_index = current_index\n        else:\n            end_index = find_constituent_end(gold_sequence, current_index)\n        candidates.append((gold_sequence[:gold_index], [pred_transition], gold_sequence[gold_index:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:]))\n        current_index = end_index + 1\n\n    scores, best_idx, best_candidate = score_candidates_single_block(model, state, candidates, candidate_idx=3)\n    if best_idx == len(candidates) - 1:\n        best_idx = -1\n    repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.name,\n                             value=\"%d.%d\" % (RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.value, best_idx),\n                             is_correct=False)\n    return repair_type, best_candidate\n\n\ndef fix_close_shift_ambiguous_immediate(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Instead of a Close, we predicted a Shift.  This time, we immediately close no matter what comes after the next Shift.\n\n    An alternate strategy would be to Close at the closing of the outer constituent.\n    \"\"\"\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n\n    num_closes = 0\n    while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):\n        num_closes += 1\n\n    if not isinstance(gold_sequence[gold_index + num_closes], Shift):\n        # TODO: we should be able to handle this case too (an Open)\n        # however, it will be rare once the parser gets going and it\n        # would cause a lot of errors, anyway\n        return None\n\n    if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent):\n        # this one should just have been satisfied in the non-ambiguous version\n        return None\n\n    updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+1:]\n    return updated_sequence\n\n\ndef fix_close_shift_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    Instead of a Close, we predicted a Shift.  This time, we close at the end of the outer bracket no matter what comes after the next Shift.\n\n    An alternate strategy would be to Close as soon as possible after the Shift.\n    \"\"\"\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n\n    num_closes = 0\n    while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):\n        num_closes += 1\n\n    if not isinstance(gold_sequence[gold_index + num_closes], Shift):\n        # TODO: we should be able to handle this case too (an Open)\n        # however, it will be rare once the parser gets going and it\n        # would cause a lot of errors, anyway\n        return None\n\n    if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent):\n        # this one should just have been satisfied in the non-ambiguous version\n        return None\n\n    # outer_close_index is now where the constituent which the broken constituent(s) reside inside gets closed\n    outer_close_index = advance_past_constituents(gold_sequence, gold_index + num_closes)\n\n    updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+num_closes:outer_close_index] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[outer_close_index:]\n    return updated_sequence\n\n\ndef fix_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, count_opens=False):\n    \"\"\"\n    We were supposed to Close, but instead did a Shift\n\n    In most cases, this will be ambiguous.  There is now a constituent\n    which has been missed, no matter what we do, and we are on the\n    hook for eventually closing this constituent, creating a precision\n    error as well.  The ambiguity arises because there will be\n    multiple places where the Close could occur if there are more\n    constituents created between now and when the outer constituent is\n    Closed.\n\n    The non-ambiguous case is if the proper sequence was\n      Close - Shift - Close\n    similar cases are also non-ambiguous, such as\n      Close - Close - Shift - Close\n    for that matter, so is the following, although the Opens will be lost\n      Close - Open - Shift - Close - Close\n\n    count_opens is an option to make it easy to count with or without\n      Open as different oracle fixes\n    \"\"\"\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n\n    num_closes = 0\n    while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):\n        num_closes += 1\n\n    # We may allow unary transitions here\n    # the opens will be lost in the repaired sequence\n    num_opens = 0\n    if count_opens:\n        while isinstance(gold_sequence[gold_index + num_closes + num_opens], OpenConstituent):\n            num_opens += 1\n\n    if not isinstance(gold_sequence[gold_index + num_closes + num_opens], Shift):\n        if count_opens:\n            raise AssertionError(\"Should have found a Shift after a sequence of Opens or a Close with no Open.  Started counting at %d in sequence %s\" % (gold_index, gold_sequence))\n        return None\n\n    if not isinstance(gold_sequence[gold_index + num_closes + num_opens + 1], CloseConstituent):\n        return None\n    for idx in range(num_opens):\n        if not isinstance(gold_sequence[gold_index + num_closes + num_opens + idx + 1], CloseConstituent):\n            return None\n\n    # Now we know it is Close x num_closes, Shift, Close\n    # Since we have erroneously predicted a Shift now, the best we can\n    # do is to follow that, then add num_closes Closes\n    updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+num_opens*2+1:]\n    return updated_sequence\n\ndef fix_close_shift_with_opens(*args, **kwargs):\n    return fix_close_shift(*args, **kwargs, count_opens=True)\n\ndef fix_close_next_correct_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    We were supposed to Close, but instead predicted Shift when the next transition is Shift\n\n    This differs from the previous Close-Shift in that this case does\n    not have an unambiguous place to put the Close.  Instead, we let\n    the model predict where to put the Close\n\n    Note that this can also work for Close-Open with the next Open correct\n\n    Not covered (yet?) is multiple Close in a row\n    \"\"\"\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, (Shift, OpenConstituent)):\n        return None\n    if gold_sequence[gold_index+1] != pred_transition:\n        return None\n\n    candidates = []\n    current_index = gold_index + 1\n    while not isinstance(gold_sequence[current_index], CloseConstituent):\n        if isinstance(gold_sequence[current_index], Shift):\n            end_index = current_index\n        else:\n            end_index = find_constituent_end(gold_sequence, current_index)\n        candidates.append((gold_sequence[:gold_index], gold_sequence[gold_index+1:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:]))\n        current_index = end_index + 1\n\n    scores, best_idx, best_candidate = score_candidates_single_block(model, state, candidates, candidate_idx=3)\n    if best_idx == len(candidates) - 1:\n        best_idx = -1\n    repair_type = RepairEnum(name=RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.name,\n                             value=\"%d.%d\" % (RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.value, best_idx),\n                             is_correct=False)\n    return repair_type, best_candidate\n\n\ndef fix_close_open_correct_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True):\n    \"\"\"\n    We were supposed to Close, but instead did an Open\n\n    In general this is ambiguous (like close/shift), as we need to know when to close the incorrect constituent\n\n    A case that is not ambiguous is when exactly one constituent was\n    supposed to come after the Close and it matches the Open we just\n    created.  In that case, we treat that constituent as if it were\n    part of the non-Closed constituent.  For example,\n    \"ate (NP spaghetti) (PP with a fork)\" ->\n    \"ate (NP spaghetti (PP with a fork))\"\n    (delicious)\n\n    There is also an option to not check for the Close after the first\n    constituent, in which case any number of constituents could have\n    been predicted.  This represents a solution of the ambiguous form\n    of the Close/Open transition where the Close could occur in\n    multiple places later in the sequence.\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n\n    if gold_sequence[gold_index+1] != pred_transition:\n        return None\n\n    close_index = find_constituent_end(gold_sequence, gold_index+1)\n    if check_close and not isinstance(gold_sequence[close_index+1], CloseConstituent):\n        return None\n\n    # at this point, we know we can put the Close at the end of the\n    # Open which was accidentally added\n    updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index+1] + [gold_transition] + gold_sequence[close_index+1:]\n    return updated_sequence\n\ndef fix_close_open_correct_open_ambiguous_immediate(*args, **kwargs):\n    return fix_close_open_correct_open(*args, **kwargs, check_close=False)\n\ndef fix_close_open_correct_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True):\n    \"\"\"\n    We were supposed to Close, but instead did an Open in an ambiguous context.  Here we resolve it later in the tree\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n\n    if gold_sequence[gold_index+1] != pred_transition:\n        return None\n\n    # this will be the index of the Close for the surrounding constituent\n    close_index = advance_past_constituents(gold_sequence, gold_index+1)\n    updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + [gold_transition] + gold_sequence[close_index:]\n    return updated_sequence\n\ndef fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    If there is an Open/Open error which is not covered by the unambiguous single recall error, we try fixing it as a Unary\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    if pred_transition == gold_transition:\n        return None\n    if gold_sequence[gold_index+1] == pred_transition:\n        # This case is covered by the nested open repair\n        return None\n\n    close_index = find_constituent_end(gold_sequence, gold_index)\n    assert close_index is not None\n    assert isinstance(gold_sequence[close_index], CloseConstituent)\n    updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:]\n    return updated_sequence\n\ndef fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    If there is an Open/Open error which is not covered by the\n    unambiguous single recall error, we try fixing it by putting the\n    close at the end of the outer constituent\n\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    if pred_transition == gold_transition:\n        return None\n    if gold_sequence[gold_index+1] == pred_transition:\n        # This case is covered by the nested open repair\n        return None\n\n    close_index = advance_past_constituents(gold_sequence, gold_index)\n    updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:]\n    return updated_sequence\n\ndef fix_open_open_ambiguous_random(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    \"\"\"\n    If there is an Open/Open error which is not covered by the\n    unambiguous single recall error, we try fixing it by putting the\n    close at the end of the outer constituent\n\n    \"\"\"\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n\n    if pred_transition == gold_transition:\n        return None\n    if gold_sequence[gold_index+1] == pred_transition:\n        # This case is covered by the nested open repair\n        return None\n\n    if random.random() < 0.5:\n        return fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels)\n    else:\n        return fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels)\n\n\ndef report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, Shift):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    return RepairType.OTHER_SHIFT_OPEN, None\n\n\ndef report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, Shift):\n        return None\n\n    return RepairType.OTHER_CLOSE_SHIFT, None\n\ndef report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, CloseConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    return RepairType.OTHER_CLOSE_OPEN, None\n\ndef report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):\n    if not isinstance(gold_transition, OpenConstituent):\n        return None\n    if not isinstance(pred_transition, OpenConstituent):\n        return None\n\n    return RepairType.OTHER_OPEN_OPEN, None\n\n\nclass RepairType(Enum):\n    \"\"\"\n    Keep track of which repair is used, if any, on an incorrect transition\n\n    A test of the top-down oracle with no charlm or transformer\n      (eg, word vectors only) on EN PTB3 goes as follows.\n      3x training rounds, best training parameters as of Jan. 2024\n    unambiguous transitions only:\n        oracle scheme         dev        test\n      no oracle              0.9230     0.9194\n       +shift/close          0.9224     0.9180\n       +open/close           0.9225     0.9193\n       +open/shift (one)     0.9245     0.9207\n       +open/shift (mult)    0.9243     0.9211\n       +open/open nested     0.9258     0.9213\n       +shift/open           0.9266     0.9229\n       +close/shift (only)   0.9270     0.9230\n       +close/shift w/ opens 0.9262     0.9221\n       +close/open one con   0.9273     0.9230\n\n    Potential solutions for various ambiguous transitions:\n\n    close/open\n      can close immediately after the corresponding constituent or after any number of constituents\n\n    close/shift\n      can close immediately\n      can close anywhere up to the next close\n      any number of missed Opens are treated as recall errors\n\n    open/open\n      could treat as unary\n      could close at any number of positions after the next structures, up to the outer open's closing\n\n    shift/open ambiguity resolutions:\n      treat as unary\n      treat as wrapper around the next full constituent to build\n      treat as wrapper around everything to build until the next constituent\n\n    testing one at a time in addition to the full set of unambiguous corrections:\n       +close/open immediate   0.9259     0.9225\n       +close/open later       0.9258     0.9257\n       +close/shift immediate  0.9261     0.9219\n       +close/shift later      0.9270     0.9230\n       +open/open later        0.9269     0.9239\n       +open/open unary        0.9275     0.9246\n       +shift/open later       0.9263     0.9253\n       +shift/open unary       0.9264     0.9243\n\n    so there is some evidence that open/open or shift/open would be beneficial\n\n    Training by randomly choosing between the open/open, 50/50\n       +open/open random       0.9257     0.9235\n    so that didn't work great compared to the individual transitions\n\n    Testing deterministic resolutions of the ambiguous transitions\n    vs predicting the appropriate transition to use:\n    SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR,CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR,CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR\n    SHIFT_OPEN_AMBIGUOUS_PREDICTED,CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED\n\n    EN ambiguous (no charlm or transformer)   0.9268   0.9231\n    EN predicted                              0.9270   0.9257\n    EN none of the above                      0.9268   0.9229\n\n    ZH ambiguous                              0.9137   0.9127\n    ZH predicted                              0.9148   0.9141\n    ZH none of the above                      0.9141   0.9143\n\n    DE ambiguous                              0.9579   0.9408\n    DE predicted                              0.9575   0.9406\n    DE none of the above                      0.9581   0.9411\n\n    ID ambiguous                              0.8889   0.8794\n    ID predicted                              0.8911   0.8801\n    ID none of the above                      0.8913   0.8822\n\n    IT ambiguous                              0.8404   0.8380\n    IT predicted                              0.8397   0.8398\n    IT none of the above                      0.8400   0.8409\n\n    VI ambiguous                              0.8290   0.7676\n    VI predicted                              0.8287   0.7682\n    VI none of the above                      0.8292   0.7691\n    \"\"\"\n    def __new__(cls, fn, correct=False, debug=False):\n        \"\"\"\n        Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error\n        \"\"\"\n        value = len(cls.__members__)\n        obj = object.__new__(cls)\n        obj._value_ = value + 1\n        obj.fn = fn\n        obj.correct = correct\n        obj.debug = debug\n        return obj\n\n    @property\n    def is_correct(self):\n        return self.correct\n\n    # The parser chose to close a bracket instead of shift something\n    # into the bracket\n    # This causes both a precision and a recall error as there is now\n    # an incorrect bracket and a missing correct bracket\n    # Any bracket creation here would cause more wrong brackets, though\n    SHIFT_CLOSE_ERROR                      = (fix_shift_close,)\n\n    OPEN_CLOSE_ERROR                       = (fix_open_close,)\n\n    # open followed by shift was instead predicted to be shift\n    ONE_OPEN_SHIFT_ERROR                   = (fix_one_open_shift,)\n\n    # open followed by shift was instead predicted to be shift\n    MULTIPLE_OPEN_SHIFT_ERROR              = (fix_multiple_open_shift,)\n\n    # should have done Open(X), Open(Y)\n    # instead just did Open(Y)\n    NESTED_OPEN_OPEN_ERROR                 = (fix_nested_open_constituent,)\n\n    SHIFT_OPEN_ERROR                       = (fix_shift_open_immediate_close,)\n\n    CLOSE_SHIFT_ERROR                      = (fix_close_shift,)\n\n    CLOSE_SHIFT_WITH_OPENS_ERROR           = (fix_close_shift_with_opens,)\n\n    CLOSE_OPEN_ONE_CON_ERROR               = (fix_close_open_correct_open,)\n\n    CORRECT                                = (None, True)\n\n    UNKNOWN                                = None\n\n    CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR   = (fix_close_open_correct_open_ambiguous_immediate,)\n\n    CLOSE_OPEN_AMBIGUOUS_LATER_ERROR       = (fix_close_open_correct_open_ambiguous_later,)\n\n    CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR  = (fix_close_shift_ambiguous_immediate,)\n\n    CLOSE_SHIFT_AMBIGUOUS_LATER_ERROR      = (fix_close_shift_ambiguous_later,)\n\n    # can potentially fix either close/shift or close/open\n    # as long as the gold transition after the close\n    # was the same as the transition we just predicted\n    CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED = (fix_close_next_correct_predicted,)\n\n    OPEN_OPEN_AMBIGUOUS_UNARY_ERROR        = (fix_open_open_ambiguous_unary,)\n\n    OPEN_OPEN_AMBIGUOUS_LATER_ERROR        = (fix_open_open_ambiguous_later,)\n\n    OPEN_OPEN_AMBIGUOUS_RANDOM_ERROR       = (fix_open_open_ambiguous_random,)\n\n    SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR       = (fix_shift_open_ambiguous_unary,)\n\n    SHIFT_OPEN_AMBIGUOUS_LATER_ERROR       = (fix_shift_open_ambiguous_later,)\n\n    SHIFT_OPEN_AMBIGUOUS_PREDICTED         = (fix_shift_open_ambiguous_predicted,)\n\n    OTHER_SHIFT_OPEN                       = (report_shift_open, False, True)\n\n    OTHER_CLOSE_SHIFT                      = (report_close_shift, False, True)\n\n    OTHER_CLOSE_OPEN                       = (report_close_open, False, True)\n\n    OTHER_OPEN_OPEN                        = (report_open_open, False, True)\n\nclass TopDownOracle(DynamicOracle):\n    def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):\n        super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)\n"
  },
  {
    "path": "stanza/models/constituency/trainer.py",
    "content": "\"\"\"\nThis file includes a variety of methods needed to train new\nconstituency parsers.  It also includes a method to load an\nalready-trained parser.\n\nSee the `train` method for the code block which starts from\n  raw treebank and returns a new parser.\n`evaluate` reads a treebank and gives a score for those trees.\n\"\"\"\n\nimport copy\nimport logging\nimport os\n\nimport torch\n\nfrom stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain, NoTransformerFoundationCache\nfrom stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper, pop_peft_args\nfrom stanza.models.constituency.base_trainer import BaseTrainer, ModelType\nfrom stanza.models.constituency.lstm_model import LSTMModel, SentenceBoundary, StackHistory, ConstituencyComposition\nfrom stanza.models.constituency.parse_transitions import Transition, TransitionScheme\nfrom stanza.models.constituency.utils import build_optimizer, build_scheduler\n# TODO: could put find_wordvec_pretrain, choose_charlm, etc in a more central place if it becomes widely used\nfrom stanza.utils.training.common import find_wordvec_pretrain, choose_charlm, find_charlm_file\nfrom stanza.resources.default_packages import default_charlms, default_pretrains\n\nlogger = logging.getLogger('stanza')\ntlogger = logging.getLogger('stanza.constituency.trainer')\n\nclass Trainer(BaseTrainer):\n    \"\"\"\n    Stores a constituency model and its optimizer\n\n    Not inheriting from common/trainer.py because there's no concept of change_lr (yet?)\n    \"\"\"\n    def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):\n        super().__init__(model, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer)\n\n    def save(self, filename, save_optimizer=True):\n        \"\"\"\n        Save the model (and by default the optimizer) to the given path\n        \"\"\"\n        super().save(filename, save_optimizer)\n\n    def get_peft_params(self):\n        # Hide import so that peft dependency is optional\n        if self.model.args.get('use_peft', False):\n            from peft import get_peft_model_state_dict\n            return get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)\n        return None\n\n    @property\n    def model_type(self):\n        return ModelType.LSTM\n\n    @staticmethod\n    def find_and_load_pretrain(saved_args, foundation_cache):\n        if 'wordvec_pretrain_file' not in saved_args:\n            return None\n        if os.path.exists(saved_args['wordvec_pretrain_file']):\n            return load_pretrain(saved_args['wordvec_pretrain_file'], foundation_cache)\n        logger.info(\"Unable to find pretrain in %s  Will try to load from the default resources instead\", saved_args['wordvec_pretrain_file'])\n        language = saved_args['lang']\n        wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains)\n        return load_pretrain(wordvec_pretrain, foundation_cache)\n\n    @staticmethod\n    def find_and_load_charlm(charlm_file, direction, saved_args, foundation_cache):\n        try:\n            return load_charlm(charlm_file, foundation_cache)\n        except FileNotFoundError as e:\n            logger.info(\"Unable to load charlm from %s  Will try to load from the default resources instead\", charlm_file)\n            language = saved_args['lang']\n            dataset = saved_args['shorthand'].split(\"_\")[1]\n            charlm = choose_charlm(language, dataset, \"default\", default_charlms, {})\n            charlm_file = find_charlm_file(direction, language, charlm)\n            return load_charlm(charlm_file, foundation_cache)\n\n    def log_num_words_known(self, words):\n        tlogger.info(\"Number of words in the training set found in the embedding: %d out of %d\", self.model.num_words_known(words), len(words))\n\n    @staticmethod\n    def load_optimizer(model, checkpoint, first_optimizer, filename):\n        optimizer = build_optimizer(model.args, model, first_optimizer)\n        if checkpoint.get('optimizer_state_dict', None) is not None:\n            try:\n                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n            except ValueError as e:\n                raise ValueError(\"Failed to load optimizer from %s\" % filename) from e\n        else:\n            logger.info(\"Attempted to load optimizer to resume training, but optimizer not saved.  Creating new optimizer\")\n        return optimizer\n\n    @staticmethod\n    def load_scheduler(model, optimizer, checkpoint, first_optimizer):\n        scheduler = build_scheduler(model.args, optimizer, first_optimizer=first_optimizer)\n        if 'scheduler_state_dict' in checkpoint:\n            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n        return scheduler\n\n    @staticmethod\n    def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None):\n        \"\"\"\n        Build a new model just from the saved params and some extra args\n\n        Refactoring allows other processors to include a constituency parser as a module\n        \"\"\"\n        saved_args = dict(params['config'])\n        if isinstance(saved_args['sentence_boundary_vectors'], str):\n            saved_args['sentence_boundary_vectors'] = SentenceBoundary[saved_args['sentence_boundary_vectors']]\n        if isinstance(saved_args['constituency_composition'], str):\n            saved_args['constituency_composition'] = ConstituencyComposition[saved_args['constituency_composition']]\n        if isinstance(saved_args['transition_stack'], str):\n            saved_args['transition_stack'] = StackHistory[saved_args['transition_stack']]\n        if isinstance(saved_args['constituent_stack'], str):\n            saved_args['constituent_stack'] = StackHistory[saved_args['constituent_stack']]\n        if isinstance(saved_args['transition_scheme'], str):\n            saved_args['transition_scheme'] = TransitionScheme[saved_args['transition_scheme']]\n\n        # some parameters which change the structure of a model have\n        # to be ignored, or the model will not function when it is\n        # reloaded from disk\n        if args is None: args = {}\n        update_args = copy.deepcopy(args)\n        pop_peft_args(update_args)\n        update_args.pop(\"bert_hidden_layers\", None)\n        update_args.pop(\"bert_model\", None)\n        update_args.pop(\"constituency_composition\", None)\n        update_args.pop(\"constituent_stack\", None)\n        update_args.pop(\"num_tree_lstm_layers\", None)\n        update_args.pop(\"transition_scheme\", None)\n        update_args.pop(\"transition_stack\", None)\n        update_args.pop(\"maxout_k\", None)\n        # if the pretrain or charlms are not specified, don't override the values in the model\n        # (if any), since the model won't even work without loading the same charlm\n        if 'wordvec_pretrain_file' in update_args and update_args['wordvec_pretrain_file'] is None:\n            update_args.pop('wordvec_pretrain_file')\n        if 'charlm_forward_file' in update_args and update_args['charlm_forward_file'] is None:\n            update_args.pop('charlm_forward_file')\n        if 'charlm_backward_file' in update_args and update_args['charlm_backward_file'] is None:\n            update_args.pop('charlm_backward_file')\n        # we don't pop bert_finetune, with the theory being that if\n        # the saved model has bert_finetune==True we can load the bert\n        # weights but then not further finetune if bert_finetune==False\n        saved_args.update(update_args)\n\n        # TODO: not needed if we rebuild the models\n        if saved_args.get(\"bert_finetune\", None) is None:\n            saved_args[\"bert_finetune\"] = False\n        if saved_args.get(\"stage1_bert_finetune\", None) is None:\n            saved_args[\"stage1_bert_finetune\"] = False\n\n        model_type = params['model_type']\n        if model_type == 'LSTM':\n            pt = Trainer.find_and_load_pretrain(saved_args, foundation_cache)\n            if saved_args.get('use_peft', False):\n                # if loading a peft model, we first load the base transformer\n                # then we load the weights using the saved weights in the file\n                if peft_name is None:\n                    bert_model, bert_tokenizer, peft_name = load_bert_with_peft(saved_args.get('bert_model', None), \"constituency\", foundation_cache)\n                else:\n                    bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)\n                bert_model = load_peft_wrapper(bert_model, peft_params, saved_args, logger, peft_name)\n                bert_saved = True\n            elif saved_args['bert_finetune'] or saved_args['stage1_bert_finetune'] or any(x.startswith(\"bert_model.\") for x in params['model'].keys()):\n                # if bert_finetune is True, don't use the cached model!\n                # otherwise, other uses of the cached model will be ruined\n                bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None))\n                bert_saved = True\n            else:\n                bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)\n                bert_saved = False\n            forward_charlm =  Trainer.find_and_load_charlm(saved_args[\"charlm_forward_file\"],  \"forward\",  saved_args, foundation_cache)\n            backward_charlm = Trainer.find_and_load_charlm(saved_args[\"charlm_backward_file\"], \"backward\", saved_args, foundation_cache)\n\n            # TODO: the isinstance will be unnecessary after 1.10.0\n            transitions = params['transitions']\n            if all(isinstance(x, str) for x in transitions):\n                transitions = [Transition.from_repr(x) for x in transitions]\n\n            model = LSTMModel(pretrain=pt,\n                              forward_charlm=forward_charlm,\n                              backward_charlm=backward_charlm,\n                              bert_model=bert_model,\n                              bert_tokenizer=bert_tokenizer,\n                              force_bert_saved=bert_saved,\n                              peft_name=peft_name,\n                              transitions=transitions,\n                              constituents=params['constituents'],\n                              tags=params['tags'],\n                              words=params['words'],\n                              rare_words=set(params['rare_words']),\n                              root_labels=params['root_labels'],\n                              constituent_opens=params['constituent_opens'],\n                              unary_limit=params['unary_limit'],\n                              args=saved_args)\n        else:\n            raise ValueError(\"Unknown model type {}\".format(model_type))\n        model.load_state_dict(params['model'], strict=False)\n        # model will stay on CPU if device==None\n        # can be moved elsewhere later, of course\n        model = model.to(args.get('device', None))\n        return model\n\n    @staticmethod\n    def build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file):\n        # TODO: turn finetune, relearn_structure, multistage into an enum?\n        # finetune just means continue learning, so checkpoint is sufficient\n        # relearn_structure is essentially a one stage multistage\n        # multistage with a checkpoint will have the proper optimizer for that epoch\n        # and no special learning mode means we are training a new model and should continue\n        if args['checkpoint'] and args['checkpoint_save_name'] and os.path.exists(args['checkpoint_save_name']):\n            tlogger.info(\"Found checkpoint to continue training: %s\", args['checkpoint_save_name'])\n            trainer = Trainer.load(args['checkpoint_save_name'], args, load_optimizer=True, foundation_cache=foundation_cache)\n            return trainer\n\n        # in the 'finetune' case, this will preload the models into foundation_cache,\n        # so the effort is not wasted\n        pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file'])\n        forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file'])\n        backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])\n\n        if args['finetune']:\n            tlogger.info(\"Loading model to finetune: %s\", model_load_file)\n            trainer = Trainer.load(model_load_file, args, load_optimizer=True, foundation_cache=NoTransformerFoundationCache(foundation_cache))\n            # a new finetuning will start with a new epochs_trained count\n            trainer.epochs_trained = 0\n            return trainer\n\n        if args['relearn_structure']:\n            tlogger.info(\"Loading model to continue training with new structure from %s\", model_load_file)\n            temp_args = dict(args)\n            # remove the pattn & lattn layers unless the saved model had them\n            temp_args.pop('pattn_num_layers', None)\n            temp_args.pop('lattn_d_proj', None)\n            trainer = Trainer.load(model_load_file, temp_args, load_optimizer=False, foundation_cache=NoTransformerFoundationCache(foundation_cache))\n\n            # using the model's current values works for if the new\n            # dataset is the same or smaller\n            # TODO: handle a larger dataset as well\n            model = LSTMModel(pt,\n                              forward_charlm,\n                              backward_charlm,\n                              trainer.model.bert_model,\n                              trainer.model.bert_tokenizer,\n                              trainer.model.force_bert_saved,\n                              trainer.model.peft_name,\n                              trainer.model.transitions,\n                              trainer.model.constituents,\n                              trainer.model.tags,\n                              trainer.model.delta_words,\n                              trainer.model.rare_words,\n                              trainer.model.root_labels,\n                              trainer.model.constituent_opens,\n                              trainer.model.unary_limit(),\n                              args)\n            model = model.to(args['device'])\n            model.copy_with_new_structure(trainer.model)\n            optimizer = build_optimizer(args, model, False)\n            scheduler = build_scheduler(args, optimizer)\n            trainer = Trainer(model, optimizer, scheduler)\n            return trainer\n\n        if args['multistage']:\n            # run adadelta over the model for half the time with no pattn or lattn\n            # training then switches to a different optimizer for the rest\n            # this works surprisingly well\n            tlogger.info(\"Warming up model for %d iterations using AdaDelta to train the embeddings\", args['epochs'] // 2)\n            temp_args = dict(args)\n            # remove the attention layers for the temporary model\n            temp_args['pattn_num_layers'] = 0\n            temp_args['lattn_d_proj'] = 0\n            args = temp_args\n\n        peft_name = None\n        if args['use_peft']:\n            peft_name = \"constituency\"\n            bert_model, bert_tokenizer = load_bert(args['bert_model'])\n            bert_model = build_peft_wrapper(bert_model, temp_args, tlogger, adapter_name=peft_name)\n        elif args['bert_finetune'] or args['stage1_bert_finetune']:\n            bert_model, bert_tokenizer = load_bert(args['bert_model'])\n        else:\n            bert_model, bert_tokenizer = load_bert(args['bert_model'], foundation_cache)\n        model = LSTMModel(pt,\n                          forward_charlm,\n                          backward_charlm,\n                          bert_model,\n                          bert_tokenizer,\n                          False,\n                          peft_name,\n                          train_transitions,\n                          train_constituents,\n                          tags,\n                          words,\n                          rare_words,\n                          root_labels,\n                          open_nodes,\n                          unary_limit,\n                          args)\n        model = model.to(args['device'])\n\n        optimizer = build_optimizer(args, model, build_simple_adadelta=args['multistage'])\n        scheduler = build_scheduler(args, optimizer, first_optimizer=args['multistage'])\n\n        trainer = Trainer(model, optimizer, scheduler, first_optimizer=args['multistage'])\n        return trainer\n"
  },
  {
    "path": "stanza/models/constituency/transformer_tree_stack.py",
    "content": "\"\"\"\nBased on\n\nTransition-based Parsing with Stack-Transformers\nRamon Fernandez Astudillo, Miguel Ballesteros, Tahira Naseem,\n  Austin Blodget, and Radu Florian\nhttps://aclanthology.org/2020.findings-emnlp.89.pdf\n\"\"\"\n\nfrom collections import namedtuple\n\nimport torch\nimport torch.nn as nn\n\nfrom stanza.models.constituency.positional_encoding import SinusoidalEncoding\nfrom stanza.models.constituency.tree_stack import TreeStack\n\nNode = namedtuple(\"Node\", ['value', 'key_stack', 'value_stack', 'output'])\n\nclass TransformerTreeStack(nn.Module):\n    def __init__(self, input_size, output_size, input_dropout, length_limit=None, use_position=False, num_heads=1):\n        \"\"\"\n        Builds the internal matrices and start parameter\n\n        TODO: currently only one attention head, implement MHA\n        \"\"\"\n        super().__init__()\n\n        self.input_size = input_size\n        self.output_size = output_size\n        self.inv_sqrt_output_size = 1 / output_size ** 0.5\n        self.num_heads = num_heads\n\n        self.w_query = nn.Linear(input_size, output_size)\n        self.w_key   = nn.Linear(input_size, output_size)\n        self.w_value = nn.Linear(input_size, output_size)\n\n        self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))\n        if isinstance(input_dropout, nn.Module):\n            self.input_dropout = input_dropout\n        else:\n            self.input_dropout = nn.Dropout(input_dropout)\n\n        if length_limit is not None and length_limit < 1:\n            raise ValueError(\"length_limit < 1 makes no sense\")\n        self.length_limit = length_limit\n\n        self.use_position = use_position\n        if use_position:\n            self.position_encoding = SinusoidalEncoding(model_dim=self.input_size, max_len=512)\n\n    def attention(self, key, query, value, mask=None):\n        \"\"\"\n        Calculate attention for the given key, query value\n\n        Where B is the number of items stacked together, N is the length:\n        The key should be BxNxD\n        The query is BxD\n        The value is BxNxD\n\n        If mask is specified, it should be BxN of True/False values,\n        where True means that location is masked out\n\n        Reshapes and reorders are used to handle num_heads\n\n        Return will be softmax(query x key^T) * value\n        of size BxD\n        \"\"\"\n        B = key.shape[0]\n        N = key.shape[1]\n        D = key.shape[2]\n\n        H = self.num_heads\n\n        # query is now BxDx1\n        query = query.unsqueeze(2)\n        # BxHxD/Hx1\n        query = query.reshape((B, H, -1, 1))\n\n        # BxNxHxD/H\n        key = key.reshape((B, N, H, -1))\n        # BxHxNxD/H\n        key = key.transpose(1, 2)\n\n        # BxNxHxD/H\n        value = value.reshape((B, N, H, -1))\n        # BxHxNxD/H\n        value = value.transpose(1, 2)\n\n        # BxHxNxD/H x BxHxD/Hx1\n        # result shape: BxHxN\n        attn = torch.matmul(key, query).squeeze(3) * self.inv_sqrt_output_size\n        if mask is not None:\n            # mask goes from BxN -> Bx1xN\n            mask = mask.unsqueeze(1)\n            mask = mask.expand(-1, H, -1)\n            attn.masked_fill_(mask, float('-inf'))\n        # attn shape will now be BxHx1xN\n        attn = torch.softmax(attn, dim=2).unsqueeze(2)\n        # BxHx1xN x BxHxNxD/H -> BxHxD/H\n        output = torch.matmul(attn, value).squeeze(2)\n        output = output.reshape(B, -1)\n        return output\n\n    def initial_state(self, initial_value=None):\n        \"\"\"\n        Return an initial state based on a single layer of attention\n\n        Running attention might be overkill, but it is the simplest\n        way to put the Linears and start_embedding in the computation graph\n        \"\"\"\n        start = self.start_embedding\n        if self.use_position:\n            position = self.position_encoding([0]).squeeze(0)\n            start = start + position\n\n        # N=1\n        # shape: 1xD\n        key = self.w_key(start).unsqueeze(0)\n\n        # shape: D\n        query = self.w_query(start)\n\n        # shape: 1xD\n        value = self.w_value(start).unsqueeze(0)\n\n        # unsqueeze to make it look like we are part of a batch of size 1\n        output = self.attention(key.unsqueeze(0), query.unsqueeze(0), value.unsqueeze(0)).squeeze(0)\n        return TreeStack(value=Node(initial_value, key, value, output), parent=None, length=1)\n\n    def push_states(self, stacks, values, inputs):\n        \"\"\"\n        Push new inputs to the stacks and rerun attention on them\n\n        Where B is the number of items stacked together, I is input_size\n        stacks: B TreeStacks such as produced by initial_state and/or push_states\n        values: the new items to push on the stacks such as tree nodes or anything\n        inputs: BxI for the new input items\n\n        Runs attention starting from the existing keys & values\n        \"\"\"\n        device = self.w_key.weight.device\n\n        batch_len = len(stacks)   # B\n        positions = [x.value.key_stack.shape[0] for x in stacks]\n        max_len = max(positions)  # N\n\n        if self.use_position:\n            position_encodings = self.position_encoding(positions)\n            inputs = inputs + position_encodings\n\n        inputs = self.input_dropout(inputs)\n        if len(inputs.shape) == 3:\n            if inputs.shape[0] == 1:\n                inputs = inputs.squeeze(0)\n            else:\n                raise ValueError(\"Expected the inputs to be of shape 1xBxI, got {}\".format(inputs.shape))\n\n        new_keys = self.w_key(inputs)\n        key_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)\n        key_stack[:, -1, :] = new_keys\n        for stack_idx, stack in enumerate(stacks):\n            key_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.key_stack\n\n        new_values = self.w_value(inputs)\n        value_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)\n        value_stack[:, -1, :] = new_values\n        for stack_idx, stack in enumerate(stacks):\n            value_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.value_stack\n\n        query = self.w_query(inputs)\n\n        mask = torch.zeros(batch_len, max_len+1, device=device, dtype=torch.bool)\n        for stack_idx, stack in enumerate(stacks):\n            if len(stack) < max_len:\n                masked = max_len - positions[stack_idx]\n                mask[stack_idx, :masked] = True\n\n        batched_output = self.attention(key_stack, query, value_stack, mask)\n\n        new_stacks = []\n        for stack_idx, (stack, node_value, new_key, new_value, output) in enumerate(zip(stacks, values, key_stack, value_stack, batched_output)):\n            # max_len-len(stack) so that we ignore the padding at the start of shorter stacks\n            new_key_stack = new_key[max_len-positions[stack_idx]:, :]\n            new_value_stack = new_value[max_len-positions[stack_idx]:, :]\n            if self.length_limit is not None and new_key_stack.shape[0] > self.length_limit + 1:\n                new_key_stack = torch.cat([new_key_stack[:1, :], new_key_stack[2:, :]], axis=0)\n                new_value_stack = torch.cat([new_value_stack[:1, :], new_value_stack[2:, :]], axis=0)\n            new_stacks.append(stack.push(value=Node(node_value, new_key_stack, new_value_stack, output)))\n        return new_stacks\n\n    def output(self, stack):\n        \"\"\"\n        Return the last layer of the lstm_hx as the output from a stack\n\n        Refactored so that alternate structures have an easy way of getting the output\n        \"\"\"\n        return stack.value.output\n"
  },
  {
    "path": "stanza/models/constituency/transition_sequence.py",
    "content": "\"\"\"\nBuild a transition sequence from parse trees.\n\nSupports multiple transition schemes - TOP_DOWN and variants, IN_ORDER\n\"\"\"\n\nimport logging\n\nfrom stanza.models.common import utils\nfrom stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme, Finalize\nfrom stanza.models.constituency.tree_reader import read_trees\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nlogger = logging.getLogger('stanza.constituency.trainer')\n\ndef yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):\n    \"\"\"\n    For tree (X A B C D), yield Open(X) A B C D Close\n\n    The details are in how to treat unary transitions\n    Three possibilities handled by this method:\n      TOP_DOWN_UNARY:    (Y (X ...)) -> Open(X) ... Close Unary(Y)\n      TOP_DOWN_COMPOUND: (Y (X ...)) -> Open(Y, X) ... Close\n      TOP_DOWN:          (Y (X ...)) -> Open(Y) Open(X) ... Close Close\n    \"\"\"\n    if tree.is_preterminal():\n        yield Shift()\n        return\n\n    if tree.is_leaf():\n        return\n\n    if transition_scheme is TransitionScheme.TOP_DOWN_UNARY:\n        if len(tree.children) == 1:\n            labels = []\n            while not tree.is_preterminal() and len(tree.children) == 1:\n                labels.append(tree.label)\n                tree = tree.children[0]\n            for transition in yield_top_down_sequence(tree, transition_scheme):\n                yield transition\n            yield CompoundUnary(*labels)\n            return\n\n    if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:\n        labels = [tree.label]\n        while len(tree.children) == 1 and not tree.children[0].is_preterminal():\n            tree = tree.children[0]\n            labels.append(tree.label)\n        yield OpenConstituent(*labels)\n    else:\n        yield OpenConstituent(tree.label)\n    for child in tree.children:\n        for transition in yield_top_down_sequence(child, transition_scheme):\n            yield transition\n    yield CloseConstituent()\n\ndef yield_in_order_sequence(tree):\n    \"\"\"\n    For tree (X A B C D), yield A Open(X) B C D Close\n    \"\"\"\n    if tree.is_preterminal():\n        yield Shift()\n        return\n\n    if tree.is_leaf():\n        return\n\n    for transition in yield_in_order_sequence(tree.children[0]):\n        yield transition\n\n    yield OpenConstituent(tree.label)\n\n    for child in tree.children[1:]:\n        for transition in yield_in_order_sequence(child):\n            yield transition\n\n    yield CloseConstituent()\n\n\n\ndef yield_in_order_compound_sequence(tree, transition_scheme):\n    def helper(tree):\n        if tree.is_leaf():\n            return\n\n        labels = []\n        while len(tree.children) == 1 and not tree.is_preterminal():\n            labels.append(tree.label)\n            tree = tree.children[0]\n\n        if tree.is_preterminal():\n            yield Shift()\n            if len(labels) > 0:\n                yield CompoundUnary(*labels)\n            return\n\n        for transition in helper(tree.children[0]):\n            yield transition\n\n        if transition_scheme is TransitionScheme.IN_ORDER_UNARY:\n            yield OpenConstituent(tree.label)\n        else:\n            labels.append(tree.label)\n            yield OpenConstituent(*labels)\n\n        for child in tree.children[1:]:\n            for transition in helper(child):\n                yield transition\n\n        yield CloseConstituent()\n\n        if transition_scheme is TransitionScheme.IN_ORDER_UNARY and len(labels) > 0:\n            yield CompoundUnary(*labels)\n\n    if len(tree.children) == 0:\n        raise ValueError(\"Cannot build {} on an empty tree\".format(transition_scheme))\n    if len(tree.children) != 1:\n        raise ValueError(\"Cannot build {} with a tree that has two top level nodes: {}\".format(transition_scheme, tree))\n\n    for t in helper(tree.children[0]):\n        yield t\n\n    yield Finalize(tree.label)\n\ndef build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):\n    \"\"\"\n    Turn a single tree into a list of transitions based on the TransitionScheme\n    \"\"\"\n    if transition_scheme is TransitionScheme.IN_ORDER:\n        return list(yield_in_order_sequence(tree))\n    elif (transition_scheme is TransitionScheme.IN_ORDER_COMPOUND or\n          transition_scheme is TransitionScheme.IN_ORDER_UNARY):\n        return list(yield_in_order_compound_sequence(tree, transition_scheme))\n    else:\n        return list(yield_top_down_sequence(tree, transition_scheme))\n\ndef build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, reverse=False):\n    \"\"\"\n    Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme\n    \"\"\"\n    if reverse:\n        return [build_sequence(tree.reverse(), transition_scheme) for tree in trees]\n    else:\n        return [build_sequence(tree, transition_scheme) for tree in trees]\n\ndef all_transitions(transition_lists):\n    \"\"\"\n    Given a list of transition lists, combine them all into a list of unique transitions.\n    \"\"\"\n    transitions = set()\n    for trans_list in transition_lists:\n        transitions.update(trans_list)\n    return sorted(transitions)\n\ndef convert_trees_to_sequences(trees, treebank_name, transition_scheme, reverse=False):\n    \"\"\"\n    Wrap both build_treebank and all_transitions, possibly with a tqdm\n\n    Converts trees to a list of sequences, then returns the list of known transitions\n    \"\"\"\n    if len(trees) == 0:\n        return [], []\n    logger.info(\"Building %s transition sequences\", treebank_name)\n    if logger.getEffectiveLevel() <= logging.INFO:\n        trees = tqdm(trees)\n    sequences = build_treebank(trees, transition_scheme, reverse)\n    transitions = all_transitions(sequences)\n    return sequences, transitions\n\ndef main():\n    \"\"\"\n    Convert a sample tree and print its transitions\n    \"\"\"\n    text=\"( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    #text = \"(WP Who)\"\n\n    tree = read_trees(text)[0]\n\n    print(tree)\n    transitions = build_sequence(tree)\n    print(transitions)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/constituency/tree_embedding.py",
    "content": "\"\"\"\nA module to use a Constituency Parser to make an embedding for a tree\n\nThe embedding can be produced just from the words and the top of the\ntree, or it can be done with a form of attention over the nodes\n\nCan be done over an existing parse tree or unparsed text\n\"\"\"\n\n\nimport torch\nimport torch.nn as nn\n\nfrom stanza.models.constituency.trainer import Trainer\n\nclass TreeEmbedding(nn.Module):\n    def __init__(self, constituency_parser, args):\n        super(TreeEmbedding, self).__init__()\n\n        self.config = {\n            \"all_words\":   args[\"all_words\"],\n            \"backprop\":    args[\"backprop\"],\n            #\"batch_norm\":  args[\"batch_norm\"],\n            \"node_attn\":   args[\"node_attn\"],\n            \"top_layer\":   args[\"top_layer\"],\n        }\n\n        self.constituency_parser = constituency_parser\n\n        # word_lstm:         hidden_size * num_tree_lstm_layers * 2 (start & end)\n        # transition_stack:  transition_hidden_size\n        # constituent_stack: hidden_size\n        self.hidden_size = self.constituency_parser.hidden_size + self.constituency_parser.transition_hidden_size\n        if self.config[\"all_words\"]:\n            self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers\n        else:\n            self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers * 2\n\n        if self.config[\"node_attn\"]:\n            self.query = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)\n            self.key = nn.Linear(self.hidden_size, self.constituency_parser.hidden_size)\n            self.value = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)\n\n            # TODO: cat transition and constituent hx as well?\n            self.output_size = self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers\n        else:\n            self.output_size = self.hidden_size\n\n        # TODO: maybe have batch_norm, maybe use Identity\n        #if self.config[\"batch_norm\"]:\n        #    self.input_norm = nn.BatchNorm1d(self.output_size)\n\n    def embed_trees(self, inputs):\n        if self.config[\"backprop\"]:\n            states = self.constituency_parser.analyze_trees(inputs)\n        else:\n            with torch.no_grad():\n                states = self.constituency_parser.analyze_trees(inputs)\n\n        constituent_lists = [x.constituents for x in states]\n        states = [x.state for x in states]\n\n        word_begin_hx = torch.stack([state.word_queue[0].hx for state in states])\n        word_end_hx = torch.stack([state.word_queue[state.word_position].hx for state in states])\n        transition_hx = torch.stack([self.constituency_parser.transition_stack.output(state.transitions) for state in states])\n        # go down one layer to get the embedding off the top of the S, not the ROOT\n        # (in terms of the typical treebank)\n        # the idea being that the ROOT has no additional information\n        # and may even have 0s for the embedding in certain circumstances,\n        # such as after learning UNTIED_MAX long enough\n        if self.config[\"top_layer\"]:\n            constituent_hx = torch.stack([self.constituency_parser.constituent_stack.output(state.constituents) for state in states])\n        else:\n            constituent_hx = torch.cat([constituents[-2].tree_hx for constituents in constituent_lists], dim=0)\n\n        if self.config[\"all_words\"]:\n            # need B matrices of N x hidden_size\n            key = [torch.stack([torch.cat([word.hx, thx, chx]) for word in state.word_queue], dim=0)\n                   for state, thx, chx in zip(states, transition_hx, constituent_hx)]\n        else:\n            key = torch.cat((word_begin_hx, word_end_hx, transition_hx, constituent_hx), dim=1).unsqueeze(1)\n\n        if not self.config[\"node_attn\"]:\n            return key\n        key = [self.key(x) for x in key]\n\n        node_hx = [torch.stack([con.tree_hx for con in constituents], dim=0) for constituents in constituent_lists]\n        queries = [self.query(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]\n        values = [self.value(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]\n        # TODO: could pad to make faster here\n        attn = [torch.matmul(q, k.transpose(0, 1)) for q, k in zip(queries, key)]\n        attn = [torch.softmax(x, dim=0) for x in attn]\n        previous_layer = [torch.matmul(weight.transpose(0, 1), value) for weight, value in zip(attn, values)]\n        return previous_layer\n\n    def forward(self, inputs):\n        return embed_trees(self, inputs)\n\n    def get_norms(self):\n        lines = [\"constituency_parser.\" + x for x in self.constituency_parser.get_norms()]\n        for name, param in self.named_parameters():\n            if param.requires_grad and not name.startswith('constituency_parser.'):\n                lines.append(\"%s %.6g\" % (name, torch.norm(param).item()))\n        return lines\n\n\n    def get_params(self, skip_modules=True):\n        model_state = self.state_dict()\n        # skip all of the constituency parameters here -\n        # we will add them by calling the model's get_params()\n        skipped = [k for k in model_state.keys() if k.startswith(\"constituency_parser.\")]\n        for k in skipped:\n            del model_state[k]\n\n        parser = self.constituency_parser.get_params(skip_modules)\n\n        params = {\n            'model':         model_state,\n            'constituency':  parser,\n            'config':        self.config,\n        }\n        return params\n\n    @staticmethod\n    def from_parser_file(args, foundation_cache=None):\n        constituency_parser = Trainer.load(args['model'], args, foundation_cache)\n        return TreeEmbedding(constituency_parser.model, args)\n\n    @staticmethod\n    def model_from_params(params, args, foundation_cache=None):\n        # TODO: integrate with peft\n        constituency_parser = Trainer.model_from_params(params['constituency'], None, args, foundation_cache)\n        model = TreeEmbedding(constituency_parser, params['config'])\n        model.load_state_dict(params['model'], strict=False)\n        return model\n"
  },
  {
    "path": "stanza/models/constituency/tree_reader.py",
    "content": "\"\"\"\nReads ParseTree objects from a file, string, or similar input\n\nWorks by first splitting the input into (, ), and all other tokens,\nthen recursively processing those tokens into trees.\n\"\"\"\n\nfrom collections import deque\nimport logging\nimport os\nimport re\n\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nOPEN_PAREN = \"(\"\nCLOSE_PAREN = \")\"\n\nlogger = logging.getLogger('stanza.constituency')\n\n# A few specific exception types to clarify parsing errors\n# They store the line number where the error occurred\n\nclass UnclosedTreeError(ValueError):\n    \"\"\"\n    A tree looked like (Foo\n    \"\"\"\n    def __init__(self, line_num):\n        super().__init__(\"Found an unfinished tree (missing close brackets).  Tree started on line %d\" % line_num)\n        self.line_num = line_num\n\nclass ExtraCloseTreeError(ValueError):\n    \"\"\"\n    A tree looked like (Foo))\n    \"\"\"\n    def __init__(self, line_num):\n        super().__init__(\"Found a broken tree (extra close brackets).  Tree started on line %d\" % line_num)\n        self.line_num = line_num\n\nclass UnlabeledTreeError(ValueError):\n    \"\"\"\n    A tree had no label, such as ((Foo) (Bar))\n\n    This does not actually happen at the root, btw, as ROOT is silently added\n    \"\"\"\n    def __init__(self, line_num):\n        super().__init__(\"Found a tree with no label on a node!  Line number %d\" % line_num)\n        self.line_num = line_num\n\nclass MixedTreeError(ValueError):\n    \"\"\"\n    Leaf and constituent children are mixed in the same node\n    \"\"\"\n    def __init__(self, line_num, child_label, children):\n        super().__init__(\"Found a tree with both text children and bracketed children!  Line number {}  Child label {}  Children {}\".format(line_num, child_label, children))\n        self.line_num = line_num\n        self.child_label = child_label\n        self.children = children\n\ndef normalize(text):\n    return text.replace(\"-LRB-\", \"(\").replace(\"-RRB-\", \")\")\n\ndef read_single_tree(token_iterator, broken_ok):\n    \"\"\"\n    Build a tree from the tokens in the token_iterator\n    \"\"\"\n    # we were called here at a open paren, so start the stack of\n    # children with one empty list already on it\n    children_stack = deque()\n    children_stack.append([])\n    text_stack = deque()\n    text_stack.append([])\n\n    token = next(token_iterator, None)\n    token_iterator.set_mark()\n    while token is not None:\n        if token == OPEN_PAREN:\n            children_stack.append([])\n            text_stack.append([])\n        elif token == CLOSE_PAREN:\n            text = text_stack.pop()\n            children = children_stack.pop()\n            if text:\n                pieces = \" \".join(text).split()\n                if len(pieces) == 1:\n                    child = Tree(pieces[0], children)\n                else:\n                    # the assumption here is that a language such as VI may\n                    # have spaces in the words, but it still represents\n                    # just one child\n                    label = pieces[0]\n                    child_label = \" \".join(pieces[1:])\n                    if children:\n                        if broken_ok:\n                            child = Tree(label, children + [Tree(normalize(child_label))])\n                        else:\n                            raise MixedTreeError(token_iterator.line_num, child_label, children)\n                    else:\n                        child = Tree(label, Tree(normalize(child_label)))\n                if not children_stack:\n                    return child\n            else:\n                if not children_stack:\n                    return Tree(\"ROOT\", children)\n                elif broken_ok:\n                    child = Tree(None, children)\n                else:\n                    raise UnlabeledTreeError(token_iterator.line_num)\n            children_stack[-1].append(child)\n        else:\n            text_stack[-1].append(token)\n        token = next(token_iterator, None)\n    raise UnclosedTreeError(token_iterator.get_mark())\n\nLINE_SPLIT_RE = re.compile(r\"([()])\")\n\n\nclass TokenIterator:\n    \"\"\"\n    A specific iterator for reading trees from a tree file\n\n    The idea is that this will keep track of which line\n    we are processing, so that an error can be logged\n    from the correct line\n    \"\"\"\n    def __init__(self):\n        self.token_iterator = iter([])\n        self.line_num = -1\n        self.mark = None\n\n    def set_mark(self):\n        \"\"\"\n        The mark is used for determining where the start of a tree occurs for an error\n        \"\"\"\n        self.mark = self.line_num\n\n    def get_mark(self):\n        if self.mark is None:\n            raise ValueError(\"No mark set!\")\n        return self.mark\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        n = next(self.token_iterator, None)\n        while n is None:\n            self.line_num = self.line_num + 1\n            line = next(self.line_iterator)\n            if line is None:\n                raise StopIteration\n            line = line.strip()\n            if not line:\n                continue\n\n            pieces = LINE_SPLIT_RE.split(line)\n            pieces = [x.strip() for x in pieces]\n            pieces = [x for x in pieces if x]\n            self.token_iterator = iter(pieces)\n            n = next(self.token_iterator, None)\n\n        return n\n\n\nclass TextTokenIterator(TokenIterator):\n    def __init__(self, text, use_tqdm=True):\n        super().__init__()\n\n        self.lines = text.split(\"\\n\")\n        self.num_lines = len(self.lines)\n        if self.num_lines > 1000 and use_tqdm:\n            self.line_iterator = iter(tqdm(self.lines))\n        else:\n            self.line_iterator = iter(self.lines)\n\n\nclass FileTokenIterator(TokenIterator):\n    def __init__(self, filename):\n        super().__init__()\n        self.filename = filename\n\n    def __enter__(self):\n        # TODO: use the file_size instead of counting the lines\n        # file_size = Path(self.filename).stat().st_size\n        with open(self.filename) as fin:\n            num_lines = sum(1 for _ in fin)\n\n        self.file_obj = open(self.filename)\n        if num_lines > 1000:\n            self.line_iterator = iter(tqdm(self.file_obj, total=num_lines))\n        else:\n            self.line_iterator = iter(self.file_obj)\n        return self\n\n    def __exit__(self, exc_type, exc_value, exc_tb):\n        if self.file_obj:\n            self.file_obj.close()\n\ndef read_token_iterator(token_iterator, broken_ok, tree_callback):\n    trees = []\n    token = next(token_iterator, None)\n    while token:\n        if token == OPEN_PAREN:\n            next_tree = read_single_tree(token_iterator, broken_ok=broken_ok)\n            if next_tree is None:\n                raise ValueError(\"Tree reader somehow created a None tree!  Line number %d\" % token_iterator.line_num)\n            if tree_callback is not None:\n                transformed = tree_callback(next_tree)\n                if transformed is not None:\n                    trees.append(transformed)\n            else:\n                trees.append(next_tree)\n            token = next(token_iterator, None)\n        elif token == CLOSE_PAREN:\n            raise ExtraCloseTreeError(token_iterator.line_num)\n        else:\n            raise ValueError(\"Tree document had text between trees!  Line number %d\" % token_iterator.line_num)\n\n    return trees\n\n\ndef read_trees(text, broken_ok=False, tree_callback=None, use_tqdm=True):\n    \"\"\"\n    Reads multiple trees from the text\n\n    TODO: some of the error cases we hit can be recovered from\n    \"\"\"\n    token_iterator = TextTokenIterator(text, use_tqdm)\n    return read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)\n\ndef read_tree_file(filename, broken_ok=False, tree_callback=None):\n    \"\"\"\n    Read all of the trees in the given file\n    \"\"\"\n    with FileTokenIterator(filename) as token_iterator:\n        trees = read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)\n    return trees\n\ndef read_directory(dirname, broken_ok=False, tree_callback=None):\n    \"\"\"\n    Read all of the trees in all of the files in a directory\n    \"\"\"\n    trees = []\n    for filename in sorted(os.listdir(dirname)):\n        full_name = os.path.join(dirname, filename)\n        trees.extend(read_tree_file(full_name, broken_ok, tree_callback))\n    return trees\n\ndef read_treebank(filename, tree_callback=None):\n    \"\"\"\n    Read a treebank and alter the trees to be a simpler format for learning to parse\n    \"\"\"\n    logger.info(\"Reading trees from %s\", filename)\n    trees = read_tree_file(filename, tree_callback=tree_callback)\n    trees = [t.prune_none().simplify_labels() for t in trees]\n\n    illegal_trees = [t for t in trees if len(t.children) > 1]\n    if len(illegal_trees) > 0:\n        raise ValueError(\"Found {} tree(s) which had non-unary transitions at the ROOT.  First illegal tree: {:P}\".format(len(illegal_trees), illegal_trees[0]))\n\n    return trees\n\ndef main():\n    \"\"\"\n    Reads a sample tree\n    \"\"\"\n    text=\"( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = read_trees(text)\n    print(trees)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/constituency/tree_stack.py",
    "content": "\"\"\"\nA utilitiy class for keeping track of intermediate parse states\n\"\"\"\n\nfrom collections import namedtuple\n\nclass TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])):\n    \"\"\"\n    A stack which can branch in several directions, as long as you\n    keep track of the branching heads\n\n    An example usage is when K constituents are removed at once\n    to create a new constituent, and then the LSTM which tracks the\n    values of the constituents is updated starting from the Kth\n    output of the LSTM with the new value.\n\n    We don't simply keep track of a single stack object using a deque\n    because versions of the parser which use a beam will want to be\n    able to branch in different directions from the same base stack\n\n    Another possible usage is if an oracle is used for training\n    in a manner where some fraction of steps are non-gold steps,\n    but we also want to take a gold step from the same state.\n    Eg, parser gets to state X, wants to make incorrect transition T\n    instead of gold transition G, and so we continue training both\n    X+G and X+T.  If we only represent the state X with standard\n    python stacks, it would not be possible to track both of these\n    states at the same time without copying the entire thing.\n\n    Value can be as transition, a word, or a partially built constituent\n\n    Implemented as a namedtuple to make it a bit more efficient\n    \"\"\"\n    def pop(self):\n        return self.parent\n\n    def push(self, value):\n        # returns a new stack node which points to this\n        return TreeStack(value, self, self.length+1)\n\n    def __iter__(self):\n        stack = self\n        while stack.parent is not None:\n            yield stack.value\n            stack = stack.parent\n        yield stack.value\n\n    def __reversed__(self):\n        items = list(iter(self))\n        for item in reversed(items):\n            yield item\n\n    def __str__(self):\n        return \"TreeStack(%s)\" % \", \".join([str(x) for x in self])\n\n    def __len__(self):\n        return self.length\n"
  },
  {
    "path": "stanza/models/constituency/utils.py",
    "content": "\"\"\"\nCollects a few of the conparser utility methods which don't belong elsewhere\n\"\"\"\n\nfrom collections import Counter\nimport logging\nimport warnings\n\nimport torch.nn as nn\nfrom torch import optim\n\nfrom stanza.models.common.doc import TEXT, Document\nfrom stanza.models.common.utils import get_optimizer\nfrom stanza.models.constituency.base_model import SimpleModel\nfrom stanza.models.constituency.parse_transitions import TransitionScheme\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nDEFAULT_LEARNING_RATES = { \"adamw\": 0.0002, \"adadelta\": 1.0, \"sgd\": 0.001, \"adabelief\": 0.00005, \"madgrad\": 0.0000007 , \"mirror_madgrad\": 0.00005 }\nDEFAULT_LEARNING_EPS = { \"adabelief\": 1e-12, \"adadelta\": 1e-6, \"adamw\": 1e-8 }\nDEFAULT_LEARNING_RHO = 0.9\nDEFAULT_MOMENTUM = { \"madgrad\": 0.9, \"mirror_madgrad\": 0.9, \"sgd\": 0.9 }\n\ntlogger = logging.getLogger('stanza.constituency.trainer')\n\n# madgrad experiment for weight decay\n# with learning_rate set to 0.0000007 and momentum 0.9\n# on en_wsj, with a baseline model trained on adadela for 200,\n# then madgrad used to further improve that model\n#  0.00000002.out: 0.9590347746438835\n#  0.00000005.out: 0.9591378819960182\n#  0.0000001.out: 0.9595450596319405\n#  0.0000002.out: 0.9594603134479271\n#  0.0000005.out: 0.9591317672706594\n#  0.000001.out: 0.9592548741021389\n#  0.000002.out: 0.9598395477013945\n#  0.000003.out: 0.9594974271553495\n#  0.000004.out: 0.9596665982603754\n#  0.000005.out: 0.9591620720706487\nDEFAULT_WEIGHT_DECAY = { \"adamw\": 0.05, \"adadelta\": 0.02, \"sgd\": 0.01, \"adabelief\": 1.2e-6, \"madgrad\": 2e-6, \"mirror_madgrad\": 2e-6 }\n\ndef retag_tags(doc, pipelines, xpos):\n    \"\"\"\n    Returns a list of list of tags for the items in doc\n\n    doc can be anything which feeds into the pipeline(s)\n    pipelines are a list of 1 or more retag pipelines\n    if multiple pipelines are given, majority vote wins\n    \"\"\"\n    tag_lists = []\n    for pipeline in pipelines:\n        doc = pipeline(doc)\n        tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences])\n    # tag_lists: for N pipeline, S sentences\n    # we now have N lists of S sentences each\n    # for sentence in zip(*tag_lists): N lists of |s| tags for this given sentence s\n    # for tag in zip(*sentence): N predicted tags.\n    # most common one in the Counter will be chosen\n    tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)]\n                 for sentence in zip(*tag_lists)]\n    return tag_lists\n\ndef retag_trees(trees, pipelines, xpos=True):\n    \"\"\"\n    Retag all of the trees using the given processor\n\n    Returns a list of new trees\n    \"\"\"\n    if len(trees) == 0:\n        return trees\n\n    new_trees = []\n    chunk_size = 1000\n    with tqdm(total=len(trees)) as pbar:\n        for chunk_start in range(0, len(trees), chunk_size):\n            chunk_end = min(chunk_start + chunk_size, len(trees))\n            chunk = trees[chunk_start:chunk_end]\n            sentences = []\n            try:\n                for idx, tree in enumerate(chunk):\n                    tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()]\n                    sentences.append(tokens)\n            except ValueError as e:\n                raise ValueError(\"Unable to process tree %d\" % (idx + chunk_start)) from e\n\n            doc = Document(sentences)\n            tag_lists = retag_tags(doc, pipelines, xpos)\n\n            for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)):\n                try:\n                    if any(tag is None for tag in tags):\n                        raise RuntimeError(\"Tagged tree #{} with a None tag!\\n{}\\n{}\".format(tree_idx, tree, tags))\n                    new_tree = tree.replace_tags(tags)\n                    new_trees.append(new_tree)\n                    pbar.update(1)\n                except ValueError as e:\n                    raise ValueError(\"Failed to properly retag tree #{}: {}\".format(tree_idx, tree)) from e\n    if len(new_trees) != len(trees):\n        raise AssertionError(\"Retagged tree counts did not match: {} vs {}\".format(len(new_trees), len(trees)))\n    return new_trees\n\n\ndef build_optimizer(args, model, build_simple_adadelta=False):\n    \"\"\"\n    Build an optimizer based on the arguments given\n\n    If we are \"multistage\" training and epochs_trained < epochs // 2,\n    we build an AdaDelta optimizer instead of whatever was requested\n    The build_simple_adadelta parameter controls this\n    \"\"\"\n    bert_learning_rate = 0.0\n    bert_weight_decay = args['bert_weight_decay']\n    if build_simple_adadelta:\n        optim_type = 'adadelta'\n        bert_finetune = args.get('stage1_bert_finetune', False)\n        if bert_finetune:\n            bert_learning_rate = args['stage1_bert_learning_rate']\n        learning_beta2 = 0.999   # doesn't matter for AdaDelta\n        learning_eps = DEFAULT_LEARNING_EPS['adadelta']\n        learning_rate = args['stage1_learning_rate']\n        learning_rho = DEFAULT_LEARNING_RHO\n        momentum = None    # also doesn't matter for AdaDelta\n        weight_decay = DEFAULT_WEIGHT_DECAY['adadelta']\n    else:\n        optim_type = args['optim'].lower()\n        bert_finetune = args.get('bert_finetune', False)\n        if bert_finetune:\n            bert_learning_rate = args['bert_learning_rate']\n        learning_beta2 = args['learning_beta2']\n        learning_eps = args['learning_eps']\n        learning_rate = args['learning_rate']\n        learning_rho = args['learning_rho']\n        momentum = args['learning_momentum']\n        weight_decay = args['learning_weight_decay']\n\n    # TODO: allow rho as an arg for AdaDelta\n    return get_optimizer(name=optim_type,\n                         model=model,\n                         lr=learning_rate,\n                         betas=(0.9, learning_beta2),\n                         eps=learning_eps,\n                         momentum=momentum,\n                         weight_decay=weight_decay,\n                         bert_learning_rate=bert_learning_rate,\n                         bert_weight_decay=weight_decay*bert_weight_decay,\n                         is_peft=args.get('use_peft', False),\n                         bert_finetune_layers=args['bert_finetune_layers'],\n                         opt_logger=tlogger)\n\ndef build_scheduler(args, optimizer, first_optimizer=False):\n    \"\"\"\n    Build the scheduler for the conparser based on its args\n\n    Used to use a warmup for learning rate, but that wasn't working very well\n    Now, we just use a ReduceLROnPlateau, which does quite well\n    \"\"\"\n    #if args.get('learning_rate_warmup', 0) <= 0:\n    #    # TODO: is there an easier way to make an empty scheduler?\n    #    lr_lambda = lambda x: 1.0\n    #else:\n    #    warmup_end = args['learning_rate_warmup']\n    #    def lr_lambda(x):\n    #        if x >= warmup_end:\n    #            return 1.0\n    #        return x / warmup_end\n\n    #scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n\n    if first_optimizer:\n        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['stage1_learning_rate_min_lr'])\n    else:\n        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['learning_rate_min_lr'])\n    return scheduler\n\ndef initialize_linear(linear, nonlinearity, bias):\n    \"\"\"\n    Initializes the bias to a positive value, hopefully preventing dead neurons\n    \"\"\"\n    if nonlinearity in ('relu', 'leaky_relu'):\n        nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity)\n        nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5)\n\ndef add_predict_output_args(parser):\n    \"\"\"\n    Args specifically for the output location of data\n    \"\"\"\n    parser.add_argument('--predict_dir', type=str, default=\".\", help='Where to write the predictions during --mode predict.  Pred and orig files will be written - the orig file will be retagged if that is requested.  Writing the orig file is useful for removing None and retagging')\n    parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')\n    parser.add_argument('--predict_format', type=str, default=\"{:_O}\", help='Format to use when writing predictions')\n\n    parser.add_argument('--predict_output_gold_tags', default=False, action='store_true', help='Output gold tags as part of the evaluation - useful for putting the trees through EvalB')\n\ndef postprocess_predict_output_args(args):\n    if len(args['predict_format']) <= 2 or (len(args['predict_format']) <= 3 and args['predict_format'].endswith(\"Vi\")):\n        args['predict_format'] = \"{:\" + args['predict_format'] + \"}\"\n\n\ndef get_open_nodes(trees, transition_scheme):\n    \"\"\"\n    Return a list of all open nodes in the given dataset.\n    Depending on the parameters, may be single or compound open transitions.\n    \"\"\"\n    if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:\n        return Tree.get_compound_constituents(trees)\n    elif transition_scheme is TransitionScheme.IN_ORDER_COMPOUND:\n        return Tree.get_compound_constituents(trees, separate_root=True)\n    else:\n        return [(x,) for x in Tree.get_unique_constituent_labels(trees)]\n\n\ndef verify_transitions(trees, sequences, transition_scheme, unary_limit, reverse, name, root_labels):\n    \"\"\"\n    Given a list of trees and their transition sequences, verify that the sequences rebuild the trees\n    \"\"\"\n    model = SimpleModel(transition_scheme, unary_limit, reverse, root_labels)\n    tlogger.info(\"Verifying the transition sequences for %d trees\", len(trees))\n\n    data = zip(trees, sequences)\n    if tlogger.getEffectiveLevel() <= logging.INFO:\n        data = tqdm(zip(trees, sequences), total=len(trees))\n\n    for tree_idx, (tree, sequence) in enumerate(data):\n        # TODO: make the SimpleModel have a parse operation?\n        state = model.initial_state_from_gold_trees([tree])[0]\n        for idx, trans in enumerate(sequence):\n            if not trans.is_legal(state, model):\n                raise RuntimeError(\"Tree {} of {} failed: transition {}:{} was not legal in a transition sequence:\\nOriginal tree: {}\\nTransitions: {}\".format(tree_idx, name, idx, trans, tree, sequence))\n            state = trans.apply(state, model)\n        result = model.get_top_constituent(state.constituents)\n        if reverse:\n            result = result.reverse()\n        if tree != result:\n            raise RuntimeError(\"Tree {} of {} failed: transition sequence did not match for a tree!\\nOriginal tree:{}\\nTransitions: {}\\nResult tree:{}\".format(tree_idx, name, tree, sequence, result))\n\ndef check_constituents(train_constituents, trees, treebank_name, fail=True):\n    \"\"\"\n    Check that all the constituents in the other dataset are known in the train set\n    \"\"\"\n    constituents = Tree.get_unique_constituent_labels(trees)\n    for con in constituents:\n        if con not in train_constituents:\n            first_error = None\n            num_errors = 0\n            for tree_idx, tree in enumerate(trees):\n                constituents = Tree.get_unique_constituent_labels(tree)\n                if con in constituents:\n                    num_errors += 1\n                    if first_error is None:\n                        first_error = tree_idx\n            error = \"Found constituent label {} in the {} set which don't exist in the train set.  This constituent label occurred in {} trees, with the first tree index at {} counting from 1\\nThe error tree (which may have POS tags changed from the retagger and may be missing functional tags or empty nodes) is:\\n{:P}\".format(con, treebank_name, num_errors, (first_error+1), trees[first_error])\n            if fail:\n                raise RuntimeError(error)\n            else:\n                warnings.warn(error)\n\ndef check_root_labels(root_labels, other_trees, treebank_name):\n    \"\"\"\n    Check that all the root states in the other dataset are known in the train set\n    \"\"\"\n    for root_state in Tree.get_root_labels(other_trees):\n        if root_state not in root_labels:\n            raise RuntimeError(\"Found root state {} in the {} set which is not a ROOT state in the train set\".format(root_state, treebank_name))\n\ndef remove_duplicate_trees(trees, treebank_name):\n    \"\"\"\n    Filter duplicates from the given dataset\n    \"\"\"\n    new_trees = []\n    known_trees = set()\n    for tree in trees:\n        tree_str = \"{}\".format(tree)\n        if tree_str in known_trees:\n            continue\n        known_trees.add(tree_str)\n        new_trees.append(tree)\n    if len(new_trees) < len(trees):\n        tlogger.info(\"Filtered %d duplicates from %s dataset\", (len(trees) - len(new_trees)), treebank_name)\n    return new_trees\n\ndef remove_singleton_trees(trees):\n    \"\"\"\n    remove trees which are just a root and a single word\n\n    TODO: remove these trees in the conversion instead of here\n    \"\"\"\n    new_trees = [x for x in trees if\n                 len(x.children) > 1 or\n                 (len(x.children) == 1 and len(x.children[0].children) > 1) or\n                 (len(x.children) == 1 and len(x.children[0].children) == 1 and len(x.children[0].children[0].children) >= 1)]\n    if len(trees) - len(new_trees) > 0:\n        tlogger.info(\"Eliminated %d trees with missing structure\", (len(trees) - len(new_trees)))\n    return new_trees\n\n"
  },
  {
    "path": "stanza/models/constituency_parser.py",
    "content": "\"\"\"A command line interface to a shift reduce constituency parser.\n\nThis follows the work of\nRecurrent neural network grammars by Dyer et al\nIn-Order Transition-based Constituent Parsing by Liu & Zhang\n\nThe general outline is:\n\n  Train a model by taking a list of trees, converting them to\n    transition sequences, and learning a model which can predict the\n    next transition given a current state\n  Then, at inference time, repeatedly predict the next transition until parsing is complete\n\nThe \"transitions\" are variations on shift/reduce as per an\nintro-to-compilers class.  The idea is that you can treat all of the\nwords in a sentence as a buffer of tokens, then either \"shift\" them to\nrepresent a new constituent, or \"reduce\" one or more constituents to\nform a new constituent.\n\nIn order to make the runtime a more competitive speed, effort is taken\nto batch the transitions and apply multiple transitions at once.  At\ntrain time, batches are groups together by length, and at inference\ntime, new trees are added to the batch as previous trees on the batch\nfinish their inference.\n\nThere are a few minor differences in the model:\n  - The word input is a bi-lstm, not a uni-lstm.\n    This gave a small increase in accuracy.\n  - The combination of several constituents into one constituent is done\n    via a single bi-lstm rather than two separate lstms.  This increases\n    speed without a noticeable effect on accuracy.\n  - In fact, an even better (in terms of final model accuracy) method\n    is to combine the constituents with torch.max, believe it or not\n    See lstm_model.py for more details\n  - Initializing the embeddings with smaller values than pytorch default\n    For example, on a ja_alt dataset, scores went from 0.8980 to 0.8985\n    at 200 iterations averaged over 5 trials\n  - Training with AdaDelta first, then AdamW or madgrad later improves\n    results quite a bit.  See --multistage\n\nA couple experiments which have been tried with little noticeable impact:\n  - Combining constituents using the method in the paper (only a trained\n    vector at the start instead of both ends) did not affect results\n    and is a little slower\n  - Using multiple layers of LSTM hidden state for the input to the final\n    classification layers didn't help\n  - Initializing Linear layers with He initialization and a positive bias\n    (to avoid dead connections) had no noticeable effect on accuracy\n    0.8396 on it_turin with the original initialization\n    0.8401 and 0.8427 on two runs with updated initialization\n    (so maybe a small improvement...)\n  - Initializing LSTM layers with different gates was slightly worse:\n    forget gates of 1.0\n    forget gates of 1.0, input gates of -1.0\n  - Replacing the LSTMs that make up the Transition and Constituent\n    LSTMs with Dynamic Skip LSTMs made no difference, but was slower\n  - Highway LSTMs also made no difference\n  - Putting labels on the shift transitions (the word or the tag shifted)\n    or putting labels on the close transitions didn't help\n  - Building larger constituents from the output of the constituent LSTM\n    instead of the children constituents hurts scores\n    For example, an experiment on ja_alt went from 0.8985 to 0.8964\n    when built that way\n  - The initial transition scheme implemented was TOP_DOWN.  We tried\n    a compound unary option, since this worked so well in the CoreNLP\n    constituency parser.  Unfortunately, this is far less effective\n    than IN_ORDER.  Both specialized unary matrices and reusing the\n    n-ary constituency combination fell short.  On the ja_alt dataset:\n      IN_ORDER, max combination method:           0.8985\n      TOP_DOWN_UNARY, specialized matrices:       0.8501\n      TOP_DOWN_UNARY, max combination method:     0.8508\n  - Adding multiple layers of MLP to combine inputs for words made\n    no difference in the scores\n    Tried both before the LSTM and after\n    A simple single layer tensor multiply after the LSTM works well.\n    Replacing that with a two layer MLP on the English PTB\n    with roberta-base causes a notable drop in scores\n    First experiment didn't use the fancy Linear weight init,\n    but adding that barely made a difference\n      260 training iterations on en_wsj dev, roberta-base\n      model as of bb983fd5e912f6706ad484bf819486971742c3d1\n      two layer MLP:                    0.9409\n      two layer MLP, init weights:      0.9413\n      single layer:                     0.9467\n  - There is code to rebuild models with a new structure in lstm_model.py\n    As part of this, we tried to randomly reinitialize the transitions\n    if the transition embedding had gone to 0, which often happens\n    This didn't help at all\n  - We tried something akin to attention with just the query vector\n    over the bert embeddings as a way to mix them, but that did not\n    improve scores.\n    Example, with a self.bert_layer_mix of size bert_dim x 1:\n        mixed_bert_embeddings = []\n        for feature in bert_embeddings:\n            weighted_feature = self.bert_layer_mix(feature.transpose(1, 2))\n            weighted_feature = torch.softmax(weighted_feature, dim=1)\n            weighted_feature = torch.matmul(feature, weighted_feature).squeeze(2)\n            mixed_bert_embeddings.append(weighted_feature)\n        bert_embeddings = mixed_bert_embeddings\n    It seems just finetuning the transformer is already enough\n    (in general, no need to mix layers at all when finetuning bert embeddings)\n\n\nThe code breakdown is as follows:\n\n  this file: main interface for training or evaluating models\n  constituency/trainer.py: contains the training & evaluation code\n  constituency/ensemble.py: evaluation code specifically for letting multiple models\n    vote on the correct next transition.  a modest improvement.\n  constituency/evaluate_treebanks.py: specifically to evaluate multiple parsed treebanks\n    against a gold.  in particular, reports whether the theoretical best from those\n    parsed treebanks is an improvement (eg, the k-best score as reported by CoreNLP)\n\n  constituency/parse_tree.py: a data structure for representing a parse tree and utility methods\n  constituency/tree_reader.py: a module which can read trees from a string or input file\n\n  constituency/tree_stack.py: a linked list which can branch in\n    different directions, which will be useful when implementing beam\n    search or a dynamic oracle\n  constituency/lstm_tree_stack.py: an LSTM over the elements of a TreeStack\n  constituency/transformer_tree_stack.py: attempts to run attention over the nodes\n    of a tree_stack.  not as effective as the lstm_tree_stack in the initial experiments.\n    perhaps it could be refined to work better, though\n\n  constituency/parse_transitions.py: transitions and a State data structure to store them\n  constituency/transition_sequence.py: turns ParseTree objects into\n    the transition sequences needed to make them\n\n  constituency/base_model.py: operates on the transitions to turn them in to constituents,\n    eventually forming one final parse tree composed of all of the constituents\n  constituency/lstm_model.py: adds LSTM features to the constituents to predict what the\n    correct transition to make is, allowing for predictions on previously unseen text\n\n  constituency/retagging.py: a couple utility methods specifically for retagging\n  constituency/utils.py: a couple utility methods\n\n  constituency/dyanmic_oracle.py: a dynamic oracle which currently\n    only operates for the inorder transition sequence.\n    uses deterministic rules to redo the correct action sequence when\n    the parser makes an error.\n\n  constituency/partitioned_transformer.py: implementation of a transformer for self-attention.\n     presumably this should help, but we have yet to find a model structure where\n     this makes the scores go up.\n  constituency/label_attention.py: an even fancier form of transformer based on labeled attention:\n     https://arxiv.org/abs/1911.03875\n  constituency/positional_encoding.py: so far, just the sinusoidal is here.\n     a trained encoding is in partitioned_transformer.py.\n     this should probably be refactored to common, especially if used elsewhere.\n\n  stanza/pipeline/constituency_processor.py: interface between this model and the Pipeline\n\n  stanza/utils/datasets/constituency: various scripts and tools for processing constituency datasets\n\nSome alternate optimizer methods:\n  adabelief: https://github.com/juntang-zhuang/Adabelief-Optimizer\n  madgrad: https://github.com/facebookresearch/madgrad\n\n\"\"\"\n\nimport argparse\nimport logging\nimport os\nimport random\nimport re\n\nimport torch\n\nimport stanza\nfrom stanza.models.common import constant\nfrom stanza.models.common import utils\nfrom stanza.models.common.peft_config import add_peft_args, resolve_peft_args\nfrom stanza.models.common.utils import NONLINEARITY\nfrom stanza.models.constituency import parser_training\nfrom stanza.models.constituency import retagging\nfrom stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory\nfrom stanza.models.constituency.parse_transitions import TransitionScheme\nfrom stanza.models.constituency.text_processing import load_model_parse_text\nfrom stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, add_predict_output_args, postprocess_predict_output_args\nfrom stanza.resources.common import DEFAULT_MODEL_DIR\n\nlogger = logging.getLogger('stanza')\ntlogger = logging.getLogger('stanza.constituency.trainer')\n\ndef build_argparse():\n    \"\"\"\n    Adds the arguments for building the con parser\n\n    For the most part, defaults are set to cross-validated values, at least for WSJ\n    \"\"\"\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--data_dir', type=str, default='data/constituency', help='Directory of constituency data.')\n\n    parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors')\n    parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n    parser.add_argument('--pretrain_max_vocab', type=int, default=250000)\n\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n\n    # BERT helps a lot and actually doesn't slow things down too much\n    # for VI, for example, use vinai/phobert-base\n    parser.add_argument('--bert_model', type=str, default=None, help=\"Use an external bert model (requires the transformers package)\")\n    parser.add_argument('--no_bert_model', dest='bert_model', action=\"store_const\", const=None, help=\"Don't use bert\")\n    parser.add_argument('--bert_hidden_layers', type=int, default=4, help=\"How many layers of hidden state to use from the transformer\")\n    parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')\n\n    # BERT finetuning (or any transformer finetuning)\n    # also helps quite a lot.\n    # Experimentally, finetuning all of the layers is the most effective\n    # On the id_icon dataset with the indolem transformer\n    # In this experiment, we trained for 150 iterations with AdaDelta,\n    # with the learning rate 0.01,\n    # then trained for another 150 with madgrad and no finetuning\n    #   1 layer        0.880753  (152)\n    #   2 layers       0.880453  (174)\n    #   3 layers       0.881774  (163)\n    #   4 layers       0.886915  (194)\n    #   5 layers       0.892064  (299)\n    #   6 layers       0.891825  (224)\n    #   7 layers       0.894373  (173)\n    #   8 layers       0.894505  (233)\n    #   9 layers       0.896676  (269)\n    #  10 layers       0.897525  (269)\n    #  11 layers       0.897348  (211)\n    #  12 layers       0.898729  (270)\n    #  everything      0.898855  (252)\n    # so the trend is clear that more finetuning is better\n    #\n    # We found that finetuning works very well on the AdaDelta portion\n    # of a multistage training, but less well on a madgrad second\n    # stage.  The issue was that we literally could not set the\n    # learning rate low enough because madgrad used epsilon in the LR:\n    #  https://github.com/facebookresearch/madgrad/issues/16\n    #\n    # Possible values of the AdaDelta learning rate on the id_icon dataset\n    # In this experiment, we finetuned the entire transformer 150\n    # iterations on AdaDelta, then trained with madgrad for another\n    # 150 with no finetuning\n    #   0.0005:    0.89122   (155)\n    #   0.001:     0.889807  (241)\n    #   0.002:     0.894874  (202)\n    #   0.005:     0.896327  (270)\n    #   0.006:     0.898989  (246)\n    #   0.007:     0.896712  (167)\n    #   0.008:     0.900136  (237)\n    #   0.009:     0.898597  (169)\n    #   0.01:      0.898665  (251)\n    #   0.012:     0.89661   (274)\n    #   0.014:     0.899149  (283)\n    #   0.016:     0.896314  (230)\n    #   0.018:     0.897753  (257)\n    #   0.02:      0.893665  (256)\n    #   0.05:      0.849274  (159)\n    #   0.1:       0.850633  (183)\n    #   0.2:       0.847332  (176)\n    #\n    # The peak is somewhere around 0.008 to 0.014, with the further\n    # observation that at the 150 iteration mark, 0.09 was winning:\n    #   0.007:     0.894589  (33)\n    #   0.008:     0.894777  (53)\n    #   0.009:     0.896466  (56)\n    #   0.01:      0.895557  (71)\n    #   0.012:     0.893479  (45)\n    #   0.014:     0.89468  (116)\n    #   0.016:     0.893053 (128)\n    #   0.018:     0.893086  (48)\n    #\n    # Another option is to train for a few iterations with no\n    # finetuning, then begin finetuning.  However, that was not\n    # beneficial at all.\n    # Start iteration on id_icon, same setup as above:\n    #   1:         0.898855  (252)\n    #   5:         0.897885  (217)\n    #   10:        0.895367  (215)\n    #   25:        0.896781  (193)\n    #   50:        0.895216  (193)\n    # Using adamw instead of madgrad:\n    #   1:         0.900594  (226)\n    #   5:         0.898153  (267)\n    #   10:        0.898756  (271)\n    #   25:        0.896867  (256)\n    #   50:        0.895025  (220)\n    #\n    #\n    # With the observation that very low learning rate is currently\n    # not working for madgrad, we tried to parameter sweep LR for\n    # AdamW, and got the following, using a first stage LR of 0.009:\n    #  0.0:     0.899706  (290)\n    #  0.00005: 0.899631  (176)\n    #  0.0001:  0.899851  (233)\n    #  0.0002:  0.898601  (207)\n    #  0.0003:  0.899258  (252)\n    #  0.0004:  0.90033  (187)\n    #  0.0005:  0.899091  (183)\n    #  0.001:   0.899791  (268)\n    #  0.002:   0.899453  (196)\n    #  0.003:   0.897029  (173)\n    #  0.004:   0.899566  (290)\n    #  0.005:   0.899285  (289)\n    #  0.01:    0.898938  (233)\n    #  0.02:    0.898983  (248)\n    #  0.03:    0.898571  (247)\n    #  0.04:    0.898466  (180)\n    #  0.05:    0.897448  (214)\n    # It should be noted that in the 0.0001 range, the epoch to epoch\n    # change of the Bert weights was almost negligible.  Weights would\n    # change in the 5th or 6th decimal place, if at all.\n    #\n    # The conclusion of all these experiments is that, if we are using\n    # bert_finetuning, the best approach is probably a stage1 learning\n    # rate of 0.009 or so and a second stage optimizer of adamw with\n    # no LR or a very low LR.  This behavior is what happens with the\n    # --stage1_bert_finetune flag\n    parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')\n    parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help=\"Don't finetune the bert (or other transformer)\")\n    parser.add_argument('--bert_finetune_layers', default=None, type=int, help='Only finetune this many layers from the transformer')\n    parser.add_argument('--bert_finetune_begin_epoch', default=None, type=int, help='Which epoch to start finetuning the transformer')\n    parser.add_argument('--bert_finetune_end_epoch', default=None, type=int, help='Which epoch to stop finetuning the transformer')\n    parser.add_argument('--bert_learning_rate', default=0.009, type=float, help='Scale the learning rate for transformer finetuning by this much')\n    parser.add_argument('--stage1_bert_learning_rate', default=None, type=float, help=\"Scale the learning rate for transformer finetuning by this much only during an AdaDelta warmup\")\n    parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much')\n    parser.add_argument('--stage1_bert_finetune', default=None, action='store_true', help=\"Finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune\")\n    parser.add_argument('--no_stage1_bert_finetune', dest='stage1_bert_finetune', action='store_false', help=\"Don't finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune\")\n\n    add_peft_args(parser)\n\n    parser.add_argument('--tag_embedding_dim', type=int, default=20, help=\"Embedding size for a tag.  0 turns off the feature\")\n    # Smaller values also seem to work\n    # For example, after 700 iterations:\n    #   32: 0.9174\n    #   50: 0.9183\n    #   72: 0.9176\n    #  100: 0.9185\n    # not a huge difference regardless\n    # (these numbers were without retagging)\n    parser.add_argument('--delta_embedding_dim', type=int, default=100, help=\"Embedding size for a delta embedding\")\n\n    parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--no_train_remove_duplicates', default=True, action='store_false', dest=\"train_remove_duplicates\", help=\"Do/don't remove duplicates from the training file.  Could be useful for intentionally reweighting some trees\")\n    parser.add_argument('--silver_file', type=str, default=None, help='Secondary training file.')\n    parser.add_argument('--silver_remove_duplicates', default=False, action='store_true', help=\"Do/don't remove duplicates from the silver training file.  Could be useful for intentionally reweighting some trees\")\n    parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')\n    # TODO: possibly refactor --tokenized_file / --tokenized_dir from here & ensemble\n    parser.add_argument('--xml_tree_file', type=str, default=None, help='Input file of VLSP formatted trees for parsing with parse_text.')\n    parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')\n    parser.add_argument('--tokenized_dir', type=str, default=None, help='Input directory of tokenized text for parsing with parse_text.')\n    parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer'])\n    parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one')\n    add_predict_output_args(parser)\n\n    parser.add_argument('--lang', type=str, help='Language')\n    parser.add_argument('--shorthand', type=str, help=\"Treebank shorthand\")\n\n    parser.add_argument('--transition_embedding_dim', type=int, default=20, help=\"Embedding size for a transition\")\n    parser.add_argument('--transition_hidden_size', type=int, default=20, help=\"Embedding size for transition stack\")\n    parser.add_argument('--transition_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()],\n                        help='How to track transitions over a parse.  {}'.format(\", \".join(x.name for x in StackHistory)))\n    parser.add_argument('--transition_heads', default=4, type=int, help=\"How many heads to use in MHA *if* the transition_stack is Attention\")\n\n    parser.add_argument('--constituent_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()],\n                        help='How to track transitions over a parse.  {}'.format(\", \".join(x.name for x in StackHistory)))\n    parser.add_argument('--constituent_heads', default=8, type=int, help=\"How many heads to use in MHA *if* the transition_stack is Attention\")\n\n    # larger was more effective, up to a point\n    # substantially smaller, such as 128,\n    # is fine if bert & charlm are not available\n    parser.add_argument('--hidden_size', type=int, default=512, help=\"Size of the output layers for constituency stack and word queue\")\n\n    parser.add_argument('--epochs', type=int, default=400)\n    parser.add_argument('--epoch_size', type=int, default=5000, help=\"Runs this many trees in an 'epoch' instead of going through the training dataset exactly once.  Set to 0 to do the whole training set\")\n    parser.add_argument('--silver_epoch_size', type=int, default=None, help=\"Runs this many trees in a silver 'epoch'.  If not set, will match --epoch_size\")\n\n    # AdaDelta warmup for the conparser.  Motivation: AdaDelta results in\n    # higher scores overall, but learns 0s for the weights of the pattn and\n    # lattn layers.  AdamW learns weights for pattn, and the models are more\n    # accurate than models trained without pattn using AdamW, but the models\n    # are lower scores overall than the AdaDelta models.\n    #\n    # This improves that by first running AdaDelta, then switching.\n    #\n    # Now, if --multistage is set, run AdaDelta for half the epochs with no\n    # pattn or lattn.  Then start the specified optimizer for the rest of\n    # the time with the full model.  If pattn and lattn are both present,\n    # the model is 1/2 no attn, 1/4 pattn, 1/4 pattn and lattn\n    #\n    # Improvement on the WSJ dev set can be seen from 94.8 to 95.3\n    # when 4 layers of pattn are trained this way.\n    # More experiments to follow.\n    parser.add_argument('--multistage', default=True, action='store_true', help='1/2 epochs with adadelta no pattn or lattn, 1/4 with chosen optim and no lattn, 1/4 full model')\n    parser.add_argument('--no_multistage', dest='multistage', action='store_false', help=\"don't do the multistage learning\")\n\n    # 1 seems to be the most effective, but we should cross-validate\n    parser.add_argument('--oracle_initial_epoch', type=int, default=1, help=\"Epoch where we start using the dynamic oracle to let the parser keep going with wrong decisions\")\n    parser.add_argument('--oracle_frequency', type=float, default=0.8, help=\"How often to use the oracle vs how often to force the correct transition\")\n    parser.add_argument('--oracle_forced_errors', type=float, default=0.001, help=\"Occasionally have the model randomly walk through the state space to try to learn how to recover\")\n    parser.add_argument('--oracle_level', type=int, default=None, help='Restrict oracle transitions to this level or lower.  0 means off.  None means use all oracle transitions.')\n    parser.add_argument('--additional_oracle_levels', type=str, default=None, help='Add some additional experimental oracle transitions.  Basically for A/B testing transitions we expect to be bad.')\n    parser.add_argument('--deactivated_oracle_levels', type=str, default=None, help='Temporarily turn off a default oracle level.  Basically for A/B testing transitions we expect to be bad.')\n\n    # 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ\n    # earlier version of the model (less accurate overall) had the following results with adadelta:\n    #  30: 0.9085\n    #  50: 0.9070\n    #  75: 0.9010\n    # 150: 0.8985\n    # as another data point, running a newer version with better constituency lstm behavior had:\n    #  30: 0.9111\n    #  50: 0.9094\n    # checking smaller batch sizes to see how this works, at 135 epochs, the values are\n    #  10: 0.8919\n    #  20: 0.9072\n    #  30: 0.9121\n    # obviously these experiments aren't the complete story, but it\n    # looks like 30 trees per batch is the best value for WSJ\n    # note that these numbers are for adadelta and might not apply\n    # to other optimizers\n    # eval batch should generally be faster the bigger the batch,\n    # up to a point, as it allows for more batching of the LSTM\n    # operations and the prediction step\n    parser.add_argument('--train_batch_size', type=int, default=30, help='How many trees to train before taking an optimizer step')\n    parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')\n\n    parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.')\n    parser.add_argument('--save_name', type=str, default=\"{shorthand}_{embedding}_{finetune}_constituency.pt\", help=\"File name to save the model\")\n    parser.add_argument('--save_each_name', type=str, default=None, help=\"Save each model in sequence to this pattern.  Mostly for testing\")\n    parser.add_argument('--save_each_start', type=int, default=None, help=\"When to start saving each model\")\n    parser.add_argument('--save_each_frequency', type=int, default=1, help=\"How frequently to save each model\")\n    parser.add_argument('--no_save_each_optimizer', dest='save_each_optimizer', default=True, action='store_false', help=\"Don't save the optimizer when saving 'each' model\")\n\n    parser.add_argument('--seed', type=int, default=1234)\n    parser.add_argument('--no_seed', action='store_const', const=None, dest='seed', help='Remove the random seed, resulting in a randomly chosen random seed')\n\n    parser.add_argument('--no_check_valid_states', default=True, action='store_false', dest='check_valid_states', help=\"Don't check the constituents or transitions in the dev set when starting a new parser.  Warning: the parser will never guess unknown constituents\")\n    parser.add_argument('--no_strict_check_constituents', default=True, action='store_false', dest='strict_check_constituents', help=\"Don't check the constituents between the train & dev set.  May result in untrainable transitions\")\n    utils.add_device_args(parser)\n\n    # Numbers are on a VLSP dataset, before adding attn or other improvements\n    # baseline is an 80.6 model that occurs when trained using adadelta, lr 1.0\n    #\n    # adabelief 0.1:      fails horribly\n    #           0.02:     converges very low scores\n    #           0.01:     very slow learning\n    #           0.002:    almost decent\n    #           0.001:    close, but about 1 f1 low on IT\n    #           0.0005:   79.71\n    #           0.0002:   80.11\n    #           0.0001:   79.85\n    #           0.00005:  80.40\n    #           0.00002:  80.02\n    #           0.00001:  78.95\n    #\n    # madgrad   0.005:    fails horribly\n    #           0.001:    low scores\n    #           0.0005:   still somewhat low\n    #           0.0002:   close, but about 1 f1 low on IT\n    #           0.0001:   80.04\n    #           0.00005:  79.91\n    #           0.00002:  80.15\n    #           0.00001:  80.44\n    #           0.000005: 80.34\n    #           0.000002: 80.39\n    #\n    # adamw experiment on a TR dataset (not necessarily the best test case)\n    # note that at that time, the expected best for adadelta was 0.816\n    #\n    #           0.00005 - 0.7925\n    #           0.0001  - 0.7889\n    #           0.0002  - 0.8110\n    #           0.00025 - 0.8108\n    #           0.0003  - 0.8050\n    #           0.0005  - 0.8076\n    #           0.001   - 0.8069\n\n    # Numbers on the VLSP Dataset, with --multistage and default learning rates and adabelief optimizer\n    # Gelu: 82.32\n    # Mish: 81.95\n    # ELU: 81.73\n    # Hardshrink: 0.3\n    # Hardsigmoid: 79.03\n    # Hardtanh: 81.44\n    # Hardswish: 81.67\n    # Logsigmoid: 80.91\n    # Prelu: 80.95 (terminated early)\n    # Relu6: 81.91\n    # RReLU: 77.00\n    # Selu: 81.17\n    # Celu: 81.43\n    # Silu: 81.90\n    # Softplus: 80.94\n    # Softshrink: 0.3\n    # Softsign: 81.63\n    # Softshrink: 13.74\n    #\n    # Tests with no_charlm, --multitstage\n    # Gelu\n    # 0.00002 0.819746\n    # 0.00005 0.818\n    # 0.0001 0.818566\n    # 0.0002 0.819111\n    # 0.001 0.815609\n    #\n    # Mish\n    # 0.00002 0.816898\n    # 0.00005 0.821085\n    # 0.0001 0.817821\n    # 0.0002 0.818806\n    # 0.001 0.816494\n    #\n    # Relu\n    # 0.00002 0.818402\n    # 0.00005 0.819019\n    # 0.0001 0.821625\n    # 0.0002 0.820633\n    # 0.001 0.814315\n    #\n    # Relu6\n    # 0.00002 0.819719\n    # 0.00005 0.819871\n    # 0.0001 0.819018\n    # 0.0002 0.819506\n    # 0.001 0.819018\n\n    parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate for the optimizer.  Reasonable values are 1.0 for adadelta or 0.001 for SGD.  None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_RATES))\n    parser.add_argument('--learning_eps', default=None, type=float, help='eps value to use in the optimizer.  None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_EPS))\n    parser.add_argument('--learning_momentum', default=None, type=float, help='Momentum.  None uses a default for the given optimizer: {}'.format(DEFAULT_MOMENTUM))\n    # weight decay values other than adadelta have not been thoroughly tested.\n    # When using adadelta, weight_decay of 0.01 to 0.001 had the best results.\n    # 0.1 was very clearly too high. 0.0001 might have been okay.\n    # Running a series of 5x experiments on a VI dataset:\n    #    0.030:   0.8167018\n    #    0.025:   0.81659\n    #    0.020:   0.81722\n    #    0.015:   0.81721\n    #    0.010:   0.81474348\n    #    0.005:   0.81503\n    parser.add_argument('--learning_weight_decay', default=None, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer')\n    parser.add_argument('--learning_rho', default=DEFAULT_LEARNING_RHO, type=float, help='Rho parameter in Adadelta')\n    # A few experiments on beta2 didn't show much benefit from changing it\n    #   On an experiment with training WSJ with default parameters\n    #   AdaDelta for 200 iterations, then training AdamW for 200 more,\n    #   0.999, 0.997, 0.995 all wound up with 0.9588\n    #   values lower than 0.995 all had a slight dropoff\n    parser.add_argument('--learning_beta2', default=0.999, type=float, help='Beta2 argument for AdamW')\n    parser.add_argument('--optim', default=None, help='Optimizer type: SGD, AdamW, Adadelta, AdaBelief, Madgrad')\n\n    parser.add_argument('--stage1_learning_rate', default=None, type=float, help='Learning rate to use in the first stage of --multistage.  None means use default: {}'.format(DEFAULT_LEARNING_RATES['adadelta']))\n\n    parser.add_argument('--learning_rate_warmup', default=0, type=int, help=\"Number of epochs to ramp up learning rate from 0 to full.  Set to 0 to always use the chosen learning rate.  Currently not functional, as it didn't do anything\")\n\n    parser.add_argument('--learning_rate_factor', default=0.6, type=float, help='Plateau learning rate decreate when plateaued')\n    parser.add_argument('--learning_rate_patience', default=5, type=int, help='Plateau learning rate patience')\n    parser.add_argument('--learning_rate_cooldown', default=10, type=int, help='Plateau learning rate cooldown')\n    parser.add_argument('--learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum')\n    parser.add_argument('--stage1_learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum (stage 1)')\n\n    parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount.  Use --no_grad_clipping to turn off grad clipping')\n    parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')\n\n    # Large Margin is from Large Margin In Softmax Cross-Entropy Loss\n    # it did not help on an Italian VIT test\n    # scores went from 0.8252 to 0.8248\n    parser.add_argument('--loss', default='cross', help='cross, large_margin, or focal.  Focal requires `pip install focal_loss_torch`')\n    parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss')\n\n    # turn off dropout for word_dropout, predict_dropout, and lstm_input_dropout\n    # this mechanism doesn't actually turn off lstm_layer_dropout (yet)\n    # but that is set to a default of 0 anyway\n    # this is reusing the idea presented in\n    # https://arxiv.org/pdf/2303.01500v2\n    # \"Dropout Reduces Underfitting\"\n    # Zhuang Liu, Zhiqiu Xu, Joseph Jin, Zhiqiang Shen, Trevor Darrell\n    # Unfortunately, this does not consistently help results\n    # Averaged of 5 models w/ transformer, dev / test\n    # id_icon - improves a little\n    #  baseline           0.8823    0.8904\n    #  early_dropout 40   0.8835    0.8919\n    # ja_alt - worsens a little\n    #  baseline           0.9308    0.9355\n    #  early_dropout 40   0.9287    0.9345\n    # vi_vlsp23 - worsens a little\n    #  baseline           0.8262    0.8290\n    #  early_dropout 40   0.8255    0.8286\n    # We keep this as an available option for further experiments, if needed\n    parser.add_argument('--early_dropout', default=-1, type=int, help='When to turn off dropout')\n    # When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations:\n    # 0.0: 0.9085\n    # 0.2: 0.9165\n    # 0.4: 0.9162\n    # 0.5: 0.9123\n    # Letting 0.2 and 0.4 run for longer, along with 0.3 as another\n    # trial, continued to give extremely similar results over time.\n    # No attempt has been made to test the different dropouts separately...\n    parser.add_argument('--word_dropout', default=0.2, type=float, help='Dropout on the word embedding')\n    parser.add_argument('--predict_dropout', default=0.2, type=float, help='Dropout on the final prediction layer')\n    # lstm_dropout has not been fully tested yet\n    # one experiment after 200 iterations (after retagging, so scores are lower than some other experiments):\n    # 0.0: 0.9093\n    # 0.1: 0.9094\n    # 0.2: 0.9094\n    # 0.3: 0.9076\n    # 0.4: 0.9077\n    parser.add_argument('--lstm_layer_dropout', default=0.0, type=float, help='Dropout in the LSTM layers')\n    # one not very conclusive experiment (not long enough) came up with these numbers after ~200 iterations\n    # 0.0       0.9091\n    # 0.1       0.9095\n    # 0.2       0.9118\n    # 0.3       0.9123\n    # 0.4       0.9080\n    parser.add_argument('--lstm_input_dropout', default=0.2, type=float, help='Dropout on the input to an LSTM')\n\n    parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()],\n                        help='Transition scheme to use.  {}'.format(\", \".join(x.name for x in TransitionScheme)))\n\n    parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed')\n\n    # combining dummy and open node embeddings might be a slight improvement\n    # for example, after 550 iterations, one experiment had\n    # True:     0.9154\n    # False:    0.9150\n    # another (with a different structure) had 850 iterations\n    # True:     0.9155\n    # False:    0.9149\n    parser.add_argument('--combined_dummy_embedding', default=True, action='store_true', help=\"Use the same embedding for dummy nodes and the vectors used when combining constituents\")\n    parser.add_argument('--no_combined_dummy_embedding', dest='combined_dummy_embedding', action='store_false', help=\"Don't use the same embedding for dummy nodes and the vectors used when combining constituents\")\n\n    # relu gave at least 1 F1 improvement over tanh in various experiments\n    # relu & gelu seem roughly the same, but relu is clearly faster.\n    # relu, 496 iterations: 0.9176\n    # gelu, 467 iterations: 0.9181\n    # after the same clock time on the same hardware.  the two had been\n    # trading places in terms of accuracy over those ~500 iterations.\n    # leaky_relu was not an improvement - a full run on WSJ led to 0.9181 f1 instead of 0.919\n    # See constituency/utils.py for more extensive comments on nonlinearity options\n    parser.add_argument('--nonlinearity', default='relu', choices=NONLINEARITY.keys(), help='Nonlinearity to use in the model.  relu is a noticeable improvement over tanh')\n    # In one experiment on an Italian dataset, VIT, we got the following:\n    #  0.8254 with relu as the nonlinearity   (10 trials)\n    #  0.8265 with maxout, k = 2              (15)\n    #  0.8253 with maxout, k = 3              (5)\n    # The speed in terms of trees/second might be slightly slower with maxout.\n    #  51.4 it/s on a Titan Xp with maxout 2 and 51.9 it/s with relu\n    # It might also be worth running some experiments with bigger\n    # output layers to see if that makes up for the difference in score.\n    parser.add_argument('--maxout_k', default=None, type=int, help=\"Use maxout layers instead of a nonlinearity for the output layers\")\n\n    parser.add_argument('--use_silver_words', default=True, dest='use_silver_words', action='store_true', help=\"Train/don't train word vectors for words only in the silver dataset\")\n    parser.add_argument('--no_use_silver_words', default=True, dest='use_silver_words', action='store_false', help=\"Train/don't train word vectors for words only in the silver dataset\")\n    parser.add_argument('--rare_word_unknown_frequency', default=0.02, type=float, help='How often to replace a rare word with UNK when training')\n    parser.add_argument('--rare_word_threshold', default=0.02, type=float, help='How many words to consider as rare words as a fraction of the dataset')\n    parser.add_argument('--tag_unknown_frequency', default=0.001, type=float, help='How often to replace a tag with UNK when training')\n\n    parser.add_argument('--num_lstm_layers', default=2, type=int, help='How many layers to use in the LSTMs')\n    parser.add_argument('--num_tree_lstm_layers', default=None, type=int, help='How many layers to use in the TREE_LSTMs, if used.  This also increases the width of the word outputs to match the tree lstm inputs.  Default 2 if TREE_LSTM or TREE_LSTM_CX, 1 otherwise')\n    parser.add_argument('--num_output_layers', default=3, type=int, help='How many layers to use at the prediction level')\n\n    parser.add_argument('--sentence_boundary_vectors', default=SentenceBoundary.EVERYTHING, type=lambda x: SentenceBoundary[x.upper()],\n                        help='Vectors to learn at the start & end of sentences.  {}'.format(\", \".join(x.name for x in SentenceBoundary)))\n    parser.add_argument('--constituency_composition', default=ConstituencyComposition.MAX, type=lambda x: ConstituencyComposition[x.upper()],\n                        help='How to build a new composition from its children.  {}'.format(\", \".join(x.name for x in ConstituencyComposition)))\n    parser.add_argument('--reduce_heads', default=8, type=int, help='Number of attn heads to use when reducing children into a parent tree (constituency_composition == attn)')\n    parser.add_argument('--reduce_position', default=None, type=int, help=\"Dimension of position vector to use when reducing children.  None means 1/4 hidden_size, 0 means don't use (constituency_composition == key | untied_key)\")\n\n    parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn.  One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw')\n    parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path')\n    parser.add_argument('--checkpoint_save_name', type=str, default=None, help=\"File name to save the most recent checkpoint\")\n    parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help=\"Don't save checkpoints\")\n    parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file')\n    parser.add_argument('--load_package', type=str, default=None, help='Download an existing stanza package & use this for tests, finetuning, etc')\n\n    retagging.add_retag_args(parser)\n\n    # Partitioned Attention\n    parser.add_argument('--pattn_d_model', default=1024, type=int, help='Partitioned attention model dimensionality')\n    parser.add_argument('--pattn_morpho_emb_dropout', default=0.2, type=float, help='Dropout rate for morphological features obtained from pretrained model')\n    parser.add_argument('--pattn_encoder_max_len', default=512, type=int, help='Max length that can be put into the transformer attention layer')\n    parser.add_argument('--pattn_num_heads', default=8, type=int, help='Partitioned attention model number of attention heads')\n    parser.add_argument('--pattn_d_kv', default=64, type=int, help='Size of the query and key vector')\n    parser.add_argument('--pattn_d_ff', default=2048, type=int, help='Size of the intermediate vectors in the feed-forward sublayer')\n    parser.add_argument('--pattn_relu_dropout', default=0.1, type=float, help='ReLU dropout probability in feed-forward sublayer')\n    parser.add_argument('--pattn_residual_dropout', default=0.2, type=float, help='Residual dropout probability for all residual connections')\n    parser.add_argument('--pattn_attention_dropout', default=0.2, type=float, help='Attention dropout probability')\n    parser.add_argument('--pattn_num_layers', default=0, type=int, help='Number of layers for the Partitioned Attention.  Currently turned off')\n    parser.add_argument('--pattn_bias', default=False, action='store_true', help='Whether or not to learn an additive bias')\n    # Results seem relatively similar with learned position embeddings or sin/cos position embeddings\n    parser.add_argument('--pattn_timing', default='sin', choices=['learned', 'sin'], help='Use a learned embedding or a sin embedding')\n\n    # Label Attention\n    parser.add_argument('--lattn_d_input_proj', default=None, type=int, help='If set, project the non-positional inputs down to this size before proceeding.')\n    parser.add_argument('--lattn_d_kv', default=64, type=int, help='Dimension of the key/query vector')\n    parser.add_argument('--lattn_d_proj', default=64, type=int, help='Dimension of the output vector from each label attention head')\n    parser.add_argument('--lattn_resdrop', default=True, action='store_true', help='Whether or not to use Residual Dropout')\n    parser.add_argument('--lattn_pwff', default=True, action='store_true', help='Whether or not to use a Position-wise Feed-forward Layer')\n    parser.add_argument('--lattn_q_as_matrix', default=False, action='store_true', help='Whether or not Label Attention uses learned query vectors. False means it does')\n    parser.add_argument('--lattn_partitioned', default=True, action='store_true', help='Whether or not it is partitioned')\n    parser.add_argument('--no_lattn_partitioned', default=True, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned')\n    parser.add_argument('--lattn_combine_as_self', default=False, action='store_true', help='Whether or not the layer uses concatenation. False means it does')\n    # currently unused - always assume 1/2 of pattn\n    #parser.add_argument('--lattn_d_positional', default=512, type=int, help='Dimension for the positional embedding')\n    parser.add_argument('--lattn_d_l', default=32, type=int, help='Number of labels')\n    parser.add_argument('--lattn_attention_dropout', default=0.2, type=float, help='Dropout for attention layer')\n    parser.add_argument('--lattn_d_ff', default=2048, type=int, help='Dimension of the Feed-forward layer')\n    parser.add_argument('--lattn_relu_dropout', default=0.2, type=float, help='Relu dropout for the label attention')\n    parser.add_argument('--lattn_residual_dropout', default=0.2, type=float, help='Residual dropout for the label attention')\n    parser.add_argument('--lattn_combined_input', default=True, action='store_true', help='Combine all inputs for the lattn, not just the pattn')\n    parser.add_argument('--use_lattn', default=False, action='store_true', help='Use the lattn layers - currently turned off')\n    parser.add_argument('--no_use_lattn', dest='use_lattn', action='store_false', help='Use the lattn layers - currently turned off')\n    parser.add_argument('--no_lattn_combined_input', dest='lattn_combined_input', action='store_false', help=\"Don't combine all inputs for the lattn, not just the pattn\")\n\n    parser.add_argument('--use_rattn', default=False, action='store_true', help='Use a local attention layer')\n    parser.add_argument('--rattn_window', default=16, type=int, help='Number of tokens to use for context in the local attention')\n    # Ran an experiment on id_icon with in_order, peft, 200 epochs training\n    # Equivalent experiment with no rattn had an average of 0.8922 dev\n    # window 16, cat, dim 200, sinks 0\n    #   head      dev score\n    #     1         0.8915\n    #     2         0.8933\n    #     3         0.8918\n    #     4         0.8934\n    #     5         0.8924\n    #     6         0.8936\n    #     8         0.8920\n    #    10         0.8909\n    #    12         0.8939\n    #    14         0.8949\n    #    16         0.8952\n    #    18         0.8915\n    #    20         0.8925\n    #    25         0.8913\n    #    30         0.8913\n    #    40         0.8943\n    #    50         0.8931\n    #    75         0.8940\n    # The average here is 0.8928, which is a tiny bit higher...\n    parser.add_argument('--rattn_heads', default=16, type=int, help='Number of heads to use for context in the local attention')\n    parser.add_argument('--no_rattn_forward', default=True, action='store_false', dest='rattn_forward', help=\"Use or don't use the forward relative attention\")\n    parser.add_argument('--no_rattn_reverse', default=True, action='store_false', dest='rattn_reverse', help=\"Use or don't use the reverse relative attention\")\n    parser.add_argument('--no_rattn_cat', action='store_false', dest='rattn_cat', help='Stack the rattn layers instead of adding them')\n    parser.add_argument('--rattn_cat', default=True, action='store_true', help='Stack the rattn layers instead of adding them')\n    parser.add_argument('--rattn_dim', default=200, type=int, help='Dimension of the rattn output when cat')\n    parser.add_argument('--rattn_sinks', default=0, type=int, help='Number of attention sink tokens to learn')\n    parser.add_argument('--rattn_use_endpoint_sinks', default=False, action='store_true', help='Use the endpoints of the sentences as sinks')\n\n    parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training.  A very noisy option')\n    parser.add_argument('--log_shapes', default=False, action='store_true', help='Log the parameters shapes at the beginning')\n    parser.add_argument('--watch_regex', default=None, help='regex to describe which weights and biases to output, if any')\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n    parser.add_argument('--wandb_norm_regex', default=None, help='Log on wandb any tensor whose norm matches this matrix.  Might get cluttered?')\n\n    return parser\n\ndef build_model_filename(args):\n    embedding = utils.embedding_name(args)\n    maybe_finetune = \"finetuned\" if args['bert_finetune'] or args['stage1_bert_finetune'] else \"\"\n    transformer_finetune_begin = \"%d\" % args['bert_finetune_begin_epoch'] if args['bert_finetune_begin_epoch'] is not None else \"\"\n\n    rattn = \"\"\n    if args['use_rattn']:\n        if args['rattn_forward']: rattn = rattn + \"F\"\n        if args['rattn_reverse']: rattn = rattn + \"R\"\n        if rattn:\n            if args['rattn_cat']:\n                rattn += \"c\"\n            rattn += \"h%02d\" % args['rattn_heads']\n            rattn += \"w%02d\" % args['rattn_window']\n            if args['rattn_sinks'] > 0:\n                rattn += \"s%d\" % args['rattn_sinks']\n\n    model_save_file = args['save_name'].format(shorthand=args['shorthand'],\n                                               oracle_level=args['oracle_level'],\n                                               embedding=embedding,\n                                               finetune=maybe_finetune,\n                                               transformer_finetune_begin=transformer_finetune_begin,\n                                               transition_scheme=args['transition_scheme'].name.lower().replace(\"_\", \"\"),\n                                               tscheme=args['transition_scheme'].short_name,\n                                               trans_layers=args['bert_hidden_layers'],\n                                               rattn=rattn,\n                                               seed=args['seed'])\n    model_save_file = re.sub(\"_+\", \"_\", model_save_file)\n    logger.info(\"Expanded save_name: %s\", model_save_file)\n\n    model_dir = os.path.split(model_save_file)[0]\n    if model_dir != args['save_dir']:\n        model_save_file = os.path.join(args['save_dir'], model_save_file)\n    return model_save_file\n\ndef parse_args(args=None):\n    parser = build_argparse()\n\n    args = parser.parse_args(args=args)\n    resolve_peft_args(args, logger, check_bert_finetune=False)\n    if not args.lang and args.shorthand and len(args.shorthand.split(\"_\", maxsplit=1)) == 2:\n        args.lang = args.shorthand.split(\"_\")[0]\n\n    if args.stage1_bert_learning_rate is None:\n        args.stage1_bert_learning_rate = args.bert_learning_rate\n\n    if args.optim is None and args.mode == 'train':\n        if not args.multistage:\n            # this seemed to work the best when not doing multistage\n            args.optim = \"adadelta\"\n            if args.use_peft and not args.bert_finetune:\n                logger.info(\"--use_peft set.  setting --bert_finetune as well\")\n                args.bert_finetune = True\n        elif args.bert_finetune or args.stage1_bert_finetune:\n            logger.info(\"Multistage training is set, optimizer is not chosen, and bert finetuning is active.  Will use AdamW as the second stage optimizer.\")\n            args.optim = \"adamw\"\n        else:\n            # if MADGRAD exists, use it\n            # otherwise, adamw\n            try:\n                import madgrad\n                args.optim = \"madgrad\"\n                logger.info(\"Multistage training is set, optimizer is not chosen, and MADGRAD is available.  Will use MADGRAD as the second stage optimizer.\")\n            except ModuleNotFoundError as e:\n                logger.warning(\"Multistage training is set.  Best models are with MADGRAD, but it is not installed.  Will use AdamW for the second stage optimizer.  Consider installing MADGRAD\")\n                args.optim = \"adamw\"\n\n    if args.mode == 'train':\n        if args.learning_rate is None:\n            args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim.lower(), None)\n        if args.learning_eps is None:\n            args.learning_eps = DEFAULT_LEARNING_EPS.get(args.optim.lower(), None)\n        if args.learning_momentum is None:\n            args.learning_momentum = DEFAULT_MOMENTUM.get(args.optim.lower(), None)\n        if args.learning_weight_decay is None:\n            args.learning_weight_decay = DEFAULT_WEIGHT_DECAY.get(args.optim.lower(), None)\n\n        if args.stage1_learning_rate is None:\n            args.stage1_learning_rate = DEFAULT_LEARNING_RATES[\"adadelta\"]\n        if args.stage1_bert_finetune is None:\n            args.stage1_bert_finetune = args.bert_finetune\n\n        if args.learning_rate_min_lr is None:\n            args.learning_rate_min_lr = args.learning_rate * 0.02\n        if args.stage1_learning_rate_min_lr is None:\n            args.stage1_learning_rate_min_lr = args.stage1_learning_rate * 0.02\n\n    if args.reduce_position is None:\n        args.reduce_position = args.hidden_size // 4\n\n    if args.num_tree_lstm_layers is None:\n        if args.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):\n            args.num_tree_lstm_layers = 2\n        else:\n            args.num_tree_lstm_layers = 1\n\n    if args.wandb_name or args.wandb_norm_regex:\n        args.wandb = True\n\n    args = vars(args)\n\n    retagging.postprocess_args(args)\n    postprocess_predict_output_args(args)\n\n    if args['seed'] is None:\n        args['seed'] = random.randint(0, 1000000000)\n        logger.info(\"Using random seed %d\", args['seed'])\n\n    model_save_file = build_model_filename(args)\n    args['save_name'] = model_save_file\n\n    if args['save_each_name']:\n        model_save_each_file = os.path.join(args['save_dir'], args['save_each_name'])\n        model_save_each_file = utils.build_save_each_filename(model_save_each_file)\n        args['save_each_name'] = model_save_each_file\n    else:\n        # in the event that there is a start epoch setting,\n        # this will make a reasonable default for the path\n        pieces = os.path.splitext(args['save_name'])\n        model_save_each_file = pieces[0] + \"_%04d\" + pieces[1]\n        args['save_each_name'] = model_save_each_file\n\n    if args['checkpoint']:\n        args['checkpoint_save_name'] = utils.checkpoint_name(args['save_dir'], model_save_file, args['checkpoint_save_name'])\n\n    return args\n\ndef main(args=None):\n    \"\"\"\n    Main function for building con parser\n\n    Processes args, calls the appropriate function for the chosen --mode\n    \"\"\"\n    args = parse_args(args=args)\n\n    utils.set_random_seed(args['seed'])\n\n    logger.info(\"Running constituency parser in %s mode\", args['mode'])\n    logger.debug(\"Using device: %s\", args['device'])\n\n    model_load_file = args['save_name']\n    if args['load_name']:\n        if os.path.exists(args['load_name']):\n            model_load_file = args['load_name']\n        else:\n            model_load_file = os.path.join(args['save_dir'], args['load_name'])\n    elif args['load_package']:\n        if args['lang'] is None:\n            lang_pieces = args['load_package'].split(\"_\", maxsplit=1)\n            try:\n                lang = constant.lang_to_langcode(lang_pieces[0])\n            except ValueError as e:\n                raise ValueError(\"--lang not specified, and the start of the --load_package name, %s, is not a known language.  Please check the values of those parameters\" % args['load_package']) from e\n            args['lang'] = lang\n            args['load_package'] = lang_pieces[1]\n        stanza.download(args['lang'], processors=\"constituency\", package={\"constituency\": args['load_package']})\n        model_load_file = os.path.join(DEFAULT_MODEL_DIR, args['lang'], 'constituency', args['load_package'] + \".pt\")\n        if not os.path.exists(model_load_file):\n            raise FileNotFoundError(\"Expected the downloaded model file for language %s package %s to be in %s, but there is nothing there.  Perhaps the package name doesn't exist?\" % (args['lang'], args['load_package'], model_load_file))\n        else:\n            logger.info(\"Model for language %s package %s is in %s\", args['lang'], args['load_package'], model_load_file)\n\n    # TODO: when loading a saved model, we should default to whatever\n    # is in the model file for --retag_method, not the default for the language\n    if args['mode'] == 'train':\n        if tlogger.level == logging.NOTSET:\n            tlogger.setLevel(logging.DEBUG)\n            tlogger.debug(\"Set trainer logging level to DEBUG\")\n\n    retag_pipeline = retagging.build_retag_pipeline(args)\n\n    if args['mode'] == 'train':\n        parser_training.train(args, model_load_file, retag_pipeline)\n    elif args['mode'] == 'predict':\n        parser_training.evaluate(args, model_load_file, retag_pipeline)\n    elif args['mode'] == 'parse_text':\n        load_model_parse_text(args, model_load_file, retag_pipeline)\n    elif args['mode'] == 'remove_optimizer':\n        parser_training.remove_optimizer(args, args['save_name'], model_load_file)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/coref/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/coref/anaphoricity_scorer.py",
    "content": "\"\"\" Describes AnaphicityScorer, a torch module that for a matrix of\nmentions produces their anaphoricity scores.\n\"\"\"\nimport torch\n\nfrom stanza.models.coref import utils\nfrom stanza.models.coref.config import Config\n\n\nclass AnaphoricityScorer(torch.nn.Module):\n    \"\"\" Calculates anaphoricity scores by passing the inputs into a FFNN \"\"\"\n    def __init__(self,\n                 in_features: int,\n                 config: Config):\n        super().__init__()\n        hidden_size = config.hidden_size\n        if not config.n_hidden_layers:\n            hidden_size = in_features\n        layers = []\n        for i in range(config.n_hidden_layers):\n            layers.extend([torch.nn.Linear(hidden_size if i else in_features,\n                                           hidden_size),\n                           torch.nn.LeakyReLU(),\n                           torch.nn.Dropout(config.dropout_rate)])\n        self.hidden = torch.nn.Sequential(*layers)\n        self.out = torch.nn.Linear(hidden_size, out_features=1)\n\n        # are we going to predict singletons\n        self.predict_singletons = config.singletons\n\n        if self.predict_singletons:\n            # map to whether or not this is a start of a coref given all the\n            # antecedents; not used when config.singletons = False because\n            # we only need to know this for predicting singletons\n            self.start_map = torch.nn.Linear(config.rough_k, out_features=1, bias=False)\n\n\n    def forward(self, *,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch\n                top_mentions: torch.Tensor,\n                mentions_batch: torch.Tensor,\n                pw_batch: torch.Tensor,\n                top_rough_scores_batch: torch.Tensor,\n                ) -> torch.Tensor:\n        \"\"\" Builds a pairwise matrix, scores the pairs and returns the scores.\n\n        Args:\n            all_mentions (torch.Tensor): [n_mentions, mention_emb]\n            mentions_batch (torch.Tensor): [batch_size, mention_emb]\n            pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb]\n            top_indices_batch (torch.Tensor): [batch_size, n_ants]\n            top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]\n\n        Returns:\n            torch.Tensor [batch_size, n_ants + 1]\n                anaphoricity scores for the pairs + a dummy column\n        \"\"\"\n        # [batch_size, n_ants, pair_emb]\n        pair_matrix = self._get_pair_matrix(mentions_batch, pw_batch, top_mentions)\n\n        # [batch_size, n_ants] vs [batch_size, 1]\n        # first is coref scores, the second is whether its the start of a coref\n        if self.predict_singletons:\n            scores, start = self._ffnn(pair_matrix)\n            scores = utils.add_dummy(scores+top_rough_scores_batch, eps=True)\n\n            return torch.cat([start, scores], dim=1)\n        else:\n            scores = self._ffnn(pair_matrix)\n            return utils.add_dummy(scores+top_rough_scores_batch, eps=True)\n\n    def _ffnn(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Calculates anaphoricity scores.\n\n        Args:\n            x: tensor of shape [batch_size, n_ants, n_features]\n\n        Returns:\n            tensor of shape [batch_size, n_ants]\n        \"\"\"\n        x = self.out(self.hidden(x))\n        x = x.squeeze(2)\n\n        if not self.predict_singletons:\n            return x\n\n        # because sometimes we only have the first 49 anaphoricities\n        start = x @ self.start_map.weight[:,:x.shape[1]].T\n        return x, start\n\n    @staticmethod\n    def _get_pair_matrix(mentions_batch: torch.Tensor,\n                         pw_batch: torch.Tensor,\n                         top_mentions: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Builds the matrix used as input for AnaphoricityScorer.\n\n        Args:\n            all_mentions (torch.Tensor): [n_mentions, mention_emb],\n                all the valid mentions of the document,\n                can be on a different device\n            mentions_batch (torch.Tensor): [batch_size, mention_emb],\n                the mentions of the current batch,\n                is expected to be on the current device\n            pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],\n                pairwise features of the current batch,\n                is expected to be on the current device\n            top_indices_batch (torch.Tensor): [batch_size, n_ants],\n                indices of antecedents of each mention\n\n        Returns:\n            torch.Tensor: [batch_size, n_ants, pair_emb]\n        \"\"\"\n        emb_size = mentions_batch.shape[1]\n        n_ants = pw_batch.shape[1]\n\n        a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)\n        b_mentions = top_mentions\n        similarity = a_mentions * b_mentions\n\n        out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2)\n        return out\n"
  },
  {
    "path": "stanza/models/coref/bert.py",
    "content": "\"\"\"Functions related to BERT or similar models\"\"\"\n\nimport logging\nfrom typing import List, Tuple\n\nimport numpy as np                                 # type: ignore\nfrom transformers import AutoModel, AutoTokenizer  # type: ignore\n\nfrom stanza.models.coref.config import Config\nfrom stanza.models.coref.const import Doc\n\n\nlogger = logging.getLogger('stanza')\n\ndef get_subwords_batches(doc: Doc,\n                         config: Config,\n                         tok: AutoTokenizer\n                         ) -> np.ndarray:\n    \"\"\"\n    Turns a list of subwords to a list of lists of subword indices\n    of max length == batch_size (or shorter, as batch boundaries\n    should match sentence boundaries). Each batch is enclosed in cls and sep\n    special tokens.\n\n    Returns:\n        batches of bert tokens [n_batches, batch_size]\n    \"\"\"\n    batch_size = config.bert_window_size - 2  # to save space for CLS and SEP\n\n    subwords: List[str] = doc[\"subwords\"]\n    subwords_batches = []\n    start, end = 0, 0\n\n    while end < len(subwords):\n        # to prevent the case where a batch_size step forward\n        # doesn't capture more than 1 sentence, we will just cut\n        # that sequence\n        prev_end = end\n        end = min(end + batch_size, len(subwords))\n\n        # Move back till we hit a sentence end\n        if end < len(subwords):\n            sent_id = doc[\"sent_id\"][doc[\"word_id\"][end]]\n            while end and doc[\"sent_id\"][doc[\"word_id\"][end - 1]] == sent_id:\n                end -= 1\n\n        # this occurs IFF there was no sentence end found throughout\n        # the forward scan; this means that our sentence was waay too\n        # long (i.e. longer than the max length of the transformer.\n        #\n        # if so, we give up and just chop the sentence off at the max length\n        # that was given\n        if end == prev_end:\n            end = min(end + batch_size, len(subwords))\n\n        length = end - start\n        if tok.cls_token == None or tok.sep_token == None:\n            batch = [tok.eos_token] + subwords[start:end] + [tok.eos_token]\n        else:\n            batch = [tok.cls_token] + subwords[start:end] + [tok.sep_token]\n\n        # Padding to desired length\n        batch += [tok.pad_token] * (batch_size - length)\n\n        subwords_batches.append([tok.convert_tokens_to_ids(token)\n                                 for token in batch])\n        start += length\n\n    return np.array(subwords_batches)\n"
  },
  {
    "path": "stanza/models/coref/cluster_checker.py",
    "content": "\"\"\" Describes ClusterChecker, a class used to retrieve LEA scores.\nSee aclweb.org/anthology/P16-1060.pdf. \"\"\"\n\nfrom typing import Hashable, List, Tuple\n\nfrom stanza.models.coref.const import EPSILON\nimport numpy as np\n\nimport math\nimport logging\n\nlogger = logging.getLogger('stanza')\n\n\nclass ClusterChecker:\n    \"\"\" Collects information on gold and predicted clusters across documents.\n    Can be used to retrieve weighted LEA-score for them.\n    \"\"\"\n    def __init__(self):\n        self._lea_precision = 0.0\n        self._lea_recall = 0.0\n        self._lea_precision_weighting = 0.0\n        self._lea_recall_weighting = 0.0\n        self._num_preds = 0.0\n\n        # muc\n        self._muc_precision = 0.0\n        self._muc_recall = 0.0\n\n        # b3\n        self._b3_precision = 0.0\n        self._b3_recall = 0.0\n\n        # ceafe\n        self._ceafe_precision = 0.0\n        self._ceafe_recall = 0.0\n\n    @staticmethod\n    def _f1(p,r):\n        return (p * r) / (p+r + EPSILON) * 2\n    \n    def add_predictions(self,\n                        gold_clusters: List[List[Hashable]],\n                        pred_clusters: List[List[Hashable]]):\n        \"\"\"\n        Calculates LEA for the document's clusters and stores them to later\n        output weighted LEA across documents.\n\n        Returns:\n            LEA score for the document as a tuple of (f1, precision, recall)\n        \"\"\"\n\n        # if len(gold_clusters) == 0:\n            # breakpoint()\n\n        self._num_preds += 1\n        \n        recall, r_weight = ClusterChecker._lea(gold_clusters, pred_clusters)\n        precision, p_weight = ClusterChecker._lea(pred_clusters, gold_clusters)\n\n        self._muc_recall +=  ClusterChecker._muc(gold_clusters, pred_clusters)\n        self._muc_precision += ClusterChecker._muc(pred_clusters, gold_clusters)\n\n        self._b3_recall += ClusterChecker._b3(gold_clusters, pred_clusters)\n        self._b3_precision += ClusterChecker._b3(pred_clusters, gold_clusters)\n\n        ceafe_precision, ceafe_recall = ClusterChecker._ceafe(pred_clusters, gold_clusters)\n        if math.isnan(ceafe_precision) and len(gold_clusters) > 0:\n            # because our model predicted no clusters\n            ceafe_precision = 0.0\n\n        self._ceafe_precision += ceafe_precision\n        self._ceafe_recall += ceafe_recall\n\n        self._lea_recall += recall\n        self._lea_recall_weighting += r_weight\n        self._lea_precision += precision\n        self._lea_precision_weighting += p_weight\n\n        doc_precision = precision / (p_weight + EPSILON)\n        doc_recall = recall / (r_weight + EPSILON)\n        doc_f1 = (doc_precision * doc_recall) \\\n            / (doc_precision + doc_recall + EPSILON) * 2\n        return doc_f1, doc_precision, doc_recall\n\n    @property\n    def bakeoff(self):\n        \"\"\" Get the F1 macroaverage score used by the bakeoff \"\"\"\n        return sum(self.mbc)/3\n\n    @property\n    def mbc(self):\n        \"\"\" Get the F1 average score of (muc, b3, ceafe) over docs \"\"\"\n        avg_precisions = [self._muc_precision, self._b3_precision, self._ceafe_precision]\n        avg_precisions = [i/(self._num_preds + EPSILON) for i in avg_precisions]\n\n        avg_recalls = [self._muc_recall, self._b3_recall, self._ceafe_recall]\n        avg_recalls = [i/(self._num_preds + EPSILON) for i in avg_recalls]\n\n        avg_f1s = [self._f1(p,r) for p,r in zip(avg_precisions, avg_recalls)]\n\n        return avg_f1s\n\n    @property\n    def total_lea(self):\n        \"\"\" Returns weighted LEA for all the documents as\n        (f1, precision, recall) \"\"\"\n        precision = self._lea_precision / (self._lea_precision_weighting + EPSILON)\n        recall = self._lea_recall / (self._lea_recall_weighting + EPSILON)\n        f1 = self._f1(precision, recall)\n        return f1, precision, recall\n\n    @staticmethod\n    def _lea(key: List[List[Hashable]],\n             response: List[List[Hashable]]) -> Tuple[float, float]:\n        \"\"\" See aclweb.org/anthology/P16-1060.pdf. \"\"\"\n        response_clusters = [set(cluster) for cluster in response]\n        response_map = {mention: cluster\n                        for cluster in response_clusters\n                        for mention in cluster}\n        importances = []\n        resolutions = []\n        for entity in key:\n            size = len(entity)\n            if size == 1:  # entities of size 1 are not annotated\n                continue\n            importances.append(size)\n            correct_links = 0\n            for i in range(size):\n                for j in range(i + 1, size):\n                    correct_links += int(entity[i]\n                                         in response_map.get(entity[j], {}))\n            resolutions.append(correct_links / (size * (size - 1) / 2))\n        res = sum(imp * res for imp, res in zip(importances, resolutions))\n        weight = sum(importances)\n        return res, weight\n\n    @staticmethod\n    def _muc(key: List[List[Hashable]],\n             response: List[List[Hashable]]) -> float:\n        \"\"\" See aclweb.org/anthology/P16-1060.pdf. \"\"\"\n\n        response_clusters = [set(cluster) for cluster in response]\n        response_map = {mention: cluster\n                        for cluster in response_clusters\n                        for mention in cluster}\n\n        top = 0 # sum over k of |k_i| - response_partitions(|k_i|)\n        bottom = 0 # sum over k of |k_i| - 1\n\n        for entity in key:\n            S = len(entity)\n            # we need to figure the number of DIFFERENT clusters \n            # the response assigns to members of the entity; ideally\n            # this number is 1 (i.e. they are all assigned the same\n            # coref).\n            response_clusters = [response_map.get(i, None) for i in entity]\n            # and dedplicate\n            deduped = []\n            for i in response_clusters:\n                if i == None:\n                    deduped.append(i)\n                elif i not in deduped:\n                    deduped.append(i)\n            # the \"partitions\" will then be size of the deduped list\n            p_k = len(deduped)\n            top += (S - p_k)\n            bottom += (S - 1)\n        \n        try:\n            return top/bottom\n        except ZeroDivisionError:\n            logger.warning(\"muc got a zero division error because the model predicted no spans!\")\n            return 0 # +inf technically\n\n    @staticmethod\n    def _b3(key: List[List[Hashable]],\n            response: List[List[Hashable]]) -> float:\n        \"\"\" See aclweb.org/anthology/P16-1060.pdf. \"\"\"\n        \n        response_clusters = [set(cluster) for cluster in response]\n\n        top = 0 # sum over key and response of (|k intersect response|^2/|k|)\n        bottom = 0 # sum over k of |k_i|\n\n        for entity in key:\n            bottom += len(entity)\n            entity = set(entity)\n\n            for res_entity in response_clusters:\n                top += (len(entity.intersection(res_entity))**2)/len(entity)\n\n        try:\n            return top/bottom\n        except ZeroDivisionError:\n            logger.warning(\"b3 got a zero division error because the model predicted no spans!\")\n            return 0 # +inf technically\n\n\n\n    @staticmethod\n    def _phi4(c1, c2):\n        return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))\n\n    @staticmethod\n    def _ceafe(clusters: List[List[Hashable]], gold_clusters: List[List[Hashable]]):\n        \"\"\" see https://github.com/ufal/corefud-scorer/blob/main/coval/eval/evaluator.py \"\"\"\n\n        try:\n            from scipy.optimize import linear_sum_assignment\n        except ImportError:\n            raise ImportError(\"To perform CEAF scoring, please install scipy via `pip install scipy` for the Kuhn-Munkres linear assignment scheme.\")\n\n        clusters = [c for c in clusters]\n        scores = np.zeros((len(gold_clusters), len(clusters)))\n        for i in range(len(gold_clusters)):\n            for j in range(len(clusters)):\n                scores[i, j] = ClusterChecker._phi4(gold_clusters[i], clusters[j])\n        row_ind, col_ind = linear_sum_assignment(-scores)\n        similarity = scores[row_ind, col_ind].sum()\n\n        # precision, recall\n        try:\n            prec = similarity/len(clusters)\n        except ZeroDivisionError:\n            logger.warning(\"ceafe got a zero division error because the model predicted no spans!\")\n            prec = 0\n        recc = similarity/len(gold_clusters)\n        return prec, recc\n\n"
  },
  {
    "path": "stanza/models/coref/config.py",
    "content": "\"\"\" Describes Config, a simple namespace for config values.\n\nFor description of all config values, refer to config.toml.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Dict, List\n\n\n@dataclass\nclass Config:  # pylint: disable=too-many-instance-attributes, too-few-public-methods\n    \"\"\" Contains values needed to set up the coreference model. \"\"\"\n    section: str\n\n    # TODO: can either eliminate data_dir or use it for the train/dev/test data\n    data_dir: str\n    save_dir: str\n    save_name: str\n\n    train_data: str\n    dev_data: str\n    test_data: str\n\n    device: str\n\n    bert_model: str\n    bert_window_size: int\n\n    embedding_size: int\n    sp_embedding_size: int\n    a_scoring_batch_size: int\n    hidden_size: int\n    n_hidden_layers: int\n\n    max_span_len: int\n\n    rough_k: int\n\n    lora: bool\n    lora_alpha: int\n    lora_rank: int\n    lora_dropout: float\n\n    full_pairwise: bool\n\n    lora_target_modules: List[str]\n    lora_modules_to_save: List[str]\n\n    clusters_starts_are_singletons: bool\n    bert_finetune: bool\n    dropout_rate: float\n    learning_rate: float\n    bert_learning_rate: float\n    # we find that setting this to a small but non-zero number\n    # makes the model less likely to forget how to do anything\n    bert_finetune_begin_epoch: float\n    train_epochs: int\n    # if plateaued for this many epochs, stop training\n    plateau_epochs: int\n    bce_loss_weight: float\n\n    tokenizer_kwargs: Dict[str, dict]\n    conll_log_dir: str\n\n    save_each_checkpoint: bool\n    log_norms: bool\n    singletons: bool\n    \n    max_train_len: int\n    use_zeros: bool\n\n    lang_lr_attenuation: str\n    lang_lr_weights: str\n"
  },
  {
    "path": "stanza/models/coref/conll.py",
    "content": "\"\"\" Contains functions to produce conll-formatted output files with\npredicted spans and their clustering \"\"\"\n\nfrom collections import defaultdict\nfrom contextlib import contextmanager\nimport os\nfrom typing import List, TextIO\n\nfrom stanza.models.coref.config import Config\nfrom stanza.models.coref.const import Doc, Span\n\n\n# pylint: disable=too-many-locals\ndef write_conll(doc: Doc,\n                clusters: List[List[Span]],\n                heads: List[int],\n                f_obj: TextIO):\n    \"\"\" Writes span/cluster information to f_obj, which is assumed to be a file\n    object open for writing \"\"\"\n    placeholder = list(\"\\t_\" * 7)\n    # the nth token needs to be a number\n    placeholder[9] = \"0\"\n    placeholder = \"\".join(placeholder)\n    doc_id = doc[\"document_id\"].replace(\"-\", \"_\").replace(\"/\", \"_\").replace(\".\",\"_\")\n    words = doc[\"cased_words\"]\n    part_id = doc[\"part_id\"]\n    sents = doc[\"sent_id\"]\n\n    max_word_len = max(len(w) for w in words)\n\n    starts = defaultdict(lambda: [])\n    ends = defaultdict(lambda: [])\n    single_word = defaultdict(lambda: [])\n\n    for cluster_id, cluster in enumerate(clusters):\n        if len(heads[cluster_id]) != len(cluster):\n            # TODO debug this fact and why it occurs\n            # print(f\"cluster {cluster_id} doesn't have the same number of elements for word and span levels, skipping...\")\n            continue\n        for cluster_part, (start, end) in enumerate(cluster):\n            if end - start == 1:\n                single_word[start].append((cluster_part, cluster_id))\n            else:\n                starts[start].append((cluster_part, cluster_id))\n                ends[end - 1].append((cluster_part, cluster_id))\n\n    f_obj.write(f\"# newdoc id = {doc_id}\\n# global.Entity = eid-head\\n\")\n\n    word_number = 0\n    sent_id = 0\n    for word_id, word in enumerate(words):\n\n        cluster_info_lst = []\n        for part, cluster_marker in starts[word_id]:\n            start, end = clusters[cluster_marker][part]\n            cluster_info_lst.append(f\"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)}\")\n        for part, cluster_marker in single_word[word_id]:\n            start, end = clusters[cluster_marker][part]\n            cluster_info_lst.append(f\"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)})\")\n        for part, cluster_marker in ends[word_id]:\n            cluster_info_lst.append(f\"e{cluster_marker})\")\n\n\n        # we need our clusters to be ordered such that the one that is closest the first change\n        # is listed last in the chains\n        def compare_sort(x):\n            split = x.split(\"-\")\n            if len(split) > 1: \n                return int(split[-1].replace(\")\", \"\").strip())  \n            else: \n                # we want everything that's a closer to be first\n                return float(\"inf\")\n\n        cluster_info_lst = sorted(cluster_info_lst, key=compare_sort, reverse=True)\n        cluster_info = \"\".join(cluster_info_lst) if cluster_info_lst else \"_\"\n\n        if word_id == 0 or sents[word_id] != sents[word_id - 1]:\n            f_obj.write(f\"# sent_id = {doc_id}-{sent_id}\\n\")\n            word_number = 0\n            sent_id += 1\n\n        if cluster_info != \"_\":\n            cluster_info = f\"Entity={cluster_info}\"\n\n        f_obj.write(f\"{word_id}\\t{word}{placeholder}\\t{cluster_info}\\n\")\n\n        word_number += 1\n\n    f_obj.write(\"\\n\")\n\n\n@contextmanager\ndef open_(config: Config, epochs: int, data_split: str):\n    \"\"\" Opens conll log files for writing in a safe way. \"\"\"\n    base_filename = f\"{config.section}_{data_split}_e{epochs}\"\n    conll_dir = config.conll_log_dir\n    kwargs = {\"mode\": \"w\", \"encoding\": \"utf8\"}\n\n    os.makedirs(conll_dir, exist_ok=True)\n\n    with open(os.path.join(  # type: ignore\n            conll_dir, f\"{base_filename}.gold.conll\"), **kwargs) as gold_f:\n        with open(os.path.join(  # type: ignore\n                conll_dir, f\"{base_filename}.pred.conll\"), **kwargs) as pred_f:\n            yield (gold_f, pred_f)\n"
  },
  {
    "path": "stanza/models/coref/const.py",
    "content": "\"\"\" Contains type aliases for coref module \"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Tuple\n\nimport torch\n\n\nEPSILON = 1e-7\nLARGE_VALUE = 1000  # used instead of inf due to bug #16762 in pytorch\n\nDoc = Dict[str, Any]\nSpan = Tuple[int, int]\n\n\n@dataclass\nclass CorefResult:\n    coref_scores: torch.Tensor = None                  # [n_words, k + 1]\n    coref_y: torch.Tensor = None                       # [n_words, k + 1]\n    rough_y: torch.Tensor = None                       # [n_words, n_words]\n\n    word_clusters: List[List[int]] = None\n    span_clusters: List[List[Span]] = None\n\n    rough_scores: torch.Tensor = None                  # [n_words, n_words]\n    span_scores: torch.Tensor = None                   # [n_heads, n_words, 2]\n    span_y: Tuple[torch.Tensor, torch.Tensor] = None   # [n_heads] x2\n\n    zero_scores: torch.Tensor = None\n"
  },
  {
    "path": "stanza/models/coref/coref_chain.py",
    "content": "\"\"\"\nCoref chain suitable for attaching to a Document after coref processing\n\"\"\"\n\n# by not using namedtuple, we can use this object as output from the json module\n# in the doc class as long as we wrap the encoder to print these out in dict() form\n# CorefMention = namedtuple('CorefMention', ['sentence', 'start_word', 'end_word'])\nclass CorefMention:\n    def __init__(self, sentence, start_word, end_word):\n        self.sentence = sentence\n        self.start_word = start_word\n        self.end_word = end_word\n\nclass CorefChain:\n    def __init__(self, index, mentions, representative_text, representative_index):\n        self.index = index\n        self.mentions = mentions\n        self.representative_text = representative_text\n        self.representative_index = representative_index\n\nclass CorefAttachment:\n    def __init__(self, chain, is_start, is_end, is_representative):\n        self.chain = chain\n        self.is_start = is_start\n        self.is_end = is_end\n        self.is_representative = is_representative\n\n    def to_json(self):\n        j = {\n            \"index\": self.chain.index,\n            \"representative_text\": self.chain.representative_text\n        }\n        if self.is_start:\n            j['is_start'] = True\n        if self.is_end:\n            j['is_end'] = True\n        if self.is_representative:\n            j['is_representative'] = True\n        return j\n"
  },
  {
    "path": "stanza/models/coref/coref_config.toml",
    "content": "# =============================================================================\n# Before you start changing anything here, read the comments.\n# All of them can be found below in the \"DEFAULT\" section\n\n[DEFAULT]\n\n# The directory that contains extracted files of everything you've downloaded.\ndata_dir = \"data/coref\"\n\n# where to put checkpoints and final models\nsave_dir = \"saved_models/coref\"\nsave_name = \"bert-large-cased\"\n\n# Train, dev and test jsonlines\n# train_data = \"data/coref/en_gum-ud.train.nosgl.json\"\n# dev_data = \"data/coref/en_gum-ud.dev.nosgl.json\"\n# test_data = \"data/coref/en_gum-ud.test.nosgl.json\"\n\ntrain_data = \"data/coref/corefud_concat_v1_0_langid.train.json\"\ndev_data = \"data/coref/corefud_concat_v1_0_langid.dev.json\"\ntest_data = \"data/coref/corefud_concat_v1_0_langid.dev.json\"\n\n#train_data = \"data/coref/english_train_head.jsonlines\"\n#dev_data = \"data/coref/english_development_head.jsonlines\"\n#test_data = \"data/coref/english_test_head.jsonlines\"\n\n# do not use the full pairwise encoding scheme\nfull_pairwise = false\n\n# The device where everything is to be placed. \"cuda:N\"/\"cpu\" are supported.\ndevice = \"cuda:0\"\n\nsave_each_checkpoint = false\nlog_norms = false\n\n# Bert settings ======================\n\n# Base bert model architecture and tokenizer\nbert_model = \"bert-large-cased\"\n\n# Controls max length of sequences passed through bert to obtain its\n# contextual embeddings\n# Must be less than or equal to 512\nbert_window_size = 512\n\n# General model settings =============\n\n# Controls the dimensionality of feature embeddings\nembedding_size = 20\n\n# Controls the dimensionality of distance embeddings used by SpanPredictor\nsp_embedding_size = 64\n\n# Controls the number of spans for which anaphoricity can be scores in one\n# batch. Only affects final scoring; mention extraction and rough scoring\n# are less memory intensive, so they are always done in just one batch.\na_scoring_batch_size = 128\n\n# AnaphoricityScorer FFNN parameters\nhidden_size = 1024\nn_hidden_layers = 1\n\n# Do you want to support singletons?\nsingletons = true\n\n\n# Mention extraction settings ========\n\n# Mention extractor will check spans up to max_span_len words\n# The default value is chosen to be big enough to hold any dev data span\nmax_span_len = 64\n\n\n# Pruning settings ===================\n\n# Controls how many pairs should be preserved per mention\n# after applying rough scoring.\nrough_k = 50\n\n\n# Lora settings ===================\n\n# LoRA settings\nlora = false\nlora_alpha = 128\nlora_dropout = 0.1\nlora_rank = 64\nlora_target_modules = []\nlora_modules_to_save = []\n\n\n# Training settings ==================\n\n# Controls whether the first dummy node predicts cluster starts or singletons\nclusters_starts_are_singletons = true\n\n# Controls whether to fine-tune bert_model\nbert_finetune = true\n\n# Controls the dropout rate throughout all models\ndropout_rate = 0.3\n\n# Bert learning rate (only used if bert_finetune is set)\nbert_learning_rate = 1e-6\nbert_finetune_begin_epoch = 0.5\n\n# Task learning rate\nlearning_rate = 3e-4\n\n# For how many epochs the training is done\ntrain_epochs = 32\n\n# plateau for this many epochs = early terminate\nplateau_epochs = 0\n\n# Controls the weight of binary cross entropy loss added to nlml loss\nbce_loss_weight = 0.5\n\n# The directory that will contain conll prediction files\nconll_log_dir = \"data/conll_logs\"\n\n# Skip any documents longer than this length\nmax_train_len = 5000\n\n# if this is set to false, the model will set its zero_predictor to, well, 0\nuse_zeros = true\n\n# two different methods for specifying how to weaken the LR for certain languages\n# however, in their current forms, on an HE experiment, neither worked\n# better than just mixing the two datasets together unweighted\n# Starting from the HE IAHLT dataset, and possibly mixing in the ger/rom ud coref,\n# averaging over 5 different seeds, we got the following results:\n#   HE only:    0.497\n#   Attenuated: 0.508\n#   Scaled:     0.517\n#   Mixed:      0.517\n# the attenuation scheme for that experiment was 1/epoch\n# These were the settings\n#   --lang_lr_weights es=0.2,en=0.2,de=0.2,ca=0.2,fr=0.2,no=0.2\n#   --lang_lr_attenuation es,en,de,ca,fr,no\nlang_lr_attenuation = \"\"\nlang_lr_weights = \"\"\n\n# =============================================================================\n# Extra keyword arguments to be passed to bert tokenizers of specified models\n[DEFAULT.tokenizer_kwargs]\n    [DEFAULT.tokenizer_kwargs.roberta-large]\n        \"add_prefix_space\" = true\n\n    [DEFAULT.tokenizer_kwargs.xlm-roberta-large]\n        \"add_prefix_space\" = true\n\n    [DEFAULT.tokenizer_kwargs.spanbert-large-cased]\n        \"do_lower_case\" = false\n\n    [DEFAULT.tokenizer_kwargs.bert-large-cased]\n        \"do_lower_case\" = false\n\n# =============================================================================\n# The sections listed here do not need to make use of all config variables\n# If a variable is omitted, its default value will be used instead\n\n[roberta]\nbert_model = \"roberta-large\"\n\n[roberta_lora]\nbert_model = \"roberta-large\"\nbert_learning_rate = 0.00005\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[scandibert_lora]\nbert_model = \"vesteinn/ScandiBERT\"\nbert_learning_rate = 0.0002\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[xlm_roberta]\nbert_model = \"FacebookAI/xlm-roberta-large\"\nbert_learning_rate = 0.00001\nbert_finetune = true\n\n[xlm_roberta_lora]\nbert_model = \"FacebookAI/xlm-roberta-large\"\nbert_learning_rate = 0.000025\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[deeppavlov_slavic_bert_lora]\nbert_model = \"DeepPavlov/bert-base-bg-cs-pl-ru-cased\"\nbert_learning_rate = 0.000025\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[deberta_lora]\nbert_model = \"microsoft/deberta-v3-large\"\nbert_learning_rate = 0.00001\nlora = true\nlora_target_modules = [ \"query_proj\", \"value_proj\", \"output.dense\" ]\nlora_modules_to_save = [  ]\n\n[electra]\nbert_model = \"google/electra-large-discriminator\"\nbert_learning_rate = 0.00002\n\n[electra_lora]\nbert_model = \"google/electra-large-discriminator\"\nbert_learning_rate = 0.000025\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [  ]\n\n[hungarian_electra_lora]\n# TODO: experiment with tokenizer options for this to see if that's\n# why the results are so low using this transformer\nbert_model = \"NYTK/electra-small-discriminator-hungarian\"\nbert_learning_rate = 0.000025\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [  ]\n\n[muril_large_cased_lora]\nbert_model = \"google/muril-large-cased\"\nbert_learning_rate = 0.000025\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[muril_base_cased_lora]\nbert_model = \"google/muril-base-cased\"\nbert_learning_rate = 0.000025\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[indic_bert_lora]\nbert_model = \"ai4bharat/indic-bert\"\nbert_learning_rate = 0.0005\nlora = true\n# indic-bert is an albert with repeating layers of different names\nlora_target_modules = [ \"query\", \"value\", \"dense\", \"ffn\", \"full_layer\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[alephbertgimmel_lora]\nbert_model = \"imvladikon/alephbertgimmel-base-512\"\nbert_learning_rate = 0.000025\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[alephbert_lora]\nbert_model = \"onlplab/alephbert-base\"\nbert_learning_rate = 0.000025\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[hero_lora]\n# LR sweep on Hebrew IAHLT coref dev scores\n# (although there may be tokenization problems)\n# 0.000005    0.44202\n# 0.00001     0.45271\n# 0.000015    0.45771\n# 0.00002     0.45877\n# 0.000025    0.46076\n# 0.00003     0.45957\n# 0.000035    0.46187\n# 0.00004     0.46066\n# 0.000045    0.46132\n# 0.00005     0.46238\n# 0.000055    0.46084\n# 0.00006     0.46047\n# 0.000075    0.45772\n# 0.0001      0.44910\nbert_model = \"HeNLP/HeRo\"\nbert_learning_rate = 0.00005\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[bert_multilingual_cased_lora]\n# LR sweep on a Hindi dataset\n# 0.00001:  0.53238\n# 0.00002:  0.54012\n# 0.000025: 0.54206\n# 0.00003:  0.54050\n# 0.00004:  0.55081\n# 0.00005:  0.55135\n# 0.000075: 0.54482\n# 0.0001:   0.53888\nbert_model = \"google-bert/bert-base-multilingual-cased\"\nbert_learning_rate = 0.00005\nlora = true\nlora_target_modules = [ \"query\", \"value\", \"output.dense\", \"intermediate.dense\" ]\nlora_modules_to_save = [ \"pooler\" ]\n\n[t5_lora]\nbert_model = \"google-t5/t5-large\"\nbert_learning_rate = 0.000025\nbert_window_size = 1024\nlora = true\nlora_target_modules = [ \"q\", \"v\", \"o\", \"wi\", \"wo\" ]\nlora_modules_to_save = [  ]\n\n[mt5_lora]\nbert_model = \"google/mt5-base\"\nbert_learning_rate = 0.000025\nlora_alpha = 64\nlora_rank = 32\nlora = true\nlora_target_modules = [ \"q\", \"v\", \"o\", \"wi\", \"wo\" ]\nlora_modules_to_save = [  ]\n\n[deepnarrow_t5_xl_lora]\nbert_model = \"google/t5-efficient-xl\"\nbert_learning_rate = 0.00025\nlora = true\nlora_target_modules = [ \"q\", \"v\", \"o\", \"wi\", \"wo\" ]\nlora_modules_to_save = [  ]\n\n[roberta_no_finetune]\nbert_model = \"roberta-large\"\nbert_finetune = false\n\n[roberta_no_bce]\nbert_model = \"roberta-large\"\nbce_loss_weight = 0.0\n\n[spanbert]\nbert_model = \"SpanBERT/spanbert-large-cased\"\n\n[spanbert_no_bce]\nbert_model = \"SpanBERT/spanbert-large-cased\"\nbce_loss_weight = 0.0\n\n[bert]\nbert_model = \"bert-large-cased\"\n\n[longformer]\nbert_model = \"allenai/longformer-large-4096\"\nbert_window_size = 2048\n\n[debug]\nbert_window_size = 384\nbert_finetune = false\ndevice = \"cpu:0\"\n\n[debug_gpu]\nbert_window_size = 384\nbert_finetune = false\n"
  },
  {
    "path": "stanza/models/coref/dataset.py",
    "content": "import json\nimport logging\nfrom torch.utils.data import Dataset\n\nfrom stanza.models.coref.tokenizer_customization import TOKENIZER_FILTERS, TOKENIZER_MAPS\n\nlogger = logging.getLogger('stanza')\n\nclass CorefDataset(Dataset):\n\n    def __init__(self, path, config, tokenizer):\n        self.config = config\n        self.tokenizer = tokenizer\n\n        # by default, this doesn't filter anything (see lambda _ True);\n        # however, there are some subword symbols which are standalone\n        # tokens which we don't want on models like Albert; hence we\n        # pass along a filter if needed.\n        self.__filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,\n                                                   lambda _: True)\n        self.__token_map = TOKENIZER_MAPS.get(self.config.bert_model, {})\n\n        try:\n            with open(path, encoding=\"utf-8\") as fin:\n                data_f = json.load(fin)\n        except json.decoder.JSONDecodeError:\n            # read the old jsonlines format if necessary\n            with open(path, encoding=\"utf-8\") as fin:\n                text = \"[\" + \",\\n\".join(fin) + \"]\"\n            data_f = json.loads(text)\n        logger.info(\"Processing %d docs from %s...\", len(data_f), path)\n        self.__raw = data_f\n        self.__avg_span = sum(len(doc[\"head2span\"]) for doc in self.__raw) / len(self.__raw)\n        self.__out = []\n        for doc in self.__raw:\n            doc[\"span_clusters\"] = [[tuple(mention) for mention in cluster]\n                                for cluster in doc[\"span_clusters\"]]\n            word2subword = []\n            subwords = []\n            word_id = []\n            for i, word in enumerate(doc[\"cased_words\"]):\n                tokenized = self.tokenizer.tokenize(word)\n                if len(tokenized) == 0:\n                    word = \"_\"\n                    doc[\"cased_words\"][i] = word\n                    tokenized = self.tokenizer.tokenize(word)\n                    assert len(tokenized) > 0\n                tokenized_word = self.__token_map.get(word, tokenized)\n                tokenized_word = list(filter(self.__filter_func, tokenized_word))\n                word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))\n                subwords.extend(tokenized_word)\n                word_id.extend([i] * len(tokenized_word))\n\n            doc[\"word2subword\"] = word2subword\n            doc[\"subwords\"] = subwords\n            doc[\"word_id\"] = word_id\n\n            self.__out.append(doc)\n        logger.info(\"Loaded %d docs from %s.\", len(data_f), path)\n\n    @property\n    def avg_span(self):\n        return self.__avg_span\n\n    def __getitem__(self, x):\n        return self.__out[x]\n\n    def __len__(self):\n        return len(self.__out)\n"
  },
  {
    "path": "stanza/models/coref/loss.py",
    "content": "\"\"\" Describes the loss function used to train the model, which is a weighted\nsum of NLML and BCE losses. \"\"\"\n\nimport torch\n\n\nclass CorefLoss(torch.nn.Module):\n    \"\"\" See the rationale for using NLML in Lee et al. 2017\n    https://www.aclweb.org/anthology/D17-1018/\n    The added weighted summand of BCE helps the model learn even after\n    converging on the NLML task. \"\"\"\n\n    def __init__(self, bce_weight: float):\n        assert 0 <= bce_weight <= 1\n        super().__init__()\n        self._bce_module = torch.nn.BCEWithLogitsLoss()\n        self._bce_weight = bce_weight\n\n    def forward(self,    # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch\n                input_: torch.Tensor,\n                target: torch.Tensor) -> torch.Tensor:\n        \"\"\" Returns a weighted sum of two losses as a torch.Tensor \"\"\"\n        return (self._nlml(input_, target)\n                + self._bce(input_, target) * self._bce_weight)\n\n    def _bce(self,\n             input_: torch.Tensor,\n             target: torch.Tensor) -> torch.Tensor:\n        \"\"\" For numerical stability, clamps the input before passing it to BCE.\n        \"\"\"\n        return self._bce_module(torch.clamp(input_, min=-50, max=50), target)\n\n    @staticmethod\n    def _nlml(input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        gold = torch.logsumexp(input_ + torch.log(target), dim=1)\n        input_ = torch.logsumexp(input_, dim=1)\n        return (input_ - gold).mean()\n"
  },
  {
    "path": "stanza/models/coref/model.py",
    "content": "\"\"\" see __init__.py \"\"\"\n\nfrom datetime import datetime\nimport dataclasses\nimport json\nimport logging\nimport os\nimport random\nimport re\nfrom typing import Any, Dict, List, Optional, Set, Tuple\n\nimport numpy as np      # type: ignore\ntry:\n    import tomllib\nexcept ImportError:\n    import tomli as tomllib\nimport torch\nimport transformers     # type: ignore\n\nfrom pickle import UnpicklingError\nimport warnings\n\nfrom stanza.utils.get_tqdm import get_tqdm   # type: ignore\ntqdm = get_tqdm()\n\nfrom stanza.models.coref import bert, conll, utils\nfrom stanza.models.coref.anaphoricity_scorer import AnaphoricityScorer\nfrom stanza.models.coref.cluster_checker import ClusterChecker\nfrom stanza.models.coref.config import Config\nfrom stanza.models.coref.const import CorefResult, Doc\nfrom stanza.models.coref.loss import CorefLoss\nfrom stanza.models.coref.pairwise_encoder import PairwiseEncoder\nfrom stanza.models.coref.rough_scorer import RoughScorer\nfrom stanza.models.coref.span_predictor import SpanPredictor\nfrom stanza.models.coref.utils import GraphNode\nfrom stanza.models.coref.utils import sigmoid_focal_loss\nfrom stanza.models.coref.word_encoder import WordEncoder\nfrom stanza.models.coref.dataset import CorefDataset\nfrom stanza.models.coref.tokenizer_customization import *\n\nfrom stanza.models.common.bert_embedding import load_tokenizer\nfrom stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache\nfrom stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper\n\nimport torch.nn as nn\n\nlogger = logging.getLogger('stanza')\n\nclass CorefModel:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Combines all coref modules together to find coreferent spans.\n\n    Attributes:\n        config (coref.config.Config): the model's configuration,\n            see config.toml for the details\n        epochs_trained (int): number of epochs the model has been trained for\n        trainable (Dict[str, torch.nn.Module]): trainable submodules with their\n            names used as keys\n        training (bool): used to toggle train/eval modes\n\n    Submodules (in the order of their usage in the pipeline):\n        tokenizer (transformers.AutoTokenizer)\n        bert (transformers.AutoModel)\n        we (WordEncoder)\n        rough_scorer (RoughScorer)\n        pw (PairwiseEncoder)\n        a_scorer (AnaphoricityScorer)\n        sp (SpanPredictor)\n    \"\"\"\n    def __init__(self,\n                 epochs_trained: int = 0,\n                 build_optimizers: bool = True,\n                 config: Optional[dict] = None,\n                 foundation_cache=None):\n        \"\"\"\n        A newly created model is set to evaluation mode.\n\n        Args:\n            config_path (str): the path to the toml file with the configuration\n            section (str): the selected section of the config file\n            epochs_trained (int): the number of epochs finished\n                (useful for warm start)\n        \"\"\"\n        if config is None:\n            raise ValueError(\"Cannot create a model without a config\")\n        self.config = config\n        self.epochs_trained = epochs_trained\n        self._docs: Dict[str, List[Doc]] = {}\n        self._build_model(foundation_cache)\n\n        self.optimizers = {}\n        self.schedulers = {}\n\n        if build_optimizers:\n            self._build_optimizers()\n        self._set_training(False)\n\n        # final coreference resolution score\n        self._coref_criterion = CorefLoss(self.config.bce_loss_weight)\n        # score simply for the top-k choices out of the rough scorer\n        self._rough_criterion = CorefLoss(0)\n        # exact span matches\n        self._span_criterion = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n\n    @property\n    def training(self) -> bool:\n        \"\"\" Represents whether the model is in the training mode \"\"\"\n        return self._training\n\n    @training.setter\n    def training(self, new_value: bool):\n        if self._training is new_value:\n            return\n        self._set_training(new_value)\n\n    # ========================================================== Public methods\n\n    @torch.no_grad()\n    def evaluate(self,\n                 data_split: str = \"dev\",\n                 word_level_conll: bool = False, \n                 eval_lang: Optional[str] = None\n                 ) -> Tuple[float, Tuple[float, float, float]]:\n        \"\"\" Evaluates the modes on the data split provided.\n\n        Args:\n            data_split (str): one of 'dev'/'test'/'train'\n            word_level_conll (bool): if True, outputs conll files on word-level\n            eval_lang (str): which language to evaluate\n\n        Returns:\n            mean loss\n            span-level LEA: f1, precision, recal\n        \"\"\"\n        self.training = False\n        w_checker = ClusterChecker()\n        s_checker = ClusterChecker()\n        try:\n            data_split_data = f\"{data_split}_data\"\n            data_path = self.config.__dict__[data_split_data]\n            docs = self._get_docs(data_path)\n        except FileNotFoundError as e:\n            raise FileNotFoundError(\"Unable to find data split %s at file %s\" % (data_split_data, data_path)) from e\n        running_loss = 0.0\n        s_correct = 0\n        s_total = 0\n        z_correct = 0\n        z_total = 0\n\n        with conll.open_(self.config, self.epochs_trained, data_split) \\\n                as (gold_f, pred_f):\n            pbar = tqdm(docs, unit=\"docs\", ncols=0)\n            for doc in pbar:\n                if eval_lang and doc.get(\"lang\", \"\") != eval_lang:\n                    # skip that document, only used for ablation where we only\n                    # want to test evaluation on one language\n                    continue\n \n                res = self.run(doc, True)\n                # measure zero prediction accuracy\n                zero_preds = (res.zero_scores > 0).view(-1).to(device=res.zero_scores.device)\n                is_zero = doc.get(\"is_zero\")\n                if is_zero is None:\n                    zero_targets = torch.zeros_like(zero_preds, device=zero_preds.device)\n                else:\n                    zero_targets = torch.tensor(is_zero, device=zero_preds.device)\n                z_correct += (zero_preds == zero_targets).sum().item()\n                z_total += zero_targets.numel()\n\n                if (res.coref_y.argmax(dim=1) == 1).all():\n                    logger.warning(f\"EVAL: skipping document with no corefs...\")\n                    continue\n\n                running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item()\n                if res.word_clusters is None or res.span_clusters is None:\n                    logger.warning(f\"EVAL: skipping document with no clusters...\")\n                    continue\n\n                if res.span_y:\n                    pred_starts = res.span_scores[:, :, 0].argmax(dim=1)\n                    pred_ends = res.span_scores[:, :, 1].argmax(dim=1)\n                    s_correct += ((res.span_y[0] == pred_starts) * (res.span_y[1] == pred_ends)).sum().item()\n                    s_total += len(pred_starts)\n\n\n                if word_level_conll:\n                    raise NotImplementedError(\"We now write Conll-U conforming to UDCoref, which means that the span_clusters annotations will have headword info. word_level option is meaningless.\")\n                else:\n                    conll.write_conll(doc, doc[\"span_clusters\"], doc[\"word_clusters\"], gold_f)\n                    conll.write_conll(doc, res.span_clusters, res.word_clusters, pred_f)\n\n                w_checker.add_predictions(doc[\"word_clusters\"], res.word_clusters)\n                w_lea = w_checker.total_lea\n\n                s_checker.add_predictions(doc[\"span_clusters\"], res.span_clusters)\n                s_lea = s_checker.total_lea\n\n                del res\n\n                pbar.set_description(\n                    f\"{data_split}:\"\n                    f\" | WL: \"\n                    f\" loss: {running_loss / (pbar.n + 1):<.5f},\"\n                    f\" f1: {w_lea[0]:.5f},\"\n                    f\" p: {w_lea[1]:.5f},\"\n                    f\" r: {w_lea[2]:<.5f}\"\n                    f\" | SL: \"\n                    f\" sa: {s_correct / s_total:<.5f},\"\n                    f\" f1: {s_lea[0]:.5f},\"\n                    f\" p: {s_lea[1]:.5f},\"\n                    f\" r: {s_lea[2]:<.5f}\"\n                    f\" | ZA: {z_correct / z_total:<.5f}\"\n                )\n            logger.info(f\"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}\")\n            logger.info(f\"Zero prediction accuracy: {z_correct / z_total:.5f}\")\n\n        return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff)\n\n    def load_weights(self,\n                     path: Optional[str] = None,\n                     ignore: Optional[Set[str]] = None,\n                     map_location: Optional[str] = None,\n                     noexception: bool = False) -> None:\n        \"\"\"\n        Loads pretrained weights of modules saved in a file located at path.\n        If path is None, the last saved model with current configuration\n        in save_dir is loaded.\n        Assumes files are named like {configuration}_(e{epoch}_{time})*.pt.\n        \"\"\"\n        if path is None:\n            # pattern = rf\"{self.config.save_name}_\\(e(\\d+)_[^()]*\\).*\\.pt\"\n            # tries to load the last checkpoint in the same dir\n            pattern = rf\"{self.config.save_name}.*?\\.checkpoint\\.pt\"\n            files = []\n            os.makedirs(self.config.save_dir, exist_ok=True)\n            for f in os.listdir(self.config.save_dir):\n                match_obj = re.match(pattern, f)\n                if match_obj:\n                    files.append(f)\n            if not files:\n                if noexception:\n                    logger.debug(\"No weights have been loaded\", flush=True)\n                    return\n                raise OSError(f\"No weights found in {self.config.save_dir}!\")\n            path = sorted(files)[-1]\n            path = os.path.join(self.config.save_dir, path)\n\n        if map_location is None:\n            map_location = self.config.device\n        logger.debug(f\"Loading from {path}...\")\n        try:\n            state_dicts = torch.load(path, map_location=map_location, weights_only=True)\n        except UnpicklingError:\n            state_dicts = torch.load(path, map_location=map_location, weights_only=False)\n            warnings.warn(\"The saved coref model has an old format using Config instead of the Config mapped to dict to store weights.  This version of Stanza can support reading both the new and the old formats.  Future versions will only allow loading with weights_only=True.  Please resave the coref model using this version ASAP.\")\n        self.epochs_trained = state_dicts.pop(\"epochs_trained\", 0)\n        # just ignore a config in the model, since we should already have one\n        # TODO: some config elements may be fixed parameters of the model,\n        # such as the dimensions of the head,\n        # so we would want to use the ones from the config even if the\n        # user created a weird shaped model\n        config = state_dicts.pop(\"config\", {})\n        self.load_state_dicts(state_dicts, ignore)\n\n    def load_state_dicts(self,\n                         state_dicts: dict,\n                         ignore: Optional[Set[str]] = None):\n        \"\"\"\n        Process the dictionaries from the save file\n\n        Loads the weights into the tensors of this model\n        May also have optimizer and/or schedule state\n        \"\"\"\n        for key, state_dict in state_dicts.items():\n            logger.debug(\"Loading state: %s\", key)\n            if not ignore or key not in ignore:\n                if key.endswith(\"_optimizer\"):\n                    self.optimizers[key].load_state_dict(state_dict)\n                elif key.endswith(\"_scheduler\"):\n                    self.schedulers[key].load_state_dict(state_dict)\n                elif key == \"bert_lora\":\n                    assert self.config.lora, \"Unable to load state dict of LoRA model into model initialized without LoRA!\"\n                    self.bert = load_peft_wrapper(self.bert, state_dict, vars(self.config), logger, self.peft_name)\n                else:\n                    self.trainable[key].load_state_dict(state_dict, strict=False)\n                logger.debug(f\"Loaded {key}\")\n        if self.config.log_norms:\n            self.log_norms()\n\n    def build_doc(self, doc: dict) -> dict:\n        filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,\n                                            lambda _: True)\n        token_map = TOKENIZER_MAPS.get(self.config.bert_model, {})\n\n        word2subword = []\n        subwords = []\n        word_id = []\n        for i, word in enumerate(doc[\"cased_words\"]):\n            tokenized_word = (token_map[word]\n                              if word in token_map\n                              else self.tokenizer.tokenize(word))\n            tokenized_word = list(filter(filter_func, tokenized_word))\n            word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))\n            subwords.extend(tokenized_word)\n            word_id.extend([i] * len(tokenized_word))\n        doc[\"word2subword\"] = word2subword\n        doc[\"subwords\"] = subwords\n        doc[\"word_id\"] = word_id\n\n        doc[\"head2span\"] = []\n        if \"speaker\" not in doc:\n            doc[\"speaker\"] = [\"_\" for _ in doc[\"cased_words\"]]\n        doc[\"word_clusters\"] = []\n        doc[\"span_clusters\"] = []\n\n        return doc\n\n\n    @staticmethod\n    def load_model(path: str,\n                   map_location: str = \"cpu\",\n                   ignore: Optional[Set[str]] = None,\n                   config_update: Optional[dict] = None,\n                   foundation_cache = None):\n        if not path:\n            raise FileNotFoundError(\"coref model got an invalid path |%s|\" % path)\n        if not os.path.exists(path):\n            raise FileNotFoundError(\"coref model file %s not found\" % path)\n        try:\n            state_dicts = torch.load(path, map_location=map_location, weights_only=True)\n        except UnpicklingError:\n            state_dicts = torch.load(path, map_location=map_location, weights_only=False)\n            warnings.warn(\"The saved coref model has an old format using Config instead of the Config mapped to dict to store weights.  This version of Stanza can support reading both the new and the old formats.  Future versions will only allow loading with weights_only=True.  Please resave the coref model using this version ASAP.\")\n        epochs_trained = state_dicts.pop(\"epochs_trained\", 0)\n        config = state_dicts.pop('config', None)\n        if config is None:\n            raise ValueError(\"Cannot load this format model without config in the dicts\")\n        if 'max_train_len' not in config:\n            # TODO: this is to keep old models working.\n            # Can get rid of it if those models are rebuilt\n            config['max_train_len'] = 5000\n        if isinstance(config, dict):\n            config = Config(**config)\n        if config_update:\n            for key, value in config_update.items():\n                setattr(config, key, value)\n        model = CorefModel(config=config, build_optimizers=False,\n                           epochs_trained=epochs_trained, foundation_cache=foundation_cache)\n        model.load_state_dicts(state_dicts, ignore)\n        return model\n\n\n    def run(self,  # pylint: disable=too-many-locals\n            doc: Doc,\n            use_gold_spans_for_zeros = False\n            ) -> CorefResult:\n        \"\"\"\n        This is a massive method, but it made sense to me to not split it into\n        several ones to let one see the data flow.\n\n        Args:\n            doc (Doc): a dictionary with the document data.\n\n        Returns:\n            CorefResult (see const.py)\n        \"\"\"\n        # Encode words with bert\n        # words           [n_words, span_emb]\n        # cluster_ids     [n_words]\n        words, cluster_ids = self.we(doc, self._bertify(doc))\n\n        # Obtain bilinear scores and leave only top-k antecedents for each word\n        # top_rough_scores  [n_words, n_ants]\n        # top_indices       [n_words, n_ants]\n        top_rough_scores, top_indices, rough_scores = self.rough_scorer(words)\n\n        # Get pairwise features [n_words, n_ants, n_pw_features]\n        pw = self.pw(top_indices, doc)\n\n        batch_size = self.config.a_scoring_batch_size\n        a_scores_lst: List[torch.Tensor] = []\n\n        for i in range(0, len(words), batch_size):\n            pw_batch = pw[i:i + batch_size]\n            words_batch = words[i:i + batch_size]\n            top_indices_batch = top_indices[i:i + batch_size]\n            top_rough_scores_batch = top_rough_scores[i:i + batch_size]\n\n            # a_scores_batch    [batch_size, n_ants]\n            a_scores_batch = self.a_scorer(\n                top_mentions=words[top_indices_batch], mentions_batch=words_batch,\n                pw_batch=pw_batch, top_rough_scores_batch=top_rough_scores_batch\n            )\n            a_scores_lst.append(a_scores_batch)\n\n        res = CorefResult()\n\n        # coref_scores  [n_spans, n_ants]\n        res.coref_scores = torch.cat(a_scores_lst, dim=0)\n\n        res.coref_y = self._get_ground_truth(\n            cluster_ids, top_indices, (top_rough_scores > float(\"-inf\")),\n            self.config.clusters_starts_are_singletons,\n            self.config.singletons\n        )\n\n        res.word_clusters = self._clusterize(\n            doc, res.coref_scores, top_indices,\n            self.config.singletons\n        )\n\n        res.span_scores, res.span_y = self.sp.get_training_data(doc, words)\n\n        if not self.training:\n            res.span_clusters = self.sp.predict(doc, words, res.word_clusters)\n\n        if not self.training and not use_gold_spans_for_zeros:\n            zero_words = words[[word_id\n                                for cluster in res.word_clusters\n                                for word_id in cluster]]\n        else:\n            zero_words = words[[i[0] for i in sorted(doc[\"head2span\"])]]\n        res.zero_scores = self.zeros_predictor(zero_words)\n\n        return res\n\n    def save_weights(self, save_path=None, save_optimizers=True):\n        \"\"\" Saves trainable models as state dicts. \"\"\"\n        to_save: List[Tuple[str, Any]] = \\\n            [(key, value) for key, value in self.trainable.items()\n             if (self.config.bert_finetune and not self.config.lora) or key != \"bert\"]\n        if save_optimizers:\n            to_save.extend(self.optimizers.items())\n            to_save.extend(self.schedulers.items())\n\n        time = datetime.strftime(datetime.now(), \"%Y.%m.%d_%H.%M\")\n        if save_path is None:\n            save_path = os.path.join(self.config.save_dir,\n                                     f\"{self.config.save_name}\"\n                                     f\"_e{self.epochs_trained}_{time}.pt\")\n        savedict = {name: module.state_dict() for name, module in to_save}\n        if self.config.lora:\n            # so that this dependency remains optional\n            from peft import get_peft_model_state_dict\n            savedict[\"bert_lora\"] = get_peft_model_state_dict(self.bert, adapter_name=\"coref\")\n        savedict[\"epochs_trained\"] = self.epochs_trained  # type: ignore\n        # save as a dictionary because the weights_only=True load option\n        # doesn't allow for arbitrary @dataclass configs\n        savedict[\"config\"] = dataclasses.asdict(self.config)\n        save_dir = os.path.split(save_path)[0]\n        if save_dir:\n            os.makedirs(save_dir, exist_ok=True)\n        torch.save(savedict, save_path)\n\n    def log_norms(self):\n        lines = [\"NORMS FOR MODEL PARAMTERS\"]\n        for t_name, trainable in self.trainable.items():\n            for name, param in trainable.named_parameters():\n                if param.requires_grad:\n                    lines.append(\"  %s: %s %.6g  (%d)\" % (t_name, name, torch.norm(param).item(), param.numel()))\n        logger.info(\"\\n\".join(lines))\n\n\n    def train(self, log=False):\n        \"\"\"\n        Trains all the trainable blocks in the model using the config provided.\n\n        log: whether or not to log using wandb\n        skip_lang: str if we want to skip training this language (used for ablation)\n        \"\"\"\n\n        if log:\n            import wandb\n            wandb.watch((self.bert, self.pw,\n                         self.a_scorer, self.we,\n                         self.rough_scorer, self.sp))\n\n        docs = self._get_docs(self.config.train_data)\n        docs_ids = list(range(len(docs)))\n        avg_spans = docs.avg_span\n\n        # for a brand new model, we set the zeros prediction to all 0 if the dataset has no zeros\n        training_has_zeros = any('is_zero' in doc for doc in docs)\n        if not training_has_zeros:\n            logger.info(\"No zeros found in the dataset.  The zeros predictor will set to 0\")\n            if self.epochs_trained == 0:\n                # new model, set it to always predict not-zero\n                self.disable_zeros_predictor()\n\n        attenuated_languages = set()\n        if self.config.lang_lr_attenuation:\n            attenuated_languages = self.config.lang_lr_attenuation.split(\",\")\n            logger.info(\"Attenuating LR for the following languages: %s\", attenuated_languages)\n\n        lr_scaled_languages = dict()\n        if self.config.lang_lr_weights:\n            scaled_languages = self.config.lang_lr_weights.split(\",\")\n            for piece in scaled_languages:\n                pieces = piece.split(\"=\")\n                lr_scaled_languages[pieces[0]] = float(pieces[1])\n            logger.info(\"Scaling LR for the following languages: %s\", lr_scaled_languages)\n\n        best_f1 = None\n        best_epoch = self.epochs_trained\n        for epoch in range(self.epochs_trained, self.config.train_epochs):\n            self.training = True\n            if self.config.log_norms:\n                self.log_norms()\n            running_c_loss = 0.0\n            running_s_loss = 0.0\n            running_z_loss = 0.0\n            random.shuffle(docs_ids)\n            pbar = tqdm(docs_ids, unit=\"docs\", ncols=0)\n            for doc_indx, doc_id in enumerate(pbar):\n                doc = docs[doc_id]\n\n                # skip very long documents during training time\n                if len(doc[\"subwords\"]) > self.config.max_train_len:\n                    continue\n\n                for optim in self.optimizers.values():\n                    optim.zero_grad()\n\n                res = self.run(doc)\n\n                if res.zero_scores.size(0) == 0 or not training_has_zeros:\n                    z_loss = 0.0 # since there are no corefs\n                else:\n                    is_zero = doc.get(\"is_zero\")\n                    if is_zero is None:\n                        is_zero = torch.zeros_like(res.zero_scores.squeeze(-1), device=res.zero_scores.device, dtype=torch.float)\n                    else:\n                        is_zero = torch.tensor(is_zero).to(res.zero_scores.device).float()\n                    z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1), is_zero, reduction=\"mean\")\n\n                c_loss = self._coref_criterion(res.coref_scores, res.coref_y)\n\n                if res.span_y:\n                    s_loss = (self._span_criterion(res.span_scores[:, :, 0], res.span_y[0])\n                              + self._span_criterion(res.span_scores[:, :, 1], res.span_y[1])) / avg_spans / 2\n                else:\n                    s_loss = torch.zeros_like(c_loss)\n\n                lr_scale = lr_scaled_languages.get(doc.get(\"lang\"), 1.0)\n                if doc.get(\"lang\") in attenuated_languages:\n                    lr_scale = lr_scale / max(epoch, 1.0)\n                c_loss = c_loss * lr_scale\n                s_loss = s_loss * lr_scale\n                z_loss = z_loss * lr_scale\n\n                (c_loss + s_loss + z_loss).backward()\n\n                running_c_loss += c_loss.item()\n                running_s_loss += s_loss.item()\n                if res.zero_scores.size(0) != 0 and training_has_zeros:\n                    running_z_loss += z_loss.item()\n\n                # log every 100 docs\n                if log and doc_indx % 100 == 0:\n                    logged = {\n                        'train_c_loss': c_loss.item(),\n                        'train_s_loss': s_loss.item(),\n                    }\n                    if res.zero_scores.size(0) != 0 and training_has_zeros:\n                        logged['train_z_loss'] = z_loss.item()\n                    wandb.log(logged)\n\n                del c_loss, s_loss, z_loss, res\n\n                for optim in self.optimizers.values():\n                    optim.step()\n                for scheduler in self.schedulers.values():\n                    scheduler.step()\n\n                pbar.set_description(\n                    f\"Epoch {epoch + 1}:\"\n                    f\" {doc['document_id']:26}\"\n                    f\" c_loss: {running_c_loss / (pbar.n + 1):<.5f}\"\n                    f\" s_loss: {running_s_loss / (pbar.n + 1):<.5f}\"\n                    f\" z_loss: {running_z_loss / (pbar.n + 1):<.5f}\"\n                )\n\n            self.epochs_trained += 1\n            scores = self.evaluate()\n            prev_best_f1 = best_f1\n            if log:\n                wandb.log({'dev_score': scores[1]})\n                wandb.log({'dev_bakeoff': scores[-1]})\n\n            if best_f1 is None or scores[1] > best_f1:\n                best_epoch = epoch\n                if best_f1 is None:\n                    logger.info(\"Saving new best model: F1 %.4f\", scores[1])\n                else:\n                    logger.info(\"Saving new best model: F1 %.4f > %.4f\", scores[1], best_f1)\n                best_f1 = scores[1]\n                if self.config.save_name.endswith(\".pt\"):\n                    save_path = os.path.join(self.config.save_dir,\n                                             f\"{self.config.save_name}\")\n                else:\n                    save_path = os.path.join(self.config.save_dir,\n                                             f\"{self.config.save_name}.pt\")\n                self.save_weights(save_path, save_optimizers=False)\n            if self.config.save_each_checkpoint:\n                self.save_weights()\n            else:\n                if self.config.save_name.endswith(\".pt\"):\n                    checkpoint_path = os.path.join(self.config.save_dir,\n                                                   f\"{self.config.save_name[:-3]}.checkpoint.pt\")\n                else:\n                    checkpoint_path = os.path.join(self.config.save_dir,\n                                                   f\"{self.config.save_name}.checkpoint.pt\")\n                self.save_weights(checkpoint_path)\n            if prev_best_f1 is not None and prev_best_f1 != best_f1:\n                logger.info(\"Epoch %d finished.\\nSentence F1 %.5f p %.5f r %.5f\\nBest F1 %.5f\\nPrevious best F1 %.5f\", self.epochs_trained, scores[1], scores[2], scores[3], best_f1, prev_best_f1)\n            else:\n                logger.info(\"Epoch %d finished.\\nSentence F1 %.5f p %.5f r %.5f\\nBest F1 %.5f\", self.epochs_trained, scores[1], scores[2], scores[3], best_f1)\n            if self.config.plateau_epochs > 0 and best_epoch + self.config.plateau_epochs < epoch:\n                logger.info(\"Have plateaued for too long (%d epochs).  Will terminate training\", self.config.plateau_epochs)\n                break\n\n    # ========================================================= Private methods\n\n    def _bertify(self, doc: Doc) -> torch.Tensor:\n        all_batches = bert.get_subwords_batches(doc, self.config, self.tokenizer)\n\n        # we index the batches n at a time to prevent oom\n        result = []\n        for i in range(0, all_batches.shape[0], 1024):\n            subwords_batches = all_batches[i:i+1024]\n\n            special_tokens = np.array([self.tokenizer.cls_token_id,\n                                       self.tokenizer.sep_token_id,\n                                       self.tokenizer.pad_token_id,\n                                       self.tokenizer.eos_token_id])\n            subword_mask = ~(np.isin(subwords_batches, special_tokens))\n\n            subwords_batches_tensor = torch.tensor(subwords_batches,\n                                                device=self.config.device,\n                                                dtype=torch.long)\n            subword_mask_tensor = torch.tensor(subword_mask,\n                                            device=self.config.device)\n\n            # Obtain bert output for selected batches only\n            attention_mask = (subwords_batches != self.tokenizer.pad_token_id)\n            if \"t5\" in self.config.bert_model:\n                out = self.bert.encoder(\n                        input_ids=subwords_batches_tensor,\n                        attention_mask=torch.tensor(\n                            attention_mask, device=self.config.device))\n            else:\n                out = self.bert(\n                        subwords_batches_tensor,\n                        attention_mask=torch.tensor(\n                            attention_mask, device=self.config.device))\n\n            out = out['last_hidden_state']\n            # [n_subwords, bert_emb]\n            result.append(out[subword_mask_tensor])\n\n        # stack returns and return\n        return torch.cat(result)\n\n    def _build_model(self, foundation_cache):\n        if hasattr(self.config, 'lora') and self.config.lora:\n            self.bert, self.tokenizer, peft_name = load_bert_with_peft(self.config.bert_model, \"coref\", foundation_cache)\n            # vars() converts a dataclass to a dict, used for being able to index things like args[\"lora_*\"]\n            self.bert = build_peft_wrapper(self.bert, vars(self.config), logger, adapter_name=peft_name)\n            self.peft_name = peft_name\n        else:\n            if self.config.bert_finetune:\n                logger.debug(\"Coref model requested a finetuned transformer; we are not using the foundation model cache to prevent we accidentally leak the finetuning weights elsewhere.\")\n                foundation_cache = NoTransformerFoundationCache(foundation_cache)\n            self.bert, self.tokenizer = load_bert(self.config.bert_model, foundation_cache)\n\n        base_bert_name = self.config.bert_model.split(\"/\")[-1]\n        tokenizer_kwargs = self.config.tokenizer_kwargs.get(base_bert_name, {})\n        if tokenizer_kwargs:\n            logger.debug(f\"Using tokenizer kwargs: {tokenizer_kwargs}\")\n        # we just downloaded the tokenizer, so for simplicity, we don't make another request to HF\n        self.tokenizer = load_tokenizer(self.config.bert_model, tokenizer_kwargs, local_files_only=True)\n\n        if self.config.bert_finetune or (hasattr(self.config, 'lora') and self.config.lora):\n            self.bert = self.bert.train()\n\n        self.bert = self.bert.to(self.config.device)\n        self.pw = PairwiseEncoder(self.config).to(self.config.device)\n\n        bert_emb = self.bert.config.hidden_size\n        pair_emb = bert_emb * 3 + self.pw.shape\n\n        # pylint: disable=line-too-long\n        self.a_scorer = AnaphoricityScorer(pair_emb, self.config).to(self.config.device)\n        self.we = WordEncoder(bert_emb, self.config).to(self.config.device)\n        self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device)\n        self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device)\n        self.zeros_predictor = nn.Sequential(\n            nn.Linear(bert_emb, bert_emb),\n            nn.ReLU(),\n            nn.Linear(bert_emb, 1)\n        ).to(self.config.device)\n        if not hasattr(self.config, 'use_zeros') or not self.config.use_zeros:\n            self.disable_zeros_predictor()\n\n        self.trainable: Dict[str, torch.nn.Module] = {\n            \"bert\": self.bert, \"we\": self.we,\n            \"rough_scorer\": self.rough_scorer,\n            \"pw\": self.pw, \"a_scorer\": self.a_scorer,\n            \"sp\": self.sp, \"zeros_predictor\": self.zeros_predictor\n        }\n\n    def disable_zeros_predictor(self):\n        nn.init.zeros_(self.zeros_predictor[-1].weight)\n        nn.init.zeros_(self.zeros_predictor[-1].bias)\n\n    def _build_optimizers(self):\n        n_docs = len(self._get_docs(self.config.train_data))\n        self.optimizers: Dict[str, torch.optim.Optimizer] = {}\n        self.schedulers: Dict[str, torch.optim.lr_scheduler.LRScheduler] = {}\n\n        if not getattr(self.config, 'lora', False):\n            for param in self.bert.parameters():\n                param.requires_grad = self.config.bert_finetune\n\n        if self.config.bert_finetune:\n            logger.debug(\"Making bert optimizer with LR of %f\", self.config.bert_learning_rate)\n            self.optimizers[\"bert_optimizer\"] = torch.optim.Adam(\n                self.bert.parameters(), lr=self.config.bert_learning_rate\n            )\n            start_finetuning = int(n_docs * self.config.bert_finetune_begin_epoch)\n            if start_finetuning > 0:\n                logger.info(\"Will begin finetuning transformer at iteration %d\", start_finetuning)\n            zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizers[\"bert_optimizer\"], factor=0, total_iters=start_finetuning)\n            warmup_scheduler = transformers.get_linear_schedule_with_warmup(\n                self.optimizers[\"bert_optimizer\"],\n                start_finetuning, n_docs * self.config.train_epochs - start_finetuning)\n            self.schedulers[\"bert_scheduler\"] = torch.optim.lr_scheduler.SequentialLR(\n                self.optimizers[\"bert_optimizer\"],\n                schedulers=[zero_scheduler, warmup_scheduler],\n                milestones=[start_finetuning])\n\n        # Must ensure the same ordering of parameters between launches\n        modules = sorted((key, value) for key, value in self.trainable.items()\n                         if key != \"bert\")\n        params = []\n        for _, module in modules:\n            for param in module.parameters():\n                param.requires_grad = True\n                params.append(param)\n\n        self.optimizers[\"general_optimizer\"] = torch.optim.Adam(\n            params, lr=self.config.learning_rate)\n        self.schedulers[\"general_scheduler\"] = \\\n            transformers.get_linear_schedule_with_warmup(\n                self.optimizers[\"general_optimizer\"],\n                0, n_docs * self.config.train_epochs\n            )\n\n    def _clusterize(self, doc: Doc, scores: torch.Tensor, top_indices: torch.Tensor,\n                    singletons: bool = True):\n        if singletons:\n            antecedents = scores[:,1:].argmax(dim=1) - 1\n            # set the dummy values to -1, so that they are not coref to themselves\n            is_start = (scores[:, :2].argmax(dim=1) == 0)\n        else:\n            antecedents = scores.argmax(dim=1) - 1\n\n        not_dummy = antecedents >= 0\n        coref_span_heads = torch.arange(0, len(scores), device=not_dummy.device)[not_dummy]\n        antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]\n\n        nodes = [GraphNode(i) for i in range(len(doc[\"cased_words\"]))]\n        for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):\n            nodes[i].link(nodes[j])\n            assert nodes[i] is not nodes[j]\n\n        visited = {}\n\n        clusters = []\n        for node in nodes:\n            if len(node.links) > 0 and not node.visited:\n                cluster = []\n                stack = [node]\n                while stack:\n                    current_node = stack.pop()\n                    current_node.visited = True\n                    cluster.append(current_node.id)\n                    stack.extend(link for link in current_node.links if not link.visited)\n                assert len(cluster) > 1\n                for i in cluster:\n                    visited[i] = True\n                clusters.append(sorted(cluster))\n\n        if singletons:\n            # go through the is_start nodes; if no clusters contain that node\n            # i.e. visited[i] == False, we add it as a singleton\n            for indx, i in enumerate(is_start):\n                if i and not visited.get(indx, False):\n                    clusters.append([indx])\n\n        return sorted(clusters)\n\n    def _get_docs(self, path: str) -> List[Doc]:\n        if path not in self._docs:\n            self._docs[path] = CorefDataset(path, self.config, self.tokenizer)\n        return self._docs[path]\n\n    @staticmethod\n    def _get_ground_truth(cluster_ids: torch.Tensor,\n                          top_indices: torch.Tensor,\n                          valid_pair_map: torch.Tensor,\n                          cluster_starts: bool,\n                          singletons:bool = True) -> torch.Tensor:\n        \"\"\"\n        Args:\n            cluster_ids: tensor of shape [n_words], containing cluster indices\n                for each word. Non-gold words have cluster id of zero.\n            top_indices: tensor of shape [n_words, n_ants],\n                indices of antecedents of each word\n            valid_pair_map: boolean tensor of shape [n_words, n_ants],\n                whether for pair at [i, j] (i-th word and j-th word)\n                j < i is True\n\n        Returns:\n            tensor of shape [n_words, n_ants + 1] (dummy added),\n                containing 1 at position [i, j] if i-th and j-th words corefer.\n        \"\"\"\n        y = cluster_ids[top_indices] * valid_pair_map  # [n_words, n_ants]\n        y[y == 0] = -1                                 # -1 for non-gold words\n        y = utils.add_dummy(y)                         # [n_words, n_cands + 1]\n\n        if singletons:\n            if not cluster_starts:\n                unique, counts = cluster_ids.unique(return_counts=True)\n                singleton_clusters = unique[(counts == 1) & (unique != 0)]\n                first_corefs = [(cluster_ids == i).nonzero().flatten()[0] for i in singleton_clusters]\n                if len(first_corefs) > 0:\n                    first_coref = torch.stack(first_corefs)\n                else:\n                    first_coref = torch.tensor([]).to(cluster_ids.device).long()\n            else:\n                # I apologize for this abuse of everything that's good about PyTorch.\n                # in essence, this line finds the INDEX of FIRST OCCURENCE of each NON-ZERO value\n                # from cluster_ids. We need this information because we use it to mark the\n                # special \"is-start-of-ref\" marker used to detect singletons.\n                first_coref = (cluster_ids ==\n                            cluster_ids.unique().sort().values[1:].unsqueeze(1)\n                            ).float().topk(k=1, dim=1).indices.squeeze()\n        y = (y == cluster_ids.unsqueeze(1))            # True if coreferent\n        # For all rows with no gold antecedents setting dummy to True\n        y[y.sum(dim=1) == 0, 0] = True\n\n        if singletons:\n            # add another dummy for first coref\n            y = utils.add_dummy(y)                         # [n_words, n_cands + 2]\n            # for all rows that's a first coref, setting its dummy to True and unset the\n            # non-coref dummy to false\n            y[first_coref, 0] = True\n            y[first_coref, 1] = False\n        return y.to(torch.float)\n\n    @staticmethod\n    def _load_config(config_path: str,\n                     section: str) -> Config:\n        with open(config_path, \"rb\") as fin:\n            config = tomllib.load(fin)\n        default_section = config[\"DEFAULT\"]\n        current_section = config[section]\n        unknown_keys = (set(current_section.keys())\n                        - set(default_section.keys()))\n        if unknown_keys:\n            raise ValueError(f\"Unexpected config keys: {unknown_keys}\")\n        return Config(section, **{**default_section, **current_section})\n\n    def _set_training(self, value: bool):\n        self._training = value\n        for module in self.trainable.values():\n            module.train(self._training)\n"
  },
  {
    "path": "stanza/models/coref/pairwise_encoder.py",
    "content": "\"\"\" Describes PairwiseEncodes, that transforms pairwise features, such as\ndistance between the mentions, same/different speaker into feature embeddings\n\"\"\"\nfrom typing import List\n\nimport torch\n\nfrom stanza.models.coref.config import Config\nfrom stanza.models.coref.const import Doc\n\n\nclass PairwiseEncoder(torch.nn.Module):\n    \"\"\" A Pytorch module to obtain feature embeddings for pairwise features\n\n    Usage:\n        encoder = PairwiseEncoder(config)\n        pairwise_features = encoder(pair_indices, doc)\n    \"\"\"\n    def __init__(self, config: Config):\n        super().__init__()\n        emb_size = config.embedding_size\n\n        self.genre2int = {g: gi for gi, g in enumerate([\"bc\", \"bn\", \"mz\", \"nw\",\n                                                        \"pt\", \"tc\", \"wb\"])}\n        self.genre_emb = torch.nn.Embedding(len(self.genre2int), emb_size)\n\n        # each position corresponds to a bucket:\n        #   [(0, 2), (2, 3), (3, 4), (4, 5), (5, 8),\n        #    (8, 16), (16, 32), (32, 64), (64, float(\"inf\"))]\n        self.distance_emb = torch.nn.Embedding(9, emb_size)\n\n        # two possibilities: same vs different speaker\n        self.speaker_emb = torch.nn.Embedding(2, emb_size)\n\n        self.dropout = torch.nn.Dropout(config.dropout_rate)\n\n        self.__full_pw = config.full_pairwise\n\n        if self.__full_pw:\n            self.shape = emb_size * 2  # distance, speaker\n        else:\n            self.shape = emb_size # distance only\n\n    @property\n    def device(self) -> torch.device:\n        \"\"\" A workaround to get current device (which is assumed to be the\n        device of the first parameter of one of the submodules) \"\"\"\n        return next(self.genre_emb.parameters()).device\n\n    def forward(self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch\n                top_indices: torch.Tensor,\n                doc: Doc) -> torch.Tensor:\n        word_ids = torch.arange(0, len(doc[\"cased_words\"]), device=self.device)\n\n        # bucketing the distance (see __init__())\n        distance = (word_ids.unsqueeze(1) - word_ids[top_indices]\n                    ).clamp_min_(min=1)\n        log_distance = distance.to(torch.float).log2().floor_()\n        log_distance = log_distance.clamp_max_(max=6).to(torch.long)\n        distance = torch.where(distance < 5, distance - 1, log_distance + 2)\n        distance = self.distance_emb(distance)\n\n        if not self.__full_pw:\n            return self.dropout(distance)\n\n        # calculate speaker embeddings\n        speaker_map = torch.tensor(self._speaker_map(doc), device=self.device)\n        same_speaker = (speaker_map[top_indices] == speaker_map.unsqueeze(1))\n        same_speaker = self.speaker_emb(same_speaker.to(torch.long))\n\n        return self.dropout(torch.cat((same_speaker, distance), dim=2))\n\n    @staticmethod\n    def _speaker_map(doc: Doc) -> List[int]:\n        \"\"\"\n        Returns a tensor where i-th element is the speaker id of i-th word.\n        \"\"\"\n        # if speaker is not found in the doc, simply return \"speaker#1\" for all the speakers\n        # and embed them using the same ID\n        \n        # speaker string -> speaker id\n        str2int = {s: i for i, s in enumerate(set(doc.get(\"speaker\", [\"speaker#1\"\n                                                                      for _ in range(len(doc[\"cased_words\"]))])))}\n\n        # word id -> speaker id\n        return [str2int[s] for s in doc.get(\"speaker\", [\"speaker#1\"\n                                                        for _ in range(len(doc[\"cased_words\"]))])]\n"
  },
  {
    "path": "stanza/models/coref/predict.py",
    "content": "import argparse\n\nimport json\nimport torch\nfrom tqdm import tqdm\n\nfrom stanza.models.coref.model import CorefModel\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser()\n    argparser.add_argument(\"experiment\")\n    argparser.add_argument(\"input_file\")\n    argparser.add_argument(\"output_file\")\n    argparser.add_argument(\"--config-file\", default=\"config.toml\")\n    argparser.add_argument(\"--batch-size\", type=int,\n                           help=\"Adjust to override the config value if you're\"\n                                \" experiencing out-of-memory issues\")\n    argparser.add_argument(\"--weights\",\n                           help=\"Path to file with weights to load.\"\n                                \" If not supplied, in the latest\"\n                                \" weights of the experiment will be loaded;\"\n                                \" if there aren't any, an error is raised.\")\n    args = argparser.parse_args()\n\n    model = CorefModel.load_model(path=args.weights,\n                                  map_location=\"cpu\",\n                                  ignore={\"bert_optimizer\", \"general_optimizer\",\n                                          \"bert_scheduler\", \"general_scheduler\"})\n    if args.batch_size:\n        model.config.a_scoring_batch_size = args.batch_size\n    model.training = False\n\n    try:\n        with open(args.input_file, encoding=\"utf-8\") as fin:\n            input_data = json.load(fin)\n    except json.decoder.JSONDecodeError:\n        # read the old jsonlines format if necessary\n        with open(args.input_file, encoding=\"utf-8\") as fin:\n            text = \"[\" + \",\\n\".join(fin) + \"]\"\n        input_data = json.loads(text)\n    docs = [model.build_doc(doc) for doc in input_data]\n\n    with torch.no_grad():\n        for doc in tqdm(docs, unit=\"docs\"):\n            result = model.run(doc)\n            doc[\"span_clusters\"] = result.span_clusters\n            doc[\"word_clusters\"] = result.word_clusters\n\n            for key in (\"word2subword\", \"subwords\", \"word_id\", \"head2span\"):\n                del doc[key]\n\n    with open(args.output_file, mode=\"w\") as fout:\n        for doc in docs:\n            json.dump(doc, fout)\n"
  },
  {
    "path": "stanza/models/coref/rough_scorer.py",
    "content": "\"\"\" Describes RoughScorer, a simple bilinear module to calculate rough\nanaphoricity scores.\n\"\"\"\n\nfrom typing import Tuple\n\nimport torch\n\nfrom stanza.models.coref.config import Config\n\n\nclass RoughScorer(torch.nn.Module):\n    \"\"\"\n    Is needed to give a roughly estimate of the anaphoricity of two candidates,\n    only top scoring candidates are considered on later steps to reduce\n    computational complexity.\n    \"\"\"\n    def __init__(self, features: int, config: Config):\n        super().__init__()\n        self.dropout = torch.nn.Dropout(config.dropout_rate)\n        self.bilinear = torch.nn.Linear(features, features)\n\n        self.k = config.rough_k\n\n    def forward(self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch\n                mentions: torch.Tensor,\n                ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns rough anaphoricity scores for candidates, which consist of\n        the bilinear output of the current model summed with mention scores.\n        \"\"\"\n        # [n_mentions, n_mentions]\n        pair_mask = torch.arange(mentions.shape[0])\n        pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)\n        pair_mask = torch.log((pair_mask > 0).to(torch.float))\n        pair_mask = pair_mask.to(mentions.device)\n\n        bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)\n\n        rough_scores = pair_mask + bilinear_scores\n\n        return self._prune(rough_scores)\n\n    def _prune(self,\n               rough_scores: torch.Tensor\n               ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Selects top-k rough antecedent scores for each mention.\n\n        Args:\n            rough_scores: tensor of shape [n_mentions, n_mentions], containing\n                rough antecedent scores of each mention-antecedent pair.\n\n        Returns:\n            FloatTensor of shape [n_mentions, k], top rough scores\n            LongTensor of shape [n_mentions, k], top indices\n        \"\"\"\n        top_scores, indices = torch.topk(rough_scores,\n                                         k=min(self.k, len(rough_scores)),\n                                         dim=1, sorted=False)\n        return top_scores, indices, rough_scores\n"
  },
  {
    "path": "stanza/models/coref/span_predictor.py",
    "content": "\"\"\" Describes SpanPredictor which aims to predict spans by taking as input\nhead word and context embeddings.\n\"\"\"\n\nfrom typing import List, Optional, Tuple\n\nfrom stanza.models.coref.const import Doc, Span\nimport torch\n\n\nclass SpanPredictor(torch.nn.Module):\n    def __init__(self, input_size: int, distance_emb_size: int):\n        super().__init__()\n        self.ffnn = torch.nn.Sequential(\n            torch.nn.Linear(input_size * 2 + 64, input_size),\n            torch.nn.ReLU(),\n            torch.nn.Dropout(0.3),\n            torch.nn.Linear(input_size, 256),\n            torch.nn.ReLU(),\n            torch.nn.Dropout(0.3),\n            torch.nn.Linear(256, 64),\n        )\n        self.conv = torch.nn.Sequential(\n            torch.nn.Conv1d(64, 4, 3, 1, 1),\n            torch.nn.Conv1d(4, 2, 3, 1, 1)\n        )\n        self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far\n\n    @property\n    def device(self) -> torch.device:\n        \"\"\" A workaround to get current device (which is assumed to be the\n        device of the first parameter of one of the submodules) \"\"\"\n        return next(self.ffnn.parameters()).device\n\n    def forward(self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch\n                doc: Doc,\n                words: torch.Tensor,\n                heads_ids: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Calculates span start/end scores of words for each span head in\n        heads_ids\n\n        Args:\n            doc (Doc): the document data\n            words (torch.Tensor): contextual embeddings for each word in the\n                document, [n_words, emb_size]\n            heads_ids (torch.Tensor): word indices of span heads\n\n        Returns:\n            torch.Tensor: span start/end scores, [n_heads, n_words, 2]\n        \"\"\"\n        # Obtain distance embedding indices, [n_heads, n_words]\n        relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))\n        emb_ids = relative_positions + 63               # make all valid distances positive\n        emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127  # \"too_far\"\n\n        # Obtain \"same sentence\" boolean mask, [n_heads, n_words]\n        sent_id = torch.tensor(doc[\"sent_id\"], device=words.device)\n        same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))\n\n        # To save memory, only pass candidates from one sentence for each head\n        # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb\n        # for each candidate among the words in the same sentence as span_head\n        # [n_heads, input_size * 2 + distance_emb_size]\n        rows, cols = same_sent.nonzero(as_tuple=True)\n        pair_matrix = torch.cat((\n            words[heads_ids[rows]],\n            words[cols],\n            self.emb(emb_ids[rows, cols]),\n        ), dim=1)\n\n        lengths = same_sent.sum(dim=1)\n        padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)\n        padding_mask = (padding_mask < lengths.unsqueeze(1))  # [n_heads, max_sent_len]\n\n        # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]\n        # This is necessary to allow the convolution layer to look at several\n        # word scores\n        padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)\n        padded_pairs[padding_mask] = pair_matrix\n\n        res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]\n        res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]\n\n        scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)\n        scores[rows, cols] = res[padding_mask]\n\n        # Make sure that start <= head <= end during inference\n        if not self.training:\n            valid_starts = torch.log((relative_positions >= 0).to(torch.float))\n            valid_ends = torch.log((relative_positions <= 0).to(torch.float))\n            valid_positions = torch.stack((valid_starts, valid_ends), dim=2)\n            return scores + valid_positions\n        return scores\n\n    def get_training_data(self,\n                          doc: Doc,\n                          words: torch.Tensor\n                          ) -> Tuple[Optional[torch.Tensor],\n                                     Optional[Tuple[torch.Tensor, torch.Tensor]]]:\n        \"\"\" Returns span starts/ends for gold mentions in the document. \"\"\"\n        head2span = sorted(doc[\"head2span\"])\n        if not head2span:\n            return None, None\n        heads, starts, ends = zip(*head2span)\n        heads = torch.tensor(heads, device=self.device)\n        starts = torch.tensor(starts, device=self.device)\n        ends = torch.tensor(ends, device=self.device) - 1\n        return self(doc, words, heads), (starts, ends)\n\n    def predict(self,\n                doc: Doc,\n                words: torch.Tensor,\n                clusters: List[List[int]]) -> List[List[Span]]:\n        \"\"\"\n        Predicts span clusters based on the word clusters.\n\n        Args:\n            doc (Doc): the document data\n            words (torch.Tensor): [n_words, emb_size] matrix containing\n                embeddings for each of the words in the text\n            clusters (List[List[int]]): a list of clusters where each cluster\n                is a list of word indices\n\n        Returns:\n            List[List[Span]]: span clusters\n        \"\"\"\n        if not clusters:\n            return []\n\n        heads_ids = torch.tensor(\n            sorted(i for cluster in clusters for i in cluster),\n            device=self.device\n        )\n\n        scores = self(doc, words, heads_ids)\n        starts = scores[:, :, 0].argmax(dim=1).tolist()\n        ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist()\n\n        head2span = {\n            head: (start, end)\n            for head, start, end in zip(heads_ids.tolist(), starts, ends)\n        }\n\n        return [[head2span[head] for head in cluster]\n                for cluster in clusters]\n"
  },
  {
    "path": "stanza/models/coref/tokenizer_customization.py",
    "content": "\"\"\" This file defines functions used to modify the default behaviour\nof transformers.AutoTokenizer. These changes are necessary, because some\ntokenizers are meant to be used with raw text, while the OntoNotes documents\nhave already been split into words.\nAll the functions are used in coref_model.CorefModel._get_docs. \"\"\"\n\n\n# Filters out unwanted tokens produced by the tokenizer\nTOKENIZER_FILTERS = {\n    \"albert-xxlarge-v2\": (lambda token: token != \"▁\"),  # U+2581, not just \"_\"\n    \"albert-large-v2\": (lambda token: token != \"▁\"),\n}\n\n# Maps some words to tokens directly, without a tokenizer\nTOKENIZER_MAPS = {\n    \"roberta-large\": {\".\": [\".\"], \",\": [\",\"], \"!\": [\"!\"], \"?\": [\"?\"],\n                      \":\":[\":\"], \";\":[\";\"], \"'s\": [\"'s\"]}\n}\n"
  },
  {
    "path": "stanza/models/coref/utils.py",
    "content": "\"\"\" Contains functions not directly linked to coreference resolution \"\"\"\n\nfrom typing import List, Set\n\nimport torch\nimport torch.nn.functional as F\n\nfrom stanza.models.coref.const import EPSILON\n\n\nclass GraphNode:\n    def __init__(self, node_id: int):\n        self.id = node_id\n        self.links: Set[GraphNode] = set()\n        self.visited = False\n\n    def link(self, another: \"GraphNode\"):\n        self.links.add(another)\n        another.links.add(self)\n\n    def __repr__(self) -> str:\n        return str(self.id)\n\n\ndef add_dummy(tensor: torch.Tensor, eps: bool = False):\n    \"\"\" Prepends zeros (or a very small value if eps is True)\n    to the first (not zeroth) dimension of tensor.\n    \"\"\"\n    kwargs = dict(device=tensor.device, dtype=tensor.dtype)\n    shape: List[int] = list(tensor.shape)\n    shape[1] = 1\n    if not eps:\n        dummy = torch.zeros(shape, **kwargs)          # type: ignore\n    else:\n        dummy = torch.full(shape, EPSILON, **kwargs)  # type: ignore\n    return torch.cat((dummy, tensor), dim=1)\n\ndef sigmoid_focal_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        alpha: float = 0.25,\n        gamma: float = 2,\n        reduction: str = \"none\",\n) -> torch.Tensor:\n    \"\"\"\n    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.\n\n    Args:\n        inputs (Tensor): A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets (Tensor): A float tensor with the same shape as inputs. Stores the binary\n                classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n        alpha (float): Weighting factor in range [0, 1] to balance\n                positive vs negative examples or -1 for ignore. Default: ``0.25``.\n        gamma (float): Exponent of the modulating factor (1 - p_t) to\n                balance easy vs hard examples. Default: ``2``.\n        reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``\n                ``'none'``: No reduction will be applied to the output.\n                ``'mean'``: The output will be averaged.\n                ``'sum'``: The output will be summed. Default: ``'none'``.\n    Returns:\n        Loss tensor with the reduction option applied.\n    \"\"\"\n    # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py\n\n    if not (0 <= alpha <= 1) and alpha != -1:\n        raise ValueError(f\"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.\")\n\n    p = torch.sigmoid(inputs)\n    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n    p_t = p * targets + (1 - p) * (1 - targets)\n    loss = ce_loss * ((1 - p_t) ** gamma)\n\n    if alpha >= 0:\n        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)\n        loss = alpha_t * loss\n\n    # Check reduction option and return loss accordingly\n    if reduction == \"none\":\n        pass\n    elif reduction == \"mean\":\n        loss = loss.mean()\n    elif reduction == \"sum\":\n        loss = loss.sum()\n    else:\n        raise ValueError(\n            f\"Invalid Value for arg 'reduction': '{reduction} \\n Supported reduction modes: 'none', 'mean', 'sum'\"\n        )\n    return loss\n"
  },
  {
    "path": "stanza/models/coref/word_encoder.py",
    "content": "\"\"\" Describes WordEncoder. Extracts mention vectors from bert-encoded text.\n\"\"\"\n\nfrom typing import Tuple\n\nimport torch\n\nfrom stanza.models.coref.config import Config\nfrom stanza.models.coref.const import Doc\n\n\nclass WordEncoder(torch.nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\" Receives bert contextual embeddings of a text, extracts all the\n    possible mentions in that text. \"\"\"\n\n    def __init__(self, features: int, config: Config):\n        \"\"\"\n        Args:\n            features (int): the number of featues in the input embeddings\n            config (Config): the configuration of the current session\n        \"\"\"\n        super().__init__()\n        self.attn = torch.nn.Linear(in_features=features, out_features=1)\n        self.dropout = torch.nn.Dropout(config.dropout_rate)\n\n    @property\n    def device(self) -> torch.device:\n        \"\"\" A workaround to get current device (which is assumed to be the\n        device of the first parameter of one of the submodules) \"\"\"\n        return next(self.attn.parameters()).device\n\n    def forward(self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch\n                doc: Doc,\n                x: torch.Tensor,\n                ) -> Tuple[torch.Tensor, ...]:\n        \"\"\"\n        Extracts word representations from text.\n\n        Args:\n            doc: the document data\n            x: a tensor containing bert output, shape (n_subtokens, bert_dim)\n\n        Returns:\n            words: a Tensor of shape [n_words, mention_emb];\n                mention representations\n            cluster_ids: tensor of shape [n_words], containing cluster indices\n                for each word. Non-coreferent words have cluster id of zero.\n        \"\"\"\n        word_boundaries = torch.tensor(doc[\"word2subword\"], device=self.device)\n        starts = word_boundaries[:, 0]\n        ends = word_boundaries[:, 1]\n\n        # [n_mentions, features]\n        words = self._attn_scores(x, starts, ends).mm(x)\n\n        words = self.dropout(words)\n\n        return (words, self._cluster_ids(doc))\n\n    def _attn_scores(self,\n                     bert_out: torch.Tensor,\n                     word_starts: torch.Tensor,\n                     word_ends: torch.Tensor) -> torch.Tensor:\n        \"\"\" Calculates attention scores for each of the mentions.\n\n        Args:\n            bert_out (torch.Tensor): [n_subwords, bert_emb], bert embeddings\n                for each of the subwords in the document\n            word_starts (torch.Tensor): [n_words], start indices of words\n            word_ends (torch.Tensor): [n_words], end indices of words\n\n        Returns:\n            torch.Tensor: [description]\n        \"\"\"\n        n_subtokens = len(bert_out)\n        n_words = len(word_starts)\n\n        # [n_mentions, n_subtokens]\n        # with 0 at positions belonging to the words and -inf elsewhere\n        attn_mask = torch.arange(0, n_subtokens, device=self.device).expand((n_words, n_subtokens))\n        attn_mask = ((attn_mask >= word_starts.unsqueeze(1))\n                     * (attn_mask < word_ends.unsqueeze(1)))\n\n        # if first row all False, set col 0 to True\n        # otherwise, set the row to be the previous row?\n        word_lengths = torch.sum(attn_mask, dim=1)\n        if torch.any(word_lengths == 0):\n            raise ValueError(\"Found a blank word in training data!  This will break everything, starting with the attention masks, as some rows of the scoring table will be set to entirely -inf and then softmax to NaN.\")\n\n        attn_mask = torch.log(attn_mask.to(torch.float))\n\n        attn_scores = self.attn(bert_out).T  # [1, n_subtokens]\n        attn_scores = attn_scores.expand((n_words, n_subtokens))\n        attn_scores = attn_mask + attn_scores\n        del attn_mask\n        return torch.softmax(attn_scores, dim=1)  # [n_words, n_subtokens]\n\n    def _cluster_ids(self, doc: Doc) -> torch.Tensor:\n        \"\"\"\n        Args:\n            doc: document information\n\n        Returns:\n            torch.Tensor of shape [n_word], containing cluster indices for\n                each word. Non-coreferent words have cluster id of zero.\n        \"\"\"\n        word2cluster = {word_i: i\n                        for i, cluster in enumerate(doc[\"word_clusters\"], start=1)\n                        for word_i in cluster}\n\n        return torch.tensor(\n            [word2cluster.get(word_i, 0)\n             for word_i in range(len(doc[\"cased_words\"]))],\n            device=self.device\n        )\n"
  },
  {
    "path": "stanza/models/depparse/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/depparse/data.py",
    "content": "import random\nimport logging\nimport torch\n\nfrom stanza.models.common.bert_embedding import filter_data, needs_length_filter\nfrom stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all\nfrom stanza.models.common.utils import DEFAULT_WORD_CUTOFF, simplify_punct\nfrom stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, ROOT_ID, CompositeVocab, CharVocab\nfrom stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab\nfrom stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory\nfrom stanza.models.common.doc import *\n\nlogger = logging.getLogger('stanza')\n\ndef data_to_batches(data, batch_size, eval_mode, sort_during_eval, min_length_to_batch_separately):\n    \"\"\"\n    Given a list of lists, where the first element of each sublist\n    represents the sentence, group the sentences into batches.\n\n    During training mode (not eval_mode) the sentences are sorted by\n    length with a bit of random shuffling.  During eval mode, the\n    sentences are sorted by length if sort_during_eval is true.\n\n    Refactored from the data structure in case other models could use\n    it and for ease of testing.\n\n    Returns (batches, original_order), where original_order is None\n    when in train mode or when unsorted and represents the original\n    location of each sentence in the sort\n    \"\"\"\n    res = []\n\n    if not eval_mode:\n        # sort sentences (roughly) by length for better memory utilization\n        data = sorted(data, key = lambda x: len(x[0]), reverse=random.random() > .5)\n        data_orig_idx = None\n    elif sort_during_eval:\n        (data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data])\n    else:\n        data_orig_idx = None\n\n    current = []\n    currentlen = 0\n    for x in data:\n        if min_length_to_batch_separately is not None and len(x[0]) > min_length_to_batch_separately:\n            if currentlen > 0:\n                res.append(current)\n                current = []\n                currentlen = 0\n            res.append([x])\n        else:\n            if len(x[0]) + currentlen > batch_size and currentlen > 0:\n                res.append(current)\n                current = []\n                currentlen = 0\n            current.append(x)\n            currentlen += len(x[0])\n\n    if currentlen > 0:\n        res.append(current)\n\n    return res, data_orig_idx\n\n\nclass DataLoader:\n\n    def __init__(self, doc, batch_size, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, min_length_to_batch_separately=None, bert_tokenizer=None):\n        self.batch_size = batch_size\n        self.min_length_to_batch_separately=min_length_to_batch_separately\n        self.args = args\n        self.eval = evaluation\n        self.shuffled = not self.eval\n        self.sort_during_eval = sort_during_eval\n        self.doc = doc\n        data = self.load_doc(doc)\n\n        # handle vocab\n        if vocab is None:\n            self.vocab = self.init_vocab(data)\n        else:\n            self.vocab = vocab\n        \n        # filter out the long sentences if bert is used\n        if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):\n            data = filter_data(self.args['bert_model'], data, bert_tokenizer)\n\n        # handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None\n        self.pretrain_vocab = None\n        if pretrain is not None and args['pretrain']:\n            self.pretrain_vocab = pretrain.vocab\n\n        # filter and sample data\n        if args.get('sample_train', 1.0) < 1.0 and not self.eval:\n            keep = int(args['sample_train'] * len(data))\n            data = random.sample(data, keep)\n            logger.debug(\"Subsample training set with rate {:g}\".format(args['sample_train']))\n\n        data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)\n        # shuffle for training\n        if self.shuffled:\n            random.shuffle(data)\n        self.num_examples = len(data)\n\n        # chunk into batches\n        self.data = self.chunk_batches(data)\n        logger.debug(\"{} batches created.\".format(len(self.data)))\n\n    def init_vocab(self, data):\n        assert self.eval == False # for eval vocab must exist\n        cutoff = self.args['word_cutoff'] if self.args.get('word_cutoff') is not None else DEFAULT_WORD_CUTOFF\n        charvocab = CharVocab(data, self.args['shorthand'])\n        wordvocab = WordVocab(data, self.args['shorthand'], cutoff=cutoff, lower=True)\n        uposvocab = WordVocab(data, self.args['shorthand'], idx=1)\n        xposvocab = xpos_vocab_factory(data, self.args['shorthand'])\n        featsvocab = FeatureVocab(data, self.args['shorthand'], idx=3)\n        lemmavocab = WordVocab(data, self.args['shorthand'], cutoff=cutoff, idx=4, lower=True)\n        deprelvocab = WordVocab(data, self.args['shorthand'], idx=6)\n        vocab = MultiVocab({'char': charvocab,\n                            'word': wordvocab,\n                            'upos': uposvocab,\n                            'xpos': xposvocab,\n                            'feats': featsvocab,\n                            'lemma': lemmavocab,\n                            'deprel': deprelvocab})\n        return vocab\n\n    def preprocess(self, data, vocab, pretrain_vocab, args):\n        processed = []\n        xpos_replacement = [[ROOT_ID] * len(vocab['xpos'])] if isinstance(vocab['xpos'], CompositeVocab) else [ROOT_ID]\n        feats_replacement = [[ROOT_ID] * len(vocab['feats'])]\n        for sent in data:\n            processed_sent = [[ROOT_ID] + vocab['word'].map([w[0] for w in sent])]\n            processed_sent += [[[ROOT_ID]] + [vocab['char'].map([x for x in w[0]]) for w in sent]]\n            processed_sent += [[ROOT_ID] + vocab['upos'].map([w[1] for w in sent])]\n            processed_sent += [xpos_replacement + vocab['xpos'].map([w[2] for w in sent])]\n            processed_sent += [feats_replacement + vocab['feats'].map([w[3] for w in sent])]\n            if pretrain_vocab is not None:\n                # always use lowercase lookup in pretrained vocab\n                processed_sent += [[ROOT_ID] + pretrain_vocab.map([w[0].lower() for w in sent])]\n            else:\n                processed_sent += [[ROOT_ID] + [PAD_ID] * len(sent)]\n            processed_sent += [[ROOT_ID] + vocab['lemma'].map([w[4] for w in sent])]\n            processed_sent += [[to_int(w[5], ignore_error=self.eval) for w in sent]]\n            processed_sent += [vocab['deprel'].map([w[6] for w in sent])]\n            processed_sent.append([w[0] for w in sent])\n            processed.append(processed_sent)\n        return processed\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, key):\n        \"\"\" Get a batch with index. \"\"\"\n        if not isinstance(key, int):\n            raise TypeError\n        if key < 0 or key >= len(self.data):\n            raise IndexError\n        batch = self.data[key]\n        batch_size = len(batch)\n        batch = list(zip(*batch))\n        assert len(batch) == 10\n\n        # sort sentences by lens for easy RNN operations\n        lens = [len(x) for x in batch[0]]\n        batch, orig_idx = sort_all(batch, lens)\n\n        # sort words by lens for easy char-RNN operations\n        batch_words = [w for sent in batch[1] for w in sent]\n        word_lens = [len(x) for x in batch_words]\n        batch_words, word_orig_idx = sort_all([batch_words], word_lens)\n        batch_words = batch_words[0]\n        word_lens = [len(x) for x in batch_words]\n\n        # convert to tensors\n        words = batch[0]\n        words = get_long_tensor(words, batch_size)\n        words_mask = torch.eq(words, PAD_ID)\n        wordchars = get_long_tensor(batch_words, len(word_lens))\n        wordchars_mask = torch.eq(wordchars, PAD_ID)\n\n        upos = get_long_tensor(batch[2], batch_size)\n        xpos = get_long_tensor(batch[3], batch_size)\n        ufeats = get_long_tensor(batch[4], batch_size)\n        pretrained = get_long_tensor(batch[5], batch_size)\n        sentlens = [len(x) for x in batch[0]]\n        lemma = get_long_tensor(batch[6], batch_size)\n        head = get_long_tensor(batch[7], batch_size)\n        deprel = get_long_tensor(batch[8], batch_size)\n        text = batch[9]\n        return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, orig_idx, word_orig_idx, sentlens, word_lens, text\n\n    def load_doc(self, doc):\n        data = doc.get([TEXT, UPOS, XPOS, FEATS, LEMMA, HEAD, DEPREL], as_sentences=True)\n        data = self.resolve_none(data)\n        data = simplify_punct(data)\n        return data\n\n    def resolve_none(self, data):\n        # replace None to '_'\n        for sent_idx in range(len(data)):\n            for tok_idx in range(len(data[sent_idx])):\n                for feat_idx in range(len(data[sent_idx][tok_idx])):\n                    if data[sent_idx][tok_idx][feat_idx] is None:\n                        data[sent_idx][tok_idx][feat_idx] = '_'\n        return data\n\n    def __iter__(self):\n        for i in range(self.__len__()):\n            yield self.__getitem__(i)\n\n    def set_batch_size(self, batch_size):\n        self.batch_size = batch_size\n\n    def reshuffle(self):\n        data = [y for x in self.data for y in x]\n        self.data = self.chunk_batches(data)\n        random.shuffle(self.data)\n\n    def chunk_batches(self, data):\n        batches, data_orig_idx = data_to_batches(data=data, batch_size=self.batch_size,\n                                                 eval_mode=self.eval, sort_during_eval=self.sort_during_eval,\n                                                 min_length_to_batch_separately=self.min_length_to_batch_separately)\n        # data_orig_idx might be None at train time, since we don't anticipate unsorting\n        self.data_orig_idx = data_orig_idx\n        return batches\n\ndef to_int(string, ignore_error=False):\n    try:\n        res = int(string)\n    except ValueError as err:\n        if ignore_error:\n            return 0\n        else:\n            raise err\n    return res\n\n"
  },
  {
    "path": "stanza/models/depparse/model.py",
    "content": "import logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence\n\nfrom stanza.models.common.bert_embedding import extract_bert_embeddings\nfrom stanza.models.common.biaffine import DeepBiaffineScorer\nfrom stanza.models.common.foundation_cache import load_charlm\nfrom stanza.models.common.hlstm import HighwayLSTM\nfrom stanza.models.common.dropout import WordDropout\nfrom stanza.models.common.utils import attach_bert_model\nfrom stanza.models.common.vocab import CompositeVocab\nfrom stanza.models.common.char_model import CharacterModel, CharacterLanguageModel\nfrom stanza.models.common import utils\n\nlogger = logging.getLogger('stanza')\n\nclass Parser(nn.Module):\n    def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):\n        super().__init__()\n\n        self.vocab = vocab\n        self.args = args\n        self.unsaved_modules = []\n\n        # input layers\n        input_size = 0\n        if self.args['word_emb_dim'] > 0:\n            # frequent word embeddings\n            self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)\n            self.lemma_emb = nn.Embedding(len(vocab['lemma']), self.args['word_emb_dim'], padding_idx=0)\n            input_size += self.args['word_emb_dim'] * 2\n\n        if self.args['tag_emb_dim'] > 0:\n            if self.args.get('use_upos', True):\n                self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)\n            if self.args.get('use_xpos', True):\n                if not isinstance(vocab['xpos'], CompositeVocab):\n                    self.xpos_emb = nn.Embedding(len(vocab['xpos']), self.args['tag_emb_dim'], padding_idx=0)\n                else:\n                    self.xpos_emb = nn.ModuleList()\n\n                    for l in vocab['xpos'].lens():\n                        self.xpos_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))\n            if self.args.get('use_upos', True) or self.args.get('use_xpos', True):\n                input_size += self.args['tag_emb_dim']\n\n            if self.args.get('use_ufeats', True):\n                self.ufeats_emb = nn.ModuleList()\n\n                for l in vocab['feats'].lens():\n                    self.ufeats_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))\n\n                input_size += self.args['tag_emb_dim']\n\n        if self.args['char'] and self.args['char_emb_dim'] > 0:\n            if self.args.get('charlm', None):\n                if self.args['charlm_forward_file'] is None or not os.path.exists(self.args['charlm_forward_file']):\n                    raise FileNotFoundError('Could not find forward character model: {}  Please specify with --charlm_forward_file'.format(self.args['charlm_forward_file']))\n                if self.args['charlm_backward_file'] is None or not os.path.exists(self.args['charlm_backward_file']):\n                    raise FileNotFoundError('Could not find backward character model: {}  Please specify with --charlm_backward_file'.format(self.args['charlm_backward_file']))\n                logger.debug(\"Depparse model loading charmodels: %s and %s\", self.args['charlm_forward_file'], self.args['charlm_backward_file'])\n                self.add_unsaved_module('charmodel_forward', load_charlm(self.args['charlm_forward_file'], foundation_cache=foundation_cache))\n                self.add_unsaved_module('charmodel_backward', load_charlm(self.args['charlm_backward_file'], foundation_cache=foundation_cache))\n                input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()\n            else:\n                self.charmodel = CharacterModel(self.args, vocab)\n                self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)\n                input_size += self.args['transformed_dim']\n\n        self.peft_name = peft_name\n        attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)\n        if self.args.get('bert_model', None):\n            # TODO: refactor bert_hidden_layers between the different models\n            if self.args.get('bert_hidden_layers', False):\n                # The average will be offset by 1/N so that the default zeros\n                # represents an average of the N layers\n                self.bert_layer_mix = nn.Linear(self.args['bert_hidden_layers'], 1, bias=False)\n                nn.init.zeros_(self.bert_layer_mix.weight)\n            else:\n                # an average of layers 2, 3, 4 will be used\n                # (for historic reasons)\n                self.bert_layer_mix = None\n            input_size += self.bert_model.config.hidden_size\n\n        if self.args['pretrain']:\n            # pretrained embeddings, by default this won't be saved into model file\n            self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))\n            self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)\n            input_size += self.args['transformed_dim']\n\n        # recurrent layers\n        self.parserlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)\n        self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))\n        self.parserlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))\n        self.parserlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))\n\n        # dropout\n        self.drop = nn.Dropout(self.args['dropout'])\n        self.worddrop = WordDropout(self.args['word_dropout'])\n\n        # classifiers\n        # args.get to preserve old models, including models other people might have created\n        if self.args.get('use_arc_embedding'):\n            logger.debug(\"Using arc embedding enhancement\")\n            self.arc_embedding = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], self.args['deep_biaff_output_dim'], pairwise=True, dropout=self.args['dropout'])\n            self.unlabeled_linear = nn.Sequential(self.drop,\n                                                  nn.Linear(self.args['deep_biaff_output_dim'], 1))\n            self.deprel_linear = nn.Sequential(self.drop,\n                                               nn.Linear(self.args['deep_biaff_output_dim'], 2 * self.args['deep_biaff_output_dim']),\n                                               nn.ReLU(),\n                                               self.drop,\n                                               nn.Linear(self.args['deep_biaff_output_dim'] * 2, len(vocab['deprel'])))\n        else:\n            logger.debug(\"Not using arc embedding enhancement\")\n            self.unlabeled = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout'])\n            self.deprel = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], len(vocab['deprel']), pairwise=True, dropout=self.args['dropout'])\n        if self.args['linearization']:\n            self.linearization = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout'])\n        if self.args['distance']:\n            self.distance = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout'])\n\n        # criterion\n        self.crit = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') # ignore padding\n\n\n    def add_unsaved_module(self, name, module):\n        self.unsaved_modules += [name]\n        setattr(self, name, module)\n\n    def log_norms(self):\n        utils.log_norms(self)\n\n    def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):\n        def pack(x):\n            return pack_padded_sequence(x, sentlens, batch_first=True)\n\n        inputs = []\n        if self.args['pretrain']:\n            pretrained_emb = self.pretrained_emb(pretrained)\n            pretrained_emb = self.trans_pretrained(pretrained_emb)\n            pretrained_emb = pack(pretrained_emb)\n            inputs += [pretrained_emb]\n\n        #def pad(x):\n        #    return pad_packed_sequence(PackedSequence(x, pretrained_emb.batch_sizes), batch_first=True)[0]\n\n        if self.args['word_emb_dim'] > 0:\n            word_emb = self.word_emb(word)\n            word_emb = pack(word_emb)\n            lemma_emb = self.lemma_emb(lemma)\n            lemma_emb = pack(lemma_emb)\n            inputs += [word_emb, lemma_emb]\n\n        if self.args['tag_emb_dim'] > 0:\n            if self.args.get('use_upos', True):\n                pos_emb = self.upos_emb(upos)\n            else:\n                pos_emb = 0\n\n            if self.args.get('use_xpos', True):\n                if isinstance(self.vocab['xpos'], CompositeVocab):\n                    for i in range(len(self.vocab['xpos'])):\n                        pos_emb += self.xpos_emb[i](xpos[:, :, i])\n                else:\n                    pos_emb += self.xpos_emb(xpos)\n\n            if self.args.get('use_upos', True) or self.args.get('use_xpos', True):\n                pos_emb = pack(pos_emb)\n                inputs += [pos_emb]\n\n            if self.args.get('use_ufeats', True):\n                feats_emb = 0\n                for i in range(len(self.vocab['feats'])):\n                    feats_emb += self.ufeats_emb[i](ufeats[:, :, i])\n                feats_emb = pack(feats_emb)\n\n                inputs += [pos_emb]\n\n        if self.args['char'] and self.args['char_emb_dim'] > 0:\n            if self.args.get('charlm', None):\n                # \\n is to add a somewhat neutral \"word\" for the ROOT\n                charlm_text = [[\"\\n\"] + x for x in text]\n                all_forward_chars = self.charmodel_forward.build_char_representation(charlm_text)\n                all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))\n                all_backward_chars = self.charmodel_backward.build_char_representation(charlm_text)\n                all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))\n                inputs += [all_forward_chars, all_backward_chars]\n            else:\n                char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)\n                char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)\n                inputs += [char_reps]\n\n        if self.bert_model is not None:\n            device = next(self.parameters()).device\n            processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=True,\n                                                     num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,\n                                                     detach=not self.args.get('bert_finetune', False) or not self.training,\n                                                     peft_name=self.peft_name)\n            if self.bert_layer_mix is not None:\n                # use a linear layer to weighted average the embedding dynamically\n                processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]\n\n            # we are using the first endpoint from the transformer as the \"word\" for ROOT\n            processed_bert = [x[:-1, :] for x in processed_bert]\n            processed_bert = pad_sequence(processed_bert, batch_first=True)\n            inputs += [pack(processed_bert)]\n\n        lstm_inputs = torch.cat([x.data for x in inputs], 1)\n\n        lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)\n        lstm_inputs = self.drop(lstm_inputs)\n\n        lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)\n\n        lstm_outputs, _ = self.parserlstm(lstm_inputs, sentlens, hx=(self.parserlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.parserlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))\n        lstm_outputs, _ = pad_packed_sequence(lstm_outputs, batch_first=True)\n\n        if self.args.get('use_arc_embedding'):\n            arc_scores = self.arc_embedding(self.drop(lstm_outputs), self.drop(lstm_outputs))\n            unlabeled_scores = self.unlabeled_linear(arc_scores).squeeze(3)\n            deprel_scores = self.deprel_linear(arc_scores).squeeze(3)\n        else:\n            unlabeled_scores = self.unlabeled(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)\n            deprel_scores = self.deprel(self.drop(lstm_outputs), self.drop(lstm_outputs))\n\n        #goldmask = head.new_zeros(*head.size(), head.size(-1)+1, dtype=torch.uint8)\n        #goldmask.scatter_(2, head.unsqueeze(2), 1)\n\n        if self.args['linearization'] or self.args['distance']:\n            head_offset = torch.arange(word.size(1), device=head.device).view(1, 1, -1).expand(word.size(0), -1, -1) - torch.arange(word.size(1), device=head.device).view(1, -1, 1).expand(word.size(0), -1, -1)\n\n        if self.args['linearization']:\n            lin_scores = self.linearization(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)\n            unlabeled_scores += F.logsigmoid(lin_scores * torch.sign(head_offset).float()).detach()\n\n        if self.args['distance']:\n            dist_scores = self.distance(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)\n            dist_pred = 1 + F.softplus(dist_scores)\n            dist_target = torch.abs(head_offset)\n            dist_kld = -torch.log((dist_target.float() - dist_pred)**2/2 + 1)\n            unlabeled_scores += dist_kld.detach()\n\n        diag = torch.eye(head.size(-1)+1, dtype=torch.bool, device=head.device).unsqueeze(0)\n        unlabeled_scores.masked_fill_(diag, -float('inf'))\n\n        preds = []\n\n        if self.training:\n            unlabeled_scores = unlabeled_scores[:, 1:, :] # exclude attachment for the root symbol\n            unlabeled_scores = unlabeled_scores.masked_fill(word_mask.unsqueeze(1), -float('inf'))\n            unlabeled_target = head.masked_fill(word_mask[:, 1:], -1)\n            loss = self.crit(unlabeled_scores.contiguous().view(-1, unlabeled_scores.size(2)), unlabeled_target.view(-1))\n\n            deprel_scores = deprel_scores[:, 1:] # exclude attachment for the root symbol\n            #deprel_scores = deprel_scores.masked_select(goldmask.unsqueeze(3)).view(-1, len(self.vocab['deprel']))\n            deprel_scores = torch.gather(deprel_scores, 2, head.unsqueeze(2).unsqueeze(3).expand(-1, -1, -1, len(self.vocab['deprel']))).view(-1, len(self.vocab['deprel']))\n            deprel_target = deprel.masked_fill(word_mask[:, 1:], -1)\n            loss += self.crit(deprel_scores.contiguous(), deprel_target.view(-1))\n\n            if self.args['linearization']:\n                #lin_scores = lin_scores[:, 1:].masked_select(goldmask)\n                lin_scores = torch.gather(lin_scores[:, 1:], 2, head.unsqueeze(2)).view(-1)\n                lin_scores = torch.cat([-lin_scores.unsqueeze(1)/2, lin_scores.unsqueeze(1)/2], 1)\n                #lin_target = (head_offset[:, 1:] > 0).long().masked_select(goldmask)\n                lin_target = torch.gather((head_offset[:, 1:] > 0).long(), 2, head.unsqueeze(2))\n                loss += self.crit(lin_scores.contiguous(), lin_target.view(-1))\n\n            if self.args['distance']:\n                #dist_kld = dist_kld[:, 1:].masked_select(goldmask)\n                # dist_kld[:, 1:] so that the root isn't included in the distance calculation\n                dist_kld = torch.gather(dist_kld[:, 1:], 2, head.unsqueeze(2))\n                loss -= dist_kld.sum()\n\n            loss /= wordchars.size(0) # number of words\n        else:\n            loss = 0\n            preds.append(F.log_softmax(unlabeled_scores, 2).detach().cpu().numpy())\n            preds.append(deprel_scores.max(3)[1].detach().cpu().numpy())\n\n        return loss, preds\n"
  },
  {
    "path": "stanza/models/depparse/scorer.py",
    "content": "\"\"\"\nUtils and wrappers for scoring parsers.\n\"\"\"\n\nfrom collections import Counter\nimport logging\n\nfrom stanza.models.common.utils import ud_scores\n\nlogger = logging.getLogger('stanza')\n\ndef score_named_dependencies(pred_doc, gold_doc, output_latex=False):\n    if len(pred_doc.sentences) != len(gold_doc.sentences):\n        logger.warning(\"Not evaluating individual dependency F1 on accound of document length mismatch\")\n        return\n    for sent_idx, (x, y) in enumerate(zip(pred_doc.sentences, gold_doc.sentences)):\n        if len(x.words) != len(y.words):\n            logger.warning(\"Not evaluating individual dependency F1 on accound of sentence length mismatch\")\n            return\n\n    tp = Counter()\n    fp = Counter()\n    fn = Counter()\n    for pred_sentence, gold_sentence in zip(pred_doc.sentences, gold_doc.sentences):\n        for pred_word, gold_word in zip(pred_sentence.words, gold_sentence.words):\n            if pred_word.head == gold_word.head and pred_word.deprel == gold_word.deprel:\n                tp[gold_word.deprel] = tp[gold_word.deprel] + 1\n            else:\n                fn[gold_word.deprel] = fn[gold_word.deprel] + 1\n                fp[pred_word.deprel] = fp[pred_word.deprel] + 1\n\n    labels = sorted(set(tp.keys()).union(fp.keys()).union(fn.keys()))\n    max_len = max(len(x) for x in labels)\n    log_lines = []\n    #log_line_fmt = \"%\" + str(max_len) + \"s: p %.4f r %.4f f1 %.4f (%d actual)\"\n    if output_latex:\n        log_lines.append(r\"\\begin{tabular}{lrr}\")\n        log_lines.append(r\"Reln & F1 & Total \\\\\")\n        log_line_fmt = \"{label} & {f1:0.4f} & {actual} \\\\\\\\\"\n    else:\n        log_line_fmt = \"{label:>\" + str(max_len) + \"s}: p {precision:0.4f} r {recall:0.4f} f1 {f1:0.4f} ({actual} actual)\"\n    for label in labels:\n        if tp[label] == 0:\n            precision = 0\n            recall = 0\n            f1 = 0\n        else:\n            precision = tp[label] / (tp[label] + fp[label])\n            recall = tp[label] / (tp[label] + fn[label])\n            f1 = 2 * (precision * recall) / (precision + recall)\n        actual = tp[label] + fn[label]\n        template = {\n            'label': label,\n            'precision': precision,\n            'recall': recall,\n            'f1': f1,\n            'actual': actual\n        }\n        log_lines.append(log_line_fmt.format(**template))\n    if output_latex:\n        log_lines.append(r\"\\end{tabular}\")\n    logger.info(\"F1 scores for each dependency:\\n  Note that unlabeled attachment errors hurt the labeled attachment scores\\n%s\" % \"\\n\".join(log_lines))\n\ndef score(system_conllu_file, gold_conllu_file, verbose=True):\n    \"\"\" Wrapper for UD parser scorer. \"\"\"\n    evaluation = ud_scores(gold_conllu_file, system_conllu_file)\n    el = evaluation['LAS']\n    p = el.precision\n    r = el.recall\n    f = el.f1\n    if verbose:\n        scores = [evaluation[k].f1 * 100 for k in ['LAS', 'MLAS', 'BLEX']]\n        logger.info(\"LAS\\tMLAS\\tBLEX\")\n        logger.info(\"{:.2f}\\t{:.2f}\\t{:.2f}\".format(*scores))\n    return p, r, f\n\n"
  },
  {
    "path": "stanza/models/depparse/trainer.py",
    "content": "\"\"\"\nA trainer class to handle training and testing of models.\n\"\"\"\n\nimport copy\nimport sys\nimport logging\nimport torch\nfrom torch import nn\n\ntry:\n    import transformers\nexcept ImportError:\n    pass\n\nfrom stanza.models.common.trainer import Trainer as BaseTrainer\nfrom stanza.models.common import utils, loss\nfrom stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache\nfrom stanza.models.common.chuliu_edmonds import chuliu_edmonds_one_root\nfrom stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper\nfrom stanza.models.common.vocab import VOCAB_PREFIX_SIZE\nfrom stanza.models.depparse.model import Parser\nfrom stanza.models.pos.vocab import MultiVocab\n\nlogger = logging.getLogger('stanza')\n\ndef unpack_batch(batch, device):\n    \"\"\" Unpack a batch from the data loader. \"\"\"\n    inputs = [b.to(device) if b is not None else None for b in batch[:11]]\n    orig_idx = batch[11]\n    word_orig_idx = batch[12]\n    sentlens = batch[13]\n    wordlens = batch[14]\n    text = batch[15]\n    return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text\n\nclass Trainer(BaseTrainer):\n    \"\"\" A trainer for training models. \"\"\"\n    def __init__(self, args=None, vocab=None, pretrain=None, model_file=None,\n                 device=None, foundation_cache=None, ignore_model_config=False, reset_history=False):\n        self.global_step = 0\n        self.last_best_step = 0\n        self.dev_score_history = []\n\n        orig_args = copy.deepcopy(args)\n        # whether the training is in primary or secondary stage\n        # during FT (loading weights), etc., the training is considered to be in \"secondary stage\"\n        # during this time, we (optionally) use a different set of optimizers than that during \"primary stage\".\n        #\n        # Regardless, we use TWO SETS of optimizers; once primary converges, we switch to secondary\n\n        if model_file is not None:\n            # load everything from file\n            self.load(model_file, pretrain, args, foundation_cache, device)\n\n            if reset_history:\n                self.global_step = 0\n                self.last_best_step = 0\n                self.dev_score_history = []\n        else:\n            # build model from scratch\n            self.args = args\n            self.vocab = vocab\n\n            bert_model, bert_tokenizer = load_bert(self.args['bert_model'])\n            peft_name = None\n            if self.args['use_peft']:\n                # fine tune the bert if we're using peft\n                self.args['bert_finetune'] = True\n                peft_name = \"depparse\"\n                bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)\n\n            self.model = Parser(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)\n            self.model = self.model.to(device)\n            self.__init_optim()\n\n        self.fallback = self.vocab['deprel'].unit2id('dep') if 'dep' in self.vocab['deprel'] else None\n\n        if ignore_model_config:\n            self.args = orig_args\n\n        if self.args.get('wandb'):\n            import wandb\n            # track gradients!\n            wandb.watch(self.model, log_freq=4, log=\"all\", log_graph=True)\n\n    def __init_optim(self):\n        # TODO: can get rid of args.get when models are rebuilt\n        if (self.args.get(\"second_stage\", False) and self.args.get('second_optim')):\n            self.optimizer = utils.get_split_optimizer(self.args['second_optim'], self.model,\n                                                       self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6,\n                                                       bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0),\n                                                       is_peft=self.args.get('use_peft', False),\n                                                       bert_finetune_layers=self.args.get('bert_finetune_layers', None))\n        else:\n            self.optimizer = utils.get_split_optimizer(self.args['optim'], self.model,\n                                                       self.args['lr'], betas=(0.9, self.args['beta2']),\n                                                       eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0),\n                                                       weight_decay=self.args.get('weight_decay', None),\n                                                       bert_weight_decay=self.args.get('bert_weight_decay', 0.0),\n                                                       is_peft=self.args.get('use_peft', False),\n                                                       bert_finetune_layers=self.args.get('bert_finetune_layers', None))\n        self.scheduler = {}\n        if self.args.get(\"second_stage\", False) and self.args.get('second_optim'):\n            if self.args.get('second_warmup_steps', None):\n                for name, optimizer in self.optimizer.items():\n                    name = name + \"_scheduler\"\n                    warmup_scheduler = transformers.get_constant_schedule_with_warmup(optimizer, self.args['second_warmup_steps'])\n                    self.scheduler[name] = warmup_scheduler\n        else:\n            if \"bert_optimizer\" in self.optimizer:\n                zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer[\"bert_optimizer\"], factor=0, total_iters=self.args['bert_start_finetuning'])\n                warmup_scheduler = transformers.get_constant_schedule_with_warmup(\n                    self.optimizer[\"bert_optimizer\"],\n                    self.args['bert_warmup_steps'])\n                self.scheduler[\"bert_scheduler\"] = torch.optim.lr_scheduler.SequentialLR(\n                    self.optimizer[\"bert_optimizer\"],\n                    schedulers=[zero_scheduler, warmup_scheduler],\n                    milestones=[self.args['bert_start_finetuning']])\n\n    def update(self, batch, eval=False):\n        device = next(self.model.parameters()).device\n        inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)\n        word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs\n\n        if eval:\n            self.model.eval()\n        else:\n            self.model.train()\n            for opt in self.optimizer.values():\n                opt.zero_grad()\n        loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)\n        loss_val = loss.data.item()\n        if eval:\n            return loss_val\n\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])\n        for opt in self.optimizer.values():\n            opt.step()\n        for scheduler in self.scheduler.values():\n            scheduler.step()\n        return loss_val\n\n    def predict(self, batch, unsort=True):\n        device = next(self.model.parameters()).device\n        inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)\n        word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs\n\n        self.model.eval()\n        batch_size = word.size(0)\n        _, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)\n        # TODO: would be cleaner for the model to not have the capability to produce predictions < VOCAB_PREFIX_SIZE\n        if self.fallback is not None:\n            preds[1][preds[1] < VOCAB_PREFIX_SIZE] = self.fallback\n        head_seqs = [chuliu_edmonds_one_root(adj[:l, :l])[1:] for adj, l in zip(preds[0], sentlens)] # remove attachment for the root\n        deprel_seqs = [self.vocab['deprel'].unmap([preds[1][i][j+1][h] for j, h in enumerate(hs)]) for i, hs in enumerate(head_seqs)]\n\n        pred_tokens = [[[head_seqs[i][j], deprel_seqs[i][j]] for j in range(sentlens[i]-1)] for i in range(batch_size)]\n        if unsort:\n            pred_tokens = utils.unsort(pred_tokens, orig_idx)\n        return pred_tokens\n\n    def save(self, filename, skip_modules=True, save_optimizer=False):\n        model_state = self.model.state_dict()\n        # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file\n        if skip_modules:\n            skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]\n            for k in skipped:\n                del model_state[k]\n        params = {\n                'model': model_state,\n                'vocab': self.vocab.state_dict(),\n                'config': self.args,\n                'global_step': self.global_step,\n                'last_best_step': self.last_best_step,\n                'dev_score_history': self.dev_score_history,\n                }\n        if self.args.get('use_peft', False):\n            # Hide import so that peft dependency is optional\n            from peft import get_peft_model_state_dict\n            params[\"bert_lora\"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)\n\n        if save_optimizer and self.optimizer is not None:\n            params['optimizer_state_dict'] = {k: opt.state_dict() for k, opt in self.optimizer.items()}\n            params['scheduler_state_dict'] = {k: scheduler.state_dict() for k, scheduler in self.scheduler.items()}\n\n        try:\n            torch.save(params, filename, _use_new_zipfile_serialization=False)\n            logger.info(\"Model saved to {}\".format(filename))\n        except BaseException as e:\n            logger.warning(\"Saving failed... continuing anyway.  Error was: %s\" % e)\n\n    def load(self, filename, pretrain, args=None, foundation_cache=None, device=None):\n        \"\"\"\n        Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,\n        and the actual use of pretrain embeddings will depend on the boolean config \"pretrain\" in the loaded args.\n        \"\"\"\n        try:\n            checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        except BaseException:\n            logger.error(\"Cannot load model from {}\".format(filename))\n            raise\n        self.args = checkpoint['config']\n        if args is not None: self.args.update(args)\n\n        # preserve old models which were created before transformers were added\n        if 'bert_model' not in self.args:\n            self.args['bert_model'] = None\n\n        lora_weights = checkpoint.get('bert_lora')\n        if lora_weights:\n            logger.debug(\"Found peft weights for depparse; loading a peft adapter\")\n            self.args[\"use_peft\"] = True\n\n        self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])\n        # load model\n        emb_matrix = None\n        if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None\n            emb_matrix = pretrain.emb\n\n        # TODO: refactor this common block of code with NER\n        force_bert_saved = False\n        peft_name = None\n        if self.args.get('use_peft', False):\n            force_bert_saved = True\n            bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], \"depparse\", foundation_cache)\n            bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)\n            logger.debug(\"Loaded peft with name %s\", peft_name)\n        else:\n            if any(x.startswith(\"bert_model.\") for x in checkpoint['model'].keys()):\n                logger.debug(\"Model %s has a finetuned transformer.  Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere\", filename)\n                foundation_cache = NoTransformerFoundationCache(foundation_cache)\n                force_bert_saved = True\n            bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)\n\n        self.model = Parser(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)\n        self.model.load_state_dict(checkpoint['model'], strict=False)\n\n        if device is not None:\n            self.model = self.model.to(device)\n\n        self.__init_optim()\n        optim_state_dict = checkpoint.get(\"optimizer_state_dict\")\n        if optim_state_dict:\n            for k, state in optim_state_dict.items():\n                self.optimizer[k].load_state_dict(state)\n\n        scheduler_state_dict = checkpoint.get(\"scheduler_state_dict\")\n        if scheduler_state_dict:\n            for k, state in scheduler_state_dict.items():\n                self.scheduler[k].load_state_dict(state)\n\n        self.global_step = checkpoint.get(\"global_step\", 0)\n        self.last_best_step = checkpoint.get(\"last_best_step\", 0)\n        self.dev_score_history = checkpoint.get(\"dev_score_history\", list())\n"
  },
  {
    "path": "stanza/models/identity_lemmatizer.py",
    "content": "\"\"\"\nAn identity lemmatizer that mimics the behavior of a normal lemmatizer but directly uses word as lemma.\n\"\"\"\n\nimport os\nimport argparse\nimport logging\nimport random\n\nfrom stanza.models.lemma.data import DataLoader\nfrom stanza.models.lemma import scorer\nfrom stanza.models.common import utils\nfrom stanza.models.common.doc import *\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models import _training_logging\n\nlogger = logging.getLogger('stanza')\n\ndef parse_args(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')\n    parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')\n    parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')\n\n    parser.add_argument('--mode', default='train', choices=['train', 'predict'])\n    parser.add_argument('--shorthand', type=str, help='Shorthand')\n\n    parser.add_argument('--batch_size', type=int, default=50)\n    parser.add_argument('--seed', type=int, default=1234)\n\n    args = parser.parse_args(args=args)\n    return args\n\ndef main(args=None):\n    args = parse_args(args=args)\n\n    random.seed(args.seed)\n\n    args = vars(args)\n\n    logger.info(\"[Launching identity lemmatizer...]\")\n\n    if args['mode'] == 'train':\n        logger.info(\"[No training is required; will only generate evaluation output...]\")\n    \n    document = CoNLL.conll2doc(input_file=args['eval_file'])\n    batch = DataLoader(document, args['batch_size'], args, evaluation=True, conll_only=True)\n    system_pred_file = args['output_file']\n    gold_file = args['gold_file']\n\n    # use identity mapping for prediction\n    preds = batch.doc.get([TEXT])\n\n    # write to file and score\n    batch.doc.set([LEMMA], preds)\n    if system_pred_file is not None:\n        CoNLL.write_doc2conll(batch.doc, system_pred_file)\n    if gold_file is not None:\n        system_pred_file = \"{:C}\\n\\n\".format(batch.doc)\n        system_pred_file = io.StringIO(system_pred_file)\n        _, _, score = scorer.score(system_pred_file, gold_file)\n\n        logger.info(\"Lemma score:\")\n        logger.info(\"{} {:.2f}\".format(args['shorthand'], score*100))\n\n    return None, batch.doc\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/lang_identifier.py",
    "content": "\"\"\"\nEntry point for training and evaluating a Bi-LSTM language identifier\n\"\"\"\n\nimport argparse\nimport json\nimport logging\nimport os\nimport random\nimport torch\n\nfrom datetime import datetime\nfrom stanza.models.common import utils\nfrom stanza.models.langid.data import DataLoader\nfrom stanza.models.langid.trainer import Trainer\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nlogger = logging.getLogger('stanza')\n\ndef parse_args(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--batch_mode\", help=\"custom settings when running in batch mode\", action=\"store_true\")\n    parser.add_argument(\"--batch_size\", help=\"batch size for training\", type=int, default=64)\n    parser.add_argument(\"--eval_length\", help=\"length of strings to eval on\", type=int, default=None)\n    parser.add_argument(\"--eval_set\", help=\"eval on dev or test\", default=\"test\")\n    parser.add_argument(\"--data_dir\", help=\"directory with train/dev/test data\", default=None)\n    parser.add_argument(\"--load_name\", help=\"path to load model from\", default=None)\n    parser.add_argument(\"--mode\", help=\"train or eval\", default=\"train\")\n    parser.add_argument(\"--num_epochs\", help=\"number of epochs for training\", type=int, default=50)\n    parser.add_argument(\"--randomize\", help=\"take random substrings of samples\", action=\"store_true\")\n    parser.add_argument(\"--randomize_lengths_range\", help=\"range of lengths to use when random sampling text\",\n                        type=randomize_lengths_range, default=\"5,20\")\n    parser.add_argument(\"--merge_labels_for_eval\",\n                        help=\"merge some language labels for eval (e.g. \\\"zh-hans\\\" and \\\"zh-hant\\\" to \\\"zh\\\")\", \n                        action=\"store_true\")\n    parser.add_argument(\"--save_best_epochs\", help=\"save model for every epoch with new best score\", action=\"store_true\")\n    parser.add_argument(\"--save_name\", help=\"where to save model\", default=None)\n    utils.add_device_args(parser)\n    args = parser.parse_args(args=args)\n    return args\n\n\ndef randomize_lengths_range(range_list):\n    \"\"\"\n    Range of lengths for random samples\n    \"\"\"\n    range_boundaries = [int(x) for x in range_list.split(\",\")]\n    assert range_boundaries[0] < range_boundaries[1], f\"Invalid range: ({range_boundaries[0]}, {range_boundaries[1]})\"\n    return range_boundaries\n\n\ndef main(args=None):\n    args = parse_args(args=args)\n    torch.manual_seed(0)\n    if args.mode == \"train\":\n        train_model(args)\n    else:\n        eval_model(args)\n\n\ndef build_indexes(args):\n    tag_to_idx = {}\n    char_to_idx = {}\n    train_files = [f\"{args.data_dir}/{x}\" for x in os.listdir(args.data_dir) if \"train\" in x]\n    for train_file in train_files:\n        with open(train_file) as curr_file:\n            lines = curr_file.read().strip().split(\"\\n\")\n        examples = [json.loads(line) for line in lines if line.strip()]\n        for example in examples:\n            label = example[\"label\"]\n            if label not in tag_to_idx:\n                tag_to_idx[label] = len(tag_to_idx)\n            sequence = example[\"text\"]\n            for char in list(sequence):\n                if char not in char_to_idx:\n                    char_to_idx[char] = len(char_to_idx)\n    char_to_idx[\"UNK\"] = len(char_to_idx)\n    char_to_idx[\"<PAD>\"] = len(char_to_idx)\n\n    return tag_to_idx, char_to_idx\n\n\ndef train_model(args):\n    # set up indexes\n    tag_to_idx, char_to_idx = build_indexes(args)\n    # load training data\n    train_data = DataLoader(args.device)\n    train_files = [f\"{args.data_dir}/{x}\" for x in os.listdir(args.data_dir) if \"train\" in x]\n    train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)\n    # load dev data\n    dev_data = DataLoader(args.device)\n    dev_files = [f\"{args.data_dir}/{x}\" for x in os.listdir(args.data_dir) if \"dev\" in x]\n    dev_data.load_data(args.batch_size, dev_files, char_to_idx, tag_to_idx, randomize=False, \n                       max_length=args.eval_length)\n    # set up trainer\n    trainer_config = {\n        \"model_path\": args.save_name,\n        \"char_to_idx\": char_to_idx,\n        \"tag_to_idx\": tag_to_idx,\n        \"batch_size\": args.batch_size,\n        \"lang_weights\": train_data.lang_weights\n    }\n    if args.load_name:\n        trainer_config[\"load_name\"] = args.load_name\n        logger.info(f\"{datetime.now()}\\tLoading model from: {args.load_name}\")\n    trainer = Trainer(trainer_config, load_model=args.load_name is not None, device=args.device)\n    # run training\n    best_accuracy = 0.0\n    for epoch in range(1, args.num_epochs+1):\n        logger.info(f\"{datetime.now()}\\tEpoch {epoch}\")\n        logger.info(f\"{datetime.now()}\\tNum training batches: {len(train_data.batches)}\")\n\n        batches = train_data.batches\n        if not args.batch_mode:\n            batches = tqdm(batches)\n        for train_batch in batches:\n            inputs = (train_batch[\"sentences\"], train_batch[\"targets\"])\n            trainer.update(inputs)\n\n        logger.info(f\"{datetime.now()}\\tEpoch complete. Evaluating on dev data.\")\n        curr_dev_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \\\n            eval_trainer(trainer, dev_data, batch_mode=args.batch_mode)\n        logger.info(f\"{datetime.now()}\\tCurrent dev accuracy: {curr_dev_accuracy}\")\n        if curr_dev_accuracy > best_accuracy:\n            logger.info(f\"{datetime.now()}\\tNew best score. Saving model.\")\n            model_label = f\"epoch{epoch}\" if args.save_best_epochs else None\n            trainer.save(label=model_label)\n            with open(score_log_path(args.save_name), \"w\") as score_log_file:\n                for score_log in [{\"dev_accuracy\": curr_dev_accuracy}, curr_confusion_matrix, curr_precisions,\n                                  curr_recalls, curr_f1s]:\n                    score_log_file.write(json.dumps(score_log) + \"\\n\")\n            best_accuracy = curr_dev_accuracy\n\n        # reload training data\n        logger.info(f\"{datetime.now()}\\tResampling training data.\")\n        train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)\n\n\ndef score_log_path(file_path):\n    \"\"\"\n    Helper that will determine corresponding log file (e.g. /path/to/demo.pt to /path/to/demo.json\n    \"\"\"\n    model_suffix = os.path.splitext(file_path)\n    if model_suffix[1]:\n        score_log_path = f\"{file_path[:-len(model_suffix[1])]}.json\"\n    else:\n        score_log_path = f\"{file_path}.json\"\n    return score_log_path\n\n\ndef eval_model(args):\n    # set up trainer\n    trainer_config = {\n        \"model_path\": None,\n        \"load_name\": args.load_name,\n        \"batch_size\": args.batch_size\n    }\n    trainer = Trainer(trainer_config, load_model=True, device=args.device)\n    # load test data\n    test_data = DataLoader(args.device)\n    test_files = [f\"{args.data_dir}/{x}\" for x in os.listdir(args.data_dir) if args.eval_set in x]\n    test_data.load_data(args.batch_size, test_files, trainer.model.char_to_idx, trainer.model.tag_to_idx, \n                        randomize=False, max_length=args.eval_length)\n    curr_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \\\n        eval_trainer(trainer, test_data, batch_mode=args.batch_mode, fine_grained=not args.merge_labels_for_eval)\n    logger.info(f\"{datetime.now()}\\t{args.eval_set} accuracy: {curr_accuracy}\")\n    eval_save_path = args.save_name if args.save_name else score_log_path(args.load_name)\n    if not os.path.exists(eval_save_path) or args.save_name:\n        with open(eval_save_path, \"w\") as score_log_file:\n            for score_log in [{\"dev_accuracy\": curr_accuracy}, curr_confusion_matrix, curr_precisions,\n                              curr_recalls, curr_f1s]:\n                score_log_file.write(json.dumps(score_log) + \"\\n\")\n        \n\n\ndef eval_trainer(trainer, dev_data, batch_mode=False, fine_grained=True):\n    \"\"\"\n    Produce dev accuracy and confusion matrix for a trainer\n    \"\"\"\n\n    # set up confusion matrix\n    tag_to_idx = dev_data.tag_to_idx\n    idx_to_tag = dev_data.idx_to_tag\n    confusion_matrix = {}\n    for row_label in tag_to_idx:\n        confusion_matrix[row_label] = {}\n        for col_label in tag_to_idx:\n            confusion_matrix[row_label][col_label] = 0\n\n    # process dev batches\n    batches = dev_data.batches\n    if not batch_mode:\n        batches = tqdm(batches)\n    for dev_batch in batches:\n        inputs = (dev_batch[\"sentences\"], dev_batch[\"targets\"])\n        predictions = trainer.predict(inputs)\n        for target_idx, prediction in zip(dev_batch[\"targets\"], predictions):\n            prediction_label = idx_to_tag[prediction] if fine_grained else idx_to_tag[prediction].split(\"-\")[0]\n            confusion_matrix[idx_to_tag[target_idx]][prediction_label] += 1\n\n    # calculate dev accuracy\n    total_examples = sum([sum([confusion_matrix[i][j] for j in confusion_matrix[i]]) for i in confusion_matrix])\n    total_correct = sum([confusion_matrix[i][i] for i in confusion_matrix])\n    dev_accuracy = float(total_correct) / float(total_examples)\n\n    # calculate precision, recall, F1\n    precision_scores = {\"type\": \"precision\"}\n    recall_scores = {\"type\": \"recall\"}\n    f1_scores = {\"type\": \"f1\"}\n    for prediction_label in tag_to_idx:\n        total = sum([confusion_matrix[k][prediction_label] for k in tag_to_idx])\n        if total != 0.0:\n            precision_scores[prediction_label] = float(confusion_matrix[prediction_label][prediction_label])/float(total)\n        else:\n            precision_scores[prediction_label] = 0.0\n    for target_label in tag_to_idx:\n        total = sum([confusion_matrix[target_label][k] for k in tag_to_idx])\n        if total != 0:\n            recall_scores[target_label] = float(confusion_matrix[target_label][target_label])/float(total)\n        else:\n            recall_scores[target_label] = 0.0\n    for label in tag_to_idx:\n        if precision_scores[label] == 0.0 and recall_scores[label] == 0.0:\n            f1_scores[label] = 0.0\n        else:\n            f1_scores[label] = \\\n                2.0 * (precision_scores[label] * recall_scores[label]) / (precision_scores[label] + recall_scores[label])\n\n    return dev_accuracy, confusion_matrix, precision_scores, recall_scores, f1_scores\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/models/langid/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/langid/create_ud_data.py",
    "content": "\"\"\"\nScript for producing training/dev/test data from UD data or sentences\n\nExample output data format (one example per line):\n\n{\"text\": \"Hello world.\", \"label\": \"en\"}\n\nThis is an attempt to recreate data pre-processing in https://github.com/AU-DIS/LSTM_langid\n\nSpecifically borrows methods from https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py\n\nData format is same as LSTM_langid as well.\n\"\"\"\n\nimport argparse\nimport json\nimport logging\nimport os\nimport re\nimport sys\n\nfrom pathlib import Path\nfrom random import randint, random, shuffle\nfrom string import digits\nfrom tqdm import tqdm\n\nfrom stanza.models.common.constant import treebank_to_langid\n\nlogger = logging.getLogger('stanza')\n\nDEFAULT_LANGUAGES = \"af,ar,be,bg,bxr,ca,cop,cs,cu,da,de,el,en,es,et,eu,fa,fi,fr,fro,ga,gd,gl,got,grc,he,hi,hr,hsb,hu,hy,id,it,ja,kk,kmr,ko,la,lt,lv,lzh,mr,mt,nl,nn,no,olo,orv,pl,pt,ro,ru,sk,sl,sme,sr,sv,swl,ta,te,tr,ug,uk,ur,vi,wo,zh-hans,zh-hant\".split(\",\")\n\ndef parse_args(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-format\", help=\"input data format\", choices=[\"ud\", \"one-per-line\"], default=\"ud\")\n    parser.add_argument(\"--eval-length\", help=\"length of eval strings\", type=int, default=10)\n    parser.add_argument(\"--languages\", help=\"list of languages to use, or \\\"all\\\"\", default=DEFAULT_LANGUAGES)\n    parser.add_argument(\"--min-window\", help=\"minimal training example length\", type=int, default=10)\n    parser.add_argument(\"--max-window\", help=\"maximum training example length\", type=int, default=50)\n    parser.add_argument(\"--ud-path\", help=\"path to ud data\")\n    parser.add_argument(\"--save-path\", help=\"path to save data\", default=\".\")\n    parser.add_argument(\"--splits\", help=\"size of train/dev/test splits in percentages\", type=splits_from_list, \n                        default=\"0.8,0.1,0.1\")\n    args = parser.parse_args(args=args)\n    return args\n\n\ndef splits_from_list(value_list):\n    return [float(x) for x in value_list.split(\",\")]\n\n\ndef main(args=None):\n    args = parse_args(args=args)\n    if isinstance(args.languages, str):\n        args.languages = args.languages.split(\",\")\n    data_paths = [f\"{args.save_path}/{data_split}.jsonl\" for data_split in [\"train\", \"dev\", \"test\"]]\n    lang_to_files = collect_files(args.ud_path, args.languages, data_format=args.data_format)\n    logger.info(f\"Building UD data for languages: {','.join(args.languages)}\")\n    for lang_id in tqdm(lang_to_files):\n        lang_examples = generate_examples(lang_id, lang_to_files[lang_id], splits=args.splits, \n                                          min_window=args.min_window, max_window=args.max_window, \n                                          eval_length=args.eval_length, data_format=args.data_format)\n        for (data_set, save_path) in zip(lang_examples, data_paths):\n            with open(save_path, \"a\") as json_file:\n                for json_entry in data_set:\n                    json.dump(json_entry, json_file, ensure_ascii=False)\n                    json_file.write(\"\\n\")\n\n\ndef collect_files(ud_path, languages, data_format=\"ud\"):\n    \"\"\" \n    Given path to UD, collect files\n    If data_format = \"ud\", expects files to be of form *.conllu\n    If data_format = \"one-per-line\", expects files to be of form \"*.sentences.txt\"\n    In all cases, the UD path should be a directory with subdirectories for each language\n    \"\"\"\n    data_format_to_search_path = {\"ud\": \"*/*.conllu\", \"one-per-line\": \"*/*sentences.txt\"}\n    ud_files = Path(ud_path).glob(data_format_to_search_path[data_format])\n    lang_to_files = {}\n    for ud_file in ud_files:\n        if data_format == \"ud\":\n            lang_id = treebank_to_langid(ud_file.parent.name)\n        else:\n            lang_id = ud_file.name.split(\"_\")[0]\n        if lang_id not in languages and \"all\" not in languages:\n            continue\n        if not lang_id in lang_to_files:\n            lang_to_files[lang_id] = []\n        lang_to_files[lang_id].append(ud_file)\n    return lang_to_files\n\n\ndef generate_examples(lang_id, list_of_files, splits=(0.8,0.1,0.1), min_window=10, max_window=50, \n                      eval_length=10, data_format=\"ud\"):\n    \"\"\"\n    Generate train/dev/test examples for a given language\n    \"\"\"\n    examples = []\n    for ud_file in list_of_files:\n        sentences = sentences_from_file(ud_file, data_format=data_format)\n        for sentence in sentences:\n            sentence = clean_sentence(sentence)\n            if validate_sentence(sentence, min_window):\n                examples += sentence_to_windows(sentence, min_window=min_window, max_window=max_window)\n    shuffle(examples)\n    train_idx = int(splits[0] * len(examples))\n    train_set = [example_json(lang_id, example) for example in examples[:train_idx]]\n    dev_idx = int(splits[1] * len(examples)) + train_idx\n    dev_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[train_idx:dev_idx]]\n    test_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[dev_idx:]]\n    return train_set, dev_set, test_set\n\n\ndef sentences_from_file(ud_file_path, data_format=\"ud\"):\n    \"\"\"\n    Retrieve all sentences from a UD file\n    \"\"\"\n    if data_format == \"ud\":\n        with open(ud_file_path) as ud_file:\n            ud_file_contents = ud_file.read().strip()\n            assert \"# text = \" in ud_file_contents, \\\n                   f\"{ud_file_path} does not have expected format, \\\"# text =\\\" does not appear\"\n            sentences = [x[9:] for x in ud_file_contents.split(\"\\n\") if x.startswith(\"# text = \")]\n    elif data_format == \"one-per-line\":\n        with open(ud_file_path) as ud_file:\n            sentences = [x for x in ud_file.read().strip().split(\"\\n\") if x]\n    return sentences\n\n\ndef sentence_to_windows(sentence, min_window, max_window):\n    \"\"\"\n    Create window size chunks from a sentence, always starting with a word\n    \"\"\"\n    windows = []\n    words = sentence.split(\" \")\n    curr_window = \"\"\n    for idx, word in enumerate(words):\n        curr_window += (\" \" + word)\n        curr_window = curr_window.lstrip()\n        next_word_len = len(words[idx+1]) + 1 if idx+1 < len(words) else 0\n        if len(curr_window) + next_word_len > max_window:\n            curr_window = clean_sentence(curr_window)\n            if validate_sentence(curr_window, min_window):\n                windows.append(curr_window.strip())\n            curr_window = \"\"\n    if len(curr_window) >= min_window:\n        windows.append(curr_window)\n    return windows\n\n\ndef validate_sentence(current_window, min_window):\n    \"\"\"\n    Sentence validation from: LSTM-LID\n    GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py\n    \"\"\"\n    if len(current_window) < min_window:\n        return False\n    return True\n\ndef find(s, ch):\n    \"\"\" \n    Helper for clean_sentence from LSTM-LID\n    GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py \n    \"\"\"\n    return [i for i, ltr in enumerate(s) if ltr == ch]\n\n\ndef clean_sentence(line):\n    \"\"\" \n    Sentence cleaning from LSTM-LID\n    GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py\n    \"\"\"\n    # We remove some special characters and fix small errors in the data, to improve the quality of the data\n    line = line.replace(\"\\n\", '') #{\"text\": \"- Mor.\\n\", \"label\": \"da\"}\n    line = line.replace(\"- \", '') #{\"text\": \"- Mor.\", \"label\": \"da\"}\n    line = line.replace(\"_\", '') #{\"text\": \"- Mor.\", \"label\": \"da\"}\n    line = line.replace(\"\\\\\", '')\n    line = line.replace(\"\\\"\", '')\n    line = line.replace(\"  \", \" \")\n    remove_digits = str.maketrans('', '', digits)\n    line = line.translate(remove_digits)\n    words = line.split()\n    new_words = []\n    # Below fixes large I instead of l. Does not catch everything, but should also not really make any mistakes either\n    for word in words:\n        clean_word = word\n        s = clean_word\n        if clean_word[1:].__contains__(\"I\"):\n            indices = find(clean_word, \"I\")\n            for indx in indices:\n                if clean_word[indx-1].islower():\n                    if len(clean_word) > indx + 1:\n                        if clean_word[indx+1].islower():\n                            s = s[:indx] + \"l\" + s[indx + 1:]\n                    else:\n                        s = s[:indx] + \"l\" + s[indx + 1:]\n        new_words.append(s)\n    new_line = \" \".join(new_words)\n    return new_line\n\n\ndef example_json(lang_id, text, eval_length=None):\n    if eval_length is not None:\n        text = text[:eval_length]\n    return {\"text\": text.strip(), \"label\": lang_id}\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/models/langid/data.py",
    "content": "import json\nimport random\nimport torch\n\n\nclass DataLoader:\n    \"\"\"\n    Class for loading language id data and providing batches\n\n    Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid\n\n    Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py\n\n    Data format is same as LSTM_langid\n    \"\"\"\n\n    def __init__(self, device=None):\n        self.batches = None\n        self.batches_iter = None\n        self.tag_to_idx = None\n        self.idx_to_tag = None\n        self.lang_weights = None\n        self.device = device\n\n    def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20),\n                  max_length=None):\n        \"\"\"\n        Load sequence data and labels, calculate weights for weighted cross entropy loss.\n        Data is stored in a file, 1 example per line\n        Example: {\"text\": \"Hello world.\", \"label\": \"en\"}\n        \"\"\"\n\n        # set up examples from data files\n        examples = []\n        for data_file in data_files:\n            examples += [x for x in open(data_file).read().split(\"\\n\") if x.strip()]\n        random.shuffle(examples)\n        examples = [json.loads(x) for x in examples]\n\n        # add additional labels in this data set to tag index\n        tag_index = dict(tag_index)\n        new_labels = set([x[\"label\"] for x in examples]) - set(tag_index.keys())\n        for new_label in new_labels:\n            tag_index[new_label] = len(tag_index)\n        self.tag_to_idx = tag_index\n        self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]\n        \n        # set up lang counts used for weights for cross entropy loss\n        lang_counts = [0 for _ in tag_index]\n\n        # optionally limit text to max length\n        if max_length is not None:\n            examples = [{\"text\": x[\"text\"][:max_length], \"label\": x[\"label\"]} for x in examples]\n\n        # randomize data\n        if randomize:\n            split_examples = []\n            for example in examples:\n                sequence = example[\"text\"]\n                label = example[\"label\"]\n                sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1], \n                                                      lower_lim=randomize_range[0])\n                split_examples += [{\"text\": seq, \"label\": label} for seq in sequences]\n            examples = split_examples\n            random.shuffle(examples)\n\n        # break into equal length batches\n        batch_lengths = {}\n        for example in examples:\n            sequence = example[\"text\"]\n            label = example[\"label\"]\n            if len(sequence) not in batch_lengths:\n                batch_lengths[len(sequence)] = []\n            sequence_as_list = [char_index.get(c, char_index[\"UNK\"]) for c in list(sequence)]\n            batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label]))\n            lang_counts[tag_index[label]] += 1\n        for length in batch_lengths:\n            random.shuffle(batch_lengths[length])\n\n        # create final set of batches\n        batches = []\n        for length in batch_lengths:\n            for sublist in [batch_lengths[length][i:i + batch_size] for i in\n                            range(0, len(batch_lengths[length]), batch_size)]:\n                batches.append(sublist)\n\n        self.batches = [self.build_batch_tensors(batch) for batch in batches]\n\n        # set up lang weights\n        most_frequent = max(lang_counts)\n        # set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise\n        lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts]\n        self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float)\n\n        # shuffle batches to mix up lengths\n        random.shuffle(self.batches)\n        self.batches_iter = iter(self.batches)\n\n    @staticmethod\n    def randomize_data(sentences, upper_lim=20, lower_lim=5):\n        \"\"\"\n        Takes the original data and creates random length examples with length between upper limit and lower limit\n        From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py\n        \"\"\"\n\n        new_data = []\n        for sentence in sentences:\n            remaining = sentence\n            while lower_lim < len(remaining):\n                lim = random.randint(lower_lim, upper_lim)\n                m = min(len(remaining), lim)\n                new_sentence = remaining[:m]\n                new_data.append(new_sentence)\n                split = remaining[m:].split(\" \", 1)\n                if len(split) <= 1:\n                    break\n                remaining = split[1]\n        random.shuffle(new_data)\n        return new_data\n\n    def build_batch_tensors(self, batch):\n        \"\"\"\n        Helper to turn batches into tensors\n        \"\"\"\n\n        batch_tensors = dict()\n        batch_tensors[\"sentences\"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long)\n        batch_tensors[\"targets\"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long)\n\n        return batch_tensors\n\n    def next(self):\n        return next(self.batches_iter)\n\n"
  },
  {
    "path": "stanza/models/langid/model.py",
    "content": "import os\n\nimport torch\nimport torch.nn as nn\n\n\nclass LangIDBiLSTM(nn.Module):\n    \"\"\"\n    Multi-layer BiLSTM model for language detecting. A recreation of \"A reproduction of Apple's bi-directional LSTM models\n    for language identification in short strings.\" (Toftrup et al 2021)\n\n    Arxiv: https://arxiv.org/abs/2102.06282\n    GitHub: https://github.com/AU-DIS/LSTM_langid\n\n    This class is similar to https://github.com/AU-DIS/LSTM_langid/blob/main/src/LSTMLID.py\n    \"\"\"\n\n    def __init__(self, char_to_idx, tag_to_idx, num_layers, embedding_dim, hidden_dim, batch_size=64, weights=None, \n                 dropout=0.0, lang_subset=None):\n        super(LangIDBiLSTM, self).__init__()\n        self.num_layers = num_layers\n        self.embedding_dim = embedding_dim\n        self.hidden_dim = hidden_dim\n        self.char_to_idx = char_to_idx\n        self.vocab_size = len(char_to_idx)\n        self.tag_to_idx = tag_to_idx\n        self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]\n        self.lang_subset = lang_subset\n        self.padding_idx = char_to_idx[\"<PAD>\"]\n        self.tagset_size = len(tag_to_idx)\n        self.batch_size = batch_size\n        self.loss_train = nn.CrossEntropyLoss(weight=weights)\n        self.dropout_prob = dropout\n        \n        # embeddings for chars\n        self.char_embeds = nn.Embedding(\n                num_embeddings=self.vocab_size, \n                embedding_dim=self.embedding_dim,\n                padding_idx=self.padding_idx\n        )\n\n        # the bidirectional LSTM\n        self.lstm = nn.LSTM(\n                self.embedding_dim, \n                self.hidden_dim,\n                num_layers=self.num_layers,\n                bidirectional=True,\n                batch_first=True\n        )\n\n        # convert output to tag space\n        self.hidden_to_tag = nn.Linear(\n                self.hidden_dim * 2, \n                self.tagset_size\n        )\n\n        # dropout layer\n        self.dropout = nn.Dropout(p=self.dropout_prob)\n\n    def build_lang_mask(self, device):\n        \"\"\"\n        Build language mask if a lang subset is specified (e.g. [\"en\", \"fr\"])\n\n        The mask will be added to the results to set the prediction scores of illegal languages to -inf\n        \"\"\"\n        if self.lang_subset:\n            lang_mask_list = [0.0 if lang in self.lang_subset else -float('inf') for lang in self.idx_to_tag]\n            self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)\n        else:\n            self.lang_mask = torch.zeros(len(self.idx_to_tag), device=device, dtype=torch.float)\n\n    def loss(self, Y_hat, Y):\n        return self.loss_train(Y_hat, Y)\n\n    def forward(self, x):\n        # embed input\n        x = self.char_embeds(x)\n        \n        # run through LSTM\n        x, _ = self.lstm(x)\n        \n        # run through linear layer\n        x = self.hidden_to_tag(x)\n        \n        # sum character outputs for each sequence\n        x = torch.sum(x, dim=1)\n\n        return x\n\n    def prediction_scores(self, x):\n        prediction_probs = self(x)\n        if self.lang_subset:\n            prediction_batch_size = prediction_probs.size()[0]\n            batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)])\n            prediction_probs = prediction_probs + batch_mask\n        return torch.argmax(prediction_probs, dim=1)\n\n    def save(self, path):\n        \"\"\" Save a model at path \"\"\"\n        checkpoint = {\n            \"char_to_idx\": self.char_to_idx,\n            \"tag_to_idx\": self.tag_to_idx,\n            \"num_layers\": self.num_layers,\n            \"embedding_dim\": self.embedding_dim,\n            \"hidden_dim\": self.hidden_dim,\n            \"model_state_dict\": self.state_dict()\n        }\n        torch.save(checkpoint, path)\n    \n    @classmethod\n    def load(cls, path, device=None, batch_size=64, lang_subset=None):\n        \"\"\" Load a serialized model located at path \"\"\"\n        if path is None:\n            raise FileNotFoundError(\"Trying to load langid model, but path not specified!  Try --load_name\")\n        if not os.path.exists(path):\n            raise FileNotFoundError(\"Trying to load langid model from path which does not exist: %s\" % path)\n        checkpoint = torch.load(path, map_location=torch.device(\"cpu\"), weights_only=True)\n        weights = checkpoint[\"model_state_dict\"][\"loss_train.weight\"]\n        model = cls(checkpoint[\"char_to_idx\"], checkpoint[\"tag_to_idx\"], checkpoint[\"num_layers\"],\n                    checkpoint[\"embedding_dim\"], checkpoint[\"hidden_dim\"], batch_size=batch_size, weights=weights,\n                    lang_subset=lang_subset)\n        model.load_state_dict(checkpoint[\"model_state_dict\"])\n        model = model.to(device)\n        model.build_lang_mask(device)\n        return model\n\n"
  },
  {
    "path": "stanza/models/langid/trainer.py",
    "content": "import torch\nimport torch.optim as optim\n\nfrom stanza.models.langid.model import LangIDBiLSTM\n\n\nclass Trainer:\n\n    DEFAULT_BATCH_SIZE = 64\n    DEFAULT_LAYERS = 2\n    DEFAULT_EMBEDDING_DIM = 150\n    DEFAULT_HIDDEN_DIM = 150\n\n    def __init__(self, config, load_model=False, device=None):\n        self.model_path = config[\"model_path\"]\n        self.batch_size = config.get(\"batch_size\", Trainer.DEFAULT_BATCH_SIZE)\n        if load_model:\n            self.load(config[\"load_name\"], device)\n        else:\n            self.model = LangIDBiLSTM(config[\"char_to_idx\"], config[\"tag_to_idx\"], Trainer.DEFAULT_LAYERS, \n                                      Trainer.DEFAULT_EMBEDDING_DIM,\n                                      Trainer.DEFAULT_HIDDEN_DIM,\n                                      batch_size=self.batch_size,\n                                      weights=config[\"lang_weights\"]).to(device)\n        self.optimizer = optim.AdamW(self.model.parameters())\n\n    def update(self, inputs):\n        self.model.train()\n        sentences, targets = inputs\n        self.optimizer.zero_grad()\n        y_hat = self.model.forward(sentences)\n        loss = self.model.loss(y_hat, targets)\n        loss.backward()\n        self.optimizer.step()\n\n    def predict(self, inputs):\n        self.model.eval()\n        sentences, targets = inputs\n        return torch.argmax(self.model(sentences), dim=1)\n\n    def save(self, label=None):\n        # save a copy of model with label\n        if label:\n            self.model.save(f\"{self.model_path[:-3]}-{label}.pt\")\n        self.model.save(self.model_path)\n\n    def load(self, model_path=None, device=None):\n        if not model_path:\n            model_path = self.model_path\n        self.model = LangIDBiLSTM.load(model_path, device, self.batch_size)\n\n"
  },
  {
    "path": "stanza/models/lemma/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/lemma/attach_lemma_classifier.py",
    "content": "import argparse\n\nfrom stanza.models.lemma.trainer import Trainer\nfrom stanza.models.lemma_classifier.base_model import LemmaClassifier\n\ndef attach_classifier(input_filename, output_filename, classifiers):\n    trainer = Trainer(model_file=input_filename)\n\n    for classifier in classifiers:\n        classifier = LemmaClassifier.load(classifier)\n        trainer.contextual_lemmatizers.append(classifier)\n\n    trainer.save(output_filename)\n\ndef main(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input', type=str, required=True, help='Which lemmatizer to start from')\n    parser.add_argument('--output', type=str, required=True, help='Where to save the lemmatizer')\n    parser.add_argument('--classifier', type=str, required=True, nargs='+', help='Lemma classifier to attach')\n    args = parser.parse_args(args)\n\n    attach_classifier(args.input, args.output, args.classifier)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/lemma/data.py",
    "content": "import random\nimport numpy as np\nimport os\nfrom collections import Counter\nimport logging\nimport torch\n\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all\nfrom stanza.models.common.vocab import DeltaVocab\nfrom stanza.models.lemma.vocab import Vocab, MultiVocab\nfrom stanza.models.lemma import edit\nfrom stanza.models.common.doc import *\n\nlogger = logging.getLogger('stanza')\n\nclass DataLoader:\n    def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, conll_only=False, skip=None, expand_unk_vocab=False):\n        self.batch_size = batch_size\n        self.args = args\n        self.eval = evaluation\n        self.shuffled = not self.eval\n        self.doc = doc\n\n        data = self.raw_data()\n\n        if conll_only: # only load conll file\n            return\n\n        if skip is not None:\n            assert len(data) == len(skip)\n            data = [x for x, y in zip(data, skip) if not y]\n\n        # handle vocab\n        if vocab is not None:\n            if expand_unk_vocab:\n                pos_vocab = vocab['pos']\n                char_vocab = DeltaVocab(data, vocab['char'])\n                self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab})\n            else:\n                self.vocab = vocab\n        else:\n            self.vocab = dict()\n            char_vocab, pos_vocab = self.init_vocab(data)\n            self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab})\n\n        # filter and sample data\n        if args.get('sample_train', 1.0) < 1.0 and not self.eval:\n            keep = int(args['sample_train'] * len(data))\n            data = random.sample(data, keep)\n            logger.debug(\"Subsample training set with rate {:g}\".format(args['sample_train']))\n\n        data = self.preprocess(data, self.vocab['char'], self.vocab['pos'], args)\n        # shuffle for training\n        if self.shuffled:\n            indices = list(range(len(data)))\n            random.shuffle(indices)\n            data = [data[i] for i in indices]\n        self.num_examples = len(data)\n\n        # chunk into batches\n        data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]\n        self.data = data\n        logger.debug(\"{} batches created.\".format(len(data)))\n\n    def init_vocab(self, data):\n        assert self.eval is False, \"Vocab file must exist for evaluation\"\n        char_data = \"\".join(d[0] + d[2] for d in data)\n        char_vocab = Vocab(char_data, self.args['lang'])\n        pos_data = [d[1] for d in data]\n        pos_vocab = Vocab(pos_data, self.args['lang'])\n        return char_vocab, pos_vocab\n\n    def preprocess(self, data, char_vocab, pos_vocab, args):\n        processed = []\n        for d in data:\n            edit_type = edit.EDIT_TO_ID[edit.get_edit_type(d[0], d[2])]\n            src = list(d[0])\n            src = [constant.SOS] + src + [constant.EOS]\n            src = char_vocab.map(src)\n            pos = d[1]\n            pos = pos_vocab.unit2id(pos)\n            tgt = list(d[2])\n            tgt_in = char_vocab.map([constant.SOS] + tgt)\n            tgt_out = char_vocab.map(tgt + [constant.EOS])\n            processed += [[src, tgt_in, tgt_out, pos, edit_type, d[0]]]\n        return processed\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, key):\n        \"\"\" Get a batch with index. \"\"\"\n        if not isinstance(key, int):\n            raise TypeError\n        if key < 0 or key >= len(self.data):\n            raise IndexError\n        batch = self.data[key]\n        batch_size = len(batch)\n        batch = list(zip(*batch))\n        assert len(batch) == 6\n\n        # sort all fields by lens for easy RNN operations\n        lens = [len(x) for x in batch[0]]\n        batch, orig_idx = sort_all(batch, lens)\n\n        # convert to tensors\n        src = batch[0]\n        src = get_long_tensor(src, batch_size)\n        src_mask = torch.eq(src, constant.PAD_ID)\n        tgt_in = get_long_tensor(batch[1], batch_size)\n        tgt_out = get_long_tensor(batch[2], batch_size)\n        pos = torch.LongTensor(batch[3])\n        edits = torch.LongTensor(batch[4])\n        text = batch[5]\n        assert tgt_in.size(1) == tgt_out.size(1), \"Target input and output sequence sizes do not match.\"\n        return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx, text\n\n    def __iter__(self):\n        for i in range(self.__len__()):\n            yield self.__getitem__(i)\n\n    def raw_data(self):\n        return self.load_doc(self.doc, self.args.get('caseless', False), self.args.get('skip_blank_lemmas', False), self.eval)\n\n    @staticmethod\n    def load_doc(doc, caseless, skip_blank_lemmas, evaluation):\n        if evaluation:\n            data = doc.get([TEXT, UPOS, LEMMA])\n        else:\n            data = doc.get([TEXT, UPOS, LEMMA, HEAD, DEPREL, MISC], as_sentences=True)\n            data = DataLoader.remove_goeswith(data)\n            data = DataLoader.extract_correct_forms(data)\n        data = DataLoader.resolve_none(data)\n        if not evaluation and skip_blank_lemmas:\n            data = DataLoader.skip_blank_lemmas(data)\n        if caseless:\n            data = DataLoader.lowercase_data(data)\n        return data\n\n    @staticmethod\n    def extract_correct_forms(data):\n        \"\"\"\n        Here we go through the raw data and use the CorrectForm of words tagged with CorrectForm\n\n        In addition, if the incorrect form of the word is not present in the training data,\n        we keep the incorrect form for the lemmatizer to learn from.\n        This way, it can occasionally get things right in misspelled input text.\n\n        We do check for and eliminate words where the incorrect form is already known as the\n        lemma for a different word.  For example, in the English datasets, there is a \"busy\"\n        which was meant to be \"buys\", and we don't want the model to learn to lemmatize \"busy\" to \"buy\"\n        \"\"\"\n        new_data = []\n        incorrect_forms = []\n        for word in data:\n            misc = word[-1]\n            if not misc:\n                new_data.append(word[:3])\n                continue\n            misc = misc.split(\"|\")\n            for piece in misc:\n                if piece.startswith(\"CorrectForm=\"):\n                    cf = piece.split(\"=\", maxsplit=1)[1]\n                    # treat the CorrectForm as the desired word\n                    new_data.append((cf, word[1], word[2]))\n                    # and save the broken one for later in case it wasn't used anywhere else\n                    incorrect_forms.append((cf, word))\n                    break\n            else:\n                # if no CorrectForm, just keep the word as normal\n                new_data.append(word[:3])\n        known_words = {x[0] for x in new_data}\n        for correct_form, word in incorrect_forms:\n            if word[0] not in known_words:\n                new_data.append(word[:3])\n        return new_data\n\n    @staticmethod\n    def remove_goeswith(data):\n        \"\"\"\n        This method specifically removes words that goeswith something else, along with the something else\n\n        The purpose is to eliminate text such as\n\n1\tKen\tkenrice@enroncommunications\tX\tGW\tTypo=Yes\t0\troot\t0:root\t_\n2\tRice@ENRON\t_\tX\tGW\t_\t1\tgoeswith\t1:goeswith\t_\n3\tCOMMUNICATIONS\t_\tX\tADD\t_\t1\tgoeswith\t1:goeswith\t_\n        \"\"\"\n        filtered_data = []\n        remove_indices = set()\n        for sentence in data:\n            remove_indices.clear()\n            for word_idx, word in enumerate(sentence):\n                if word[4] == 'goeswith':\n                    remove_indices.add(word_idx)\n                    remove_indices.add(word[3]-1)\n            filtered_data.extend([x for idx, x in enumerate(sentence) if idx not in remove_indices])\n        return filtered_data\n\n    @staticmethod\n    def lowercase_data(data):\n        for token in data:\n            token[0] = token[0].lower()\n        return data\n\n    @staticmethod\n    def skip_blank_lemmas(data):\n        data = [x for x in data if x[2] != '_']\n        return data\n\n    @staticmethod\n    def resolve_none(data):\n        # replace None to '_'\n        for tok_idx in range(len(data)):\n            for feat_idx in range(len(data[tok_idx])):\n                if data[tok_idx][feat_idx] is None:\n                    data[tok_idx][feat_idx] = '_'\n        return data\n"
  },
  {
    "path": "stanza/models/lemma/edit.py",
    "content": "\"\"\"\nUtilities for calculating edits between word and lemma forms.\n\"\"\"\n\nEDIT_TO_ID = {'none': 0, 'identity': 1, 'lower': 2}\n\ndef get_edit_type(word, lemma):\n    \"\"\" Calculate edit types. \"\"\"\n    if lemma == word:\n        return 'identity'\n    elif lemma == word.lower():\n        return 'lower'\n    return 'none'\n\ndef edit_word(word, pred, edit_id):\n    \"\"\"\n    Edit a word, given edit and seq2seq predictions.\n    \"\"\"\n    if edit_id == 1:\n        return word\n    elif edit_id == 2:\n        return word.lower()\n    elif edit_id == 0:\n        return pred\n    else:\n        raise Exception(\"Unrecognized edit ID: {}\".format(edit_id))\n\n"
  },
  {
    "path": "stanza/models/lemma/scorer.py",
    "content": "\"\"\"\nUtils and wrappers for scoring lemmatizers.\n\"\"\"\n\nimport logging\n\nfrom stanza.models.common.utils import ud_scores\n\nlogger = logging.getLogger('stanza')\n\ndef score(system_conllu_file, gold_conllu_file):\n    \"\"\" Wrapper for lemma scorer. \"\"\"\n    logger.debug(\"Evaluating system file %s vs gold file %s\", system_conllu_file, gold_conllu_file)\n    evaluation = ud_scores(gold_conllu_file, system_conllu_file)\n    el = evaluation[\"Lemmas\"]\n    p, r, f = el.precision, el.recall, el.f1\n    return p, r, f\n\n"
  },
  {
    "path": "stanza/models/lemma/trainer.py",
    "content": "\"\"\"\nA trainer class to handle training and testing of models.\n\"\"\"\n\nimport os\nimport sys\nimport numpy as np\nfrom collections import Counter\nimport logging\nimport torch\nfrom torch import nn\nimport torch.nn.init as init\n\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.models.common.doc import TEXT, UPOS\nfrom stanza.models.common.foundation_cache import load_charlm\nfrom stanza.models.common.seq2seq_model import Seq2SeqModel\nfrom stanza.models.common.char_model import CharacterLanguageModelWordAdapter\nfrom stanza.models.common import utils, loss\nfrom stanza.models.lemma import edit\nfrom stanza.models.lemma.vocab import MultiVocab\nfrom stanza.models.lemma_classifier.base_model import LemmaClassifier\n\nlogger = logging.getLogger('stanza')\n\ndef unpack_batch(batch, device):\n    \"\"\" Unpack a batch from the data loader. \"\"\"\n    inputs = [b.to(device) if b is not None else None for b in batch[:6]]\n    orig_idx = batch[6]\n    text = batch[7]\n    return inputs, orig_idx, text\n\nclass Trainer(object):\n    \"\"\" A trainer for training models. \"\"\"\n    def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None, lemma_classifier_args=None):\n        if model_file is not None:\n            # load everything from file\n            self.load(model_file, args, foundation_cache, lemma_classifier_args)\n        else:\n            # build model from scratch\n            self.args = args\n            if args['dict_only']:\n                self.model = None\n            else:\n                self.model = self.build_seq2seq(args, emb_matrix, foundation_cache)\n            self.vocab = vocab\n            # dict-based components\n            self.word_dict = dict()\n            self.composite_dict = dict()\n            self.contextual_lemmatizers = []\n\n        self.caseless = self.args.get('caseless', False)\n\n        if not self.args['dict_only']:\n            self.model = self.model.to(device)\n            if self.args.get('edit', False):\n                self.crit = loss.MixLoss(self.vocab['char'].size, self.args['alpha']).to(device)\n                logger.debug(\"Running seq2seq lemmatizer with edit classifier...\")\n            else:\n                self.crit = loss.SequenceLoss(self.vocab['char'].size).to(device)\n            self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'])\n\n    def build_seq2seq(self, args, emb_matrix, foundation_cache):\n        charmodel = None\n        charlms = []\n        if args is not None and args.get('charlm_forward_file', None):\n            charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache)\n            charlms.append(charmodel_forward)\n        if args is not None and args.get('charlm_backward_file', None):\n            charmodel_backward = load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache)\n            charlms.append(charmodel_backward)\n        if len(charlms) > 0:\n            charlms = nn.ModuleList(charlms)\n            charmodel = CharacterLanguageModelWordAdapter(charlms)\n        model = Seq2SeqModel(args, emb_matrix=emb_matrix, contextual_embedding=charmodel)\n        return model\n\n    def update(self, batch, eval=False):\n        device = next(self.model.parameters()).device\n        inputs, orig_idx, text = unpack_batch(batch, device)\n        src, src_mask, tgt_in, tgt_out, pos, edits = inputs\n\n        if eval:\n            self.model.eval()\n        else:\n            self.model.train()\n            self.optimizer.zero_grad()\n        log_probs, edit_logits = self.model(src, src_mask, tgt_in, pos, raw=text)\n        if self.args.get('edit', False):\n            assert edit_logits is not None\n            loss = self.crit(log_probs.view(-1, self.vocab['char'].size), tgt_out.view(-1), \\\n                    edit_logits, edits)\n        else:\n            loss = self.crit(log_probs.view(-1, self.vocab['char'].size), tgt_out.view(-1))\n        loss_val = loss.data.item()\n        if eval:\n            return loss_val\n\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])\n        self.optimizer.step()\n        return loss_val\n\n    def predict(self, batch, beam_size=1, vocab=None):\n        if vocab is None:\n            vocab = self.vocab\n\n        device = next(self.model.parameters()).device\n        inputs, orig_idx, text = unpack_batch(batch, device)\n        src, src_mask, tgt, tgt_mask, pos, edits = inputs\n\n        self.model.eval()\n        batch_size = src.size(0)\n        preds, edit_logits = self.model.predict(src, src_mask, pos=pos, beam_size=beam_size, raw=text)\n        pred_seqs = [vocab['char'].unmap(ids) for ids in preds] # unmap to tokens\n        pred_seqs = utils.prune_decoded_seqs(pred_seqs)\n        pred_tokens = [\"\".join(seq) for seq in pred_seqs] # join chars to be tokens\n        pred_tokens = utils.unsort(pred_tokens, orig_idx)\n        if self.args.get('edit', False):\n            assert edit_logits is not None\n            edits = np.argmax(edit_logits.data.cpu().numpy(), axis=1).reshape([batch_size]).tolist()\n            edits = utils.unsort(edits, orig_idx)\n        else:\n            edits = None\n        return pred_tokens, edits\n\n    def postprocess(self, words, preds, edits=None):\n        \"\"\" Postprocess, mainly for handing edits. \"\"\"\n        assert len(words) == len(preds), \"Lemma predictions must have same length as words.\"\n        edited = []\n        if self.args.get('edit', False):\n            assert edits is not None and len(words) == len(edits)\n            for w, p, e in zip(words, preds, edits):\n                lem = edit.edit_word(w, p, e)\n                edited += [lem]\n        else:\n            edited = preds # do not edit\n        # final sanity check\n        assert len(edited) == len(words)\n        final = []\n        for lem, w in zip(edited, words):\n            if len(lem) == 0 or constant.UNK in lem:\n                final += [w] # invalid prediction, fall back to word\n            else:\n                final += [lem]\n        return final\n\n    def has_contextual_lemmatizers(self):\n        return self.contextual_lemmatizers is not None and len(self.contextual_lemmatizers) > 0\n\n    def predict_contextual(self, sentence_words, sentence_tags, preds):\n        if len(self.contextual_lemmatizers) == 0:\n            return preds\n\n        # reversed so that the first lemmatizer has priority\n        for contextual in reversed(self.contextual_lemmatizers):\n            pred_idx = []\n            pred_sent_words = []\n            pred_sent_tags = []\n            pred_sent_ids = []\n            for sent_id, (words, tags) in enumerate(zip(sentence_words, sentence_tags)):\n                indices = contextual.target_indices(words, tags)\n                for idx in indices:\n                    pred_idx.append(idx)\n                    pred_sent_words.append(words)\n                    pred_sent_tags.append(tags)\n                    pred_sent_ids.append(sent_id)\n            if len(pred_idx) == 0:\n                continue\n            contextual_predictions = contextual.predict(pred_idx, pred_sent_words, pred_sent_tags)\n            for sent_id, word_id, pred in zip(pred_sent_ids, pred_idx, contextual_predictions):\n                preds[sent_id][word_id] = pred\n        return preds\n\n    def update_contextual_preds(self, doc, preds):\n        \"\"\"\n        Update a flat list of preds with the output of the contextual lemmatizers\n\n        - First, it unflattens the preds based on the lengths of the sentences\n        - Then it uses the contextual lemmatizers\n        - Finally, it reflattens the preds into the format expected by the caller\n        \"\"\"\n        if len(self.contextual_lemmatizers) == 0:\n            return preds\n\n        sentence_words = doc.get([TEXT], as_sentences=True)\n        sentence_tags = doc.get([UPOS], as_sentences=True)\n        sentence_preds = []\n        start_index = 0\n        for sent in sentence_words:\n            end_index = start_index + len(sent)\n            sentence_preds.append(preds[start_index:end_index])\n            start_index += len(sent)\n        preds = self.predict_contextual(sentence_words, sentence_tags, sentence_preds)\n        preds = [lemma for sentence in preds for lemma in sentence]\n        return preds\n\n    def update_lr(self, new_lr):\n        utils.change_lr(self.optimizer, new_lr)\n\n    def train_dict(self, triples, update_word_dict=True):\n        \"\"\"\n        Train a dict lemmatizer given training (word, pos, lemma) triples.\n\n        Can update only the composite_dict (word/pos) in situations where\n        the data might be limited from the tags, such as when adding more\n        words at pipeline time\n        \"\"\"\n        # accumulate counter\n        ctr = Counter()\n        ctr.update([(p[0], p[1], p[2]) for p in triples])\n        # find the most frequent mappings\n        for p, _ in ctr.most_common():\n            w, pos, l = p\n            if (w,pos) not in self.composite_dict:\n                self.composite_dict[(w,pos)] = l\n            if update_word_dict and w not in self.word_dict:\n                self.word_dict[w] = l\n        return\n\n    def predict_dict(self, pairs):\n        \"\"\" Predict a list of lemmas using the dict model given (word, pos) pairs. \"\"\"\n        lemmas = []\n        for p in pairs:\n            w, pos = p\n            if self.caseless:\n                w = w.lower()\n            if (w,pos) in self.composite_dict:\n                lemmas += [self.composite_dict[(w,pos)]]\n            elif w in self.word_dict:\n                lemmas += [self.word_dict[w]]\n            else:\n                lemmas += [w]\n        return lemmas\n\n    def skip_seq2seq(self, pairs):\n        \"\"\" Determine if we can skip the seq2seq module when ensembling with the frequency lexicon. \"\"\"\n\n        skip = []\n        for p in pairs:\n            w, pos = p\n            if self.caseless:\n                w = w.lower()\n            if (w,pos) in self.composite_dict:\n                skip.append(True)\n            elif w in self.word_dict:\n                skip.append(True)\n            else:\n                skip.append(False)\n        return skip\n\n    def ensemble(self, pairs, other_preds):\n        \"\"\" Ensemble the dict with statistical model predictions. \"\"\"\n        lemmas = []\n        assert len(pairs) == len(other_preds)\n        for p, pred in zip(pairs, other_preds):\n            w, pos = p\n            if self.caseless:\n                w = w.lower()\n            if (w,pos) in self.composite_dict:\n                lemma = self.composite_dict[(w,pos)]\n            elif w in self.word_dict:\n                lemma = self.word_dict[w]\n            else:\n                lemma = pred\n            if lemma is None:\n                lemma = w\n            lemmas.append(lemma)\n        return lemmas\n\n    def save(self, filename, skip_modules=True):\n        model_state = None\n        if self.model is not None:\n            model_state = self.model.state_dict()\n            # skip saving modules like the pretrained charlm\n            if skip_modules:\n                skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]\n                for k in skipped:\n                    del model_state[k]\n        params = {\n            'model': model_state,\n            'dicts': (self.word_dict, self.composite_dict),\n            'vocab': self.vocab.state_dict(),\n            'config': self.args,\n            'contextual': [],\n        }\n        for contextual in self.contextual_lemmatizers:\n            params['contextual'].append(contextual.get_save_dict())\n        save_dir = os.path.split(filename)[0]\n        if save_dir:\n            os.makedirs(os.path.split(filename)[0], exist_ok=True)\n        torch.save(params, filename, _use_new_zipfile_serialization=False)\n        logger.info(\"Model saved to {}\".format(filename))\n\n    def load(self, filename, args, foundation_cache, lemma_classifier_args=None):\n        try:\n            checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        except BaseException:\n            logger.error(\"Cannot load model from {}\".format(filename))\n            raise\n        self.args = checkpoint['config']\n        if args is not None:\n            self.args['charlm_forward_file'] = args.get('charlm_forward_file', self.args['charlm_forward_file'])\n            self.args['charlm_backward_file'] = args.get('charlm_backward_file', self.args['charlm_backward_file'])\n        self.word_dict, self.composite_dict = checkpoint['dicts']\n        if not self.args['dict_only']:\n            self.model = self.build_seq2seq(self.args, None, foundation_cache)\n            # could remove strict=False after rebuilding all models,\n            # or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False\n            self.model.load_state_dict(checkpoint['model'], strict=False)\n        else:\n            self.model = None\n        self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])\n        self.contextual_lemmatizers = []\n        for contextual in checkpoint.get('contextual', []):\n            self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual, args=lemma_classifier_args))\n"
  },
  {
    "path": "stanza/models/lemma/vocab.py",
    "content": "from collections import Counter\n\nfrom stanza.models.common.vocab import BaseVocab, BaseMultiVocab\nfrom stanza.models.common.seq2seq_constant import VOCAB_PREFIX\n\nclass Vocab(BaseVocab):\n    def build_vocab(self):\n        counter = Counter(self.data)\n        self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))\n        self._unit2id = {w:i for i, w in enumerate(self._id2unit)}\n\nclass MultiVocab(BaseMultiVocab):\n    @classmethod\n    def load_state_dict(cls, state_dict):\n        new = cls()\n        for k,v in state_dict.items():\n            new[k] = Vocab.load_state_dict(v)\n        return new\n"
  },
  {
    "path": "stanza/models/lemma_classifier/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/lemma_classifier/base_model.py",
    "content": "\"\"\"\nBase class for the LemmaClassifier types.\n\nVersions include LSTM and Transformer varieties\n\"\"\"\n\nimport logging\n\nfrom abc import ABC, abstractmethod\n\nimport os\n\nimport torch\nimport torch.nn as nn\n\nfrom stanza.models.common.foundation_cache import load_pretrain\nfrom stanza.models.lemma_classifier.constants import ModelType\n\nfrom typing import List\n\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\nclass LemmaClassifier(ABC, nn.Module):\n    def __init__(self, label_decoder, target_words, target_upos, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        self.label_decoder = label_decoder\n        self.label_encoder = {y: x for x, y in label_decoder.items()}\n        self.target_words = target_words\n        self.target_upos = target_upos\n        self.unsaved_modules = []\n\n    def add_unsaved_module(self, name, module):\n        self.unsaved_modules += [name]\n        setattr(self, name, module)\n\n    def is_unsaved_module(self, name):\n        return name.split('.')[0] in self.unsaved_modules\n\n    def save(self, save_name):\n        \"\"\"\n        Save the model to the given path, possibly with some args\n        \"\"\"\n        save_dir = os.path.split(save_name)[0]\n        if save_dir:\n            os.makedirs(save_dir, exist_ok=True)\n        save_dict = self.get_save_dict()\n        torch.save(save_dict, save_name)\n        return save_dict\n\n    @abstractmethod\n    def model_type(self):\n        \"\"\"\n        return a ModelType\n        \"\"\"\n\n    def target_indices(self, words, tags):\n        return [idx for idx, (word, tag) in enumerate(zip(words, tags)) if word.lower() in self.target_words and tag in self.target_upos]\n\n    def predict(self, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[str]]=[]) -> torch.Tensor:\n        upos_tags = self.convert_tags(upos_tags)\n        with torch.no_grad():\n            logits = self.forward(position_indices, sentences, upos_tags)  # should be size (batch_size, output_size)\n            predicted_class = torch.argmax(logits, dim=1)  # should be size (batch_size, 1)\n        predicted_class = [self.label_encoder[x.item()] for x in predicted_class]\n        return predicted_class\n\n    @staticmethod\n    def from_checkpoint(checkpoint, args=None):\n        model_type = ModelType[checkpoint['model_type']]\n        if model_type is ModelType.LSTM:\n            # TODO: if anyone can suggest a way to avoid this circular import\n            # (or better yet, avoid the load method knowing about subclasses)\n            # please do so\n            # maybe the subclassing is not necessary and we just put\n            # save & load in the trainer\n            from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM\n\n            saved_args = checkpoint['args']\n            # other model args are part of the model and cannot be changed for evaluation or pipeline\n            # the file paths might be relevant, though\n            keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file']\n            for arg in keep_args:\n                if args is not None and args.get(arg, None) is not None:\n                    saved_args[arg] = args[arg]\n\n            # TODO: refactor loading the pretrain (also done in the trainer)\n            pt = load_pretrain(saved_args['wordvec_pretrain_file'])\n\n            use_charlm = saved_args['use_charlm']\n            charlm_forward_file = saved_args.get('charlm_forward_file', None)\n            charlm_backward_file = saved_args.get('charlm_backward_file', None)\n\n            model = LemmaClassifierLSTM(model_args=saved_args,\n                                        output_dim=len(checkpoint['label_decoder']),\n                                        pt_embedding=pt,\n                                        label_decoder=checkpoint['label_decoder'],\n                                        upos_to_id=checkpoint['upos_to_id'],\n                                        known_words=checkpoint['known_words'],\n                                        target_words=set(checkpoint['target_words']),\n                                        target_upos=set(checkpoint['target_upos']),\n                                        use_charlm=use_charlm,\n                                        charlm_forward_file=charlm_forward_file,\n                                        charlm_backward_file=charlm_backward_file)\n        elif model_type is ModelType.TRANSFORMER:\n            from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer\n\n            output_dim = len(checkpoint['label_decoder'])\n            saved_args = checkpoint['args']\n            bert_model = saved_args['bert_model']\n            model = LemmaClassifierWithTransformer(model_args=saved_args,\n                                                   output_dim=output_dim,\n                                                   transformer_name=bert_model,\n                                                   label_decoder=checkpoint['label_decoder'],\n                                                   target_words=set(checkpoint['target_words']),\n                                                   target_upos=set(checkpoint['target_upos']))\n        else:\n            raise ValueError(\"Unknown model type %s\" % model_type)\n\n        # strict=False to accommodate missing parameters from the transformer or charlm\n        model.load_state_dict(checkpoint['params'], strict=False)\n        return model\n\n    @staticmethod\n    def load(filename, args=None):\n        try:\n            checkpoint = torch.load(filename, lambda storage, loc: storage)\n        except BaseException:\n            logger.exception(\"Cannot load model from %s\", filename)\n            raise\n\n        logger.debug(\"Loading LemmaClassifier model from %s\", filename)\n\n        return LemmaClassifier.from_checkpoint(checkpoint)\n"
  },
  {
    "path": "stanza/models/lemma_classifier/base_trainer.py",
    "content": "\nfrom abc import ABC, abstractmethod\nimport logging\nimport os\nfrom typing import List, Tuple, Any, Mapping\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom stanza.models.common.utils import default_device\nfrom stanza.models.lemma_classifier import utils\nfrom stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE\nfrom stanza.models.lemma_classifier.evaluate_models import evaluate_model\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\nclass BaseLemmaClassifierTrainer(ABC):\n    def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):\n        \"\"\"\n        If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss.\n        The weights are determined by the counts of the classes in the dataset. The weights are inversely proportional to the\n        frequency of the class in the set. E.g. classes with lower frequency will have higher weight.\n        \"\"\"\n        weights = [0 for _ in label_decoder.keys()]  # each key in the label decoder is one class, we have one weight per class\n        total_samples = sum(counts.values())\n        for class_idx in counts:\n            weights[class_idx] = total_samples / (counts[class_idx] * len(counts))  # weight_i = total / (# examples in class i * num classes)\n        weights = torch.tensor(weights)\n        logger.info(f\"Using weights {weights} for weighted loss.\")\n        self.criterion = nn.BCEWithLogitsLoss(weight=weights)\n\n    @abstractmethod\n    def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):\n        \"\"\"\n        Build a model using pieces of the dataset to determine some of the model shape\n        \"\"\"\n\n    def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, train_file: str) -> None:\n        \"\"\"\n        Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.\n\n        Args:\n            num_epochs (int): Number of training epochs\n            save_name (str): Path to file where trained model should be saved.\n            eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.\n            train_file (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.\n        \"\"\"\n        # Put model on GPU (if possible)\n        device = default_device()\n\n        if not train_file:\n            raise ValueError(\"Cannot train model - no train_file supplied!\")\n\n        dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get(\"batch_size\", DEFAULT_BATCH_SIZE))\n        label_decoder = dataset.label_decoder\n        upos_to_id = dataset.upos_to_id\n        self.output_dim = len(label_decoder)\n        logger.info(f\"Loaded dataset successfully from {train_file}\")\n        logger.info(f\"Using label decoder: {label_decoder}  Output dimension: {self.output_dim}\")\n        logger.info(f\"Target words: {dataset.target_words}\")\n\n        self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words, set(dataset.target_upos))\n        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)\n\n        self.model.to(device)\n        logger.info(f\"Training model on device: {device}. {next(self.model.parameters()).device}\")\n\n        if os.path.exists(save_name) and not args.get('force', False):\n            raise FileExistsError(f\"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...\")\n\n        if self.weighted_loss:\n            self.configure_weighted_loss(label_decoder, dataset.counts)\n\n        # Put the criterion on GPU too\n        logger.debug(f\"Criterion on {next(self.model.parameters()).device}\")\n        self.criterion = self.criterion.to(next(self.model.parameters()).device)\n\n        best_model, best_f1 = None, float(\"-inf\")  # Used for saving checkpoints of the model\n        for epoch in range(num_epochs):\n            # go over entire dataset with each epoch\n            for sentences, positions, upos_tags, labels in tqdm(dataset):\n                assert len(sentences) == len(positions) == len(labels), f\"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})\"\n\n                self.optimizer.zero_grad()\n                outputs = self.model(positions, sentences, upos_tags)\n\n                # Compute loss, which is different if using CE or BCEWithLogitsLoss\n                if self.weighted_loss:  # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.\n                    # TODO: three classes?\n                    targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device)\n                    # should be shape size (batch_size, 2)\n                else:  # CELoss accepts target as just raw label\n                    targets = labels.to(device)\n\n                loss = self.criterion(outputs, targets)\n\n                loss.backward()\n                self.optimizer.step()\n\n            logger.info(f\"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}\")\n            if eval_file:\n                # Evaluate model on dev set to see if it should be saved.\n                _, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True)\n                logger.info(f\"Weighted f1 for model: {f1}\")\n                if f1 > best_f1:\n                    best_f1 = f1\n                    self.model.save(save_name)\n                    logger.info(f\"New best model: weighted f1 score of {f1}.\")\n            else:\n                self.model.save(save_name)\n\n"
  },
  {
    "path": "stanza/models/lemma_classifier/baseline_model.py",
    "content": "\"\"\"\nBaseline model for the existing lemmatizer which always predicts \"be\" and never \"have\" on the \"'s\" token.\n\nThe BaselineModel class can be updated to any arbitrary token and predicton lemma, not just \"be\" on the \"s\" token.\n\"\"\"\n\nimport stanza\nimport os\nfrom stanza.models.lemma_classifier.evaluate_models import evaluate_sequences\nfrom stanza.models.lemma_classifier.prepare_dataset import load_doc_from_conll_file\n\nclass BaselineModel:\n\n    def __init__(self, token_to_lemmatize, prediction_lemma, prediction_upos):\n        self.token_to_lemmatize = token_to_lemmatize\n        self.prediction_lemma = prediction_lemma\n        self.prediction_upos = prediction_upos\n\n    def predict(self, token):\n        if token == self.token_to_lemmatize:\n            return self.prediction_lemma\n\n    def evaluate(self, conll_path):\n        \"\"\"\n        Evaluates the baseline model against the test set defined in conll_path.\n\n        Returns a map where the keys are each class and the values are another map including the precision, recall and f1 scores\n        for that class.\n\n        Also returns confusion matrix. Keys are gold tags and inner keys are predicted tags\n        \"\"\"\n        doc = load_doc_from_conll_file(conll_path)\n        gold_tag_sequences, pred_tag_sequences = [], []\n        for sentence in doc.sentences:\n            gold_tags, pred_tags = [], []\n            for word in sentence.words:\n                if word.upos in self.prediction_upos and word.text == self.token_to_lemmatize:\n                    pred = self.prediction_lemma\n                    gold = word.lemma\n                    gold_tags.append(gold)\n                    pred_tags.append(pred)\n            gold_tag_sequences.append(gold_tags)\n            pred_tag_sequences.append(pred_tags)\n\n        multiclass_result, confusion_mtx, weighted_f1 = evaluate_sequences(gold_tag_sequences, pred_tag_sequences)\n        return multiclass_result, confusion_mtx\n\n\nif __name__ == \"__main__\":\n\n    bl_model = BaselineModel(\"'s\", \"be\", [\"AUX\"])\n    coNLL_path = os.path.join(os.path.dirname(__file__), \"en_gum-ud-train.conllu\")\n    bl_model.evaluate(coNLL_path)\n\n"
  },
  {
    "path": "stanza/models/lemma_classifier/constants.py",
    "content": "from enum import Enum\n\nUNKNOWN_TOKEN = \"unk\"  # token name for unknown tokens\nUNKNOWN_TOKEN_IDX = -1   # custom index we apply to unknown tokens\n\n# TODO: ModelType could just be LSTM and TRANSFORMER\n# and then the transformer baseline would have the transformer as another argument\nclass ModelType(Enum):\n    LSTM               = 1\n    TRANSFORMER        = 2\n    BERT               = 3\n    ROBERTA            = 4\n\nDEFAULT_BATCH_SIZE = 16"
  },
  {
    "path": "stanza/models/lemma_classifier/evaluate_many.py",
    "content": "\"\"\"\nUtils to evaluate many models of the same type at once\n\"\"\"\nimport argparse\nimport os\nimport logging\n\nfrom stanza.models.lemma_classifier.evaluate_models import main as evaluate_main\n\n\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\ndef evaluate_n_models(path_to_models_dir, args):\n\n    total_results = {\n        \"be\": 0.0,\n        \"have\": 0.0,\n        \"accuracy\": 0.0,\n        \"weighted_f1\": 0.0\n    }\n    paths = os.listdir(path_to_models_dir)\n    num_models = len(paths)\n    for model_path in paths:\n        full_path = os.path.join(path_to_models_dir, model_path)\n        args.save_name = full_path\n        mcc_results, confusion, acc, weighted_f1 = evaluate_main(predefined_args=args)\n\n        for lemma in mcc_results:\n\n            lemma_f1 = mcc_results.get(lemma, None).get(\"f1\") * 100\n            total_results[lemma] += lemma_f1\n\n        total_results[\"accuracy\"] += acc\n        total_results[\"weighted_f1\"] += weighted_f1\n\n    total_results[\"be\"] /= num_models\n    total_results[\"have\"] /= num_models\n    total_results[\"accuracy\"] /= num_models\n    total_results[\"weighted_f1\"] /= num_models\n\n    logger.info(f\"Models in {path_to_models_dir} had average weighted f1 of {100 * total_results['weighted_f1']}.\\nLemma 'be' had f1: {total_results['be']}\\nLemma 'have' had f1: {total_results['have']}.\\nAccuracy: {100 * total_results['accuracy']}.\\n ({num_models} models evaluated).\")\n    return total_results\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--vocab_size\", type=int, default=10000, help=\"Number of tokens in vocab\")\n    parser.add_argument(\"--embedding_dim\", type=int, default=100, help=\"Number of dimensions in word embeddings (currently using GloVe)\")\n    parser.add_argument(\"--hidden_dim\", type=int, default=256, help=\"Size of hidden layer\")\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n    parser.add_argument(\"--charlm\", action='store_true', default=False, help=\"Whether not to use the charlm embeddings\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument(\"--charlm_forward_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"charlm_files\", \"1billion_forward.pt\"), help=\"Path to forward charlm file\")\n    parser.add_argument(\"--charlm_backward_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"charlm_files\", \"1billion_backwards.pt\"), help=\"Path to backward charlm file\")\n    parser.add_argument(\"--save_name\", type=str, default=os.path.join(os.path.dirname(__file__), \"saved_models\", \"lemma_classifier_model.pt\"), help=\"Path to model save file\")\n    parser.add_argument(\"--model_type\", type=str, default=\"roberta\", help=\"Which transformer to use ('bert' or 'roberta' or 'lstm')\")\n    parser.add_argument(\"--bert_model\", type=str, default=None, help=\"Use a specific transformer instead of the default bert/roberta\")\n    parser.add_argument(\"--eval_file\", type=str, help=\"path to evaluation file\")\n\n    # Args specific to several model eval\n    parser.add_argument(\"--base_path\", type=str, default=None, help=\"path to dir for eval\")\n\n    args = parser.parse_args()\n    evaluate_n_models(args.base_path, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/models/lemma_classifier/evaluate_models.py",
    "content": "import os\nimport sys\n\nparentdir = os.path.dirname(__file__)\nparentdir = os.path.dirname(parentdir)\nparentdir = os.path.dirname(parentdir)\nsys.path.append(parentdir)\n\nimport logging\nimport argparse\nimport os\n\nfrom typing import Any, List, Tuple, Mapping\nfrom collections import defaultdict\nfrom numpy import random\n\nimport torch\nimport torch.nn as nn\n\nimport stanza\n\nfrom stanza.models.common.utils import default_device\nfrom stanza.models.lemma_classifier import utils\nfrom stanza.models.lemma_classifier.base_model import LemmaClassifier\nfrom stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM\nfrom stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer\nfrom stanza.utils.confusion import format_confusion\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\n\ndef get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: Mapping[int, Mapping[int, int]]) -> float:\n    \"\"\"\n    Computes the weighted F1 score across an evaluation set.\n\n    The weight of a class's F1 score is equal to the number of examples in evaluation. This makes classes that have more\n    examples in the evaluation more impactful to the weighted f1.\n    \"\"\"\n    num_total_examples = 0\n    weighted_f1 = 0\n\n    for class_id in mcc_results:\n        class_f1 = mcc_results.get(class_id).get(\"f1\")\n        num_class_examples = sum(confusion.get(class_id).values())\n        weighted_f1 += class_f1 * num_class_examples\n        num_total_examples += num_class_examples\n\n    return weighted_f1 / num_total_examples\n\n\ndef evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True):\n    \"\"\"\n    Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes.\n\n    Precision = true positives / true positives + false positives\n    Recall = true positives / true positives + false negatives\n    F1 = 2 * (Precision * Recall) / (Precision + Recall)\n\n    Returns:\n        1. Multi class result dictionary, where each class is a key and maps to another map of its F1, precision, and recall scores.\n           e.g. multiclass_results[0][\"precision\"] would give class 0's precision.\n        2. Confusion matrix, where each key is a gold tag and its value is another map with a key of the predicted tag with value of that (gold, pred) count.\n           e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times.\n    \"\"\"\n    assert len(gold_tag_sequences) == len(pred_tag_sequences), \\\n    f\"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}\"\n\n    confusion = defaultdict(lambda: defaultdict(int))\n\n    reverse_label_decoder = {y: x for x, y in label_decoder.items()}\n    for gold, pred in zip(gold_tag_sequences, pred_tag_sequences):\n        confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1\n\n    multi_class_result = defaultdict(lambda: defaultdict(float))\n    # compute precision, recall and f1 for each class and store inside of `multi_class_result`\n    for gold_tag in confusion.keys():\n\n        try:\n            prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()])\n        except ZeroDivisionError:\n            prec = 0.0\n\n        try:\n            recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values())\n        except ZeroDivisionError:\n            recall = 0.0\n\n        try:\n            f1 = 2 * (prec * recall) / (prec + recall)\n        except ZeroDivisionError:\n            f1 = 0.0\n\n        multi_class_result[gold_tag] = {\n            \"precision\": prec,\n            \"recall\": recall,\n            \"f1\": f1\n        }\n\n    if verbose:\n        for lemma in multi_class_result:\n            logger.info(f\"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}\")\n\n    weighted_f1 = get_weighted_f1(multi_class_result, confusion)\n\n    return multi_class_result, confusion, weighted_f1\n\n\ndef model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor:\n    \"\"\"\n    A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token.\n\n    Args:\n        model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token.\n        position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch.\n        sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences.\n\n    Returns:\n        (int): The index of the predicted class in `model`'s output.\n    \"\"\"\n    with torch.no_grad():\n        logits = model(position_indices, sentences, upos_tags)  # should be size (batch_size, output_size)\n        predicted_class = torch.argmax(logits, dim=1)  # should be size (batch_size, 1)\n\n    return predicted_class\n\n\ndef evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]:\n    \"\"\"\n    Helper function for model evaluation\n\n    Args:\n        model (LemmaClassifierLSTM or LemmaClassifierWithTransformer): An instance of the LemmaClassifier class that has architecture initialized which matches the model saved in `model_path`.\n        model_path (str): Path to the saved model weights that will be loaded into `model`.\n        eval_path (str): Path to the saved evaluation dataset.\n        verbose (bool, optional): True if `evaluate_sequences()` should print the F1, Precision, and Recall for each class. Defaults to True.\n        is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode.\n\n    Returns:\n        1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is\n                                                                    another map with key of \"f1\", \"precision\", or \"recall\" with corresponding values.\n        2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the\n                                                               map with the key as the predicted tag and corresponding count of that (gold, pred) pair.\n        3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set.\n    \"\"\"\n    # load model\n    device = default_device()\n    model.to(device)\n\n    if not is_training:\n        model.eval()  # set to eval mode\n\n    # load in eval data\n    dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False)\n\n    logger.info(f\"Evaluating on evaluation file {eval_path}\")\n\n    correct, total = 0, 0\n    gold_tags, pred_tags = dataset.labels, []\n\n    # run eval on each example from dataset\n    for sentences, pos_indices, upos_tags, labels in tqdm(dataset, \"Evaluating examples from data file\"):\n        pred = model_predict(model, pos_indices, sentences, upos_tags)  # Pred should be size (batch_size, )\n        correct_preds = pred == labels.to(device)\n        correct += torch.sum(correct_preds)\n        total += len(correct_preds)\n        pred_tags += pred.tolist()\n\n    logger.info(\"Finished evaluating on dataset. Computing scores...\")\n    accuracy = correct / total\n\n    mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose)\n    # add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper\n    if verbose:\n        logger.info(f\"Accuracy: {accuracy} ({correct}/{total})\")\n        logger.info(f\"Label decoder: {dataset.label_decoder}\")\n\n    return mc_results, confusion, accuracy, weighted_f1\n\n\ndef main(args=None, predefined_args=None):\n\n    # TODO: can unify this script with train_lstm_model.py?\n    # TODO: can save the model type in the model .pt, then\n    # automatically figure out what type of model we are using by\n    # looking in the file\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--vocab_size\", type=int, default=10000, help=\"Number of tokens in vocab\")\n    parser.add_argument(\"--embedding_dim\", type=int, default=100, help=\"Number of dimensions in word embeddings (currently using GloVe)\")\n    parser.add_argument(\"--hidden_dim\", type=int, default=256, help=\"Size of hidden layer\")\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n    parser.add_argument(\"--charlm\", action='store_true', default=False, help=\"Whether not to use the charlm embeddings\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument(\"--charlm_forward_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"charlm_files\", \"1billion_forward.pt\"), help=\"Path to forward charlm file\")\n    parser.add_argument(\"--charlm_backward_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"charlm_files\", \"1billion_backwards.pt\"), help=\"Path to backward charlm file\")\n    parser.add_argument(\"--save_name\", type=str, default=os.path.join(os.path.dirname(__file__), \"saved_models\", \"lemma_classifier_model.pt\"), help=\"Path to model save file\")\n    parser.add_argument(\"--model_type\", type=str, default=\"roberta\", help=\"Which transformer to use ('bert' or 'roberta' or 'lstm')\")\n    parser.add_argument(\"--bert_model\", type=str, default=None, help=\"Use a specific transformer instead of the default bert/roberta\")\n    parser.add_argument(\"--eval_file\", type=str, help=\"path to evaluation file\")\n\n    args = parser.parse_args(args) if not predefined_args else predefined_args\n\n    logger.info(\"Running training script with the following args:\")\n    args = vars(args)\n    for arg in args:\n        logger.info(f\"{arg}: {args[arg]}\")\n    logger.info(\"------------------------------------------------------------\")\n\n    logger.info(f\"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}\")\n    model = LemmaClassifier.load(args['save_name'], args)\n\n    mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, args['eval_file'])\n\n    logger.info(f\"MCC Results: {dict(mcc_results)}\")\n    logger.info(\"______________________________________________\")\n    logger.info(f\"Confusion:\\n%s\", format_confusion(confusion))\n    logger.info(\"______________________________________________\")\n    logger.info(f\"Accuracy: {acc}\")\n    logger.info(\"______________________________________________\")\n    logger.info(f\"Weighted f1: {weighted_f1}\")\n\n    return mcc_results, confusion, acc, weighted_f1\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/models/lemma_classifier/lstm_model.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nimport logging\nimport math\nfrom torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence\nfrom stanza.models.common.char_model import CharacterModel, CharacterLanguageModel\nfrom typing import List, Tuple\n\nfrom stanza.models.common.vocab import UNK_ID\nfrom stanza.models.lemma_classifier import utils\nfrom stanza.models.lemma_classifier.base_model import LemmaClassifier\nfrom stanza.models.lemma_classifier.constants import ModelType\n\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\nclass LemmaClassifierLSTM(LemmaClassifier):\n    \"\"\"\n    Model architecture:\n        Extracts word embeddings over the sentence, passes embeddings into a bi-LSTM to get a sentence encoding.\n        From the LSTM output, we get the embedding of the specific token that we classify on. That embedding\n        is fed into an MLP for classification.\n    \"\"\"\n    def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos,\n                 use_charlm=False, charlm_forward_file=None, charlm_backward_file=None):\n        \"\"\"\n        Args:\n            vocab_size (int): Size of the vocab being used (if custom vocab)\n            output_dim (int): Size of output vector from MLP layer\n            upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs\n            pt_embedding (Pretrain): pretrained embeddings\n            known_words (list(str)): Words which are in the training data\n            target_words (set(str)): a set of the words which might need lemmatization\n            use_charlm (bool): Whether or not to use the charlm embeddings\n            charlm_forward_file (str): The path to the forward pass model for the character language model\n            charlm_backward_file (str): The path to the forward pass model for the character language model.\n\n        Kwargs:\n            upos_emb_dim (int): The size of the UPOS tag embeddings\n            num_heads (int): The number of heads to use for attention. If there are more than 0 heads, attention will be used instead of the LSTM.\n\n        Raises:\n            FileNotFoundError: if the forward or backward charlm file cannot be found.\n        \"\"\"\n        super(LemmaClassifierLSTM, self).__init__(label_decoder, target_words, target_upos)\n        self.model_args = model_args\n\n        self.hidden_dim = model_args['hidden_dim']\n        self.input_size = 0\n        self.num_heads = self.model_args['num_heads']\n\n        emb_matrix = pt_embedding.emb\n        self.add_unsaved_module(\"embeddings\", nn.Embedding.from_pretrained(emb_matrix, freeze=True))\n        self.vocab_map = { word.replace('\\xa0', ' '): i for i, word in enumerate(pt_embedding.vocab) }\n        self.vocab_size = emb_matrix.shape[0]\n        self.embedding_dim = emb_matrix.shape[1]\n\n        self.known_words = known_words\n        self.known_word_map = {word: idx for idx, word in enumerate(known_words)}\n        self.delta_embedding = nn.Embedding(num_embeddings=len(known_words)+1,\n                                            embedding_dim=self.embedding_dim,\n                                            padding_idx=0)\n        nn.init.normal_(self.delta_embedding.weight, std=0.01)\n\n        self.input_size += self.embedding_dim\n\n        # Optionally, include charlm embeddings\n        self.use_charlm = use_charlm\n\n        if self.use_charlm:\n            if charlm_forward_file is None or not os.path.exists(charlm_forward_file):\n                raise FileNotFoundError(f'Could not find forward character model: {charlm_forward_file}')\n            if charlm_backward_file is None or not os.path.exists(charlm_backward_file):\n                raise FileNotFoundError(f'Could not find backward character model: {charlm_backward_file}')\n            self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(charlm_forward_file, finetune=False))\n            self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(charlm_backward_file, finetune=False))\n\n            self.input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()\n\n        self.upos_emb_dim = self.model_args[\"upos_emb_dim\"]\n        self.upos_to_id = upos_to_id\n        if self.upos_emb_dim > 0 and self.upos_to_id is not None:\n            # TODO: should leave space for unknown POS?\n            self.upos_emb = nn.Embedding(num_embeddings=len(self.upos_to_id),\n                                         embedding_dim=self.upos_emb_dim,\n                                         padding_idx=0)\n            self.input_size += self.upos_emb_dim\n\n        device = next(self.parameters()).device\n        # Determine if attn or LSTM should be used\n        if self.num_heads > 0:\n            self.input_size = utils.round_up_to_multiple(self.input_size, self.num_heads)\n            self.multihead_attn = nn.MultiheadAttention(embed_dim=self.input_size, num_heads=self.num_heads, batch_first=True).to(device)\n            logger.debug(f\"Using attention mechanism with embed dim {self.input_size} and {self.num_heads} attention heads.\")\n        else:\n            self.lstm = nn.LSTM(self.input_size,\n                                self.hidden_dim,\n                                batch_first=True,\n                                bidirectional=True)\n            logger.debug(f\"Using LSTM mechanism.\")\n\n        mlp_input_size = self.hidden_dim * 2 if self.num_heads == 0 else self.input_size\n        self.mlp = nn.Sequential(\n            nn.Linear(mlp_input_size, 64),\n            nn.ReLU(),\n            nn.Linear(64, output_dim)\n        )\n\n    def get_save_dict(self):\n        save_dict = {\n            \"params\": self.state_dict(),\n            \"label_decoder\": self.label_decoder,\n            \"model_type\": self.model_type().name,\n            \"args\": self.model_args,\n            \"upos_to_id\": self.upos_to_id,\n            \"known_words\": self.known_words,\n            \"target_words\": list(self.target_words),\n            \"target_upos\": list(self.target_upos),\n        }\n        skipped = [k for k in save_dict[\"params\"].keys() if self.is_unsaved_module(k)]\n        for k in skipped:\n            del save_dict[\"params\"][k]\n        return save_dict\n\n    def convert_tags(self, upos_tags: List[List[str]]):\n        if self.upos_to_id is not None:\n            return [[self.upos_to_id[x] for x in sentence] for sentence in upos_tags]\n        return None\n\n    def forward(self, pos_indices: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):\n        \"\"\"\n        Computes the forward pass of the neural net\n\n        Args:\n            pos_indices (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.\n            sentences (List[List[str]]): A list of the token-split sentences of the input data.\n            upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence.\n\n        Returns:\n            torch.tensor: Output logits of the neural network, where the shape is  (n, output_size) where n is the number of sentences.\n        \"\"\"\n        device = next(self.parameters()).device\n        batch_size = len(sentences)\n        token_ids = []\n        delta_token_ids = []\n        for words in sentences:\n            sentence_token_ids = [self.vocab_map.get(word.lower(), UNK_ID) for word in words]\n            sentence_token_ids = torch.tensor(sentence_token_ids, device=device)\n            token_ids.append(sentence_token_ids)\n\n            sentence_delta_token_ids = [self.known_word_map.get(word.lower(), 0) for word in words]\n            sentence_delta_token_ids = torch.tensor(sentence_delta_token_ids, device=device)\n            delta_token_ids.append(sentence_delta_token_ids)\n\n        token_ids = pad_sequence(token_ids, batch_first=True)\n        delta_token_ids = pad_sequence(delta_token_ids, batch_first=True)\n        embedded = self.embeddings(token_ids) + self.delta_embedding(delta_token_ids)\n\n        if self.upos_emb_dim > 0:\n            upos_tags = [torch.tensor(sentence_tags) for sentence_tags in upos_tags]  # convert internal lists to tensors\n            upos_tags = pad_sequence(upos_tags, batch_first=True, padding_value=0).to(device)\n            pos_emb = self.upos_emb(upos_tags)\n            embedded = torch.cat((embedded, pos_emb), 2).to(device)\n\n        if self.use_charlm:\n            char_reps_forward = self.charmodel_forward.build_char_representation(sentences)  # takes [[str]]\n            char_reps_backward = self.charmodel_backward.build_char_representation(sentences)\n\n            char_reps_forward = pad_sequence(char_reps_forward, batch_first=True)\n            char_reps_backward = pad_sequence(char_reps_backward, batch_first=True)\n\n            embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 2)\n\n        if self.num_heads > 0:\n\n            def positional_encoding(seq_len, d_model, device):\n                encoding = torch.zeros(seq_len, d_model, device=device)\n                position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1)\n                div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device)\n\n                encoding[:, 0::2] = torch.sin(position * div_term)\n                encoding[:, 1::2] = torch.cos(position * div_term)\n\n                # Add a new dimension to fit the batch size\n                encoding = encoding.unsqueeze(0)\n                return encoding\n\n            seq_len, d_model = embedded.shape[1], embedded.shape[2]\n            pos_enc = positional_encoding(seq_len, d_model, device=device)\n\n            embedded += pos_enc.expand_as(embedded)\n\n        padded_sequences = pad_sequence(embedded, batch_first=True)\n        lengths = torch.tensor([len(seq) for seq in embedded])\n\n        if self.num_heads > 0:\n            target_seq_length, src_seq_length = padded_sequences.size(1), padded_sequences.size(1)\n            attn_mask = torch.triu(torch.ones(batch_size * self.num_heads, target_seq_length, src_seq_length, dtype=torch.bool), diagonal=1)\n\n            attn_mask = attn_mask.view(batch_size, self.num_heads, target_seq_length, src_seq_length)\n            attn_mask = attn_mask.repeat(1, 1, 1, 1).view(batch_size * self.num_heads, target_seq_length, src_seq_length).to(device)\n\n            attn_output, attn_weights = self.multihead_attn(padded_sequences, padded_sequences, padded_sequences, attn_mask=attn_mask)\n            # Extract the hidden state at the index of the token to classify\n            token_reps = attn_output[torch.arange(attn_output.size(0)), pos_indices]\n\n        else:\n            packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True)\n            lstm_out, (hidden, _) = self.lstm(packed_sequences)\n            # Extract the hidden state at the index of the token to classify\n            unpacked_lstm_outputs, _ = pad_packed_sequence(lstm_out, batch_first=True)\n            token_reps = unpacked_lstm_outputs[torch.arange(unpacked_lstm_outputs.size(0)), pos_indices]\n\n        # MLP forward pass\n        output = self.mlp(token_reps)\n        return output\n\n    def model_type(self):\n        return ModelType.LSTM\n"
  },
  {
    "path": "stanza/models/lemma_classifier/prepare_dataset.py",
    "content": "import argparse\nimport json\nimport os\nimport re\n\nimport stanza\nfrom stanza.models.lemma_classifier import utils\n\nfrom typing import List, Tuple, Any\n\n\"\"\"\nThe code in this file processes a CoNLL dataset by taking its sentences and filtering out all sentences that do not contain the target token.\nFurthermore, it will store tuples of the Stanza document object, the position index of the target token, and its lemma.\n\"\"\"\n\n\ndef load_doc_from_conll_file(path: str):\n    \"\"\"\"\n    loads in a Stanza document object from a path to a CoNLL file containing annotated sentences.\n    \"\"\"\n    return stanza.utils.conll.CoNLL.conll2doc(path)\n\n\nclass DataProcessor():\n\n    def __init__(self, target_word: str, target_upos: List[str], allowed_lemmas: str):\n        self.target_word = target_word\n        self.target_word_regex = re.compile(target_word)\n        self.target_upos = target_upos\n        self.allowed_lemmas = re.compile(allowed_lemmas)\n\n    def keep_sentence(self, sentence):\n        for word in sentence.words:\n            if self.target_word_regex.fullmatch(word.text) and word.upos in self.target_upos:\n                return True\n        return False\n\n    def find_all_occurrences(self, sentence) -> List[int]:\n        \"\"\"\n        Finds all occurrences of self.target_word in tokens and returns the index(es) of such occurrences.\n        \"\"\"\n        occurrences = []\n        for idx, token in enumerate(sentence.words):\n            if self.target_word_regex.fullmatch(token.text) and token.upos in self.target_upos:\n                occurrences.append(idx)\n        return occurrences\n\n    @staticmethod\n    def write_output_file(save_name, target_upos, sentences):\n        with open(save_name, \"w+\", encoding=\"utf-8\") as output_f:\n            output_f.write(\"{\\n\")\n            output_f.write('  \"upos\": %s,\\n' % json.dumps(target_upos))\n            output_f.write('  \"sentences\": [')\n            wrote_sentence = False\n            for sentence in sentences:\n                if not wrote_sentence:\n                    output_f.write(\"\\n    \")\n                    wrote_sentence = True\n                else:\n                    output_f.write(\",\\n    \")\n                output_f.write(json.dumps(sentence))\n            output_f.write(\"\\n  ]\\n}\\n\")\n\n    def process_document(self, doc, save_name: str) -> None:\n        \"\"\"\n        Takes any sentence from `doc` that meets the condition of `keep_sentence` and writes its tokens, index of target word, and lemma to `save_name`\n\n        Sentences that meet `keep_sentence` and contain `self.target_word` multiple times have each instance in a different example in the output file.\n\n        Args:\n            doc (Stanza.doc): Document object that represents the file to be analyzed\n            save_name (str): Path to the file for storing output\n        \"\"\"\n        sentences = []\n        for sentence in doc.sentences:\n            # for each sentence, we need to determine if it should be added to the output file.\n            # if the sentence fulfills keep_sentence, then we will save it along with the target word's index and its corresponding lemma\n            if self.keep_sentence(sentence):\n                tokens = [token.text for token in sentence.words]\n                indexes = self.find_all_occurrences(sentence)\n                for idx in indexes:\n                    if self.allowed_lemmas.fullmatch(sentence.words[idx].lemma):\n                        # for each example found, we write the tokens,\n                        # their respective upos tags, the target token index,\n                        # and the target lemma\n                        upos_tags = [sentence.words[i].upos for i in range(len(sentence.words))]\n                        num_tokens = len(upos_tags)\n                        sentences.append({\n                            \"words\": tokens,\n                            \"upos_tags\": upos_tags,\n                            \"index\": idx,\n                            \"lemma\": sentence.words[idx].lemma\n                        })\n\n        if save_name:\n            self.write_output_file(save_name, self.target_upos, sentences)\n        return sentences\n\ndef main(args=None):\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--conll_path\", type=str, default=os.path.join(os.path.dirname(__file__), \"en_gum-ud-train.conllu\"), help=\"path to the conll file to translate\")\n    parser.add_argument(\"--target_word\", type=str, default=\"'s\", help=\"Token to classify on, e.g. 's.\")\n    parser.add_argument(\"--target_upos\", type=str, default=\"AUX\", help=\"upos on target token\")\n    parser.add_argument(\"--output_path\", type=str, default=\"test_output.txt\", help=\"Path for output file\")\n    parser.add_argument(\"--allowed_lemmas\", type=str, default=\".*\", help=\"A regex for allowed lemmas.  If not set, all lemmas are allowed\")\n\n    args = parser.parse_args(args)\n\n    conll_path = args.conll_path\n    target_upos = args.target_upos\n    output_path = args.output_path\n    allowed_lemmas = args.allowed_lemmas\n\n    args = vars(args)\n    for arg in args:\n        print(f\"{arg}: {args[arg]}\")\n\n    doc = load_doc_from_conll_file(conll_path)\n    processor = DataProcessor(target_word=args['target_word'], target_upos=[target_upos], allowed_lemmas=allowed_lemmas)\n\n    return processor.process_document(doc, output_path)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/models/lemma_classifier/train_lstm_model.py",
    "content": "\"\"\"\nThe code in this file works to train a lemma classifier for 's\n\"\"\"\n\nimport argparse\nimport logging\nimport os\n\nimport torch\nimport torch.nn as nn\n\nfrom stanza.models.common.foundation_cache import load_pretrain\nfrom stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer\nfrom stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE\nfrom stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM\n\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\nclass LemmaClassifierTrainer(BaseLemmaClassifierTrainer):\n    \"\"\"\n    Class to assist with training a LemmaClassifierLSTM\n    \"\"\"\n\n    def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = False, charlm_forward_file: str = None, charlm_backward_file: str = None, lr: float = 0.001, loss_func: str = None):\n        \"\"\"\n        Initializes the LemmaClassifierTrainer class.\n\n        Args:\n            model_args (dict): Various model shape parameters\n            embedding_file (str): What word embeddings file to use.  Use a Stanza pretrain .pt\n            use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False.\n            charlm_forward_file (str): Path to the forward pass embeddings for the charlm\n            charlm_backward_file (str): Path to the backward pass embeddings for the charlm\n            upos_emb_dim (int): The dimension size of UPOS tag embeddings\n            num_heads (int): The number of attention heads to use.\n            lr (float): Learning rate, defaults to 0.001.\n            loss_func (str): Which loss function to use (either 'ce' or 'weighted_bce')\n\n        Raises:\n            FileNotFoundError: If the forward charlm file is not present\n            FileNotFoundError: If the backward charlm file is not present\n        \"\"\"\n        super().__init__()\n\n        self.model_args = model_args\n\n        # Load word embeddings\n        pt = load_pretrain(embedding_file)\n        self.pt_embedding = pt\n\n        # Load CharLM embeddings\n        if use_charlm and charlm_forward_file is not None and not os.path.exists(charlm_forward_file):\n            raise FileNotFoundError(f\"Could not find forward charlm file: {charlm_forward_file}\")\n        if use_charlm and charlm_backward_file is not None and not os.path.exists(charlm_backward_file):\n            raise FileNotFoundError(f\"Could not find backward charlm file: {charlm_backward_file}\")\n\n        # TODO: just pass around the args instead\n        self.use_charlm = use_charlm\n        self.charlm_forward_file = charlm_forward_file\n        self.charlm_backward_file = charlm_backward_file\n        self.lr = lr\n\n        # Find loss function\n        if loss_func == \"ce\":\n            self.criterion = nn.CrossEntropyLoss()\n            self.weighted_loss = False\n            logger.debug(\"Using CE loss\")\n        elif loss_func == \"weighted_bce\":\n            self.criterion = nn.BCEWithLogitsLoss()\n            self.weighted_loss = True  # used to add weights during train time.\n            logger.debug(\"Using Weighted BCE loss\")\n        else:\n            raise ValueError(\"Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')\")\n\n    def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):\n        return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos,\n                                   use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file)\n\ndef build_argparse():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--hidden_dim\", type=int, default=256, help=\"Size of hidden layer\")\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), \"pretrain\", \"glove.pt\"), help='Exact name of the pretrain file to read')\n    parser.add_argument(\"--charlm\", action='store_true', dest='use_charlm', default=False, help=\"Whether not to use the charlm embeddings\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument(\"--charlm_forward_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"charlm_files\", \"1billion_forward.pt\"), help=\"Path to forward charlm file\")\n    parser.add_argument(\"--charlm_backward_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"charlm_files\", \"1billion_backwards.pt\"), help=\"Path to backward charlm file\")\n    parser.add_argument(\"--upos_emb_dim\", type=int, default=20, help=\"Dimension size for UPOS tag embeddings.\")\n    parser.add_argument(\"--use_attn\", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')\n    parser.add_argument(\"--num_heads\", type=int, default=0, help=\"Number of heads to use for multihead attention.\")\n    parser.add_argument(\"--save_name\", type=str, default=os.path.join(os.path.dirname(__file__), \"saved_models\", \"lemma_classifier_model_weighted_loss_charlm_new.pt\"), help=\"Path to model save file\")\n    parser.add_argument(\"--lr\", type=float, default=0.001, help=\"learning rate\")\n    parser.add_argument(\"--num_epochs\", type=float, default=10, help=\"Number of training epochs\")\n    parser.add_argument(\"--batch_size\", type=int, default=DEFAULT_BATCH_SIZE, help=\"Number of examples to include in each batch\")\n    parser.add_argument(\"--train_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"data\", \"processed_ud_en\", \"combined_train.txt\"), help=\"Full path to training file\")\n    parser.add_argument(\"--weighted_loss\", action='store_true', dest='weighted_loss', default=False, help=\"Whether to use weighted loss during training.\")\n    parser.add_argument(\"--eval_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"data\", \"processed_ud_en\", \"combined_dev.txt\"), help=\"Path to dev file used to evaluate model for saves\")\n    parser.add_argument(\"--force\", action='store_true', default=False, help='Whether or not to clobber an existing save file')\n    return parser\n\ndef main(args=None, predefined_args=None):\n    parser = build_argparse()\n    args = parser.parse_args(args) if predefined_args is None else predefined_args\n\n    wordvec_pretrain_file = args.wordvec_pretrain_file\n    use_charlm = args.use_charlm\n    charlm_forward_file = args.charlm_forward_file\n    charlm_backward_file = args.charlm_backward_file\n    upos_emb_dim = args.upos_emb_dim\n    use_attention = args.attn\n    num_heads = args.num_heads\n    save_name = args.save_name\n    lr = args.lr\n    num_epochs = args.num_epochs\n    train_file = args.train_file\n    weighted_loss = args.weighted_loss\n    eval_file = args.eval_file\n\n    args = vars(args)\n\n    if os.path.exists(save_name) and not args.get('force', False):\n        raise FileExistsError(f\"Save name {save_name} already exists. Training would override existing data. Aborting...\")\n    if not os.path.exists(train_file):\n        raise FileNotFoundError(f\"Training file {train_file} not found. Try again with a valid path.\")\n\n    logger.info(\"Running training script with the following args:\")\n    for arg in args:\n        logger.info(f\"{arg}: {args[arg]}\")\n    logger.info(\"------------------------------------------------------------\")\n\n    trainer = LemmaClassifierTrainer(model_args=args,\n                                     embedding_file=wordvec_pretrain_file,\n                                     use_charlm=use_charlm,\n                                     charlm_forward_file=charlm_forward_file,\n                                     charlm_backward_file=charlm_backward_file,\n                                     lr=lr,\n                                     loss_func=\"weighted_bce\" if weighted_loss else \"ce\",\n                                     )\n\n    trainer.train(\n        num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_file=train_file\n    )\n\n    return trainer\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/models/lemma_classifier/train_many.py",
    "content": "\"\"\"\nUtils for training and evaluating multiple models simultaneously\n\"\"\"\n\nimport argparse\nimport os\n\nfrom stanza.models.lemma_classifier.train_lstm_model import main as train_lstm_main\nfrom stanza.models.lemma_classifier.train_transformer_model import main as train_tfmr_main\nfrom stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE\n\n\nchange_params_map = {\n    \"lstm_layer\": [16, 32, 64, 128, 256, 512],\n    \"upos_emb_dim\": [5, 10, 20, 30],\n    \"training_size\": [150, 300, 450, 600, 'full'],\n}  # TODO: Add attention\n\ndef train_n_models(num_models: int, base_path: str, args):\n\n    if args.change_param == \"lstm_layer\":\n        for num_layers in change_params_map.get(\"lstm_layer\", None):\n            for i in range(num_models):\n                new_save_name = os.path.join(base_path, f\"{num_layers}_{i}.pt\")\n                args.save_name = new_save_name\n                args.hidden_dim = num_layers\n                train_lstm_main(predefined_args=args)\n\n    if args.change_param == \"upos_emb_dim\":\n        for upos_dim in change_params_map(\"upos_emb_dim\", None):\n            for i in range(num_models):\n                new_save_name = os.path.join(base_path, f\"dim_{upos_dim}_{i}.pt\")\n                args.save_name = new_save_name\n                args.upos_emb_dim = upos_dim\n                train_lstm_main(predefined_args=args)\n\n    if args.change_param == \"training_size\":\n        for size in change_params_map.get(\"training_size\", None):\n            for i in range(num_models):\n                new_save_name = os.path.join(base_path, f\"{size}_examples_{i}.pt\")\n                new_train_file = os.path.join(os.path.dirname(__file__), \"data\", \"processed_ud_en\", \"combined_train.txt\")\n                args.save_name = new_save_name\n                args.train_file = new_train_file\n                train_lstm_main(predefined_args=args)\n\n    if args.change_param == \"base\":\n        for i in range(num_models):\n            new_save_name = os.path.join(base_path, f\"lstm_model_{i}.pt\")\n            args.save_name = new_save_name\n            args.weighted_loss = False\n            train_lstm_main(predefined_args=args)\n\n            if not args.weighted_loss:\n                args.weighted_loss = True\n                new_save_name = os.path.join(base_path, f\"lstm_model_wloss_{i}.pt\")\n                args.save_name = new_save_name\n                train_lstm_main(predefined_args=args)\n\n    if args.change_param == \"base_charlm\":\n        for i in range(num_models):\n            new_save_name = os.path.join(base_path, f\"lstm_charlm_{i}.pt\")\n            args.save_name = new_save_name\n            train_lstm_main(predefined_args=args)\n\n    if args.change_param == \"base_charlm_upos\":\n        for i in range(num_models):\n            new_save_name = os.path.join(base_path, f\"lstm_charlm_upos_{i}.pt\")\n            args.save_name = new_save_name\n            train_lstm_main(predefined_args=args)\n\n    if args.change_param == \"base_upos\":\n        for i in range(num_models):\n            new_save_name = os.path.join(base_path, f\"lstm_upos_{i}.pt\")\n            args.save_name = new_save_name\n            train_lstm_main(predefined_args=args)\n\n    if args.change_param == \"attn_model\":\n        for i in range(num_models):\n            new_save_name = os.path.join(base_path, f\"attn_model_{args.num_heads}_heads_{i}.pt\")\n            args.save_name = new_save_name\n            train_lstm_main(predefined_args=args)\n\ndef train_n_tfmrs(num_models: int, base_path: str, args):\n\n    if args.multi_train_type == \"tfmr\":\n\n        for i in range(num_models):\n\n            if args.change_param == \"bert\":\n                new_save_name = os.path.join(base_path, f\"bert_{i}.pt\")\n                args.save_name = new_save_name\n                args.loss_fn = \"ce\"\n                train_tfmr_main(predefined_args=args)\n\n                new_save_name = os.path.join(base_path, f\"bert_wloss_{i}.pt\")\n                args.save_name = new_save_name\n                args.loss_fn = \"weighted_bce\"\n                train_tfmr_main(predefined_args=args)\n\n            elif args.change_param == \"roberta\":\n                new_save_name = os.path.join(base_path, f\"roberta_{i}.pt\")\n                args.save_name = new_save_name\n                args.loss_fn = \"ce\"\n                train_tfmr_main(predefined_args=args)\n\n                new_save_name = os.path.join(base_path, f\"roberta_wloss_{i}.pt\")\n                args.save_name = new_save_name\n                args.loss_fn = \"weighted_bce\"\n                train_tfmr_main(predefined_args=args)\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--hidden_dim\", type=int, default=256, help=\"Size of hidden layer\")\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), \"pretrain\", \"glove.pt\"), help='Exact name of the pretrain file to read')\n    parser.add_argument(\"--charlm\", action='store_true', dest='use_charlm', default=False, help=\"Whether not to use the charlm embeddings\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument(\"--charlm_forward_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"charlm_files\", \"1billion_forward.pt\"), help=\"Path to forward charlm file\")\n    parser.add_argument(\"--charlm_backward_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"charlm_files\", \"1billion_backwards.pt\"), help=\"Path to backward charlm file\")\n    parser.add_argument(\"--upos_emb_dim\", type=int, default=20, help=\"Dimension size for UPOS tag embeddings.\")\n    parser.add_argument(\"--use_attn\", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')\n    parser.add_argument(\"--num_heads\", type=int, default=0, help=\"Number of heads to use for multihead attention.\")\n    parser.add_argument(\"--save_name\", type=str, default=os.path.join(os.path.dirname(__file__), \"saved_models\", \"lemma_classifier_model_weighted_loss_charlm_new.pt\"), help=\"Path to model save file\")\n    parser.add_argument(\"--lr\", type=float, default=0.001, help=\"learning rate\")\n    parser.add_argument(\"--num_epochs\", type=float, default=10, help=\"Number of training epochs\")\n    parser.add_argument(\"--batch_size\", type=int, default=DEFAULT_BATCH_SIZE, help=\"Number of examples to include in each batch\")\n    parser.add_argument(\"--train_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"data\", \"processed_ud_en\", \"combined_train.txt\"), help=\"Full path to training file\")\n    parser.add_argument(\"--weighted_loss\", action='store_true', dest='weighted_loss', default=False, help=\"Whether to use weighted loss during training.\")\n    parser.add_argument(\"--eval_file\", type=str, default=os.path.join(os.path.dirname(__file__), \"data\", \"processed_ud_en\", \"combined_dev.txt\"), help=\"Path to dev file used to evaluate model for saves\")\n    # Tfmr-specific args\n    parser.add_argument(\"--model_type\", type=str, default=\"roberta\", help=\"Which transformer to use ('bert' or 'roberta')\")\n    parser.add_argument(\"--bert_model\", type=str, default=None, help=\"Use a specific transformer instead of the default bert/roberta\")\n    parser.add_argument(\"--loss_fn\", type=str, default=\"weighted_bce\", help=\"Which loss function to train with (e.g. 'ce' or 'weighted_bce')\")\n    # Multi-model train args\n    parser.add_argument(\"--multi_train_type\", type=str, default=\"lstm\", help=\"Whether you are attempting to multi-train an LSTM or transformer\")\n    parser.add_argument(\"--multi_train_count\", type=int, default=5, help=\"Number of each model to build\")\n    parser.add_argument(\"--base_path\", type=str, default=None, help=\"Path to start generating model type for.\")\n    parser.add_argument(\"--change_param\", type=str, default=None, help=\"Which hyperparameter to change when training\")\n\n\n    args = parser.parse_args()\n\n    if args.multi_train_type == \"lstm\":\n        train_n_models(num_models=args.multi_train_count,\n                       base_path=args.base_path,\n                       args=args)\n    elif args.multi_train_type == \"tfmr\":\n        train_n_tfmrs(num_models=args.multi_train_count,\n                      base_path=args.base_path,\n                      args=args)\n    else:\n        raise ValueError(f\"Improper input {args.multi_train_type}\")\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/models/lemma_classifier/train_transformer_model.py",
    "content": "\"\"\"\nThis file contains code used to train a baseline transformer model to classify on a lemma of a particular token.\n\"\"\"\n\nimport argparse\nimport os\nimport sys\nimport logging\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer\nfrom stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE\nfrom stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer\nfrom stanza.models.common.utils import default_device\n\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\nclass TransformerBaselineTrainer(BaseLemmaClassifierTrainer):\n    \"\"\"\n    Class to assist with training a baseline transformer model to classify on token lemmas.\n    To find the model spec, refer to `model.py` in this directory.\n    \"\"\"\n\n    def __init__(self, model_args: dict, transformer_name: str = \"roberta\", loss_func: str = \"ce\", lr: int = 0.001):\n        \"\"\"\n        Creates the Trainer object\n\n        Args:\n            transformer_name (str, optional): What kind of transformer to use for embeddings. Defaults to \"roberta\".\n            loss_func (str, optional): Which loss function to use (either 'ce' or 'weighted_bce'). Defaults to \"ce\".\n            lr (int, optional): learning rate for the optimizer. Defaults to 0.001.\n        \"\"\"\n        super().__init__()\n\n        self.model_args = model_args\n\n        # Find loss function\n        if loss_func == \"ce\":\n            self.criterion = nn.CrossEntropyLoss()\n            self.weighted_loss = False\n        elif loss_func == \"weighted_bce\":\n            self.criterion = nn.BCEWithLogitsLoss()\n            self.weighted_loss = True  # used to add weights during train time.\n        else:\n            raise ValueError(\"Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')\")\n\n        self.transformer_name = transformer_name\n        self.lr = lr\n\n    def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torch.optim:\n        \"\"\"\n        Sets learning rates for each layer of the model.\n        Currently, the model has the transformer layer and the MLP layer, so these are tweakable.\n\n        Returns (torch.optim): An Adam optimizer with the learning rates adjusted per layer.\n\n        Currently unused - could be refactored into the parent class's train method,\n        or the parent class could call a build_optimizer and this subclass would use the optimizer\n        \"\"\"\n        transformer_params, mlp_params = [], []\n        for name, param in self.model.named_parameters():\n            if 'transformer' in name:\n                transformer_params.append(param)\n            elif 'mlp' in name:\n                mlp_params.append(param)\n        optimizer = optim.Adam([\n            {\"params\": transformer_params, \"lr\": transformer_lr},\n            {\"params\": mlp_params, \"lr\": mlp_lr}\n        ])\n        return optimizer\n\n    def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):\n        return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words, target_upos=target_upos)\n\n\ndef main(args=None, predefined_args=None):\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--save_name\", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), \"saved_models\", \"big_model_roberta_weighted_loss.pt\"), help=\"Path to model save file\")\n    parser.add_argument(\"--num_epochs\", type=float, default=10, help=\"Number of training epochs\")\n    parser.add_argument(\"--train_file\", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), \"test_sets\", \"combined_train.txt\"), help=\"Full path to training file\")\n    parser.add_argument(\"--model_type\", type=str, default=\"roberta\", help=\"Which transformer to use ('bert' or 'roberta')\")\n    parser.add_argument(\"--bert_model\", type=str, default=None, help=\"Use a specific transformer instead of the default bert/roberta\")\n    parser.add_argument(\"--loss_fn\", type=str, default=\"weighted_bce\", help=\"Which loss function to train with (e.g. 'ce' or 'weighted_bce')\")\n    parser.add_argument(\"--batch_size\", type=int, default=DEFAULT_BATCH_SIZE, help=\"Number of examples to include in each batch\")\n    parser.add_argument(\"--eval_file\", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), \"test_sets\", \"combined_dev.txt\"), help=\"Path to dev file used to evaluate model for saves\")\n    parser.add_argument(\"--lr\", type=float, default=0.001, help=\"Learning rate for the optimizer.\")\n    parser.add_argument(\"--force\", action='store_true', default=False, help='Whether or not to clobber an existing save file')\n\n    args = parser.parse_args(args) if predefined_args is None else predefined_args\n\n    save_name = args.save_name\n    num_epochs = args.num_epochs\n    train_file = args.train_file\n    loss_fn = args.loss_fn\n    eval_file = args.eval_file\n    lr = args.lr\n\n    args = vars(args)\n\n    if args['model_type'] == 'bert':\n        args['bert_model'] = 'bert-base-uncased'\n    elif args['model_type'] == 'roberta':\n        args['bert_model'] = 'roberta-base'\n    elif args['model_type'] == 'transformer':\n        if args['bert_model'] is None:\n            raise ValueError(\"Need to specify a bert_model for model_type transformer!\")\n    else:\n        raise ValueError(\"Unknown model type \" + args['model_type'])\n\n    if os.path.exists(save_name) and not args.get('force', False):\n        raise FileExistsError(f\"Save name {save_name} already exists. Training would override existing data. Aborting...\")\n    if not os.path.exists(train_file):\n        raise FileNotFoundError(f\"Training file {train_file} not found. Try again with a valid path.\")\n\n    logger.info(\"Running training script with the following args:\")\n    for arg in args:\n        logger.info(f\"{arg}: {args[arg]}\")\n    logger.info(\"------------------------------------------------------------\")\n\n    trainer = TransformerBaselineTrainer(model_args=args, transformer_name=args['bert_model'], loss_func=loss_fn, lr=lr)\n\n    trainer.train(num_epochs=num_epochs, save_name=save_name, train_file=train_file, args=args, eval_file=eval_file)\n    return trainer\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/models/lemma_classifier/transformer_model.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nimport sys\nimport logging\n\nfrom transformers import AutoTokenizer, AutoModel\nfrom typing import Mapping, List, Tuple, Any\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence\nfrom stanza.models.common.bert_embedding import extract_bert_embeddings\nfrom stanza.models.lemma_classifier.base_model import LemmaClassifier\nfrom stanza.models.lemma_classifier.constants import ModelType\n\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\nclass LemmaClassifierWithTransformer(LemmaClassifier):\n    def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set, target_upos: set):\n        \"\"\"\n        Model architecture:\n\n            Use a transformer (BERT or RoBERTa) to extract contextual embedding over a sentence.\n            Get the embedding for the word that is to be classified on, and feed the embedding\n            as input to an MLP classifier that has 2 linear layers, and a prediction head.\n\n        Args:\n            model_args (dict): args for the model\n            output_dim (int): Dimension of the output from the MLP\n            transformer_name (str): name of the HF transformer to use\n            label_decoder (dict): a map of the labels available to the model\n            target_words (set(str)): a set of the words which might need lemmatization\n        \"\"\"\n        super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words, target_upos)\n        self.model_args = model_args\n\n        # Choose transformer\n        self.transformer_name = transformer_name\n        self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True)\n        self.add_unsaved_module(\"transformer\", AutoModel.from_pretrained(transformer_name))\n        config = self.transformer.config\n\n        embedding_size = config.hidden_size\n\n        # define an MLP layer\n        self.mlp = nn.Sequential(\n            nn.Linear(embedding_size, 64),\n            nn.ReLU(),\n            nn.Linear(64, output_dim)\n        )\n\n    def get_save_dict(self):\n        save_dict = {\n            \"params\": self.state_dict(),\n            \"label_decoder\": self.label_decoder,\n            \"target_words\": list(self.target_words),\n            \"target_upos\": list(self.target_upos),\n            \"model_type\": self.model_type().name,\n            \"args\": self.model_args,\n        }\n        skipped = [k for k in save_dict[\"params\"].keys() if self.is_unsaved_module(k)]\n        for k in skipped:\n            del save_dict[\"params\"][k]\n        return save_dict\n\n    def convert_tags(self, upos_tags: List[List[str]]):\n        return None\n\n    def forward(self, idx_positions: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):\n        \"\"\"\n        Computes the forward pass of the transformer baselines\n\n        Args:\n            idx_positions (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.\n            sentences (List[List[str]]): A list of the token-split sentences of the input data.\n            upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence - not used in this model, here for compatibility\n\n        Returns:\n            torch.tensor: Output logits of the neural network, where the shape is  (n, output_size) where n is the number of sentences.\n        \"\"\"\n        device = next(self.transformer.parameters()).device\n        bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device,\n                                                  keep_endpoints=False, num_layers=1, detach=True)\n        embeddings = [emb[idx] for idx, emb in zip(idx_positions, bert_embeddings)]\n        embeddings = torch.stack(embeddings, dim=0)[:, :, 0]\n        # pass to the MLP\n        output = self.mlp(embeddings)\n        return output\n\n    def model_type(self):\n        return ModelType.TRANSFORMER\n"
  },
  {
    "path": "stanza/models/lemma_classifier/utils.py",
    "content": "from collections import Counter\nimport json\nimport logging\nimport os\nimport random\nfrom typing import List, Tuple, Any, Mapping\n\nimport stanza\nimport torch\n\nfrom stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE\n\nlogger = logging.getLogger('stanza.lemmaclassifier')\n\nclass Dataset:\n    def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True):\n        \"\"\"\n        Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.\n\n        Args:\n            data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.\n            batch_size (int): Size of each batch of examples\n            get_counts (optional, bool): Whether there should be a map of the label index to counts\n\n        Returns:\n            1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence\n            2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence\n            3. List[torch.tensor[int]]: A batch of labels for the target token's lemma\n            4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.)\n            5 (Optional): A mapping of label ID to counts in the dataset.\n            6. Mapping[str, int]: A map between the labels and their indexes\n            7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches\n        \"\"\"\n\n        if data_path is None or not os.path.exists(data_path):\n            raise FileNotFoundError(f\"Data file {data_path} could not be found.\")\n\n        if label_decoder is None:\n            label_decoder = {}\n        else:\n            # if labels in the test set aren't in the original model,\n            # the model will never predict those labels,\n            # but we can still use those labels in a confusion matrix\n            label_decoder = dict(label_decoder)\n\n        logger.debug(\"Final label decoder: %s  Should be strings to ints\", label_decoder)\n\n        # words which we are analyzing\n        target_words = set()\n\n        # all known words in the dataset, not just target words\n        known_words = set()\n\n        with open(data_path, \"r+\", encoding=\"utf-8\") as fin:\n            sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), {}\n\n            input_json = json.load(fin)\n            sentences_data = input_json['sentences']\n            self.target_upos = input_json['upos']\n\n            for idx, sentence in enumerate(sentences_data):\n                # TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons\n                words, target_idx, upos_tags, label = sentence.get(\"words\"), sentence.get(\"index\"), sentence.get(\"upos_tags\"), sentence.get(\"lemma\")\n                if None in [words, target_idx, upos_tags, label]:\n                    raise ValueError(f\"Expected data to be complete but found a null value in sentence {idx}: {sentence}\")\n\n                label_id = label_decoder.get(label, None)\n                if label_id is None:\n                    label_decoder[label] = len(label_decoder)  # create a new ID for the unknown label\n\n                converted_upos_tags = []  # convert upos tags to upos IDs\n                for upos_tag in upos_tags:\n                    if upos_tag not in upos_to_id:\n                        upos_to_id[upos_tag] = len(upos_to_id)  # create a new ID for the unknown UPOS tag\n                    converted_upos_tags.append(upos_to_id[upos_tag])\n\n                sentences.append(words)\n                indices.append(target_idx)\n                upos_ids.append(converted_upos_tags)\n                labels.append(label_decoder[label])\n\n                if get_counts:\n                    counts[label_decoder[label]] += 1\n\n                target_words.add(words[target_idx])\n                known_words.update(words)\n\n        self.sentences = sentences\n        self.indices = indices\n        self.upos_ids = upos_ids\n        self.labels = labels\n\n        self.counts = counts\n        self.label_decoder = label_decoder\n        self.upos_to_id = upos_to_id\n\n        self.batch_size = batch_size\n        self.shuffle = shuffle\n\n        self.known_words = [x.lower() for x in sorted(known_words)]\n        self.target_words = set(x.lower() for x in target_words)\n\n    def __len__(self):\n        \"\"\"\n        Number of batches, rounded up to nearest batch\n        \"\"\"\n        return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0)\n\n    def __iter__(self):\n        num_sentences = len(self.sentences)\n        indices = list(range(num_sentences))\n        if self.shuffle:\n            random.shuffle(indices)\n        for i in range(self.__len__()):\n            batch_start = self.batch_size * i\n            batch_end = min(batch_start + self.batch_size, num_sentences)\n\n            batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]]\n            batch_indices =   torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]])\n            batch_upos_ids =  [self.upos_ids[x] for x in indices[batch_start:batch_end]]\n            batch_labels =    torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]])\n            yield batch_sentences, batch_indices, batch_upos_ids, batch_labels\n\ndef extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]:\n    \"\"\"\n    Extracts the indices within `tokenized_indices` which match `unknown_token_idx`\n\n    Args:\n        tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices.\n        unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors.\n\n    Returns:\n        List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index`\n    \"\"\"\n    return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx]\n\n\ndef get_device():\n    \"\"\"\n    Get the device to run computations on\n    \"\"\"\n    if torch.cuda.is_available:\n        device = torch.device(\"cuda\")\n    if torch.backends.mps.is_available():\n        device = torch.device(\"mps\")\n    else:\n        device = torch.device(\"cpu\")\n\n    return device\n\n\ndef round_up_to_multiple(number, multiple):\n    if multiple == 0:\n        return \"Error: The second number (multiple) cannot be zero.\"\n\n    # Calculate the remainder when dividing the number by the multiple\n    remainder = number % multiple\n\n    # If remainder is non-zero, round up to the next multiple\n    if remainder != 0:\n        rounded_number = number + (multiple - remainder)\n    else:\n        rounded_number = number  # No rounding needed\n\n    return rounded_number\n\n\ndef main():\n    default_test_path = os.path.join(os.path.dirname(__file__), \"test_sets\", \"processed_ud_en\", \"combined_dev.txt\")   # get the GUM stuff\n    sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/models/lemmatizer.py",
    "content": "\"\"\"\nEntry point for training and evaluating a lemmatizer.\n\nThis lemmatizer combines a neural sequence-to-sequence architecture with an `edit` classifier \nand two dictionaries to produce robust lemmas from word forms.\nFor details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.\n\"\"\"\n\nimport logging\nimport sys\nimport os\nimport shutil\nimport time\nfrom datetime import datetime\nimport argparse\nimport numpy as np\nimport random\nimport torch\nfrom torch import nn, optim\n\nfrom stanza.models.lemma.data import DataLoader\nfrom stanza.models.lemma.vocab import Vocab\nfrom stanza.models.lemma.trainer import Trainer\nfrom stanza.models.lemma import scorer, edit\nfrom stanza.models.common import utils\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.models.common.doc import *\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models import _training_logging\n\nlogger = logging.getLogger('stanza')\n\ndef build_argparse():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')\n    parser.add_argument('--train_file', type=str, default=None, help='Training input file for data loader.')\n    parser.add_argument('--eval_file', type=str, default=None, help='Evaluation input file for data loader.')\n    parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')\n\n    parser.add_argument('--mode', default='train', choices=['train', 'predict'])\n    parser.add_argument('--shorthand', type=str, help='Shorthand for the dataset to use.  lang_dataset')\n\n    parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default use ensemble.')\n    parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based lemmatizer.')\n\n    parser.add_argument('--hidden_dim', type=int, default=200)\n    parser.add_argument('--emb_dim', type=int, default=50)\n    parser.add_argument('--num_layers', type=int, default=1)\n    parser.add_argument('--emb_dropout', type=float, default=0.5)\n    parser.add_argument('--dropout', type=float, default=0.5)\n    parser.add_argument('--max_dec_len', type=int, default=50)\n    parser.add_argument('--beam_size', type=int, default=1)\n\n    parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')\n    parser.add_argument('--pos_dim', type=int, default=50)\n    parser.add_argument('--pos_dropout', type=float, default=0.5)\n    parser.add_argument('--no_edit', dest='edit', action='store_false', help='Do not use edit classifier in lemmatization. By default use an edit classifier.')\n    parser.add_argument('--num_edit', type=int, default=len(edit.EDIT_TO_ID))\n    parser.add_argument('--alpha', type=float, default=1.0)\n    parser.add_argument('--no_pos', dest='pos', action='store_false', help='Do not use UPOS in lemmatization. By default UPOS is used.')\n    parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in lemmatization. By default copy mechanism is used to improve generalization.')\n\n    parser.add_argument('--charlm', action='store_true', help=\"Turn on contextualized char embedding using pretrained character-level language model.\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n\n    parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')\n    parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')\n    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')\n    parser.add_argument('--lr_decay', type=float, default=0.9)\n    parser.add_argument('--decay_epoch', type=int, default=30, help=\"Decay the lr starting from this epoch.\")\n    parser.add_argument('--num_epoch', type=int, default=60)\n    parser.add_argument('--batch_size', type=int, default=50)\n    parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')\n    parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')\n    parser.add_argument('--save_dir', type=str, default='saved_models/lemma', help='Root dir for saving models.')\n    parser.add_argument('--save_name', type=str, default=\"{shorthand}_{embedding}_lemmatizer.pt\", help=\"File name to save the model\")\n\n    parser.add_argument('--caseless', default=False, action='store_true', help='Lowercase everything first before processing.  This will happen automatically if 100%% of the data is caseless')\n    parser.add_argument('--skip_blank_lemmas', default=False, action='store_true', help='Skip blank entries in the data files.  Useful for training a lemmatizer from a partially annotated dataset')\n\n    parser.add_argument('--seed', type=int, default=1234)\n    utils.add_device_args(parser)\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n    return parser\n\ndef parse_args(args=None):\n    parser = build_argparse()\n    args = parser.parse_args(args=args)\n\n    if args.wandb_name:\n        args.wandb = True\n\n    args = vars(args)\n    # when building the vocab, we keep track of the original language name\n    lang = args['shorthand'].split(\"_\")[0] if args['shorthand'] else \"\"\n    args['lang'] = lang\n    return args\n\ndef main(args=None):\n    args = parse_args(args=args)\n\n    utils.set_random_seed(args['seed'])\n\n    logger.info(\"Running lemmatizer in {} mode\".format(args['mode']))\n\n    if args['mode'] == 'train':\n        train(args)\n    else:\n        evaluate(args)\n\ndef all_lowercase(doc):\n    for sentence in doc.sentences:\n        for word in sentence.words:\n            if word.text.lower() != word.text:\n                return False\n    return True\n\ndef build_model_filename(args):\n    embedding = \"nocharlm\"\n    if args['charlm'] and args['charlm_forward_file']:\n        embedding = \"charlm\"\n    model_file = args['save_name'].format(shorthand=args['shorthand'],\n                                          embedding=embedding)\n    model_dir = os.path.split(model_file)[0]\n    if not model_dir.startswith(args['save_dir']):\n        model_file = os.path.join(args['save_dir'], model_file)\n    return model_file\n\ndef train(args):\n    # load data\n    logger.info(\"[Loading data with batch size {}...]\".format(args['batch_size']))\n    train_doc = CoNLL.conll2doc(input_file=args['train_file'])\n    train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)\n    vocab = train_batch.vocab\n    args['vocab_size'] = vocab['char'].size\n    args['pos_vocab_size'] = vocab['pos'].size\n    dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])\n    dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)\n\n    utils.ensure_dir(args['save_dir'])\n    model_file = build_model_filename(args)\n    logger.info(\"Using full savename: %s\", model_file)\n\n    # gold path\n    gold_file = args['eval_file']\n\n    utils.print_config(args)\n\n    # skip training if the language does not have training or dev data\n    if len(train_batch) == 0 or len(dev_batch) == 0:\n        logger.warning(\"[Skip training because no training data available...]\")\n        return\n\n    if not args['caseless'] and all_lowercase(train_doc):\n        logger.info(\"Building a caseless model, as all of the training data is caseless\")\n        args['caseless'] = True\n\n    # start training\n    # train a dictionary-based lemmatizer\n    logger.info(\"Building lemmatizer in %s\", model_file)\n    trainer = Trainer(args=args, vocab=vocab, device=args['device'])\n    logger.info(\"[Training dictionary-based lemmatizer...]\")\n    trainer.train_dict(train_batch.raw_data())\n    logger.info(\"Evaluating on dev set...\")\n    dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS]))\n    dev_batch.doc.set([LEMMA], dev_preds)\n    system_pred_file = \"{:C}\\n\\n\".format(dev_batch.doc)\n    system_pred_file = io.StringIO(system_pred_file)\n    _, _, dev_f = scorer.score(system_pred_file, gold_file)\n    logger.info(\"Dev F1 = {:.2f}\".format(dev_f * 100))\n\n    if args.get('dict_only', False):\n        # save dictionaries\n        trainer.save(model_file)\n    else:\n        if args['wandb']:\n            import wandb\n            wandb_name = args['wandb_name'] if args['wandb_name'] else \"%s_lemmatizer\" % args['shorthand']\n            wandb.init(name=wandb_name, config=args)\n            wandb.run.define_metric('train_loss', summary='min')\n            wandb.run.define_metric('dev_score', summary='max')\n\n        # train a seq2seq model\n        logger.info(\"[Training seq2seq-based lemmatizer...]\")\n        global_step = 0\n        max_steps = len(train_batch) * args['num_epoch']\n        dev_score_history = []\n        best_dev_preds = []\n        current_lr = args['lr']\n        global_start_time = time.time()\n        format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'\n\n        # start training\n        for epoch in range(1, args['num_epoch']+1):\n            train_loss = 0\n            for i, batch in enumerate(train_batch):\n                start_time = time.time()\n                global_step += 1\n                loss = trainer.update(batch, eval=False) # update step\n                train_loss += loss\n                if global_step % args['log_step'] == 0:\n                    duration = time.time() - start_time\n                    logger.info(format_str.format(datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), global_step,\n                                                  max_steps, epoch, args['num_epoch'], loss, duration, current_lr))\n\n            # eval on dev\n            logger.info(\"Evaluating on dev set...\")\n            dev_preds = []\n            dev_edits = []\n            for i, batch in enumerate(dev_batch):\n                preds, edits = trainer.predict(batch, args['beam_size'])\n                dev_preds += preds\n                if edits is not None:\n                    dev_edits += edits\n            dev_preds = trainer.postprocess(dev_batch.doc.get([TEXT]), dev_preds, edits=dev_edits)\n\n            # try ensembling with dict if necessary\n            if args.get('ensemble_dict', False):\n                logger.info(\"[Ensembling dict with seq2seq model...]\")\n                dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds)\n            dev_batch.doc.set([LEMMA], dev_preds)\n            system_pred_file = \"{:C}\\n\\n\".format(dev_batch.doc)\n            system_pred_file = io.StringIO(system_pred_file)\n            _, _, dev_score = scorer.score(system_pred_file, gold_file)\n\n            train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch\n            logger.info(\"epoch {}: train_loss = {:.6f}, dev_score = {:.4f}\".format(epoch, train_loss, dev_score))\n\n            if args['wandb']:\n                wandb.log({'train_loss': train_loss, 'dev_score': dev_score})\n\n            # save best model\n            if epoch == 1 or dev_score > max(dev_score_history):\n                trainer.save(model_file)\n                logger.info(\"new best model saved.\")\n                best_dev_preds = dev_preds\n\n            # lr schedule\n            if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1] and \\\n                    args['optim'] in ['sgd', 'adagrad']:\n                current_lr *= args['lr_decay']\n                trainer.update_lr(current_lr)\n\n            dev_score_history += [dev_score]\n            logger.info(\"\")\n\n        logger.info(\"Training ended with {} epochs.\".format(epoch))\n\n        if args['wandb']:\n            wandb.finish()\n\n        best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1\n        logger.info(\"Best dev F1 = {:.2f}, at epoch = {}\".format(best_f, best_epoch))\n\ndef evaluate(args):\n    # file paths\n    system_pred_file = args['output_file']\n    model_file = build_model_filename(args)\n\n    # load model\n    trainer = Trainer(model_file=model_file, device=args['device'], args=args)\n    loaded_args, vocab = trainer.args, trainer.vocab\n\n    for k in args:\n        if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:\n            loaded_args[k] = args[k]\n\n    # load data\n    logger.info(\"Loading data with batch size {}...\".format(args['batch_size']))\n    doc = CoNLL.conll2doc(input_file=args['eval_file'])\n    batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)\n\n    # skip eval if dev data does not exist\n    if len(batch) == 0:\n        logger.warning(\"Skip evaluation because no dev data is available...\\nLemma score:\\n{} \".format(args['shorthand']))\n        return\n\n    dict_preds = trainer.predict_dict(batch.doc.get([TEXT, UPOS]))\n\n    if loaded_args.get('dict_only', False):\n        preds = dict_preds\n    else:\n        logger.info(\"Running the seq2seq model...\")\n        preds = []\n        edits = []\n        for i, b in enumerate(batch):\n            ps, es = trainer.predict(b, args['beam_size'])\n            preds += ps\n            if es is not None:\n                edits += es\n        preds = trainer.postprocess(batch.doc.get([TEXT]), preds, edits=edits)\n\n        if loaded_args.get('ensemble_dict', False):\n            logger.info(\"[Ensembling dict with seq2seq lemmatizer...]\")\n            preds = trainer.ensemble(batch.doc.get([TEXT, UPOS]), preds)\n\n        if trainer.has_contextual_lemmatizers():\n            preds = trainer.update_contextual_preds(batch.doc, preds)\n\n    # write to file and score\n    batch.doc.set([LEMMA], preds)\n    if system_pred_file:\n        CoNLL.write_doc2conll(batch.doc, system_pred_file)\n\n    system_pred_file = \"{:C}\\n\\n\".format(batch.doc)\n    system_pred_file = io.StringIO(system_pred_file)\n    _, _, score = scorer.score(system_pred_file, args['eval_file'])\n    logger.info(\"Finished evaluation\\nLemma score:\\n{} {:.2f}\".format(args['shorthand'], score*100))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/mwt/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/mwt/character_classifier.py",
    "content": "\"\"\"\nClassify characters based on an LSTM with learned character representations\n\"\"\"\n\nimport logging\n\nimport torch\nfrom torch import nn\n\nimport stanza.models.common.seq2seq_constant as constant\n\nlogger = logging.getLogger('stanza')\n\nclass CharacterClassifier(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n\n        self.vocab_size = args['vocab_size']\n        self.emb_dim = args['emb_dim']\n        self.hidden_dim = args['hidden_dim']\n        self.nlayers = args['num_layers'] # lstm encoder layers\n        self.pad_token = constant.PAD_ID\n        self.enc_hidden_dim = self.hidden_dim // 2   # since it is bidirectional\n\n        self.num_outputs = 2\n\n        self.args = args\n\n        self.emb_dropout = args.get('emb_dropout', 0.0)\n        self.emb_drop = nn.Dropout(self.emb_dropout)\n        self.dropout = args['dropout']\n\n        self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)\n        self.input_dim = self.emb_dim\n        self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \\\n                               bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)\n\n        self.output_layer = nn.Sequential(\n            nn.Linear(self.hidden_dim, self.hidden_dim),\n            nn.ReLU(),\n            nn.Linear(self.hidden_dim, self.num_outputs))\n\n    def encode(self, enc_inputs, lens):\n        \"\"\" Encode source sequence. \"\"\"\n        packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)\n        packed_h_in, (hn, cn) = self.encoder(packed_inputs)\n        return packed_h_in\n\n    def embed(self, src, src_mask):\n        # the input data could have characters outside the known range\n        # of characters in cases where the vocabulary was temporarily\n        # expanded (note that this model does nothing with those chars)\n        embed_src = src.clone()\n        embed_src[embed_src >= self.vocab_size] = constant.UNK_ID\n        enc_inputs = self.emb_drop(self.embedding(embed_src))\n        batch_size = enc_inputs.size(0)\n        src_lens = list(src_mask.data.eq(self.pad_token).long().sum(1))\n        return enc_inputs, batch_size, src_lens, src_mask\n\n    def forward(self, src, src_mask):\n        enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask)\n        encoded = self.encode(enc_inputs, src_lens)\n        encoded, _ = nn.utils.rnn.pad_packed_sequence(encoded, batch_first=True)\n        logits = self.output_layer(encoded)\n        return logits\n"
  },
  {
    "path": "stanza/models/mwt/data.py",
    "content": "import random\nimport numpy as np\nimport os\nfrom collections import Counter, namedtuple\nimport logging\n\nimport torch\nfrom torch.nn.utils.rnn import pad_sequence\nfrom torch.utils.data import DataLoader as DL\n\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all\nfrom stanza.models.common.vocab import DeltaVocab\nfrom stanza.models.mwt.vocab import Vocab\nfrom stanza.models.common.doc import Document\n\nlogger = logging.getLogger('stanza')\n\nDataSample = namedtuple(\"DataSample\", \"src tgt_in tgt_out orig_text\")\nDataBatch = namedtuple(\"DataBatch\", \"src src_mask tgt_in tgt_out orig_text orig_idx\")\n\n# enforce that the MWT splitter knows about a couple different alternate apostrophes\n# including covering some potential \" typos\n# setting the augmentation to a very low value should be enough to teach it\n# about the unknown characters without messing up the predictions for other text\n#\n#      0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07\nAPOS = ('\"',  \"'\",    'ʼ',    'ˊ',    '՚',    'ߴ',    '’',   '＇')\n\nclass DataLoader:\n    def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):\n        self.batch_size = batch_size\n        self.args = args\n        self.augment_apos = args.get('augment_apos', 0.0)\n        self.evaluation = evaluation\n        self.doc = doc\n\n        data = self.load_doc(self.doc, evaluation=self.evaluation)\n\n        # handle vocab\n        if vocab is None:\n            assert self.evaluation == False # for eval vocab must exist\n            self.vocab = self.init_vocab(data)\n            if self.augment_apos > 0 and any(x in self.vocab for x in APOS):\n                for apos in APOS:\n                    self.vocab.add_unit(apos)\n        elif expand_unk_vocab:\n            self.vocab = DeltaVocab(data, vocab)\n        else:\n            self.vocab = vocab\n\n        # filter and sample data\n        if args.get('sample_train', 1.0) < 1.0 and not self.evaluation:\n            keep = int(args['sample_train'] * len(data))\n            data = random.sample(data, keep)\n            logger.debug(\"Subsample training set with rate {:g}\".format(args['sample_train']))\n\n        # shuffle for training\n        if not self.evaluation:\n            indices = list(range(len(data)))\n            random.shuffle(indices)\n            data = [data[i] for i in indices]\n\n        self.data = data\n        self.num_examples = len(data)\n\n    def init_vocab(self, data):\n        assert self.evaluation == False # for eval vocab must exist\n        vocab = Vocab(data, self.args['shorthand'])\n        return vocab\n\n    def maybe_augment_apos(self, datum):\n        for original in APOS:\n            if original in datum[0]:\n                if random.uniform(0,1) < self.augment_apos:\n                    replacement = random.choice(APOS)\n                    datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement))\n                break\n        return datum\n\n    def process(self, sample):\n        if not self.evaluation and self.augment_apos > 0:\n            sample = self.maybe_augment_apos(sample)\n        src = list(sample[0])\n        src = [constant.SOS] + src + [constant.EOS]\n        tgt_in, tgt_out = self.prepare_target(self.vocab, sample)\n        src = self.vocab.map(src)\n        processed = [src, tgt_in, tgt_out, sample[0]]\n        return processed\n\n    def prepare_target(self, vocab, datum):\n        if self.evaluation:\n            tgt = list(datum[0])  # as a placeholder\n        else:\n            tgt = list(datum[1])\n        tgt_in = vocab.map([constant.SOS] + tgt)\n        tgt_out = vocab.map(tgt + [constant.EOS])\n        return tgt_in, tgt_out\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, key):\n        \"\"\" Get a batch with index. \"\"\"\n        if not isinstance(key, int):\n            raise TypeError\n        if key < 0 or key >= len(self.data):\n            raise IndexError\n        sample = self.data[key]\n        sample = self.process(sample)\n        assert len(sample) == 4\n\n        src = torch.tensor(sample[0])\n        tgt_in = torch.tensor(sample[1])\n        tgt_out = torch.tensor(sample[2])\n        orig_text = sample[3]\n        result = DataSample(src, tgt_in, tgt_out, orig_text), key\n        return result\n\n    @staticmethod\n    def __collate_fn(data):\n        (data, idx) = zip(*data)\n        (src, tgt_in, tgt_out, orig_text) = zip(*data)\n\n        # collate_fn is given a list of length batch size\n        batch_size = len(data)\n\n        # need to sort by length of src to properly handle\n        # the batching in the model itself\n        lens = [len(x) for x in src]\n        (src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens)\n        lens = [len(x) for x in src]\n\n        # convert to tensors\n        src = pad_sequence(src, True, constant.PAD_ID)\n        src_mask = torch.eq(src, constant.PAD_ID)\n        tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID)\n        tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID)\n        assert tgt_in.size(1) == tgt_out.size(1), \\\n                \"Target input and output sequence sizes do not match.\"\n        return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx)\n\n    def __iter__(self):\n        for i in range(self.__len__()):\n            yield self.__getitem__(i)\n\n    def to_loader(self):\n        \"\"\"Converts self to a DataLoader \"\"\"\n\n        batch_size = self.batch_size\n        shuffle = not self.evaluation\n        return DL(self,\n                  collate_fn=self.__collate_fn,\n                  batch_size=batch_size,\n                  shuffle=shuffle)\n\n    def load_doc(self, doc, evaluation=False):\n        data = doc.get_mwt_expansions(evaluation)\n        if evaluation: data = [[e] for e in data]\n        return data\n\nclass BinaryDataLoader(DataLoader):\n    \"\"\"\n    This version of the DataLoader performs the same tasks as the regular DataLoader,\n    except the targets are arrays of 0/1 indicating if the character is the location\n    of an MWT split\n    \"\"\"\n    def prepare_target(self, vocab, datum):\n        src = datum[0] if self.evaluation else datum[1]\n        binary = [0]\n        has_space = False\n        for char in src:\n            if char == ' ':\n                has_space = True\n            elif has_space:\n                has_space = False\n                binary.append(1)\n            else:\n                binary.append(0)\n        binary.append(0)\n        return binary, binary\n\n"
  },
  {
    "path": "stanza/models/mwt/scorer.py",
    "content": "\"\"\"\nUtils and wrappers for scoring MWT\n\"\"\"\nfrom stanza.models.common.utils import ud_scores\n\ndef score(system_conllu_file, gold_conllu_file):\n    \"\"\" Wrapper for word segmenter scorer. \"\"\"\n    evaluation = ud_scores(gold_conllu_file, system_conllu_file)\n    el = evaluation[\"Words\"]\n    p, r, f = el.precision, el.recall, el.f1\n    return p, r, f\n\n"
  },
  {
    "path": "stanza/models/mwt/trainer.py",
    "content": "\"\"\"\nA trainer class to handle training and testing of models.\n\"\"\"\n\nimport sys\nimport numpy as np\nfrom collections import Counter\nimport logging\nimport torch\nfrom torch import nn\nimport torch.nn.init as init\n\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.models.common.trainer import Trainer as BaseTrainer\nfrom stanza.models.common.seq2seq_model import Seq2SeqModel\nfrom stanza.models.common import utils, loss\nfrom stanza.models.mwt.character_classifier import CharacterClassifier\nfrom stanza.models.mwt.vocab import Vocab\n\nlogger = logging.getLogger('stanza')\n\ndef unpack_batch(batch, device):\n    \"\"\" Unpack a batch from the data loader. \"\"\"\n    inputs = [b.to(device) if b is not None else None for b in batch[:4]]\n    orig_text = batch[4]\n    orig_idx = batch[5]\n    return inputs, orig_text, orig_idx\n\nclass Trainer(BaseTrainer):\n    \"\"\" A trainer for training models. \"\"\"\n    def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None):\n        if model_file is not None:\n            # load from file\n            self.load(model_file)\n        else:\n            self.args = args\n            if args['dict_only']:\n                self.model = None\n            elif args.get('force_exact_pieces', False):\n                self.model = CharacterClassifier(args)\n            else:\n                self.model = Seq2SeqModel(args, emb_matrix=emb_matrix)\n            self.vocab = vocab\n            self.expansion_dict = dict()\n        if not self.args['dict_only']:\n            self.model = self.model.to(device)\n            if self.args.get('force_exact_pieces', False):\n                self.crit = nn.CrossEntropyLoss()\n            else:\n                self.crit = loss.SequenceLoss(self.vocab.size).to(device)\n            self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'])\n\n    def update(self, batch, eval=False):\n        device = next(self.model.parameters()).device\n        # ignore the original text when training\n        # can try to learn the correct values, even if we eventually\n        # copy directly from the original text\n        inputs, _, orig_idx = unpack_batch(batch, device)\n        src, src_mask, tgt_in, tgt_out = inputs\n\n        if eval:\n            self.model.eval()\n        else:\n            self.model.train()\n            self.optimizer.zero_grad()\n        if self.args.get('force_exact_pieces', False):\n            log_probs = self.model(src, src_mask)\n            src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))\n            packed_output = nn.utils.rnn.pack_padded_sequence(log_probs, src_lens, batch_first=True)\n            packed_tgt = nn.utils.rnn.pack_padded_sequence(tgt_in, src_lens, batch_first=True)\n            loss = self.crit(packed_output.data, packed_tgt.data)\n        else:\n            log_probs, _ = self.model(src, src_mask, tgt_in)\n            loss = self.crit(log_probs.view(-1, self.vocab.size), tgt_out.view(-1))\n        loss_val = loss.data.item()\n        if eval:\n            return loss_val\n\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])\n        self.optimizer.step()\n        return loss_val\n\n    def predict(self, batch, unsort=True, never_decode_unk=False, vocab=None):\n        if vocab is None:\n            vocab = self.vocab\n\n        device = next(self.model.parameters()).device\n        inputs, orig_text, orig_idx = unpack_batch(batch, device)\n        src, src_mask, tgt, tgt_mask = inputs\n\n        self.model.eval()\n        batch_size = src.size(0)\n        if self.args.get('force_exact_pieces', False):\n            log_probs = self.model(src, src_mask)\n            cuts = log_probs[:, :, 1] > log_probs[:, :, 0]\n            src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))\n            pred_tokens = []\n            for src_ids, cut, src_len in zip(src, cuts, src_lens):\n                src_chars = vocab.unmap(src_ids)\n                pred_seq = []\n                for char_idx in range(1, src_len-1):\n                    if cut[char_idx]:\n                        pred_seq.append(' ')\n                    pred_seq.append(src_chars[char_idx])\n                pred_seq = \"\".join(pred_seq).strip()\n                pred_tokens.append(pred_seq)\n        else:\n            preds, _ = self.model.predict(src, src_mask, self.args['beam_size'], never_decode_unk=never_decode_unk)\n            pred_seqs = [vocab.unmap(ids) for ids in preds] # unmap to tokens\n            pred_seqs = utils.prune_decoded_seqs(pred_seqs)\n\n            pred_tokens = [\"\".join(seq) for seq in pred_seqs] # join chars to be tokens\n            # if any tokens are predicted to expand to blank,\n            # that is likely an error.  use the original text\n            # this originally came up with the Spanish model turning 's' into a blank\n            # furthermore, if there are no spaces predicted by the seq2seq,\n            # might as well use the original in case the seq2seq went crazy\n            # this particular error came up training a Hebrew MWT\n            pred_tokens = [x if x and ' ' in x else y for x, y in zip(pred_tokens, orig_text)]\n        if unsort:\n            pred_tokens = utils.unsort(pred_tokens, orig_idx)\n        return pred_tokens\n\n    def train_dict(self, pairs):\n        \"\"\" Train a MWT expander given training word-expansion pairs. \"\"\"\n        # accumulate counter\n        ctr = Counter()\n        ctr.update([(p[0], p[1]) for p in pairs])\n        seen = set()\n        # find the most frequent mappings\n        for p, _ in ctr.most_common():\n            w, l = p\n            if w not in seen and w != l:\n                self.expansion_dict[w] = l\n            seen.add(w)\n        return\n\n    def dict_expansion(self, word):\n        \"\"\"\n        Check the expansion dictionary for the word along with a couple common lowercasings of the word\n\n        (Leadingcase and UPPERCASE)\n        \"\"\"\n        expansion = self.expansion_dict.get(word)\n        if expansion is not None:\n            return expansion\n\n        if word.isupper():\n            expansion = self.expansion_dict.get(word.lower())\n            if expansion is not None:\n                return expansion.upper()\n\n        if word[0].isupper() and word[1:].islower():\n            expansion = self.expansion_dict.get(word.lower())\n            if expansion is not None:\n                return expansion[0].upper() + expansion[1:]\n\n        # could build a truecasing model of some kind to handle cRaZyCaSe...\n        # but that's probably too much effort\n        return None\n\n    def predict_dict(self, words):\n        \"\"\" Predict a list of expansions given words. \"\"\"\n        expansions = []\n        for w in words:\n            expansion = self.dict_expansion(w)\n            if expansion is not None:\n                expansions.append(expansion)\n            else:\n                expansions.append(w)\n        return expansions\n\n    def ensemble(self, cands, other_preds):\n        \"\"\" Ensemble the dict with statistical model predictions. \"\"\"\n        expansions = []\n        assert len(cands) == len(other_preds)\n        for c, pred in zip(cands, other_preds):\n            expansion = self.dict_expansion(c)\n            if expansion is not None:\n                expansions.append(expansion)\n            else:\n                expansions.append(pred)\n        return expansions\n\n    def save(self, filename):\n        params = {\n                'model': self.model.state_dict() if self.model is not None else None,\n                'dict': self.expansion_dict,\n                'vocab': self.vocab.state_dict(),\n                'config': self.args\n                }\n        try:\n            torch.save(params, filename, _use_new_zipfile_serialization=False)\n            logger.info(\"Model saved to {}\".format(filename))\n        except BaseException:\n            logger.warning(\"Saving failed... continuing anyway.\")\n\n    def load(self, filename):\n        try:\n            checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        except BaseException:\n            logger.error(\"Cannot load model from {}\".format(filename))\n            raise\n        self.args = checkpoint['config']\n        self.expansion_dict = checkpoint['dict']\n        if not self.args['dict_only']:\n            if self.args.get('force_exact_pieces', False):\n                self.model = CharacterClassifier(self.args)\n            else:\n                self.model = Seq2SeqModel(self.args)\n            # could remove strict=False after rebuilding all models,\n            # or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False\n            self.model.load_state_dict(checkpoint['model'], strict=False)\n        else:\n            self.model = None\n        self.vocab = Vocab.load_state_dict(checkpoint['vocab'])\n\n"
  },
  {
    "path": "stanza/models/mwt/utils.py",
    "content": "import stanza\n\nfrom stanza.models.common import doc\nfrom stanza.models.tokenization.data import TokenizationDataset\nfrom stanza.models.tokenization.utils import predict, decode_predictions\n\ndef mwts_composed_of_words(doc):\n    \"\"\"\n    Return True/False if the MWTs in the doc are all exactly composed of the text in their words\n    \"\"\"\n    for sent_idx, sentence in enumerate(doc.sentences):\n        for token_idx, token in enumerate(sentence.tokens):\n            if len(token.words) > 1:\n                expected = \"\".join(x.text for x in token.words)\n                if token.text != expected:\n                    return False\n    return True\n\n\ndef resplit_mwt(tokens, pipeline, keep_tokens=True):\n    \"\"\"\n    Uses the tokenize processor and the mwt processor in the pipeline to resplit tokens into MWT\n\n    tokens: a list of list of string\n    pipeline: a Stanza pipeline which contains, at a minimum, tokenize and mwt\n\n    keep_tokens: if True, enforce the old token boundaries by modify\n      the results of the tokenize inference.\n      Otherwise, use whatever new boundaries the model comes up with.\n\n    between running the tokenize model and breaking the text into tokens,\n    we can update all_preds to use the original token boundaries\n    (if and only if keep_tokens == True)\n\n    This method returns a Document with just the tokens and words annotated.\n    \"\"\"\n    if \"tokenize\" not in pipeline.processors:\n        raise ValueError(\"Need a Pipeline with a valid tokenize processor\")\n    if \"mwt\" not in pipeline.processors:\n        raise ValueError(\"Need a Pipeline with a valid mwt processor\")\n    tokenize_processor = pipeline.processors[\"tokenize\"]\n    mwt_processor = pipeline.processors[\"mwt\"]\n    fake_text = \"\\n\\n\".join(\" \".join(sentence) for sentence in tokens)\n\n    # set up batches\n    batches = TokenizationDataset(tokenize_processor.config,\n                                  input_text=fake_text,\n                                  vocab=tokenize_processor.vocab,\n                                  evaluation=True,\n                                  dictionary=tokenize_processor.trainer.dictionary)\n\n    all_preds, all_raw = predict(trainer=tokenize_processor.trainer,\n                                 data_generator=batches,\n                                 batch_size=tokenize_processor.trainer.args['batch_size'],\n                                 max_seqlen=tokenize_processor.config.get('max_seqlen', tokenize_processor.MAX_SEQ_LENGTH_DEFAULT),\n                                 use_regex_tokens=True,\n                                 num_workers=tokenize_processor.config.get('num_workers', 0))\n\n    if keep_tokens:\n        for sentence, pred in zip(tokens, all_preds):\n            char_idx = 0\n            for word in sentence:\n                if len(word) > 0:\n                    pred[char_idx:char_idx+len(word)-1] = 0\n                if pred[char_idx+len(word)-1] == 0:\n                    pred[char_idx+len(word)-1] = 1\n                char_idx += len(word) + 1\n\n    _, _, document = decode_predictions(vocab=tokenize_processor.vocab,\n                                        mwt_dict=None,\n                                        orig_text=fake_text,\n                                        all_raw=all_raw,\n                                        all_preds=all_preds,\n                                        no_ssplit=True,\n                                        skip_newline=tokenize_processor.trainer.args['skip_newline'],\n                                        use_la_ittb_shorthand=tokenize_processor.trainer.args['shorthand'] == 'la_ittb')\n\n    document = doc.Document(document, fake_text)\n    mwt_processor.process(document)\n    return document\n\ndef main():\n    pipe = stanza.Pipeline(\"en\", processors=\"tokenize,mwt\", package=\"gum\")\n    tokens = [[\"I\", \"can't\", \"believe\", \"it\"], [\"I can't\", \"sleep\"]]\n    doc = resplit_mwt(tokens, pipe)\n    print(doc)\n\n    doc = resplit_mwt(tokens, pipe, keep_tokens=False)\n    print(doc)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/mwt/vocab.py",
    "content": "from collections import Counter\n\nfrom stanza.models.common.vocab import BaseVocab\nimport stanza.models.common.seq2seq_constant as constant\n\nclass Vocab(BaseVocab):\n    def build_vocab(self):\n        pairs = self.data\n        allchars = \"\".join([src + tgt for src, tgt in pairs])\n        counter = Counter(allchars)\n\n        self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))\n        self._unit2id = {w:i for i, w in enumerate(self._id2unit)}\n\n    def add_unit(self, unit):\n        if unit in self._unit2id:\n            return\n        self._unit2id[unit] = len(self._id2unit)\n        self._id2unit.append(unit)\n"
  },
  {
    "path": "stanza/models/mwt_expander.py",
    "content": "\"\"\"\nEntry point for training and evaluating a multi-word token (MWT) expander.\n\nThis MWT expander combines a neural sequence-to-sequence architecture with a dictionary\nto decode the token into multiple words.\nFor details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf\n\nIn the case of a dataset where all of the MWT exactly split into the words\ncomposing the MWT, a classifier over the characters is used instead of the seq2seq\n\"\"\"\n\nimport io\nimport sys\nimport os\nimport shutil\nimport time\nfrom datetime import datetime\nimport argparse\nimport logging\nimport math\nimport numpy as np\nimport random\nimport torch\nfrom torch import nn, optim\nimport copy\n\nfrom stanza.models.mwt.data import DataLoader, BinaryDataLoader\nfrom stanza.models.mwt.utils import mwts_composed_of_words\nfrom stanza.models.mwt.vocab import Vocab\nfrom stanza.models.mwt.trainer import Trainer\nfrom stanza.models.mwt import scorer\nfrom stanza.models.common import utils\nimport stanza.models.common.seq2seq_constant as constant\nfrom stanza.models.common.doc import Document\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models import _training_logging\n\nlogger = logging.getLogger('stanza')\n\ndef build_argparse():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_dir', type=str, default='data/mwt', help='Root dir for saving models.')\n    parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')\n    parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')\n\n    parser.add_argument('--mode', default='train', choices=['train', 'predict'])\n    parser.add_argument('--lang', type=str, help='Language')\n    parser.add_argument('--shorthand', type=str, help=\"Treebank shorthand\")\n\n    parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default ensemble a dict.')\n    parser.add_argument('--ensemble_early_stop', action='store_true', help='Early stopping based on ensemble performance.')\n    parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based MWT expander.')\n\n    parser.add_argument('--hidden_dim', type=int, default=100)\n    parser.add_argument('--emb_dim', type=int, default=50)\n    parser.add_argument('--num_layers', type=int, default=None, help='Number of layers in model encoder.  Defaults to 1 for seq2seq, 2 for classifier')\n    parser.add_argument('--emb_dropout', type=float, default=0.5)\n    parser.add_argument('--dropout', type=float, default=0.5)\n    parser.add_argument('--max_dec_len', type=int, default=50)\n    parser.add_argument('--beam_size', type=int, default=1)\n    parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')\n    parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.')\n\n    parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\\'| to |\"| |’| |ʼ|')\n    parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself.  (By default, this is determined from the dataset.)')\n    parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help=\"Don't make the text of the pieces of the MWT add up to the token itself.  (By default, this is determined from the dataset.)\")\n\n    parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')\n    parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')\n    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')\n    parser.add_argument('--lr_decay', type=float, default=0.9)\n    parser.add_argument('--decay_epoch', type=int, default=30, help=\"Decay the lr starting from this epoch.\")\n    parser.add_argument('--num_epoch', type=int, default=30)\n    parser.add_argument('--batch_size', type=int, default=50)\n    parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')\n    parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')\n    parser.add_argument('--save_dir', type=str, default='saved_models/mwt', help='Root dir for saving models.')\n    parser.add_argument('--save_name', type=str, default=None, help=\"File name to save the model\")\n    parser.add_argument('--save_each_name', type=str, default=None, help=\"Save each model in sequence to this pattern.  Mostly for testing\")\n\n    parser.add_argument('--seed', type=int, default=1234)\n    utils.add_device_args(parser)\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n    return parser\n\ndef parse_args(args=None):\n    parser = build_argparse()\n    args = parser.parse_args(args=args)\n\n    if args.wandb_name:\n        args.wandb = True\n\n    return args\n\ndef main(args=None):\n    args = parse_args(args=args)\n\n    utils.set_random_seed(args.seed)\n\n    args = vars(args)\n    logger.info(\"Running MWT expander in {} mode\".format(args['mode']))\n\n    if args['mode'] == 'train':\n        return train(args)\n    else:\n        return evaluate(args)\n\ndef train(args):\n    # load data\n    logger.debug('max_dec_len: %d' % args['max_dec_len'])\n    logger.debug(\"Loading data with batch size {}...\".format(args['batch_size']))\n    train_doc = CoNLL.conll2doc(input_file=args['train_file'])\n    train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)\n    vocab = train_batch.vocab\n    args['vocab_size'] = vocab.size\n    dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])\n    dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)\n\n    utils.ensure_dir(args['save_dir'])\n    save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])\n    model_file = os.path.join(args['save_dir'], save_name)\n\n    save_each_name = None\n    if args['save_each_name']:\n        save_each_name = os.path.join(args['save_dir'], args['save_each_name'])\n        save_each_name = utils.build_save_each_filename(save_each_name)\n\n    # pred and gold path\n    gold_file = args['gold_file']\n\n    # skip training if the language does not have training or dev data\n    if len(train_batch) == 0:\n        logger.warning(\"Skip training because no data available...\")\n        return\n    dev_mwt = dev_doc.get_mwt_expansions(False)\n    if len(dev_batch) == 0 and args.get('dict_only', False):\n        logger.warning(\"Training data available, but dev data has no MWTs.  Only training a dict based MWT\")\n        args['dict_only'] = True\n\n    if args['force_exact_pieces'] and not mwts_composed_of_words(train_doc):\n        raise ValueError(\"Cannot train model with --force_exact_pieces, as the MWT in this dataset are not entirely composed of their subwords\")\n\n    if args['force_exact_pieces'] is None and mwts_composed_of_words(train_doc):\n        # the force_exact_pieces mechanism trains a separate version of the MWT expander in the Trainer\n        # (the training loop here does not need to change)\n        # in this model, a classifier distinguishes whether or not a location is a split\n        # and the text is copied exactly from the input rather than created via seq2seq\n        # this behavior can be turned off at training time with --no_force_exact_pieces\n        logger.info(\"Train MWTs entirely composed of their subwords.  Training the MWT to match that paradigm as closely as possible\")\n        args['force_exact_pieces'] = True\n\n    if args['force_exact_pieces']:\n        logger.info(\"Reconverting to BinaryDataLoader\")\n        train_batch = BinaryDataLoader(train_doc, args['batch_size'], args, evaluation=False)\n        vocab = train_batch.vocab\n        args['vocab_size'] = vocab.size\n        dev_batch = BinaryDataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)\n\n    if args['num_layers'] is None:\n        if args['force_exact_pieces']:\n            args['num_layers'] = 2\n        else:\n            args['num_layers'] = 1\n\n    # train a dictionary-based MWT expander\n    trainer = Trainer(args=args, vocab=vocab, device=args['device'])\n    logger.info(\"Training dictionary-based MWT expander...\")\n    trainer.train_dict(train_batch.doc.get_mwt_expansions(evaluation=False))\n    logger.info(\"Evaluating on dev set...\")\n    dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True))\n    doc = copy.deepcopy(dev_batch.doc)\n    doc.set_mwt_expansions(dev_preds, fake_dependencies=True)\n    system_preds = \"{:C}\\n\\n\".format(doc)\n    system_preds = io.StringIO(system_preds)\n    _, _, dev_f = scorer.score(system_preds, gold_file)\n    logger.info(\"Dev F1 = {:.2f}\".format(dev_f * 100))\n\n    if args.get('dict_only', False):\n        # save dictionaries\n        trainer.save(model_file)\n    else:\n        # train a seq2seq model\n        logger.info(\"Training seq2seq-based MWT expander...\")\n        global_step = 0\n        steps_per_epoch = math.ceil(len(train_batch) / args['batch_size'])\n        max_steps = steps_per_epoch * args['num_epoch']\n        dev_score_history = []\n        best_dev_preds = []\n        current_lr = args['lr']\n        global_start_time = time.time()\n        format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'\n\n        if args['wandb']:\n            import wandb\n            wandb_name = args['wandb_name'] if args['wandb_name'] else \"%s_mwt\" % args['shorthand']\n            wandb.init(name=wandb_name, config=args)\n            wandb.run.define_metric('train_loss', summary='min')\n            wandb.run.define_metric('dev_score', summary='max')\n\n        # start training\n        for epoch in range(1, args['num_epoch']+1):\n            train_loss = 0\n            for i, batch in enumerate(train_batch.to_loader()):\n                start_time = time.time()\n                global_step += 1\n                loss = trainer.update(batch, eval=False) # update step\n                train_loss += loss\n                if global_step % args['log_step'] == 0:\n                    duration = time.time() - start_time\n                    logger.info(format_str.format(datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), global_step,\\\n                                                  max_steps, epoch, args['num_epoch'], loss, duration, current_lr))\n\n            if save_each_name:\n                trainer.save(save_each_name % epoch)\n                logger.info(\"Saved epoch %d model to %s\" % (epoch, save_each_name % epoch))\n\n            # eval on dev\n            logger.info(\"Evaluating on dev set...\")\n            dev_preds = []\n            for i, batch in enumerate(dev_batch.to_loader()):\n                preds = trainer.predict(batch)\n                dev_preds += preds\n            if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False):\n                logger.info(\"[Ensembling dict with seq2seq model...]\")\n                dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds)\n            doc = copy.deepcopy(dev_batch.doc)\n            doc.set_mwt_expansions(dev_preds, fake_dependencies=True)\n            system_preds = \"{:C}\\n\\n\".format(doc)\n            system_preds = io.StringIO(system_preds)\n            _, _, dev_score = scorer.score(system_preds, gold_file)\n            train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch\n            logger.info(\"epoch {}: train_loss = {:.6f}, dev_score = {:.4f}\".format(epoch, train_loss, dev_score))\n\n            if args['wandb']:\n                wandb.log({'train_loss': train_loss, 'dev_score': dev_score})\n\n            # save best model\n            if epoch == 1 or dev_score > max(dev_score_history):\n                trainer.save(model_file)\n                logger.info(\"new best model saved.\")\n                best_dev_preds = dev_preds\n\n            # lr schedule\n            if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1]:\n                current_lr *= args['lr_decay']\n                trainer.change_lr(current_lr)\n\n            dev_score_history += [dev_score]\n\n        logger.info(\"Training ended with {} epochs.\".format(epoch))\n\n        if args['wandb']:\n            wandb.finish()\n\n        best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1\n        logger.info(\"Best dev F1 = {:.2f}, at epoch = {}\".format(best_f, best_epoch))\n\n        # try ensembling with dict if necessary\n        if args.get('ensemble_dict', False):\n            logger.info(\"[Ensembling dict with seq2seq model...]\")\n            dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds)\n            doc = copy.deepcopy(dev_batch.doc)\n            doc.set_mwt_expansions(dev_preds, fake_dependencies=True)\n            system_preds = \"{:C}\\n\\n\".format(doc)\n            system_preds = io.StringIO(system_preds)\n            _, _, dev_score = scorer.score(system_preds, gold_file)\n            logger.info(\"Ensemble dev F1 = {:.2f}\".format(dev_score*100))\n            best_f = max(best_f, dev_score)\n\n    return trainer, _\n\ndef evaluate(args):\n    # file paths\n    system_pred_file = args['output_file']\n    gold_file = args['gold_file']\n    model_file = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])\n    if args['save_dir'] and not model_file.startswith(args['save_dir']) and not os.path.exists(model_file):\n        model_file = os.path.join(args['save_dir'], model_file)\n\n    # load model\n    trainer = Trainer(model_file=model_file, device=args['device'])\n    loaded_args, vocab = trainer.args, trainer.vocab\n\n    for k in args:\n        if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:\n            loaded_args[k] = args[k]\n    logger.debug('max_dec_len: %d' % loaded_args['max_dec_len'])\n\n    # load data\n    logger.debug(\"Loading data with batch size {}...\".format(args['batch_size']))\n    doc = CoNLL.conll2doc(input_file=args['eval_file'])\n    batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)\n\n    if len(batch) > 0:\n        dict_preds = trainer.predict_dict(batch.doc.get_mwt_expansions(evaluation=True))\n        # decide trainer type and run eval\n        if loaded_args['dict_only']:\n            preds = dict_preds\n        else:\n            logger.info(\"Running the seq2seq model...\")\n            preds = []\n            for i, b in enumerate(batch.to_loader()):\n                preds += trainer.predict(b)\n\n            if loaded_args.get('ensemble_dict', False):\n                preds = trainer.ensemble(batch.doc.get_mwt_expansions(evaluation=True), preds)\n    else:\n        # skip eval if dev data does not exist\n        preds = []\n\n    # write to file and score\n    doc = copy.deepcopy(batch.doc)\n    doc.set_mwt_expansions(preds, fake_dependencies=True)\n    if system_pred_file is not None:\n        CoNLL.write_doc2conll(doc, system_pred_file)\n    else:\n        system_pred_file = \"{:C}\\n\\n\".format(doc)\n        system_pred_file = io.StringIO(system_pred_file)\n\n    if gold_file is not None:\n        _, _, score = scorer.score(system_pred_file, gold_file)\n\n        logger.info(\"MWT expansion score: {} {:.2f}\".format(args['shorthand'], score*100))\n\n    return trainer, doc\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/ner/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/ner/data.py",
    "content": "import random\nimport logging\nimport torch\n\nfrom stanza.models.common.bert_embedding import filter_data, needs_length_filter\nfrom stanza.models.common.data import map_to_ids, get_long_tensor, sort_all\nfrom stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX\nfrom stanza.models.pos.vocab import CharVocab, CompositeVocab, WordVocab\nfrom stanza.models.ner.vocab import MultiVocab\nfrom stanza.models.common.doc import *\nfrom stanza.models.ner.utils import process_tags, normalize_empty_tags\n\nlogger = logging.getLogger('stanza')\n\nclass DataLoader:\n    def __init__(self, doc, batch_size, args, pretrain=None, vocab=None, evaluation=False, preprocess_tags=True, bert_tokenizer=None, scheme=None, max_batch_words=None):\n        self.max_batch_words = max_batch_words\n        self.batch_size = batch_size\n        self.args = args\n        self.eval = evaluation\n        self.shuffled = not self.eval\n        self.doc = doc\n        self.preprocess_tags = preprocess_tags\n\n        data = self._load_doc(self.doc, scheme)\n\n        # filter out the long sentences if bert is used\n        if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):\n            data = filter_data(self.args['bert_model'], data, bert_tokenizer)\n\n        self.tags = [[w[1] for w in sent] for sent in data]\n        # handle vocab\n        self.pretrain = pretrain\n        if vocab is None:\n            self.vocab = self.init_vocab(data)\n        else:\n            self.vocab = vocab\n\n        # filter and sample data\n        if args.get('sample_train', 1.0) < 1.0 and not self.eval:\n            keep = int(args['sample_train'] * len(data))\n            data = random.sample(data, keep)\n            logger.debug(\"Subsample training set with rate {:g}\".format(args['sample_train']))\n\n        data = self.preprocess(data, self.vocab, args)\n        # shuffle for training\n        if self.shuffled:\n            random.shuffle(data)\n        self.num_examples = len(data)\n\n        # chunk into batches\n        self.data = self.chunk_batches(data)\n        logger.debug(\"{} batches created.\".format(len(self.data)))\n\n    def init_vocab(self, data):\n        def from_model(model_filename):\n            \"\"\" Try loading vocab from charLM model file. \"\"\"\n            state_dict = torch.load(model_filename, lambda storage, loc: storage, weights_only=True)\n            if 'vocab' in state_dict:\n                return state_dict['vocab']\n            if 'model' in state_dict and 'vocab' in state_dict['model']:\n                return state_dict['model']['vocab']\n            raise ValueError(\"Cannot find vocab in charLM model file %s\" % model_filename)\n\n        if self.eval:\n            raise AssertionError(\"Vocab must exist for evaluation.\")\n        if self.args['charlm']:\n            charvocab = CharVocab.load_state_dict(from_model(self.args['charlm_forward_file']))\n        else:\n            charvocab = CharVocab(data, self.args['shorthand'])\n        wordvocab = self.pretrain.vocab if self.pretrain is not None else None\n        tag_data = [[(x[1],) for x in sentence] for sentence in data]\n        tagvocab = CompositeVocab(tag_data, self.args['shorthand'], idx=0, sep=None)\n        ignore = None\n        if self.args['emb_finetune_known_only']:\n            if self.pretrain is None:\n                raise ValueError(\"Cannot train emb_finetune_known_only with no pretrain of known words\")\n            if self.args['lowercase']:\n                ignore = set([w[0].lower() for sent in data for w in sent if w[0] not in wordvocab and w[0].lower() not in wordvocab])\n            else:\n                ignore = set([w[0] for sent in data for w in sent if w[0] not in wordvocab])\n            logger.debug(\"Ignoring %d in the delta vocab as they did not appear in the original embedding\", len(ignore))\n        deltavocab = WordVocab(data, self.args['shorthand'], cutoff=1, lower=self.args['lowercase'], ignore=ignore)\n        logger.debug(\"Creating delta vocab of size %s\", len(deltavocab))\n        vocabs = {'char': charvocab,\n                  'delta': deltavocab,\n                  'tag': tagvocab}\n        if wordvocab is not None:\n            vocabs['word'] = wordvocab\n        vocab = MultiVocab(vocabs)\n        return vocab\n\n    def preprocess(self, data, vocab, args):\n        processed = []\n        if args.get('char_lowercase', False): # handle character case\n            char_case = lambda x: x.lower()\n        else:\n            char_case = lambda x: x\n        for sent_idx, sent in enumerate(data):\n            processed_sent = [[w[0] for w in sent]]\n            processed_sent += [[vocab['char'].map([char_case(x) for x in w[0]]) for w in sent]]\n            processed_sent += [vocab['tag'].map([w[1] for w in sent])]\n            processed.append(processed_sent)\n        return processed\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, key):\n        \"\"\" Get a batch with index. \"\"\"\n        if not isinstance(key, int):\n            raise TypeError\n        if key < 0 or key >= len(self.data):\n            raise IndexError\n        batch = self.data[key]\n        batch_size = len(batch)\n        batch = list(zip(*batch))\n        assert len(batch) == 3 # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[List[int]]]\n\n        # sort sentences by lens for easy RNN operations\n        sentlens = [len(x) for x in batch[0]]\n        batch, orig_idx = sort_all(batch, sentlens)\n        sentlens = [len(x) for x in batch[0]]\n\n        # sort chars by lens for easy char-LM operations\n        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars(batch[1])\n        chars_sorted, char_orig_idx = sort_all([chars_forward, chars_backward, charoffsets_forward, charoffsets_backward], charlens)\n        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted\n        charlens = [len(sent) for sent in chars_forward]\n\n        # sort words by lens for easy char-RNN operations\n        batch_words = [w for sent in batch[1] for w in sent]\n        wordlens = [len(x) for x in batch_words]\n        batch_words, word_orig_idx = sort_all([batch_words], wordlens)\n        batch_words = batch_words[0]\n        wordlens = [len(x) for x in batch_words]\n\n        words = batch[0]\n        \n        wordchars = get_long_tensor(batch_words, len(wordlens))\n        wordchars_mask = torch.eq(wordchars, PAD_ID)\n        chars_forward = get_long_tensor(chars_forward, batch_size, pad_id=self.vocab['char'].unit2id(' '))\n        chars_backward = get_long_tensor(chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' '))\n        chars = torch.cat([chars_forward.unsqueeze(0), chars_backward.unsqueeze(0)]) # padded forward and backward char idx\n        charoffsets = [charoffsets_forward, charoffsets_backward] # idx for forward and backward lm to get word representation\n        tags = get_long_tensor(batch[2], batch_size)\n\n        return words, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets\n\n    def __iter__(self):\n        for i in range(self.__len__()):\n            yield self.__getitem__(i)\n\n    def _load_doc(self, doc, scheme):\n        # preferentially load the MULTI_NER in case we are training /\n        # testing a model with multiple layers of tags\n        data = doc.get([TEXT, NER, MULTI_NER], as_sentences=True, from_token=True)\n        data = [[[token[0], token[2]] if token[2] else [token[0], (token[1],)] for token in sentence] for sentence in data]\n        if self.preprocess_tags: # preprocess tags\n            if scheme is None:\n                data = process_tags(data, self.args.get('scheme', 'bio'))\n            data = normalize_empty_tags(data)\n        return data\n\n    def process_chars(self, sents):\n        start_id, end_id = self.vocab['char'].unit2id('\\n'), self.vocab['char'].unit2id(' ') # special token\n        start_offset, end_offset = 1, 1\n        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = [], [], [], []\n        # get char representation for each sentence\n        for sent in sents:\n            chars_forward_sent, chars_backward_sent, charoffsets_forward_sent, charoffsets_backward_sent = [start_id], [start_id], [], []\n            # forward lm\n            for word in sent:\n                chars_forward_sent += word\n                charoffsets_forward_sent = charoffsets_forward_sent + [len(chars_forward_sent)] # add each token offset in the last for forward lm\n                chars_forward_sent += [end_id]\n            # backward lm\n            for word in sent[::-1]:\n                chars_backward_sent += word[::-1]\n                charoffsets_backward_sent = [len(chars_backward_sent)] + charoffsets_backward_sent # add each offset in the first for backward lm\n                chars_backward_sent += [end_id]\n            # store each sentence\n            chars_forward.append(chars_forward_sent)\n            chars_backward.append(chars_backward_sent)\n            charoffsets_forward.append(charoffsets_forward_sent)\n            charoffsets_backward.append(charoffsets_backward_sent)\n        charlens = [len(sent) for sent in chars_forward] # forward lm and backward lm should have the same lengths\n        return chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens\n\n    def reshuffle(self):\n        data = [y for x in self.data for y in x]\n        random.shuffle(data)\n        self.data = self.chunk_batches(data)\n\n    def chunk_batches(self, data):\n        if self.max_batch_words is None:\n            return [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]\n        batches = []\n        next_batch = []\n        for item in data:\n            next_batch.append(item)\n            if len(next_batch) >= self.batch_size:\n                batches.append(next_batch)\n                next_batch = []\n            if sum(len(x[0]) for x in next_batch) >= self.max_batch_words:\n                batches.append(next_batch)\n                next_batch = []\n        if len(next_batch) > 0:\n            batches.append(next_batch)\n        return batches\n"
  },
  {
    "path": "stanza/models/ner/model.py",
    "content": "import os\nimport logging\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence\n\nfrom stanza.models.common.data import map_to_ids, get_long_tensor\nfrom stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError\nfrom stanza.models.common.packed_lstm import PackedLSTM\nfrom stanza.models.common.dropout import WordDropout, LockedDropout\nfrom stanza.models.common.char_model import CharacterModel, CharacterLanguageModel\nfrom stanza.models.common.crf import CRFLoss\nfrom stanza.models.common.foundation_cache import load_bert\nfrom stanza.models.common.utils import attach_bert_model\nfrom stanza.models.common.vocab import PAD_ID, UNK_ID, EMPTY_ID\nfrom stanza.models.common.bert_embedding import extract_bert_embeddings\n\nlogger = logging.getLogger('stanza')\n\n# this gets created in two places in trainer\n# in both places, pass in the bert model & tokenizer\nclass NERTagger(nn.Module):\n    def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):\n        super().__init__()\n\n        self.vocab = vocab\n        self.args = args\n        self.unsaved_modules = []\n\n        # input layers\n        input_size = 0\n        if self.args['word_emb_dim'] > 0:\n            emb_finetune = self.args.get('emb_finetune', True)\n\n            if 'word' in self.vocab:\n                # load pretrained embeddings if specified\n                word_emb = nn.Embedding(len(self.vocab['word']), self.args['word_emb_dim'], PAD_ID)\n                # if a model trained with no 'delta' vocab is loaded, and\n                # emb_finetune is off, any resaving of the model will need\n                # the updated vectors.  this is accounted for in load()\n                if not emb_finetune or 'delta' in self.vocab:\n                    # if emb_finetune is off\n                    # or if the delta embedding is present\n                    # then we won't fine tune the original embedding\n                    self.add_unsaved_module('word_emb', word_emb)\n                    self.word_emb.weight.detach_()\n                else:\n                    self.word_emb = word_emb\n                if emb_matrix is not None:\n                    self.init_emb(emb_matrix)\n\n            # TODO: allow for expansion of delta embedding if new\n            # training data has new words in it?\n            self.delta_emb = None\n            if 'delta' in self.vocab:\n                # zero inits seems to work better\n                # note that the gradient will flow to the bottom and then adjust the 0 weights\n                # as opposed to a 0 matrix cutting off the gradient if higher up in the model\n                self.delta_emb = nn.Embedding(len(self.vocab['delta']), self.args['word_emb_dim'], PAD_ID)\n                nn.init.zeros_(self.delta_emb.weight)\n                # if the model was trained with a delta embedding, but emb_finetune is off now,\n                # then we will detach the delta embedding\n                if not emb_finetune:\n                    self.delta_emb.weight.detach_()\n\n            input_size += self.args['word_emb_dim']\n\n        self.peft_name = peft_name\n        attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)\n        if self.args.get('bert_model', None):\n            # TODO: refactor bert_hidden_layers between the different models\n            if args.get('bert_hidden_layers', False):\n                # The average will be offset by 1/N so that the default zeros\n                # represents an average of the N layers\n                self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)\n                nn.init.zeros_(self.bert_layer_mix.weight)\n            else:\n                # an average of layers 2, 3, 4 will be used\n                # (for historic reasons)\n                self.bert_layer_mix = None\n            input_size += self.bert_model.config.hidden_size\n\n        if self.args['char'] and self.args['char_emb_dim'] > 0:\n            if self.args['charlm']:\n                if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):\n                    raise ForwardCharlmNotFoundError('Could not find forward character model: {}  Please specify with --charlm_forward_file'.format(args['charlm_forward_file']), args['charlm_forward_file'])\n                if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):\n                    raise BackwardCharlmNotFoundError('Could not find backward character model: {}  Please specify with --charlm_backward_file'.format(args['charlm_backward_file']), args['charlm_backward_file'])\n                self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False))\n                self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False))\n                input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()\n            else:\n                self.charmodel = CharacterModel(args, vocab, bidirectional=True, attention=False)\n                input_size += self.args['char_hidden_dim'] * 2\n\n        # optionally add a input transformation layer\n        if self.args.get('input_transform', False):\n            self.input_transform = nn.Linear(input_size, input_size)\n        else:\n            self.input_transform = None\n       \n        # recurrent layers\n        self.taggerlstm = PackedLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, \\\n                bidirectional=True, dropout=0 if self.args['num_layers'] == 1 else self.args['dropout'])\n        # self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))\n        self.drop_replacement = None\n        self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)\n        self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)\n\n        # tag classifier\n        tag_lengths = self.vocab['tag'].lens()\n        self.num_output_layers = len(tag_lengths)\n        if self.args.get('connect_output_layers'):\n            tag_clfs = [nn.Linear(self.args['hidden_dim']*2, tag_lengths[0])]\n            for prev_length, next_length in zip(tag_lengths[:-1], tag_lengths[1:]):\n                tag_clfs.append(nn.Linear(self.args['hidden_dim']*2 + prev_length, next_length))\n            self.tag_clfs = nn.ModuleList(tag_clfs)\n        else:\n            self.tag_clfs = nn.ModuleList([nn.Linear(self.args['hidden_dim']*2, num_tag) for num_tag in tag_lengths])\n        for tag_clf in self.tag_clfs:\n            tag_clf.bias.data.zero_()\n        self.crits = nn.ModuleList([CRFLoss(num_tag) for num_tag in tag_lengths])\n\n        self.drop = nn.Dropout(args['dropout'])\n        self.worddrop = WordDropout(args['word_dropout'])\n        self.lockeddrop = LockedDropout(args['locked_dropout'])\n\n    def init_emb(self, emb_matrix):\n        if isinstance(emb_matrix, np.ndarray):\n            emb_matrix = torch.from_numpy(emb_matrix)\n        vocab_size = len(self.vocab['word'])\n        dim = self.args['word_emb_dim']\n        assert emb_matrix.size() == (vocab_size, dim), \\\n            \"Input embedding matrix must match size: {} x {}, found {}\".format(vocab_size, dim, emb_matrix.size())\n        self.word_emb.weight.data.copy_(emb_matrix)\n\n    def add_unsaved_module(self, name, module):\n        self.unsaved_modules += [name]\n        setattr(self, name, module)\n\n    def log_norms(self):\n        lines = [\"NORMS FOR MODEL PARAMTERS\"]\n        for name, param in self.named_parameters():\n            if param.requires_grad and name.split(\".\")[0] not in ('charmodel_forward', 'charmodel_backward'):\n                lines.append(\"  %s %.6g\" % (name, torch.norm(param).item()))\n        logger.info(\"\\n\".join(lines))\n\n    def forward(self, sentences, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx):\n        device = next(self.parameters()).device\n\n        def pack(x):\n            return pack_padded_sequence(x, sentlens, batch_first=True)\n        \n        inputs = []\n        batch_size = len(sentences)\n\n        has_embedding = False\n        if self.args['word_emb_dim'] > 0:\n            #extract static embeddings\n            if 'word' in self.vocab:\n                static_words, word_mask = self.extract_static_embeddings(self.args, sentences, self.vocab['word'])\n\n                word_mask = word_mask.to(device)\n                static_words = static_words.to(device)\n\n                word_static_emb = self.word_emb(static_words)\n                has_embedding = True\n\n            if 'delta' in self.vocab and self.delta_emb is not None:\n                # masks should be the same\n                delta_words, delta_mask = self.extract_static_embeddings(self.args, sentences, self.vocab['delta'])\n                delta_words = delta_words.to(device)\n                # unclear whether to treat words in the main embedding\n                # but not in delta as unknown\n                # simple heuristic though - treating them as not\n                # unknown keeps existing models the same when\n                # separating models into the base WV and delta WV\n                # also, note that at training time, words like this\n                # did not show up in the training data, but are\n                # not exactly UNK, so it makes sense\n                if has_embedding:\n                    delta_unk_mask = torch.eq(delta_words, UNK_ID)\n                    static_unk_mask = torch.not_equal(static_words, UNK_ID)\n                    unk_mask = delta_unk_mask * static_unk_mask\n                    delta_words[unk_mask] = PAD_ID\n                else:\n                    word_mask = delta_mask.to(device)\n\n                delta_emb = self.delta_emb(delta_words)\n                if has_embedding:\n                    word_static_emb = word_static_emb + delta_emb\n                else:\n                    has_embedding = True\n                    word_static_emb = delta_emb\n\n            if has_embedding:\n                word_emb = pack(word_static_emb)\n                inputs += [word_emb]\n\n        if self.bert_model is not None:\n            device = next(self.parameters()).device\n            processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, sentences, device, keep_endpoints=False,\n                                                     num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,\n                                                     detach=not self.args.get('bert_finetune', False),\n                                                     peft_name=self.peft_name)\n            if self.bert_layer_mix is not None:\n                # use a linear layer to weighted average the embedding dynamically\n                processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]\n\n            processed_bert = pad_sequence(processed_bert, batch_first=True)\n            inputs += [pack(processed_bert)]\n\n        if self.args['char'] and self.args['char_emb_dim'] > 0:\n            if self.args.get('charlm', None):\n                char_reps_forward = self.charmodel_forward.get_representation(chars[0], charoffsets[0], charlens, char_orig_idx)\n                char_reps_forward = PackedSequence(char_reps_forward.data, char_reps_forward.batch_sizes)\n                char_reps_backward = self.charmodel_backward.get_representation(chars[1], charoffsets[1], charlens, char_orig_idx)\n                char_reps_backward = PackedSequence(char_reps_backward.data, char_reps_backward.batch_sizes)\n                inputs += [char_reps_forward, char_reps_backward]\n            else:\n                char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)\n                char_reps = PackedSequence(char_reps.data, char_reps.batch_sizes)\n                inputs += [char_reps]\n\n        batch_sizes = inputs[0].batch_sizes\n        def pad(x):\n            return pad_packed_sequence(PackedSequence(x, batch_sizes), batch_first=True)[0]\n\n        lstm_inputs = torch.cat([x.data for x in inputs], 1)\n        if self.args['word_dropout'] > 0:\n            lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)\n        lstm_inputs = self.drop(lstm_inputs)\n        lstm_inputs = pad(lstm_inputs)\n        lstm_inputs = self.lockeddrop(lstm_inputs)\n        lstm_inputs = pack(lstm_inputs).data\n\n        if self.input_transform:\n            lstm_inputs = self.input_transform(lstm_inputs)\n\n        lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)\n        lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(\\\n                self.taggerlstm_h_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous(), \\\n                self.taggerlstm_c_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous()))\n        lstm_outputs = lstm_outputs.data\n\n\n        # prediction layer\n        lstm_outputs = self.drop(lstm_outputs)\n        lstm_outputs = pad(lstm_outputs)\n        lstm_outputs = self.lockeddrop(lstm_outputs)\n        lstm_outputs = pack(lstm_outputs).data\n\n        loss = 0\n        logits = []\n        trans = []\n        for idx, (tag_clf, crit) in enumerate(zip(self.tag_clfs, self.crits)):\n            if not self.args.get('connect_output_layers') or idx == 0:\n                next_logits = pad(tag_clf(lstm_outputs)).contiguous()\n            else:\n                # here we pack the output of the previous round, then append it\n                packed_logits = pack(next_logits).data\n                input_logits = torch.cat([lstm_outputs, packed_logits], axis=1)\n                next_logits = pad(tag_clf(input_logits)).contiguous()\n            # the tag_mask lets us avoid backprop on a blank tag\n            tag_mask = torch.eq(tags[:, :, idx], EMPTY_ID)\n            if has_embedding:\n                tag_mask = torch.bitwise_or(tag_mask, word_mask)\n            else:\n                tag_mask = torch.bitwise_or(tag_mask, torch.eq(tags[:, :, idx], PAD_ID))\n            next_loss, next_trans = crit(next_logits, tag_mask, tags[:, :, idx])\n            loss = loss + next_loss\n            logits.append(next_logits)\n            trans.append(next_trans)\n\n        return loss, logits, trans\n\n    @staticmethod\n    def extract_static_embeddings(args, sents, vocab):\n        processed = []\n        if args.get('lowercase', True): # handle word case\n            case = lambda x: x.lower()\n        else:\n            case = lambda x: x\n        for idx, sent in enumerate(sents):\n            processed_sent = [vocab.map([case(w) for w in sent])]\n            processed.append(processed_sent[0])\n\n        words = get_long_tensor(processed, len(sents))\n        words_mask = torch.eq(words, PAD_ID)\n\n        return words, words_mask\n\n"
  },
  {
    "path": "stanza/models/ner/scorer.py",
    "content": "\"\"\"\nAn NER scorer that calculates F1 score given gold and predicted tags.\n\"\"\"\nimport sys\nimport os\nimport logging\nfrom collections import Counter, defaultdict\n\nfrom stanza.models.ner.utils import decode_from_bioes\n\nlogger = logging.getLogger('stanza')\n\ndef score_by_entity(pred_tag_sequences, gold_tag_sequences, verbose=True, ignore_tags=None):\n    \"\"\" Score predicted tags at the entity level.\n\n    Args:\n        pred_tags_sequences: a list of list of predicted tags for each word\n        gold_tags_sequences: a list of list of gold tags for each word\n        verbose: print log with results\n        ignore_tags: a list or a string with a comma-separated list of tags to ignore\n    \n    Returns:\n        Precision, recall and F1 scores.\n    \"\"\"\n    assert(len(gold_tag_sequences) == len(pred_tag_sequences)), \\\n        \"Number of predicted tag sequences does not match gold sequences.\"\n    \n    def decode_all(tag_sequences):\n        # decode from all sequences, each sequence with a unique id\n        ents = []\n        for sent_id, tags in enumerate(tag_sequences):\n            for ent in decode_from_bioes(tags):\n                ent['sent_id'] = sent_id\n                ents += [ent]\n        return ents\n\n    ignore_tag_set = set()\n    if ignore_tags:\n        if isinstance(ignore_tags, str):\n            ignore_tag_set.update(ignore_tags.split(\",\"))\n        else:\n            ignore_tag_set.update(ignore_tags)\n\n    gold_ents = decode_all(gold_tag_sequences)\n    gold_ents = [x for x in gold_ents if x['type'] not in ignore_tag_set]\n\n    pred_ents = decode_all(pred_tag_sequences)\n    pred_ents = [x for x in pred_ents if x['type'] not in ignore_tag_set]\n\n    # scoring\n    true_positive_by_type = Counter()\n    false_positive_by_type = Counter()\n    false_negative_by_type = Counter()\n    guessed_by_type = Counter()\n    gold_by_type = Counter()\n\n    for p in pred_ents:\n        guessed_by_type[p['type']] += 1\n        if p in gold_ents:\n            true_positive_by_type[p['type']] += 1\n        else:\n            false_positive_by_type[p['type']] += 1\n    for g in gold_ents:\n        gold_by_type[g['type']] += 1\n        if g not in pred_ents:\n            false_negative_by_type[g['type']] += 1\n\n    entities = sorted(set(list(true_positive_by_type.keys()) + list(false_positive_by_type.keys()) + list(false_negative_by_type.keys())))\n    entity_f1 = {}\n    for entity in entities:\n        entity_f1[entity] = 2 * true_positive_by_type[entity] / (2 * true_positive_by_type[entity] + false_positive_by_type[entity] + false_negative_by_type[entity])\n\n    prec_micro = 0.0\n    if sum(guessed_by_type.values()) > 0:\n        prec_micro = sum(true_positive_by_type.values()) * 1.0 / sum(guessed_by_type.values())\n    rec_micro = 0.0\n    if sum(gold_by_type.values()) > 0:\n        rec_micro = sum(true_positive_by_type.values()) * 1.0 / sum(gold_by_type.values())\n    f_micro = 0.0\n    if prec_micro + rec_micro > 0:\n        f_micro = 2.0 * prec_micro * rec_micro / (prec_micro + rec_micro)\n    \n    if verbose:\n        logger.info(\"Score by entity:\\nPrec.\\tRec.\\tF1\\n{:.2f}\\t{:.2f}\\t{:.2f}\".format(\n            prec_micro*100, rec_micro*100, f_micro*100))\n    return prec_micro, rec_micro, f_micro, entity_f1\n\n\ndef score_by_token(pred_tag_sequences, gold_tag_sequences, verbose=True, ignore_tags=None):\n    \"\"\" Score predicted tags at the token level.\n\n    Args:\n        pred_tags_sequences: a list of list of predicted tags for each word\n        gold_tags_sequences: a list of list of gold tags for each word\n        verbose: print log with results\n        ignore_tags: a list or a string with a comma-separated list of tags to ignore\n    \n    Returns:\n        Precision, recall and F1 scores, along with a confusion matrix\n    \"\"\"\n    assert(len(gold_tag_sequences) == len(pred_tag_sequences)), \\\n        \"Number of predicted tag sequences does not match gold sequences.\"\n    \n    ignore_tag_set = set()\n    if ignore_tags:\n        if isinstance(ignore_tags, str):\n            ignore_tag_set.update(ignore_tags.split(\",\"))\n        else:\n            ignore_tag_set.update(ignore_tags)\n\n    def ignore_tag(tag):\n        if tag == 'O':\n            return True\n        if len(tag) > 2 and (tag[1] == '_' or tag[1] == '-'):\n            tag = tag[2:]\n        if tag in ignore_tag_set:\n            return True\n        return False\n\n    correct_by_tag = Counter()\n    guessed_by_tag = Counter()\n    gold_by_tag = Counter()\n    confusion = defaultdict(lambda: defaultdict(int))\n\n    for gold_tags, pred_tags in zip(gold_tag_sequences, pred_tag_sequences):\n        assert(len(gold_tags) == len(pred_tags)), \\\n            \"Number of predicted tags does not match gold.\"\n        for g, p in zip(gold_tags, pred_tags):\n            if ignore_tag(g):\n                g = 'O'\n            if ignore_tag(p):\n                p = 'O'\n            confusion[g][p] = confusion[g][p] + 1\n            if g == 'O' and p == 'O':\n                continue\n            elif g == 'O' and p != 'O':\n                guessed_by_tag[p] += 1\n            elif g != 'O' and p == 'O':\n                gold_by_tag[g] += 1\n            else:\n                guessed_by_tag[p] += 1\n                gold_by_tag[p] += 1\n                if g == p:\n                    correct_by_tag[p] += 1\n    \n    prec_micro = 0.0\n    if sum(guessed_by_tag.values()) > 0:\n        prec_micro = sum(correct_by_tag.values()) * 1.0 / sum(guessed_by_tag.values())\n    rec_micro = 0.0\n    if sum(gold_by_tag.values()) > 0:\n        rec_micro = sum(correct_by_tag.values()) * 1.0 / sum(gold_by_tag.values())\n    f_micro = 0.0\n    if prec_micro + rec_micro > 0:\n        f_micro = 2.0 * prec_micro * rec_micro / (prec_micro + rec_micro)\n    \n    if verbose:\n        logger.info(\"Score by token:\\nPrec.\\tRec.\\tF1\\n{:.2f}\\t{:.2f}\\t{:.2f}\".format(\n            prec_micro*100, rec_micro*100, f_micro*100))\n    return prec_micro, rec_micro, f_micro, confusion\n\ndef test():\n    pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'],\n                    ['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']]\n    gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'],\n                    ['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']]\n    print(score_by_token(pred_sequences, gold_sequences))\n    print(score_by_entity(pred_sequences, gold_sequences))\n\nif __name__ == '__main__':\n    test()\n\n"
  },
  {
    "path": "stanza/models/ner/trainer.py",
    "content": "\"\"\"\nA trainer class to handle training and testing of models.\n\"\"\"\n\nimport sys\nimport logging\nimport torch\nfrom torch import nn\n\nfrom stanza.models.common.foundation_cache import NoTransformerFoundationCache, load_bert, load_bert_with_peft\nfrom stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper\nfrom stanza.models.common.trainer import Trainer as BaseTrainer\nfrom stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE\nfrom stanza.models.common import utils, loss\nfrom stanza.models.ner.model import NERTagger\nfrom stanza.models.ner.vocab import MultiVocab\nfrom stanza.models.common.crf import viterbi_decode\n\n\nlogger = logging.getLogger('stanza')\n\ndef unpack_batch(batch, device):\n    \"\"\" Unpack a batch from the data loader. \"\"\"\n    inputs = [batch[0]]\n    inputs += [b.to(device) if b is not None else None for b in batch[1:5]]\n    orig_idx = batch[5]\n    word_orig_idx = batch[6]\n    char_orig_idx = batch[7]\n    sentlens = batch[8]\n    wordlens = batch[9]\n    charlens = batch[10]\n    charoffsets = batch[11]\n    return inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets\n\ndef fix_singleton_tags(tags):\n    \"\"\"\n    If there are any singleton B- or E- tags, convert them to S-\n    \"\"\"\n    new_tags = list(tags)\n    # first update all I- tags at the start or end of sequence to B- or E- as appropriate\n    for idx, tag in enumerate(new_tags):\n        if (tag.startswith(\"I-\") and\n            (idx == len(new_tags) - 1 or\n             (new_tags[idx+1] != \"I-\" + tag[2:] and new_tags[idx+1] != \"E-\" + tag[2:]))):\n            new_tags[idx] = \"E-\" + tag[2:]\n        if (tag.startswith(\"I-\") and\n            (idx == 0 or\n             (new_tags[idx-1] != \"B-\" + tag[2:] and new_tags[idx-1] != \"I-\" + tag[2:]))):\n            new_tags[idx] = \"B-\" + tag[2:]\n    # now make another pass through the data to update any singleton tags,\n    # including ones which were turned into singletons by the previous operation\n    for idx, tag in enumerate(new_tags):\n        if (tag.startswith(\"B-\") and\n            (idx == len(new_tags) - 1 or\n             (new_tags[idx+1] != \"I-\" + tag[2:] and new_tags[idx+1] != \"E-\" + tag[2:]))):\n            new_tags[idx] = \"S-\" + tag[2:]\n        if (tag.startswith(\"E-\") and\n            (idx == 0 or\n             (new_tags[idx-1] != \"B-\" + tag[2:] and new_tags[idx-1] != \"I-\" + tag[2:]))):\n            new_tags[idx] = \"S-\" + tag[2:]\n    return new_tags\n\nclass Trainer(BaseTrainer):\n    \"\"\" A trainer for training models. \"\"\"\n    def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None,\n                 train_classifier_only=False, foundation_cache=None, second_optim=False):\n        if model_file is not None:\n            # load everything from file\n            self.load(model_file, pretrain, args, foundation_cache)\n        else:\n            assert args is not None\n            assert vocab is not None\n            # build model from scratch\n            self.args = args\n            self.vocab = vocab\n            bert_model, bert_tokenizer = load_bert(self.args['bert_model'])\n            peft_name = None\n            if self.args['use_peft']:\n                # fine tune the bert if we're using peft\n                self.args['bert_finetune'] = True\n                peft_name = \"ner\"\n                # peft the lovely model\n                bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)\n\n            emb_matrix=None\n            if pretrain is not None:\n                emb_matrix = pretrain.emb\n\n            self.model = NERTagger(args, vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)\n\n            # IMPORTANT: gradient checkpointing BREAKS peft if applied before\n            # 1. Apply PEFT FIRST (looksie! it's above this line)\n            # 2. Run gradient checkpointing\n            # https://github.com/huggingface/peft/issues/742\n            if self.args.get(\"gradient_checkpointing\", False) and self.args.get(\"bert_finetune\", False):\n                self.model.bert_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n\n\n        # if this wasn't set anywhere, we use a default of the 0th tagset\n        # we don't set this as a default in the options so that\n        # we can distinguish \"intentionally set to 0\" and \"not set at all\"\n        if self.args.get('predict_tagset', None) is None:\n            self.args['predict_tagset'] = 0\n\n        if train_classifier_only:\n            logger.info('Disabling gradient for non-classifier layers')\n            exclude = ['tag_clf', 'crit']\n            for pname, p in self.model.named_parameters():\n                if pname.split('.')[0] not in exclude:\n                    p.requires_grad = False\n        self.model = self.model.to(device)\n        if not second_optim:\n            self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get(\"use_peft\"))\n        else:\n            self.optimizer = utils.get_optimizer(self.args['second_optim'], self.model, self.args['second_lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0), is_peft=self.args.get(\"use_peft\"))\n\n    def update(self, batch, eval=False):\n        device = next(self.model.parameters()).device\n        inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device)\n        word, wordchars, wordchars_mask, chars, tags = inputs\n\n        if eval:\n            self.model.eval()\n        else:\n            self.model.train()\n            self.optimizer.zero_grad()\n        loss, _, _ = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx)\n        loss_val = loss.data.item()\n        if eval:\n            return loss_val\n\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])\n        self.optimizer.step()\n        return loss_val\n\n    def predict(self, batch, unsort=True):\n        device = next(self.model.parameters()).device\n        inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device)\n        word, wordchars, wordchars_mask, chars, tags = inputs\n\n        self.model.eval()\n        #batch_size = word.size(0)\n        _, logits, trans = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx)\n\n        # decode\n        # TODO: might need to decode multiple columns of output for\n        # models with multiple layers\n        trans = [x.data.cpu().numpy() for x in trans]\n        logits = [x.data.cpu().numpy() for x in logits]\n        batch_size = logits[0].shape[0]\n        if any(x.shape[0] != batch_size for x in logits):\n            raise AssertionError(\"Expected all of the logits to have the same size\")\n        tag_seqs = []\n        predict_tagset = self.args['predict_tagset']\n        for i in range(batch_size):\n            # for each tag column in the output, decode the tag assignments\n            tags = [viterbi_decode(x[i, :sentlens[i]], y)[0] for x, y in zip(logits, trans)]\n            # TODO: this is to patch that the model can sometimes predict < \"O\"\n            tags = [[x if x >= VOCAB_PREFIX_SIZE else VOCAB_PREFIX_SIZE for x in y] for y in tags]\n            # that gives us N lists of |sent| tags, whereas we want |sent| lists of N tags\n            tags = list(zip(*tags))\n            # now unmap that to the tags in the vocab\n            tags = self.vocab['tag'].unmap(tags)\n            # for now, allow either TagVocab or CompositeVocab\n            # TODO: we might want to return all of the predictions\n            # rather than a single column\n            tags = [x[predict_tagset] if isinstance(x, list) else x for x in tags]\n            tags = fix_singleton_tags(tags)\n            tag_seqs += [tags]\n\n        if unsort:\n            tag_seqs = utils.unsort(tag_seqs, orig_idx)\n        return tag_seqs\n\n    def save(self, filename, skip_modules=True):\n        model_state = self.model.state_dict()\n        # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file\n        if skip_modules:\n            skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]\n            for k in skipped:\n                del model_state[k]\n        params = {\n                'model': model_state,\n                'vocab': self.vocab.state_dict(),\n                'config': self.args\n                }\n\n        if self.args[\"use_peft\"]:\n            from peft import get_peft_model_state_dict\n            params[\"bert_lora\"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)\n        try:\n            torch.save(params, filename, _use_new_zipfile_serialization=False)\n            logger.info(\"Model saved to {}\".format(filename))\n        except (KeyboardInterrupt, SystemExit):\n            raise\n        except:\n            logger.warning(\"Saving failed... continuing anyway.\")\n\n    def load(self, filename, pretrain=None, args=None, foundation_cache=None):\n        try:\n            checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        except BaseException:\n            logger.error(\"Cannot load model from {}\".format(filename))\n            raise\n        self.args = checkpoint['config']\n        if args: self.args.update(args)\n        # if predict_tagset was not explicitly set in the args,\n        # we use the value the model was trained with\n        for keep_arg in ('predict_tagset', 'train_scheme', 'scheme'):\n            if self.args.get(keep_arg, None) is None:\n                self.args[keep_arg] = checkpoint['config'].get(keep_arg, None)\n\n        lora_weights = checkpoint.get('bert_lora')\n        if lora_weights:\n            logger.debug(\"Found peft weights for NER; loading a peft adapter\")\n            self.args[\"use_peft\"] = True\n\n        self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])\n\n        emb_matrix=None\n        if pretrain is not None:\n            emb_matrix = pretrain.emb\n\n        force_bert_saved = False\n        peft_name = None\n        if self.args.get('use_peft', False):\n            force_bert_saved = True\n            bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], \"ner\", foundation_cache)\n            bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)\n            logger.debug(\"Loaded peft with name %s\", peft_name)\n        else:\n            if any(x.startswith(\"bert_model.\") for x in checkpoint['model'].keys()):\n                logger.debug(\"Model %s has a finetuned transformer.  Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere\", filename)\n                foundation_cache = NoTransformerFoundationCache(foundation_cache)\n                force_bert_saved = True\n            bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)\n\n        if any(x.startswith(\"crit.\") for x in checkpoint['model'].keys()):\n            logger.debug(\"Old model format detected.  Updating to the new format with one column of tags\")\n            checkpoint['model']['crits.0._transitions'] = checkpoint['model'].pop('crit._transitions')\n            checkpoint['model']['tag_clfs.0.weight'] = checkpoint['model'].pop('tag_clf.weight')\n            checkpoint['model']['tag_clfs.0.bias'] = checkpoint['model'].pop('tag_clf.bias')\n        self.model = NERTagger(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)\n        self.model.load_state_dict(checkpoint['model'], strict=False)\n\n        # there is a possible issue with the delta embeddings.\n        # specifically, with older models trained without the delta\n        # embedding matrix\n        # if those models have been trained with the embedding\n        # modifications saved as part of the base embedding,\n        # we need to resave the model with the updated embedding\n        # otherwise the resulting model will be broken\n        if 'delta' not in self.model.vocab and 'word_emb.weight' in checkpoint['model'].keys() and 'word_emb' in self.model.unsaved_modules:\n            logger.debug(\"Removing word_emb from unsaved_modules so that resaving %s will keep the saved embedding\", filename)\n            self.model.unsaved_modules.remove('word_emb')\n\n    def get_known_tags(self):\n        \"\"\"\n        Return the tags known by this model\n\n        Removes the S-, B-, etc, and does not include O\n        \"\"\"\n        tags = set()\n        for tag in self.vocab['tag'].items(0):\n            if tag in VOCAB_PREFIX:\n                continue\n            if tag == 'O':\n                continue\n            if len(tag) > 2 and tag[:2] in ('S-', 'B-', 'I-', 'E-'):\n                tag = tag[2:]\n            tags.add(tag)\n        return sorted(tags)\n"
  },
  {
    "path": "stanza/models/ner/utils.py",
    "content": "\"\"\"\nUtility functions for dealing with NER tagging.\n\"\"\"\n\nimport logging\n\nfrom stanza.models.common.vocab import EMPTY\n\nlogger = logging.getLogger('stanza')\n\nEMPTY_TAG = ('_', '-', '', None)\nEMPTY_OR_O_TAG = tuple(list(EMPTY_TAG) + ['O'])\n\ndef is_basic_scheme(all_tags):\n    \"\"\"\n    Check if a basic tagging scheme is used. Return True if so.\n\n    Args:\n        all_tags: a list of NER tags\n\n    Returns:\n        True if the tagging scheme does not use B-, I-, etc, otherwise False\n    \"\"\"\n    for tag in all_tags:\n        if len(tag) > 2 and tag[:2] in ('B-', 'I-', 'S-', 'E-', 'B_', 'I_', 'S_', 'E_'):\n            return False\n    return True\n\n\ndef is_bio_scheme(all_tags):\n    \"\"\"\n    Check if BIO tagging scheme is used. Return True if so.\n\n    Args:\n        all_tags: a list of NER tags\n    \n    Returns:\n        True if the tagging scheme is BIO, otherwise False\n    \"\"\"\n    for tag in all_tags:\n        if tag in EMPTY_OR_O_TAG:\n            continue\n        elif len(tag) > 2 and tag[:2] in ('B-', 'I-', 'B_', 'I_'):\n            continue\n        else:\n            return False\n    return True\n\ndef to_bio2(tags):\n    \"\"\"\n    Convert the original tag sequence to BIO2 format. If the input is already in BIO2 format,\n    the original input is returned.\n\n    Args:\n        tags: a list of tags in either BIO or BIO2 format\n    \n    Returns:\n        new_tags: a list of tags in BIO2 format\n    \"\"\"\n    new_tags = []\n    for i, tag in enumerate(tags):\n        if tag in EMPTY_OR_O_TAG:\n            new_tags.append(tag)\n        elif tag[0] == 'I':\n            if i == 0 or tags[i-1] == 'O' or tags[i-1][1:] != tag[1:]:\n                new_tags.append('B' + tag[1:])\n            else:\n                new_tags.append(tag)\n        else:\n            new_tags.append(tag)\n    return new_tags\n\ndef basic_to_bio(tags):\n    \"\"\"\n    Convert a basic tag sequence into a BIO sequence.\n    You can compose this with bio2_to_bioes to convert to bioes\n\n    Args:\n        tags: a list of tags in basic (no B-, I-, etc) format\n\n    Returns:\n        new_tags: a list of tags in BIO format\n    \"\"\"\n    new_tags = []\n    for i, tag in enumerate(tags):\n        if tag in EMPTY_OR_O_TAG:\n            new_tags.append(tag)\n        elif i == 0 or tags[i-1] == 'O' or tags[i-1] != tag:\n            new_tags.append('B-' + tag)\n        else:\n            new_tags.append('I-' + tag)\n    return new_tags\n\n\ndef bio2_to_bioes(tags):\n    \"\"\"\n    Convert the BIO2 tag sequence into a BIOES sequence.\n\n    Args:\n        tags: a list of tags in BIO2 format\n\n    Returns:\n        new_tags: a list of tags in BIOES format\n    \"\"\"\n    new_tags = []\n    for i, tag in enumerate(tags):\n        if tag in EMPTY_OR_O_TAG:\n            new_tags.append(tag)\n        else:\n            if len(tag) < 2:\n                raise Exception(f\"Invalid BIO2 tag found: {tag}\")\n            else:\n                if tag[:2] in ('I-', 'I_'): # convert to E- if next tag is not I-\n                    if i+1 < len(tags) and tags[i+1][:2] in ('I-', 'I_'):\n                        new_tags.append('I-' + tag[2:]) # compensate for underscores\n                    else:\n                        new_tags.append('E-' + tag[2:])\n                elif tag[:2] in ('B-', 'B_'): # convert to S- if next tag is not I-\n                    if i+1 < len(tags) and tags[i+1][:2] in ('I-', 'I_'):\n                        new_tags.append('B-' + tag[2:])\n                    else:\n                        new_tags.append('S-' + tag[2:])\n                else:\n                    raise Exception(f\"Invalid IOB tag found: {tag}\")\n    return new_tags\n\ndef normalize_empty_tags(sentences):\n    \"\"\"\n    If any tags are None, _, -, or blank, turn them into EMPTY\n\n    The input should be a list(sentence) of list(word) of tuple(text, list(tag))\n    which is the typical format for the data at the time data.py is preprocessing the tags\n    \"\"\"\n    new_sentences = [[(word[0], tuple(EMPTY if x in EMPTY_TAG else x for x in word[1])) for word in sentence]\n                     for sentence in sentences]\n    return new_sentences\n\ndef process_tags(sentences, scheme):\n    \"\"\"\n    Convert tags in these sentences to bioes\n\n    We allow empty tags ('_', '-', None), which will represent tags\n    that do not get any gradient when training\n    \"\"\"\n    all_words = []\n    all_tags = []\n    converted_tuples = False\n    for sent_idx, sent in enumerate(sentences):\n        words, tags = zip(*sent)\n        all_words.append(words)\n        # if we got one dimension tags w/o tuples or lists, make them tuples\n        # but we also check that the format is consistent,\n        # as otherwise the result being converted might be confusing\n        if not converted_tuples and any(tag is None or isinstance(tag, str) for tag in tags):\n            if sent_idx > 0:\n                raise ValueError(\"Got a mix of tags and lists of tags.  First non-list was in sentence %d\" % sent_idx)\n            converted_tuples = True\n        if converted_tuples:\n            if not all(tag is None or isinstance(tag, str) for tag in tags):\n                raise ValueError(\"Got a mix of tags and lists of tags.  First tag as a list was in sentence %d\" % sent_idx)\n            tags = [(tag,) for tag in tags]\n        all_tags.append(tags)\n\n    max_columns = max(len(x) for tags in all_tags for x in tags)\n    for sent_idx, tags in enumerate(all_tags):\n        if any(len(x) < max_columns for x in tags):\n            raise ValueError(\"NER tags not uniform in length at sentence %d.  TODO: extend those columns with O\" % sent_idx)\n\n    all_convert_bio_to_bioes = []\n    all_convert_basic_to_bioes = []\n\n    for column_idx in range(max_columns):\n        # check if tag conversion is needed for each column\n        # we treat each column separately, although practically\n        # speaking it would be pretty weird for a dataset to have BIO\n        # in one column and basic in another, for example\n        convert_bio_to_bioes = False\n        convert_basic_to_bioes = False\n        tag_column = [x[column_idx] for sent in all_tags for x in sent]\n        is_bio = is_bio_scheme(tag_column)\n        is_basic = not is_bio and is_basic_scheme(tag_column)\n        if is_bio and scheme.lower() == 'bioes':\n            convert_bio_to_bioes = True\n            logger.debug(\"BIO tagging scheme found in input at column %d; converting into BIOES scheme...\" % column_idx)\n        elif is_basic and scheme.lower() == 'bioes':\n            convert_basic_to_bioes = True\n            logger.debug(\"Basic tagging scheme found in input at column %d; converting into BIOES scheme...\" % column_idx)\n        all_convert_bio_to_bioes.append(convert_bio_to_bioes)\n        all_convert_basic_to_bioes.append(convert_basic_to_bioes)\n\n    result = []\n    for words, tags in zip(all_words, all_tags):\n        # TODO: add a convert_basic_to_bio option as well\n        # process tags\n        # tags is a list of each column of tags for each word in this sentence\n        # copy the tags to a list so we can edit them\n        tags = [[x for x in sentence_tags] for sentence_tags in tags]\n        for column_idx, (convert_bio_to_bioes, convert_basic_to_bioes) in enumerate(zip(all_convert_bio_to_bioes, all_convert_basic_to_bioes)):\n            tag_column = [x[column_idx] for x in tags]\n            if convert_basic_to_bioes:\n                # if basic, convert tags -> bio -> bioes\n                tag_column = bio2_to_bioes(basic_to_bio(tag_column))\n            else:\n                # first ensure BIO2 scheme\n                tag_column = to_bio2(tag_column)\n                # then convert to BIOES\n                if convert_bio_to_bioes:\n                    tag_column = bio2_to_bioes(tag_column)\n            for tag_idx, tag in enumerate(tag_column):\n                tags[tag_idx][column_idx] = tag\n        result.append([(w,tuple(t)) for w,t in zip(words, tags)])\n\n    if converted_tuples:\n        result = [[(word[0], word[1][0]) for word in sentence] for sentence in result]\n    return result\n\n\ndef decode_from_bioes(tags):\n    \"\"\"\n    Decode from a sequence of BIOES tags, assuming default tag is 'O'.\n    Args:\n        tags: a list of BIOES tags\n    \n    Returns:\n        A list of dict with start_idx, end_idx, and type values.\n    \"\"\"\n    res = []\n    ent_idxs = []\n    cur_type = None\n\n    def flush():\n        if len(ent_idxs) > 0:\n            res.append({\n                'start': ent_idxs[0], \n                'end': ent_idxs[-1], \n                'type': cur_type})\n\n    for idx, tag in enumerate(tags):\n        if tag is None:\n            tag = 'O'\n        if tag == 'O':\n            flush()\n            ent_idxs = []\n        elif tag.startswith('B-'): # start of new ent\n            flush()\n            ent_idxs = [idx]\n            cur_type = tag[2:]\n        elif tag.startswith('I-'): # continue last ent\n            ent_idxs.append(idx)\n            cur_type = tag[2:]\n        elif tag.startswith('E-'): # end last ent\n            ent_idxs.append(idx)\n            cur_type = tag[2:]\n            flush()\n            ent_idxs = []\n        elif tag.startswith('S-'): # start single word ent\n            flush()\n            ent_idxs = [idx]\n            cur_type = tag[2:]\n            flush()\n            ent_idxs = []\n    # flush after whole sentence\n    flush()\n    return res\n\n\ndef merge_tags(*sequences):\n    \"\"\"\n    Merge multiple sequences of NER tags into one sequence\n\n    Only O is replaced, and the earlier tags have precedence\n    \"\"\"\n    tags = list(sequences[0])\n    for sequence in sequences[1:]:\n        idx = 0\n        while idx < len(sequence):\n            # skip empty tags in the later sequences\n            if sequence[idx] == 'O':\n                idx += 1\n                continue\n\n            # check for singletons.  copy if not O in the original\n            if sequence[idx].startswith(\"S-\"):\n                if tags[idx] == 'O':\n                    tags[idx] = sequence[idx]\n                idx += 1\n                continue\n\n            # at this point, we know we have a B-... sequence\n            if not sequence[idx].startswith(\"B-\"):\n                raise ValueError(\"Got unexpected tag sequence at idx {}: {}\".format(idx, sequence))\n\n            # take the block of tags which are B- through E-\n            start_idx = idx\n            end_idx = start_idx + 1\n            while end_idx < len(sequence):\n                if sequence[end_idx][2:] != sequence[start_idx][2:]:\n                    raise ValueError(\"Unexpected tag sequence at idx {}: {}\".format(end_idx, sequence))\n                if sequence[end_idx].startswith(\"E-\"):\n                    break\n                if not sequence[end_idx].startswith(\"I-\"):\n                    raise ValueError(\"Unexpected tag sequence at idx {}: {}\".format(end_idx, sequence))\n                end_idx += 1\n            if end_idx == len(sequence):\n                raise ValueError(\"Got a sequence with an unclosed tag: {}\".format(sequence))\n            end_idx = end_idx + 1\n\n            # if all tags in the original are O, we can overwrite\n            # otherwise, keep the originals\n            if all(x == 'O' for x in tags[start_idx:end_idx]):\n                tags[start_idx:end_idx] = sequence[start_idx:end_idx]\n            idx = end_idx\n\n    return tags\n"
  },
  {
    "path": "stanza/models/ner/vocab.py",
    "content": "from collections import Counter, OrderedDict\n\nfrom stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab, CompositeVocab\nfrom stanza.models.common.vocab import VOCAB_PREFIX\nfrom stanza.models.common.pretrain import PretrainedWordVocab\nfrom stanza.models.pos.vocab import WordVocab\n\nclass TagVocab(BaseVocab):\n    \"\"\" A vocab for the output tag sequence. \"\"\"\n    def build_vocab(self):\n        counter = Counter([w[self.idx] for sent in self.data for w in sent])\n\n        self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))\n        self._unit2id = {w:i for i, w in enumerate(self._id2unit)}\n\ndef convert_tag_vocab(state_dict):\n    if state_dict['lower']:\n        raise AssertionError(\"Did not expect an NER vocab with 'lower' set to True\")\n    items = state_dict['_id2unit'][len(VOCAB_PREFIX):]\n    # this looks silly, but the vocab builder treats this as words with multiple fields\n    # (we set it to look for field 0 with idx=0)\n    # and then the label field is expected to be a list or tuple of items\n    items = [[[[x]]] for x in items]\n    vocab = CompositeVocab(data=items, lang=state_dict['lang'], idx=0, sep=None)\n    if len(vocab._id2unit[0]) != len(state_dict['_id2unit']):\n        raise AssertionError(\"Failed to construct a new vocab of the same length as the original\")\n    if vocab._id2unit[0] != state_dict['_id2unit']:\n        raise AssertionError(\"Failed to construct a new vocab in the same order as the original\")\n    return vocab\n\nclass MultiVocab(BaseMultiVocab):\n    def state_dict(self):\n        \"\"\" Also save a vocab name to class name mapping in state dict. \"\"\"\n        state = OrderedDict()\n        key2class = OrderedDict()\n        for k, v in self._vocabs.items():\n            state[k] = v.state_dict()\n            key2class[k] = type(v).__name__\n        state['_key2class'] = key2class\n        return state\n\n    @classmethod\n    def load_state_dict(cls, state_dict):\n        class_dict = {'CharVocab': CharVocab.load_state_dict,\n                      'PretrainedWordVocab': PretrainedWordVocab.load_state_dict,\n                      'TagVocab': convert_tag_vocab,\n                      'CompositeVocab': CompositeVocab.load_state_dict,\n                      'WordVocab': WordVocab.load_state_dict}\n        new = cls()\n        assert '_key2class' in state_dict, \"Cannot find class name mapping in state dict!\"\n        key2class = state_dict['_key2class']\n        for k,v in state_dict.items():\n            if k == '_key2class':\n                continue\n            classname = key2class[k]\n            new[k] = class_dict[classname](v)\n        return new\n\n"
  },
  {
    "path": "stanza/models/ner_tagger.py",
    "content": "\"\"\"\nEntry point for training and evaluating an NER tagger.\n\nThis tagger uses BiLSTM layers with character and word-level representations, and a CRF decoding layer \nto produce NER predictions.\nFor details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.\n\"\"\"\n\nimport sys\nimport os\nimport time\nfrom datetime import datetime\nimport argparse\nimport logging\nimport numpy as np\nimport random\nimport re\nimport json\nimport torch\nfrom torch import nn, optim\n\nfrom stanza.models.ner.data import DataLoader\nfrom stanza.models.ner.trainer import Trainer\nfrom stanza.models.ner import scorer\nfrom stanza.models.common import utils\nfrom stanza.models.common.pretrain import Pretrain\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models.common.doc import *\nfrom stanza.models import _training_logging\n\nfrom stanza.models.common.peft_config import add_peft_args, resolve_peft_args\nfrom stanza.utils.confusion import confusion_to_weighted_f1, format_confusion\n\nlogger = logging.getLogger('stanza')\n\ndef build_argparse():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_dir', type=str, default='data/ner', help='Directory of NER data.')\n    parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors')\n    parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n    parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--eval_output_file', type=str, default=None, help='Where to write results: text, gold, pred.  If None, no results file printed')\n\n    parser.add_argument('--mode', default='train', choices=['train', 'predict'])\n    parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `save_dir` path')\n    parser.add_argument('--finetune_load_name', type=str, default=None, help='Model to load when finetuning')\n    parser.add_argument('--train_classifier_only', action='store_true',\n                        help='In case of applying Transfer-learning approach and training only the classifier layer this will freeze gradient propagation for all other layers.')\n    parser.add_argument('--shorthand', type=str, help=\"Treebank shorthand\")\n\n    parser.add_argument('--hidden_dim', type=int, default=256)\n    parser.add_argument('--char_hidden_dim', type=int, default=100)\n    parser.add_argument('--word_emb_dim', type=int, default=100)\n    parser.add_argument('--char_emb_dim', type=int, default=100)\n    parser.add_argument('--num_layers', type=int, default=1)\n    parser.add_argument('--char_num_layers', type=int, default=1)\n    parser.add_argument('--pretrain_max_vocab', type=int, default=100000)\n    parser.add_argument('--word_dropout', type=float, default=0.01, help=\"How often to remove a word at training time.  Set to a small value to train unk when finetuning word embeddings\")\n    parser.add_argument('--locked_dropout', type=float, default=0.0)\n    parser.add_argument('--dropout', type=float, default=0.5)\n    parser.add_argument('--rec_dropout', type=float, default=0, help=\"Word recurrent dropout\")\n    parser.add_argument('--char_rec_dropout', type=float, default=0, help=\"Character recurrent dropout\")\n    parser.add_argument('--char_dropout', type=float, default=0, help=\"Character-level language model dropout\")\n    parser.add_argument('--no_char', dest='char', action='store_false', help=\"Turn off training a character model.\")\n    parser.add_argument('--charlm', action='store_true', help=\"Turn on contextualized char embedding using pretrained character-level language model.\")\n    parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help=\"Root dir for pretrained character-level language model.\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n    parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help=\"Use lowercased characters in character model.\")\n    parser.add_argument('--no_lowercase', dest='lowercase', action='store_false', help=\"Use cased word vectors.\")\n    parser.add_argument('--no_emb_finetune', dest='emb_finetune', action='store_false', help=\"Turn off finetuning of the embedding matrix.\")\n    parser.add_argument('--emb_finetune_known_only', dest='emb_finetune_known_only', action='store_true', help=\"Finetune the embedding matrix only for words in the embedding.  (Default: finetune words not in the embedding as well)  This may be useful for very large datasets where obscure words are only trained once in a while, such as French-WikiNER\")\n    parser.add_argument('--no_input_transform', dest='input_transform', action='store_false', help=\"Do not use input transformation layer before tagger lstm.\")\n    parser.add_argument('--scheme', type=str, default='bioes', help=\"The tagging scheme to use: bio or bioes.\")\n    parser.add_argument('--train_scheme', type=str, default=None, help=\"The tagging scheme to use when training: bio or bioes.  Overrides --scheme for the training set\")\n\n    parser.add_argument('--bert_model', type=str, default=None, help=\"Use an external bert model (requires the transformers package)\")\n    parser.add_argument('--no_bert_model', dest='bert_model', action=\"store_const\", const=None, help=\"Don't use bert\")\n    parser.add_argument('--bert_hidden_layers', type=int, default=None, help=\"How many layers of hidden state to use from the transformer\")\n    parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')\n    parser.add_argument('--gradient_checkpointing', default=False, action='store_true', help='Checkpoint intermediate gradients between layers to save memory at the cost of training steps')\n    parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help=\"Don't finetune the bert (or other transformer)\")\n    parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')\n    parser.add_argument('--second_optim', type=str, default=None, help='once first optimizer converged, tune the model again. with: sgd, adagrad, adam or adamax.')\n    parser.add_argument('--second_bert_learning_rate', default=0, type=float, help='Secondary stage transformer finetuning learning rate scale')\n\n    parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help=\"Turn off pretrained embeddings.\")\n\n    parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')\n    parser.add_argument('--optim', type=str, default='sgd', help='sgd, adagrad, adam or adamax.')\n    parser.add_argument('--lr', type=float, default=0.1, help='Learning rate.')\n    parser.add_argument('--min_lr', type=float, default=1e-4, help='Minimum learning rate to stop training.')\n    parser.add_argument('--second_lr', type=float, default=5e-3, help='Secondary learning rate')\n    parser.add_argument('--momentum', type=float, default=0, help='Momentum for SGD.')\n    parser.add_argument('--lr_decay', type=float, default=0.5, help=\"LR decay rate.\")\n    parser.add_argument('--patience', type=int, default=3, help=\"Patience for LR decay.\")\n\n    parser.add_argument('--connect_output_layers', action='store_true', default=False, help='Connect one output layer to the input of the next output layer.  By default, those layers are all separate')\n    parser.add_argument('--predict_tagset', type=int, default=None, help='Which tagset to predict if there are multiple tagsets.  Will default to 0.  Default of None allows the model to remember the value from training time, but be overridden at test time')\n\n    parser.add_argument('--ignore_tag_scores', type=str, default=None, help=\"Which tags to ignore, if any, when scoring dev & test sets\")\n\n    parser.add_argument('--max_steps', type=int, default=200000)\n    parser.add_argument('--max_steps_no_improve', type=int, default=2500, help='if the model doesn\\'t improve after this many steps, give up or switch to new optimizer.')\n    parser.add_argument('--eval_interval', type=int, default=500)\n    parser.add_argument('--batch_size', type=int, default=32)\n    parser.add_argument('--max_batch_words', type=int, default=800, help='Long sentences can overwhelm even a large GPU when finetuning a transformer on otherwise reasonable batch sizes.  This cuts off those batches early')\n    parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')\n    parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')\n    parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')\n    parser.add_argument('--save_dir', type=str, default='saved_models/ner', help='Root dir for saving models.')\n    parser.add_argument('--save_name', type=str, default=\"{shorthand}_{embedding}_{finetune}_nertagger.pt\", help=\"File name to save the model\")\n\n    parser.add_argument('--seed', type=int, default=1234)\n    utils.add_device_args(parser)\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n    return parser\n\ndef parse_args(args=None):\n    parser = build_argparse()\n    add_peft_args(parser)\n    args = parser.parse_args(args=args)\n    resolve_peft_args(args, logger)\n\n    if args.wandb_name:\n        args.wandb = True\n\n    args = vars(args)\n    return args\n\ndef main(args=None):\n    args = parse_args(args=args)\n\n    utils.set_random_seed(args['seed'])\n\n    logger.info(\"Running NER tagger in {} mode\".format(args['mode']))\n\n    if args['mode'] == 'train':\n        return train(args)\n    else:\n        evaluate(args)\n\ndef load_pretrain(args):\n    # load pretrained vectors\n    if not args['pretrain']:\n        return None\n\n    if args['wordvec_pretrain_file']:\n        pretrain_file = args['wordvec_pretrain_file']\n        pretrain = Pretrain(pretrain_file, None, args['pretrain_max_vocab'], save_to_file=False)\n    else:\n        if len(args['wordvec_file']) == 0:\n            vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])\n        else:\n            vec_file = args['wordvec_file']\n        # do not save pretrained embeddings individually\n        pretrain = Pretrain(None, vec_file, args['pretrain_max_vocab'], save_to_file=False)\n    return pretrain\n\ndef model_file_name(args):\n    return utils.standard_model_file_name(args, \"nertagger\")\n\ndef get_known_tags(tags):\n    \"\"\"\n    Tags are stored in the dataset as a list of list of tags\n\n    This returns a sorted list for each column of tags in the dataset\n    \"\"\"\n    max_columns = max(len(word) for sent in tags for word in sent)\n    known_tags = [set() for _ in range(max_columns)]\n    for sent in tags:\n        for word in sent:\n            for tag_idx, tag in enumerate(word):\n                known_tags[tag_idx].add(tag)\n    return [sorted(x) for x in known_tags]\n\ndef warn_missing_tags(tag_vocab, data_tags, error_msg, bioes_to_bio=False):\n    \"\"\"\n    Check for tags missing from the tag_vocab.\n\n    Given a tag_vocab and the known tags in the format used by\n    ner.data, go through the tags in the dataset and look for any\n    which aren't in the tag_vocab.\n\n    error_msg is something like \"training set\" or \"eval file\" to\n    indicate where the missing tags came from.\n    \"\"\"\n    tag_depth = max(max(len(tags) for tags in sentence) for sentence in data_tags)\n\n    if tag_depth != len(tag_vocab.lens()):\n        logger.warning(\"Test dataset has a different number of tag types compared to the model: %d vs %d\", tag_depth, len(tag_vocab.lens()))\n    for tag_set_idx in range(min(tag_depth, len(tag_vocab.lens()))):\n        tag_set = tag_vocab.items(tag_set_idx)\n        if len(tag_vocab.lens()) > 1:\n            current_error_msg = error_msg + \" tag set %d\" % tag_set_idx\n        else:\n            current_error_msg = error_msg\n\n        current_tags = set([word[tag_set_idx] for sentence in data_tags for word in sentence])\n        if bioes_to_bio:\n            current_tags = set([re.sub(\"^E-\", \"I-\", re.sub(\"^S-\", \"B-\", x)) for x in current_tags])\n        utils.warn_missing_tags(tag_set, current_tags, current_error_msg)\n\ndef train(args):\n    model_file = model_file_name(args)\n\n    save_dir, save_name = os.path.split(model_file)\n    utils.ensure_dir(save_dir)\n    if args['save_dir'] is None:\n        args['save_dir'] = save_dir\n    args['save_name'] = save_name\n\n    utils.log_training_args(args, logger)\n\n    pretrain = None\n    vocab = None\n    trainer = None\n\n    if args['finetune'] and args['finetune_load_name']:\n        logger.warning('Finetune is ON. Using model from \"{}\"'.format(args['finetune_load_name']))\n        _, trainer, vocab = load_model(args, args['finetune_load_name'])\n    elif args['finetune'] and os.path.exists(model_file):\n        logger.warning('Finetune is ON. Using model from \"{}\"'.format(model_file))\n        _, trainer, vocab = load_model(args, model_file)\n    else:\n        if args['finetune']:\n            raise FileNotFoundError('Finetune is set to true but model file is not found: {}'.format(model_file))\n\n        pretrain = load_pretrain(args)\n\n        if pretrain is not None:\n            word_emb_dim = pretrain.emb.shape[1]\n            if args['word_emb_dim'] and args['word_emb_dim'] != word_emb_dim:\n                logger.warning(\"Embedding file has a dimension of {}.  Model will be built with that size instead of {}\".format(word_emb_dim, args['word_emb_dim']))\n            args['word_emb_dim'] = word_emb_dim\n\n        if args['charlm']:\n            if args['charlm_shorthand'] is None:\n                raise ValueError(\"CharLM Shorthand is required for loading pretrained CharLM model...\")\n            logger.info('Using pretrained contextualized char embedding')\n            if not args['charlm_forward_file']:\n                args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])\n            if not args['charlm_backward_file']:\n                args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])\n\n    # load data\n    logger.info(\"Loading training data with batch size %d from %s\", args['batch_size'], args['train_file'])\n    with open(args['train_file']) as fin:\n        train_doc = Document(json.load(fin))\n    logger.info(\"Loaded %d sentences of training data\", len(train_doc.sentences))\n    if len(train_doc.sentences) == 0:\n        raise ValueError(\"File %s exists but has no usable training data\" % args['train_file'])\n    train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])\n    vocab = train_batch.vocab\n    logger.info(\"Loading dev data from %s\", args['eval_file'])\n    with open(args['eval_file']) as fin:\n        dev_doc = Document(json.load(fin))\n    logger.info(\"Loaded %d sentences of dev data\", len(dev_doc.sentences))\n    if len(dev_doc.sentences) == 0:\n        raise ValueError(\"File %s exists but has no usable dev data\" % args['train_file'])\n    dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True)\n\n    train_tags = get_known_tags(train_batch.tags)\n    logger.info(\"Training data has %d columns of tags\", len(train_tags))\n    for tag_idx, tags in enumerate(train_tags):\n        logger.info(\"Tags present in training set at column %d:\\n  Tags without BIES markers: %s\\n  Tags with B-, I-, E-, or S-: %s\",\n                    tag_idx,\n                    \" \".join(sorted(set(i for i in tags if i[:2] not in ('B-', 'I-', 'E-', 'S-')))),\n                    \" \".join(sorted(set(i[2:] for i in tags if i[:2] in ('B-', 'I-', 'E-', 'S-')))))\n\n    # skip training if the language does not have training or dev data\n    if len(train_batch) == 0 or len(dev_batch) == 0:\n        logger.info(\"Skip training because no data available...\")\n        return\n\n    logger.info(\"Training tagger...\")\n    if trainer is None: # init if model was not loaded previously from file\n        trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'],\n                          train_classifier_only=args['train_classifier_only'])\n\n    if args['finetune']:\n        warn_missing_tags(trainer.vocab['tag'], train_batch.tags, \"training set\")\n    # the evaluation will coerce the tags to the proper scheme,\n    # so we won't need to alert for not having S- or E- tags\n    bioes_to_bio = args['train_scheme'] == 'bio' and args['scheme'] == 'bioes'\n    warn_missing_tags(trainer.vocab['tag'], dev_batch.tags, \"dev set\", bioes_to_bio=bioes_to_bio)\n\n    # TODO: might still want to add multiple layers of tag evaluation to the scorer\n    dev_gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in dev_batch.tags]\n\n    logger.info(trainer.model)\n\n    global_step = 0\n    max_steps = args['max_steps']\n    dev_score_history = []\n    best_dev_preds = []\n    current_lr = trainer.optimizer.param_groups[0]['lr']\n    global_start_time = time.time()\n    format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'\n\n    # LR scheduling\n    if args['lr_decay'] > 0:\n        # learning rate changes on plateau -- no improvement on model for patience number of epochs\n        # change is made as a factor of the learning rate decay\n        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, mode='max', factor=args['lr_decay'],\n                                                               patience=args['patience'], min_lr=args['min_lr'])\n    else:\n        scheduler = None\n\n    if args['wandb']:\n        import wandb\n        wandb_name = args['wandb_name'] if args['wandb_name'] else \"%s_ner\" % args['shorthand']\n        wandb.init(name=wandb_name, config=args)\n        wandb.run.define_metric('train_loss', summary='min')\n        wandb.run.define_metric('dev_score', summary='max')\n        # track gradients!\n        wandb.watch(trainer.model, log_freq=4, log=\"gradients\")\n\n    # start training\n    last_best_step = 0\n    train_loss = 0\n    is_second_optim = False\n    while True:\n        should_stop = False\n        for i, batch in enumerate(train_batch):\n            start_time = time.time()\n            global_step += 1\n            loss = trainer.update(batch, eval=False) # update step\n            train_loss += loss\n            if global_step % args['log_step'] == 0:\n                duration = time.time() - start_time\n                logger.info(format_str.format(datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), global_step,\n                                              max_steps, loss, duration, current_lr))\n            if global_step % args['eval_interval'] == 0:\n                # eval on dev\n                logger.info(\"Evaluating on dev set...\")\n                dev_preds = []\n                for batch in dev_batch:\n                    preds = trainer.predict(batch)\n                    dev_preds += preds\n                _, _, dev_score, _ = scorer.score_by_entity(dev_preds, dev_gold_tags, ignore_tags=args['ignore_tag_scores'])\n\n                train_loss = train_loss / args['eval_interval'] # avg loss per batch\n                logger.info(\"step {}: train_loss = {:.6f}, dev_score = {:.4f}\".format(global_step, train_loss, dev_score))\n                if args['wandb']:\n                    wandb.log({'train_loss': train_loss, 'dev_score': dev_score})\n                train_loss = 0\n\n                # save best model\n                if len(dev_score_history) == 0 or dev_score > max(dev_score_history):\n                    trainer.save(model_file)\n                    last_best_step = global_step\n                    logger.info(\"New best model saved.\")\n                    best_dev_preds = dev_preds\n\n                dev_score_history += [dev_score]\n                logger.info(\"\")\n\n                # lr schedule\n                if scheduler is not None:\n                    scheduler.step(dev_score)\n            \n                if args['log_norms']:\n                    trainer.model.log_norms()\n\n            # check stopping\n            current_lr = trainer.optimizer.param_groups[0]['lr']\n            if (global_step - last_best_step) >= args['max_steps_no_improve'] or global_step >= args['max_steps'] or current_lr <= args['min_lr']:\n                if (global_step - last_best_step) >= args['max_steps_no_improve']:\n                    logger.info(\"{} steps without improvement...\".format((global_step - last_best_step)))\n                if not is_second_optim and args['second_optim'] is not None:\n                    logger.info(\"Switching to second optimizer: {}\".format(args['second_optim']))\n                    logger.info('Reloading best model to continue from current local optimum')\n                    trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'],\n                                      train_classifier_only=args['train_classifier_only'], model_file=model_file, second_optim=True)\n                    is_second_optim = True\n                    last_best_step = global_step\n                    current_lr = trainer.optimizer.param_groups[0]['lr']\n                else:\n                    logger.info(\"stopping...\")\n                    should_stop = True\n                    break\n\n        if should_stop:\n            break\n\n        train_batch.reshuffle()\n\n    logger.info(\"Training ended with {} steps.\".format(global_step))\n\n    if args['wandb']:\n        wandb.finish()\n\n    if len(dev_score_history) > 0:\n        best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1\n        logger.info(\"Best dev F1 = {:.2f}, at iteration = {}\".format(best_f, best_eval * args['eval_interval']))\n    else:\n        logger.info(\"Dev set never evaluated.  Saving final model.\")\n        trainer.save(model_file)\n\n    return trainer\n\ndef write_ner_results(filename, batch, preds, predict_tagset):\n    if len(batch.tags) != len(preds):\n        raise ValueError(\"Unexpected batch vs pred lengths: %d vs %d\" % (len(batch.tags), len(preds)))\n\n    with open(filename, \"w\", encoding=\"utf-8\") as fout:\n        tag_idx = 0\n        for b in batch:\n            # b[0] is words, b[5] is orig_idx\n            # a namedtuple would make this cleaner without being much slower\n            text = utils.unsort(b[0], b[5])\n            for sentence in text:\n                # TODO: if we change the predict_tagset mechanism, will have to change this\n                sentence_gold = [x[predict_tagset] for x in batch.tags[tag_idx]]\n                sentence_pred = preds[tag_idx]\n                tag_idx += 1\n                for word, gold, pred in zip(sentence, sentence_gold, sentence_pred):\n                    fout.write(\"%s\\t%s\\t%s\\n\" % (word, gold, pred))\n                fout.write(\"\\n\")\n\ndef evaluate(args):\n    # file paths\n    model_file = model_file_name(args)\n\n    loaded_args, trainer, vocab = load_model(args, model_file)\n    return evaluate_model(loaded_args, trainer, vocab, args['eval_file'])\n\ndef evaluate_model(loaded_args, trainer, vocab, eval_file):\n    if loaded_args['log_norms']:\n        trainer.model.log_norms()\n\n    model_file = os.path.join(loaded_args['save_dir'], loaded_args['save_name'])\n    logger.debug(\"Loaded model for eval from %s\", model_file)\n    logger.debug(\"Using the %d tagset for evaluation\", loaded_args['predict_tagset'])\n\n    # load data\n    logger.info(\"Loading data with batch size {}...\".format(loaded_args['batch_size']))\n    with open(eval_file) as fin:\n        doc = Document(json.load(fin))\n    batch = DataLoader(doc, loaded_args['batch_size'], loaded_args, vocab=vocab, evaluation=True, bert_tokenizer=trainer.model.bert_tokenizer)\n    bioes_to_bio = loaded_args['train_scheme'] == 'bio' and loaded_args['scheme'] == 'bioes'\n    warn_missing_tags(trainer.vocab['tag'], batch.tags, \"eval_file\", bioes_to_bio=bioes_to_bio)\n\n    logger.info(\"Start evaluation...\")\n    preds = []\n    for i, b in enumerate(batch):\n        preds += trainer.predict(b)\n\n    gold_tags = batch.tags\n    # TODO: might still want to add multiple layers of tag evaluation to the scorer\n    gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in gold_tags]\n\n    _, _, score, entity_f1 = scorer.score_by_entity(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores'])\n    _, _, _, confusion = scorer.score_by_token(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores'])\n    logger.info(\"Weighted f1 for non-O tokens: %5f\", confusion_to_weighted_f1(confusion, exclude=[\"O\"]))\n\n    logger.info(\"NER tagger score: %s %s %s %.2f\", loaded_args['shorthand'], model_file, eval_file, score*100)\n    entity_f1_lines = [\"%s: %.2f\" % (x, y*100) for x, y in entity_f1.items()]\n    logger.info(\"NER Entity F1 scores:\\n  %s\", \"\\n  \".join(entity_f1_lines))\n    logger.info(\"NER token confusion matrix:\\n{}\".format(format_confusion(confusion)))\n\n    if loaded_args['eval_output_file']:\n        write_ner_results(loaded_args['eval_output_file'], batch, preds, trainer.args['predict_tagset'])\n\n    return confusion\n\ndef load_model(args, model_file):\n    # load model\n    charlm_args = {}\n    if 'charlm_forward_file' in args:\n        charlm_args['charlm_forward_file'] = args['charlm_forward_file']\n    if 'charlm_backward_file' in args:\n        charlm_args['charlm_backward_file'] = args['charlm_backward_file']\n    if args['predict_tagset'] is not None:\n        charlm_args['predict_tagset'] = args['predict_tagset']\n    pretrain = load_pretrain(args)\n    trainer = Trainer(args=charlm_args, model_file=model_file, pretrain=pretrain, device=args['device'], train_classifier_only=args['train_classifier_only'])\n    loaded_args, vocab = trainer.args, trainer.vocab\n\n    # load config\n    for k in args:\n        if k.endswith('_dir') or k.endswith('_file') or k in ['batch_size', 'ignore_tag_scores', 'log_norms', 'mode', 'scheme', 'shorthand']:\n            loaded_args[k] = args[k]\n    save_dir, save_name = os.path.split(model_file)\n    loaded_args['save_dir'] = save_dir\n    loaded_args['save_name'] = save_name\n    return loaded_args, trainer, vocab\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/parser.py",
    "content": "\"\"\"\nEntry point for training and evaluating a dependency parser.\n\nThis implementation combines a deep biaffine graph-based parser with linearization and distance features.\nFor details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.\n\"\"\"\n\n\"\"\"\nTraining and evaluation for the parser.\n\"\"\"\n\nimport io\nimport sys\nimport os\nimport copy\nimport shutil\nimport time\nimport argparse\nimport logging\nimport numpy as np\nimport random\nimport zipfile\n\nimport torch\nfrom torch import nn, optim\n\nimport stanza.models.depparse.data as data\nfrom stanza.models.depparse.data import DataLoader\nfrom stanza.models.depparse.trainer import Trainer\nfrom stanza.models.depparse import scorer\nfrom stanza.models.common import utils\nfrom stanza.models.common import pretrain\nfrom stanza.models.common.data import augment_punct\nfrom stanza.models.common.doc import *\nfrom stanza.models.common.peft_config import add_peft_args, resolve_peft_args\nfrom stanza.models.common.utils import log_training_args\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models import _training_logging\n\nlogger = logging.getLogger('stanza')\n\ndef build_argparse():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_dir', type=str, default='data/depparse', help='Root dir for saving models.')\n    parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors.')\n    parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.')\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n    parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')\n    parser.add_argument('--no_gold_labels', dest='gold_labels', action='store_false', help=\"Don't score the eval file - perhaps it has no gold labels, for example.  Cannot be used at training time\")\n    parser.add_argument('--output_latex', default=False, action='store_true', help='Output the per-relation table in Latex form')\n    parser.add_argument('--mode', default='train', choices=['train', 'predict'])\n    parser.add_argument('--lang', type=str, help='Language')\n    parser.add_argument('--shorthand', type=str, help=\"Treebank shorthand\")\n\n    parser.add_argument('--hidden_dim', type=int, default=400)\n    parser.add_argument('--char_hidden_dim', type=int, default=400)\n    parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400)\n    parser.add_argument('--deep_biaff_output_dim', type=int, default=160)\n    # As an additional option, we implement arc embeddings\n    #  described in https://arxiv.org/pdf/2501.09451\n    #  Scaling Graph-Based Dependency Parsing with Arc Vectorization and Attention-Based Refinement\n    #  Nicolas Floquet, Joseph Le Roux, Nadi Tomeh, Thierry Charnois\n    # Unfortunately, the current implementation and hyperparameters do not seem to help\n    # when combined with a transformer as the input embedding\n    # LAS Scores on a few dev sets, UD 2.17, averaged over 5 seeds\n    # This is with a version where the arc -> unlabeled is one layer, arc -> label is two layers\n    # Using two layers for the arc -> unlabeled hurts scores a bit more\n    #  treebank   w/     w/o\n    #   en_ewt  93.46  93.47\n    #   de_gsd  89.02  89.12\n    #   it_vit  90.15  90.19\n    # However, this is without the transformer over the arcs, which is\n    # an important component of making the arcs more useful\n    parser.add_argument('--use_arc_embedding', action='store_true', default=False, help='Use arc embeddings, as per Scaling Graph-Based Dependency Parsing')\n    parser.add_argument('--no_use_arc_embedding', dest='use_arc_embedding', action='store_false', help=\"Don't use arc embeddings\")\n    parser.add_argument('--word_emb_dim', type=int, default=75)\n    parser.add_argument('--word_cutoff', type=int, default=None, help='How common a word must be to include it in the finetuned word embedding.  If not set, small word vector files will be 0, larger will be %d' % utils.DEFAULT_WORD_CUTOFF)\n    parser.add_argument('--char_emb_dim', type=int, default=100)\n    parser.add_argument('--tag_emb_dim', type=int, default=50)\n    parser.add_argument('--no_upos', dest='use_upos', action='store_false', default=True, help=\"Don't use upos tags as part of the tag embedding\")\n    parser.add_argument('--no_xpos', dest='use_xpos', action='store_false', default=True, help=\"Don't use xpos tags as part of the tag embedding\")\n    parser.add_argument('--no_ufeats', dest='use_ufeats', action='store_false', default=True, help=\"Don't use ufeats as part of the tag embedding\")\n    parser.add_argument('--transformed_dim', type=int, default=125)\n    parser.add_argument('--num_layers', type=int, default=3)\n    parser.add_argument('--char_num_layers', type=int, default=1)\n    parser.add_argument('--checkpoint_save_name', type=str, default=None, help=\"File name to save the most recent checkpoint\")\n    parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help=\"Don't save checkpoints\")\n    parser.add_argument('--pretrain_max_vocab', type=int, default=250000)\n    parser.add_argument('--word_dropout', type=float, default=0.33)\n    parser.add_argument('--dropout', type=float, default=0.5)\n    parser.add_argument('--rec_dropout', type=float, default=0, help=\"Recurrent dropout\")\n    parser.add_argument('--char_rec_dropout', type=float, default=0, help=\"Recurrent dropout\")\n\n    parser.add_argument('--no_char', dest='char', action='store_false', help=\"Turn off character model.\")\n    parser.add_argument('--charlm', action='store_true', help=\"Turn on contextualized char embedding using pretrained character-level language model.\")\n    parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help=\"Root dir for pretrained character-level language model.\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n\n    parser.add_argument('--bert_model', type=str, default=None, help=\"Use an external bert model (requires the transformers package)\")\n    parser.add_argument('--no_bert_model', dest='bert_model', action=\"store_const\", const=None, help=\"Don't use bert\")\n    parser.add_argument('--bert_hidden_layers', type=int, default=4, help=\"How many layers of hidden state to use from the transformer\")\n    parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')\n    parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')\n    parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help=\"Don't finetune the bert (or other transformer)\")\n    parser.add_argument('--bert_finetune_layers', default=None, type=int, help='Only finetune this many layers from the transformer')\n    parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')\n    parser.add_argument('--second_bert_learning_rate', default=1e-3, type=float, help='Secondary stage transformer finetuning learning rate scale')\n    parser.add_argument('--bert_start_finetuning', default=200, type=int, help='When to start finetuning the transformer')\n    parser.add_argument('--bert_warmup_steps', default=200, type=int, help='How many steps for a linear warmup when finetuning the transformer')\n    parser.add_argument('--bert_weight_decay', default=0.0, type=float, help='Weight decay bert parameters by this much')\n\n    parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help=\"Turn off pretrained embeddings.\")\n    parser.add_argument('--no_linearization', dest='linearization', action='store_false', help=\"Turn off linearization term.\")\n    parser.add_argument('--no_distance', dest='distance', action='store_false', help=\"Turn off distance term.\")\n\n    # Originally, we used a single adam optimizer, stopping after 1000 stalled iterations,\n    # with a couple other hyperparameters corresponding to:  TODO\n    #   --max_steps_before_stop 1000\n    #   --beta2 0.95\n    #   --lr 3e-3\n    #   --weight_decay 0.0\n    #   --optim adam\n    #   --no_second_optim\n    # Later experiments found the current defaults helped the results\n    # on several different datasets (using a transformer as the input embedding)\n    # These experiements are averaged across 5 models,\n    # with multiple early stopping values as well\n    #   5 model dev avg LAS  1 stage  1 stage 2k  1 stage 4k   2 stage\n    # de_gsd                  89.03    89.50       89.71        89.83\n    # en_ewt                  93.47    93.69       93.74        93.89\n    # fi_tdt                  92.16    92.56       92.69        93.15\n    # it_vit                  90.12    90.37       90.44        90.60\n    # ta_ttb                  71.26    71.39       71.45        72.19\n    # zh-hans_gsdsimp         85.47    85.69       85.76        85.89\n    #\n    #   5 model test avg LAS 1 stage  1 stage 2k  1 stage 4k   2 stage\n    # de_gsd                  86.60    86.96       87.04        87.09\n    # en_ewt                  93.37    93.51       93.55        93.72\n    # fi_tdt                  92.56    92.92       93.10        93.47\n    # it_vit                  90.51    90.74       90.75        90.88\n    # ta_ttb                  68.22    68.27       68.42        69.06\n    # zh-hans_gsdsimp         85.66    85.92       86.04        86.34\n    #\n    # In addition to these experiments, we ran multiple alternate optimizer combinations, none of which\n    # were a clear improvement over AdaDelta+Adam\n    #\n    # rmsprop  --weight_decay 1e-5 --lr 0.0001\n    # adamw    --second_lr 0.0001\n    # madgrad  --second_lr 0.00008\n    #   5 model dev avg LAS   ada+adam   rms+adam    ada+adamw  ada+madgrad\n    # de_gsd                 89.83      89.80       89.67      89.55\n    # en_ewt                 93.89      93.97       93.92      93.90\n    # fi_tdt                 93.15      92.95       93.03      93.08\n    # it_vit                 90.60      90.64       90.58      90.54\n    # ta_ttb                 72.19      71.86       72.18      72.24\n    # zh-hans_gsdsimp        85.89      85.60       85.97      85.92\n    #\n    #   5 model test avg LAS    ada+adam  rms+adam   ada+adamw  ada+madgrad\n    # de_gsd                   87.09     87.26      87.06      87.08\n    # en_ewt                   93.72     93.73      93.75      93.73\n    # fi_tdt                   93.47     93.30      93.43      93.44\n    # it_vit                   90.88     90.95      90.90      90.85\n    # ta_ttb                   69.06     68.45      69.05      69.26\n    # zh-hans_gsdsimp          86.34     85.86      86.27      86.23\n\n    parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')\n    parser.add_argument('--optim', type=str, default='adadelta', help='sgd, adagrad, adam or adamax.')\n    parser.add_argument('--second_optim', type=str, default=\"adam\", help='sgd, adagrad, adam or adamax.')\n    parser.add_argument('--no_second_optim', dest='second_optim', action='store_const', const=None, help=\"Don't use the second optimizer\")\n    parser.add_argument('--lr', type=float, default=2.0, help='Learning rate')\n    parser.add_argument('--second_lr', type=float, default=0.0002, help='Secondary stage learning rate')\n    parser.add_argument('--weight_decay', type=float, default=0.00001, help='Weight decay for the first optimizer')\n    parser.add_argument('--beta2', type=float, default=0.999)\n    parser.add_argument('--second_optim_start_step', type=int, default=10000, help='If set, switch to the second optimizer when stalled or at this step regardless of performance.  Normally, the optimizer only switches when the dev scores have stalled for --max_steps_before_stop steps')\n    parser.add_argument('--second_warmup_steps', type=int, default=200, help=\"If set, give the 2nd optimizer a linear warmup.  Idea being that the optimizer won't have a good grasp on the initial gradients and square gradients when it first starts\")\n\n    parser.add_argument('--max_steps', type=int, default=50000)\n    parser.add_argument('--eval_interval', type=int, default=100)\n    parser.add_argument('--checkpoint_interval', type=int, default=500)\n    parser.add_argument('--max_steps_before_stop', type=int, default=2000)\n    parser.add_argument('--batch_size', type=int, default=5000)\n    parser.add_argument('--second_batch_size', type=int, default=None, help='Use a different batch size for the second optimizer.  Can be relevant for models with different transformer finetuning settings between optimizers, for example, where the larger batch size is impossible for FT the transformer\"')\n    parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.')\n    parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')\n    parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')\n    parser.add_argument('--save_dir', type=str, default='saved_models/depparse', help='Root dir for saving models.')\n    parser.add_argument('--save_name', type=str, default=\"{shorthand}_{embedding}_parser.pt\", help=\"File name to save the model\")\n    parser.add_argument('--continue_from', type=str, default=None, help=\"File name to preload the model to continue training from\")\n\n    parser.add_argument('--seed', type=int, default=1234)\n    add_peft_args(parser)\n    utils.add_device_args(parser)\n\n    parser.add_argument('--augment_nopunct', type=float, default=None, help='Augment the training data by copying this fraction of punct-ending sentences as non-punct.  Default of None will aim for roughly 10%%')\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n\n    parser.add_argument('--train_size', type=int, default=None, help='If specified, randomly select this many sentences from the training data')\n    return parser\n\ndef parse_args(args=None):\n    parser = build_argparse()\n    args = parser.parse_args(args=args)\n    resolve_peft_args(args, logger)\n\n    if args.wandb_name:\n        args.wandb = True\n\n    args = vars(args)\n    return args\n\ndef main(args=None):\n    args = parse_args(args=args)\n\n    utils.set_random_seed(args['seed'])\n\n    logger.info(\"Running parser in {} mode\".format(args['mode']))\n\n    if args['mode'] == 'train':\n        return train(args)\n    else:\n        return evaluate(args)\n\ndef model_file_name(args):\n    return utils.standard_model_file_name(args, \"parser\")\n\n# TODO: refactor with everywhere\ndef load_pretrain(args):\n    pt = None\n    if args['pretrain']:\n        pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])\n        if os.path.exists(pretrain_file):\n            vec_file = None\n        else:\n            vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])\n        pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])\n    return pt\n\ndef predict_dataset(trainer, dev_batch):\n    dev_preds = []\n    if len(dev_batch) > 0:\n        for batch in dev_batch:\n            preds = trainer.predict(batch)\n            dev_preds += preds\n        dev_preds = utils.unsort(dev_preds, dev_batch.data_orig_idx)\n    return dev_preds\n\ndef train(args):\n    model_file = model_file_name(args)\n    utils.ensure_dir(os.path.split(model_file)[0])\n\n    # load pretrained vectors if needed\n    pretrain = load_pretrain(args)\n    args['word_cutoff'] = utils.update_word_cutoff(pretrain, args['word_cutoff'])\n\n    # TODO: refactor.  the exact same thing is done in the tagger\n    if args['charlm']:\n        if args['charlm_shorthand'] is None:\n            raise ValueError(\"CharLM Shorthand is required for loading pretrained CharLM model...\")\n        logger.info('Using pretrained contextualized char embedding')\n        if not args['charlm_forward_file']:\n            args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])\n        if not args['charlm_backward_file']:\n            args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])\n\n    utils.log_training_args(args, logger)\n\n    # load data\n    logger.info(\"Loading data with batch size {}...\".format(args['batch_size']))\n    train_file = args['train_file']\n    if zipfile.is_zipfile(train_file):\n        logger.info(\"Decompressing %s\" % train_file)\n        train_data = []\n        with zipfile.ZipFile(train_file) as zin:\n            for zipped_train_file in zin.namelist():\n                with zin.open(zipped_train_file) as fin:\n                    logger.info(\"Reading %s from %s\" % (zipped_train_file, train_file))\n                    train_str = fin.read()\n                    train_str = train_str.decode(\"utf-8\")\n                    train_file_data, _, _ = CoNLL.conll2dict(input_str=train_str)\n                    logger.info(\"Train File {} from {}, Data Size: {}\".format(zipped_train_file, train_file, len(train_file_data)))\n                    train_data.extend(train_file_data)\n    else:\n        train_data, _, _ = CoNLL.conll2dict(input_file=args['train_file'])\n        logger.info(\"Train File {}, Data Size: {}\".format(train_file, len(train_data)))\n    # possibly augment the training data with some amount of fake data\n    # based on the options chosen\n    logger.info(\"Original data size: {}\".format(len(train_data)))\n    if args['train_size']:\n        if len(train_data) < args['train_size']:\n            random.shuffle(train_data)\n            train_data = train_data[:args['train_size']]\n            logger.info(\"Limiting training data to %d entries\", len(train_data))\n        else:\n            logger.info(\"Train data less than %d already, not limiting train data\", args['train_size'])\n    # build the training data once, before augmentation, so that random variation\n    # (which might be different based on the random seed)\n    # doesn't have an effect on the vocab being cut off at the word limit\n    # otherwise different models will have different vocabs\n    # based on how often the words were duplicated in the augmentation\n    # TODO: put the augmentation into the dataloader,\n    # such as is done with the POS or the tokenizer\n    train_doc = Document(train_data)\n    train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, evaluation=False)\n    vocab = train_batch.vocab\n    train_data.extend(augment_punct(train_data, args['augment_nopunct'],\n                                    keep_original_sentences=False))\n    logger.info(\"Augmented data size: {}\".format(len(train_data)))\n    train_doc = Document(train_data)\n    train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=False)\n    dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])\n    dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)\n\n    # skip training if the language does not have training or dev data\n    if len(train_batch) == 0 or len(dev_batch) == 0:\n        logger.info(\"Skip training because no data available...\")\n        sys.exit(0)\n\n    if args['wandb']:\n        import wandb\n        wandb_name = args['wandb_name'] if args['wandb_name'] else \"%s_depparse\" % args['shorthand']\n        wandb.init(name=wandb_name, config=args)\n        wandb.run.define_metric('train_loss', summary='min')\n        wandb.run.define_metric('dev_score', summary='max')\n\n    logger.info(\"Training parser...\")\n    checkpoint_file = None\n    if args.get(\"checkpoint\"):\n        # calculate checkpoint file name from the save filename\n        checkpoint_file = utils.checkpoint_name(args.get(\"save_dir\"), model_file, args.get(\"checkpoint_save_name\"))\n        args[\"checkpoint_save_name\"] = checkpoint_file\n\n    if args.get(\"checkpoint\") and os.path.exists(args[\"checkpoint_save_name\"]):\n        trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args[\"checkpoint_save_name\"], device=args['device'], ignore_model_config=True)\n        if len(trainer.dev_score_history) > 0:\n            logger.info(\"Continuing from checkpoint %s  Model was previously trained for %d steps, with a best dev score of %.4f\", args[\"checkpoint_save_name\"], trainer.global_step, max(trainer.dev_score_history))\n    elif args[\"continue_from\"]:\n        if not os.path.exists(args[\"continue_from\"]):\n            raise FileNotFoundError(\"--continue_from specified, but the file %s does not exist\" % args[\"continue_from\"])\n        trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args[\"continue_from\"], device=args['device'], ignore_model_config=True, reset_history=True)\n    else:\n        trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'])\n\n    max_steps = args['max_steps']\n    current_lr = args['lr']\n    global_start_time = time.time()\n    format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'\n\n    is_second_stage = False\n    # start training\n    train_loss = 0\n    if args['log_norms']:\n        trainer.model.log_norms()\n    while True:\n        do_break = False\n        for i, batch in enumerate(train_batch):\n            start_time = time.time()\n            trainer.global_step += 1\n            loss = trainer.update(batch, eval=False) # update step\n            train_loss += loss\n\n            # will checkpoint if we switch optimizers or score a new best score\n            force_checkpoint = False\n            if trainer.global_step % args['log_step'] == 0:\n                duration = time.time() - start_time\n                logger.info(format_str.format(trainer.global_step, max_steps, loss, duration, current_lr))\n\n            if trainer.global_step % args['eval_interval'] == 0:\n                # eval on dev\n                logger.info(\"Evaluating on dev set...\")\n                dev_preds = predict_dataset(trainer, dev_batch)\n\n                dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x])\n\n                system_pred_file = \"{:C}\\n\\n\".format(dev_batch.doc)\n                system_pred_file = io.StringIO(system_pred_file)\n                _, _, dev_score = scorer.score(system_pred_file, args['eval_file'])\n\n                train_loss = train_loss / args['eval_interval'] # avg loss per batch\n                logger.info(\"step {}: train_loss = {:.6f}, dev_score = {:.4f}\".format(trainer.global_step, train_loss, dev_score))\n\n                if args['wandb']:\n                    wandb.log({'train_loss': train_loss, 'dev_score': dev_score})\n\n                train_loss = 0\n\n                # save best model\n                trainer.dev_score_history += [dev_score]\n                if dev_score >= max(trainer.dev_score_history):\n                    trainer.last_best_step = trainer.global_step\n                    trainer.save(model_file)\n                    logger.info(\"new best model saved.\")\n                    force_checkpoint = True\n\n                for scheduler_name, scheduler in trainer.scheduler.items():\n                    logger.info('scheduler %s learning rate: %s', scheduler_name, scheduler.get_last_lr())\n                if args['log_norms']:\n                    trainer.model.log_norms()\n\n            if not is_second_stage and args.get('second_optim', None) is not None:\n                if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop'] or (args['second_optim_start_step'] is not None and trainer.global_step >= args['second_optim_start_step']):\n                    logger.info(\"Switching to second optimizer: {}\".format(args.get('second_optim', None)))\n                    global_step = trainer.global_step\n                    args[\"second_stage\"] = True\n                    # if the loader gets a model file, it uses secondary optimizer\n                    # (because of the second_stage = True argument)\n                    trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain,\n                                      model_file=model_file, device=args['device'])\n                    logger.info('Reloading best model to continue from current local optimum')\n\n                    dev_preds = predict_dataset(trainer, dev_batch)\n                    dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x])\n                    system_pred_file = \"{:C}\\n\\n\".format(dev_batch.doc)\n                    system_pred_file = io.StringIO(system_pred_file)\n                    _, _, dev_score = scorer.score(system_pred_file, args['eval_file'])\n                    logger.info(\"Reloaded model with dev score %.4f\", dev_score)\n\n                    is_second_stage = True\n                    trainer.global_step = global_step\n                    trainer.last_best_step = global_step\n                    if args['second_batch_size'] is not None:\n                        train_batch.set_batch_size(args['second_batch_size'])\n                    force_checkpoint = True\n            else:\n                if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop']:\n                    do_break = True\n                    break\n\n            if trainer.global_step % args['eval_interval'] == 0 or force_checkpoint:\n                # if we need to save checkpoint, do so\n                # (save after switching the optimizer, if applicable, so that\n                # the new optimizer is the optimizer used if a restart happens)\n                if checkpoint_file is not None:\n                    trainer.save(checkpoint_file, save_optimizer=True)\n                    logger.info(\"new model checkpoint saved.\")\n\n            if trainer.global_step >= args['max_steps']:\n                do_break = True\n                break\n\n        if do_break: break\n\n        train_batch.reshuffle()\n\n    logger.info(\"Training ended with {} steps.\".format(trainer.global_step))\n\n    if args['wandb']:\n        wandb.finish()\n\n    if len(trainer.dev_score_history) > 0:\n        # TODO: technically the iteration position will be wrong if\n        # the eval_interval changed when running from a checkpoint\n        # could fix this by saving step & score instead of just score\n        best_f, best_eval = max(trainer.dev_score_history)*100, np.argmax(trainer.dev_score_history)+1\n        logger.info(\"Best dev F1 = {:.2f}, at iteration = {}\".format(best_f, best_eval * args['eval_interval']))\n    else:\n        logger.info(\"Dev set never evaluated.  Saving final model.\")\n        trainer.save(model_file)\n\n    return trainer, _\n\ndef evaluate(args):\n    model_file = model_file_name(args)\n    # load pretrained vectors if needed\n    pretrain = load_pretrain(args)\n\n    load_args = {'charlm_forward_file': args.get('charlm_forward_file', None),\n                 'charlm_backward_file': args.get('charlm_backward_file', None)}\n\n    # load model\n    logger.info(\"Loading model from: {}\".format(model_file))\n    trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args)\n    if args['log_norms']:\n        trainer.model.log_norms()\n    return trainer, evaluate_trainer(args, trainer, pretrain)\n\ndef evaluate_trainer(args, trainer, pretrain):\n    system_pred_file = args['output_file']\n    loaded_args, vocab = trainer.args, trainer.vocab\n\n    # load config\n    for k in args:\n        if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand'] or k == 'mode':\n            loaded_args[k] = args[k]\n\n    # load data\n    logger.info(\"Loading data with batch size {}...\".format(args['batch_size']))\n    doc = CoNLL.conll2doc(input_file=args['eval_file'])\n    batch = DataLoader(doc, args['batch_size'], loaded_args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)\n\n    preds = predict_dataset(trainer, batch)\n\n    # write to file and score\n    batch.doc.set([HEAD, DEPREL], [y for x in preds for y in x])\n    if system_pred_file:\n        CoNLL.write_doc2conll(batch.doc, system_pred_file)\n\n    if args['gold_labels']:\n        gold_doc = CoNLL.conll2doc(input_file=args['eval_file'])\n\n        # Check for None ... otherwise an inscrutable error occurs later in the scorer\n        for sent_idx, sentence in enumerate(gold_doc.sentences):\n            for word_idx, word in enumerate(sentence.words):\n                if word.deprel is None:\n                    raise ValueError(\"Gold document {} has a None at sentence {} word {}\\n{:C}\".format(args['eval_file'], sent_idx, word_idx, sentence))\n\n        scorer.score_named_dependencies(batch.doc, gold_doc, args['output_latex'])\n        system_pred_file = \"{:C}\\n\\n\".format(batch.doc)\n        system_pred_file = io.StringIO(system_pred_file)            \n        _, _, score = scorer.score(system_pred_file, args['eval_file'])\n\n        logger.info(\"Parser score on %s file %s: %.2f\", args['shorthand'], args['eval_file'], score*100)\n\n    return batch.doc\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/pos/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/pos/build_xpos_vocab_factory.py",
    "content": "import argparse\nfrom collections import defaultdict\nimport logging\nimport os\nimport re\nimport sys\nfrom zipfile import ZipFile\n\nfrom stanza.models.common.constant import treebank_to_short_name\nfrom stanza.models.pos.xpos_vocab_utils import DEFAULT_KEY, choose_simplest_factory, XPOSType\nfrom stanza.models.common.doc import *\nfrom stanza.utils.conll import CoNLL\nfrom stanza.utils import default_paths\n\nSHORTNAME_RE = re.compile(\"[a-z-]+_[a-z0-9]+\")\nDATA_DIR = default_paths.get_default_paths()['POS_DATA_DIR']\n\nlogger = logging.getLogger('stanza')\n\ndef get_xpos_factory(shorthand, fn):\n    logger.info('Resolving vocab option for {}...'.format(shorthand))\n    doc = None\n    train_file = os.path.join(DATA_DIR, '{}.train.in.conllu'.format(shorthand))\n    if os.path.exists(train_file):\n        doc = CoNLL.conll2doc(input_file=train_file)\n    else:\n        zip_file = os.path.join(DATA_DIR, '{}.train.in.zip'.format(shorthand))\n        if os.path.exists(zip_file):\n            with ZipFile(zip_file) as zin:\n                for train_file in zin.namelist():\n                    doc = CoNLL.conll2doc(input_file=train_file, zip_file=zip_file)\n                    if any(word.xpos for sentence in doc.sentences for word in sentence.words):\n                        break\n                else:\n                    raise ValueError('Found training data in {}, but none of the files contained had xpos'.format(zip_file))\n\n    if doc is None:\n        raise FileNotFoundError('Training data for {} not found.  To generate the XPOS vocabulary '\n                                'for this treebank properly, please run the following command first:\\n'\n                                '  python3 stanza/utils/datasets/prepare_pos_treebank.py {}'.format(fn, fn))\n        # without the training file, there's not much we can do\n        key = DEFAULT_KEY\n        return key\n\n    data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)\n    return choose_simplest_factory(data, shorthand)\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--treebanks', type=str, default=DATA_DIR, help=\"Treebanks to process - directory with processed datasets or a file with a list\")\n    parser.add_argument('--output_file', type=str, default=\"stanza/models/pos/xpos_vocab_factory.py\", help=\"Where to write the results\")\n    args = parser.parse_args()\n\n    output_file = args.output_file\n    if os.path.isdir(args.treebanks):\n        # if the path is a directory of datasets (which is the default if --treebanks is not set)\n        # we use those datasets to prepare the xpos factories\n        treebanks = os.listdir(args.treebanks)\n        treebanks = [x.split(\".\", maxsplit=1)[0] for x in treebanks]\n        treebanks = sorted(set(treebanks))\n    elif os.path.exists(args.treebanks):\n        # maybe it's a file with a list of names\n        with open(args.treebanks) as fin:\n            treebanks = sorted(set([x.strip() for x in fin.readlines() if x.strip()]))\n    else:\n        raise ValueError(\"Cannot figure out which treebanks to use.   Please set the --treebanks parameter\")\n\n    logger.info(\"Processing the following treebanks: %s\" % \" \".join(treebanks))\n\n    shorthands = []\n    fullnames = []\n    for treebank in treebanks:\n        fullnames.append(treebank)\n        if SHORTNAME_RE.match(treebank):\n            shorthands.append(treebank)\n        else:\n            shorthands.append(treebank_to_short_name(treebank))\n\n    # For each treebank, we would like to find the XPOS Vocab configuration that minimizes\n    # the number of total classes needed to predict by all tagger classifiers. This is\n    # achieved by enumerating different options of separators that different treebanks might\n    # use, and comparing that to treating the XPOS tags as separate categories (using a\n    # WordVocab).\n    mapping = defaultdict(list)\n    for sh, fn in zip(shorthands, fullnames):\n        factory = get_xpos_factory(sh, fn)\n        mapping[factory].append(sh)\n        if sh == 'zh-hans_gsdsimp':\n            mapping[factory].append('zh_gsdsimp')\n        elif sh == 'no_bokmaal':\n            mapping[factory].append('nb_bokmaal')\n\n    mapping[DEFAULT_KEY].append('en_test')\n\n    # Generate code. This takes the XPOS vocabulary classes selected above, and generates the\n    # actual factory class as seen in models.pos.xpos_vocab_factory.\n    first = True\n    with open(output_file, 'w') as f:\n        max_len = max(max(len(x) for x in mapping[key]) for key in mapping)\n        print('''# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.\n# Please don't edit it!\n\nimport logging\n\nfrom stanza.models.pos.vocab import WordVocab, XPOSVocab\nfrom stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory\n\n# using a sublogger makes it easier to test in the unittests\nlogger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')\n\nXPOS_DESCRIPTIONS = {''', file=f)\n\n        for key_idx, key in enumerate(mapping):\n            if key_idx > 0:\n                print(file=f)\n            for shorthand in sorted(mapping[key]):\n                # +2 to max_len for the ''\n                # this format string is left justified (either would be okay, probably)\n                if key.sep is None:\n                    sep = 'None'\n                else:\n                    sep = \"'%s'\" % key.sep\n                print((\"    {:%ds}: XPOSDescription({}, {}),\" % (max_len+2)).format(\"'%s'\" % shorthand, key.xpos_type, sep), file=f)\n\n        print('''}\n\ndef xpos_vocab_factory(data, shorthand):\n    if shorthand not in XPOS_DESCRIPTIONS:\n        logger.warning(\"%s is not a known dataset.  Examining the data to choose which xpos vocab to use\", shorthand)\n    desc = choose_simplest_factory(data, shorthand)\n    if shorthand in XPOS_DESCRIPTIONS:\n        if XPOS_DESCRIPTIONS[shorthand] != desc:\n            # log instead of throw\n            # otherwise, updating datasets would be unpleasant\n            logger.error(\"XPOS tagset in %s has apparently changed!  Was %s, is now %s\", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)\n    else:\n        logger.warning(\"Chose %s for the xpos factory for %s\", desc, shorthand)\n    return build_xpos_vocab(desc, data, shorthand)\n''', file=f)\n\n    logger.info('Done!')\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/models/pos/data.py",
    "content": "import random\nimport logging\nimport copy\nimport torch\nfrom collections import namedtuple\n\nfrom torch.utils.data import DataLoader as DL\nfrom torch.utils.data.sampler import Sampler\nfrom torch.nn.utils.rnn import pad_sequence\n\nfrom stanza.models.common.bert_embedding import filter_data, needs_length_filter\nfrom stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all\nfrom stanza.models.common.utils import DEFAULT_WORD_CUTOFF, simplify_punct\nfrom stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, CharVocab\nfrom stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab\nfrom stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory\nfrom stanza.models.common.doc import *\n\nlogger = logging.getLogger('stanza')\n\nDataSample = namedtuple(\"DataSample\", \"word char upos xpos feats pretrain text\")\nDataBatch = namedtuple(\"DataBatch\", \"words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx\")\n\nclass Dataset:\n    def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs):\n        self.args = args\n        self.eval = evaluation\n        self.shuffled = not self.eval\n        self.sort_during_eval = sort_during_eval\n        self.doc = doc\n\n        if vocab is None:\n            self.vocab = Dataset.init_vocab([doc], args)\n        else:\n            self.vocab = vocab\n\n        self.has_upos = not all(x is None or x == '_' for x in doc.get(UPOS, as_sentences=False))\n        self.has_xpos = not all(x is None or x == '_' for x in doc.get(XPOS, as_sentences=False))\n        self.has_feats = not all(x is None or x == '_' for x in doc.get(FEATS, as_sentences=False))\n\n        data = self.load_doc(self.doc)\n        # filter out the long sentences if bert is used\n        if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):\n            data = filter_data(self.args['bert_model'], data, bert_tokenizer)\n\n        # handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None\n        self.pretrain_vocab = None\n        if pretrain is not None and args['pretrain']:\n            self.pretrain_vocab = pretrain.vocab\n\n        # filter and sample data\n        if args.get('sample_train', 1.0) < 1.0 and not self.eval:\n            keep = int(args['sample_train'] * len(data))\n            data = random.sample(data, keep)\n            logger.debug(\"Subsample training set with rate {:g}\".format(args['sample_train']))\n\n        data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)\n\n        self.data = data\n\n        self.num_examples = len(data)\n        self.__punct_tags = self.vocab[\"upos\"].map([\"PUNCT\"])\n        self.augment_nopunct = self.args.get(\"augment_nopunct\", 0.0)\n\n    @staticmethod\n    def init_vocab(docs, args):\n        cutoff = args['word_cutoff'] if args.get('word_cutoff') is not None else DEFAULT_WORD_CUTOFF\n        data = [x for doc in docs for x in Dataset.load_doc(doc)]\n        charvocab = CharVocab(data, args['shorthand'])\n        wordvocab = WordVocab(data, args['shorthand'], cutoff=cutoff, lower=True)\n        uposvocab = WordVocab(data, args['shorthand'], idx=1)\n        xposvocab = xpos_vocab_factory(data, args['shorthand'])\n        try:\n            featsvocab = FeatureVocab(data, args['shorthand'], idx=3)\n        except ValueError as e:\n            raise ValueError(\"Unable to build features vocab.  Please check the Features column of your data for an error which may match the following description.\") from e\n        vocab = MultiVocab({'char': charvocab,\n                            'word': wordvocab,\n                            'upos': uposvocab,\n                            'xpos': xposvocab,\n                            'feats': featsvocab})\n        return vocab\n\n    def preprocess(self, data, vocab, pretrain_vocab, args):\n        processed = []\n        for sent in data:\n            processed_sent = DataSample(\n                word = [vocab['word'].map([w[0] for w in sent])],\n                char = [[vocab['char'].map([x for x in w[0]]) for w in sent]],\n                upos = [vocab['upos'].map([w[1] for w in sent])],\n                xpos = [vocab['xpos'].map([w[2] for w in sent])],\n                feats = [vocab['feats'].map([w[3] for w in sent])],\n                pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])]\n                            if pretrain_vocab is not None\n                           else [[PAD_ID] * len(sent)]),\n                text = [w[0] for w in sent]\n            )\n            processed.append(processed_sent)\n\n        return processed\n\n    def __len__(self):\n        return len(self.data)\n\n    def __mask(self, upos):\n        \"\"\"Returns a torch boolean about which elements should be masked out\"\"\"\n\n        # creates all false mask\n        mask = torch.zeros_like(upos, dtype=torch.bool)\n\n        ### augmentation 1: punctuation augmentation ###\n        # tags that needs to be checked, currently only PUNCT\n        if random.uniform(0,1) < self.augment_nopunct:\n            for i in self.__punct_tags:\n                # generate a mask for the last element\n                last_element = torch.zeros_like(upos, dtype=torch.bool)\n                last_element[..., -1] = True\n                # we or the bitmask against the existing mask\n                # if it satisfies, we remove the word by masking it\n                # to true\n                #\n                # if your input is just a lone punctuation, we perform\n                # no masking\n                if not torch.all(upos.eq(torch.tensor([[i]]))):\n                    mask |= ((upos == i) & (last_element))\n\n        return mask\n\n    def __getitem__(self, key):\n        \"\"\"Retrieves a sample from the dataset.\n\n        Retrieves a sample from the dataset. This function, for the\n        most part, is spent performing ad-hoc data augmentation and\n        restoration. It receives a DataSample object from the storage,\n        and returns an almost-identical DataSample object that may\n        have been augmented with /possibly/ (depending on augment_punct\n        settings) PUNCT chopped.\n\n        **Important Note**\n        ------------------\n        If you would like to load the data into a model, please convert\n        this Dataset object into a DataLoader via self.to_loader(). Then,\n        you can use the resulting object like any other PyTorch data\n        loader. As masks are calculated ad-hoc given the batch, the samples\n        returned from this object doesn't have the appropriate masking.\n\n        Motivation\n        ----------\n        Why is this here? Every time you call next(iter(dataloader)), it calls\n        this function. Therefore, if we augmented each sample on each iteration,\n        the model will see dynamically generated augmentation.\n        Furthermore, PyTorch dataloader handles shuffling natively.\n\n        Parameters\n        ----------\n        key : int\n            the integer ID to from which to retrieve the key.\n\n        Returns\n        -------\n        DataSample\n            The sample of data you requested, with augmentation.\n        \"\"\"\n        # get a sample of the input data\n        sample = self.data[key]\n\n        # some data augmentation requires constructing a mask based on upos.\n        # For instance, sometimes we'd like to mask out ending sentence punctuation.\n        # We copy the other items here so that any edits made because\n        # of the mask don't clobber the version owned by the Dataset\n        # convert to tensors\n        # TODO: only store single lists per data entry?\n        words = torch.tensor(sample.word[0])\n        # convert the rest to tensors\n        upos = torch.tensor(sample.upos[0]) if self.has_upos else None\n        xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None\n        ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None\n        pretrained = torch.tensor(sample.pretrain[0])\n\n        # and deal with char & raw_text\n        char = sample.char[0]\n        raw_text = sample.text\n\n        # some data augmentation requires constructing a mask based on\n        # which upos. For instance, sometimes we'd like to mask out ending\n        # sentence punctuation. The mask is True if we want to remove the element\n        if self.has_upos and upos is not None and not self.eval:\n            # perform actual masking\n            mask = self.__mask(upos)\n        else:\n            # dummy mask that's all false\n            mask = None\n        if mask is not None:\n            mask_index = mask.nonzero()\n\n            # mask out the elements that we need to mask out\n            for mask in mask_index:\n                mask = mask.item()\n                words[mask] = PAD_ID\n                if upos is not None:\n                    upos[mask] = PAD_ID\n                if xpos is not None:\n                    # TODO: test the multi-dimension xpos\n                    xpos[mask, ...] = PAD_ID\n                if ufeats is not None:\n                    ufeats[mask, ...] = PAD_ID\n                pretrained[mask] = PAD_ID\n                char = char[:mask] + char[mask+1:]\n                raw_text = raw_text[:mask] + raw_text[mask+1:]\n\n        # get each character from the input sentnece\n        # chars = [w for sent in char for w in sent]\n\n        return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key\n\n    def __iter__(self):\n        for i in range(self.__len__()):\n            yield self.__getitem__(i)\n\n    def to_loader(self, **kwargs):\n        \"\"\"Converts self to a DataLoader \"\"\"\n\n        return DL(self,\n                  collate_fn=Dataset.__collate_fn,\n                  **kwargs)\n\n    def to_length_limited_loader(self, batch_size, maximum_tokens):\n        sampler = LengthLimitedBatchSampler(self, batch_size, maximum_tokens)\n        return DL(self,\n                  collate_fn=Dataset.__collate_fn,\n                  batch_sampler = sampler)\n\n    @staticmethod\n    def __collate_fn(data):\n        \"\"\"Function used by DataLoader to pack data\"\"\"\n        (data, idx) = zip(*data)\n        (words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data)\n\n        # collate_fn is given a list of length batch size\n        batch_size = len(data)\n\n        # sort sentences by lens for easy RNN operations\n        lens = [torch.sum(x != PAD_ID) for x in words]\n        (words, wordchars, upos, xpos,\n         ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos,\n                                                         ufeats, pretrained, text), lens)\n        lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN\n\n        # combine all words into one large list, and sort for easy charRNN ops\n        wordchars = [w for sent in wordchars for w in sent]\n        word_lens = [len(x) for x in wordchars]\n        (wordchars,), word_orig_idx = sort_all([wordchars], word_lens)\n        word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN\n\n        # We now pad everything\n        words = pad_sequence(words, True, PAD_ID)\n        if None not in upos:\n            upos = pad_sequence(upos, True, PAD_ID)\n        else:\n            upos = None\n        if None not in xpos:\n            xpos = pad_sequence(xpos, True, PAD_ID)\n        else:\n            xpos = None\n        if None not in ufeats:\n            ufeats = pad_sequence(ufeats, True, PAD_ID)\n        else:\n            ufeats = None\n        pretrained = pad_sequence(pretrained, True, PAD_ID)\n        wordchars = get_long_tensor(wordchars, len(word_lens))\n\n        # and finally create masks for the padding indices\n        words_mask = torch.eq(words, PAD_ID)\n        wordchars_mask = torch.eq(wordchars, PAD_ID)\n\n        return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,\n                         pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)\n\n    @staticmethod\n    def load_doc(doc):\n        data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)\n        data = Dataset.resolve_none(data)\n        data = simplify_punct(data)\n        return data\n\n    @staticmethod\n    def resolve_none(data):\n        # replace None to '_'\n        for sent_idx in range(len(data)):\n            for tok_idx in range(len(data[sent_idx])):\n                for feat_idx in range(len(data[sent_idx][tok_idx])):\n                    if data[sent_idx][tok_idx][feat_idx] is None:\n                        data[sent_idx][tok_idx][feat_idx] = '_'\n        return data\n\nclass LengthLimitedBatchSampler(Sampler):\n    \"\"\"\n    Batches up the text in batches of batch_size, but cuts off each time a batch reaches maximum_tokens\n\n    Intent is to avoid GPU OOM in situations where one sentence is significantly longer than expected,\n    leaving a batch too large to fit in the GPU\n\n    Sentences which are longer than maximum_tokens by themselves are put in their own batches\n    \"\"\"\n    def __init__(self, data, batch_size, maximum_tokens):\n        \"\"\"\n        Precalculate the batches, making it so len and iter just read off the precalculated batches\n        \"\"\"\n        self.data = data\n        self.batch_size = batch_size\n        self.maximum_tokens = maximum_tokens\n\n        self.batches = []\n        current_batch = []\n        current_length = 0\n\n        for item, item_idx in data:\n            item_len = len(item.word)\n            if maximum_tokens and item_len > maximum_tokens:\n                if len(current_batch) > 0:\n                    self.batches.append(current_batch)\n                    current_batch = []\n                    current_length = 0\n                self.batches.append([item_idx])\n                continue\n            if len(current_batch) + 1 > batch_size or (maximum_tokens and item_len + current_length > maximum_tokens):\n                self.batches.append(current_batch)\n                current_batch = []\n                current_length = 0\n            current_batch.append(item_idx)\n            current_length += item_len\n\n        if len(current_batch) > 0:\n            self.batches.append(current_batch)\n\n    def __len__(self):\n        return len(self.batches)\n\n    def __iter__(self):\n        for batch in self.batches:\n            current_batch = []\n            for idx in batch:\n                current_batch.append(idx)\n            yield current_batch\n\n\nclass ShuffledDataset:\n    \"\"\"A wrapper around one or more datasets which shuffles the data in batch_size chunks\n\n    This means that if multiple datasets are passed in, the batches\n    from each dataset are shuffled together, with one batch being\n    entirely members of the same dataset.\n\n    The main use case of this is that in the tagger, there are cases\n    where batches from different datasets will have different\n    properties, such as having or not having UPOS tags.  We found that\n    it is actually somewhat tricky to make the model's loss function\n    (in model.py) properly represent batches with mixed w/ and w/o\n    property, whereas keeping one entire batch together makes it a lot\n    easier to process.\n\n    The mechanism for the shuffling is that the iterator first makes a\n    list long enough to represent each batch from each dataset,\n    tracking the index of the dataset it is coming from, then shuffles\n    that list.  Another alternative would be to use a weighted\n    randomization approach, but this is very simple and the memory\n    requirements are not too onerous.\n\n    Note that the batch indices are wasteful in the case of only one\n    underlying dataset, which is actually the most common use case,\n    but the overhead is small enough that it probably isn't worth\n    special casing the one dataset version.\n    \"\"\"\n    def __init__(self, datasets, batch_size):\n        self.batch_size = batch_size\n        self.datasets = datasets\n        self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets]\n\n    def __iter__(self):\n        iterators = [iter(x) for x in self.loaders]\n        lengths = [len(x) for x in self.loaders]\n        indices = [[x] * y for x, y in enumerate(lengths)]\n        indices = [idx for inner in indices for idx in inner]\n        random.shuffle(indices)\n\n        for idx in indices:\n            yield(next(iterators[idx]))\n\n    def __len__(self):\n        return sum(len(x) for x in self.datasets)\n"
  },
  {
    "path": "stanza/models/pos/model.py",
    "content": "import logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence\n\nfrom stanza.models.common.bert_embedding import extract_bert_embeddings\nfrom stanza.models.common.biaffine import BiaffineScorer\nfrom stanza.models.common.foundation_cache import load_bert, load_charlm\nfrom stanza.models.common.hlstm import HighwayLSTM\nfrom stanza.models.common.dropout import WordDropout\nfrom stanza.models.common.utils import attach_bert_model\nfrom stanza.models.common.vocab import CompositeVocab\nfrom stanza.models.common.char_model import CharacterModel\nfrom stanza.models.common import utils\n\nlogger = logging.getLogger('stanza')\n\nclass Tagger(nn.Module):\n    def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):\n        super().__init__()\n\n        self.vocab = vocab\n        self.args = args\n        self.share_hid = share_hid\n        self.unsaved_modules = []\n\n        # input layers\n        input_size = 0\n        if self.args['word_emb_dim'] > 0:\n            # frequent word embeddings\n            self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)\n            input_size += self.args['word_emb_dim']\n\n        if not share_hid:\n            # upos embeddings\n            self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)\n\n        if self.args['char'] and self.args['char_emb_dim'] > 0:\n            if self.args.get('charlm', None):\n                if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):\n                    raise FileNotFoundError('Could not find forward character model: {}  Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))\n                if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):\n                    raise FileNotFoundError('Could not find backward character model: {}  Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))\n                logger.debug(\"POS model loading charmodels: %s and %s\", args['charlm_forward_file'], args['charlm_backward_file'])\n                self.add_unsaved_module('charmodel_forward', load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache))\n                self.add_unsaved_module('charmodel_backward', load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache))\n                # optionally add a input transformation layer\n                if self.args.get('charlm_transform_dim', 0):\n                    self.charmodel_forward_transform = nn.Linear(self.charmodel_forward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)\n                    self.charmodel_backward_transform = nn.Linear(self.charmodel_backward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)\n                    input_size += self.args['charlm_transform_dim'] * 2\n                else:\n                    self.charmodel_forward_transform = None\n                    self.charmodel_backward_transform = None\n                    input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()\n            else:\n                bidirectional = args.get('char_bidirectional', False)\n                self.charmodel = CharacterModel(args, vocab, bidirectional=bidirectional)\n                if bidirectional:\n                    self.trans_char = nn.Linear(self.args['char_hidden_dim'] * 2, self.args['transformed_dim'], bias=False)\n                else:\n                    self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)\n                input_size += self.args['transformed_dim']\n\n        self.peft_name = peft_name\n        attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)\n        if self.args.get('bert_model', None):\n            # TODO: refactor bert_hidden_layers between the different models\n            if args.get('bert_hidden_layers', False):\n                # The average will be offset by 1/N so that the default zeros\n                # represents an average of the N layers\n                self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)\n                nn.init.zeros_(self.bert_layer_mix.weight)\n            else:\n                # an average of layers 2, 3, 4 will be used\n                # (for historic reasons)\n                self.bert_layer_mix = None\n            input_size += self.bert_model.config.hidden_size\n\n        if self.args['pretrain']:\n            # pretrained embeddings, by default this won't be saved into model file\n            self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))\n            self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)\n            input_size += self.args['transformed_dim']\n        \n        # recurrent layers\n        self.taggerlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)\n        self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))\n        self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))\n        self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))\n\n        # classifiers\n        self.upos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'])\n        self.upos_clf = nn.Linear(self.args['deep_biaff_hidden_dim'], len(vocab['upos']))\n        self.upos_clf.weight.data.zero_()\n        self.upos_clf.bias.data.zero_()\n\n        if share_hid:\n            clf_constructor = lambda insize, outsize: nn.Linear(insize, outsize)\n        else:\n            self.xpos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'] if not isinstance(vocab['xpos'], CompositeVocab) else self.args['composite_deep_biaff_hidden_dim'])\n            self.ufeats_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['composite_deep_biaff_hidden_dim'])\n            clf_constructor = lambda insize, outsize: BiaffineScorer(insize, self.args['tag_emb_dim'], outsize)\n\n        if isinstance(vocab['xpos'], CompositeVocab):\n            self.xpos_clf = nn.ModuleList()\n            for l in vocab['xpos'].lens():\n                self.xpos_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))\n        else:\n            self.xpos_clf = clf_constructor(self.args['deep_biaff_hidden_dim'], len(vocab['xpos']))\n            if share_hid:\n                self.xpos_clf.weight.data.zero_()\n                self.xpos_clf.bias.data.zero_()\n\n        self.ufeats_clf = nn.ModuleList()\n        for l in vocab['feats'].lens():\n            if share_hid:\n                self.ufeats_clf.append(clf_constructor(self.args['deep_biaff_hidden_dim'], l))\n                self.ufeats_clf[-1].weight.data.zero_()\n                self.ufeats_clf[-1].bias.data.zero_()\n            else:\n                self.ufeats_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))\n\n        # criterion\n        self.crit = nn.CrossEntropyLoss(ignore_index=0) # ignore padding\n\n        self.drop = nn.Dropout(args['dropout'])\n        self.worddrop = WordDropout(args['word_dropout'])\n\n    def add_unsaved_module(self, name, module):\n        self.unsaved_modules += [name]\n        setattr(self, name, module)\n\n    def log_norms(self):\n        utils.log_norms(self)\n\n    def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text):\n        \n        def pack(x):\n            return pack_padded_sequence(x, sentlens, batch_first=True)\n\n        inputs = []\n        if self.args['word_emb_dim'] > 0:\n            word_emb = self.word_emb(word)\n            word_emb = pack(word_emb)\n            inputs += [word_emb]\n\n        if self.args['pretrain']:\n            pretrained_emb = self.pretrained_emb(pretrained)\n            pretrained_emb = self.trans_pretrained(pretrained_emb)\n            pretrained_emb = pack(pretrained_emb)\n            inputs += [pretrained_emb]\n\n        def pad(x):\n            return pad_packed_sequence(PackedSequence(x, inputs[0].batch_sizes), batch_first=True)[0]\n\n        if self.args['char'] and self.args['char_emb_dim'] > 0:\n            if self.args.get('charlm', None):\n                all_forward_chars = self.charmodel_forward.build_char_representation(text)\n                assert isinstance(all_forward_chars, list)\n                if self.charmodel_forward_transform is not None:\n                    all_forward_chars = [self.charmodel_forward_transform(x) for x in all_forward_chars]\n                all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))\n\n                all_backward_chars = self.charmodel_backward.build_char_representation(text)\n                if self.charmodel_backward_transform is not None:\n                    all_backward_chars = [self.charmodel_backward_transform(x) for x in all_backward_chars]\n                all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))\n\n                inputs += [all_forward_chars, all_backward_chars]\n            else:\n                char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)\n                char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)\n                inputs += [char_reps]\n\n        if self.bert_model is not None:\n            device = next(self.parameters()).device\n            processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=False,\n                                                     num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,\n                                                     detach=not self.args.get('bert_finetune', False) or not self.training,\n                                                     peft_name=self.peft_name)\n\n            if self.bert_layer_mix is not None:\n                # add the average so that the default behavior is to\n                # take an average of the N layers, and anything else\n                # other than that needs to be learned\n                # TODO: refactor this\n                processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]\n\n            processed_bert = pad_sequence(processed_bert, batch_first=True)\n            inputs += [pack(processed_bert)]\n\n        lstm_inputs = torch.cat([x.data for x in inputs], 1)\n        lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)\n        lstm_inputs = self.drop(lstm_inputs)\n        lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)\n\n        lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(self.taggerlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.taggerlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))\n        lstm_outputs = lstm_outputs.data\n\n        upos_hid = F.relu(self.upos_hid(self.drop(lstm_outputs)))\n        upos_pred = self.upos_clf(self.drop(upos_hid))\n\n        preds = [pad(upos_pred).max(2)[1]]\n\n        if upos is not None:\n            upos = pack(upos).data\n            loss = self.crit(upos_pred.view(-1, upos_pred.size(-1)), upos.view(-1))\n        else:\n            loss = 0.0\n\n        if self.share_hid:\n            xpos_hid = upos_hid\n            ufeats_hid = upos_hid\n\n            clffunc = lambda clf, hid: clf(self.drop(hid))\n        else:\n            xpos_hid = F.relu(self.xpos_hid(self.drop(lstm_outputs)))\n            ufeats_hid = F.relu(self.ufeats_hid(self.drop(lstm_outputs)))\n\n            if self.training and upos is not None:\n                upos_emb = self.upos_emb(upos)\n            else:\n                upos_emb = self.upos_emb(upos_pred.max(1)[1])\n\n            clffunc = lambda clf, hid: clf(self.drop(hid), self.drop(upos_emb))\n\n        if xpos is not None: xpos = pack(xpos).data\n        if isinstance(self.vocab['xpos'], CompositeVocab):\n            xpos_preds = []\n            for i in range(len(self.vocab['xpos'])):\n                xpos_pred = clffunc(self.xpos_clf[i], xpos_hid)\n                if xpos is not None:\n                    loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos[:, i].view(-1))\n                xpos_preds.append(pad(xpos_pred).max(2, keepdim=True)[1])\n            preds.append(torch.cat(xpos_preds, 2))\n        else:\n            xpos_pred = clffunc(self.xpos_clf, xpos_hid)\n            if xpos is not None:\n                loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos.view(-1))\n            preds.append(pad(xpos_pred).max(2)[1])\n\n        ufeats_preds = []\n        if ufeats is not None: ufeats = pack(ufeats).data\n        for i in range(len(self.vocab['feats'])):\n            ufeats_pred = clffunc(self.ufeats_clf[i], ufeats_hid)\n            if ufeats is not None:\n                loss += self.crit(ufeats_pred.view(-1, ufeats_pred.size(-1)), ufeats[:, i].view(-1))\n            ufeats_preds.append(pad(ufeats_pred).max(2, keepdim=True)[1])\n        preds.append(torch.cat(ufeats_preds, 2))\n\n        return loss, preds\n"
  },
  {
    "path": "stanza/models/pos/scorer.py",
    "content": "\"\"\"\nUtils and wrappers for scoring taggers.\n\"\"\"\nimport logging\n\nfrom stanza.models.common.utils import ud_scores\n\nlogger = logging.getLogger('stanza')\n\ndef score(system_conllu_file, gold_conllu_file, verbose=True, eval_type='AllTags'):\n    \"\"\" Wrapper for tagger scorer. \"\"\"\n    evaluation = ud_scores(gold_conllu_file, system_conllu_file)\n    el = evaluation[eval_type]\n    p = el.precision\n    r = el.recall\n    f = el.f1\n    if verbose:\n        scores = [evaluation[k].f1 * 100 for k in ['UPOS', 'XPOS', 'UFeats', 'AllTags']]\n        logger.info(\"UPOS\\tXPOS\\tUFeats\\tAllTags\")\n        logger.info(\"{:.2f}\\t{:.2f}\\t{:.2f}\\t{:.2f}\".format(*scores))\n    return p, r, f\n\n"
  },
  {
    "path": "stanza/models/pos/trainer.py",
    "content": "\"\"\"\nA trainer class to handle training and testing of models.\n\"\"\"\n\nimport sys\nimport logging\nimport torch\nfrom torch import nn\n\nfrom stanza.models.common.trainer import Trainer as BaseTrainer\nfrom stanza.models.common import utils, loss\nfrom stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache\nfrom stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper\nfrom stanza.models.pos.model import Tagger\nfrom stanza.models.pos.vocab import MultiVocab\n\nlogger = logging.getLogger('stanza')\n\ndef unpack_batch(batch, device):\n    \"\"\" Unpack a batch from the data loader. \"\"\"\n    inputs = [b.to(device) if b is not None else None for b in batch[:8]]\n    orig_idx = batch[8]\n    word_orig_idx = batch[9]\n    sentlens = batch[10]\n    wordlens = batch[11]\n    text = batch[12]\n    return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text\n\nclass Trainer(BaseTrainer):\n    \"\"\" A trainer for training models. \"\"\"\n    def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, foundation_cache=None):\n        if model_file is not None:\n            # load everything from file\n            self.load(model_file, pretrain, args=args, foundation_cache=foundation_cache)\n        else:\n            # build model from scratch\n            self.args = args\n            self.vocab = vocab\n\n            bert_model, bert_tokenizer = load_bert(self.args['bert_model'])\n            peft_name = None\n            if self.args['use_peft']:\n                # fine tune the bert if we're using peft\n                self.args['bert_finetune'] = True\n                peft_name = \"pos\"\n                bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)\n\n            self.model = Tagger(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, share_hid=args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)\n\n        self.model = self.model.to(device)\n        self.optimizers = utils.get_split_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, weight_decay=self.args.get('initial_weight_decay', None), bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get(\"peft\", False))\n\n        self.schedulers = {}\n\n        if self.args.get('bert_finetune', None):\n            import transformers\n            warmup_scheduler = transformers.get_linear_schedule_with_warmup(\n                self.optimizers[\"bert_optimizer\"],\n                # todo late starting?\n                0, self.args[\"max_steps\"])\n            self.schedulers[\"bert_scheduler\"] = warmup_scheduler\n\n    def update(self, batch, eval=False):\n        device = next(self.model.parameters()).device\n        inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)\n        word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs\n\n        if eval:\n            self.model.eval()\n        else:\n            self.model.train()\n            for optimizer in self.optimizers.values():\n                optimizer.zero_grad()\n        loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)\n        if loss == 0.0:\n            return loss\n\n        loss_val = loss.data.item()\n        if eval:\n            return loss_val\n\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])\n\n        for optimizer in self.optimizers.values():\n            optimizer.step()\n        for scheduler in self.schedulers.values():\n            scheduler.step()\n        return loss_val\n\n    def predict(self, batch, unsort=True):\n        device = next(self.model.parameters()).device\n        inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)\n        word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs\n\n        self.model.eval()\n        batch_size = word.size(0)\n        _, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)\n        upos_seqs = [self.vocab['upos'].unmap(sent) for sent in preds[0].tolist()]\n        xpos_seqs = [self.vocab['xpos'].unmap(sent) for sent in preds[1].tolist()]\n        feats_seqs = [self.vocab['feats'].unmap(sent) for sent in preds[2].tolist()]\n\n        pred_tokens = [[[upos_seqs[i][j], xpos_seqs[i][j], feats_seqs[i][j]] for j in range(sentlens[i])] for i in range(batch_size)]\n        if unsort:\n            pred_tokens = utils.unsort(pred_tokens, orig_idx)\n        return pred_tokens\n\n    def save(self, filename, skip_modules=True):\n        model_state = self.model.state_dict()\n        # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file\n        if skip_modules:\n            skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]\n            for k in skipped:\n                del model_state[k]\n        params = {\n                'model': model_state,\n                'vocab': self.vocab.state_dict(),\n                'config': self.args\n                }\n        if self.args.get('use_peft', False):\n            # Hide import so that peft dependency is optional\n            from peft import get_peft_model_state_dict\n            params[\"bert_lora\"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)\n\n        try:\n            torch.save(params, filename, _use_new_zipfile_serialization=False)\n            logger.info(\"Model saved to {}\".format(filename))\n        except (KeyboardInterrupt, SystemExit):\n            raise\n        except Exception as e:\n            logger.warning(f\"Saving failed... {e} continuing anyway.\")\n\n    def load(self, filename, pretrain, args=None, foundation_cache=None):\n        \"\"\"\n        Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,\n        and the actual use of pretrain embeddings will depend on the boolean config \"pretrain\" in the loaded args.\n        \"\"\"\n        try:\n            checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        except BaseException:\n            logger.error(\"Cannot load model from {}\".format(filename))\n            raise\n        self.args = checkpoint['config']\n        if args is not None: self.args.update(args)\n\n        # preserve old models which were created before transformers were added\n        if 'bert_model' not in self.args:\n            self.args['bert_model'] = None\n\n        lora_weights = checkpoint.get('bert_lora')\n        if lora_weights:\n            logger.debug(\"Found peft weights for POS; loading a peft adapter\")\n            self.args[\"use_peft\"] = True\n\n        # TODO: refactor this common block of code with NER\n        force_bert_saved = False\n        peft_name = None\n        if self.args.get('use_peft', False):\n            force_bert_saved = True\n            bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], \"pos\", foundation_cache)\n            bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)\n            logger.debug(\"Loaded peft with name %s\", peft_name)\n        else:\n            if any(x.startswith(\"bert_model.\") for x in checkpoint['model'].keys()):\n                logger.debug(\"Model %s has a finetuned transformer.  Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere\", filename)\n                foundation_cache = NoTransformerFoundationCache(foundation_cache)\n                force_bert_saved = True\n            bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)\n\n        self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])\n        # load model\n        emb_matrix = None\n        if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None\n            emb_matrix = pretrain.emb\n        if any(x.startswith(\"bert_model.\") for x in checkpoint['model'].keys()):\n            logger.debug(\"Model %s has a finetuned transformer.  Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere\", filename)\n            foundation_cache = NoTransformerFoundationCache(foundation_cache)\n        self.model = Tagger(self.args, self.vocab, emb_matrix=emb_matrix, share_hid=self.args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)\n        self.model.load_state_dict(checkpoint['model'], strict=False)\n"
  },
  {
    "path": "stanza/models/pos/vocab.py",
    "content": "from collections import Counter, OrderedDict\n\nfrom stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab\nfrom stanza.models.common.vocab import CompositeVocab, VOCAB_PREFIX, EMPTY, EMPTY_ID\n\nclass WordVocab(BaseVocab):\n    def __init__(self, data=None, lang=\"\", idx=0, cutoff=0, lower=False, ignore=None):\n        self.ignore = ignore if ignore is not None else []\n        super().__init__(data, lang=lang, idx=idx, cutoff=cutoff, lower=lower)\n        self.state_attrs += ['ignore']\n\n    def id2unit(self, id):\n        if len(self.ignore) > 0 and id == EMPTY_ID:\n            return '_'\n        else:\n            return super().id2unit(id)\n\n    def unit2id(self, unit):\n        if len(self.ignore) > 0 and unit in self.ignore:\n            return self._unit2id[EMPTY]\n        else:\n            return super().unit2id(unit)\n\n    def build_vocab(self):\n        if self.lower:\n            counter = Counter([w[self.idx].lower() for sent in self.data for w in sent])\n        else:\n            counter = Counter([w[self.idx] for sent in self.data for w in sent])\n        for k in list(counter.keys()):\n            if counter[k] < self.cutoff or k in self.ignore:\n                del counter[k]\n\n        self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))\n        self._unit2id = {w:i for i, w in enumerate(self._id2unit)}\n\n    def __iter__(self):\n        # the EMPTY shenanigans above make list() look really weird\n        # when using the __len__ / __getitem__ paradigm,\n        # but yielding items like this works fine\n        for x in self._id2unit:\n            yield x\n\n    def __str__(self):\n        return \"<{}: {}>\".format(type(self), \",\".join(\"|%s|\" % x for x in self._id2unit))\n\nclass XPOSVocab(CompositeVocab):\n    def __init__(self, data=None, lang=\"\", idx=0, sep=\"\", keyed=False):\n        super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed)\n\nclass FeatureVocab(CompositeVocab):\n    def __init__(self, data=None, lang=\"\", idx=0, sep=\"|\", keyed=True):\n        super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed)\n\nclass MultiVocab(BaseMultiVocab):\n    def state_dict(self):\n        \"\"\" Also save a vocab name to class name mapping in state dict. \"\"\"\n        state = OrderedDict()\n        key2class = OrderedDict()\n        for k, v in self._vocabs.items():\n            state[k] = v.state_dict()\n            key2class[k] = type(v).__name__\n        state['_key2class'] = key2class\n        return state\n\n    @classmethod\n    def load_state_dict(cls, state_dict):\n        class_dict = {'CharVocab': CharVocab,\n                      'WordVocab': WordVocab,\n                      'XPOSVocab': XPOSVocab,\n                      'FeatureVocab': FeatureVocab}\n        new = cls()\n        assert '_key2class' in state_dict, \"Cannot find class name mapping in state dict!\"\n        key2class = state_dict['_key2class']\n        for k,v in state_dict.items():\n            if k == '_key2class':\n                continue\n            classname = key2class[k]\n            new[k] = class_dict[classname].load_state_dict(v)\n        return new\n\n"
  },
  {
    "path": "stanza/models/pos/xpos_vocab_factory.py",
    "content": "# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.\n# Please don't edit it!\n\nimport logging\n\nfrom stanza.models.pos.vocab import WordVocab, XPOSVocab\nfrom stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory\n\n# using a sublogger makes it easier to test in the unittests\nlogger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')\n\nXPOS_DESCRIPTIONS = {\n    'af_afribooms'   : XPOSDescription(XPOSType.XPOS, ''),\n    'ar_padt'        : XPOSDescription(XPOSType.XPOS, ''),\n    'bg_btb'         : XPOSDescription(XPOSType.XPOS, ''),\n    'ca_ancora'      : XPOSDescription(XPOSType.XPOS, ''),\n    'cs_cac'         : XPOSDescription(XPOSType.XPOS, ''),\n    'cs_cltt'        : XPOSDescription(XPOSType.XPOS, ''),\n    'cs_fictree'     : XPOSDescription(XPOSType.XPOS, ''),\n    'cs_pdt'         : XPOSDescription(XPOSType.XPOS, ''),\n    'en_partut'      : XPOSDescription(XPOSType.XPOS, ''),\n    'es_ancora'      : XPOSDescription(XPOSType.XPOS, ''),\n    'es_combined'    : XPOSDescription(XPOSType.XPOS, ''),\n    'fr_partut'      : XPOSDescription(XPOSType.XPOS, ''),\n    'gd_arcosg'      : XPOSDescription(XPOSType.XPOS, ''),\n    'gl_ctg'         : XPOSDescription(XPOSType.XPOS, ''),\n    'gl_treegal'     : XPOSDescription(XPOSType.XPOS, ''),\n    'grc_perseus'    : XPOSDescription(XPOSType.XPOS, ''),\n    'hr_set'         : XPOSDescription(XPOSType.XPOS, ''),\n    'is_gc'          : XPOSDescription(XPOSType.XPOS, ''),\n    'is_icepahc'     : XPOSDescription(XPOSType.XPOS, ''),\n    'is_modern'      : XPOSDescription(XPOSType.XPOS, ''),\n    'it_combined'    : XPOSDescription(XPOSType.XPOS, ''),\n    'it_isdt'        : XPOSDescription(XPOSType.XPOS, ''),\n    'it_markit'      : XPOSDescription(XPOSType.XPOS, ''),\n    'it_parlamint'   : XPOSDescription(XPOSType.XPOS, ''),\n    'it_partut'      : XPOSDescription(XPOSType.XPOS, ''),\n    'it_postwita'    : XPOSDescription(XPOSType.XPOS, ''),\n    'it_twittiro'    : XPOSDescription(XPOSType.XPOS, ''),\n    'it_vit'         : XPOSDescription(XPOSType.XPOS, ''),\n    'la_perseus'     : XPOSDescription(XPOSType.XPOS, ''),\n    'la_udante'      : XPOSDescription(XPOSType.XPOS, ''),\n    'lt_alksnis'     : XPOSDescription(XPOSType.XPOS, ''),\n    'lv_lvtb'        : XPOSDescription(XPOSType.XPOS, ''),\n    'ro_nonstandard' : XPOSDescription(XPOSType.XPOS, ''),\n    'ro_rrt'         : XPOSDescription(XPOSType.XPOS, ''),\n    'ro_simonero'    : XPOSDescription(XPOSType.XPOS, ''),\n    'sk_snk'         : XPOSDescription(XPOSType.XPOS, ''),\n    'sl_ssj'         : XPOSDescription(XPOSType.XPOS, ''),\n    'sl_sst'         : XPOSDescription(XPOSType.XPOS, ''),\n    'sr_set'         : XPOSDescription(XPOSType.XPOS, ''),\n    'ta_ttb'         : XPOSDescription(XPOSType.XPOS, ''),\n    'uk_iu'          : XPOSDescription(XPOSType.XPOS, ''),\n\n    'be_hse'         : XPOSDescription(XPOSType.WORD, None),\n    'bxr_bdt'        : XPOSDescription(XPOSType.WORD, None),\n    'cop_scriptorium': XPOSDescription(XPOSType.WORD, None),\n    'cu_proiel'      : XPOSDescription(XPOSType.WORD, None),\n    'cy_ccg'         : XPOSDescription(XPOSType.WORD, None),\n    'da_ddt'         : XPOSDescription(XPOSType.WORD, None),\n    'de_gsd'         : XPOSDescription(XPOSType.WORD, None),\n    'de_hdt'         : XPOSDescription(XPOSType.WORD, None),\n    'el_gdt'         : XPOSDescription(XPOSType.WORD, None),\n    'el_gud'         : XPOSDescription(XPOSType.WORD, None),\n    'en_atis'        : XPOSDescription(XPOSType.WORD, None),\n    'en_combined'    : XPOSDescription(XPOSType.WORD, None),\n    'en_craft'       : XPOSDescription(XPOSType.WORD, None),\n    'en_eslspok'     : XPOSDescription(XPOSType.WORD, None),\n    'en_ewt'         : XPOSDescription(XPOSType.WORD, None),\n    'en_genia'       : XPOSDescription(XPOSType.WORD, None),\n    'en_gum'         : XPOSDescription(XPOSType.WORD, None),\n    'en_gumreddit'   : XPOSDescription(XPOSType.WORD, None),\n    'en_mimic'       : XPOSDescription(XPOSType.WORD, None),\n    'en_test'        : XPOSDescription(XPOSType.WORD, None),\n    'es_gsd'         : XPOSDescription(XPOSType.WORD, None),\n    'et_edt'         : XPOSDescription(XPOSType.WORD, None),\n    'et_ewt'         : XPOSDescription(XPOSType.WORD, None),\n    'eu_bdt'         : XPOSDescription(XPOSType.WORD, None),\n    'fa_perdt'       : XPOSDescription(XPOSType.WORD, None),\n    'fa_seraji'      : XPOSDescription(XPOSType.WORD, None),\n    'fi_tdt'         : XPOSDescription(XPOSType.WORD, None),\n    'fr_combined'    : XPOSDescription(XPOSType.WORD, None),\n    'fr_gsd'         : XPOSDescription(XPOSType.WORD, None),\n    'fr_parisstories': XPOSDescription(XPOSType.WORD, None),\n    'fr_rhapsodie'   : XPOSDescription(XPOSType.WORD, None),\n    'fr_sequoia'     : XPOSDescription(XPOSType.WORD, None),\n    'fro_profiterole': XPOSDescription(XPOSType.WORD, None),\n    'ga_idt'         : XPOSDescription(XPOSType.WORD, None),\n    'ga_twittirish'  : XPOSDescription(XPOSType.WORD, None),\n    'got_proiel'     : XPOSDescription(XPOSType.WORD, None),\n    'grc_proiel'     : XPOSDescription(XPOSType.WORD, None),\n    'grc_ptnk'       : XPOSDescription(XPOSType.WORD, None),\n    'gv_cadhan'      : XPOSDescription(XPOSType.WORD, None),\n    'hbo_ptnk'       : XPOSDescription(XPOSType.WORD, None),\n    'he_combined'    : XPOSDescription(XPOSType.WORD, None),\n    'he_htb'         : XPOSDescription(XPOSType.WORD, None),\n    'he_iahltknesset': XPOSDescription(XPOSType.WORD, None),\n    'he_iahltwiki'   : XPOSDescription(XPOSType.WORD, None),\n    'hi_hdtb'        : XPOSDescription(XPOSType.WORD, None),\n    'hsb_ufal'       : XPOSDescription(XPOSType.WORD, None),\n    'hu_szeged'      : XPOSDescription(XPOSType.WORD, None),\n    'hy_armtdp'      : XPOSDescription(XPOSType.WORD, None),\n    'hy_bsut'        : XPOSDescription(XPOSType.WORD, None),\n    'hyw_armtdp'     : XPOSDescription(XPOSType.WORD, None),\n    'id_csui'        : XPOSDescription(XPOSType.WORD, None),\n    'it_old'         : XPOSDescription(XPOSType.WORD, None),\n    'ka_glc'         : XPOSDescription(XPOSType.WORD, None),\n    'kk_ktb'         : XPOSDescription(XPOSType.WORD, None),\n    'kmr_mg'         : XPOSDescription(XPOSType.WORD, None),\n    'kpv_lattice'    : XPOSDescription(XPOSType.WORD, None),\n    'ky_ktmu'        : XPOSDescription(XPOSType.WORD, None),\n    'la_proiel'      : XPOSDescription(XPOSType.WORD, None),\n    'lij_glt'        : XPOSDescription(XPOSType.WORD, None),\n    'lt_hse'         : XPOSDescription(XPOSType.WORD, None),\n    'lzh_kyoto'      : XPOSDescription(XPOSType.WORD, None),\n    'mr_ufal'        : XPOSDescription(XPOSType.WORD, None),\n    'mt_mudt'        : XPOSDescription(XPOSType.WORD, None),\n    'myv_jr'         : XPOSDescription(XPOSType.WORD, None),\n    'nb_bokmaal'     : XPOSDescription(XPOSType.WORD, None),\n    'nds_lsdc'       : XPOSDescription(XPOSType.WORD, None),\n    'nn_nynorsk'     : XPOSDescription(XPOSType.WORD, None),\n    'nn_nynorsklia'  : XPOSDescription(XPOSType.WORD, None),\n    'no_bokmaal'     : XPOSDescription(XPOSType.WORD, None),\n    'orv_birchbark'  : XPOSDescription(XPOSType.WORD, None),\n    'orv_rnc'        : XPOSDescription(XPOSType.WORD, None),\n    'orv_torot'      : XPOSDescription(XPOSType.WORD, None),\n    'ota_boun'       : XPOSDescription(XPOSType.WORD, None),\n    'pcm_nsc'        : XPOSDescription(XPOSType.WORD, None),\n    'pt_bosque'      : XPOSDescription(XPOSType.WORD, None),\n    'pt_cintil'      : XPOSDescription(XPOSType.WORD, None),\n    'pt_dantestocks' : XPOSDescription(XPOSType.WORD, None),\n    'pt_gsd'         : XPOSDescription(XPOSType.WORD, None),\n    'pt_petrogold'   : XPOSDescription(XPOSType.WORD, None),\n    'pt_porttinari'  : XPOSDescription(XPOSType.WORD, None),\n    'qpm_philotis'   : XPOSDescription(XPOSType.WORD, None),\n    'qtd_sagt'       : XPOSDescription(XPOSType.WORD, None),\n    'ru_gsd'         : XPOSDescription(XPOSType.WORD, None),\n    'ru_poetry'      : XPOSDescription(XPOSType.WORD, None),\n    'ru_syntagrus'   : XPOSDescription(XPOSType.WORD, None),\n    'ru_taiga'       : XPOSDescription(XPOSType.WORD, None),\n    'sa_vedic'       : XPOSDescription(XPOSType.WORD, None),\n    'sme_giella'     : XPOSDescription(XPOSType.WORD, None),\n    'swl_sslc'       : XPOSDescription(XPOSType.WORD, None),\n    'sq_staf'        : XPOSDescription(XPOSType.WORD, None),\n    'te_mtg'         : XPOSDescription(XPOSType.WORD, None),\n    'tr_atis'        : XPOSDescription(XPOSType.WORD, None),\n    'tr_boun'        : XPOSDescription(XPOSType.WORD, None),\n    'tr_framenet'    : XPOSDescription(XPOSType.WORD, None),\n    'tr_imst'        : XPOSDescription(XPOSType.WORD, None),\n    'tr_kenet'       : XPOSDescription(XPOSType.WORD, None),\n    'tr_penn'        : XPOSDescription(XPOSType.WORD, None),\n    'tr_tourism'     : XPOSDescription(XPOSType.WORD, None),\n    'ug_udt'         : XPOSDescription(XPOSType.WORD, None),\n    'uk_parlamint'   : XPOSDescription(XPOSType.WORD, None),\n    'vi_vtb'         : XPOSDescription(XPOSType.WORD, None),\n    'wo_wtb'         : XPOSDescription(XPOSType.WORD, None),\n    'xcl_caval'      : XPOSDescription(XPOSType.WORD, None),\n    'zh-hans_gsdsimp': XPOSDescription(XPOSType.WORD, None),\n    'zh-hant_gsd'    : XPOSDescription(XPOSType.WORD, None),\n    'zh_gsdsimp'     : XPOSDescription(XPOSType.WORD, None),\n\n    'en_lines'       : XPOSDescription(XPOSType.XPOS, '-'),\n    'fo_farpahc'     : XPOSDescription(XPOSType.XPOS, '-'),\n    'ja_gsd'         : XPOSDescription(XPOSType.XPOS, '-'),\n    'ja_gsdluw'      : XPOSDescription(XPOSType.XPOS, '-'),\n    'sv_lines'       : XPOSDescription(XPOSType.XPOS, '-'),\n    'ur_udtb'        : XPOSDescription(XPOSType.XPOS, '-'),\n\n    'fi_ftb'         : XPOSDescription(XPOSType.XPOS, ','),\n    'orv_ruthenian'  : XPOSDescription(XPOSType.XPOS, ','),\n\n    'id_gsd'         : XPOSDescription(XPOSType.XPOS, '+'),\n    'ko_gsd'         : XPOSDescription(XPOSType.XPOS, '+'),\n    'ko_kaist'       : XPOSDescription(XPOSType.XPOS, '+'),\n    'ko_ksl'         : XPOSDescription(XPOSType.XPOS, '+'),\n    'qaf_arabizi'    : XPOSDescription(XPOSType.XPOS, '+'),\n\n    'la_ittb'        : XPOSDescription(XPOSType.XPOS, '|'),\n    'la_llct'        : XPOSDescription(XPOSType.XPOS, '|'),\n    'nl_alpino'      : XPOSDescription(XPOSType.XPOS, '|'),\n    'nl_lassysmall'  : XPOSDescription(XPOSType.XPOS, '|'),\n    'sv_talbanken'   : XPOSDescription(XPOSType.XPOS, '|'),\n\n    'pl_lfg'         : XPOSDescription(XPOSType.XPOS, ':'),\n    'pl_pdb'         : XPOSDescription(XPOSType.XPOS, ':'),\n}\n\ndef xpos_vocab_factory(data, shorthand):\n    if shorthand not in XPOS_DESCRIPTIONS:\n        logger.warning(\"%s is not a known dataset.  Examining the data to choose which xpos vocab to use\", shorthand)\n    desc = choose_simplest_factory(data, shorthand)\n    if shorthand in XPOS_DESCRIPTIONS:\n        if XPOS_DESCRIPTIONS[shorthand] != desc:\n            # log instead of throw\n            # otherwise, updating datasets would be unpleasant\n            logger.error(\"XPOS tagset in %s has apparently changed!  Was %s, is now %s\", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)\n    else:\n        logger.warning(\"Chose %s for the xpos factory for %s\", desc, shorthand)\n    return build_xpos_vocab(desc, data, shorthand)\n\n"
  },
  {
    "path": "stanza/models/pos/xpos_vocab_utils.py",
    "content": "from collections import namedtuple\nfrom enum import Enum\nimport logging\nimport os\n\nfrom stanza.models.common.vocab import VOCAB_PREFIX\nfrom stanza.models.pos.vocab import XPOSVocab, WordVocab\n\nclass XPOSType(Enum):\n    XPOS     = 1\n    WORD     = 2\n\nXPOSDescription = namedtuple('XPOSDescription', ['xpos_type', 'sep'])\nDEFAULT_KEY = XPOSDescription(XPOSType.WORD, None)\n\nlogger = logging.getLogger('stanza')\n\ndef filter_data(data, idx):\n    data_filtered = []\n    for sentence in data:\n        flag = True\n        for token in sentence:\n            if token[idx] is None:\n                flag = False\n        if flag: data_filtered.append(sentence)\n    return data_filtered\n\ndef choose_simplest_factory(data, shorthand):\n    logger.info(f'Original length = {len(data)}')\n    data = filter_data(data, idx=2)\n    logger.info(f'Filtered length = {len(data)}')\n    vocab = WordVocab(data, shorthand, idx=2, ignore=[\"_\"])\n    key = DEFAULT_KEY\n    best_size = len(vocab) - len(VOCAB_PREFIX)\n    if best_size > 20:\n        for sep in ['', '-', '+', '|', ',', ':']: # separators\n            vocab = XPOSVocab(data, shorthand, idx=2, sep=sep)\n            length = sum(len(x) - len(VOCAB_PREFIX) for x in vocab._id2unit.values())\n            if length < best_size:\n                key = XPOSDescription(XPOSType.XPOS, sep)\n                best_size = length\n    return key\n\ndef build_xpos_vocab(description, data, shorthand):\n    if description.xpos_type is XPOSType.WORD:\n        return WordVocab(data, shorthand, idx=2, ignore=[\"_\"])\n\n    return XPOSVocab(data, shorthand, idx=2, sep=description.sep)\n"
  },
  {
    "path": "stanza/models/tagger.py",
    "content": "\"\"\"\nEntry point for training and evaluating a POS/morphological features tagger.\n\nThis tagger uses highway BiLSTM layers with character and word-level representations, and biaffine classifiers\nto produce consistent POS and UFeats predictions.\nFor details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.\n\"\"\"\n\nimport argparse\nimport logging\nimport io\nimport os\nimport time\nimport zipfile\n\nimport numpy as np\nimport torch\nfrom torch import nn, optim\n\nfrom stanza.models.pos.data import Dataset, ShuffledDataset\nfrom stanza.models.pos.trainer import Trainer\nfrom stanza.models.pos import scorer\nfrom stanza.models.common import utils\nfrom stanza.models.common import pretrain\nfrom stanza.models.common.doc import *\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.models.common.peft_config import add_peft_args, resolve_peft_args\nfrom stanza.models import _training_logging\nfrom stanza.utils.conll import CoNLL\n\nlogger = logging.getLogger('stanza')\n\ndef build_argparse():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_dir', type=str, default='data/pos', help='Root dir for saving models.')\n    parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors.')\n    parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.')\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n    parser.add_argument('--train_file', type=str, default=None, help='Input file for training.')\n    parser.add_argument('--eval_file', type=str, default=None, help='Input file for scoring.')\n    parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')\n    parser.add_argument('--no_gold_labels', dest='gold_labels', action='store_false', help=\"Don't score the eval file - perhaps it has no gold labels, for example.  Cannot be used at training time\")\n\n    parser.add_argument('--mode', default='train', choices=['train', 'predict'])\n    parser.add_argument('--lang', type=str, help='Language')\n    parser.add_argument('--shorthand', type=str, help=\"Treebank shorthand\")\n\n    parser.add_argument('--hidden_dim', type=int, default=200)\n    parser.add_argument('--char_hidden_dim', type=int, default=400)\n    parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400)\n    parser.add_argument('--composite_deep_biaff_hidden_dim', type=int, default=100)\n    parser.add_argument('--word_emb_dim', type=int, default=75, help='Dimension of the finetuned word embedding.  Set to 0 to turn off')\n    parser.add_argument('--word_cutoff', type=int, default=None, help='How common a word must be to include it in the finetuned word embedding.  If not set, small word vector files will be 0, larger will be %d' % utils.DEFAULT_WORD_CUTOFF)\n    parser.add_argument('--char_emb_dim', type=int, default=100)\n    parser.add_argument('--tag_emb_dim', type=int, default=50)\n    parser.add_argument('--charlm_transform_dim', type=int, default=None, help='Transform the pretrained charlm to this dimension.  If not set, no transform is used')\n    parser.add_argument('--transformed_dim', type=int, default=125)\n    parser.add_argument('--num_layers', type=int, default=2)\n    parser.add_argument('--char_num_layers', type=int, default=1)\n    parser.add_argument('--pretrain_max_vocab', type=int, default=250000)\n    parser.add_argument('--word_dropout', type=float, default=0.33)\n    parser.add_argument('--dropout', type=float, default=0.5)\n    parser.add_argument('--rec_dropout', type=float, default=0, help=\"Recurrent dropout\")\n    parser.add_argument('--char_rec_dropout', type=float, default=0, help=\"Recurrent dropout\")\n\n    # TODO: refactor charlm arguments for models which use it?\n    parser.add_argument('--no_char', dest='char', action='store_false', help=\"Turn off character model.\")\n    parser.add_argument('--char_bidirectional', dest='char_bidirectional', action='store_true', help=\"Use a bidirectional version of the non-pretrained charlm.  Doesn't help much, makes the models larger\")\n    parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help=\"Use lowercased characters in character model.\")\n    parser.add_argument('--charlm', action='store_true', help=\"Turn on contextualized char embedding using pretrained character-level language model.\")\n    parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help=\"Root dir for pretrained character-level language model.\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n\n    parser.add_argument('--bert_model', type=str, default=None, help=\"Use an external bert model (requires the transformers package)\")\n    parser.add_argument('--no_bert_model', dest='bert_model', action=\"store_const\", const=None, help=\"Don't use bert\")\n    parser.add_argument('--bert_hidden_layers', type=int, default=None, help=\"How many layers of hidden state to use from the transformer\")\n    parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')\n    parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help=\"Don't finetune the bert (or other transformer)\")\n    parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')\n\n    parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help=\"Turn off pretrained embeddings.\")\n    parser.add_argument('--share_hid', action='store_true', help=\"Share hidden representations for UPOS, XPOS and UFeats.\")\n    parser.set_defaults(share_hid=False)\n\n    parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')\n    parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam, adamw, adamax, or adadelta.  madgrad as an optional dependency')\n    parser.add_argument('--second_optim', type=str, default='amsgrad', help='Optimizer for the second half of training.  Default is Adam with AMSGrad')\n    parser.add_argument('--second_optim_reload', default=False, action='store_true', help='Reload the best model instead of continuing from current model if the first optimizer stalls out.  This does not seem to help, but might be useful for further experiments')\n    parser.add_argument('--no_second_optim', action='store_const', const=None, dest='second_optim', help=\"Don't use a second optimizer - only use the first optimizer\")\n    parser.add_argument('--lr', type=float, default=3e-3, help='Learning rate')\n    parser.add_argument('--second_lr', type=float, default=None, help='Alternate learning rate for the second optimizer')\n    parser.add_argument('--initial_weight_decay', type=float, default=None, help='Optimizer weight decay for the first optimizer')\n    parser.add_argument('--second_weight_decay', type=float, default=None, help='Optimizer weight decay for the second optimizer')\n    parser.add_argument('--beta2', type=float, default=0.95)\n\n    parser.add_argument('--max_steps', type=int, default=50000)\n    parser.add_argument('--eval_interval', type=int, default=100)\n    parser.add_argument('--fix_eval_interval', dest='adapt_eval_interval', action='store_false', \\\n            help=\"Use fixed evaluation interval for all treebanks, otherwise by default the interval will be increased for larger treebanks.\")\n    parser.add_argument('--max_steps_before_stop', type=int, default=3000, help='Changes learning method or early terminates after this many steps if the dev scores are not improving')\n    parser.add_argument('--batch_size', type=int, default=250)\n    parser.add_argument('--batch_maximum_tokens', type=int, default=5000, help='When run in a Pipeline, limit a batch to this many tokens to help avoid OOM for long sentences')\n    parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.')\n    parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')\n    parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')\n    parser.add_argument('--save_dir', type=str, default='saved_models/pos', help='Root dir for saving models.')\n    parser.add_argument('--save_name', type=str, default=\"{shorthand}_{embedding}_tagger.pt\", help=\"File name to save the model\")\n    parser.add_argument('--save_each', default=False, action='store_true', help=\"Save each checkpoint to its own model.  Will take up a bunch of space\")\n\n    parser.add_argument('--seed', type=int, default=1234)\n    add_peft_args(parser)\n    utils.add_device_args(parser)\n\n    parser.add_argument('--augment_nopunct', type=float, default=None, help='Augment the training data by copying this fraction of punct-ending sentences as non-punct.  Default of None will aim for roughly 50%%')\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n    return parser\n\ndef parse_args(args=None):\n    parser = build_argparse()\n    args = parser.parse_args(args=args)\n    resolve_peft_args(args, logger)\n\n    if args.augment_nopunct is None:\n        args.augment_nopunct = 0.25\n\n    if args.wandb_name:\n        args.wandb = True\n\n    if not args.share_hid and args.tag_emb_dim == 0:\n        raise ValueError(\"Cannot have tag_emb_dim==0 with share_hid==False, as the tags will be embedded for the next layer\")\n\n    args = vars(args)\n    return args\n\ndef main(args=None):\n    args = parse_args(args=args)\n\n    utils.set_random_seed(args['seed'])\n\n    logger.info(\"Running tagger in {} mode\".format(args['mode']))\n\n    if args['mode'] == 'train':\n        return train(args)\n    else:\n        return evaluate(args)\n\ndef model_file_name(args):\n    return utils.standard_model_file_name(args, \"tagger\")\n\ndef save_each_file_name(args):\n    model_file = model_file_name(args)\n    pieces = os.path.splitext(model_file)\n    return pieces[0] + \"_%05d\" + pieces[1]\n\ndef load_pretrain(args):\n    pt = None\n    if args['pretrain']:\n        pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])\n        if os.path.exists(pretrain_file):\n            vec_file = None\n        else:\n            vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])\n        pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])\n    return pt\n\ndef get_eval_type(dev_batch):\n    \"\"\"\n    If there is only one column to score in the dev set, use that instead of AllTags\n    \"\"\"\n    if dev_batch.has_xpos and not dev_batch.has_upos and not dev_batch.has_feats:\n        return \"XPOS\"\n    elif dev_batch.has_upos and not dev_batch.has_xpos and not dev_batch.has_feats:\n        return \"UPOS\"\n    else:\n        return \"AllTags\"\n\ndef load_training_data(args, pretrain):\n    train_docs = []\n    raw_train_files = args['train_file'].split(\";\")\n    train_files = []\n    for train_file in raw_train_files:\n        if zipfile.is_zipfile(train_file):\n            logger.info(\"Decompressing %s\" % train_file)\n            with zipfile.ZipFile(train_file) as zin:\n                for zipped_train_file in zin.namelist():\n                    with zin.open(zipped_train_file) as fin:\n                        logger.info(\"Reading %s from %s\" % (zipped_train_file, train_file))\n                        train_str = fin.read()\n                        train_str = train_str.decode(\"utf-8\")\n                        train_file_data, _, _ = CoNLL.conll2dict(input_str=train_str)\n                        logger.info(\"Train File {} from {}, Data Size: {}\".format(zipped_train_file, train_file, len(train_file_data)))\n                        train_docs.append(Document(train_file_data))\n                        train_files.append(\"%s %s\" % (train_file, zipped_train_file))\n        else:\n            logger.info(\"Reading %s\" % train_file)\n            # train_data is now a list of sentences, where each sentence is a\n            # list of words, in which each word is a dict of conll attributes\n            train_file_data, _, _ = CoNLL.conll2dict(input_file=train_file)\n            logger.info(\"Train File {}, Data Size: {}\".format(train_file, len(train_file_data)))\n            train_docs.append(Document(train_file_data))\n            train_files.append(train_file)\n    if sum(len(x.sentences) for x in train_docs) == 0:\n        raise RuntimeError(\"Training data for the tagger is empty: %s\" % args['train_file'])\n    # we want to ensure that the model is able te output _ for empty columns,\n    # but create batches whereby if a doc has upos/xpos tags we include them all.\n    # therefore, we create separate datasets and loaders for each input training file,\n    # which will ensure the system be able to see batches with both upos available\n    # and upos unavailable depending on what the availability in the file is.\n    vocab = Dataset.init_vocab(train_docs, args)\n    train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False)\n                  for i in train_docs]\n    for train_file, td in zip(train_files, train_data):\n        if not td.has_upos:\n            logger.info(\"No UPOS in %s\" % train_file)\n        if not td.has_xpos:\n            logger.info(\"No XPOS in %s\" % train_file)\n        if not td.has_feats:\n            logger.info(\"No feats in %s\" % train_file)\n\n    # reject partially tagged upos or xpos documents\n    # otherwise, the model will learn to output blanks for some words,\n    # which is probably a confusing result\n    # (and definitely throws off the depparse)\n    # another option would be to treat those as masked out\n    for td_idx, td in enumerate(train_data):\n        if td.has_upos:\n            upos_data = td.doc.get(UPOS, as_sentences=True)\n            for sentence_idx, sentence in enumerate(upos_data):\n                for word_idx, upos in enumerate(sentence):\n                    if upos == '_' or upos is None:\n                        conll = \"{:C}\".format(td.doc.sentences[sentence_idx])\n                        raise RuntimeError(\"Found a blank tag in the UPOS at sentence %d word %d of %s.\\n%s\" % ((sentence_idx+1), (word_idx+1), train_files[td_idx], conll))\n\n    # here we make sure the model will learn to output _ for empty columns\n    # if *any* dataset has data for the upos, xpos, or feature column,\n    # we consider that data enough to train the model on that column\n    # otherwise, we want to train the model to always output blanks\n    if not any(td.has_upos for td in train_data):\n        for td in train_data:\n            td.has_upos = True\n    if not any(td.has_xpos for td in train_data):\n        for td in train_data:\n            td.has_xpos = True\n    if not any(td.has_feats for td in train_data):\n        for td in train_data:\n            td.has_feats = True\n    # calculate the batches\n    train_batches = ShuffledDataset(train_data, args[\"batch_size\"])\n    return vocab, train_data, train_batches\n\ndef train(args):\n    model_file = model_file_name(args)\n    utils.ensure_dir(os.path.split(model_file)[0])\n\n    if args['save_each']:\n        # so models.pt -> models_0001.pt, etc\n        model_save_each_file = save_each_file_name(args)\n        logger.info(\"Saving each checkpoint to %s\" % model_save_each_file)\n\n    # load pretrained vectors if needed\n    pretrain = load_pretrain(args)\n    args['word_cutoff'] = utils.update_word_cutoff(pretrain, args['word_cutoff'])\n\n    if args['charlm']:\n        if args['charlm_shorthand'] is None:\n            raise ValueError(\"CharLM Shorthand is required for loading pretrained CharLM model...\")\n        logger.info('Using pretrained contextualized char embedding')\n        if not args['charlm_forward_file']:\n            args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])\n        if not args['charlm_backward_file']:\n            args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])\n\n    # load data\n    logger.info(\"Loading data with batch size {}...\".format(args['batch_size']))\n    vocab, train_data, train_batches = load_training_data(args, pretrain)\n\n    dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])\n    dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)\n    dev_batch = dev_data.to_loader(batch_size=args[\"batch_size\"])\n\n    eval_type = get_eval_type(dev_data)\n\n    # skip training if the language does not have training or dev data\n    # sum(...) to check if all of the training files are empty\n    if sum(len(td) for td in train_data) == 0 or len(dev_data) == 0:\n        logger.info(\"Skip training because no data available...\")\n        return None, None\n\n    if args['wandb']:\n        import wandb\n        wandb_name = args['wandb_name'] if args['wandb_name'] else \"%s_tagger\" % args['shorthand']\n        wandb.init(name=wandb_name, config=args)\n        wandb.run.define_metric('train_loss', summary='min')\n        wandb.run.define_metric('dev_score', summary='max')\n\n    logger.info(\"Training tagger...\")\n    foundation_cache = FoundationCache()\n    trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'], foundation_cache=foundation_cache)\n\n    global_step = 0\n    max_steps = args['max_steps']\n    dev_score_history = []\n    best_dev_preds = []\n    current_lr = args['lr']\n    global_start_time = time.time()\n    format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'\n\n    logger.debug(\"Training model on device %s\", next(trainer.model.parameters()).device)\n\n    if args['adapt_eval_interval']:\n        args['eval_interval'] = utils.get_adaptive_eval_interval(dev_data.num_examples, 2000, args['eval_interval'])\n        logger.info(\"Evaluating the model every {} steps...\".format(args['eval_interval']))\n\n    if args['save_each']:\n        logger.info(\"Saving initial checkpoint to %s\" % (model_save_each_file % global_step))\n        trainer.save(model_save_each_file % global_step)\n\n    using_amsgrad = False\n    last_best_step = 0\n    # start training\n    train_loss = 0\n    if args['log_norms']:\n        trainer.model.log_norms()\n    while True:\n        do_break = False\n        for i, batch in enumerate(train_batches):\n            start_time = time.time()\n            global_step += 1\n            loss = trainer.update(batch, eval=False) # update step\n            train_loss += loss\n            if global_step % args['log_step'] == 0:\n                duration = time.time() - start_time\n                logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr))\n                if args['log_norms']:\n                    trainer.model.log_norms()\n\n            if global_step % args['eval_interval'] == 0:\n                # eval on dev\n                logger.info(\"Evaluating on dev set...\")\n                dev_preds = []\n                indices = []\n                for batch in dev_batch:\n                    preds = trainer.predict(batch)\n                    dev_preds += preds\n                    indices.extend(batch[-1])\n                dev_preds = utils.unsort(dev_preds, indices)\n                dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in dev_preds for y in x])\n\n                system_pred_file = \"{:C}\\n\\n\".format(dev_data.doc)\n                system_pred_file = io.StringIO(system_pred_file)\n\n                _, _, dev_score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)\n\n                train_loss = train_loss / args['eval_interval'] # avg loss per batch\n                logger.info(\"step {}: train_loss = {:.6f}, dev_score = {:.4f}\".format(global_step, train_loss, dev_score))\n\n                if args['wandb']:\n                    wandb.log({'train_loss': train_loss, 'dev_score': dev_score})\n\n                train_loss = 0\n\n                if args['save_each']:\n                    logger.info(\"Saving checkpoint to %s\" % (model_save_each_file % global_step))\n                    trainer.save(model_save_each_file % global_step)\n\n                # save best model\n                if len(dev_score_history) == 0 or dev_score > max(dev_score_history):\n                    last_best_step = global_step\n                    trainer.save(model_file)\n                    logger.info(\"new best model saved.\")\n                    best_dev_preds = dev_preds\n\n                dev_score_history += [dev_score]\n\n            if global_step - last_best_step >= args['max_steps_before_stop']:\n                if not using_amsgrad and args['second_optim'] is not None:\n                    logger.info(\"Switching to second optimizer: {}\".format(args['second_optim']))\n                    if args['second_optim_reload']:\n                        logger.info('Reloading best model to continue from current local optimum')\n                        trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain, model_file=model_file, device=args['device'], foundation_cache=foundation_cache)\n                    last_best_step = global_step\n                    using_amsgrad = True\n                    lr = args['second_lr']\n                    if lr is None:\n                        lr = args['lr']\n                    trainer.optimizer = utils.get_optimizer(args['second_optim'], trainer.model, lr=lr, betas=(.9, args['beta2']), eps=1e-6, weight_decay=args['second_weight_decay'])\n                else:\n                    logger.info(\"Early termination: have not improved in {} steps\".format(args['max_steps_before_stop']))\n                    do_break = True\n                    break\n\n            if global_step >= args['max_steps']:\n                do_break = True\n                break\n\n        if do_break: break\n\n    logger.info(\"Training ended with {} steps.\".format(global_step))\n\n    if args['wandb']:\n        wandb.finish()\n\n    if len(dev_score_history) > 0:\n        best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1\n        logger.info(\"Best dev F1 = {:.2f}, at iteration = {}\".format(best_f, best_eval * args['eval_interval']))\n    else:\n        logger.info(\"Dev set never evaluated.  Saving final model.\")\n        trainer.save(model_file)\n\n    return trainer, _\n\ndef evaluate(args):\n    # file paths\n    model_file = model_file_name(args)\n\n    pretrain = load_pretrain(args)\n\n    load_args = {'charlm_forward_file': args.get('charlm_forward_file', None),\n                 'charlm_backward_file': args.get('charlm_backward_file', None)}\n\n    # load model\n    logger.info(\"Loading model from: {}\".format(model_file))\n    trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args)\n    result_doc = evaluate_trainer(args, trainer, pretrain)\n    return trainer, result_doc\n\ndef evaluate_trainer(args, trainer, pretrain):\n    system_pred_file = args['output_file']\n    loaded_args, vocab = trainer.args, trainer.vocab\n\n    # load config\n    for k in args:\n        if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand'] or k == 'mode':\n            loaded_args[k] = args[k]\n\n    # load data\n    logger.info(\"Loading data with batch size {}...\".format(args['batch_size']))\n    doc = CoNLL.conll2doc(input_file=args['eval_file'])\n    dev_data = Dataset(doc, loaded_args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)\n    dev_batch = dev_data.to_loader(batch_size=args['batch_size'])\n    eval_type = get_eval_type(dev_data)\n    if len(dev_batch) > 0:\n        logger.info(\"Start evaluation...\")\n        preds = []\n        indices = []\n        with torch.no_grad():\n            for b in dev_batch:\n                preds += trainer.predict(b)\n                indices.extend(b[-1])\n    else:\n        # skip eval if dev data does not exist\n        preds = []\n    preds = utils.unsort(preds, indices)\n\n    # write to file and score\n    dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in preds for y in x])\n    if system_pred_file:\n        CoNLL.write_doc2conll(dev_data.doc, system_pred_file)\n\n    if args['gold_labels']:\n        system_pred_file = \"{:C}\\n\\n\".format(dev_data.doc)\n        system_pred_file = io.StringIO(system_pred_file)\n\n        _, _, score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)\n\n        logger.info(\"POS Tagger score: %s %.2f\", args['shorthand'], score*100)\n\n    return dev_data.doc\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/tokenization/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/models/tokenization/data.py",
    "content": "from bisect import bisect_right\nfrom collections import defaultdict\nfrom copy import copy\nimport numpy as np\nimport random\nimport logging\nimport re\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom stanza.models.common.utils import sort_with_indices, unsort\nfrom stanza.models.tokenization.vocab import Vocab\n\nlogger = logging.getLogger('stanza')\n\ndef filter_consecutive_whitespaces(para):\n    filtered = []\n    for i, (char, label) in enumerate(para):\n        if i > 0:\n            if char == ' ' and para[i-1][0] == ' ':\n                continue\n\n        filtered.append((char, label))\n\n    return filtered\n\nNEWLINE_WHITESPACE_RE = re.compile(r'\\n\\s*\\n')\n# this was (r'^([\\d]+[,\\.]*)+$')\n# but the runtime on that can explode exponentially\n# for example, on 111111111111111111111111a\nNUMERIC_RE = re.compile(r'^[\\d]+([,\\.]+[\\d]+)*[,\\.]*$')\nWHITESPACE_RE = re.compile(r'\\s')\n\nclass TokenizationDataset:\n    def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs):\n        super().__init__(*args, **kwargs)  # forwards all unused arguments\n        self.args = tokenizer_args\n        self.eval = evaluation\n        self.dictionary = dictionary\n        self.vocab = vocab\n\n        # get input files\n        txt_file = input_files['txt']\n        label_file = input_files['label']\n\n        # Load data and process it\n        # set up text from file or input string\n        assert txt_file is not None or input_text is not None\n        if input_text is None:\n            with open(txt_file, encoding=\"utf-8\") as f:\n                text = ''.join(f.readlines()).rstrip()\n        else:\n            text = input_text\n\n        text_chunks = NEWLINE_WHITESPACE_RE.split(text)\n        text_chunks = [pt.rstrip() for pt in text_chunks]\n        text_chunks = [pt for pt in text_chunks if pt]\n        if label_file is not None:\n            with open(label_file, encoding=\"utf-8\") as f:\n                labels = ''.join(f.readlines()).rstrip()\n                labels = NEWLINE_WHITESPACE_RE.split(labels)\n                labels = [pt.rstrip() for pt in labels]\n                labels = [map(int, pt) for pt in labels if pt]\n        else:\n            labels = [[0 for _ in pt] for pt in text_chunks]\n\n        skip_newline = self.args.get('skip_newline', False)\n        self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces\n                      for char, label in zip(pt, pc) if not (skip_newline and char == '\\n')] # check if newline needs to be eaten\n                     for pt, pc in zip(text_chunks, labels)]\n\n        # remove consecutive whitespaces\n        self.data = [filter_consecutive_whitespaces(x) for x in self.data]\n\n    def labels(self):\n        \"\"\"\n        Returns a list of the labels for all of the sentences in this DataLoader\n\n        Used at eval time to compare to the results, for example\n        \"\"\"\n        return [np.array(list(x[1] for x in sent)) for sent in self.data]\n\n    def extract_dict_feat(self, para, idx):\n        \"\"\"\n        This function is to extract dictionary features for each character\n        \"\"\"\n        length = len(para)\n\n        dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]\n        dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]\n        forward_word = para[idx][0]\n        backward_word = para[idx][0]\n        prefix = True\n        suffix = True\n        for window in range(1,self.args['num_dict_feat']+1):\n            # concatenate each character and check if words found in dict not, stop if prefix not found\n            #check if idx+t is out of bound and if the prefix is already not found\n            if (idx + window) <= length-1 and prefix:\n                forward_word += para[idx+window][0].lower()\n                #check in json file if the word is present as prefix or word or None.\n                feat = 1 if forward_word in self.dictionary[\"words\"] else 0\n                #if the return value is not 2 or 3 then the checking word is not a valid word in dict.\n                dict_forward_feats[window-1] = feat\n                #if the dict return 0 means no prefixes found, thus, stop looking for forward.\n                if forward_word not in self.dictionary[\"prefixes\"]:\n                    prefix = False\n            #backward check: similar to forward\n            if (idx - window) >= 0 and suffix:\n                backward_word = para[idx-window][0].lower() + backward_word\n                feat = 1 if backward_word in self.dictionary[\"words\"] else 0\n                dict_backward_feats[window-1] = feat\n                if backward_word not in self.dictionary[\"suffixes\"]:\n                    suffix = False\n            #if cannot find both prefix and suffix, then exit the loop\n            if not prefix and not suffix:\n                break\n\n        return dict_forward_feats + dict_backward_feats\n\n    def para_to_sentences(self, para):\n        \"\"\" Convert a paragraph to a list of processed sentences. \"\"\"\n        res = []\n        funcs = []\n        for feat_func in self.args['feat_funcs']:\n            if feat_func == 'end_of_para' or feat_func == 'start_of_para':\n                # skip for position-dependent features\n                continue\n            if feat_func == 'space_before':\n                func = lambda x: 1 if x.startswith(' ') else 0\n            elif feat_func == 'capitalized':\n                func = lambda x: 1 if x[0].isupper() else 0\n            elif feat_func == 'numeric':\n                func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0\n            else:\n                raise ValueError('Feature function \"{}\" is undefined.'.format(feat_func))\n\n            funcs.append(func)\n\n        # stacking all featurize functions\n        composite_func = lambda x: [f(x) for f in funcs]\n\n        def process_sentence(sent_units, sent_labels, sent_feats):\n            return (np.array([self.vocab.unit2id(y) for y in sent_units]),\n                    np.array(sent_labels),\n                    np.array(sent_feats),\n                    list(sent_units))\n\n        use_end_of_para = 'end_of_para' in self.args['feat_funcs']\n        use_start_of_para = 'start_of_para' in self.args['feat_funcs']\n        use_dictionary = self.args['use_dictionary']\n        current_units = []\n        current_labels = []\n        current_feats = []\n        for i, (unit, label) in enumerate(para):\n            feats = composite_func(unit)\n            # position-dependent features\n            if use_end_of_para:\n                f = 1 if i == len(para)-1 else 0\n                feats.append(f)\n            if use_start_of_para:\n                f = 1 if i == 0 else 0\n                feats.append(f)\n\n            #if dictionary feature is selected\n            if use_dictionary:\n                dict_feats = self.extract_dict_feat(para, i)\n                feats = feats + dict_feats\n\n            current_units.append(unit)\n            current_labels.append(label)\n            current_feats.append(feats)\n            if not self.eval and (label == 2 or label == 4): # end of sentence\n                if len(current_units) <= self.args['max_seqlen']:\n                    # get rid of sentences that are too long during training of the tokenizer\n                    res.append(process_sentence(current_units, current_labels, current_feats))\n                current_units.clear()\n                current_labels.clear()\n                current_feats.clear()\n\n        if len(current_units) > 0:\n            if self.eval or len(current_units) <= self.args['max_seqlen']:\n                res.append(process_sentence(current_units, current_labels, current_feats))\n\n        return res\n\n    def advance_old_batch(self, eval_offsets, old_batch):\n        \"\"\"\n        Advance to a new position in a batch where we have partially processed the batch\n\n        If we have previously built a batch of data and made predictions on them, then when we are trying to make\n        prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch\n        and just (essentially) advance the indices/offsets from where we read converted data in this old batch.\n        In this case, eval_offsets index within the old_batch to advance the strings to process.\n        \"\"\"\n        unkid = self.vocab.unit2id('<UNK>')\n        padid = self.vocab.unit2id('<PAD>')\n\n        ounits, olabels, ofeatures, oraw = old_batch\n        feat_size = ofeatures.shape[-1]\n        lens = (ounits != padid).sum(1).tolist()\n        pad_len = max(l-i for i, l in zip(eval_offsets, lens))\n\n        units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64)\n        labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32)\n        features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32)\n        raw_units = []\n\n        for i in range(len(ounits)):\n            eval_offsets[i] = min(eval_offsets[i], lens[i])\n            units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]]\n            labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]]\n            features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]]\n            raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + ['<PAD>'] * (pad_len - lens[i] + eval_offsets[i]))\n\n        return units, labels, features, raw_units\n\ndef build_move_punct_set(data, move_back_prob):\n    move_punct = {',', ':', '!', '.', '?', '\"', '(', ')'}\n    for chunk in data:\n        # ignore positions at the start and end of a chunk\n        for idx in range(1, len(chunk)-1):\n            if chunk[idx][0] not in move_punct:\n                continue\n            if chunk[idx][1] == 0:\n                if chunk[idx+1][0].isspace() and not chunk[idx-1][0].isdigit():\n                    # this check removes punct which isn't ending a word...\n                    # honestly that's a rather unusual situation\n                    # VI has |3, 5| as a complete token\n                    # so we also eliminate isdigit()\n                    move_punct.remove(chunk[idx][0])\n                continue\n            # we skip isdigit() because we will intentionally not\n            # create things that look like decimal numbers\n            if not chunk[idx-1][0].isspace() and chunk[idx-1][0] not in move_punct and not chunk[idx-1][0].isdigit():\n                # this check eliminates things like '.' after 'Mr.'\n                move_punct.remove(chunk[idx][0])\n                continue\n    return move_punct\n\ndef build_known_mwt(data, mwt_expansions):\n    known_mwts = set()\n    for chunk in data:\n        for idx, unit in enumerate(chunk):\n            if unit[1] != 3:\n                continue\n            # found an MWT\n            prev_idx = idx - 1\n            while prev_idx >= 0 and chunk[prev_idx][1] == 0:\n                prev_idx -= 1\n            prev_idx += 1\n            while chunk[prev_idx][0].isspace():\n                prev_idx += 1\n            if prev_idx == idx:\n                continue\n            mwt = \"\".join(x[0] for x in chunk[prev_idx:idx+1])\n            if mwt not in mwt_expansions:\n                continue\n            if len(mwt_expansions[mwt]) > 2:\n                # TODO: could split 3 word tokens as well\n                continue\n            known_mwts.add(mwt)\n    return known_mwts\n\nclass DataLoader(TokenizationDataset):\n    \"\"\"\n    This is the training version of the dataset.\n    \"\"\"\n    def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None):\n        super().__init__(args, input_files, input_text, vocab, evaluation, dictionary)\n\n        self.vocab = vocab if vocab is not None else self.init_vocab()\n\n        # data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels.\n        # At evaluation time, each paragraph is treated as single \"sentence\" as we don't know a priori where\n        # sentence breaks occur. We make prediction from left to right for each paragraph and move forward to\n        # the last predicted sentence break to start afresh.\n        self.sentences = [self.para_to_sentences(para) for para in self.data]\n\n        self.init_sent_ids()\n        logger.debug(f\"{len(self.sentence_ids)} sentences loaded.\")\n\n        punct_move_back_prob = args.get('punct_move_back_prob', 0.0)\n        if punct_move_back_prob > 0.0:\n            self.move_punct = build_move_punct_set(self.data, punct_move_back_prob)\n            if len(self.move_punct) > 0:\n                logger.debug('Based on the training data, will augment space/punct combinations {}'.format(self.move_punct))\n            else:\n                logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace')\n\n        split_mwt_prob = args.get('split_mwt_prob', 0.0)\n        if split_mwt_prob > 0.0 and not evaluation:\n            self.mwt_expansions = mwt_expansions\n            self.known_mwt = build_known_mwt(self.data, mwt_expansions)\n            if len(self.known_mwt) > 0:\n                logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt))\n            else:\n                logger.debug('Based on the training data, there are NO MWT to split at training time')\n\n        augment_final_punct_prob = 0.0 if evaluation else args.get('augment_final_punct_prob', 0.0)\n        if augment_final_punct_prob > 0:\n            self.augmentations = defaultdict(list)\n            AUGMENT_PAIRS = [(\"?\", \"？\"),\n                             (\"?\", \"︖\"),\n                             (\"?\", \"﹖\"),\n                             (\"?\", \"⁇\"),\n                             (\"!\", \"！\"),\n                             (\"!\", \"︕\"),\n                             (\"!\", \"﹗\"),\n                             (\"!\", \"‼\"),]\n            for orig, target in AUGMENT_PAIRS:\n                if self.augment_vocab(self.vocab, self.data, orig, target):\n                    logger.debug('Based on the training data, augmenting |%s| to |%s|' % (orig, target))\n                    self.augmentations[orig].append(target)\n                if self.augment_vocab(self.vocab, self.data, target, orig):\n                    logger.debug('Based on the training data, augmenting |%s| to |%s|' % (target, orig))\n                    self.augmentations[target].append(orig)\n\n    def __len__(self):\n        return len(self.sentence_ids)\n\n    def init_vocab(self):\n        vocab = Vocab(self.data, self.args['lang'])\n        return vocab\n\n    @staticmethod\n    def augment_vocab(vocab, data, existing_unit, new_unit):\n        if existing_unit not in vocab:\n            return False\n        new_unit_count = 0\n        existing_unit_count = 0\n        for sentence in data:\n            unit = sentence[-1][0]\n            if unit == new_unit:\n                new_unit_count += 1\n            elif unit == existing_unit:\n                existing_unit_count += 1\n        if existing_unit_count == 0:\n            return False\n        if new_unit_count > 0:\n            return False\n        if new_unit not in vocab:\n            vocab.append(new_unit)\n        logger.debug(\"Found %d |%s| and %d |%s|\", new_unit_count, new_unit, existing_unit_count, existing_unit)\n        return True\n\n    def init_sent_ids(self):\n        self.sentence_ids = []\n        self.cumlen = [0]\n        for i, para in enumerate(self.sentences):\n            for j in range(len(para)):\n                self.sentence_ids += [(i, j)]\n                self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])]\n\n    def has_mwt(self):\n        # presumably this only needs to be called either 0 or 1 times,\n        # 1 when training and 0 any other time, so no effort is put\n        # into caching the result\n        for sentence in self.data:\n            for word in sentence:\n                if word[1] > 2:\n                    return True\n        return False\n\n    def shuffle(self):\n        for para in self.sentences:\n            random.shuffle(para)\n        self.init_sent_ids()\n\n    def move_last_char(self, sentence):\n        if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen'] and sentence[1][-1] == 2 and sentence[1][-2] != 0:\n            new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])]\n            new_units.extend([(' ', 0), (sentence[3][-1], int(sentence[1][-1]))])\n            encoded = self.para_to_sentences(new_units)\n            return encoded\n        return None\n\n    def split_mwt(self, sentence):\n        if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:\n            return None\n\n        # if we find a token in the sentence which ends with label 3,\n        # eg it is an MWT,\n        # with some probability we split it into two tokens\n        # and treat the split tokens as both label 1 instead of 3\n        # in this manner, we teach the tokenizer not to treat the\n        # entire sequence of characters with added spaces as an MWT,\n        # which weirdly can happen in some corner cases\n\n        mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3]\n        if len(mwt_ends) == 0:\n            return None\n        random_end = random.randint(0, len(mwt_ends)-1)\n        mwt_end = mwt_ends[random_end]\n        mwt_start = mwt_end - 1\n        while mwt_start >= 0 and sentence[1][mwt_start] == 0:\n            mwt_start -= 1\n        mwt_start += 1\n        while sentence[3][mwt_start].isspace():\n            mwt_start += 1\n        if mwt_start == mwt_end:\n            return None\n        mwt = \"\".join(x for x in sentence[3][mwt_start:mwt_end+1])\n        if mwt not in self.mwt_expansions:\n            return None\n\n        all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]\n        w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]]\n        w0_units[-1] = (w0_units[-1][0], 1)\n        w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]]\n        w1_units[-1] = (w1_units[-1][0], 1)\n        split_units = w0_units + [(' ', 0)] + w1_units\n        new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:]\n        encoded = self.para_to_sentences(new_units)\n        return encoded\n\n    def move_punct_back(self, sentence):\n        if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:\n            return None\n\n        # check that we are not accidentally creating decimal numbers\n        #   idx == 1 or not sentence[3][idx-2].isdigit()\n        # one disadvantage of checking for sentence[1][idx] == 0\n        #   would be that tokens of all punct, such as '...',\n        #   should move but would not move if this is eliminated\n        commas = [idx for idx, c in enumerate(sentence[3])\n                  if c in self.move_punct and idx > 0 and sentence[3][idx-1].isspace() and (idx == 1 or not sentence[3][idx-2].isdigit())]\n        if len(commas) == 0:\n            return None\n\n        all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]\n        new_units = []\n\n        span_start = 0\n        for span_end in commas:\n            new_units.extend(all_units[span_start:span_end-1])\n            span_start = span_end\n        if span_end < len(sentence[3]):\n            new_units.extend(all_units[span_end:])\n\n        encoded = self.para_to_sentences(new_units)\n        return encoded\n\n    def augment_final_punct(self, sentence):\n        if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen']:\n            if sentence[3][-1] in self.augmentations:\n                augmented = random.choice(self.augmentations[sentence[3][-1]])\n                new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])]\n                new_units.append((augmented, sentence[1][-1]))\n            else:\n                return None\n            encoded = self.para_to_sentences(new_units)\n            return encoded\n        return None\n\n\n    def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0):\n        ''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. '''\n        feat_size = len(self.sentences[0][0][2][0])\n        unkid = self.vocab.unit2id('<UNK>')\n        padid = self.vocab.unit2id('<PAD>')\n\n        def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):\n            # At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed \n            # by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences\n            # from the entire dataset until we reach max_seqlen.\n            drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0))\n            drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0))\n            move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0)\n            move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0)\n            split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0)\n            augment_final_punct_prob = 0.0 if self.eval else self.args.get('augment_final_punct_prob', 0.0)\n\n            pid, sid = id_pair if self.eval else random.choice(self.sentence_ids)\n            sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])]\n            total_len = len(sentences[0][0])\n\n            assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join([\"{}/{}\".format(*x) for x in zip(self.sentences[pid][sid])]))\n            if self.eval:\n                for sid1 in range(sid+1, len(self.sentences[pid])):\n                    total_len += len(self.sentences[pid][sid1][0])\n                    sentences.append(self.sentences[pid][sid1])\n\n                    if total_len >= self.args['max_seqlen']:\n                        break\n            else:\n                while True:\n                    pid1, sid1 = random.choice(self.sentence_ids)\n                    total_len += len(self.sentences[pid1][sid1][0])\n                    sentences.append(self.sentences[pid1][sid1])\n\n                    if total_len >= self.args['max_seqlen']:\n                        break\n\n            if move_last_char_prob > 0.0:\n                for sentence_idx, sentence in enumerate(sentences):\n                    if random.random() < move_last_char_prob:\n                        # the sentence might not be eligible, such as\n                        # already having a space or not having a sentence final punct,\n                        # so we need to do a two step checking process here\n                        new_sentence = self.move_last_char(sentence)\n                        if new_sentence is not None:\n                            sentences[sentence_idx] = new_sentence[0]\n                            total_len += 1\n\n            if move_punct_back_prob > 0.0:\n                for sentence_idx, sentence in enumerate(sentences):\n                    if random.random() < move_punct_back_prob:\n                        # the sentence might not be eligible, such as\n                        # not having a space separated punct,\n                        # so we need to do a two step checking process here\n                        new_sentence = self.move_punct_back(sentence)\n                        if new_sentence is not None:\n                            total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])\n                            sentences[sentence_idx] = new_sentence[0]\n\n            if split_mwt_prob > 0.0:\n                for sentence_idx, sentence in enumerate(sentences):\n                    if random.random() < split_mwt_prob:\n                        new_sentence = self.split_mwt(sentence)\n                        if new_sentence is not None:\n                            total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])\n                            sentences[sentence_idx] = new_sentence[0]\n\n            if augment_final_punct_prob > 0.0:\n                for sentence_idx, sentence in enumerate(sentences):\n                    if random.random() < split_mwt_prob:\n                        new_sentence = self.augment_final_punct(sentence)\n                        if new_sentence is not None:\n                            sentences[sentence_idx] = new_sentence[0]\n\n            if drop_sents and len(sentences) > 1:\n                if total_len > self.args['max_seqlen']:\n                    sentences = sentences[:-1]\n                if len(sentences) > 1:\n                    p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability\n                    cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0]\n                    sentences = sentences[:cutoff+1]\n\n            units = np.concatenate([s[0] for s in sentences])\n            labels = np.concatenate([s[1] for s in sentences])\n            feats = np.concatenate([s[2] for s in sentences])\n            raw_units = [x for s in sentences for x in s[3]]\n\n            if not self.eval:\n                cutoff = self.args['max_seqlen']\n                units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff]\n\n            if drop_last_char:  # can only happen in non-eval mode\n                if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3):\n                    # training text ended with a sentence end position\n                    # and that word was a single character\n                    # and the previous character ended the word\n                    units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1]\n                    # word end -> sentence end, mwt end -> sentence mwt end\n                    labels[-1] = labels[-1] + 1\n\n            return units, labels, feats, raw_units\n\n        if eval_offsets is not None:\n            # find max padding length\n            pad_len = 0\n            for eval_offset in eval_offsets:\n                if eval_offset < self.cumlen[-1]:\n                    pair_id = bisect_right(self.cumlen, eval_offset) - 1\n                    pair = self.sentence_ids[pair_id]\n                    pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0]))\n\n            pad_len += 1\n            id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets]\n            pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs]\n            offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)]\n\n            offsets_pairs = list(zip(offsets, pairs))\n        else:\n            id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size']))\n            offsets_pairs = [(0, x) for x in id_pairs]\n            pad_len = self.args['max_seqlen']\n\n        # put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors\n        units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64)\n        labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64)\n        features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32)\n        raw_units = []\n        for i, (offset, pair) in enumerate(offsets_pairs):\n            u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len)\n            units[i, :len(u_)] = u_\n            labels[i, :len(l_)] = l_\n            features[i, :len(f_), :] = f_\n            raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))\n\n        if unit_dropout > 0 and not self.eval:\n            # dropout characters/units at training time and replace them with UNKs\n            mask = np.random.random_sample(units.shape) < unit_dropout\n            mask[units == padid] = 0\n            units[mask] = unkid\n            for i in range(len(raw_units)):\n                for j in range(len(raw_units[i])):\n                    if mask[i, j]:\n                        raw_units[i][j] = '<UNK>'\n\n        # dropout unit feature vector in addition to only torch.dropout in the model.\n        # experiments showed that only torch.dropout hurts the model\n        # we believe it is because the dict feature vector is mostly scarse so it makes\n        # more sense to drop out the whole vector instead of only single element.\n        if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:\n            mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout\n            mask_feat[units == padid] = 0\n            for i in range(len(raw_units)):\n                for j in range(len(raw_units[i])):\n                    if mask_feat[i,j]:\n                        features[i,j,:] = 0\n                        \n        units = torch.from_numpy(units)\n        labels = torch.from_numpy(labels)\n        features = torch.from_numpy(features)\n\n        return units, labels, features, raw_units\n\nclass SortedDataset(Dataset):\n    \"\"\"\n    Holds a TokenizationDataset for use in a torch DataLoader\n\n    The torch DataLoader is different from the DataLoader defined here\n    and allows for cpu & gpu parallelism.  Updating output_predictions\n    to use this class as a wrapper to a TokenizationDataset means the\n    calculation of features can happen in parallel, saving quite a\n    bit of time.\n    \"\"\"\n    def __init__(self, dataset):\n        super().__init__()\n\n        self.dataset = dataset\n        self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True)\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, index):\n        # This will return a single sample\n        #   np: index in character map\n        #   np: tokenization label\n        #   np: features\n        #   list: original text as one length strings\n        return self.dataset.para_to_sentences(self.data[index])\n\n    def unsort(self, arr):\n        return unsort(arr, self.indices)\n\n    def collate(self, samples):\n        if any(len(x) > 1 for x in samples):\n            raise ValueError(\"Expected all paragraphs to have no preset sentence splits!\")\n        feat_size = samples[0][0][2].shape[-1]\n        padid = self.dataset.vocab.unit2id('<PAD>')\n\n        # +1 so that all samples end with at least one pad\n        pad_len = max(len(x[0][3]) for x in samples) + 1\n\n        units = torch.full((len(samples), pad_len), padid, dtype=torch.int64)\n        labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32)\n        features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32)\n        raw_units = []\n        for i, sample in enumerate(samples):\n            u_, l_, f_, r_ = sample[0]\n            units[i, :len(u_)] = torch.from_numpy(u_)\n            labels[i, :len(l_)] = torch.from_numpy(l_)\n            features[i, :len(f_), :] = torch.from_numpy(f_)\n            raw_units.append(r_ + ['<PAD>'])\n\n        return units, labels, features, raw_units\n\n"
  },
  {
    "path": "stanza/models/tokenization/model.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence\n\nfrom stanza.models.common.char_model import CharacterLanguageModelWordAdapter\nfrom stanza.models.common.foundation_cache import load_charlm\n\nclass Tokenizer(nn.Module):\n    def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, foundation_cache=None):\n        super().__init__()\n\n        self.unsaved_modules = []\n\n        self.args = args\n        feat_dim = args['feat_dim']\n\n        self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0)\n\n        self.input_dim = emb_dim + feat_dim\n\n        charmodel = None\n        if args is not None and args.get('charlm_forward_file', None):\n            charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache)\n            charmodels = nn.ModuleList([charmodel_forward])\n            charmodel = CharacterLanguageModelWordAdapter(charmodels)\n            self.input_dim += charmodel.hidden_dim()\n        self.add_unsaved_module(\"charmodel\", charmodel)\n\n        self.rnn = nn.LSTM(self.input_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0)\n\n        if self.args['conv_res'] is not None:\n            self.conv_res = nn.ModuleList()\n            self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')]\n\n            for si, size in enumerate(self.conv_sizes):\n                l = nn.Conv1d(self.input_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0))\n                self.conv_res.append(l)\n\n            if self.args.get('hier_conv_res', False):\n                self.conv_res2 = nn.Conv1d(hidden_dim * 2 * len(self.conv_sizes), hidden_dim * 2, 1)\n        self.tok_clf = nn.Linear(hidden_dim * 2, 1)\n        self.sent_clf = nn.Linear(hidden_dim * 2, 1)\n        if self.args['use_mwt']:\n            self.mwt_clf = nn.Linear(hidden_dim * 2, 1)\n\n        if args['hierarchical']:\n            in_dim = hidden_dim * 2\n            self.rnn2 = nn.LSTM(in_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)\n            self.tok_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)\n            self.sent_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)\n            if self.args['use_mwt']:\n                self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)\n\n        self.dropout = nn.Dropout(dropout)\n        self.dropout_feat = nn.Dropout(feat_dropout)\n\n        self.toknoise = nn.Dropout(self.args['tok_noise'])\n\n    def add_unsaved_module(self, name, module):\n        self.unsaved_modules += [name]\n        setattr(self, name, module)\n\n    def forward(self, x, feats, lengths, raw=None):\n        emb = self.embeddings(x)\n\n        if self.charmodel is not None and raw is not None:\n            char_emb = self.charmodel(raw, wrap=False)\n            emb = torch.cat([emb, char_emb], axis=2)\n\n        emb = self.dropout(emb)\n        feats = self.dropout_feat(feats)\n\n        emb = torch.cat([emb, feats], 2)\n        emb = pack_padded_sequence(emb, lengths, batch_first=True)\n        inp, _ = self.rnn(emb)\n        inp, _ = pad_packed_sequence(inp, batch_first=True)\n\n        if self.args['conv_res'] is not None:\n            conv_input = emb.transpose(1, 2).contiguous()\n            if not self.args.get('hier_conv_res', False):\n                for l in self.conv_res:\n                    inp = inp + l(conv_input).transpose(1, 2).contiguous()\n            else:\n                hid = []\n                for l in self.conv_res:\n                    hid += [l(conv_input)]\n                hid = torch.cat(hid, 1)\n                hid = F.relu(hid)\n                hid = self.dropout(hid)\n                inp = inp + self.conv_res2(hid).transpose(1, 2).contiguous()\n\n        inp = self.dropout(inp)\n\n        tok0 = self.tok_clf(inp)\n        sent0 = self.sent_clf(inp)\n        if self.args['use_mwt']:\n            mwt0 = self.mwt_clf(inp)\n\n        if self.args['hierarchical']:\n            inp2 = inp\n            if self.args['hier_invtemp'] > 0:\n                inp2 = inp2 * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp'])))\n            inp2 = pack_padded_sequence(inp2, lengths, batch_first=True)\n            inp2, _ = self.rnn2(inp2)\n            inp2, _ = pad_packed_sequence(inp2, batch_first=True)\n\n            inp2 = self.dropout(inp2)\n\n            tok0 = tok0 + self.tok_clf2(inp2)\n            sent0 = sent0 + self.sent_clf2(inp2)\n            if self.args['use_mwt']:\n                mwt0 = mwt0 + self.mwt_clf2(inp2)\n\n        nontok = F.logsigmoid(-tok0)\n        tok = F.logsigmoid(tok0)\n        nonsent = F.logsigmoid(-sent0)\n        sent = F.logsigmoid(sent0)\n        if self.args['use_mwt']:\n            nonmwt = F.logsigmoid(-mwt0)\n            mwt = F.logsigmoid(mwt0)\n            pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2)\n        else:\n            pred = torch.cat([nontok, tok+nonsent, tok+sent], 2)\n\n        return pred\n"
  },
  {
    "path": "stanza/models/tokenization/tokenize_files.py",
    "content": "\"\"\"Use a Stanza tokenizer to turn a text file into one tokenized paragraph per line\n\nFor example, the output of this script is suitable for Glove\n\nCurrently this *only* supports tokenization, no MWT splitting.\nIt also would be beneficial to have an option to convert spaces into\nNBSP, underscore, or some other marker to make it easier to process\nlanguages such as VI which have spaces in them\n\"\"\"\n\n\nimport argparse\nimport io\nimport os\nimport time\nimport re\nimport zipfile\n\nimport torch\n\nimport stanza\nfrom stanza.models.common.utils import open_read_text, default_device\nfrom stanza.models.tokenization.data import TokenizationDataset\nfrom stanza.models.tokenization.utils import output_predictions\nfrom stanza.pipeline.tokenize_processor import TokenizeProcessor\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nNEWLINE_SPLIT_RE = re.compile(r\"\\n\\s*\\n\")\n\ndef tokenize_to_file(tokenizer, fin, fout, chunk_size=500):\n    raw_text = fin.read()\n    documents = NEWLINE_SPLIT_RE.split(raw_text)\n    for chunk_start in tqdm(range(0, len(documents), chunk_size), leave=False):\n        chunk_end = min(chunk_start + chunk_size, len(documents))\n        chunk = documents[chunk_start:chunk_end]\n        in_docs = [stanza.Document([], text=d) for d in chunk]\n        out_docs = tokenizer.bulk_process(in_docs)\n        for document in out_docs:\n            for sent_idx, sentence in enumerate(document.sentences):\n                if sent_idx > 0:\n                    fout.write(\" \")\n                fout.write(\" \".join(x.text for x in sentence.tokens))\n            fout.write(\"\\n\")\n\ndef main(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--lang\", type=str, default=\"sd\", help=\"Which language to use for tokenization\")\n    parser.add_argument(\"--tokenize_model_path\", type=str, default=None, help=\"Specific tokenizer model to use\")\n    parser.add_argument(\"input_files\", type=str, nargs=\"+\", help=\"Which input files to tokenize\")\n    parser.add_argument(\"--output_file\", type=str, default=\"glove.txt\", help=\"Where to write the tokenized output\")\n    parser.add_argument(\"--model_dir\", type=str, default=None, help=\"Where to get models for a Pipeline (None => default models dir)\")\n    parser.add_argument(\"--chunk_size\", type=int, default=500, help=\"How many 'documents' to use in a chunk when tokenizing.  This is separate from the tokenizer batching - this limits how much memory gets used at once, since we don't need to store an entire file in memory at once\")\n    args = parser.parse_args(args=args)\n\n    if os.path.exists(args.output_file):\n        print(\"Cowardly refusing to overwrite existing output file %s\" % args.output_file)\n        return\n\n    if args.tokenize_model_path:\n        config = { \"model_path\": args.tokenize_model_path,\n                   \"check_requirements\": False }\n        tokenizer = TokenizeProcessor(config, pipeline=None, device=default_device())\n    else:\n        pipe = stanza.Pipeline(lang=args.lang, processors=\"tokenize\", model_dir=args.model_dir)\n        tokenizer = pipe.processors[\"tokenize\"]\n\n    with open(args.output_file, \"w\", encoding=\"utf-8\") as fout:\n        for filename in tqdm(args.input_files):\n            if filename.endswith(\".zip\"):\n                with zipfile.ZipFile(filename) as zin:\n                    input_names = zin.namelist()\n                    for input_name in tqdm(input_names, leave=False):\n                        with zin.open(input_names[0]) as fin:\n                            fin = io.TextIOWrapper(fin, encoding='utf-8')\n                            tokenize_to_file(tokenizer, fin, fout)\n            else:\n                with open_read_text(filename, encoding=\"utf-8\") as fin:\n                    tokenize_to_file(tokenizer, fin, fout)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/tokenization/trainer.py",
    "content": "import sys\nimport logging\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom stanza.models.common import utils\nfrom stanza.models.common.trainer import Trainer as BaseTrainer\nfrom stanza.models.tokenization.utils import create_dictionary\n\nfrom .model import Tokenizer\nfrom .vocab import Vocab\n\nlogger = logging.getLogger('stanza')\n\nclass Trainer(BaseTrainer):\n    def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, foundation_cache=None):\n        # TODO: make a test of the training w/ and w/o charlm\n        if model_file is not None:\n            # load everything from file\n            self.load(model_file, args, foundation_cache)\n        else:\n            # build model from scratch\n            self.args = args\n            self.vocab = vocab\n            self.lexicon = list(lexicon) if lexicon is not None else None\n            self.dictionary = dictionary\n            self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])\n        self.model = self.model.to(device)\n        self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)\n        self.optimizer = utils.get_optimizer(\"adam\", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay'])\n        self.feat_funcs = self.args.get('feat_funcs', None)\n        self.lang = self.args['lang'] # language determines how token normalization is done\n\n    def update(self, inputs):\n        self.model.train()\n        units, labels, features, text = inputs\n        lengths = [len(x) for x in text]\n\n        device = next(self.model.parameters()).device\n        units = units.to(device)\n        labels = labels.to(device)\n        features = features.to(device)\n\n        pred = self.model(units, features, lengths, text)\n\n        self.optimizer.zero_grad()\n        classes = pred.size(2)\n        loss = self.criterion(pred.view(-1, classes), labels.view(-1))\n\n        loss.backward()\n        nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])\n        self.optimizer.step()\n\n        return loss.item()\n\n    def predict(self, inputs):\n        self.model.eval()\n        units, _, features, text = inputs\n        lengths = [len(x) for x in text]\n\n        device = next(self.model.parameters()).device\n        units = units.to(device)\n        features = features.to(device)\n\n        pred = self.model(units, features, lengths, text)\n\n        return pred.data.cpu().numpy()\n\n    def save(self, filename, skip_modules=True):\n        model_state = None\n        if self.model is not None:\n            model_state = self.model.state_dict()\n            # skip saving modules like the pretrained charlm\n            if skip_modules:\n                skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]\n                for k in skipped:\n                    del model_state[k]\n\n        params = {\n            'model': model_state,\n            'vocab': self.vocab.state_dict(),\n            # save and load lexicon as list instead of set so\n            # we can use weights_only=True\n            'lexicon': list(self.lexicon) if self.lexicon is not None else None,\n            'config': self.args\n        }\n        try:\n            torch.save(params, filename, _use_new_zipfile_serialization=False)\n            logger.info(\"Model saved to {}\".format(filename))\n        except BaseException:\n            logger.warning(\"Saving failed... continuing anyway.\")\n\n    def load(self, filename, args, foundation_cache):\n        try:\n            checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        except BaseException:\n            logger.error(\"Cannot load model from {}\".format(filename))\n            raise\n        self.args = checkpoint['config']\n        if args is not None and args.get('charlm_forward_file', None) is not None:\n            if checkpoint['config'].get('charlm_forward_file') is None:\n                # if the saved model didn't use a charlm, we skip the charlm here\n                # otherwise the loaded model weights won't fit in the newly created model\n                self.args['charlm_forward_file'] = None\n            else:\n                self.args['charlm_forward_file'] = args['charlm_forward_file']\n        if self.args.get('use_mwt', None) is None:\n            # Default to True as many currently saved models\n            # were built with mwt layers\n            self.args['use_mwt'] = True\n        self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])\n        self.model.load_state_dict(checkpoint['model'], strict=False)\n        self.vocab = Vocab.load_state_dict(checkpoint['vocab'])\n        self.lexicon = checkpoint['lexicon']\n\n        if self.lexicon is not None:\n            self.lexicon = set(self.lexicon)\n            self.dictionary = create_dictionary(self.lexicon)\n        else:\n            self.dictionary = None\n"
  },
  {
    "path": "stanza/models/tokenization/utils.py",
    "content": "from collections import Counter\nfrom copy import copy\nimport json\nimport numpy as np\nimport re\nimport logging\nimport os\n\nfrom torch.utils.data import DataLoader as TorchDataLoader\n\nimport stanza.utils.default_paths as default_paths\nfrom stanza.models.common.utils import ud_scores, harmonic_mean\nfrom stanza.models.common.doc import Document\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models.common.doc import *\nfrom stanza.models.tokenization.data import SortedDataset\n\nlogger = logging.getLogger('stanza')\npaths = default_paths.get_default_paths()\n\ndef create_dictionary(lexicon):\n    \"\"\"\n    This function is to create a new dictionary used for improving tokenization model for multi-syllable words languages\n    such as vi, zh or th. This function takes the lexicon as input and output a dictionary that contains three set:\n    words, prefixes and suffixes where prefixes set should contains all the prefixes in the lexicon and similar for suffixes.\n    The point of having prefixes/suffixes sets in the  dictionary is just to make it easier to check during data preparation.\n\n    :param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp\n    :param lexicon - set of words used to create dictionary\n    :return a dictionary object that contains words and their prefixes and suffixes.\n    \"\"\"\n    \n    dictionary = {\"words\":set(), \"prefixes\":set(), \"suffixes\":set()}\n    \n    def add_word(word):\n        if word not in dictionary[\"words\"]:\n            dictionary[\"words\"].add(word)\n            prefix = \"\"\n            suffix = \"\"\n            for i in range(0,len(word)-1):\n                prefix = prefix + word[i]\n                suffix = word[len(word) - i - 1] + suffix\n                dictionary[\"prefixes\"].add(prefix)\n                dictionary[\"suffixes\"].add(suffix)\n\n    for word in lexicon:\n        if len(word)>1:\n            add_word(word)\n\n    return dictionary\n\ndef create_lexicon(shorthand=None, train_path=None, external_path=None):\n    \"\"\"\n    This function is to create a lexicon to store all the words from the training set and external dictionary.\n    This lexicon will be saved with the model and will be used to create dictionary when the model is loaded.\n    The idea of separating lexicon and dictionary in two different phases is a good tradeoff between time and space.\n    Note that we eliminate all the long words but less frequently appeared in the lexicon by only taking 95-percentile\n    list of words.\n\n    :param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp\n    :param train_path - path to conllu train file\n    :param external_path - path to extenral dict, expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt\n    :return a set lexicon object that contains all distinct words\n    \"\"\"\n    lexicon = set()\n    length_freq = []\n    #this regex is to check if a character is an actual Thai character as seems .isalpha() python method doesn't pick up Thai accent characters..\n    pattern_thai = re.compile(r\"(?:[^\\d\\W]+)|\\s\")\n    \n    def check_valid_word(shorthand, word):\n        \"\"\"\n        This function is to check if the word are multi-syllable words and not numbers. \n        For vi, whitespaces are syllabe-separator.\n        \"\"\"\n        if shorthand.startswith(\"vi_\"):\n            return True if len(word.split(\" \")) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False\n        elif shorthand.startswith(\"th_\"):\n            return True if len(word) > 1 and any(map(pattern_thai.match, word)) and not any(map(str.isdigit, word)) else False\n        else:\n            return True if len(word) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False\n\n    #checking for words in the training set to add them to lexicon.\n    if train_path is not None:\n        if not os.path.isfile(train_path):\n            raise FileNotFoundError(f\"Cannot open train set at {train_path}\")\n\n        train_doc = CoNLL.conll2doc(input_file=train_path)\n\n        for train_sent in train_doc.sentences:\n            train_words = [x.text for x in train_sent.tokens if x.is_mwt()] + [x.text for x in train_sent.words]\n            for word in train_words:\n                word = word.lower()\n                if check_valid_word(shorthand, word) and word not in lexicon:\n                    lexicon.add(word)\n                    length_freq.append(len(word))\n        count_word = len(lexicon)\n        logger.info(f\"Added {count_word} words from the training data to the lexicon.\")\n\n    #checking for external dictionary and add them to lexicon.\n    if external_path is not None:\n        if not os.path.isfile(external_path):\n            raise FileNotFoundError(f\"Cannot open external dictionary at {external_path}\")\n\n        with open(external_path, \"r\", encoding=\"utf-8\") as external_file:\n            lines = external_file.readlines()\n        for line in lines:\n            word = line.lower()\n            word = word.replace(\"\\n\",\"\")\n            if check_valid_word(shorthand, word) and word not in lexicon:\n                lexicon.add(word)\n                length_freq.append(len(word))\n        logger.info(f\"Added another {len(lexicon) - count_word} words from the external dict to dictionary.\")\n        \n\n    #automatically calculate the number of dictionary features (window size to look for words) based on the frequency of word length\n    #take the length at 95-percentile to eliminate all the longest (maybe) compounds words in the lexicon\n    num_dict_feat = int(np.percentile(length_freq, 95))\n    lexicon = {word for word in lexicon if len(word) <= num_dict_feat }\n    logger.info(f\"Final lexicon consists of {len(lexicon)} words after getting rid of long words.\")\n\n    return lexicon, num_dict_feat\n\ndef load_lexicon(args):\n    \"\"\"\n    This function is to create a new dictionary and load it to training.\n    The external dictionary is expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt\n    For example, vi_vlsp-externaldict.txt\n    \"\"\"\n    shorthand = args[\"shorthand\"]\n    tokenize_dir = paths[\"TOKENIZE_DATA_DIR\"]\n    train_path = f\"{tokenize_dir}/{shorthand}.train.gold.conllu\"\n    external_dict_path = f\"{tokenize_dir}/{shorthand}-externaldict.txt\"\n    if not os.path.exists(external_dict_path):\n        logger.info(f\"External dictionary not found! Looked in {external_dict_path}  Checking training data...\")\n        external_dict_path = None\n    if not os.path.exists(train_path):\n        logger.info(f\"Training dataset does not exist, thus cannot create dictionary {shorthand}\")\n        train_path = None\n    if train_path is None and external_dict_path is None:\n        raise FileNotFoundError(f\"Cannot find training set / external dictionary at {train_path} and {external_dict_path}\")\n\n    return create_lexicon(shorthand, train_path, external_dict_path)\n\n\ndef load_mwt_dict(filename):\n    \"\"\"\n    Returns a dict from an MWT to its most common expansion and count.\n\n    Other less common expansions are discarded.\n    \"\"\"\n    if filename is None:\n        return None\n\n    with open(filename, 'r') as f:\n        mwt_dict0 = json.load(f)\n\n    mwt_dict = dict()\n    for item in mwt_dict0:\n        (key, expansion), count = item\n\n        if key not in mwt_dict or mwt_dict[key][1] < count:\n            mwt_dict[key] = (expansion, count)\n\n    return mwt_dict\n\ndef process_sentence(sentence, mwt_dict=None):\n    sent = []\n    i = 0\n    for tok, p, position_info in sentence:\n        expansion = None\n        if (p == 3 or p == 4) and mwt_dict is not None:\n            # MWT found, (attempt to) expand it!\n            if tok in mwt_dict:\n                expansion = mwt_dict[tok][0]\n            elif tok.lower() in mwt_dict:\n                expansion = mwt_dict[tok.lower()][0]\n        if expansion is not None:\n            sent.append({ID: (i+1, i+len(expansion)), TEXT: tok})\n            if position_info is not None:\n                sent[-1][START_CHAR] = position_info[0]\n                sent[-1][END_CHAR] = position_info[1]\n            for etok in expansion:\n                sent.append({ID: (i+1, ), TEXT: etok})\n                i += 1\n        else:\n            if len(tok) <= 0:\n                continue\n            sent.append({ID: (i+1, ), TEXT: tok})\n            if position_info is not None:\n                sent[-1][START_CHAR] = position_info[0]\n                sent[-1][END_CHAR] = position_info[1]\n            if p == 3 or p == 4:# MARK\n                sent[-1][MISC] = 'MWT=Yes'\n            i += 1\n    return sent\n\n\n# https://stackoverflow.com/questions/201323/how-to-validate-an-email-address-using-a-regular-expression\nEMAIL_RAW_RE = r\"\"\"(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|\"(?:[\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x21\\x23-\\x5b\\x5d-\\x7f]|\\\\[\\x01-\\x09\\x0b\\x0c\\x0e-\\x7f])*\")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\\[(?:(?:(?:2(?:5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\\.){3}(?:(?:2(?:5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x21-\\x5a\\x53-\\x7f]|\\\\[\\x01-\\x09\\x0b\\x0c\\x0e-\\x7f])+)\\])\"\"\"\n\n# https://stackoverflow.com/questions/3809401/what-is-a-good-regular-expression-to-match-a-url\n# modification: disallow \" as opposed to all ^\\s\nURL_RAW_RE = r\"\"\"(?:https?:\\/\\/(?:www\\.|(?!www))[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\\.[^\\s\"]{2,}|www\\.[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\\.[^\\s\"]{2,}|https?:\\/\\/(?:www\\.|(?!www))[a-zA-Z0-9]+\\.[^\\s\"]{2,}|www\\.[a-zA-Z0-9]+\\.[^\\s\"]{2,})|[a-zA-Z0-9]+\\.(?:gov|org|edu|net|com|co)(?:\\.[^\\s\"]{2,})\"\"\"\n\nMASK_RE = re.compile(f\"(?:{EMAIL_RAW_RE}|{URL_RAW_RE})\")\n\ndef find_spans(raw):\n    \"\"\"\n    Return spans of text which don't contain <PAD> and are split by <PAD>\n    \"\"\"\n    pads = [idx for idx, char in enumerate(raw) if char == '<PAD>']\n    if len(pads) == 0:\n        spans = [(0, len(raw))]\n    else:\n        prev = 0\n        spans = []\n        for pad in pads:\n            if pad != prev:\n                spans.append( (prev, pad) )\n            prev = pad + 1\n        if prev < len(raw):\n            spans.append( (prev, len(raw)) )\n    return spans\n\ndef update_pred_regex(raw, pred):\n    \"\"\"\n    Update the results of a tokenization batch by checking the raw text against a couple regular expressions\n\n    Currently, emails and urls are handled\n    TODO: this might work better as a constraint on the inference\n\n    for efficiency pred is modified in place\n    \"\"\"\n    spans = find_spans(raw)\n\n    for span_begin, span_end in spans:\n        text = \"\".join(raw[span_begin:span_end])\n        for match in MASK_RE.finditer(text):\n            match_begin, match_end = match.span()\n            # first, update all characters touched by the regex to not split\n            # with the exception of the last character...\n            for char in range(match_begin+span_begin, match_end+span_begin-1):\n                pred[char] = 0\n            # if the last character is not currently a split, make it a word split\n            if pred[match_end+span_begin-1] == 0:\n                pred[match_end+span_begin-1] = 1\n\n    return pred\n\nSPACE_RE = re.compile(r'\\s')\nSPACE_SPLIT_RE = re.compile(r'( *[^ ]+)')\n\ndef predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, num_workers):\n    \"\"\"\n    The guts of the prediction method\n\n    Calls trainer.predict() over and over until we have predictions for all of the text\n    \"\"\"\n    all_preds = []\n    all_raw = []\n\n    sorted_data = SortedDataset(data_generator)\n    dataloader = TorchDataLoader(sorted_data, batch_size=batch_size, collate_fn=sorted_data.collate, num_workers=num_workers)\n    for batch_idx, batch in enumerate(dataloader):\n        num_sentences = len(batch[3])\n        # being sorted by descending length, we need to use 0 as the longest sentence\n        N = len(batch[3][0])\n        for paragraph in batch[3]:\n            all_raw.append(list(paragraph))\n\n        if N <= max_seqlen:\n            pred = np.argmax(trainer.predict(batch), axis=2)\n        else:\n            # TODO: we could shortcircuit some processing of\n            # long strings of PAD by tracking which rows are finished\n            idx = [0] * num_sentences\n            adv = [0] * num_sentences\n            para_lengths = [x.index('<PAD>') for x in batch[3]]\n            pred = [[] for _ in range(num_sentences)]\n            while True:\n                ens = [min(N - idx1, max_seqlen) for idx1, N in zip(idx, para_lengths)]\n                en = max(ens)\n                batch1 = batch[0][:, :en], batch[1][:, :en], batch[2][:, :en], [x[:en] for x in batch[3]]\n                pred1 = np.argmax(trainer.predict(batch1), axis=2)\n\n                for j in range(num_sentences):\n                    sentbreaks = np.where((pred1[j] == 2) + (pred1[j] == 4))[0]\n                    if len(sentbreaks) <= 0 or idx[j] >= para_lengths[j] - max_seqlen:\n                        advance = ens[j]\n                    else:\n                        advance = np.max(sentbreaks) + 1\n\n                    pred[j] += [pred1[j, :advance]]\n                    idx[j] += advance\n                    adv[j] = advance\n\n                if all([idx1 >= N for idx1, N in zip(idx, para_lengths)]):\n                    break\n                # once we've made predictions on a certain number of characters for each paragraph (recorded in `adv`),\n                # we skip the first `adv` characters to make the updated batch\n                batch = data_generator.advance_old_batch(adv, batch)\n\n            pred = [np.concatenate(p, 0) for p in pred]\n\n        for par_idx in range(num_sentences):\n            offset = batch_idx * batch_size + par_idx\n\n            raw = all_raw[offset]\n            par_len = raw.index('<PAD>')\n            raw = raw[:par_len]\n            all_raw[offset] = raw\n            if pred[par_idx][par_len-1] < 2:\n                pred[par_idx][par_len-1] = 2\n            elif pred[par_idx][par_len-1] > 2:\n                pred[par_idx][par_len-1] = 4\n            if use_regex_tokens:\n                all_preds.append(update_pred_regex(raw, pred[par_idx][:par_len]))\n            else:\n                all_preds.append(pred[par_idx][:par_len])\n\n    all_preds = sorted_data.unsort(all_preds)\n    all_raw = sorted_data.unsort(all_raw)\n\n    return all_preds, all_raw\n\ndef output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, max_seqlen=1000, orig_text=None, no_ssplit=False, use_regex_tokens=True, num_workers=0, postprocessor=None):\n    batch_size = trainer.args['batch_size']\n    max_seqlen = max(1000, max_seqlen)\n\n    all_preds, all_raw = predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, num_workers)\n\n    use_la_ittb_shorthand = trainer.args['shorthand'] == 'la_ittb'\n    skip_newline = trainer.args['skip_newline']\n    oov_count, offset, doc = decode_predictions(vocab, mwt_dict, orig_text, all_raw, all_preds, no_ssplit, skip_newline, use_la_ittb_shorthand)\n\n    # If we are provided a postprocessor, we prepare a list of pre-tokenized words and mwt flags and\n    # call the postprocessor for analysis.\n    if postprocessor:\n        doc = postprocess_doc(doc, postprocessor, orig_text)\n\n    if output_file: CoNLL.dict2conll(doc, output_file)\n    return oov_count, offset, all_preds, doc\n\ndef postprocess_doc(doc, postprocessor, orig_text=None):\n    \"\"\"Applies a postprocessor on the doc\"\"\"\n\n    # get a list of all the words in the \"draft\" document to pass to the postprocessor\n    # the words array looks like [[\"words, \"words\", \"words\"], [\"words, (\"i_am_a_mwt\", True), \"I_am_not\"]]\n    # and the postprocessor is expected to return in the same format\n    words = [[((word[\"text\"], True)\n                if word.get(\"misc\") == \"MWT=Yes\"\n                else word[\"text\"]) for word in sentence]\n            for sentence in doc]\n    if not orig_text:\n        raw_text = \"\".join(\"\".join(i) for i in all_raw) # template to compare the stitched text against\n    else:\n        raw_text = orig_text\n\n    # perform correction with the postprocessor\n    postprocessor_return = postprocessor(words)\n\n    # collect the words and MWTs separately\n    corrected_words = []\n    corrected_mwts = []\n    corrected_expansions = []\n\n    # for each word, if its just a string (without the (\"word\", mwt_bool) format)\n    # we default that the word is not a MWT.\n    for sent in postprocessor_return:\n        sent_words = []\n        sent_mwts = []\n        sent_expansions = []\n        for word in sent:\n            if isinstance(word, str):\n                sent_words.append(word)\n                sent_mwts.append(False)\n                sent_expansions.append(None)\n            else:\n                if isinstance(word[1], bool):\n                    sent_words.append(word[0])\n                    sent_mwts.append(word[1])\n                    sent_expansions.append(None)\n                else:\n                    sent_words.append(word[0])\n                    sent_mwts.append(True)\n                    # expansions are marked in a space-separated list, which\n                    # `stanza.common.doc.set_mwt_expansions` reads and splits again\n                    # by splitting by spaces. Therefore, to serialize the users' supplied MWT\n                    # information, we join them by spaces to be split later by\n                    # `set_mwt_expansions`.\n                    sent_expansions.append(\" \".join(word[1]))\n        corrected_words.append(sent_words)\n        corrected_mwts.append(sent_mwts)\n        corrected_expansions.append(sent_expansions)\n        \n    # check postprocessor output\n    token_lens = [len(i) for i in corrected_words]\n    mwt_lens = [len(i) for i in corrected_mwts]\n    assert token_lens == mwt_lens, \"Postprocessor returned token and MWT lists of different length! Token list lengths %s, MWT list lengths %s\" % (token_lens, mwt_lens)\n    \n    # reassemble document. offsets and oov shouldn't change\n    doc = reassemble_doc_from_tokens(corrected_words, corrected_mwts,\n                                     corrected_expansions, raw_text)\n\n    return doc\n\ndef reassemble_doc_from_tokens(tokens, mwts, expansions, raw_text):\n    \"\"\"Assemble a Stanza document list format from a list of string tokens, calculating offsets as needed.\n\n    Parameters\n    ----------\n    tokens : List[List[str]]\n        A list of sentences, which includes string tokens.\n    mwts : List[List[bool]]\n        Whether or not each of the tokens are MWTs to be analyzed by the MWT system.\n    expansions : List[List[Optional[List[str]]]]\n        A list of possible expansions for MWTs, or None if no user-defined expansion\n        is given.\n    parser_text : str\n        The raw text off of which we can compare offsets.\n\n    Returns\n    -------\n    List[List[Dict]]\n        List of words and their offsets, used as `doc`.\n    \"\"\"\n\n    # oov count and offset stays the same; doc gets regenerated\n    new_offset = 0\n    corrected_doc = []\n\n    for sent_words, sent_mwts, sent_expansions in zip(tokens, mwts, expansions):\n        sentence_doc = []\n\n        for indx, (word, mwt, expansion) in enumerate(zip(sent_words, sent_mwts, sent_expansions)):\n            try:\n                offset_index = raw_text.index(word, new_offset)\n            except ValueError as e:\n                sub_start = max(0, new_offset - 20)\n                sub_end = min(len(raw_text), new_offset + 20)\n                sub = raw_text[sub_start:sub_end]\n                raise ValueError(\"Could not find word |%s| starting from char_offset %d.  Surrounding text: |%s|. \\n Hint: did you accidentally add/subtract a symbol/character such as a space when combining tokens?\" % (word, new_offset, sub)) from e\n\n            wd = {\n                \"id\": (indx+1,), \"text\": word,\n                \"start_char\":  offset_index,\n                \"end_char\":    offset_index+len(word)\n            }\n            if expansion:\n                wd[\"manual_expansion\"] = True\n            elif mwt:\n                wd[\"misc\"] = \"MWT=Yes\"\n\n            sentence_doc.append(wd)\n\n            # start the next search after the previous word ended\n            new_offset = offset_index+len(word)\n\n        corrected_doc.append(sentence_doc)\n\n    # use the built in MWT system to expand MWTs\n    doc = Document(corrected_doc, raw_text)\n    doc.set_mwt_expansions([j\n                            for i in expansions\n                            for j in i if j],\n                           process_manual_expanded=True)\n    return doc.to_dict()\n\ndef decode_predictions(vocab, mwt_dict, orig_text, all_raw, all_preds, no_ssplit, skip_newline, use_la_ittb_shorthand):\n    \"\"\"\n    Decode the predictions into a document of words\n\n    Once everything is fed through the tokenizer model, it's time to decode the predictions\n    into actual tokens and sentences that the rest of the pipeline uses\n    \"\"\"\n    offset = 0\n    oov_count = 0\n    doc = []\n\n    text = SPACE_RE.sub(' ', orig_text) if orig_text is not None else None\n    char_offset = 0\n\n    if vocab is not None:\n        UNK_ID = vocab.unit2id('<UNK>')\n\n    for raw, pred in zip(all_raw, all_preds):\n        current_tok = ''\n        current_sent = []\n\n        for t, p in zip(raw, pred):\n            if t == '<PAD>':\n                break\n            # hack la_ittb\n            if use_la_ittb_shorthand and t in (\":\", \";\"):\n                p = 2\n            offset += 1\n            if vocab is not None and vocab.unit2id(t) == UNK_ID:\n                oov_count += 1\n\n            current_tok += t\n            if p >= 1:\n                if vocab is not None:\n                    tok = vocab.normalize_token(current_tok)\n                else:\n                    tok = current_tok\n                assert '\\t' not in tok, tok\n                if len(tok) <= 0:\n                    current_tok = ''\n                    continue\n                if orig_text is not None:\n                    st = -1\n                    tok_len = 0\n                    for part in SPACE_SPLIT_RE.split(current_tok):\n                        if len(part) == 0: continue\n                        if skip_newline:\n                            part_pattern = re.compile(r'\\s*'.join(re.escape(c) for c in part))\n                            match = part_pattern.search(text, char_offset)\n                            st0 = match.start(0) - char_offset\n                            partlen = match.end(0) - match.start(0)\n                            lstripped = match.group(0).lstrip()\n                        else:\n                            try:\n                                st0 = text.index(part, char_offset) - char_offset\n                            except ValueError as e:\n                                sub_start = max(0, char_offset - 20)\n                                sub_end = min(len(text), char_offset + 20)\n                                sub = text[sub_start:sub_end]\n                                raise ValueError(\"Could not find |%s| starting from char_offset %d.  Surrounding text: |%s|\" % (part, char_offset, sub)) from e\n                            partlen = len(part)\n                            lstripped = part.lstrip()\n                        if st < 0:\n                            st = char_offset + st0 + (partlen - len(lstripped))\n                        char_offset += st0 + partlen\n                    position_info = (st, char_offset)\n                else:\n                    position_info = None\n                current_sent.append((tok, p, position_info))\n                current_tok = ''\n                if (p == 2 or p == 4) and not no_ssplit:\n                    doc.append(process_sentence(current_sent, mwt_dict))\n                    current_sent = []\n\n        if len(current_tok) > 0:\n            raise ValueError(\"Finished processing tokens, but there is still text left!\")\n        if len(current_sent):\n            doc.append(process_sentence(current_sent, mwt_dict))\n\n    return oov_count, offset, doc\n\ndef match_tokens_with_text(sentences, orig_text):\n    \"\"\"\n    Turns pretokenized text and the original text into a Doc object\n\n    sentences: list of list of string\n    orig_text: string, where the text must be exactly the sentences\n      concatenated with 0 or more whitespace characters\n\n    if orig_text deviates in any way, a ValueError will be thrown\n    \"\"\"\n    text = \"\".join([\"\".join(x) for x in sentences])\n    all_raw = list(text)\n    all_preds = [0] * len(all_raw)\n    offset = 0\n    for sentence in sentences:\n        for word in sentence:\n            offset += len(word)\n            all_preds[offset-1] = 1\n        all_preds[offset-1] = 2\n    _, _, doc = decode_predictions(None, None, orig_text, [all_raw], [all_preds], False, False, False)\n    doc = Document(doc, orig_text)\n\n    # check that all the orig_text was used up by the tokens\n    offset = doc.sentences[-1].tokens[-1].end_char\n    remainder = orig_text[offset:].strip()\n    if len(remainder) > 0:\n        raise ValueError(\"Finished processing tokens, but there is still text left!\")\n\n    return doc\n\n\ndef eval_model(args, trainer, batches, vocab, mwt_dict):\n    oov_count, N, all_preds, doc = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen'])\n\n    all_preds = np.concatenate(all_preds, 0)\n    labels = np.concatenate(batches.labels())\n    counter = Counter(zip(all_preds, labels))\n\n    def f1(pred, gold, mapping):\n        pred = [mapping[p] for p in pred]\n        gold = [mapping[g] for g in gold]\n\n        lastp = -1; lastg = -1\n        tp = 0; fp = 0; fn = 0\n        for i, (p, g) in enumerate(zip(pred, gold)):\n            if p == g > 0 and lastp == lastg:\n                lastp = i\n                lastg = i\n                tp += 1\n            elif p > 0 and g > 0:\n                lastp = i\n                lastg = i\n                fp += 1\n                fn += 1\n            elif p > 0:\n                # and g == 0\n                lastp = i\n                fp += 1\n            elif g > 0:\n                lastg = i\n                fn += 1\n\n        if tp == 0:\n            return 0\n        else:\n            return 2 * tp / (2 * tp + fp + fn)\n\n    f1tok = f1(all_preds, labels, {0:0, 1:1, 2:1, 3:1, 4:1})\n    f1sent = f1(all_preds, labels, {0:0, 1:0, 2:1, 3:0, 4:1})\n    f1mwt = f1(all_preds, labels, {0:0, 1:1, 2:1, 3:2, 4:2})\n    logger.info(f\"{args['shorthand']}: token F1 = {f1tok*100:.2f}, sentence F1 = {f1sent*100:.2f}, mwt F1 = {f1mwt*100:.2f}\")\n    return harmonic_mean([f1tok, f1sent, f1mwt], [1, 1, .01])\n\n"
  },
  {
    "path": "stanza/models/tokenization/vocab.py",
    "content": "from collections import Counter\nimport re\n\nfrom stanza.models.common.vocab import BaseVocab\nfrom stanza.models.common.vocab import UNK, PAD\n\nSPACE_RE = re.compile(r'\\s')\n\nclass Vocab(BaseVocab):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.lang_replaces_spaces = any([self.lang.startswith(x) for x in ['zh', 'ja', 'ko']])\n\n    def build_vocab(self):\n        paras = self.data\n        counter = Counter()\n        for para in paras:\n            for unit in para:\n                normalized = self.normalize_unit(unit[0])\n                counter[normalized] += 1\n\n        self._id2unit = [PAD, UNK] + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))\n        self._unit2id = {w:i for i, w in enumerate(self._id2unit)}\n\n    def append(self, unit):\n        self._id2unit.append(unit)\n        idx = len(self._id2unit) - 1\n        self._unit2id[unit] = idx\n\n    def normalize_unit(self, unit):\n        # Normalize minimal units used by the tokenizer\n        return unit\n\n    def normalize_token(self, token):\n        token = SPACE_RE.sub(' ', token.lstrip())\n\n        if self.lang_replaces_spaces:\n            token = token.replace(' ', '')\n\n        return token\n"
  },
  {
    "path": "stanza/models/tokenizer.py",
    "content": "\"\"\"\nEntry point for training and evaluating a neural tokenizer.\n\nThis tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of\nrecurrent and convolutional architectures.\nFor details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.\n\nUpdated: This new version of tokenizer model incorporates the dictionary feature, especially useful for languages that\nhave multi-syllable words such as Vietnamese, Chinese or Thai. In summary, a lexicon contains all unique words found in \ntraining dataset and external lexicon (if any) is created during training and saved alongside the model after training.\nUsing this lexicon, a dictionary is created which includes \"words\", \"prefixes\" and \"suffixes\" sets. During data preparation,\ndictionary features are extracted at each character position, to \"look ahead\" and \"look backward\" to see if any words formed\nfound in the dictionary. The window size (or the dictionary feature length) is defined at the 95-percentile among all the existing\nwords in the lexicon, this is to eliminate the less frequent but long words (avoid having a high-dimension feat vector). Prefixes \nand suffixes are used to stop early during the window-dictionary checking process.  \n\"\"\"\n\nimport argparse\nfrom copy import copy\nimport logging\nimport random\nimport numpy as np\nimport os\nimport torch\nimport json\nfrom stanza.models.common import utils\nfrom stanza.models.tokenization.trainer import Trainer\nfrom stanza.models.tokenization.data import DataLoader, TokenizationDataset\nfrom stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary\nfrom stanza.models import _training_logging\n\nlogger = logging.getLogger('stanza')\n\ndef build_argparse():\n    \"\"\"\n    If args == None, the system args are used.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--txt_file', type=str, help=\"Input plaintext file\")\n    parser.add_argument('--label_file', type=str, default=None, help=\"Character-level label file\")\n    parser.add_argument('--mwt_json_file', type=str, default=None, help=\"JSON file for MWT expansions\")\n    parser.add_argument('--conll_file', type=str, default=None, help=\"CoNLL file for output\")\n    parser.add_argument('--dev_txt_file', type=str, help=\"(Train only) Input plaintext file for the dev set\")\n    parser.add_argument('--dev_label_file', type=str, default=None, help=\"(Train only) Character-level label file for the dev set\")\n    parser.add_argument('--dev_conll_gold', type=str, default=None, help=\"(Train only) CoNLL-U file for the dev set for early stopping\")\n    parser.add_argument('--lang', type=str, help=\"Language\")\n    parser.add_argument('--shorthand', type=str, help=\"UD treebank shorthand\")\n\n    parser.add_argument('--mode', default='train', choices=['train', 'predict'])\n    parser.add_argument('--skip_newline', action='store_true', help=\"Whether to skip newline characters in input. Particularly useful for languages like Chinese.\")\n\n    parser.add_argument('--emb_dim', type=int, default=32, help=\"Dimension of unit embeddings\")\n    parser.add_argument('--hidden_dim', type=int, default=64, help=\"Dimension of hidden units\")\n    parser.add_argument('--conv_filters', type=str, default=\"1,9\", help=\"Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.\")\n    parser.add_argument('--no-residual', dest='residual', action='store_false', help=\"Add linear residual connections\")\n    parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help=\"\\\"Hierarchical\\\" RNN tokenizer\")\n    parser.add_argument('--hier_invtemp', type=float, default=0.5, help=\"Inverse temperature used in propagating tokenization predictions between RNN layers\")\n    parser.add_argument('--input_dropout', action='store_true', help=\"Dropout input embeddings as well\")\n    parser.add_argument('--conv_res', type=str, default=None, help=\"Convolutional residual layers for the RNN\")\n    parser.add_argument('--rnn_layers', type=int, default=1, help=\"Layers of RNN in the tokenizer\")\n    parser.add_argument('--use_dictionary', action='store_true', help=\"Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/tokenize/zh_gsdsimp-externaldict.txt\")\n\n    parser.add_argument('--max_grad_norm', type=float, default=1.0, help=\"Maximum gradient norm to clip to\")\n    parser.add_argument('--anneal', type=float, default=.999, help=\"Anneal the learning rate by this amount when dev performance deteriorate\")\n    parser.add_argument('--anneal_after', type=int, default=2000, help=\"Anneal the learning rate no earlier than this step\")\n    parser.add_argument('--lr0', type=float, default=2e-3, help=\"Initial learning rate\")\n    parser.add_argument('--dropout', type=float, default=0.33, help=\"Dropout probability\")\n    parser.add_argument('--unit_dropout', type=float, default=0.33, help=\"Unit dropout probability\")\n    parser.add_argument('--feat_dropout', type=float, default=0.05, help=\"Features dropout probability for each element in feature vector\")\n    parser.add_argument('--feat_unit_dropout', type=float, default=0.33, help=\"The whole feature of units dropout probability\")\n    parser.add_argument('--tok_noise', type=float, default=0.02, help=\"Probability to induce noise to the input of the higher RNN\")\n    parser.add_argument('--sent_drop_prob', type=float, default=0.2, help=\"Probability to drop sentences at the end of batches during training uniformly at random.  Idea is to fake paragraph endings.\")\n    parser.add_argument('--last_char_drop_prob', type=float, default=0.2, help=\"Probability to drop the last char of a block of text during training, uniformly at random.  Idea is to fake a document ending w/o sentence final punctuation, hopefully to avoid the tokenizer learning to always tokenize the last character as a period\")\n    parser.add_argument('--last_char_move_prob', type=float, default=0.02, help=\"Probability to move the sentence final punctuation of a sentence during training, uniformly at random.  Idea is to teach the tokenizer that a space separated sentence final punct still ends the sentence\")\n    parser.add_argument('--punct_move_back_prob', type=float, default=0.02, help=\"Probability to move a comma in the sentence one over, removing the previous space, during training.  Idea is to teach the tokenizer that commas can appear next to words even in languages where the dataset doesn't allow it, such as Vietnamese\")\n    parser.add_argument('--split_mwt_prob', type=float, default=0.01, help=\"Probably to split an MWT into its component pieces and turn it into separate words\")\n    parser.add_argument('--augment_final_punct_prob', type=float, default=0.05, help=\"Probability to replace a ? with a ？ or other similar augmentations\")\n    parser.add_argument('--weight_decay', type=float, default=0.0, help=\"Weight decay\")\n    parser.add_argument('--max_seqlen', type=int, default=100, help=\"Maximum sequence length to consider at a time\")\n    parser.add_argument('--batch_size', type=int, default=32, help=\"Batch size to use\")\n    parser.add_argument('--epochs', type=int, default=10, help=\"Total epochs to train the model for\")\n    parser.add_argument('--steps', type=int, default=50000, help=\"Steps to train the model for, if unspecified use epochs\")\n    parser.add_argument('--report_steps', type=int, default=20, help=\"Update step interval to report loss\")\n    parser.add_argument('--shuffle_steps', type=int, default=100, help=\"Step interval to shuffle each paragraph in the generator\")\n    parser.add_argument('--eval_steps', type=int, default=200, help=\"Step interval to evaluate the model on the dev set for early stopping\")\n    parser.add_argument('--max_steps_before_stop', type=int, default=5000, help='Early terminates after this many steps if the dev scores are not improving')\n    parser.add_argument('--save_name', type=str, default=\"{shorthand}_{embedding}_tokenizer.pt\", help=\"File name to save the model\")\n    parser.add_argument('--load_name', type=str, default=None, help=\"File name to load a saved model\")\n    parser.add_argument('--save_dir', type=str, default='saved_models/tokenize', help=\"Directory to save models in\")\n    utils.add_device_args(parser)\n    parser.add_argument('--seed', type=int, default=1234)\n\n    parser.add_argument('--charlm', action='store_true', help=\"Turn on contextualized char embedding using pretrained character-level language model.\")\n    parser.add_argument('--charlm_shorthand', type=str, default=None, help=\"Shorthand for character-level language model training corpus.\")\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n\n    parser.add_argument('--use_mwt', dest='use_mwt', default=None, action='store_true', help='Whether or not to include mwt output layers.  If set to None, this will be determined by examining the training data for MWTs')\n    parser.add_argument('--no_use_mwt', dest='use_mwt', action='store_false', help='Whether or not to include mwt output layers')\n\n    parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name')\n    parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n    return parser\n\ndef parse_args(args=None):\n    parser = build_argparse()\n    args = parser.parse_args(args=args)\n\n    if args.wandb_name:\n        args.wandb = True\n\n    args = vars(args)\n    return args\n\ndef model_file_name(args):\n    embedding = \"nocharlm\"\n    if args['charlm'] and args['charlm_forward_file']:\n        embedding = \"charlm\"\n    save_name = args['save_name'].format(shorthand=args['shorthand'],\n                                         embedding=embedding)\n\n    logger.info(\"Saving to: %s\", save_name)\n    if not os.path.exists(os.path.join(args['save_dir'], save_name)) and os.path.exists(save_name):\n        return save_name\n    return os.path.join(args['save_dir'], save_name)\n\ndef main(args=None):\n    args = parse_args(args=args)\n\n    utils.set_random_seed(args['seed'])\n\n    logger.info(\"Running tokenizer in {} mode\".format(args['mode']))\n\n    args['feat_funcs'] = ['space_before', 'capitalized', 'numeric', 'end_of_para', 'start_of_para']\n    args['feat_dim'] = len(args['feat_funcs'])\n    args['save_name'] = model_file_name(args)\n    utils.ensure_dir(os.path.split(args['save_name'])[0])\n\n    if args['mode'] == 'train':\n        return train(args)\n    else:\n        return evaluate(args)\n\ndef train(args):\n    if args['use_dictionary']:\n        #load lexicon\n        lexicon, args['num_dict_feat'] = load_lexicon(args)\n        #create the dictionary\n        dictionary = create_dictionary(lexicon)\n        #adjust the feat_dim\n        args['feat_dim'] += args['num_dict_feat']*2\n    else:\n        args['num_dict_feat'] = 0\n        lexicon=None\n        dictionary=None\n\n    mwt_dict = load_mwt_dict(args['mwt_json_file'])\n    mwt_expansions = {x: y[0] for x, y in mwt_dict.items()}\n\n    train_input_files = {\n        'txt': args['txt_file'],\n        'label': args['label_file']\n    }\n    train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary, mwt_expansions=mwt_expansions)\n    vocab = train_batches.vocab\n\n    args['vocab_size'] = len(vocab)\n\n    dev_input_files = {\n            'txt': args['dev_txt_file'],\n            'label': args['dev_label_file']\n            }\n    dev_batches = TokenizationDataset(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=dictionary)\n\n    if args['use_mwt'] is None:\n        args['use_mwt'] = train_batches.has_mwt()\n        logger.info(\"Found {}mwts in the training data.  Setting use_mwt to {}\".format((\"\" if args['use_mwt'] else \"no \"), args['use_mwt']))\n\n    trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'], foundation_cache=None)\n\n    if args['load_name'] is not None:\n        load_name = os.path.join(args['save_dir'], args['load_name'])\n        trainer.load(load_name)\n    trainer.change_lr(args['lr0'])\n\n    N = len(train_batches)\n    steps = args['steps'] if args['steps'] is not None else int(N * args['epochs'] / args['batch_size'] + .5)\n    lr = args['lr0']\n\n    prev_dev_score = -1\n    best_dev_score = -1\n    best_dev_step = -1\n\n    if args['wandb']:\n        import wandb\n        wandb_name = args['wandb_name'] if args['wandb_name'] else \"%s_tokenizer\" % args['shorthand']\n        wandb.init(name=wandb_name, config=args)\n        wandb.run.define_metric('train_loss', summary='min')\n        wandb.run.define_metric('dev_score', summary='max')\n\n\n    for step in range(1, steps+1):\n        batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout'])\n\n        loss = trainer.update(batch)\n        if step % args['report_steps'] == 0:\n            logger.info(\"Step {:6d}/{:6d} Loss: {:.3f}\".format(step, steps, loss))\n            if args['wandb']:\n                wandb.log({'train_loss': loss}, step=step)\n\n        if args['shuffle_steps'] > 0 and step % args['shuffle_steps'] == 0:\n            train_batches.shuffle()\n\n        if step % args['eval_steps'] == 0:\n            dev_score = eval_model(args, trainer, dev_batches, vocab, mwt_dict)\n            if args['wandb']:\n                wandb.log({'dev_score': dev_score}, step=step)\n            reports = ['Dev score: {:6.3f}'.format(dev_score * 100)]\n            if step >= args['anneal_after'] and dev_score < prev_dev_score:\n                reports += ['lr: {:.6f} -> {:.6f}'.format(lr, lr * args['anneal'])]\n                lr *= args['anneal']\n                trainer.change_lr(lr)\n\n            prev_dev_score = dev_score\n\n            if dev_score > best_dev_score:\n                reports += ['New best dev score!']\n                best_dev_score = dev_score\n                best_dev_step = step\n                trainer.save(args['save_name'])\n            elif best_dev_step > 0 and step - best_dev_step > args['max_steps_before_stop']:\n                reports += ['Stopping training after {} steps with no improvement'.format(step - best_dev_step)]\n                logger.info('\\t'.join(reports))\n                break\n\n            logger.info('\\t'.join(reports))\n\n    if args['wandb']:\n        wandb.finish()\n\n    if best_dev_step > -1:\n        logger.info('Best dev score={} at step {}'.format(best_dev_score, best_dev_step))\n    else:\n        logger.info('Dev set never evaluated.  Saving final model')\n        trainer.save(args['save_name'])\n\n    return trainer, None\n\ndef evaluate(args):\n    mwt_dict = load_mwt_dict(args['mwt_json_file'])\n    trainer = Trainer(args=args, model_file=args['load_name'] or args['save_name'], device=args['device'], foundation_cache=None)\n    loaded_args, vocab = trainer.args, trainer.vocab\n\n    for k in loaded_args:\n        if not k.endswith('_file') and k not in ['device', 'mode', 'save_dir', 'load_name', 'save_name']:\n            args[k] = loaded_args[k]\n    \n    eval_input_files = {\n            'txt': args['txt_file'],\n            'label': args['label_file']\n            }\n\n\n    batches = TokenizationDataset(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=trainer.dictionary)\n\n    oov_count, N, _, doc = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen'])\n\n    logger.info(\"OOV rate: {:6.3f}% ({:6d}/{:6d})\".format(oov_count / N * 100, oov_count, N))\n\n    return trainer, doc\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/models/wl_coref.py",
    "content": "\"\"\"\nRuns experiments with CorefModel.\n\nTry 'python wl_coref.py -h' for more details.\n\nCode based on\n\nhttps://github.com/KarelDO/wl-coref/tree/master\nhttps://arxiv.org/abs/2310.06165\n\nThis was a fork of\n\nhttps://github.com/vdobrovolskii/wl-coref\nhttps://aclanthology.org/2021.emnlp-main.605/\n\nIf you use Stanza's coref module in your work, please cite the following:\n\n@misc{doosterlinck2023cawcoref,\n  title={CAW-coref: Conjunction-Aware Word-level Coreference Resolution},\n  author={Karel D'Oosterlinck and Semere Kiros Bitew and Brandon Papineau and Christopher Potts and Thomas Demeester and Chris Develder},\n  year={2023},\n  eprint={2310.06165},\n  archivePrefix={arXiv},\n  primaryClass={cs.CL},\n  url = \"https://arxiv.org/abs/2310.06165\",\n}\n\n@inproceedings{dobrovolskii-2021-word,\n  title = \"Word-Level Coreference Resolution\",\n  author = \"Dobrovolskii, Vladimir\",\n  booktitle = \"Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing\",\n  month = nov,\n  year = \"2021\",\n  address = \"Online and Punta Cana, Dominican Republic\",\n  publisher = \"Association for Computational Linguistics\",\n  url = \"https://aclanthology.org/2021.emnlp-main.605\",\n  pages = \"7670--7675\"\n}\n\"\"\"\n\nimport argparse\nfrom contextlib import contextmanager\nimport datetime\nimport logging\nimport os\nimport random\nimport sys\nimport dataclasses\nimport time\n\n\nimport numpy as np  # type: ignore\nimport torch        # type: ignore\n\nfrom stanza.models.common.utils import set_random_seed\nfrom stanza.models.coref.model import CorefModel\n\n\nlogger = logging.getLogger('stanza')\n\n@contextmanager\ndef output_running_time():\n    \"\"\" Prints the time elapsed in the context \"\"\"\n    start = int(time.time())\n    try:\n        yield\n    finally:\n        end = int(time.time())\n        delta = datetime.timedelta(seconds=end - start)\n        logger.info(f\"Total running time: {delta}\")\n\n\ndef deterministic() -> None:\n    torch.backends.cudnn.deterministic = True   # type: ignore\n    torch.backends.cudnn.benchmark = False      # type: ignore\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser()\n    argparser.add_argument(\"mode\", choices=(\"train\", \"eval\"))\n    argparser.add_argument(\"experiment\")\n    argparser.add_argument(\"--config_file\", default=\"config.toml\")\n    argparser.add_argument(\"--data_split\", choices=(\"train\", \"dev\", \"test\"),\n                           default=\"test\",\n                           help=\"Data split to be used for evaluation.\"\n                                \" Defaults to 'test'.\"\n                                \" Ignored in 'train' mode.\")\n    argparser.add_argument(\"--batch_size\", type=int,\n                           help=\"Adjust to override the config value of anaphoricity \"\n                                \"batch size if you are experiencing out-of-memory \"\n                                \"issues\")\n    argparser.add_argument(\"--disable_singletons\", action=\"store_true\",\n                           help=\"don't predict singletons\")\n    argparser.add_argument(\"--full_pairwise\", action=\"store_true\",\n                           help=\"use speaker and document embeddings\")\n    argparser.add_argument(\"--hidden_size\", type=int,\n                           help=\"Adjust the anaphoricity scorer hidden size\")\n    argparser.add_argument(\"--rough_k\", type=int,\n                           help=\"Adjust the number of dummies to keep\")\n    argparser.add_argument(\"--n_hidden_layers\", type=int,\n                           help=\"Adjust the anaphoricity scorer hidden layers\")\n    argparser.add_argument(\"--dummy_mix\", type=float,\n                           help=\"Adjust the dummy mix\")\n    argparser.add_argument(\"--bert_finetune_begin_epoch\", type=float,\n                           help=\"Adjust the bert finetune begin epoch\")\n    argparser.add_argument(\"--bert_model\", type=str,\n                           help=\"Use this transformer for the given experiment\")\n    argparser.add_argument(\"--warm_start\", action=\"store_true\",\n                           help=\"If set, the training will resume from the\"\n                                \" last checkpoint saved if any. Ignored in\"\n                                \" evaluation modes.\"\n                                \" Incompatible with '--weights'.\")\n    argparser.add_argument(\"--weights\",\n                           help=\"Path to file with weights to load.\"\n                                \" If not supplied, in 'eval' mode the latest\"\n                                \" weights of the experiment will be loaded;\"\n                                \" in 'train' mode no weights will be loaded.\")\n    argparser.add_argument(\"--word_level\", action=\"store_true\",\n                           help=\"If set, output word-level conll-formatted\"\n                                \" files in evaluation modes. Ignored in\"\n                                \" 'train' mode.\")\n    argparser.add_argument(\"--learning_rate\", default=None, type=float,\n                           help=\"If set, update the learning rate for the model\")\n    argparser.add_argument(\"--bert_learning_rate\", default=None, type=float,\n                           help=\"If set, update the learning rate for the transformer\")\n    argparser.add_argument(\"--save_dir\", default=None,\n                           help=\"If set, update the save directory for writing models\")\n    argparser.add_argument(\"--save_name\", default=None,\n                           help=\"If set, update the save name for writing models (otherwise, section name)\")\n    argparser.add_argument(\"--score_lang\", default=None,\n                           help=\"only score a particular language for eval\")\n    argparser.add_argument(\"--log_norms\", action=\"store_true\", default=None,\n                           help=\"If set, log all of the trainable norms each epoch.  Very noisy!\")\n    argparser.add_argument(\"--seed\", type=int, default=2020,\n                           help=\"Random seed to set\")\n\n    argparser.add_argument(\"--lang_lr_attenuation\", type=str, default=None,\n                           help=\"A comma-separated list of languages where the LR will be scaled by 1/epoch, such as --lang_lr_attenuation=es,en,de,...\")\n    argparser.add_argument(\"--lang_lr_weights\", type=str, default=None,\n                           help=\"A comma-separated list of languages and their weights of LR scaling for different languages, such as es=0.5,en=1.0,...\")\n\n    argparser.add_argument(\"--max_train_len\", type=int, default=5000,\n                           help=\"Skip any documents longer than this maximum length\")\n    argparser.add_argument(\"--no_max_train_len\", action=\"store_const\", const=float(\"inf\"), dest=\"max_train_len\",\n                           help=\"Do not skip any documents for being too long\")\n\n    argparser.add_argument(\"--train_epochs\", type=int, default=None,\n                           help=\"Train this many epochs\")\n    argparser.add_argument(\"--plateau_epochs\", type=int, default=None,\n                           help=\"Stop training if plateaued for this many epochs (only applies if positive)\")\n\n    argparser.add_argument(\"--train_data\", default=None, help=\"File to use for train data\")\n    argparser.add_argument(\"--dev_data\", default=None, help=\"File to use for dev data\")\n    argparser.add_argument(\"--test_data\", default=None, help=\"File to use for test data\")\n\n    argparser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training.  Only applies to training.  Use --wandb_name instead to specify a name', default=False)\n    argparser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training.  Will default to the dataset short name')\n\n    args = argparser.parse_args()\n\n    if args.warm_start and args.weights is not None:\n        raise ValueError(\"The following options are incompatible: '--warm_start' and '--weights'\")\n\n    set_random_seed(args.seed)\n    deterministic()\n    config = CorefModel._load_config(args.config_file, args.experiment)\n    if args.batch_size:\n        config.a_scoring_batch_size = args.batch_size\n    if args.hidden_size:\n        config.hidden_size = args.hidden_size\n    if args.n_hidden_layers:\n        config.n_hidden_layers = args.n_hidden_layers\n    if args.learning_rate is not None:\n        config.learning_rate = args.learning_rate\n    if args.bert_model is not None:\n        config.bert_model = args.bert_model\n    if args.bert_learning_rate is not None:\n        config.bert_learning_rate = args.bert_learning_rate\n    if args.bert_finetune_begin_epoch is not None:\n        config.bert_finetune_begin_epoch = args.bert_finetune_begin_epoch\n    if args.dummy_mix is not None:\n        config.dummy_mix = args.dummy_mix\n\n    if args.save_dir is not None:\n        config.save_dir = args.save_dir\n    if args.save_name:\n        config.save_name = args.save_name\n    else:\n        config.save_name = args.experiment\n\n    if args.rough_k is not None:\n        config.rough_k = args.rough_k\n    if args.log_norms is not None:\n        config.log_norms = args.log_norms\n    if args.full_pairwise:\n        config.full_pairwise = args.full_pairwise\n    if args.disable_singletons:\n        config.singletons = False\n    if args.train_data:\n        config.train_data = args.train_data\n    if args.dev_data:\n        config.dev_data = args.dev_data\n    if args.test_data:\n        config.test_data = args.test_data\n\n    if args.max_train_len:\n        config.max_train_len = args.max_train_len\n\n    if args.train_epochs:\n        config.train_epochs = args.train_epochs\n    if args.plateau_epochs:\n        config.plateau_epochs = args.plateau_epochs\n\n    if args.lang_lr_attenuation:\n        config.lang_lr_attenuation = args.lang_lr_attenuation\n    if args.lang_lr_weights:\n        config.lang_lr_weights = args.lang_lr_weights\n\n    # if wandb, generate wandb configuration \n    if args.mode == \"train\":\n        if args.wandb:\n            import wandb\n            wandb_name = args.wandb_name if args.wandb_name else f\"wl_coref_{args.experiment}\"\n            wandb.init(name=wandb_name, config=dataclasses.asdict(config), project=\"stanza\")\n            wandb.run.define_metric('train_c_loss', summary='min')\n            wandb.run.define_metric('train_s_loss', summary='min')\n            wandb.run.define_metric('dev_score', summary='max')\n\n        model = CorefModel(config=config)\n        if args.weights is not None or args.warm_start:\n            model.load_weights(path=args.weights, map_location=\"cpu\",\n                               noexception=args.warm_start)\n        with output_running_time():\n            model.train(args.wandb)\n    else:\n        config_update = {\n            'log_norms': args.log_norms if args.log_norms is not None else False\n        }\n        if args.test_data:\n            config_update['test_data'] = args.test_data\n\n        if args.weights is None and config.save_name is not None:\n            args.weights = config.save_name\n        if not os.path.exists(args.weights) and os.path.exists(args.weights + \".pt\"):\n            args.weights = args.weights + \".pt\"\n        elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights)):\n            args.weights = os.path.join(config.save_dir, args.weights)\n        elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights + \".pt\")):\n            args.weights = os.path.join(config.save_dir, args.weights + \".pt\")\n        model = CorefModel.load_model(path=args.weights, map_location=\"cpu\",\n                                      ignore={\"bert_optimizer\", \"general_optimizer\",\n                                              \"bert_scheduler\", \"general_scheduler\"},\n                                      config_update=config_update)\n        results = model.evaluate(data_split=args.data_split,\n                                 word_level_conll=args.word_level, \n                                 eval_lang=args.score_lang)\n        # logger.info((\"mean loss\", \"))\n        print(\"\\t\".join([str(round(i, 3)) for i in results]))\n"
  },
  {
    "path": "stanza/pipeline/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/pipeline/_constants.py",
    "content": "\"\"\" Module defining constants \"\"\"\n\n# string constants for processor names\nLANGID = 'langid'\nTOKENIZE = 'tokenize'\nMWT = 'mwt'\nPOS = 'pos'\nLEMMA = 'lemma'\nDEPPARSE = 'depparse'\nNER = 'ner'\nSENTIMENT = 'sentiment'\nCONSTITUENCY = 'constituency'\nCOREF = 'coref'\nMORPHSEG = 'morphseg'\n"
  },
  {
    "path": "stanza/pipeline/constituency_processor.py",
    "content": "\"\"\"\nProcessor that attaches a constituency tree to a sentence\n\"\"\"\n\nfrom stanza.models.constituency.trainer import Trainer\n\nfrom stanza.models.common import doc\nfrom stanza.models.common.utils import sort_with_indices, unsort\nfrom stanza.utils.get_tqdm import get_tqdm\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\n\ntqdm = get_tqdm()\n\n@register_processor(CONSTITUENCY)\nclass ConstituencyProcessor(UDProcessor):\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([CONSTITUENCY])\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([TOKENIZE, POS])\n\n    # default batch size, measured in sentences\n    DEFAULT_BATCH_SIZE = 50\n\n    def _set_up_requires(self):\n        self._pretagged = self._config.get('pretagged')\n        if self._pretagged:\n            self._requires = set()\n        else:\n            self._requires = self.__class__.REQUIRES_DEFAULT\n\n    def _set_up_model(self, config, pipeline, device):\n        # set up model\n        # pretrain and charlm paths are args from the config\n        # bert (if used) will be chosen from the model save file\n        args = {\n            \"wordvec_pretrain_file\": config.get('pretrain_path', None),\n            \"charlm_forward_file\": config.get('forward_charlm_path', None),\n            \"charlm_backward_file\": config.get('backward_charlm_path', None),\n            \"device\": device,\n        }\n        trainer = Trainer.load(filename=config['model_path'],\n                               args=args,\n                               foundation_cache=pipeline.foundation_cache)\n        self._trainer = trainer\n        self._model = trainer.model\n        self._model.eval()\n        # batch size counted as sentences\n        self._batch_size = int(config.get('batch_size', ConstituencyProcessor.DEFAULT_BATCH_SIZE))\n        self._tqdm = 'tqdm' in config and config['tqdm']\n\n    def _set_up_final_config(self, config):\n        loaded_args = self._model.args\n        loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}\n        loaded_args.update(config)\n        self._config = loaded_args\n\n    def process(self, document):\n        sentences = document.sentences\n\n        if self._model.uses_xpos():\n            words = [[(w.text, w.xpos) for w in s.words] for s in sentences]\n        else:\n            words = [[(w.text, w.upos) for w in s.words] for s in sentences]\n        words, original_indices = sort_with_indices(words, key=len, reverse=True)\n        if self._tqdm:\n            words = tqdm(words)\n\n        trees = self._model.parse_tagged_words(words, self._batch_size)\n        trees = unsort(trees, original_indices)\n        document.set(CONSTITUENCY, trees, to_sentence=True)\n        return document\n\n    def get_constituents(self):\n        \"\"\"\n        Return a set of the constituents known by this model\n\n        For a pipeline, this can be queried with\n          pipeline.processors[\"constituency\"].get_constituents()\n        \"\"\"\n        return set(self._model.constituents)\n"
  },
  {
    "path": "stanza/pipeline/core.py",
    "content": "\"\"\"\nPipeline that runs tokenize,mwt,pos,lemma,depparse\n\"\"\"\n\nimport argparse\nimport collections\nfrom enum import Enum\nimport io\nimport itertools\nimport sys\nimport logging\nimport json\nimport os\n\nfrom stanza.pipeline._constants import *\nfrom stanza.models.common.constant import langcode_to_lang\nfrom stanza.models.common.doc import Document\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.models.common.utils import default_device\nfrom stanza.pipeline.processor import Processor, ProcessorRequirementsException\nfrom stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS\nfrom stanza.pipeline.langid_processor import LangIDProcessor\nfrom stanza.pipeline.tokenize_processor import TokenizeProcessor\nfrom stanza.pipeline.mwt_processor import MWTProcessor\nfrom stanza.pipeline.pos_processor import POSProcessor\nfrom stanza.pipeline.lemma_processor import LemmaProcessor\nfrom stanza.pipeline.constituency_processor import ConstituencyProcessor\nfrom stanza.pipeline.coref_processor import CorefProcessor\nfrom stanza.pipeline.depparse_processor import DepparseProcessor\nfrom stanza.pipeline.sentiment_processor import SentimentProcessor\nfrom stanza.pipeline.ner_processor import NERProcessor\nfrom stanza.resources.common import DEFAULT_MODEL_DIR, DEFAULT_RESOURCES_URL, DEFAULT_RESOURCES_VERSION, ModelSpecification, add_dependencies, add_mwt, download_models, download_resources_json, flatten_processor_list, load_resources_json, maintain_processor_list, process_pipeline_parameters, set_logging_level, sort_processors\nfrom stanza.resources.default_packages import PACKAGES\nfrom stanza.utils.conll import CoNLL, CoNLLError\nfrom stanza.utils.helper_func import make_table\n\nlogger = logging.getLogger('stanza')\n\nclass DownloadMethod(Enum):\n    \"\"\"\n    Determines a couple options on how to download resources for the pipeline.\n\n    NONE will not download anything, including HF transformers, probably resulting in failure if the resources aren't already in place.\n    REUSE_RESOURCES will reuse the existing resources.json and models, but will download any missing models.\n    DOWNLOAD_RESOURCES will download a new resources.json and will overwrite any out of date models.\n    \"\"\"\n    NONE = 1\n    REUSE_RESOURCES = 2\n    DOWNLOAD_RESOURCES = 3\n\nclass LanguageNotDownloadedError(FileNotFoundError):\n    def __init__(self, lang, lang_dir, model_path):\n        super().__init__(f'Could not find the model file {model_path}.  The expected model directory {lang_dir} is missing.  Perhaps you need to run stanza.download(\"{lang}\")')\n        self.lang = lang\n        self.lang_dir = lang_dir\n        self.model_path = model_path\n\nclass UnsupportedProcessorError(FileNotFoundError):\n    def __init__(self, processor, lang):\n        super().__init__(f'Processor {processor} is not known for language {lang}.  If you have created your own model, please specify the {processor}_model_path parameter when creating the pipeline.')\n        self.processor = processor\n        self.lang = lang\n\nclass IllegalPackageError(ValueError):\n    def __init__(self, msg):\n        super().__init__(msg)\n\nclass PipelineRequirementsException(Exception):\n    \"\"\"\n    Exception indicating one or more requirements failures while attempting to build a pipeline.\n    Contains a ProcessorRequirementsException list.\n    \"\"\"\n\n    def __init__(self, processor_req_fails):\n        self._processor_req_fails = processor_req_fails\n        self.build_message()\n\n    @property\n    def processor_req_fails(self):\n        return self._processor_req_fails\n\n    def build_message(self):\n        err_msg = io.StringIO()\n        print(*[req_fail.message for req_fail in self.processor_req_fails], sep='\\n', file=err_msg)\n        self.message = '\\n\\n' + err_msg.getvalue()\n\n    def __str__(self):\n        return self.message\n\ndef build_default_config_option(model_specs):\n    \"\"\"\n    Build a config option for a couple situations: lemma=identity, processor is a variant\n\n    Returns the option name and value\n\n    Refactored from build_default_config so that we can reuse it when\n    downloading all models\n    \"\"\"\n    # handle case when processor variants are used\n    if any(model_spec.package in PROCESSOR_VARIANTS[model_spec.processor] for model_spec in model_specs):\n        if len(model_specs) > 1:\n            raise IllegalPackageError(\"Variant processor selected for {}, but multiple packages requested\".format(model_spec.processor))\n        return f\"{model_specs[0].processor}_with_{model_specs[0].package}\", True\n    # handle case when identity is specified as lemmatizer\n    elif any(model_spec.processor == LEMMA and model_spec.package == 'identity' for model_spec in model_specs):\n        if len(model_specs) > 1:\n            raise IllegalPackageError(\"Identity processor selected for lemma, but multiple packages requested\")\n        return f\"{LEMMA}_use_identity\", True\n    return None\n\ndef filter_variants(model_specs):\n    return [(key, value) for (key, value) in model_specs if build_default_config_option(value) is None]\n\n# given a language and models path, build a default configuration\ndef build_default_config(resources, lang, model_dir, load_list):\n    default_config = {}\n    for processor, model_specs in load_list:\n        option = build_default_config_option(model_specs)\n        if option is not None:\n            # if an option is set for the model_specs, keep that option and ignore\n            # the rest of the model spec\n            default_config[option[0]] = option[1]\n            continue\n\n        model_paths = [os.path.join(model_dir, lang, processor, model_spec.package + '.pt') for model_spec in model_specs]\n        dependencies = [model_spec.dependencies for model_spec in model_specs]\n\n        # Special case for NER: load multiple models at once\n        # The pattern will be:\n        #   a list of ner_model_path\n        #   a list of ner_dependencies\n        #     where each item in ner_dependencies is a map\n        #     the map may contain forward_charlm_path, backward_charlm_path, or any other deps\n        # The user will be able to override the defaults using a semicolon separated string\n        # TODO: at least use the same config pattern for all other models\n        if processor == NER:\n            default_config[f\"{processor}_model_path\"] = model_paths\n            dependency_paths = []\n            for dependency_block in dependencies:\n                if not dependency_block:\n                    dependency_paths.append({})\n                    continue\n                dependency_paths.append({})\n                for dependency in dependency_block:\n                    dep_processor, dep_model = dependency\n                    dependency_paths[-1][f\"{dep_processor}_path\"] = os.path.join(model_dir, lang, dep_processor, dep_model + '.pt')\n            default_config[f\"{processor}_dependencies\"] = dependency_paths\n            continue\n\n        if len(model_specs) > 1:\n            raise IllegalPackageError(\"Specified multiple packages for {}, which currently only handles one package\".format(processor))\n\n        default_config[f\"{processor}_model_path\"] = model_paths[0]\n        if not dependencies[0]: continue\n        for dependency in dependencies[0]:\n            dep_processor, dep_model = dependency\n            default_config[f\"{processor}_{dep_processor}_path\"] = os.path.join(\n                model_dir, lang, dep_processor, dep_model + '.pt'\n            )\n\n    return default_config\n\ndef normalize_download_method(download_method):\n    \"\"\"\n    Turn None -> DownloadMethod.NONE, strings to the corresponding enum\n    \"\"\"\n    if download_method is None:\n        return DownloadMethod.NONE\n    elif isinstance(download_method, str):\n        try:\n            return DownloadMethod[download_method.upper()]\n        except KeyError as e:\n            raise ValueError(\"Unknown download method %s\" % download_method) from e\n    return download_method\n\nclass Pipeline:\n\n    def __init__(self,\n                 lang='en',\n                 dir=DEFAULT_MODEL_DIR,\n                 package='default',\n                 processors={},\n                 logging_level=None,\n                 verbose=None,\n                 use_gpu=None,\n                 model_dir=None,\n                 download_method=DownloadMethod.DOWNLOAD_RESOURCES,\n                 resources_url=DEFAULT_RESOURCES_URL,\n                 resources_branch=None,\n                 resources_version=DEFAULT_RESOURCES_VERSION,\n                 resources_filepath=None,\n                 proxies=None,\n                 foundation_cache=None,\n                 device=None,\n                 allow_unknown_language=False,\n                 **kwargs):\n        self.lang, self.dir, self.kwargs = lang, dir, kwargs\n        if model_dir is not None and dir == DEFAULT_MODEL_DIR:\n            self.dir = model_dir\n\n        # set global logging level\n        set_logging_level(logging_level, verbose)\n\n        self.download_method = normalize_download_method(download_method)\n        if (self.download_method is DownloadMethod.DOWNLOAD_RESOURCES or\n            (self.download_method is DownloadMethod.REUSE_RESOURCES and not os.path.exists(os.path.join(self.dir, \"resources.json\")))):\n            logger.info(\"Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES\")\n            download_resources_json(self.dir,\n                                    resources_url=resources_url,\n                                    resources_branch=resources_branch,\n                                    resources_version=resources_version,\n                                    resources_filepath=resources_filepath,\n                                    proxies=proxies)\n\n        # processors can use this to save on the effort of loading\n        # large sub-models, such as pretrained embeddings, bert, etc\n        if foundation_cache is None:\n            self.foundation_cache = FoundationCache(local_files_only=(self.download_method is DownloadMethod.NONE))\n        else:\n            self.foundation_cache = FoundationCache(foundation_cache, local_files_only=(self.download_method is DownloadMethod.NONE))\n\n        # process different pipeline parameters\n        lang, self.dir, package, processors = process_pipeline_parameters(lang, self.dir, package, processors)\n\n        # Load resources.json to obtain latest packages.\n        logger.debug('Loading resource file...')\n        resources = load_resources_json(self.dir, resources_filepath)\n        if lang in resources:\n            if 'alias' in resources[lang]:\n                logger.info(f'\"{lang}\" is an alias for \"{resources[lang][\"alias\"]}\"')\n                lang = resources[lang]['alias']\n            lang_name = resources[lang]['lang_name'] if 'lang_name' in resources[lang] else ''\n        elif allow_unknown_language:\n            logger.warning(\"Trying to create pipeline for unsupported language: %s\", lang)\n            lang_name = langcode_to_lang(lang)\n        else:\n            logger.warning(\"Unsupported language: %s  If trying to add a new language, consider using allow_unknown_language=True\", lang)\n            lang_name = langcode_to_lang(lang)\n\n        # Maintain load list\n        if lang in resources:\n            self.load_list = maintain_processor_list(resources, lang, package, processors, maybe_add_mwt=(not kwargs.get(\"tokenize_pretokenized\")))\n            self.load_list = add_dependencies(resources, lang, self.load_list)\n            if self.download_method is not DownloadMethod.NONE:\n                # skip processors which aren't downloaded from our collection\n                download_list = [x for x in self.load_list if x[0] in resources.get(lang, {})]\n                # skip variants\n                download_list = filter_variants(download_list)\n                # gather up the model list...\n                download_list = flatten_processor_list(download_list)\n                # download_models will skip models we already have\n                download_models(download_list,\n                                resources=resources,\n                                lang=lang,\n                                model_dir=self.dir,\n                                resources_version=resources_version,\n                                proxies=proxies,\n                                log_info=False)\n        elif allow_unknown_language:\n            self.load_list = [(proc, [ModelSpecification(processor=proc, package='default', dependencies=None)])\n                              for proc in list(processors.keys())]\n        else:\n            self.load_list = []\n        self.load_list = self.update_kwargs(kwargs, self.load_list)\n        if len(self.load_list) == 0:\n            if lang not in resources or PACKAGES not in resources[lang]:\n                raise ValueError(f'No processors to load for language {lang}.  Language {lang} is currently unsupported')\n            else:\n                raise ValueError('No processors to load for language {}.  Please check if your language or package is correctly set.'.format(lang))\n        load_table = make_table(['Processor', 'Package'], [(row[0], \";\".join(model_spec.package for model_spec in row[1])) for row in self.load_list])\n        logger.info(f'Loading these models for language: {lang} ({lang_name}):\\n{load_table}')\n\n        self.config = build_default_config(resources, lang, self.dir, self.load_list)\n        self.config.update(kwargs)\n\n        # Load processors\n        self.processors = {}\n\n        # configs that are the same for all processors\n        pipeline_level_configs = {'lang': lang, 'mode': 'predict'}\n\n        if device is None:\n            if use_gpu is None or use_gpu == True:\n                device = default_device()\n            else:\n                device = 'cpu'\n            if use_gpu == True and device == 'cpu':\n                logger.warning(\"GPU requested, but is not available!\")\n        self.device = device\n        logger.info(\"Using device: {}\".format(self.device))\n\n        # set up processors\n        pipeline_reqs_exceptions = []\n        for item in self.load_list:\n            processor_name, _ = item\n            logger.info('Loading: ' + processor_name)\n            curr_processor_config = self.filter_config(processor_name, self.config)\n            curr_processor_config.update(pipeline_level_configs)\n            # TODO: this is obviously a hack\n            # a better solution overall would be to make a pretagged version of the pos annotator\n            # and then subsequent modules can use those tags without knowing where those tags came from\n            if \"pretagged\" in self.config and \"pretagged\" not in curr_processor_config:\n                curr_processor_config[\"pretagged\"] = self.config[\"pretagged\"]\n            logger.debug('With settings: ')\n            logger.debug(curr_processor_config)\n            try:\n                # try to build processor, throw an exception if there is a requirements issue\n                self.processors[processor_name] = NAME_TO_PROCESSOR_CLASS[processor_name](config=curr_processor_config,\n                                                                                          pipeline=self,\n                                                                                          device=self.device)\n            except ProcessorRequirementsException as e:\n                # if there was a requirements issue, add it to list which will be printed at end\n                pipeline_reqs_exceptions.append(e)\n                # add the broken processor to the loaded processors for the sake of analyzing the validity of the\n                # entire proposed pipeline, but at this point the pipeline will not be built successfully\n                self.processors[processor_name] = e.err_processor\n            except FileNotFoundError as e:\n                # For a FileNotFoundError, we try to guess if there's\n                # a missing model directory or file.  If so, we\n                # suggest the user try to download the models\n                if 'model_path' in curr_processor_config:\n                    model_path = curr_processor_config['model_path']\n                    if e.filename == model_path or (isinstance(model_path, (tuple, list)) and e.filename in model_path):\n                        model_path = e.filename\n                    model_dir, model_name = os.path.split(model_path)\n                    lang_dir = os.path.dirname(model_dir)\n                    if lang_dir and not os.path.exists(lang_dir):\n                        # model files for this language can't be found in the expected directory\n                        raise LanguageNotDownloadedError(lang, lang_dir, model_path) from e\n                    if processor_name not in resources[lang]:\n                        # user asked for a model which doesn't exist for this language?\n                        raise UnsupportedProcessorError(processor_name, lang) from e\n                    if not os.path.exists(model_path):\n                        model_name, _ = os.path.splitext(model_name)\n                        # TODO: before recommending this, check that such a thing exists in resources.json.\n                        # currently that case is handled by ignoring the model, anyway\n                        raise FileNotFoundError('Could not find model file %s, although there are other models downloaded for language %s.  Perhaps you need to download a specific model.  Try: stanza.download(lang=\"%s\",package=None,processors={\"%s\":\"%s\"})' % (model_path, lang, lang, processor_name, model_name)) from e\n\n                # if we couldn't find a more suitable description of the\n                # FileNotFoundError, just raise the old error\n                raise\n\n        # if there are any processor exceptions, throw an exception to indicate pipeline build failure\n        if pipeline_reqs_exceptions:\n            logger.info('\\n')\n            raise PipelineRequirementsException(pipeline_reqs_exceptions)\n\n        logger.info(\"Done loading processors!\")\n\n    @staticmethod\n    def update_kwargs(kwargs, processor_list):\n        processor_dict = {processor: [{'package': model_spec.package, 'dependencies': model_spec.dependencies} for model_spec in model_specs]\n                          for (processor, model_specs) in processor_list}\n        for key, value in kwargs.items():\n            pieces = key.split('_', 1)\n            if len(pieces) == 1:\n                continue\n            k, v = pieces\n            if v == 'model_path':\n                package = value if len(value) < 25 else value[:10]+ '...' + value[-10:]\n                original_spec = processor_dict.get(k, [])\n                if len(original_spec) > 0:\n                    dependencies = original_spec[0].get('dependencies')\n                else:\n                    dependencies = None\n                processor_dict[k] = [{'package': package, 'dependencies': dependencies}]\n        processor_list = [(processor, [ModelSpecification(processor=processor, package=model_spec['package'], dependencies=model_spec['dependencies']) for model_spec in processor_dict[processor]]) for processor in processor_dict]\n        processor_list = sort_processors(processor_list)\n        return processor_list\n\n    @staticmethod\n    def filter_config(prefix, config_dict):\n        filtered_dict = {}\n        for key in config_dict.keys():\n            pieces = key.split('_', 1)  # split tokenize_pretokenize to tokenize+pretokenize\n            if len(pieces) == 1:\n                continue\n            k, v = pieces\n            if k == prefix:\n                filtered_dict[v] = config_dict[key]\n        return filtered_dict\n\n    @property\n    def loaded_processors(self):\n        \"\"\"\n        Return all currently loaded processors in execution order.\n        :return: list of Processor instances\n        \"\"\"\n        return [self.processors[processor_name] for processor_name in PIPELINE_NAMES if self.processors.get(processor_name)]\n\n    def process(self, doc, processors=None):\n        \"\"\"\n        Run the pipeline\n\n        processors: allow for a list of processors used by this pipeline action\n          can be list, tuple, set, or comma separated string\n          if None, use all the processors this pipeline knows about\n          MWT is added if necessary\n          otherwise, no care is taken to make sure prerequisites are followed...\n            some of the annotators, such as depparse, will check, but others\n            will fail in some unusual manner or just have really bad results\n        \"\"\"\n        assert any([isinstance(doc, str), isinstance(doc, list),\n                    isinstance(doc, Document)]), 'input should be either str, list or Document'\n\n        # empty bulk process\n        if isinstance(doc, list) and len(doc) == 0:\n            return []\n\n        # determine whether we are in bulk processing mode for multiple documents\n        bulk=(isinstance(doc, list) and len(doc) > 0 and isinstance(doc[0], Document))\n\n        # various options to limit the processors used by this pipeline action\n        if processors is None:\n            processors = PIPELINE_NAMES\n        elif not isinstance(processors, (str, list, tuple, set)):\n            raise ValueError(\"Cannot process {} as a list of processors to run\".format(type(processors)))\n        else:\n            if isinstance(processors, str):\n                processors = {x for x in processors.split(\",\")}\n            else:\n                processors = set(processors)\n            if TOKENIZE in processors and MWT in self.processors and MWT not in processors:\n                logger.debug(\"Requested processors for pipeline did not have mwt, but pipeline needs mwt, so mwt is added\")\n                processors.add(MWT)\n            processors = [x for x in PIPELINE_NAMES if x in processors]\n\n        for processor_name in processors:\n            if self.processors.get(processor_name):\n                process = self.processors[processor_name].bulk_process if bulk else self.processors[processor_name].process\n                doc = process(doc)\n        return doc\n\n    def process_conllu(self, doc, ignore_gapping=True, processors=None):\n        \"\"\" Convenience method: treat the doc as a conllu text, convert it, and process it accordingly \"\"\"\n        if processors is None:\n            processors = set(self.processors.keys())\n            if TOKENIZE in processors:\n                processors.remove(TOKENIZE)\n            if MWT in processors:\n                processors.remove(MWT)\n        doc = CoNLL.conll2doc(input_str=doc, ignore_gapping=ignore_gapping)\n        return self.process(doc, processors=processors)\n\n    def bulk_process(self, docs, *args, **kwargs):\n        \"\"\"\n        Run the pipeline in bulk processing mode\n\n        Expects a list of str or a list of Docs\n        \"\"\"\n        # Wrap each text as a Document unless it is already such a document\n        docs = [doc if isinstance(doc, Document) else Document([], text=doc) for doc in docs]\n        return self.process(docs, *args, **kwargs)\n\n    def stream(self, docs, batch_size=50, *args, **kwargs):\n        \"\"\"\n        Go through an iterator of documents in batches, yield processed documents\n\n        sentence indices will be counted across the entire iterator\n        \"\"\"\n        if not isinstance(docs, collections.abc.Iterator):\n            docs = iter(docs)\n        def next_batch():\n            batch = []\n            for _ in range(batch_size):\n                try:\n                    next_doc = next(docs)\n                    batch.append(next_doc)\n                except StopIteration:\n                    return batch\n            return batch\n\n        sentence_start_index = 0\n        batch = next_batch()\n        while batch:\n            batch = self.bulk_process(batch, *args, **kwargs)\n            for doc in batch:\n                doc.reindex_sentences(sentence_start_index)\n                sentence_start_index += len(doc.sentences)\n                yield doc\n            batch = next_batch()\n\n    def __str__(self):\n        \"\"\"\n        Assemble the processors in order to make a simple description of the pipeline\n        \"\"\"\n        processors = [\"%s=%s\" % (x, str(self.processors[x])) for x in PIPELINE_NAMES if x in self.processors]\n        return \"<Pipeline: %s>\" % \", \".join(processors)\n\n    def __call__(self, doc, processors=None):\n        return self.process(doc, processors)\n\ndef main():\n    # TODO: can add a bunch more arguments\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--lang', type=str, default='en', help='Language of the pipeline to use')\n    parser.add_argument('--input_file', type=str, required=True, help='Input file to read')\n    parser.add_argument('--processors', type=str, default='tokenize,pos,lemma,depparse', help='Processors to use')\n    parser.add_argument('--package', type=str, default='default', help='Which package to use')\n    parser.add_argument('--tokenize_no_ssplit', default=False, action='store_true', help=\"Don't ssplit\")\n    parser.add_argument('--tokenize_pretokenized', default=False, action='store_true', help=\"Text is pretokenized\")\n    args, extra_args = parser.parse_known_args()\n\n    try:\n        doc = CoNLL.conll2doc(args.input_file)\n        extra_args = {\n            \"tokenize_pretokenized\": True\n        }\n    except CoNLLError:\n        logger.debug(\"Input file %s does not appear to be a conllu file.  Will read it as a text file\")\n        with open(args.input_file, encoding=\"utf-8\") as fin:\n            doc = fin.read()\n        extra_args = {}\n    extra_args['package'] = args.package\n    if args.tokenize_no_ssplit:\n        extra_args['tokenize_no_ssplit'] = True\n    if args.tokenize_pretokenized:\n        extra_args['tokenize_pretokenized'] = True\n\n    pipe = Pipeline(args.lang, processors=args.processors, **extra_args)\n\n    doc = pipe(doc)\n\n    print(\"{:C}\".format(doc))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/pipeline/coref_processor.py",
    "content": "\"\"\"\nProcessor that attaches coref annotations to a document\n\"\"\"\n\nfrom stanza.models.common.utils import misc_to_space_after\nfrom stanza.models.coref.coref_chain import CorefMention, CorefChain\nfrom stanza.models.common.doc import Word\n\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\n\nimport torch\n\ndef extract_text(document, sent_id, start_word, end_word):\n    sentence = document.sentences[sent_id]\n    tokens = []\n\n    # the coref model indexes the words from 0,\n    # whereas the ids we are looking at on the tokens start from 1\n    # here we will switch to ID space\n    start_word = start_word + 1\n    end_word = end_word + 1\n\n    # For each position between start and end word:\n    # If a word is part of an MWT, and the entire token\n    # is inside the range, we use that Token's text for that span\n    # This will let us easily handle words which are split into pieces\n    # Otherwise, we only take the text of the word itself\n    next_idx = start_word\n    while next_idx < end_word:\n        word = sentence.words[next_idx-1]\n        parent_token = word.parent\n        if isinstance(parent_token.id, int) or len(parent_token.id) == 1:\n            tokens.append(parent_token)\n            next_idx += 1\n        elif parent_token.id[0] >= start_word and parent_token.id[1] < end_word:\n            tokens.append(parent_token)\n            next_idx = parent_token.id[1] + 1\n        else:\n            tokens.append(word)\n            next_idx += 1\n\n    # We use the SpaceAfter or SpacesAfter attribute on each Word or Token\n    # we chose in the above loop to separate the text pieces\n    text = []\n    for token in tokens:\n        text.append(token.text)\n        text.append(misc_to_space_after(token.misc))\n    # the last token space_after will be discarded\n    # so that we don't have stray WS at the end of the mention text\n    text = text[:-1]\n    return \"\".join(text)\n\n\n@register_processor(COREF)\nclass CorefProcessor(UDProcessor):\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([COREF])\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([TOKENIZE])\n\n    def _set_up_model(self, config, pipeline, device):\n        try:\n            from stanza.models.coref.model import CorefModel\n        except ImportError:\n            raise ImportError(\"Please install the transformers and peft libraries before using coref! Try `pip install -e .[transformers]`.\")\n\n        # set up model\n        # currently, the model has everything packaged in it\n        # (except its config)\n        # TODO: separate any pretrains if possible\n        # TODO: add device parameter to the load mechanism\n        config_update = {'log_norms': config.get('log_norms', False),\n                         'device': device}\n        model = CorefModel.load_model(path=config['model_path'],\n                                      ignore={\"bert_optimizer\", \"general_optimizer\",\n                                              \"bert_scheduler\", \"general_scheduler\"},\n                                      config_update=config_update,\n                                      foundation_cache=pipeline.foundation_cache)\n        if config.get('batch_size', None):\n            model.config.a_scoring_batch_size = int(config['batch_size'])\n        model.training = False\n\n        self._model = model\n\n        # coref_use_zeros=False will turn off creating new nodes and attaching mentions to those zero nodes\n        self._use_zeros = config.get('use_zeros', True)\n        if isinstance(self._use_zeros, str):\n            self._use_zeros = self._use_zeros.lower() != 'false'\n\n    def process(self, document):\n        sentences = document.sentences\n\n        cased_words = []\n        sent_ids = []\n        word_pos = []\n        speaker = []\n        for sent_idx, sentence in enumerate(sentences):\n            for word_idx, word in enumerate(sentence.words):\n                cased_words.append(word.text)\n                sent_ids.append(sent_idx)\n                word_pos.append(word_idx)\n                if sentence.speaker:\n                    speaker.append(sentence.speaker)\n                else:\n                    speaker.append(\"_\")\n\n        coref_input = {\n            \"document_id\": \"wb_doc_1\",\n            \"cased_words\": cased_words,\n            \"sent_id\": sent_ids,\n            \"speaker\": speaker,\n        }\n        coref_input = self._model.build_doc(coref_input)\n        results = self._model.run(coref_input)\n\n        \n        # Handle zero anaphora - zero_scores is always predicted\n        zero_nodes_created = self._handle_zero_anaphora(document, results, sent_ids, word_pos)\n        \n        clusters = []\n        for cluster_idx, span_cluster in enumerate(results.span_clusters):\n            if len(span_cluster) == 0:\n                continue\n            span_cluster = sorted(span_cluster)\n\n            for span in span_cluster:\n                # check there are no sentence crossings before\n                # manipulating the spans, since we will expect it to\n                # be this way for multiple usages of the spans\n                sent_id = sent_ids[span[0]]\n                if sent_ids[span[1]-1] != sent_id:\n                    raise ValueError(\"The coref model predicted a span that crossed two sentences!  Please send this example to us on our github\")\n\n            # treat the longest span as the representative\n            # break ties using the first one\n            # IF there is the POS processor, and it adds upos tags\n            # to the sentence, ties are broken first by maximum\n            # number of UPOS and then earliest in the document\n            max_len = 0\n            best_span = None\n            max_propn = 0\n            for span_idx, span in enumerate(span_cluster):\n                word_idx = results.word_clusters[cluster_idx][span_idx]\n                is_zero = zero_nodes_created.get((cluster_idx, word_idx))\n                if is_zero:\n                    continue\n\n                sent_id = sent_ids[span[0]]\n                sentence = sentences[sent_id]\n                start_word = word_pos[span[0]]\n                # fiddle -1 / +1 so as to avoid problems with coref\n                # clusters that end at exactly the end of a document\n                end_word = word_pos[span[1]-1] + 1\n                # very UD specific test for most number of proper nouns in a mention\n                # will do nothing if POS is not active (they will all be None)\n                num_propn = sum(word.pos == 'PROPN' for word in sentence.words[start_word:end_word])\n\n                if ((span[1] - span[0] > max_len) or\n                    span[1] - span[0] == max_len and num_propn > max_propn):\n                    max_len = span[1] - span[0]\n                    best_span = span_idx\n                    max_propn = num_propn\n\n            mentions = []\n            for span_idx, span in enumerate(span_cluster):\n                word_idx = results.word_clusters[cluster_idx][span_idx]\n                is_zero = zero_nodes_created.get((cluster_idx, word_idx))\n                if is_zero:\n                    (sent_id, zero_word_id) = is_zero\n                    # if the word id is a tuple, it will be attached\n                    # to the zero\n                    mentions.append(\n                        CorefMention(\n                            sent_id, \n                             zero_word_id, \n                             zero_word_id\n                        )\n                    )\n                else:\n                    sent_id = sent_ids[span[0]]\n                    start_word = word_pos[span[0]]\n                    end_word = word_pos[span[1]-1] + 1\n                    mentions.append(CorefMention(sent_id, start_word, end_word))\n                \n            # if we ended up with no best span, then our \"representative text\"\n            # is just underscore\n            if best_span is not None:\n                representative = mentions[best_span]\n                representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word)\n            else:\n                representative_text = \"_\"\n\n            chain = CorefChain(len(clusters), mentions, representative_text, best_span)\n            clusters.append(chain)\n\n        document.coref = clusters\n        return document\n\n    def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):\n        \"\"\"Handle zero anaphora by creating zero nodes and updating coreference clusters.\"\"\"\n        if results.zero_scores is None or results.word_clusters is None:\n            return {}\n        if not self._use_zeros:\n            return {}\n\n        zero_scores = results.zero_scores.squeeze(-1) if results.zero_scores.dim() > 1 else results.zero_scores\n        is_zero = []\n        \n        # Flatten word_clusters to get the word indices that correspond to zero_scores\n        cluster_word_ids = []\n        cluster_mapping = {}\n        counter = 0\n        for indx, cluster in enumerate(results.word_clusters):\n            for _ in range(len(cluster)):\n                cluster_mapping[counter] = indx\n                counter += 1\n            cluster_word_ids.extend(cluster)\n        \n        # Find indices where zero_scores > 0\n        zero_indices = (zero_scores > 0.0).nonzero()\n\n        # this dict maps (cluster_id, word_id) to (cluster_id, start, end)\n        # which overrides span_clusters\n        zero_to_coref = {}\n\n        for zero_idx in zero_indices:\n            zero_idx = zero_idx.item()\n            if zero_idx >= len(cluster_word_ids):\n                continue\n                \n            word_idx = cluster_word_ids[zero_idx]\n            sent_id = sent_ids[word_idx]\n            word_id = word_pos[word_idx]\n            \n            # Create zero node - attach BEFORE the current word\n            # This means the zero node comes after word_id-1 but before word_id\n            zero_word_id = (\n                word_id, \n                len(document.sentences[sent_id]._empty_words)+1\n            )  # attach after word_id-1, before word_id\n            zero_word = Word(document.sentences[sent_id], {\n                \"text\": \"_\", \n                \"lemma\": \"_\", \n                \"id\": zero_word_id\n            })\n            document.sentences[sent_id]._empty_words.append(zero_word)\n            \n            # Track this zero node for adding to coreference clusters\n            cluster_idx = cluster_mapping[zero_idx]\n            zero_to_coref[(cluster_idx, word_idx)] = (\n                sent_id, zero_word_id\n            )\n\n        return zero_to_coref\n"
  },
  {
    "path": "stanza/pipeline/demo/README.md",
    "content": "## Interactive Demo for Stanza\n\n### Requirements\n\nstanza, flask\n\n### Run the demo locally\n\n1. Make sure you know how to disable your browser's CORS rule. For Chrome, [this extension](https://mybrowseraddon.com/access-control-allow-origin.html) works pretty well.\n2. From this directory, start the Stanza demo server\n\n```bash\nexport FLASK_APP=demo_server.py\nflask run\n```\n\n3. In `stanza-brat.js`, uncomment the line at the top that declares `serverAddress` and point it to where your flask is serving the demo server (usually `http://localhost:5000`)\n\n4. Open `stanza-brat.html` in your browser (with CORS disabled) and enjoy!\n\n### Common issues\n\nMake sure you have the models corresponding to the language you want to test out locally before submitting requests to the server! (Models can be obtained by `import stanza; stanza.download(<language_code>)`.\n"
  },
  {
    "path": "stanza/pipeline/demo/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/pipeline/demo/demo_server.py",
    "content": "from flask import Flask, request, abort\nimport json\nimport stanza\nimport os\napp = Flask(__name__, static_url_path='', static_folder=os.path.abspath(os.path.dirname(__file__)))\n\npipelineCache = dict()\n\ndef get_file(path):\n    res = os.path.join(os.path.dirname(os.path.abspath(__file__)), path)\n    print(res)\n    return res\n\n@app.route('/<path:path>')\n@app.route('/static/fonts/<path:path>')\ndef static_file(path):\n    if path in ['stanza-brat.css', 'stanza-brat.js', 'stanza-parseviewer.js', 'loading.gif',\n                'favicon.png', 'stanza-logo.png',\n                'Astloch-Bold.ttf', 'Liberation_Sans-Regular.ttf', 'PT_Sans-Caption-Web-Regular.ttf']:\n        return app.send_static_file(path)\n    elif path in 'index.html':\n        return app.send_static_file('stanza-brat.html')\n    else:\n        abort(403)\n\n@app.route('/', methods=['GET'])\ndef index():\n    return static_file('index.html')\n\n@app.route('/', methods=['POST'])\ndef annotate():\n    global pipelineCache\n\n    properties = request.args.get('properties', '')\n    lang = request.args.get('pipelineLanguage', '')\n    text = list(request.form.keys())[0]\n\n    if lang not in pipelineCache:\n        pipelineCache[lang] = stanza.Pipeline(lang=lang, use_gpu=False)\n\n    res = pipelineCache[lang](text)\n\n    annotated_sentences = []\n    for sentence in res.sentences:\n        tokens = []\n        deps = []\n        for word in sentence.words:\n            tokens.append({'index': word.id, 'word': word.text, 'lemma': word.lemma, 'pos': word.xpos, 'upos': word.upos, 'feats': word.feats, 'ner': word.parent.ner if word.parent.ner is None or word.parent.ner == 'O' else word.parent.ner[2:]})\n            deps.append({'dep': word.deprel, 'governor': word.head, 'governorGloss': sentence.words[word.head-1].text,\n                'dependent': word.id, 'dependentGloss': word.text})\n        annotated_sentences.append({'basicDependencies': deps, 'tokens': tokens})\n        if hasattr(sentence, 'constituency') and sentence.constituency is not None:\n            annotated_sentences[-1]['parse'] = str(sentence.constituency)\n\n    return json.dumps({'sentences': annotated_sentences})\n\ndef create_app():\n    return app\n\nif __name__ == \"__main__\":\n    app.run(host='0.0.0.0', port=8080)\n"
  },
  {
    "path": "stanza/pipeline/demo/stanza-brat.css",
    "content": "\n.red {\n  color:#990000\n}\n\n#wrap {\n  min-height: 100%;\n  height: auto;\n  margin: 0 auto -6ex;\n  padding: 0 0 6ex;\n}\n\n.pattern_tab {\n  margin: 1ex;\n}\n\n.pattern_brat {\n  margin-top: 1ex;\n}\n\n.label {\n  color: #777777;\n  font-size: small;\n}\n\n.footer {\n  bottom: 0;\n  width: 100%;\n  /* Set the fixed height of the footer here */\n  height: 5ex;\n  padding-top: 1ex;\n  margin-top: 1ex;\n  background-color: #f5f5f5;\n}\n\n.corenlp_error {\n  margin-top: 2ex;\n}\n\n/* Styling for parse graph */\n.node rect {\n  stroke: #333;\n  fill: #fff;\n}\n\n.parse-RULE rect {\n  fill: #C0D9AF;\n}\n\n.parse-TERMINAL rect {\n  stroke: #333;\n  fill: #EEE8AA;\n}\n\n.node.highlighted {\n  stroke: #ffff00;\n}\n\n.edgePath path {\n  stroke: #333;\n  fill: #333;\n  stroke-width: 1.5px;\n}\n\n.parse-EDGE path {\n  stroke: DarkGray;\n  fill: DarkGray;\n  stroke-width: 1.5px;\n}\n\n.logo {\n    font-family: \"Lato\", \"Gill Sans MT\", \"Gill Sans\", \"Helvetica\", \"Arial\", sans-serif;\n    font-style: italic;\n}\n"
  },
  {
    "path": "stanza/pipeline/demo/stanza-brat.html",
    "content": "<html>\n<head profile=\"https://www.w3.org/2005/10/profile\">\n  <link rel='icon' href='favicon.png' type='image/png'/ >\n  <!-- JQuery -->\n  <script src=\"https://code.jquery.com/jquery-2.1.4.min.js\"></script>\n  <!-- Bootstrap -->\n  <link rel=\"stylesheet\" href=\"https://maxcdn.bootstrapcdn.com/bootstrap/3.3.1/css/bootstrap.min.css\"/>\n  <link rel=\"stylesheet\" href=\"https://maxcdn.bootstrapcdn.com/bootstrap/3.3.1/css/bootstrap-theme.min.css\"/>\n  <script src=\"https://maxcdn.bootstrapcdn.com/bootstrap/3.3.1/js/bootstrap.min.js\"></script>\n  <!-- Chosen Dropdown Library -->\n  <link rel=\"stylesheet\" href=\"https://cdnjs.cloudflare.com/ajax/libs/chosen/1.4.2/chosen.css\"/>\n  <script src=\"https://cdnjs.cloudflare.com/ajax/libs/chosen/1.4.2/chosen.jquery.min.js\"></script>\n  <!-- Brat -->\n  <link rel=\"stylesheet\" type=\"text/css\" href=\"https://nlp.stanford.edu/js/brat/style-vis.css\"/>\n  <script type=\"text/javascript\" src=\"https://nlp.stanford.edu/js/brat/client/lib/head.load.min.js\"></script>\n  <!-- d3 -->\n  <script type=\"text/javascript\" src=\"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.17/d3.min.js\"></script>\n  <script type=\"text/javascript\" src=\"https://cdnjs.cloudflare.com/ajax/libs/dagre-d3/0.4.17/dagre-d3.min.js\"></script>\n\n    <!-- CoreNLP -->\n  <link rel=\"stylesheet\" type=\"text/css\" href=\"stanza-brat.css\"/>\n  <script type=\"text/javascript\" src=\"stanza-brat.js\"></script>\n\n  <meta charset=\"UTF-8\">\n</head>\n\n<body>\n<div id=\"wrap\">\n<!-- A header bar -->\n<nav class=\"navbar navbar-default navbar-static-top\">\n  <div class=\"container\">\n    <div class=\"navbar-header\">\n        <a class=\"navbar-brand\" style=\"height:70px;font-size:20px\" href=\"https://stanfordnlp.github.io/stanza/\"><img src=\"stanza-logo.png\" height=\"30px\" style=\"display:inline-block; margin-bottom:8px\"/> 1.11.0 (updated October 2025)</a>\n    </div>\n  </div>\n</nav>\n\n<!-- The main content of the page -->\n<div class=\"container\">\n  <div class=\"row\">\n\n    <!-- Text area input -->\n    <form id=\"form_annotate\" accept-charset=\"UTF-8\" onsubmit=\"return false;\">\n    <div class=\"col-sm-12\" style=\"margin-bottom: 5px;\">\n      <label for=\"text\" class=\"label\">&mdash; Text to annotate &mdash;</label>\n      <textarea class=\"form-control\" rows=\"2\" id=\"text\" placeholder=\"e.g., The quick brown fox jumped over the lazy dog.\" autofocus maxlength=\"10000\"></textarea>\n    </div>\n\n    <!-- Annotators select -->\n    <div class=\"col-sm-8\">\n      <label for=\"annotators\" class=\"label\">&mdash; Annotations &mdash;</label>\n      <select id=\"annotators\" data-placeholder=\"CoreNLP annotators\"\n              multiple class=\"chosen-select\" title=\"Select CoreNLP annotators\">\n        <option value=\"pos\"            selected > parts-of-speech           </option>\n        <option value=\"upos\"                    > universal parts-of-speech </option>\n        <option value=\"ner\"            selected > named entities            </option>\n        <option value=\"lemma\"          selected > lemmas                    </option>\n        <option value=\"depparse\"       selected > dependency parse          </option>\n        <option value=\"parse\"          selected > constituency parse        </option>\n      </select>\n    </div>\n\n    <div class=\"col-sm-2\">\n        <label for=\"language\" class=\"label\">&mdash; Language &mdash;</label>\n        <select id=\"language\" data-placeholder=\"Language\"\n                class=\"chosen-select\" title=\"Language\">\n                    <option value=\"af\">Afrikaans</option>\n                    <option value=\"grc\">Ancient Greek</option>\n                    <option value=\"ar\">Arabic</option>\n                    <option value=\"hy\">Armenian</option>\n                    <option value=\"eu\">Basque</option>\n                    <option value=\"be\">Belarusian</option>\n                    <option value=\"bg\">Bulgarian</option>\n                    <option value=\"bxr\">Buryat</option>\n                    <option value=\"ca\">Catalan</option>\n                    <option value=\"zh\">Chinese (simplified)</option>\n                    <option value=\"zh-Hant\">Chinese (traditional)</option>\n                    <option value=\"lzh\">Classical Chinese</option>\n                    <option value=\"cop\">Coptic</option>\n                    <option value=\"hr\">Croatian</option>\n                    <option value=\"cs\">Czech</option>\n                    <option value=\"da\">Danish</option>\n                    <option value=\"nl\">Dutch</option>\n                    <option value=\"en\" selected>English</option>\n                    <option value=\"et\">Estonian</option>\n                    <option value=\"fi\">Finnish</option>\n                    <option value=\"fr\">French</option>\n                    <option value=\"gl\">Galician</option>\n                    <option value=\"de\">German</option>\n                    <option value=\"got\">Gothic</option>\n                    <option value=\"el\">Greek</option>\n                    <option value=\"he\">Hebrew</option>\n                    <option value=\"hi\">Hindi</option>\n                    <option value=\"hu\">Hungarian</option>\n                    <option value=\"id\">Indonesian</option>\n                    <option value=\"ga\">Irish</option>\n                    <option value=\"it\">Italian</option>\n                    <option value=\"ja\">Japanese</option>\n                    <option value=\"kk\">Kazakh</option>\n                    <option value=\"ko\">Korean</option>\n                    <option value=\"kmr\">Kurmanji</option>\n                    <option value=\"la\">Latin</option>\n                    <option value=\"lv\">Latvian</option>\n                    <option value=\"lt\">Lithuanian</option>\n                    <option value=\"olo\">Livvi</option>\n                    <option value=\"mt\">Maltese</option>\n                    <option value=\"mr\">Marathi</option>\n                    <option value=\"sme\">North Sami</option>\n                    <option value=\"no\">Norwegian (Bokmål)</option>\n                    <option value=\"nn\">Norwegian (Nynorsk)</option>\n                    <option value=\"cu\">Old Church Slavonic</option>\n                    <option value=\"fro\">Old French</option>\n                    <option value=\"orv\">Old Russian</option>\n                    <option value=\"fa\">Persian</option>\n                    <option value=\"pl\">Polish</option>\n                    <option value=\"pt\">Portuguese</option>\n                    <option value=\"ro\">Romanian</option>\n                    <option value=\"ru\">Russian</option>\n                    <option value=\"gd\">Scottish Gaelic</option>\n                    <option value=\"sr\">Serbian</option>\n                    <option value=\"sk\">Slovak</option>\n                    <option value=\"sl\">Slovenian</option>\n                    <option value=\"es\">Spanish</option>\n                    <option value=\"sv\">Swedish</option>\n                    <option value=\"swl\">Swedish Sign Language</option>\n                    <option value=\"ta\">Tamil</option>\n                    <option value=\"te\">Telugu</option>\n                    <option value=\"tr\">Turkish</option>\n                    <option value=\"uk\">Ukrainian</option>\n                    <option value=\"hsb\">Upper Sorbian</option>\n                    <option value=\"ur\">Urdu</option>\n                    <option value=\"ug\">Uyghur</option>\n                    <option value=\"vi\">Vietnamese</option>\n                    <option value=\"wo\">Wolof</option>\n\n        </select>\n    </div>\n\n    <!-- Submit button -->\n    <div class=\"col-sm-2\" style=\"text-align: center; margin-top: 7px; \">\n        <button id=\"submit\" class=\"btn btn-block btn-default\">Submit</button>\n    </div>\n    </form>\n\n  </div>\n  <div class=\"row\">\n    <!-- A panel for errors to show up in -->\n    <div id=\"errors\" class=\"row\">\n    </div>\n\n    <!-- Loading gif -->\n    <div id=\"loading\" class=\"row\" style=\"display:none\">\n      <img src=\"loading.gif\" height=\"200px\" style=\"margin-left: 200px\"/>\n    </div>\n\n    <!-- Annotation population area -->\n    <div id=\"annotations\" class=\"row\" style=\"display:none\">\n    </div>\n  </div>\n\n\n</div>\n</div>\n\n<!-- The footer of the page -->\n<footer id=\"footer\" class=\"footer\">\n  <div class=\"container\">\n    <p class=\"text-muted\">\n      Visualisation provided using the <a href=\"http://brat.nlplab.org/\">brat visualisation/annotation software</a>.\n    </p>\n  </div>\n</footer>\n\n</body>\n</html>\n"
  },
  {
    "path": "stanza/pipeline/demo/stanza-brat.js",
    "content": "// Takes Stanford CoreNLP JSON output (var data = ... in data.js)\n// and uses brat to render everything.\n\n//var serverAddress = 'http://localhost:5000';\n\n// Load Brat libraries\nvar bratLocation = 'https://nlp.stanford.edu/js/brat/';\nhead.js(\n  // External libraries\n  bratLocation + '/client/lib/jquery.svg.min.js',\n  bratLocation + '/client/lib/jquery.svgdom.min.js',\n\n  // brat helper modules\n  bratLocation + '/client/src/configuration.js',\n  bratLocation + '/client/src/util.js',\n  bratLocation + '/client/src/annotation_log.js',\n  bratLocation + '/client/lib/webfont.js',\n\n  // brat modules\n  bratLocation + '/client/src/dispatcher.js',\n  bratLocation + '/client/src/url_monitor.js',\n  bratLocation + '/client/src/visualizer.js',\n\n  // parse viewer\n  './stanza-parseviewer.js'\n);\n\n// Uses Dagre (https://github.com/cpettitt/dagre) for constinuency parse\n// visualization. It works better than the brat visualization.\nvar useDagre = true;\nvar currentQuery = 'The quick brown fox jumped over the lazy dog.';\nvar currentSentences = '';\nvar currentText = '';\n\n// ----------------------------------------------------------------------------\n// HELPERS\n// ----------------------------------------------------------------------------\n\n/**\n * Add the startsWith function to the String class\n */\nif (typeof String.prototype.startsWith !== 'function') {\n  // see below for better implementation!\n  String.prototype.startsWith = function (str){\n    return this.indexOf(str) === 0;\n  };\n}\n\nfunction isInt(value) {\n  return !isNaN(value) && (function(x) { return (x | 0) === x; })(parseFloat(value))\n}\n\n/**\n * A reverse map of PTB tokens to their original gloss\n */\nvar tokensMap = {\n  '-LRB-': '(',\n  '-RRB-': ')',\n  '-LSB-': '[',\n  '-RSB-': ']',\n  '-LCB-': '{',\n  '-RCB-': '}',\n  '``': '\"',\n  '\\'\\'': '\"',\n};\n\n/**\n * A mapping from part of speech tag to the associated\n * visualization color\n */\nfunction posColor(posTag) {\n  if (posTag === null) {\n    return '#E3E3E3';\n  } else if (posTag.startsWith('N')) {\n    return '#A4BCED';\n  } else if (posTag.startsWith('V') || posTag.startsWith('M')) {\n    return '#ADF6A2';\n  } else if (posTag.startsWith('P')) {\n    return '#CCDAF6';\n  } else if (posTag.startsWith('I')) {\n    return '#FFE8BE';\n  } else if (posTag.startsWith('R') || posTag.startsWith('W')) {\n    return '#FFFDA8';\n  } else if (posTag.startsWith('D') || posTag === 'CD') {\n    return '#CCADF6';\n  } else if (posTag.startsWith('J')) {\n    return '#FFFDA8';\n  } else if (posTag.startsWith('T')) {\n    return '#FFE8BE';\n  } else if (posTag.startsWith('E') || posTag.startsWith('S')) {\n    return '#E4CBF6';\n  } else if (posTag.startsWith('CC')) {\n    return '#FFFFFF';\n  } else if (posTag === 'LS' || posTag === 'FW') {\n    return '#FFFFFF';\n  } else {\n    return '#E3E3E3';\n  }\n}\n\n/**\n * A mapping from part of speech tag to the associated\n * visualization color\n */\nfunction uposColor(posTag) {\n  if (posTag === null) {\n    return '#E3E3E3';\n  } else if (posTag === 'NOUN' || posTag === 'PROPN') {\n    return '#A4BCED';\n  } else if (posTag.startsWith('V') || posTag === 'AUX') {\n    return '#ADF6A2';\n  } else if (posTag === 'PART') {\n    return '#CCDAF6';\n  } else if (posTag === 'ADP') {\n    return '#FFE8BE';\n  } else if (posTag === 'ADV' || posTag.startsWith('PRON')) {\n    return '#FFFDA8';\n  } else if (posTag === 'NUM' || posTag === 'DET') {\n    return '#CCADF6';\n  } else if (posTag === 'ADJ') {\n    return '#FFFDA8';\n  } else if (posTag.startsWith('E') || posTag.startsWith('S')) {\n    return '#E4CBF6';\n  } else if (posTag.startsWith('CC')) {\n    return '#FFFFFF';\n  } else if (posTag === 'X' || posTag === 'FW') {\n    return '#FFFFFF';\n  } else {\n    return '#E3E3E3';\n  }\n}\n\n/**\n * A mapping from named entity tag to the associated\n * visualization color\n */\nfunction nerColor(nerTag) {\n  if (nerTag === null) {\n    return '#E3E3E3';\n  } else if (nerTag === 'PERSON' || nerTag === 'PER') {\n    return '#FFCCAA';\n  } else if (nerTag === 'ORGANIZATION' || nerTag === 'ORG') {\n    return '#8FB2FF';\n  } else if (nerTag === 'MISC') {\n    return '#F1F447';\n  } else if (nerTag === 'LOCATION' || nerTag == 'LOC') {\n    return '#95DFFF';\n  } else if (nerTag === 'DATE' || nerTag === 'TIME' || nerTag === 'SET') {\n    return '#9AFFE6';\n  } else if (nerTag === 'MONEY') {\n    return '#FFFFFF';\n  } else if (nerTag === 'PERCENT') {\n    return '#FFA22B';\n  } else {\n    return '#E3E3E3';\n  }\n}\n\n\n/**\n * A mapping from sentiment value to the associated\n * visualization color\n */\nfunction sentimentColor(sentiment) {\n  if (sentiment === \"VERY POSITIVE\") {\n    return '#00FF00';\n  } else if (sentiment === \"POSITIVE\") {\n    return '#7FFF00';\n  } else if (sentiment === \"NEUTRAL\") {\n    return '#FFFF00';\n  } else if (sentiment === \"NEGATIVE\") {\n    return '#FF7F00';\n  } else if (sentiment === \"VERY NEGATIVE\") {\n    return '#FF0000';\n  } else {\n    return '#E3E3E3';\n  }\n}\n\n\n/**\n * Get a list of annotators, from the annotator option input.\n */\nfunction annotators() {\n  var annotators = \"tokenize,ssplit\";\n  $('#annotators').find('option:selected').each(function () {\n    annotators += \",\" + $(this).val();\n  });\n  return annotators;\n}\n\n/**\n * Get the input date\n */\nfunction date() {\n  function f(n) {\n    return n < 10 ? '0' + n : n;\n  }\n  var date = new Date();\n  var M = date.getMonth() + 1;\n  var D = date.getDate();\n  var Y = date.getFullYear();\n  var h = date.getHours();\n  var m = date.getMinutes();\n  var s = date.getSeconds();\n  return \"\" + Y + \"-\" + f(M) + \"-\" + f(D) + \"T\" + f(h) + ':' + f(m) + ':' + f(s);\n}\n\n\n//-----------------------------------------------------------------------------\n// Constituency parser\n//-----------------------------------------------------------------------------\nfunction ConstituencyParseProcessor() {\n  var parenthesize = function (input, list) {\n    if (list === undefined) {\n      return parenthesize(input, []);\n    } else {\n      var token = input.shift();\n      if (token === undefined) {\n        return list.pop();\n      } else if (token === \"(\") {\n        list.push(parenthesize(input, []));\n        return parenthesize(input, list);\n      } else if (token === \")\") {\n        return list;\n      } else {\n        return parenthesize(input, list.concat(token));\n      }\n    }\n  };\n\n  var toTree = function (list) {\n    if (list.length === 2 && typeof list[1] === 'string') {\n      return {label: list[0], text: list[1], isTerminal: true};\n    } else if (list.length >= 2) {\n      var label = list.shift();\n      var node = {label: label};\n      var rest = list.map(function (x) {\n        var t = toTree(x);\n        if (typeof t === 'object') {\n          t.parent = node;\n        }\n        return t;\n      });\n      node.children = rest;\n      return node;\n    } else {\n      return list;\n    }\n  };\n\n  var indexTree = function (tree, tokens, index) {\n    index = index || 0;\n    if (tree.isTerminal) {\n      tree.token = tokens[index];\n      tree.tokenIndex = index;\n      tree.tokenStart = index;\n      tree.tokenEnd = index + 1;\n      return index + 1;\n    } else if (tree.children) {\n      tree.tokenStart = index;\n      for (var i = 0; i < tree.children.length; i++) {\n        var child = tree.children[i];\n        index = indexTree(child, tokens, index);\n      }\n      tree.tokenEnd = index;\n    }\n    return index;\n  };\n\n  var tokenize = function (input) {\n    return input.split('\"')\n      .map(function (x, i) {\n        if (i % 2 === 0) { // not in string\n          return x.replace(/\\(/g, ' ( ')\n            .replace(/\\)/g, ' ) ');\n        } else { // in string\n          return x.replace(/ /g, \"!whitespace!\");\n        }\n      })\n      .join('\"')\n      .trim()\n      .split(/\\s+/)\n      .map(function (x) {\n        return x.replace(/!whitespace!/g, \" \");\n      });\n  };\n\n  var convertParseStringToTree = function (input, tokens) {\n    var p = parenthesize(tokenize(input));\n    if (Array.isArray(p)) {\n      var tree = toTree(p);\n      // Correlate tree with tokens\n      indexTree(tree, tokens);\n      return tree;\n    }\n  };\n\n  this.process = function(annotation) {\n    for (var i = 0; i < annotation.sentences.length; i++) {\n      var s = annotation.sentences[i];\n      if (s.parse) {\n        s.parseTree = convertParseStringToTree(s.parse, s.tokens);\n      }\n    }\n  }\n}\n\n// ----------------------------------------------------------------------------\n// RENDER\n// ----------------------------------------------------------------------------\n\n/**\n * Render a given JSON data structure\n */\nfunction render(data, reverse) {\n  // Tweak arguments\n  if (typeof reverse !== 'boolean') {\n    reverse = false;\n  }\n\n  // Error checks\n  if (typeof data.sentences === 'undefined') { return; }\n\n  /**\n   * Register an entity type (a tag) for Brat\n   */\n  var entityTypesSet = {};\n  var entityTypes = [];\n  function addEntityType(name, type, coarseType) {\n    if (typeof coarseType === \"undefined\") {\n      coarseType = type;\n    }\n    // Don't add duplicates\n    if (entityTypesSet[type]) return;\n    entityTypesSet[type] = true;\n    // Get the color of the entity type\n    color = '#ffccaa';\n    if (name === 'POS') {\n      color = posColor(type);\n    } else if (name === 'UPOS') {\n      color = uposColor(type);\n    } else if (name === 'NER') {\n      color = nerColor(coarseType);\n    } else if (name === 'NNER') {\n      color = nerColor(coarseType);\n    } else if (name === 'COREF') {\n      color = '#FFE000';\n    } else if (name === 'ENTITY') {\n      color = posColor('NN');\n    } else if (name === 'RELATION') {\n      color = posColor('VB');\n    } else if (name === 'LEMMA') {\n      color = '#FFFFFF';\n    } else if (name === 'SENTIMENT') {\n      color = sentimentColor(type);\n    } else if (name === 'LINK') {\n      color = '#FFFFFF';\n    } else if (name === 'KBP_ENTITY') {\n      color = '#FFFFFF';\n    }\n    // Register the type\n    entityTypes.push({\n      type: type,\n      labels : [type],\n      bgColor: color,\n      borderColor: 'darken'\n    });\n  }\n\n  /**\n   * Register a relation type (an arc) for Brat\n   */\n  var relationTypesSet = {};\n  var relationTypes = [];\n  function addRelationType(type, symmetricEdge) {\n    // Prevent adding duplicates\n    if (relationTypesSet[type]) return;\n    relationTypesSet[type] = true;\n    // Default arguments\n    if (typeof symmetricEdge === 'undefined') { symmetricEdge = false; }\n    // Add the type\n    relationTypes.push({\n      type: type,\n      labels: [type],\n      dashArray: (symmetricEdge ? '3,3' : undefined),\n      arrowHead: (symmetricEdge ? 'none' : undefined),\n    });\n  }\n\n  //\n  // Construct text of annotation\n  //\n  currentText = [];  // GLOBAL\n  currentSentences = data.sentences;  // GLOBAL\n  data.sentences.forEach(function(sentence) {\n    for (var i = 0; i < sentence.tokens.length; ++i) {\n      var token = sentence.tokens[i];\n      var word = token.word;\n      if (!(typeof tokensMap[word] === \"undefined\")) {\n        word = tokensMap[word];\n      }\n      if (i > 0) { currentText.push(' '); }\n      token.characterOffsetBegin = currentText.length;\n      for (var j = 0; j < word.length; ++j) {\n        currentText.push(word[j]);\n      }\n      token.characterOffsetEnd = currentText.length;\n    }\n    currentText.push('\\n');\n  });\n  currentText = currentText.join('');\n\n  //\n  // Shared variables\n  // These are what we'll render in BRAT\n  //\n  // (pos)\n  var posEntities = [];\n  // (upos)\n  var uposEntities = [];\n  // (lemma)\n  var lemmaEntities = [];\n  // (ner)\n  var nerEntities = [];\n  var nerEntitiesNormalized = [];\n  // (sentiment)\n  var sentimentEntities = [];\n  // (entitylinking)\n  var linkEntities = [];\n  // (dependencies)\n  var depsRelations = [];\n  var deps2Relations = [];\n  // (openie)\n  var openieEntities = [];\n  var openieEntitiesSet = {};\n  var openieRelations = [];\n  var openieRelationsSet = {};\n  // (kbp)\n  var kbpEntities = [];\n  var kbpEntitiesSet = [];\n  var kbpRelations = [];\n  var kbpRelationsSet = [];\n\n  var cparseEntities = [];\n  var cparseRelations = [];\n\n  //\n  // Loop over sentences.\n  // This fills in the variables above.\n  //\n  for (var sentI = 0; sentI < data.sentences.length; ++sentI) {\n    var sentence = data.sentences[sentI];\n    var index = sentence.index;\n    var tokens = sentence.tokens;\n    var deps = sentence['basicDependencies'];\n    var deps2 = sentence['enhancedPlusPlusDependencies'];\n    var parseTree = sentence['parseTree'];\n\n    // POS tags\n    /**\n     * Generate a POS tagged token id\n     */\n    function posID(i) {\n      return 'POS_' + sentI + '_' + i;\n    }\n    var noXPOS = true;\n    if (tokens.length > 0 && typeof tokens[0].pos !== 'undefined' && tokens[0].pos !== null) {\n      noXPOS = false;\n      for (var i = 0; i < tokens.length; i++) {\n        var token = tokens[i];\n        var pos = token.pos;\n        var begin = parseInt(token.characterOffsetBegin);\n        var end = parseInt(token.characterOffsetEnd);\n        addEntityType('POS', pos);\n        posEntities.push([posID(i), pos, [[begin, end]]]);\n      }\n    }\n\n    // Universal POS tags\n    /**\n     * Generate a POS tagged token id\n     */\n    function uposID(i) {\n      return 'UPOS_' + sentI + '_' + i;\n    }\n    if (tokens.length > 0 && typeof tokens[0].upos !== 'undefined') {\n      for (var i = 0; i < tokens.length; i++) {\n        var token = tokens[i];\n        var upos = token.upos;\n        var begin = parseInt(token.characterOffsetBegin);\n        var end = parseInt(token.characterOffsetEnd);\n        addEntityType('UPOS', upos);\n        uposEntities.push([uposID(i), upos, [[begin, end]]]);\n      }\n    }\n\n    // Constituency parse\n    // Carries the same assumption as NER\n    if (parseTree && !useDagre) {\n      var parseEntities = [];\n      var parseRels = [];\n      function processParseTree(tree, index) {\n        tree.visitIndex = index;\n        index++;\n        if (tree.isTerminal) {\n          parseEntities[tree.visitIndex] = uposEntities[tree.tokenIndex];\n          return index;\n        } else if (tree.children) {\n          addEntityType('PARSENODE', tree.label);\n          parseEntities[tree.visitIndex] =\n            ['PARSENODE_' + sentI + '_' + tree.visitIndex, tree.label,\n              [[tokens[tree.tokenStart].characterOffsetBegin, tokens[tree.tokenEnd-1].characterOffsetEnd]]];\n          var parentEnt = parseEntities[tree.visitIndex];\n          for (var i = 0; i < tree.children.length; i++) {\n            var child = tree.children[i];\n            index = processParseTree(child, index);\n            var childEnt = parseEntities[child.visitIndex];\n            addRelationType('pc');\n            parseRels.push(['PARSEEDGE_' + sentI + '_' + parseRels.length, 'pc', [['parent', parentEnt[0]], ['child', childEnt[0]]]]);\n          }\n        }\n        return index;\n      }\n      processParseTree(parseTree, 0);\n      cparseEntities = cparseEntities.concat(cparseEntities, parseEntities);\n      cparseRelations = cparseRelations.concat(parseRels);\n    }\n\n    // Dependency parsing\n    /**\n     * Process a dependency tree from JSON to Brat relations\n     */\n    function processDeps(name, deps) {\n      var relations = [];\n      // Format: [${ID}, ${TYPE}, [[${ARGNAME}, ${TARGET}], [${ARGNAME}, ${TARGET}]]]\n      for (var i = 0; i < deps.length; i++) {\n        var dep = deps[i];\n        var governor = dep.governor - 1;\n        var dependent = dep.dependent - 1;\n        if (governor == -1) continue;\n        addRelationType(dep.dep);\n        relations.push([name + '_' + sentI + '_' + i, dep.dep, [['governor', uposID(governor)], ['dependent', uposID(dependent)]]]);\n      }\n      return relations;\n    }\n    // Actually add the dependencies\n    if (typeof deps !== 'undefined') {\n      depsRelations = depsRelations.concat(processDeps('dep', deps));\n    }\n    if (typeof deps2 !== 'undefined') {\n      deps2Relations = deps2Relations.concat(processDeps('dep2', deps2));\n    }\n\n    // Lemmas\n    if (tokens.length > 0 && typeof tokens[0].lemma !== 'undefined') {\n      for (var i = 0; i < tokens.length; i++) {\n        var token = tokens[i];\n        var lemma = token.lemma;\n        var begin = parseInt(token.characterOffsetBegin);\n        var end = parseInt(token.characterOffsetEnd);\n        addEntityType('LEMMA', lemma);\n        lemmaEntities.push(['LEMMA_' + sentI + '_' + i, lemma, [[begin, end]]]);\n      }\n    }\n\n    // NER tags\n    // Assumption: contiguous occurrence of one non-O is a single entity\n    var noNER = true;\n    if (tokens.some(function(token) { return token.ner; })) {\n      noNER = false;\n      for (var i = 0; i < tokens.length; i++) {\n        var ner = tokens[i].ner || 'O';\n        var normalizedNER = tokens[i].normalizedNER;\n        if (typeof normalizedNER === \"undefined\") {\n          normalizedNER = ner;\n        }\n        if (ner == 'O') continue;\n        var j = i;\n        while (j < tokens.length - 1 && tokens[j+1].ner == ner) j++;\n        addEntityType('NER', ner, ner);\n        nerEntities.push(['NER_' + sentI + '_' + i, ner, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);\n        if (ner != normalizedNER) {\n          addEntityType('NNER', normalizedNER, ner);\n          nerEntities.push(['NNER_' + sentI + '_' + i, normalizedNER, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);\n\n        }\n        i = j;\n      }\n    }\n\n    // Sentiment\n    if (typeof sentence.sentiment !== \"undefined\") {\n      var sentiment = sentence.sentiment.toUpperCase().replace(\"VERY\", \"VERY \");\n      addEntityType('SENTIMENT', sentiment);\n      sentimentEntities.push(['SENTIMENT_' + sentI, sentiment,\n        [[tokens[0].characterOffsetBegin, tokens[tokens.length - 1].characterOffsetEnd]]]);\n    }\n\n    // Entity Links\n    // Carries the same assumption as NER\n    if (tokens.length > 0) {\n      for (var i = 0; i < tokens.length; i++) {\n        var link = tokens[i].entitylink;\n        if (link == 'O' || typeof link === 'undefined') continue;\n        var j = i;\n        while (j < tokens.length - 1 && tokens[j+1].entitylink == link) j++;\n        addEntityType('LINK', link);\n        linkEntities.push(['LINK_' + sentI + '_' + i, link, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);\n        i = j;\n      }\n    }\n\n    // Open IE\n    // Helper Functions\n    function openieID(span) {\n      return 'OPENIEENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1];\n    }\n    function addEntity(span, role) {\n      // Don't add duplicate entities\n      if (openieEntitiesSet[[sentI, span, role]]) return;\n      openieEntitiesSet[[sentI, span, role]] = true;\n      // Add the entity\n      openieEntities.push([openieID(span), role,\n        [[tokens[span[0]].characterOffsetBegin,\n          tokens[span[1] - 1].characterOffsetEnd ]] ]);\n    }\n    function addRelation(gov, dep, role) {\n      // Don't add duplicate relations\n      if (openieRelationsSet[[sentI, gov, dep, role]]) return;\n      openieRelationsSet[[sentI, gov, dep, role]] = true;\n      // Add the relation\n      openieRelations.push(['OPENIESUBJREL_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1],\n                           role,\n                           [['governor',  openieID(gov)],\n                            ['dependent', openieID(dep)]  ] ]);\n    }\n    // Render OpenIE\n    if (typeof sentence.openie !== 'undefined') {\n      // Register the entities + relations we'll need\n      addEntityType('ENTITY',  'Entity');\n      addEntityType('RELATION', 'Relation');\n      addRelationType('subject');\n      addRelationType('object');\n      // Loop over triples\n      for (var i = 0; i < sentence.openie.length; ++i) {\n        var subjectSpan = sentence.openie[i].subjectSpan;\n        var relationSpan = sentence.openie[i].relationSpan;\n        var objectSpan = sentence.openie[i].objectSpan;\n        if (parseInt(relationSpan[0]) < 0  || parseInt(relationSpan[1]) < 0) {\n          continue;  // This is a phantom relation\n        }\n        var begin = parseInt(token.characterOffsetBegin);\n        // Add the entities\n        addEntity(subjectSpan, 'Entity');\n        addEntity(relationSpan, 'Relation');\n        addEntity(objectSpan, 'Entity');\n        // Add the relations\n        addRelation(relationSpan, subjectSpan, 'subject');\n        addRelation(relationSpan, objectSpan, 'object');\n      }\n    }  // End OpenIE block\n\n\n    //\n    // KBP\n    //\n    // Helper Functions\n    function kbpEntity(span) {\n      return 'KBPENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1];\n    }\n    function addKBPEntity(span, role) {\n      // Don't add duplicate entities\n      if (kbpEntitiesSet[[sentI, span, role]]) return;\n      kbpEntitiesSet[[sentI, span, role]] = true;\n      // Add the entity\n      kbpEntities.push([kbpEntity(span), role,\n        [[tokens[span[0]].characterOffsetBegin,\n          tokens[span[1] - 1].characterOffsetEnd ]] ]);\n    }\n    function addKBPRelation(gov, dep, role) {\n      // Don't add duplicate relations\n      if (kbpRelationsSet[[sentI, gov, dep, role]]) return;\n      kbpRelationsSet[[sentI, gov, dep, role]] = true;\n      // Add the relation\n      kbpRelations.push(['KBPRELATION_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1],\n                           role,\n                           [['governor',  kbpEntity(gov)],\n                            ['dependent', kbpEntity(dep)]  ] ]);\n    }\n    if (typeof sentence.kbp !== 'undefined') {\n      // Register the entities + relations we'll need\n      addRelationType('subject');\n      addRelationType('object');\n      // Loop over triples\n      for (var i = 0; i < sentence.kbp.length; ++i) {\n        var subjectSpan = sentence.kbp[i].subjectSpan;\n        var subjectLink = 'Entity';\n        for (var k = subjectSpan[0]; k < subjectSpan[1]; ++k) {\n          if (subjectLink == 'Entity' &&\n              typeof tokens[k] !== 'undefined' &&\n              tokens[k].entitylink != 'O' &&\n              typeof tokens[k].entitylink !== 'undefined') {\n            subjectLink = tokens[k].entitylink\n          }\n        }\n        addEntityType('KBP_ENTITY',  subjectLink);\n        var objectSpan = sentence.kbp[i].objectSpan;\n        var objectLink = 'Entity';\n        for (var k = objectSpan[0]; k < objectSpan[1]; ++k) {\n          if (objectLink == 'Entity' &&\n              typeof tokens[k] !== 'undefined' &&\n              tokens[k].entitylink != 'O' &&\n              typeof tokens[k].entitylink !== 'undefined') {\n            objectLink = tokens[k].entitylink\n          }\n        }\n        addEntityType('KBP_ENTITY',  objectLink);\n        var relation = sentence.kbp[i].relation;\n        var begin = parseInt(token.characterOffsetBegin);\n        // Add the entities\n        addKBPEntity(subjectSpan, subjectLink);\n        addKBPEntity(objectSpan, objectLink);\n        // Add the relations\n        addKBPRelation(subjectSpan, objectSpan, relation);\n      }\n    }  // End KBP block\n\n  }  // End sentence loop\n\n  //\n  // Coreference\n  //\n  var corefEntities = [];\n  var corefRelations = [];\n  if (typeof data.corefs !== 'undefined') {\n    addRelationType('coref', true);\n    addEntityType('COREF', 'Mention');\n    var clusters = Object.keys(data.corefs);\n    clusters.forEach( function (clusterId) {\n      var chain = data.corefs[clusterId];\n      if (chain.length > 1) {\n        for (var i = 0; i < chain.length; ++i) {\n          var mention = chain[i];\n          var id = 'COREF' + mention.id;\n          var tokens = data.sentences[mention.sentNum - 1].tokens;\n          corefEntities.push([id, 'Mention',\n            [[tokens[mention.startIndex - 1].characterOffsetBegin,\n              tokens[mention.endIndex - 2].characterOffsetEnd      ]] ]);\n          if (i > 0) {\n            var lastId = 'COREF' + chain[i - 1].id;\n            corefRelations.push(['COREF' + chain[i-1].id + '_' + chain[i].id,\n                                 'coref',\n                                 [['governor', lastId],\n                                  ['dependent', id]    ] ]);\n          }\n        }\n      }\n    });\n  }  // End coreference block\n\n  //\n  // Actually render the elements\n  //\n\n  /**\n   * Helper function to render a given set of entities / relations\n   * to a Div, if it exists.\n   */\n  function embed(container, entities, relations, reverse) {\n    var text = currentText;\n    if (reverse) {\n      var length = currentText.length;\n      for (var i = 0; i < entities.length; ++i) {\n        var offsets = entities[i][2][0];\n        var tmp = length - offsets[0];\n        offsets[0] = length - offsets[1];\n        offsets[1] = tmp;\n      }\n      text = text.split(\"\").reverse().join(\"\");\n    }\n    if ($('#' + container).length > 0) {\n      Util.embed(container,\n                 {entity_types: entityTypes, relation_types: relationTypes},\n                 {text: text, entities: entities, relations: relations}\n                );\n    }\n  }\n\n  function reportna(container, text) {\n    $('#' + container).text(text);\n  }\n\n  // Render each annotation\n  head.ready(function() {\n    if (!noXPOS) {\n      embed('pos', posEntities);\n    } else {\n      reportna('pos', 'XPOS is not available for this language at this time.')\n    }\n    embed('upos', uposEntities);\n    embed('lemma', lemmaEntities);\n    if (!noNER) {\n      embed('ner', nerEntities);\n    } else {\n      reportna('ner', 'NER is not available for this language at this time.')\n    }\n    embed('entities', linkEntities);\n    if (!useDagre) {\n      embed('parse', cparseEntities, cparseRelations);\n    }\n    embed('deps', uposEntities, depsRelations);\n    embed('deps2', posEntities, deps2Relations);\n    embed('coref', corefEntities, corefRelations);\n    embed('openie', openieEntities, openieRelations);\n    embed('kbp',    kbpEntities, kbpRelations);\n    embed('sentiment', sentimentEntities);\n\n    // Constituency parse\n    // Uses d3 and dagre-d3 (not brat)\n    if ($('#parse').length > 0 && useDagre) {\n      var parseViewer = new ParseViewer({ selector: '#parse' });\n      parseViewer.showAnnotation(data);\n      $('#parse').addClass('svg').css('display', 'block');\n    }\n  });\n\n}  // End render function\n\n\n/**\n * Render a TokensRegex response\n */\nfunction renderTokensregex(data) {\n  /**\n   * Register an entity type (a tag) for Brat\n   */\n  var entityTypesSet = {};\n  var entityTypes = [];\n  function addEntityType(type, color) {\n    // Don't add duplicates\n    if (entityTypesSet[type]) return;\n    entityTypesSet[type] = true;\n    // Set the color\n    if (typeof color === 'undefined') {\n      color = '#ADF6A2';\n    }\n    // Register the type\n    entityTypes.push({\n      type: type,\n      labels : [type],\n      bgColor: color,\n      borderColor: 'darken'\n    });\n  }\n\n  var entities = [];\n  for (var sentI = 0; sentI < data.sentences.length; ++sentI) {\n    var tokens = currentSentences[sentI].tokens;\n    for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) {\n      var match = data.sentences[sentI][matchI];\n      // Add groups\n      for (groupName in match) {\n        if (groupName.startsWith(\"$\") || isInt(groupName)) {\n          addEntityType(groupName, '#FFFDA8');\n          var begin = parseInt(tokens[match[groupName].begin].characterOffsetBegin);\n          var end = parseInt(tokens[match[groupName].end - 1].characterOffsetEnd);\n          entities.push(['TOK_' + sentI + '_' + matchI + '_' + groupName,\n                              groupName,\n                              [[begin, end]]]);\n        }\n      }\n      // Add match\n      addEntityType('match', '#ADF6A2');\n      var begin = parseInt(tokens[match.begin].characterOffsetBegin);\n      var end = parseInt(tokens[match.end - 1].characterOffsetEnd);\n      entities.push(['TOK_' + sentI + '_' + matchI + '_match',\n                          'match',\n                          [[begin, end]]]);\n    }\n  }\n\n  Util.embed('tokensregex',\n         {entity_types: entityTypes, relation_types: []},\n         {text: currentText, entities: entities, relations: []}\n        );\n}  // END renderTokensregex()\n\n\n/**\n * Render a Semgrex response\n */\nfunction renderSemgrex(data) {\n  /**\n   * Register an entity type (a tag) for Brat\n   */\n  var entityTypesSet = {};\n  var entityTypes = [];\n  function addEntityType(type, color) {\n    // Don't add duplicates\n    if (entityTypesSet[type]) return;\n    entityTypesSet[type] = true;\n    // Set the color\n    if (typeof color === 'undefined') {\n      color = '#ADF6A2';\n    }\n    // Register the type\n    entityTypes.push({\n      type: type,\n      labels : [type],\n      bgColor: color,\n      borderColor: 'darken'\n    });\n  }\n\n\n  relationTypes = [{\n    type: 'semgrex',\n    labels: ['-'],\n    dashArray: '3,3',\n    arrowHead: 'none',\n  }];\n\n  var entities = [];\n  var relations = [];\n\n  for (var sentI = 0; sentI < data.sentences.length; ++sentI) {\n    var tokens = currentSentences[sentI].tokens;\n    for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) {\n      var match = data.sentences[sentI][matchI];\n      // Add match\n      addEntityType('match', '#ADF6A2');\n      var begin = parseInt(tokens[match.begin].characterOffsetBegin);\n      var end = parseInt(tokens[match.end - 1].characterOffsetEnd);\n      entities.push(['SEM_' + sentI + '_' + matchI + '_match',\n                          'match',\n                          [[begin, end]]]);\n\n      // Add groups\n      for (groupName in match) {\n        if (groupName.startsWith(\"$\") || isInt(groupName)) {\n          // (add node)\n          group = match[groupName];\n          groupName = groupName.substring(1);\n          addEntityType(groupName, '#FFFDA8');\n          var begin = parseInt(tokens[group.begin].characterOffsetBegin);\n          var end = parseInt(tokens[group.end - 1].characterOffsetEnd);\n          entities.push(['SEM_' + sentI + '_' + matchI + '_' + groupName,\n                              groupName,\n                              [[begin, end]]]);\n\n          // (add relation)\n          relations.push(['SEMGREX_' + sentI + '_' + matchI + '_' + groupName,\n                          'semgrex',\n                          [['governor', 'SEM_' + sentI + '_' + matchI + '_match'],\n                           ['dependent', 'SEM_' + sentI + '_' + matchI + '_' + groupName] ] ]);\n        }\n      }\n    }\n  }\n\n  Util.embed('semgrex',\n         {entity_types: entityTypes, relation_types: relationTypes},\n         {text: currentText, entities: entities, relations: relations}\n        );\n}  // END renderSemgrex\n\n/**\n * Render a Tregex response\n */\nfunction renderTregex(data) {\n  $('#tregex').empty();\n  $('#tregex').append('<pre>' + JSON.stringify(data, null, 4) + '</pre>');\n}  // END renderTregex\n\n// ----------------------------------------------------------------------------\n// MAIN\n// ----------------------------------------------------------------------------\n\n/**\n * MAIN()\n *\n * The entry point of the page\n */\n$(document).ready(function() {\n  // Some initial styling\n  $('.chosen-select').chosen();\n  $('.chosen-container').css('width', '100%');\n\n\n  // Language-specific changes\n  $('#language').on('change', function() {\n    $('#text').attr('dir', '');\n    if ($('#language').val() === 'ar' ||\n        $('#language').val() === 'fa' ||\n        $('#language').val() === 'he' ||\n        $('#language').val() === 'ur') {\n      $('#text').attr('dir', 'rtl');\n    }\n    if ($('#language').val() === 'ar') {\n      $('#text').attr('placeholder', 'على سبيل المثال، قفز الثعلب البني السريع فوق الكلب الكسول.');\n    } else if ($('#language').val() === 'en') {\n      $('#text').attr('placeholder', 'e.g., The quick brown fox jumped over the lazy dog.');\n    } else if ($('#language').val() === 'zh') {\n      $('#text').attr('placeholder', '例如，快速的棕色狐狸跳过了懒惰的狗。');\n    } else if ($('#language').val() === 'zh-Hant') {\n      $('#text').attr('placeholder', '例如，快速的棕色狐狸跳過了懶惰的狗。');\n    } else if ($('#language').val() === 'fr') {\n      $('#text').attr('placeholder', 'Par exemple, le renard brun rapide a sauté sur le chien paresseux.');\n    } else if ($('#language').val() === 'de') {\n      $('#text').attr('placeholder', 'Z. B. sprang der schnelle braune Fuchs über den faulen Hund.');\n    } else if ($('#language').val() === 'es') {\n      $('#text').attr('placeholder', 'Por ejemplo, el rápido zorro marrón saltó sobre el perro perezoso.');\n    } else if ($('#language').val() === 'ur') {\n      $('#text').attr('placeholder', 'میرا نام علی ہے');\n    } else {\n      $('#text').attr('placeholder', 'Unknown language for placeholder query: ' + $('#language').val());\n    }\n  });\n\n  // Submit on shift-enter\n  $('#text').keydown(function (event) {\n    if (event.keyCode == 13) {\n      if(event.shiftKey){\n        event.preventDefault();  // don't register the enter key when pressed\n        return false;\n      }\n    }\n  });\n  $('#text').keyup(function (event) {\n    if (event.keyCode == 13) {\n      if(event.shiftKey){\n        $('#submit').click();  // submit the form when the enter key is released\n        event.stopPropagation();\n        return false;\n      }\n    }\n  });\n\n  // Submit on clicking the 'submit' button\n  $('#submit').click(function() {\n    // Get the text to annotate\n    currentQuery = $('#text').val();\n    if (currentQuery.trim() == '') {\n      if ($('#language').val() === 'ar') {\n        currentQuery = 'قفز الثعلب البني السريع فوق الكلب الكسول.';\n      } else if ($('#language').val() === 'en') {\n        currentQuery = 'The quick brown fox jumped over the lazy dog.';\n      } else if ($('#language').val() === 'zh') {\n        currentQuery = '快速的棕色狐狸跳过了懒惰的狗。';\n      } else if ($('#language').val() === 'zh-Hant') {\n        currentQuery = '快速的棕色狐狸跳過了懶惰的狗。';\n      } else if ($('#language').val() === 'fr') {\n        currentQuery = 'Le renard brun rapide a sauté sur le chien paresseux.';\n      } else if ($('#language').val() === 'de') {\n        currentQuery = 'Sprang der schnelle braune Fuchs über den faulen Hund.';\n      } else if ($('#language').val() === 'es') {\n        currentQuery = 'El rápido zorro marrón saltó sobre el perro perezoso.';\n      } else if ($('#language').val() === 'ur') {\n        currentQuery = 'میرا نام علی ہے';\n      } else {\n        currentQuery = 'Unknown language for default query: ' + $('#language').val();\n      }\n      $('#text').val(currentQuery);\n    }\n    // Update the UI\n    $('#submit').prop('disabled', true);\n    $('#annotations').hide();\n    $('#patterns_row').hide();\n    $('#loading').show();\n\n    // Run query\n    $.ajax({\n      type: 'POST',\n      url: serverAddress + '?properties=' + encodeURIComponent(\n        '{\"annotators\": \"' + annotators() + '\", \"date\": \"' + date() + '\"}') +\n        '&pipelineLanguage=' + encodeURIComponent($('#language').val()),\n      data: encodeURIComponent(currentQuery), //jQuery doesn't automatically URI encode strings\n      dataType: 'json',\n      contentType: \"application/x-www-form-urlencoded;charset=UTF-8\",\n      responseType: \"application/json\",\n      success: function(data) {\n        $('#submit').prop('disabled', false);\n        if (typeof data === 'undefined' || data.sentences == undefined) {\n          alert(\"Failed to reach server!\");\n        } else {\n          // Process constituency parse\n          var constituencyParseProcessor = new ConstituencyParseProcessor();\n          constituencyParseProcessor.process(data);\n          // Empty divs\n          $('#annotations').empty();\n          // Re-render divs\n          function createAnnotationDiv(id, annotator, selector, label) {\n            // (make sure we requested that element)\n            if (annotators().split(\",\").indexOf(annotator) < 0) {\n              return;\n            }\n            // (make sure the data contains that element)\n            ok = false;\n            if (typeof data[selector] !== 'undefined') {\n              ok = true;\n            } else if (typeof data.sentences !== 'undefined' && data.sentences.length > 0) {\n              if (typeof data.sentences[0][selector] !== 'undefined') {\n                ok = true;\n              } else if (typeof data.sentences[0].tokens != 'undefined' && data.sentences[0].tokens.length > 0) {\n                // (make sure the annotator select is in at least one of the tokens of any sentence)\n                ok = data.sentences.some(function(sentence) {\n                  return sentence.tokens.some(function(token) {\n                    return typeof token[selector] !== 'undefined';\n                  });\n                });\n              }\n            }\n            // (render the element)\n            if (ok) {\n              $('#annotations').append('<h4 class=\"red\">' + label + ':</h4> <div id=\"' + id + '\"></div>');\n            }\n          }\n          // (create the divs)\n          //                  div id      annotator     field_in_data                          label\n          createAnnotationDiv('pos',      'pos',        'pos',                                 'Part-of-Speech (XPOS)'          );\n          createAnnotationDiv('upos',     'upos',       'upos',                                'Universal Part-of-Speech');\n          createAnnotationDiv('lemma',    'lemma',      'lemma',                               'Lemmas'                  );\n          createAnnotationDiv('ner',      'ner',        'ner',                                 'Named Entity Recognition');\n          createAnnotationDiv('deps',     'depparse',   'basicDependencies',                   'Universal Dependencies'      );\n          createAnnotationDiv('parse',    'parse',      'parseTree',                           'Constituency Parse'      );\n          //createAnnotationDiv('deps2',    'depparse',   'enhancedPlusPlusDependencies',        'Enhanced++ Dependencies' );\n          //createAnnotationDiv('openie',   'openie',     'openie',                              'Open IE'                 );\n          //createAnnotationDiv('coref',    'coref',      'corefs',                              'Coreference'             );\n          //createAnnotationDiv('entities', 'entitylink', 'entitylink',                          'Wikidict Entities'       );\n          //createAnnotationDiv('kbp',      'kbp',        'kbp',                                 'KBP Relations'           );\n          //createAnnotationDiv('sentiment','sentiment',  'sentiment',                           'Sentiment'               );\n          // Update UI\n          $('#loading').hide();\n          $('.corenlp_error').remove();  // Clear error messages\n          $('#annotations').show();\n          // Render\n          var reverse = ($('#language').val() === 'ar' || $('#language').val() === 'fa' || $('#language').val() === 'he' || $('#language').val() === 'ur');\n          render(data, reverse);\n          // Render patterns\n          //$('#annotations').append('<h4 class=\"red\" style=\"margin-top: 4ex;\">CoreNLP Tools:</h4>');  // TODO(gabor) a strange place to add this header to\n          //$('#patterns_row').show();\n        }\n      },\n      error: function(data) {\n        DATA = data;\n        var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('corenlp_error').attr('role', 'alert')\n        var button = $('<button type=\"button\" class=\"close\" data-dismiss=\"alert\" aria-label=\"Close\"><span aria-hidden=\"true\">&times;</span></button>');\n        var message = $('<span/>').text(data.responseText);\n        button.appendTo(alertDiv);\n        message.appendTo(alertDiv);\n        $('#loading').hide();\n        alertDiv.appendTo($('#errors'));\n        $('#submit').prop('disabled', false);\n      }\n    });\n    event.preventDefault();\n    event.stopPropagation();\n    return false;\n  });\n\n\n  // Support passing parameters on page launch, via window.location.hash parameters.\n  // Example: http://localhost:9000/#text=foo%20bar&annotators=pos,lemma,ner\n  (function() {\n    var rawParams = window.location.hash.slice(1).split(\"&\");\n    var params = {};\n    rawParams.forEach(function(paramKV) {\n      paramKV = paramKV.split(\"=\");\n      if (paramKV.length === 2) {\n        var key   = paramKV[0];\n        var value = paramKV[1];\n        params[key] = value;\n      }\n    });\n    if (params.text) {\n      var text = decodeURIComponent(params.text);\n      $('#text').val(text);\n    }\n    if (params.annotators) {\n      var annotators = params.annotators.split(\",\");\n      // De-select everything\n      $('#annotators').find('option').each(function() {\n        $(this).prop('selected', false);\n      });\n      // Select the specified ones.\n      annotators.forEach(function(a) {\n        $('#annotators').find('option[value=\"'+a+'\"]').prop('selected', true);\n      });\n      // Refresh Chosen\n      $('#annotators').trigger('chosen:updated');\n    }\n    if (params.text || params.annotators) {\n      // Finally, let's auto-submit.\n      $('#submit').click();\n    }\n  })();\n\n\n  $('#form_tokensregex').submit( function (e) {\n    // Don't actually submit the form\n    e.preventDefault();\n    // Get text\n    if ($('#tokensregex_search').val().trim() == '') {\n      $('#tokensregex_search').val('(?$foxtype [{pos:JJ}]+ ) fox');\n    }\n    var pattern = $('#tokensregex_search').val();\n    // Remove existing annotation\n    $('#tokensregex').remove();\n    // Make ajax call\n    $.ajax({\n      type: 'POST',\n      url: serverAddress + '/tokensregex?pattern=' + encodeURIComponent(\n        pattern.replace(\"&\", \"\\\\&\").replace('+', '\\\\+')) +\n        '&properties=' + encodeURIComponent(\n        '{\"annotators\": \"' + annotators() + '\", \"date\": \"' + date() + '\"}') +\n        '&pipelineLanguage=' + encodeURIComponent($('#language').val()),\n      data: encodeURIComponent(currentQuery),\n      success: function(data) {\n        $('.tokensregex_error').remove();  // Clear error messages\n        $('<div id=\"tokensregex\" class=\"pattern_brat\"/>').appendTo($('#div_tokensregex'));\n        renderTokensregex(data);\n      },\n      error: function(data) {\n        var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tokensregex_error').attr('role', 'alert')\n        var button = $('<button type=\"button\" class=\"close\" data-dismiss=\"alert\" aria-label=\"Close\"><span aria-hidden=\"true\">&times;</span></button>');\n        var message = $('<span/>').text(data.responseText);\n        button.appendTo(alertDiv);\n        message.appendTo(alertDiv);\n        alertDiv.appendTo($('#div_tokensregex'));\n      }\n    });\n  });\n\n\n  $('#form_semgrex').submit( function (e) {\n    // Don't actually submit the form\n    e.preventDefault();\n    // Get text\n    if ($('#semgrex_search').val().trim() == '') {\n      $('#semgrex_search').val('{pos:/VB.*/} >nsubj {}=subject >/nmod:.*/ {}=prep_phrase');\n    }\n    var pattern = $('#semgrex_search').val();\n    // Remove existing annotation\n    $('#semgrex').remove();\n    // Add missing required annotators\n    var requiredAnnotators = annotators().split(',');\n    if (requiredAnnotators.indexOf('depparse') < 0) {\n      requiredAnnotators.push('depparse');\n    }\n    // Make ajax call\n    $.ajax({\n      type: 'POST',\n      url: serverAddress + '/semgrex?pattern=' + encodeURIComponent(\n        pattern.replace(\"&\", \"\\\\&\").replace('+', '\\\\+')) +\n        '&properties=' + encodeURIComponent(\n        '{\"annotators\": \"' + requiredAnnotators.join(',') + '\", \"date\": \"' + date() + '\"}') +\n        '&pipelineLanguage=' + encodeURIComponent($('#language').val()),\n      data: encodeURIComponent(currentQuery),\n      success: function(data) {\n        $('.semgrex_error').remove();  // Clear error messages\n        $('<div id=\"semgrex\" class=\"pattern_brat\"/>').appendTo($('#div_semgrex'));\n        renderSemgrex(data);\n      },\n      error: function(data) {\n        var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('semgrex_error').attr('role', 'alert')\n        var button = $('<button type=\"button\" class=\"close\" data-dismiss=\"alert\" aria-label=\"Close\"><span aria-hidden=\"true\">&times;</span></button>');\n        var message = $('<span/>').text(data.responseText);\n        button.appendTo(alertDiv);\n        message.appendTo(alertDiv);\n        alertDiv.appendTo($('#div_semgrex'));\n      }\n    });\n  });\n\n  $('#form_tregex').submit( function (e) {\n    // Don't actually submit the form\n    e.preventDefault();\n    // Get text\n    if ($('#tregex_search').val().trim() == '') {\n      $('#tregex_search').val('NP < NN=animal');\n    }\n    var pattern = $('#tregex_search').val();\n    // Remove existing annotation\n    $('#tregex').remove();\n    // Add missing required annotators\n    var requiredAnnotators = annotators().split(',');\n    if (requiredAnnotators.indexOf('parse') < 0) {\n      requiredAnnotators.push('parse');\n    }\n    // Make ajax call\n    $.ajax({\n      type: 'POST',\n      url: serverAddress + '/tregex?pattern=' + encodeURIComponent(\n        pattern.replace(\"&\", \"\\\\&\").replace('+', '\\\\+')) +\n        '&properties=' + encodeURIComponent(\n        '{\"annotators\": \"' + requiredAnnotators.join(',') + '\", \"date\": \"' + date() + '\"}') +\n        '&pipelineLanguage=' + encodeURIComponent($('#language').val()),\n      data: encodeURIComponent(currentQuery),\n      success: function(data) {\n        $('.tregex_error').remove();  // Clear error messages\n        $('<div id=\"tregex\" class=\"pattern_brat\"/>').appendTo($('#div_tregex'));\n        renderTregex(data);\n      },\n      error: function(data) {\n        var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tregex_error').attr('role', 'alert')\n        var button = $('<button type=\"button\" class=\"close\" data-dismiss=\"alert\" aria-label=\"Close\"><span aria-hidden=\"true\">&times;</span></button>');\n        var message = $('<span/>').text(data.responseText);\n        button.appendTo(alertDiv);\n        message.appendTo(alertDiv);\n        alertDiv.appendTo($('#div_tregex'));\n      }\n    });\n  });\n\n});\n"
  },
  {
    "path": "stanza/pipeline/demo/stanza-parseviewer.js",
    "content": "//'use strict';\n\n//d3 || require('d3');\n//var dagreD3 = require('dagre-d3');\n//var jquery = require('jquery');\n//var $ = jquery;\n\nvar ParseViewer = function(params) {\n  // Container in which the scene template is displayed\n  this.selector = params.selector;\n  this.container = $(this.selector);\n  this.fitToGraph = true;\n  this.onClickNodeCallback = params.onClickNodeCallback;\n  this.onHoverNodeCallback = params.onHoverNodeCallback;\n  this.init();\n  return this;\n};\n\nParseViewer.MIN_WIDTH = 100;\nParseViewer.MIN_HEIGHT = 100;\n\nParseViewer.prototype.constructor = ParseViewer;\n\nParseViewer.prototype.getAutoWidth = function () {\n  return Math.max(ParseViewer.MIN_WIDTH, this.container.width());\n};\n\nParseViewer.prototype.getAutoHeight = function () {\n  return Math.max(ParseViewer.MIN_HEIGHT, this.container.height() - 20);\n};\n\nParseViewer.prototype.init = function () {\n  var canvasWidth = this.getAutoWidth();\n  var canvasHeight = this.getAutoHeight();\n  this.parseElem = d3.select(this.selector)\n    .append('svg')\n    .attr({'width': canvasWidth, 'height': canvasHeight})\n    .style({'width': canvasWidth, 'height': canvasHeight});\n  console.log(this.parseElem);\n  this.graph = null;\n  this.graphRendered = false;\n\n  this.controls = $('<div class=\"text\"></div>');\n  this.container.append(this.controls);\n};\n\nvar GraphBuilder = function(roots) {\n  // Create the input graph\n  this.graph = new dagreD3.graphlib.Graph()\n    .setGraph({})\n    .setDefaultEdgeLabel(function () {\n      return {};\n    });\n  this.visitIndex = 0;\n  //console.log('building graph', roots);\n  for (var i = 0; i < roots.length; i++) {\n    this.build(roots[i]);\n  }\n};\n\nGraphBuilder.prototype.build = function(node) {\n  console.log(node);\n  // Track my visit index\n  this.visitIndex++;\n  node.visitIndex = this.visitIndex;\n\n  // Add a node\n  var nodeData = node;  // TODO: replace with semantic data\n  var nodeLabel = node.label;\n  var nodeIndex = node.visitIndex;\n  var nodeClass = 'parse-RULE';\n\n  this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData });\n  if (node.parent) {\n    this.graph.setEdge(node.parent.visitIndex, nodeIndex, {\n      class: 'parse-EDGE'\n    });\n  }\n\n  if (node.isTerminal) {\n    this.visitIndex++;\n    nodeIndex = this.visitIndex;\n    nodeLabel = node.text;\n    nodeClass = 'parse-TERMINAL';\n\n    this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData });\n    this.graph.setEdge(node.visitIndex, nodeIndex, {\n      class: 'parse-EDGE'\n    });\n  } else if (node.children) {\n    for (var i = 0; i < node.children.length; i++) {\n      this.build(node.children[i]);\n    }\n  }\n};\n\nParseViewer.prototype.updateGraphPosition = function (svg, g, minWidth, minHeight) {\n  if (this.fitToGraph) {\n    minWidth = g.graph().width;\n    minHeight = this.getAutoHeight();\n  }\n  adjustGraphPositioning(svg, g, minWidth, minHeight);\n};\n\nfunction adjustGraphPositioning(svg, g, minWidth, minHeight) {\n  // Resize svg\n  var newWidth = Math.max(minWidth, g.graph().width);\n  var newHeight = Math.max(minHeight, g.graph().height + 40);\n  svg.attr({'width': newWidth, 'height': newHeight});\n  svg.style({'width': newWidth, 'height': newHeight});\n  // Center the graph\n  var svgGroup = svg.select('g');\n  var xCenterOffset = (svg.attr('width') - g.graph().width) / 2;\n  svgGroup.attr('transform', 'translate(' + xCenterOffset + ', 20)');\n  svg.attr('height', g.graph().height + 40);\n  svg.style('height', g.graph().height + 40);\n}\n\nParseViewer.prototype.renderGraph = function (svg, g, parse) {\n  // Create the renderer\n  var render = new dagreD3.render();\n  // Run the renderer. This is what draws the final graph.\n  var svgGroup = svg.select('g');\n  render(svgGroup, g);\n\n  var scope = this;\n  var nodes = svgGroup.selectAll('g.node');\n  nodes.on('click',\n    function (d) {\n      var v = d;\n      var node = g.node(v);\n      if (scope.onClickNodeCallback) {\n        scope.onClickNodeCallback(node.data);\n      }\n      console.log(g.node(v));\n    }\n  );\n\n  nodes.on('mouseover',\n    function (d) {\n      var v = d;\n      var node = g.node(v);\n      if (scope.onHoverNodeCallback) {\n        scope.onHoverNodeCallback(node.data);\n      }\n    }\n  );\n\n  this.updateGraphPosition(svg, g, svg.attr('width'), svg.attr('height'));\n  this.graphRendered = true;\n};\n\nParseViewer.prototype.showParse = function (root) {\n  this.showParses([root]);\n};\n\nParseViewer.prototype.showParses = function (roots) {\n  // Take parse and create a graph\n  var gb = new GraphBuilder(roots);\n  var g = gb.graph;\n\n  g.nodes().forEach(function (v) {\n    var node = g.node(v);\n    // Round the corners of the nodes\n    node.rx = node.ry = 5;\n  });\n\n  var svg = this.parseElem;\n  svg.selectAll('*').remove();\n  var svgGroup = svg.append('g');\n  this.graph = g;\n  this.parse = roots;\n  if (this.container.is(':visible')) {\n    if (roots.length > 0) {\n      this.renderGraph(svg, this.graph, this.parse);\n    }\n  } else {\n    this.graphRendered = false;\n  }\n};\n\nParseViewer.prototype.showAnnotation = function (annotation) {\n  var parses = [];\n  for (var i = 0; i < annotation.sentences.length; i++) {\n    var s = annotation.sentences[i];\n    if (s && s.parseTree) {\n      parses.push(s.parseTree);\n    }\n  }\n  this.showParses(parses);\n};\n\nParseViewer.prototype.onResize = function () {\n  var canvasWidth = this.getAutoWidth();\n  var canvasHeight = this.getAutoHeight();\n  var svg = this.parseElem;\n\n  // Center the graph\n  var svgGroup = svg.select('g');\n  if (svgGroup && this.graph) {\n    if (!this.graphRendered) {\n      svg.attr({'width': canvasWidth, 'height': canvasHeight});\n      svg.style({'width': canvasWidth, 'height': canvasHeight});\n      this.renderGraph(svg, this.graph, this.parse);\n    } else {\n      this.updateGraphPosition(svg, this.graph, canvasWidth, canvasHeight);\n    }\n  } else {\n    svg.attr({'width': canvasWidth, 'height': canvasHeight});\n    svg.style({'width': canvasWidth, 'height': canvasHeight});\n  }\n};\n\n// Exports\n//module.exports = ParseViewer;\n"
  },
  {
    "path": "stanza/pipeline/depparse_processor.py",
    "content": "\"\"\"\nProcessor for performing dependency parsing\n\"\"\"\n\nimport torch\n\nfrom stanza.models.common import doc\nfrom stanza.models.common.utils import unsort\nfrom stanza.models.common.vocab import VOCAB_PREFIX\nfrom stanza.models.depparse.data import DataLoader\nfrom stanza.models.depparse.trainer import Trainer\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\n\n# these imports trigger the \"register_variant\" decorations\nfrom stanza.pipeline.external.corenlp_converter_depparse import ConverterDepparse\n\nDEFAULT_SEPARATE_BATCH=150\n\n@register_processor(name=DEPPARSE)\nclass DepparseProcessor(UDProcessor):\n\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([DEPPARSE])\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([TOKENIZE, POS, LEMMA])\n\n    def __init__(self, config, pipeline, device):\n        self._pretagged = None\n        super().__init__(config, pipeline, device)\n\n    def _set_up_requires(self):\n        self._pretagged = self._config.get('pretagged')\n        if self._pretagged:\n            self._requires = set()\n        else:\n            self._requires = self.__class__.REQUIRES_DEFAULT\n\n    def _set_up_model(self, config, pipeline, device):\n        self._trainer = config.get('trainer')\n        if self._trainer is not None:\n            return\n\n        self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None\n        args = {'charlm_forward_file': config.get('forward_charlm_path', None),\n                'charlm_backward_file': config.get('backward_charlm_path', None)}\n        self._trainer = Trainer(args=args, pretrain=self.pretrain, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)\n\n    def get_known_relations(self):\n        \"\"\"\n        Return a list of relations which this processor can produce\n        \"\"\"\n        keys = [k for k in self.vocab['deprel']._unit2id.keys() if k not in VOCAB_PREFIX]\n        return keys\n\n    def process(self, document):\n        if hasattr(self, '_variant'):\n            return self._variant.process(document)\n\n        if any(word.upos is None and word.xpos is None for sentence in document.sentences for word in sentence.words):\n            raise ValueError(\"POS not run before depparse!\")\n        try:\n            batch = DataLoader(document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True,\n                               sort_during_eval=self.config.get('sort_during_eval', True),\n                               min_length_to_batch_separately=self.config.get('min_length_to_batch_separately', DEFAULT_SEPARATE_BATCH))\n            with torch.no_grad():\n                preds = []\n                for i, b in enumerate(batch):\n                    preds += self.trainer.predict(b)\n            if batch.data_orig_idx is not None:\n                preds = unsort(preds, batch.data_orig_idx)\n            batch.doc.set((doc.HEAD, doc.DEPREL), [y for x in preds for y in x])\n            # build dependencies based on predictions\n            for sentence in batch.doc.sentences:\n                sentence.build_dependencies()\n            return batch.doc\n        except RuntimeError as e:\n            if str(e).startswith(\"CUDA out of memory. Tried to allocate\"):\n                new_message = str(e) + \" ... You may be able to compensate for this by separating long sentences into their own batch with a parameter such as depparse_min_length_to_batch_separately=150 or by limiting the overall batch size with depparse_batch_size=400.\"\n                raise RuntimeError(new_message) from e\n            else:\n                raise\n"
  },
  {
    "path": "stanza/pipeline/external/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/pipeline/external/corenlp_converter_depparse.py",
    "content": "\"\"\"\nA depparse processor which converts constituency trees using CoreNLP\n\"\"\"\n\nfrom stanza.pipeline._constants import TOKENIZE, CONSTITUENCY, DEPPARSE\nfrom stanza.pipeline.processor import ProcessorVariant, register_processor_variant\nfrom stanza.server.dependency_converter import DependencyConverter\n\n@register_processor_variant(DEPPARSE, 'converter')\nclass ConverterDepparse(ProcessorVariant):\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([TOKENIZE, CONSTITUENCY])\n\n    def __init__(self, config):\n        if config['lang'] != 'en':\n            raise ValueError(\"Constituency to dependency converter only works for English\")\n\n        # TODO: get classpath from config\n        # TODO: close this when finished?\n        #   a more involved approach would be to turn the Pipeline into\n        #   a context with __enter__ and __exit__\n        #   __exit__ would try to free all resources, although some\n        #   might linger such as GPU allocations\n        #   maybe it isn't worth even trying to clean things up on account of that\n        self.converter = DependencyConverter(classpath=\"$CLASSPATH\")\n        self.converter.open_pipe()\n\n    def process(self, document):\n        return self.converter.process(document)\n"
  },
  {
    "path": "stanza/pipeline/external/jieba.py",
    "content": "\"\"\"\nProcessors related to Jieba in the pipeline.\n\"\"\"\n\nimport re\nimport warnings\n\nfrom stanza.models.common import doc\nfrom stanza.pipeline._constants import TOKENIZE\nfrom stanza.pipeline.processor import ProcessorVariant, register_processor_variant\n\ndef check_jieba():\n    \"\"\"\n    Import necessary components from Jieba to perform tokenization.\n    \"\"\"\n    try:\n        import jieba\n    except ImportError:\n        raise ImportError(\n            \"Jieba is used but not installed on your machine. Go to https://pypi.org/project/jieba/ for installation instructions.\"\n        )\n    return True\n\n@register_processor_variant(TOKENIZE, 'jieba')\nclass JiebaTokenizer(ProcessorVariant):\n    def __init__(self, config):\n        \"\"\" Construct a Jieba-based tokenizer by loading the Jieba pipeline.\n\n        Note that this tokenizer uses regex for sentence segmentation.\n        \"\"\"\n        if config['lang'] not in ['zh', 'zh-hans', 'zh-hant']:\n            raise Exception(\"Jieba tokenizer is currently only allowed in Chinese (simplified or traditional) pipelines.\")\n\n        # Surpress a DeprecationWarning about pkg_resource from jieba.\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", category=DeprecationWarning, module=\"jieba\")\n            check_jieba()\n            import jieba\n\n        self.nlp = jieba\n        self.no_ssplit = config.get('no_ssplit', False)\n\n    def process(self, document):\n        \"\"\" Tokenize a document with the Jieba tokenizer and wrap the results into a Doc object.\n        \"\"\"\n        if isinstance(document, doc.Document):\n            text = document.text\n        else:\n            text = document\n        if not isinstance(text, str):\n            raise Exception(\"Must supply a string or Stanza Document object to the Jieba tokenizer.\")\n        tokens = self.nlp.cut(text, cut_all=False)\n\n        sentences = []\n        current_sentence = []\n        offset = 0\n        for token in tokens:\n            if re.match(r'\\s+', token):\n                offset += len(token)\n                continue\n\n            token_entry = {\n                doc.TEXT: token,\n                doc.MISC: f\"{doc.START_CHAR}={offset}|{doc.END_CHAR}={offset+len(token)}\"\n            }\n            current_sentence.append(token_entry)\n            offset += len(token)\n\n            if not self.no_ssplit and token in ['。', '！', '？', '!', '?']:\n                sentences.append(current_sentence)\n                current_sentence = []\n\n        if len(current_sentence) > 0:\n            sentences.append(current_sentence)\n\n        return doc.Document(sentences, text)\n"
  },
  {
    "path": "stanza/pipeline/external/pythainlp.py",
    "content": "\"\"\"\nProcessors related to PyThaiNLP in the pipeline.\n\nGitHub Home: https://github.com/PyThaiNLP/pythainlp\n\"\"\"\n\nfrom stanza.models.common import doc\nfrom stanza.pipeline._constants import TOKENIZE\nfrom stanza.pipeline.processor import ProcessorVariant, register_processor_variant\n\ndef check_pythainlp():\n    \"\"\"\n    Import necessary components from pythainlp to perform tokenization.\n    \"\"\"\n    try:\n        import pythainlp\n    except ImportError:\n        raise ImportError(\n            \"The pythainlp library is required. \"\n            \"Try to install it with `pip install pythainlp`. \"\n            \"Go to https://github.com/PyThaiNLP/pythainlp for more information.\"\n        )\n    return True\n\n@register_processor_variant(TOKENIZE, 'pythainlp')\nclass PyThaiNLPTokenizer(ProcessorVariant):\n    def __init__(self, config):\n        \"\"\" Construct a PyThaiNLP-based tokenizer.\n\n        Note that we always uses the default tokenizer of PyThaiNLP for sentence and word segmentation.\n        Currently this is a CRF model for sentence segmentation and a dictionary-based model (newmm) for word segmentation.\n        \"\"\"\n        if config['lang'] != 'th':\n            raise Exception(\"PyThaiNLP tokenizer is only allowed in Thai pipeline.\")\n\n        check_pythainlp()\n        from pythainlp.tokenize import sent_tokenize as pythai_sent_tokenize\n        from pythainlp.tokenize import word_tokenize as pythai_word_tokenize\n\n        self.pythai_sent_tokenize = pythai_sent_tokenize\n        self.pythai_word_tokenize = pythai_word_tokenize\n        self.no_ssplit = config.get('no_ssplit', False)\n    \n    def process(self, document):\n        \"\"\" Tokenize a document with the PyThaiNLP tokenizer and wrap the results into a Doc object.\n        \"\"\"\n        if isinstance(document, doc.Document):\n            text = document.text\n        else:\n            text = document\n        if not isinstance(text, str):\n            raise Exception(\"Must supply a string or Stanza Document object to the PyThaiNLP tokenizer.\")\n\n        sentences = []\n        current_sentence = []\n        offset = 0\n\n        if self.no_ssplit:\n            # skip sentence segmentation\n            sent_strs = [text]\n        else:\n            sent_strs = self.pythai_sent_tokenize(text, engine='crfcut')\n        for sent_str in sent_strs:\n            for token_str in self.pythai_word_tokenize(sent_str, engine='newmm'):\n                # by default pythainlp will output whitespace as a token\n                # we need to skip these tokens to be consistent with other tokenizers\n                if token_str.isspace():\n                    offset += len(token_str)\n                    continue\n                \n                # create token entry\n                token_entry = {\n                    doc.TEXT: token_str,\n                    doc.MISC: f\"{doc.START_CHAR}={offset}|{doc.END_CHAR}={offset+len(token_str)}\"\n                }\n                current_sentence.append(token_entry)\n                offset += len(token_str)\n            \n            # finish sentence\n            sentences.append(current_sentence)\n            current_sentence = []\n\n        if len(current_sentence) > 0:\n            sentences.append(current_sentence)\n\n        return doc.Document(sentences, text)"
  },
  {
    "path": "stanza/pipeline/external/spacy.py",
    "content": "\"\"\"\nProcessors related to spaCy in the pipeline.\n\"\"\"\n\nfrom stanza.models.common import doc\nfrom stanza.pipeline._constants import TOKENIZE\nfrom stanza.pipeline.processor import ProcessorVariant, register_processor_variant\n\ndef check_spacy():\n    \"\"\"\n    Import necessary components from spaCy to perform tokenization.\n    \"\"\"\n    try:\n        import spacy\n    except ImportError:\n        raise ImportError(\n            \"spaCy is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions.\"\n        )\n    return True\n\n@register_processor_variant(TOKENIZE, 'spacy')\nclass SpacyTokenizer(ProcessorVariant):\n    def __init__(self, config):\n        \"\"\" Construct a spaCy-based tokenizer by loading the spaCy pipeline.\n        \"\"\"\n        if config['lang'] != 'en':\n            raise Exception(\"spaCy tokenizer is currently only allowed in English pipeline.\")\n\n        try:\n            import spacy\n            from spacy.lang.en import English\n        except ImportError:\n            raise ImportError(\n                \"spaCy 2.0+ is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions.\"\n            )\n\n        # Create a Tokenizer with the default settings for English\n        # including punctuation rules and exceptions\n        self.nlp = English()\n        # by default spacy uses dependency parser to do ssplit\n        # we need to add a sentencizer for fast rule-based ssplit\n        if spacy.__version__.startswith(\"2.\"):\n            self.nlp.add_pipe(self.nlp.create_pipe(\"sentencizer\"))\n        else:\n            self.nlp.add_pipe(\"sentencizer\")\n        self.no_ssplit = config.get('no_ssplit', False)\n\n    def process(self, document):\n        \"\"\" Tokenize a document with the spaCy tokenizer and wrap the results into a Doc object.\n        \"\"\"\n        if isinstance(document, doc.Document):\n            text = document.text\n        else:\n            text = document\n        if not isinstance(text, str):\n            raise Exception(\"Must supply a string or Stanza Document object to the spaCy tokenizer.\")\n        spacy_doc = self.nlp(text)\n\n        sentences = []\n        for sent in spacy_doc.sents:\n            tokens = []\n            for tok in sent:\n                token_entry = {\n                    doc.TEXT: tok.text,\n                    doc.MISC: f\"{doc.START_CHAR}={tok.idx}|{doc.END_CHAR}={tok.idx+len(tok.text)}\"\n                }\n                tokens.append(token_entry)\n            sentences.append(tokens)\n\n        # if no_ssplit is set, flatten all the sentences into one sentence\n        if self.no_ssplit:\n            sentences = [[t for s in sentences for t in s]]\n\n        return doc.Document(sentences, text)\n"
  },
  {
    "path": "stanza/pipeline/external/sudachipy.py",
    "content": "\"\"\"\nProcessors related to SudachiPy in the pipeline.\n\nGitHub Home: https://github.com/WorksApplications/SudachiPy\n\"\"\"\n\nimport re\n\nfrom stanza.models.common import doc\nfrom stanza.pipeline._constants import TOKENIZE\nfrom stanza.pipeline.processor import ProcessorVariant, register_processor_variant\n\ndef check_sudachipy():\n    \"\"\"\n    Import necessary components from SudachiPy to perform tokenization.\n    \"\"\"\n    try:\n        import sudachipy\n        import sudachidict_core\n    except ImportError:\n        raise ImportError(\n            \"Both sudachipy and sudachidict_core libraries are required. \"\n            \"Try install them with `pip install sudachipy sudachidict_core`. \"\n            \"Go to https://github.com/WorksApplications/SudachiPy for more information.\"\n        )\n    return True\n\n@register_processor_variant(TOKENIZE, 'sudachipy')\nclass SudachiPyTokenizer(ProcessorVariant):\n    def __init__(self, config):\n        \"\"\" Construct a SudachiPy-based tokenizer.\n\n        Note that this tokenizer uses regex for sentence segmentation.\n        \"\"\"\n        if config['lang'] != 'ja':\n            raise Exception(\"SudachiPy tokenizer is only allowed in Japanese pipelines.\")\n\n        check_sudachipy()\n        from sudachipy import tokenizer\n        from sudachipy import dictionary\n\n        self.tokenizer = dictionary.Dictionary().create()\n        self.no_ssplit = config.get('no_ssplit', False)\n\n    def process(self, document):\n        \"\"\" Tokenize a document with the SudachiPy tokenizer and wrap the results into a Doc object.\n        \"\"\"\n        if isinstance(document, doc.Document):\n            text = document.text\n        else:\n            text = document\n        if not isinstance(text, str):\n            raise Exception(\"Must supply a string or Stanza Document object to the SudachiPy tokenizer.\")\n\n        # we use the default sudachipy tokenization mode (i.e., mode C)\n        # more config needs to be added to support other modes\n\n        tokens = self.tokenizer.tokenize(text)\n\n        sentences = []\n        current_sentence = []\n        for token in tokens:\n            token_text = token.surface()\n            # by default sudachipy will output whitespace as a token\n            # we need to skip these tokens to be consistent with other tokenizers\n            if token_text.isspace():\n                continue\n            start = token.begin()\n            end = token.end()\n\n            token_entry = {\n                doc.TEXT: token_text,\n                doc.MISC: f\"{doc.START_CHAR}={start}|{doc.END_CHAR}={end}\"\n            }\n            current_sentence.append(token_entry)\n\n            if not self.no_ssplit and token_text in ['。', '！', '？', '!', '?']:\n                sentences.append(current_sentence)\n                current_sentence = []\n\n        if len(current_sentence) > 0:\n            sentences.append(current_sentence)\n\n        return doc.Document(sentences, text)\n"
  },
  {
    "path": "stanza/pipeline/langid_processor.py",
    "content": "\"\"\"\nProcessor for determining language of text.\n\"\"\"\n\nimport emoji\nimport re\nimport stanza\nimport torch\n\nfrom stanza.models.common.doc import Document\nfrom stanza.models.langid.model import LangIDBiLSTM\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\n\n\n@register_processor(name=LANGID)\nclass LangIDProcessor(UDProcessor):\n    \"\"\"\n    Class for detecting language of text.\n    \"\"\"\n\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([LANGID])\n\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([])\n\n    # default max sequence length\n    MAX_SEQ_LENGTH_DEFAULT = 1000\n\n    def _set_up_model(self, config, pipeline, device):\n        batch_size = config.get(\"batch_size\", 64)\n        self._model = LangIDBiLSTM.load(path=config[\"model_path\"], device=device,\n                                        batch_size=batch_size, lang_subset=config.get(\"lang_subset\"))\n        self._char_index = self._model.char_to_idx\n        self._clean_text = config.get(\"clean_text\")\n\n    def _text_to_tensor(self, docs):\n        \"\"\"\n        Map list of strings to batch tensor. Assumed all docs are same length.\n        \"\"\"\n\n        device = next(self._model.parameters()).device\n        all_docs = []\n        for doc in docs:\n            doc_chars = [self._char_index.get(c, self._char_index[\"UNK\"]) for c in list(doc)]\n            all_docs.append(doc_chars)\n        return torch.tensor(all_docs, device=device, dtype=torch.long)\n\n    def _id_langs(self, batch_tensor):\n        \"\"\"\n        Identify languages for each sequence in a batch tensor\n        \"\"\"\n        predictions = self._model.prediction_scores(batch_tensor)\n        prediction_labels = [self._model.idx_to_tag[prediction] for prediction in predictions]\n\n        return prediction_labels\n\n    # regexes for cleaning text\n    http_regex = re.compile(r\"https?:\\/\\/t\\.co/[a-zA-Z0-9]+\")\n    handle_regex = re.compile(\"@[a-zA-Z0-9_]+\")\n    hashtag_regex = re.compile(\"#[a-zA-Z]+\")\n    punctuation_regex = re.compile(\"[!.]+\")\n    all_regexes = [http_regex, handle_regex, hashtag_regex, punctuation_regex]\n\n    @staticmethod\n    def clean_text(text):\n        \"\"\"\n        Process text to improve language id performance. Main emphasis is on tweets, this method removes shortened\n        urls, hashtags, handles, and punctuation and emoji.\n        \"\"\"\n\n        for regex in LangIDProcessor.all_regexes:\n            text = regex.sub(\" \", text)\n\n        text = emoji.emojize(text)\n        text = emoji.replace_emoji(text, replace=' ')\n\n        if text.strip():\n            text = text.strip()\n\n        return text\n\n    def _process_list(self, docs):\n        \"\"\"\n        Identify language of list of strings or Documents\n        \"\"\"\n\n        if len(docs) == 0:\n            # TO DO: what standard do we want for bad input, such as empty list?\n            # TO DO: more handling of bad input\n            return\n\n        if isinstance(docs[0], str):\n            docs = [Document([], text) for text in docs]\n\n        docs_by_length = {}\n        for doc in docs:\n            text = LangIDProcessor.clean_text(doc.text) if self._clean_text else doc.text\n            doc_length = len(text)\n            if doc_length not in docs_by_length:\n                docs_by_length[doc_length] = []\n            docs_by_length[doc_length].append((doc, text))\n\n        for doc_length in docs_by_length:\n            inputs = [doc[1] for doc in docs_by_length[doc_length]]\n            predictions = self._id_langs(self._text_to_tensor(inputs))\n            for doc, lang in zip(docs_by_length[doc_length], predictions):\n                doc[0].lang = lang\n\n        return docs\n\n    def process(self, doc):\n        \"\"\"\n        Handle single str or Document\n        \"\"\"\n\n        wrapped_doc = [doc]\n        return self._process_list(wrapped_doc)[0]\n\n    def bulk_process(self, docs):\n        \"\"\"\n        Handle list of strings or Documents\n        \"\"\"\n\n        return self._process_list(docs)\n\n"
  },
  {
    "path": "stanza/pipeline/lemma_processor.py",
    "content": "\"\"\"\nProcessor for performing lemmatization\n\"\"\"\n\nfrom itertools import compress\n\nimport torch\n\nfrom stanza.models.common import doc\nfrom stanza.models.lemma.data import DataLoader\nfrom stanza.models.lemma.trainer import Trainer\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\n\nWORD_TAGS = [doc.TEXT, doc.UPOS]\n\n@register_processor(name=LEMMA)\nclass LemmaProcessor(UDProcessor):\n\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([LEMMA])\n    # set of processor requirements for this processor\n    # pos will be added later for non-identity lemmatizerx\n    REQUIRES_DEFAULT = set([TOKENIZE])\n    # default batch size\n    DEFAULT_BATCH_SIZE = 5000\n\n    def __init__(self, config, pipeline, device):\n        # run lemmatizer in identity mode\n        self._use_identity = None\n        self._pretagged = None\n        super().__init__(config, pipeline, device)\n\n    @property\n    def use_identity(self):\n        return self._use_identity\n\n    def _set_up_model(self, config, pipeline, device):\n        if config.get('use_identity') in ['True', True]:\n            self._use_identity = True\n            self._config = config\n            self.config['batch_size'] = LemmaProcessor.DEFAULT_BATCH_SIZE\n        else:\n            # the lemmatizer only looks at one word when making\n            # decisions, not the surrounding context\n            # therefore, we can save some time by remembering what\n            # we did the last time we saw any given word,pos\n            # since a long running program will remember everything\n            # (unless we go back and make it smarter)\n            # we make this an option, not the default\n            # TODO: need to update the cache to skip the contextual lemmatizer\n            self.store_results = config.get('store_results', False)\n            self._use_identity = False\n            args = {'charlm_forward_file': config.get('forward_charlm_path', None),\n                    'charlm_backward_file': config.get('backward_charlm_path', None)}\n            lemma_classifier_args = dict(args)\n            lemma_classifier_args['wordvec_pretrain_file'] = config.get('pretrain_path', None)\n            self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache, lemma_classifier_args=lemma_classifier_args)\n\n    def _set_up_requires(self):\n        self._pretagged = self._config.get('pretagged', None)\n        if self._pretagged:\n            self._requires = set()\n        elif self.config.get('pos') and not self.use_identity:\n            self._requires = LemmaProcessor.REQUIRES_DEFAULT.union(set([POS]))\n        else:\n            self._requires = LemmaProcessor.REQUIRES_DEFAULT\n\n    def process(self, document):\n        if not self.use_identity:\n            batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True)\n        else:\n            batch = DataLoader(document, self.config['batch_size'], self.config, evaluation=True, conll_only=True)\n        if self.use_identity:\n            preds = [word.text for sent in batch.doc.sentences for word in sent.words]\n        elif self.config.get('dict_only', False):\n            preds = self.trainer.predict_dict(batch.doc.get([doc.TEXT, doc.UPOS]))\n        else:\n            if self.config.get('ensemble_dict', False):\n                # skip the seq2seq model when we can\n                skip = self.trainer.skip_seq2seq(batch.doc.get([doc.TEXT, doc.UPOS]))\n                # although there is no explicit use of caseless or lemma_caseless in this processor,\n                # it shows up in the config which gets passed to the DataLoader,\n                # possibly affecting its results\n                seq2seq_batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab,\n                                           evaluation=True, skip=skip, expand_unk_vocab=True)\n            else:\n                seq2seq_batch = batch\n\n            with torch.no_grad():\n                preds = []\n                edits = []\n                for i, b in enumerate(seq2seq_batch):\n                    ps, es = self.trainer.predict(b, self.config['beam_size'], seq2seq_batch.vocab)\n                    preds += ps\n                    if es is not None:\n                        edits += es\n\n            if self.config.get('ensemble_dict', False):\n                word_tags = batch.doc.get(WORD_TAGS)\n                words = [x[0] for x in word_tags]\n                preds = self.trainer.postprocess([x for x, y in zip(words, skip) if not y], preds, edits=edits)\n                if self.store_results:\n                    new_word_tags = compress(word_tags, map(lambda x: not x, skip))\n                    new_predictions = [(x[0], x[1], y) for x, y in zip(new_word_tags, preds)]\n                    self.trainer.train_dict(new_predictions, update_word_dict=False)\n                # expand seq2seq predictions to the same size as all words\n                i = 0\n                preds1 = []\n                for s in skip:\n                    if s:\n                        preds1.append('')\n                    else:\n                        preds1.append(preds[i])\n                        i += 1\n                preds = self.trainer.ensemble(word_tags, preds1)\n            else:\n                preds = self.trainer.postprocess(batch.doc.get([doc.TEXT]), preds, edits=edits)\n\n            if self.trainer.has_contextual_lemmatizers():\n                preds = self.trainer.update_contextual_preds(batch.doc, preds)\n\n        # map empty string lemmas to '_'\n        preds = [max([(len(x), x), (0, '_')])[1] for x in preds]\n        batch.doc.set([doc.LEMMA], preds)\n        return batch.doc\n"
  },
  {
    "path": "stanza/pipeline/morphseg_processor.py",
    "content": "from stanza.pipeline.core import UnsupportedProcessorError\nfrom stanza.pipeline.processor import UDProcessor, register_processor\nfrom stanza.pipeline._constants import MORPHSEG, TOKENIZE\n\n\n@register_processor(name=MORPHSEG)\nclass MorphSegProcessor(UDProcessor):\n    PROVIDES_DEFAULT = {MORPHSEG}\n    REQUIRES_DEFAULT = {TOKENIZE}\n\n    def __init__(self, config, pipeline, device):\n        self._config = config\n        self._pipeline = pipeline\n        self._set_up_requires()\n        self._set_up_provides()\n        self._set_up_model(config, pipeline, device)\n\n    def _set_up_model(self, config, pipeline, device):\n        try:\n            from morphseg import MorphemeSegmenter\n        except ImportError:\n            raise ImportError(\n                \"morphseg is required for morpheme segmentation. \"\n                \"Install it with: pip install morphseg\"\n            )\n\n        lang = config.get('lang', 'en')\n        model_path = config.get('morphseg_model_path', None)\n\n        if model_path:\n            self._segmenter = MorphemeSegmenter(\n                lang=lang,\n                load_pretrained=False,\n                model_filepath=model_path,\n                is_local=True\n            )\n        else:\n            self._segmenter = MorphemeSegmenter(\n                lang=lang,\n                load_pretrained=True\n            )\n        if self._segmenter.sequence_labeller is None:\n            raise UnsupportedProcessorError(\"morphseg\", lang)\n\n    def process(self, document):\n        # Collect all words from all sentences\n        all_words = []\n        word_mapping = []  # Track which sentence and word index each prediction belongs to\n\n        for sent_idx, sent in enumerate(document.sentences):\n            if not sent.words:\n                continue\n            for word_idx, word in enumerate(sent.words):\n                all_words.append(word.text)\n                word_mapping.append((sent_idx, word_idx))\n\n        if not all_words:\n            return document\n\n        # Prepare input for morphseg (it expects normalized, lowercased character lists)\n        word_char_lists = [\n            list(self._segmenter.normalize_for_morphology(word))\n            for word in all_words\n        ]\n\n        # Batch predict using the internal sequence_labeller\n        predictions = self._segmenter.sequence_labeller.predict(sources=word_char_lists)\n\n        # Extract segmentations from predictions\n        from morphseg.training.oracle import rules2sent\n        segmentations = [\n            rules2sent(\n                source=[align_pos.symbol for align_pos in pred.alignment],\n                actions=pred.prediction\n            ).split(' @@')  # Split by morphseg's default delimiter\n            for pred in predictions\n        ]\n\n        # Assign segmentations back to words\n        for (sent_idx, word_idx), seg in zip(word_mapping, segmentations):\n            document.sentences[sent_idx].words[word_idx].morphemes = seg\n\n        return document\n"
  },
  {
    "path": "stanza/pipeline/multilingual.py",
    "content": "\"\"\"\nClass for running multilingual pipelines\n\"\"\"\n\nfrom collections import OrderedDict\nimport copy\nimport logging\nfrom typing import Union\n\nfrom stanza.models.common.doc import Document\nfrom stanza.models.common.utils import default_device\nfrom stanza.pipeline.core import Pipeline, DownloadMethod\nfrom stanza.pipeline._constants import *\nfrom stanza.resources.common import DEFAULT_MODEL_DIR, get_language_resources, load_resources_json\n\nlogger = logging.getLogger('stanza')\n\nclass MultilingualPipeline:\n    \"\"\"\n    Pipeline for handling multilingual data. Takes in text, detects language, and routes request to pipeline for that\n    language.\n\n    You can specify options to individual language pipelines with the lang_configs field.\n    For example, if you want English pipelines to have NER, but want to turn that off for French, you can do:\n        lang_configs = {\"en\": {\"processors\": \"tokenize,pos,lemma,depparse,ner\"},\n                        \"fr\": {\"processors\": \"tokenize,pos,lemma,depparse\"}}\n        pipeline = MultilingualPipeline(lang_configs=lang_configs)\n\n    You can also pass in a defaultdict created in such a way that it provides default parameters for each language.\n    For example, in order to only get tokenization for each language:\n    (remembering that the Pipeline will automagically add MWT to a language which uses MWT):\n        from collections import defaultdict\n        lang_configs = defaultdict(lambda: dict(processors=\"tokenize\"))\n        pipeline = MultilingualPipeline(lang_configs=lang_configs)\n\n    download_method can be set as in Pipeline to turn off downloading\n      of the .json config or turn off downloading of everything\n    \"\"\"\n\n    def __init__(self,\n                 model_dir: str = DEFAULT_MODEL_DIR,\n                 lang_id_config: dict = None,\n                 lang_configs: dict = None,\n                 ld_batch_size: int = 64,\n                 max_cache_size: int = 10,\n                 use_gpu: bool = None,\n                 restrict: bool = False,\n                 device: str = None,\n                 download_method: DownloadMethod = DownloadMethod.DOWNLOAD_RESOURCES,\n                 # python 3.6 compatibility - maybe want to update to 3.7 at some point\n                 processors: Union[str, list] = None,\n    ):\n        # set up configs and cache for various language pipelines\n        self.model_dir = model_dir\n        self.lang_id_config = {} if lang_id_config is None else copy.deepcopy(lang_id_config)\n        self.lang_configs = {} if lang_configs is None else copy.deepcopy(lang_configs)\n        self.max_cache_size = max_cache_size\n        # OrderedDict so we can use it as a LRU cache\n        # most recent Pipeline goes to the end, pop the oldest one\n        # when we run out of space\n        self.pipeline_cache = OrderedDict()\n        if processors is None:\n            self.default_processors = None\n        elif isinstance(processors, str):\n            self.default_processors = [x.strip() for x in processors.split(\",\")]\n        else:\n            self.default_processors = list(processors)\n\n        self.download_method = download_method\n        if 'download_method' not in self.lang_id_config:\n            self.lang_id_config['download_method'] = self.download_method\n\n        # if lang is not in any of the lang_configs, update them to\n        # include the lang parameter.  otherwise, the default language\n        # will always be used...\n        for lang in self.lang_configs:\n            if 'lang' not in self.lang_configs[lang]:\n                self.lang_configs[lang]['lang'] = lang\n\n        if restrict and 'langid_lang_subset' not in self.lang_id_config:\n            known_langs = sorted(self.lang_configs.keys())\n            if known_langs == 0:\n                logger.warning(\"MultilingualPipeline asked to restrict to lang_configs, but lang_configs was empty.  Ignoring...\")\n            else:\n                logger.debug(\"Restricting MultilingualPipeline to %s\", known_langs)\n                self.lang_id_config['langid_lang_subset'] = known_langs\n\n        # set use_gpu\n        if device is None:\n            if use_gpu is None or use_gpu == True:\n                device = default_device()\n            else:\n                device = 'cpu'\n        self.device = device\n        \n        # build language id pipeline\n        self.lang_id_pipeline = Pipeline(dir=self.model_dir, lang='multilingual', processors=\"langid\", \n                                         device=self.device, **self.lang_id_config)\n        # load the resources so that we can refer to it later when building a new pipeline\n        # note that it was either downloaded or not based on download_method when building the lang_id_pipeline\n        self.resources = load_resources_json(self.model_dir)\n\n    def _update_pipeline_cache(self, lang):\n        \"\"\"\n        Do any necessary updates to the pipeline cache for this language. This includes building a new\n        pipeline for the lang, and possibly clearing out a language with the old last access date.\n        \"\"\"\n\n        # update request history\n        if lang in self.pipeline_cache:\n            self.pipeline_cache.move_to_end(lang, last=True)\n\n        # update language configs\n        # try/except to allow for a defaultdict\n        try:\n            lang_config = self.lang_configs[lang]\n        except KeyError:\n            lang_config = {'lang': lang}\n            self.lang_configs[lang] = lang_config\n\n        # if a defaultdict is passed in, the defaultdict might not contain 'lang'\n        # so even though we tried adding 'lang' in the constructor, we'll check again here\n        if 'lang' not in lang_config:\n            lang_config['lang'] = lang\n\n        if 'download_method' not in lang_config:\n            lang_config['download_method'] = self.download_method\n\n        if 'processors' not in lang_config:\n            if self.default_processors:\n                lang_resources = get_language_resources(self.resources, lang)\n                lang_processors = [x for x in self.default_processors if x in lang_resources]\n                if lang_processors != self.default_processors:\n                    logger.info(\"Not all requested processors %s available for %s.  Loading %s instead\", self.default_processors, lang, lang_processors)\n                lang_config['processors'] = \",\".join(lang_processors)\n\n        if 'device' not in lang_config:\n            lang_config['device'] = self.device\n\n        # update pipeline cache\n        if lang not in self.pipeline_cache:\n            logger.debug(\"Loading unknown language in MultilingualPipeline: %s\", lang)\n            # clear least recently used lang from pipeline cache\n            if len(self.pipeline_cache) == self.max_cache_size:\n                self.pipeline_cache.popitem(last=False)\n            self.pipeline_cache[lang] = Pipeline(dir=self.model_dir, **self.lang_configs[lang])\n\n    def process(self, doc):\n        \"\"\"\n        Run language detection on a string, a Document, or a list of either, route to language specific pipeline\n        \"\"\"\n\n        # only return a list if given a list\n        singleton_input = not isinstance(doc, list)\n        if singleton_input:\n            docs = [doc]\n        else:\n            docs = doc\n\n        if docs and isinstance(docs[0], str):\n            docs = [Document([], text=text) for text in docs]\n\n        # run language identification\n        docs_w_langid = self.lang_id_pipeline.process(docs)\n\n        # create language specific batches, store global idx with each doc\n        lang_batches = {}\n        for doc_idx, doc in enumerate(docs_w_langid):\n            logger.debug(\"Language for document %d: %s\", doc_idx, doc.lang)\n            if doc.lang not in lang_batches:\n                lang_batches[doc.lang] = []\n            lang_batches[doc.lang].append(doc)\n\n        # run through each language, submit a batch to the language specific pipeline\n        for lang in lang_batches.keys():\n            self._update_pipeline_cache(lang)\n            self.pipeline_cache[lang](lang_batches[lang])\n\n        # only return a list if given a list\n        if singleton_input:\n            return docs_w_langid[0]\n        else:\n            return docs_w_langid\n\n    def __call__(self, doc):\n        doc = self.process(doc)\n        return doc\n\n"
  },
  {
    "path": "stanza/pipeline/mwt_processor.py",
    "content": "\"\"\"\nProcessor for performing multi-word-token expansion\n\"\"\"\n\nimport io\n\nimport torch\n\nfrom stanza.models.mwt.data import DataLoader\nfrom stanza.models.mwt.trainer import Trainer\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\n\n@register_processor(MWT)\nclass MWTProcessor(UDProcessor):\n\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([MWT])\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([TOKENIZE])\n\n    def _set_up_model(self, config, pipeline, device):\n        self._trainer = Trainer(model_file=config['model_path'], device=device)\n\n    def build_batch(self, document):\n        return DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True)\n\n    def process(self, document):\n        batch = self.build_batch(document)\n\n        # process the rest\n        expansions = batch.doc.get_mwt_expansions(evaluation=True)\n        if len(batch) > 0:\n            # decide trainer type and run eval\n            if self.config['dict_only']:\n                preds = self.trainer.predict_dict(expansions)\n            else:\n                with torch.no_grad():\n                    preds = []\n                    for i, b in enumerate(batch.to_loader()):\n                        preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab)\n\n                if self.config.get('ensemble_dict', False):\n                    preds = self.trainer.ensemble(expansions, preds)\n        else:\n            # skip eval if dev data does not exist\n            preds = []\n\n        batch.doc.set_mwt_expansions(preds, process_manual_expanded=False)\n        return batch.doc\n\n    def bulk_process(self, docs):\n        \"\"\"\n        MWT processor counts some statistics on the individual docs, so we need to separately redo those stats\n        \"\"\"\n        docs = super().bulk_process(docs)\n        for doc in docs:\n            doc._count_words()\n        return docs\n"
  },
  {
    "path": "stanza/pipeline/ner_processor.py",
    "content": "\"\"\"\nProcessor for performing named entity tagging.\n\"\"\"\n\nimport torch\n\nimport logging\n\nfrom stanza.models.common import doc\nfrom stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError\nfrom stanza.models.common.utils import unsort\nfrom stanza.models.ner.data import DataLoader\nfrom stanza.models.ner.trainer import Trainer\nfrom stanza.models.ner.utils import merge_tags\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\n\nlogger = logging.getLogger('stanza')\n\n@register_processor(name=NER)\nclass NERProcessor(UDProcessor):\n\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([NER])\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([TOKENIZE])\n\n    def _get_dependencies(self, config, dep_name):\n        dependencies = config.get(dep_name, None)\n        if dependencies is not None:\n            dependencies = dependencies.split(\";\")\n            dependencies = [x if x else None for x in dependencies]\n        else:\n            dependencies = [x.get(dep_name) for x in config.get('dependencies', [])]\n        return dependencies\n\n    def _set_up_model(self, config, pipeline, device):\n        # set up trainer\n        model_paths = config.get('model_path')\n        if isinstance(model_paths, str):\n            model_paths = model_paths.split(\";\")\n\n        charlm_forward_files = self._get_dependencies(config, 'forward_charlm_path')\n        charlm_backward_files = self._get_dependencies(config, 'backward_charlm_path')\n        pretrain_files = self._get_dependencies(config, 'pretrain_path')\n\n        # allow predict_tagset to be specified as an int\n        # (which only applies to the first model)\n        # or as a string \";\" separated list of ints\n        self._predict_tagset = {}\n        predict_tagset = config.get('predict_tagset', None)\n        if predict_tagset:\n            if isinstance(predict_tagset, int):\n                self._predict_tagset[0] = predict_tagset\n            else:\n                predict_tagset = predict_tagset.split(\";\")\n                for piece_idx, piece in enumerate(predict_tagset):\n                    if piece:\n                        self._predict_tagset[piece_idx] = int(piece)\n\n        self.trainers = []\n        for (model_path, pretrain_path, charlm_forward, charlm_backward) in zip(model_paths, pretrain_files, charlm_forward_files, charlm_backward_files):\n            logger.debug(\"Loading %s with pretrain %s, forward charlm %s, backward charlm %s\", model_path, pretrain_path, charlm_forward, charlm_backward)\n            pretrain = pipeline.foundation_cache.load_pretrain(pretrain_path) if pretrain_path else None\n            args = {'charlm_forward_file': charlm_forward,\n                    'charlm_backward_file': charlm_backward}\n\n            predict_tagset = self._predict_tagset.get(len(self.trainers), None)\n            if predict_tagset is not None:\n                args['predict_tagset'] = predict_tagset\n\n            try:\n                trainer = Trainer(args=args, model_file=model_path, pretrain=pretrain, device=device, foundation_cache=pipeline.foundation_cache)\n            except ForwardCharlmNotFoundError as e:\n                raise ForwardCharlmNotFoundError(\"Could not find the forward charlm %s.  Please specify the correct path with ner_forward_charlm_path\" % e.filename, e.filename) from None\n            except BackwardCharlmNotFoundError as e:\n                raise BackwardCharlmNotFoundError(\"Could not find the backward charlm %s.  Please specify the correct path with ner_backward_charlm_path\" % e.filename, e.filename) from None\n            self.trainers.append(trainer)\n\n        self._trainer = self.trainers[0]\n        self.model_paths = model_paths\n\n    def _set_up_final_config(self, config):\n        \"\"\" Finalize the configurations for this processor, based off of values from a UD model. \"\"\"\n        # set configurations from loaded model\n        if len(self.trainers) == 0:\n            raise RuntimeError(\"Somehow there are no models loaded!\")\n        self._vocab = self.trainers[0].vocab\n        self.configs = []\n        for trainer in self.trainers:\n            loaded_args = trainer.args\n            # filter out unneeded args from model\n            loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}\n            loaded_args.update(config)\n            self.configs.append(loaded_args)\n        self._config = self.configs[0]\n\n    def __str__(self):\n        return \"NERProcessor(%s)\" % \";\".join(self.model_paths)\n\n    def mark_inactive(self):\n        \"\"\" Drop memory intensive resources if keeping this processor around for reasons other than running it. \"\"\"\n        super().mark_inactive()\n        self.trainers = None\n\n    def process(self, document):\n        with torch.no_grad():\n            all_preds = []\n            for trainer, config in zip(self.trainers, self.configs):\n                # set up a eval-only data loader and skip tag preprocessing\n                batch = DataLoader(document, config['batch_size'], config, vocab=trainer.vocab, evaluation=True, preprocess_tags=False, bert_tokenizer=trainer.model.bert_tokenizer)\n                preds = []\n                for i, b in enumerate(batch):\n                    preds += trainer.predict(b)\n                all_preds.append(preds)\n        # for each sentence, gather a list of predictions\n        # merge those predictions into a single list\n        # earlier models will have precedence\n        preds = [merge_tags(*x) for x in zip(*all_preds)]\n        batch.doc.set([doc.NER], [y for x in preds for y in x], to_token=True)\n        batch.doc.set([doc.MULTI_NER], [tuple(y) for x in zip(*all_preds) for y in zip(*x)], to_token=True)\n        # collect entities into document attribute\n        total = len(batch.doc.build_ents())\n        logger.debug(f'{total} entities found in document.')\n        return batch.doc\n\n    def bulk_process(self, docs):\n        \"\"\"\n        NER processor has a collation step after running inference\n        \"\"\"\n        docs = super().bulk_process(docs)\n        for doc in docs:\n            doc.build_ents()\n        return docs\n\n    def get_known_tags(self, model_idx=0):\n        \"\"\"\n        Return the tags known by this model\n\n        Removes the S-, B-, etc, and does not include O\n        Specify model_idx if the processor  has more than one model\n        \"\"\"        \n        return self.trainers[model_idx].get_known_tags()\n"
  },
  {
    "path": "stanza/pipeline/pos_processor.py",
    "content": "\"\"\"\nProcessor for performing part-of-speech tagging\n\"\"\"\n\nimport torch\n\nfrom stanza.models.common import doc\nfrom stanza.models.common.utils import unsort\nfrom stanza.models.common.vocab import VOCAB_PREFIX, CompositeVocab\nfrom stanza.models.pos.data import Dataset\nfrom stanza.models.pos.trainer import Trainer\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\n@register_processor(name=POS)\nclass POSProcessor(UDProcessor):\n\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([POS])\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([TOKENIZE])\n\n    def _set_up_model(self, config, pipeline, device):\n        # get pretrained word vectors\n        self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None\n        args = {'charlm_forward_file': config.get('forward_charlm_path', None),\n                'charlm_backward_file': config.get('backward_charlm_path', None)}\n        # set up trainer\n        self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], device=device, args=args, foundation_cache=pipeline.foundation_cache)\n        self._tqdm = 'tqdm' in config and config['tqdm']\n\n    def __str__(self):\n        return \"POSProcessor(%s)\" % self.config['model_path']\n\n    def get_known_xpos(self):\n        \"\"\"\n        Returns the xpos tags known by this model\n        \"\"\"\n        if isinstance(self.vocab['xpos'], CompositeVocab):\n            if len(self.vocab['xpos']) == 1:\n                return [k for k in self.vocab['xpos'][0]._unit2id.keys() if k not in VOCAB_PREFIX]\n            else:\n                return {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['xpos']._unit2id.items()}\n        return [k for k in self.vocab['xpos']._unit2id.keys() if k not in VOCAB_PREFIX]\n\n    def is_composite_xpos(self):\n        \"\"\"\n        Returns if the xpos tags are part of a composite vocab\n        \"\"\"\n        return isinstance(self.vocab['xpos'], CompositeVocab)\n\n    def get_known_upos(self):\n        \"\"\"\n        Returns the upos tags known by this model\n        \"\"\"\n        keys = [k for k in self.vocab['upos']._unit2id.keys() if k not in VOCAB_PREFIX]\n        return keys\n\n    def get_known_feats(self):\n        \"\"\"\n        Returns the features known by this model\n        \"\"\"\n        values = {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['feats']._unit2id.items()}\n        return values\n\n    def process(self, document):\n        # currently, POS models are saved w/o the batch_maximum_tokens flag\n        maximum_tokens = self.config.get('batch_maximum_tokens', 5000)\n\n        dataset = Dataset(\n            document, self.config, self.pretrain, vocab=self.vocab, evaluation=True,\n            sort_during_eval=True)\n        batch = iter(dataset.to_length_limited_loader(batch_size=self.config['batch_size'], maximum_tokens=maximum_tokens))\n        preds = []\n\n        idx = []\n        with torch.no_grad():\n            if self._tqdm:\n                batch = tqdm(batch)\n            for i, b in enumerate(batch):\n                idx.extend(b[-1])\n                preds += self.trainer.predict(b)\n\n        preds = unsort(preds, idx)\n        dataset.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x])\n        return dataset.doc\n"
  },
  {
    "path": "stanza/pipeline/processor.py",
    "content": "\"\"\"\nBase classes for processors\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nfrom stanza.models.common.doc import Document\nfrom stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS\n\nclass ProcessorRequirementsException(Exception):\n    \"\"\" Exception indicating a processor's requirements will not be met \"\"\"\n\n    def __init__(self, processors_list, err_processor, provided_reqs):\n        self._err_processor = err_processor\n        # mark the broken processor as inactive, drop resources\n        self.err_processor.mark_inactive()\n        self._processors_list = processors_list\n        self._provided_reqs = provided_reqs\n        self.build_message()\n\n    @property\n    def err_processor(self):\n        \"\"\" The processor that raised the exception \"\"\"\n        return self._err_processor\n\n    @property\n    def processor_type(self):\n        return type(self.err_processor).__name__\n\n    @property\n    def processors_list(self):\n        return self._processors_list\n\n    @property\n    def provided_reqs(self):\n        return self._provided_reqs\n\n    def build_message(self):\n        self.message = (f\"---\\nPipeline Requirements Error!\\n\"\n                        f\"\\tProcessor: {self.processor_type}\\n\"\n                        f\"\\tPipeline processors list: {','.join(self.processors_list)}\\n\"\n                        f\"\\tProcessor Requirements: {self.err_processor.requires}\\n\"\n                        f\"\\t\\t- fulfilled: {self.err_processor.requires.intersection(self.provided_reqs)}\\n\"\n                        f\"\\t\\t- missing: {self.err_processor.requires - self.provided_reqs}\\n\"\n                        f\"\\nThe processors list provided for this pipeline is invalid.  Please make sure all \"\n                        f\"prerequisites are met for every processor.\\n\\n\")\n\n    def __str__(self):\n        return self.message\n\n\nclass Processor(ABC):\n    \"\"\" Base class for all processors \"\"\"\n\n    def __init__(self, config, pipeline, device):\n        # overall config for the processor\n        self._config = config\n        # pipeline building this processor (presently processors are only meant to exist in one pipeline)\n        self._pipeline = pipeline\n        self._set_up_variants(config, device)\n        # run set up process\n        # set up what annotations are required based on config\n        if not self._set_up_variant_requires():\n            self._set_up_requires()\n        # set up what annotations are provided based on config\n        self._set_up_provides()\n        # given pipeline constructing this processor, check if requirements are met, throw exception if not\n        self._check_requirements()\n\n        if hasattr(self, '_variant') and self._variant.OVERRIDE:\n            self.process = self._variant.process\n\n    def __str__(self):\n        \"\"\"\n        Simple description of the processor: name(model)\n        \"\"\"\n        name = self.__class__.__name__\n        model = None\n        if self._config is not None:\n            model = self._config.get('model_path')\n        if model is None:\n            return name\n        else:\n            return \"{}({})\".format(name, model)\n\n\n    @abstractmethod\n    def process(self, doc):\n        \"\"\" Process a Document.  This is the main method of a processor. \"\"\"\n        pass\n\n    def bulk_process(self, docs):\n        \"\"\" Process a list of Documents. This should be replaced with a more efficient implementation if possible. \"\"\"\n\n        if hasattr(self, '_variant'):\n            return self._variant.bulk_process(docs)\n\n        return [self.process(doc) for doc in docs]\n\n    def _set_up_provides(self):\n        \"\"\" Set up what processor requirements this processor fulfills.  Default is to use a class defined list. \"\"\"\n        self._provides = self.__class__.PROVIDES_DEFAULT\n\n    def _set_up_requires(self):\n        \"\"\" Set up requirements for this processor.  Default is to use a class defined list. \"\"\"\n        self._requires = self.__class__.REQUIRES_DEFAULT\n\n    def _set_up_variant_requires(self):\n        \"\"\"\n        If this has a variant with its own requirements, use those instead\n\n        Returns True iff the _requires is set from the _variant\n        \"\"\"\n        if not hasattr(self, '_variant'):\n            return False\n        if hasattr(self._variant, '_set_up_requires'):\n            self._variant._set_up_requires()\n            self._requires = self._variant._requires\n            return True\n        if hasattr(self._variant.__class__, 'REQUIRES_DEFAULT'):\n            self._requires = self._variant.__class__.REQUIRES_DEFAULT\n            return True\n        return False\n\n    def _set_up_variants(self, config, device):\n        processor_name = list(self.__class__.PROVIDES_DEFAULT)[0]\n        if any(config.get(f'with_{variant}', False) for variant in PROCESSOR_VARIANTS[processor_name]):\n            self._trainer = None\n            variant_name = [variant for variant in PROCESSOR_VARIANTS[processor_name] if config.get(f'with_{variant}', False)][0]\n            self._variant = PROCESSOR_VARIANTS[processor_name][variant_name](config)\n\n    @property\n    def config(self):\n        \"\"\" Configurations for the processor \"\"\"\n        return self._config\n\n    @property\n    def pipeline(self):\n        \"\"\" The pipeline that this processor belongs to \"\"\"\n        return self._pipeline\n\n    @property\n    def provides(self):\n        return self._provides\n\n    @property\n    def requires(self):\n        return self._requires\n\n    def _check_requirements(self):\n        \"\"\" Given a list of fulfilled requirements, check if all of this processor's requirements are met or not. \"\"\"\n        if not self.config.get(\"check_requirements\", True):\n            return\n        provided_reqs = set.union(*[processor.provides for processor in self.pipeline.loaded_processors]+[set([])])\n        if self.requires - provided_reqs:\n            load_names = [item[0] for item in self.pipeline.load_list]\n            raise ProcessorRequirementsException(load_names, self, provided_reqs)\n\n\nclass ProcessorVariant(ABC):\n    \"\"\" Base class for all processor variants \"\"\"\n\n    OVERRIDE = False # Set to true to override all the processing from the processor\n\n    @abstractmethod\n    def process(self, doc):\n        \"\"\"\n        Process a document that is potentially preprocessed by the processor.\n        This is the main method of a processor variant.\n\n        If `OVERRIDE` is set to True, all preprocessing by the processor would be bypassed, and the processor variant\n        would serve as a drop-in replacement of the entire processor, and has to be able to interpret all the configs\n        that are typically handled by the processor it replaces.\n        \"\"\"\n        pass\n\n    def bulk_process(self, docs):\n        \"\"\" Process a list of Documents. This should be replaced with a more efficient implementation if possible. \"\"\"\n\n        return [self.process(doc) for doc in docs]\n\nclass UDProcessor(Processor):\n    \"\"\" Base class for the neural UD Processors (tokenize,mwt,pos,lemma,depparse,sentiment,constituency) \"\"\"\n\n    def __init__(self, config, pipeline, device):\n        super().__init__(config, pipeline, device)\n\n        # UD model resources, set up is processor specific\n        self._pretrain = None\n        self._trainer = None\n        self._vocab = None\n        if not hasattr(self, '_variant'):\n            self._set_up_model(config, pipeline, device)\n\n        # build the final config for the processor\n        self._set_up_final_config(config)\n\n    @abstractmethod\n    def _set_up_model(self, config, pipeline, device):\n        pass\n\n    def _set_up_final_config(self, config):\n        \"\"\" Finalize the configurations for this processor, based off of values from a UD model. \"\"\"\n        # set configurations from loaded model\n        if self._trainer is not None:\n            loaded_args, self._vocab = self._trainer.args, self._trainer.vocab\n            # filter out unneeded args from model\n            loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}\n        else:\n            loaded_args = {}\n        loaded_args.update(config)\n        self._config = loaded_args\n\n    def mark_inactive(self):\n        \"\"\" Drop memory intensive resources if keeping this processor around for reasons other than running it. \"\"\"\n        self._trainer = None\n        self._vocab = None\n\n    @property\n    def pretrain(self):\n        return self._pretrain\n\n    @property\n    def trainer(self):\n        return self._trainer\n\n    @property\n    def vocab(self):\n        return self._vocab\n\n    @staticmethod\n    def filter_out_option(option):\n        \"\"\" Filter out non-processor configurations \"\"\"\n        options_to_filter = ['device', 'cpu', 'cuda', 'dev_conll_gold', 'epochs', 'lang', 'mode', 'save_name', 'shorthand']\n        if option.endswith('_file') or option.endswith('_dir'):\n            return True\n        elif option in options_to_filter:\n            return True\n        else:\n            return False\n\n    def bulk_process(self, docs):\n        \"\"\"\n        Most processors operate on the sentence level, where each sentence is processed independently and processors can benefit\n        a lot from the ability to combine sentences from multiple documents for faster batched processing. This is a transparent\n        implementation that allows these processors to batch process a list of Documents as if they were from a single Document.\n        \"\"\"\n\n        if hasattr(self, '_variant'):\n            return self._variant.bulk_process(docs)\n\n        combined_sents = [sent for doc in docs for sent in doc.sentences]\n        combined_doc = Document([])\n        combined_doc.sentences = combined_sents\n        combined_doc.num_tokens = sum(doc.num_tokens for doc in docs)\n        combined_doc.num_words = sum(doc.num_words for doc in docs)\n\n        self.process(combined_doc) # annotations are attached to sentence objects\n\n        return docs\n\nclass ProcessorRegisterException(Exception):\n    \"\"\" Exception indicating processor or processor registration failure \"\"\"\n\n    def __init__(self, processor_class, expected_parent):\n        self._processor_class = processor_class\n        self._expected_parent = expected_parent\n        self.build_message()\n\n    def build_message(self):\n        self.message = f\"Failed to register '{self._processor_class}'. It must be a subclass of '{self._expected_parent}'.\"\n\n    def __str__(self):\n        return self.message\n\ndef register_processor(name):\n    def wrapper(Cls):\n        if not issubclass(Cls, Processor):\n            raise ProcessorRegisterException(Cls, Processor)\n\n        NAME_TO_PROCESSOR_CLASS[name] = Cls\n        PIPELINE_NAMES.append(name)\n        return Cls\n    return wrapper\n\ndef register_processor_variant(name, variant):\n    def wrapper(Cls):\n        if not issubclass(Cls, ProcessorVariant):\n            raise ProcessorRegisterException(Cls, ProcessorVariant)\n\n        PROCESSOR_VARIANTS[name][variant] = Cls\n        return Cls\n    return wrapper\n"
  },
  {
    "path": "stanza/pipeline/registry.py",
    "content": "from collections import defaultdict\n\n# these two get filled by register_processor\nNAME_TO_PROCESSOR_CLASS = dict()\nPIPELINE_NAMES = []\n\n# this gets filled by register_processor_variant\nPROCESSOR_VARIANTS = defaultdict(dict)\n"
  },
  {
    "path": "stanza/pipeline/sentiment_processor.py",
    "content": "\"\"\"Processor that attaches a sentiment score to a sentence\n\nThe model used is a generally a model trained on the Stanford\nSentiment Treebank or some similar dataset.  When run, this processor\nattaches a score in the form of a string to each sentence in the\ndocument.\n\nTODO: a possible way to generalize this would be to make it a\nClassifierProcessor and have \"sentiment\" be an option.\n\"\"\"\n\nimport dataclasses\nimport torch\n\nfrom types import SimpleNamespace\n\nfrom stanza.models.classifiers.trainer import Trainer\n\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\n\n@register_processor(SENTIMENT)\nclass SentimentProcessor(UDProcessor):\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([SENTIMENT])\n    # set of processor requirements for this processor\n    # TODO: a constituency based model needs CONSTITUENCY as well\n    # issue: by the time we load the model in Processor.__init__,\n    # the requirements are already prepared\n    REQUIRES_DEFAULT = set([TOKENIZE])\n\n    # default batch size, measured in words per batch\n    DEFAULT_BATCH_SIZE = 5000\n\n    def _set_up_model(self, config, pipeline, device):\n        # get pretrained word vectors\n        pretrain_path = config.get('pretrain_path', None)\n        forward_charlm_path = config.get('forward_charlm_path', None)\n        backward_charlm_path = config.get('backward_charlm_path', None)\n        # elmo does not have a convenient way to download intermediate\n        # models the way stanza downloads charlms & pretrains or\n        # transformers downloads bert etc\n        # however, elmo in general is not as good as using a\n        # transformer, so it is unlikely we will ever fix this\n        args = SimpleNamespace(device = device,\n                               charlm_forward_file = forward_charlm_path,\n                               charlm_backward_file = backward_charlm_path,\n                               wordvec_pretrain_file = pretrain_path,\n                               elmo_model = None,\n                               use_elmo = False,\n                               save_dir = None)\n        filename = config['model_path']\n        if filename is None:\n            raise FileNotFoundError(\"No model specified for the sentiment processor.  Perhaps it is not supported for the language.  {}\".format(config))\n        # set up model\n        trainer = Trainer.load(filename=filename,\n                               args=args,\n                               foundation_cache=pipeline.foundation_cache)\n        self._trainer = trainer\n        self._model = trainer.model\n        self._model_type = self._model.config.model_type\n        # batch size counted as words\n        self._batch_size = config.get('batch_size', SentimentProcessor.DEFAULT_BATCH_SIZE)\n\n    def _set_up_final_config(self, config):\n        loaded_args = dataclasses.asdict(self._model.config)\n        loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}\n        loaded_args.update(config)\n        self._config = loaded_args\n\n\n    def process(self, document):\n        sentences = self._model.extract_sentences(document)\n        with torch.no_grad():\n            labels = self._model.label_sentences(sentences, batch_size=self._batch_size)\n        # TODO: allow a classifier processor for any attribute, not just sentiment\n        document.set(SENTIMENT, labels, to_sentence=True)\n        return document\n"
  },
  {
    "path": "stanza/pipeline/tokenize_processor.py",
    "content": "\"\"\"\nProcessor for performing tokenization\n\"\"\"\n\nimport copy\nimport io\nimport logging\n\nimport torch\n\nfrom stanza.models.tokenization.data import TokenizationDataset\nfrom stanza.models.tokenization.trainer import Trainer\nfrom stanza.models.tokenization.utils import output_predictions\nfrom stanza.pipeline._constants import *\nfrom stanza.pipeline.processor import UDProcessor, register_processor\nfrom stanza.pipeline.registry import PROCESSOR_VARIANTS\nfrom stanza.models.common import doc\n\n# these imports trigger the \"register_variant\" decorations\nfrom stanza.pipeline.external.jieba import JiebaTokenizer\nfrom stanza.pipeline.external.spacy import SpacyTokenizer\nfrom stanza.pipeline.external.sudachipy import SudachiPyTokenizer\nfrom stanza.pipeline.external.pythainlp import PyThaiNLPTokenizer\n\nlogger = logging.getLogger('stanza')\n\nTOKEN_TOO_LONG_REPLACEMENT = \"<UNK>\"\n\n# class for running the tokenizer\n@register_processor(name=TOKENIZE)\nclass TokenizeProcessor(UDProcessor):\n\n    # set of processor requirements this processor fulfills\n    PROVIDES_DEFAULT = set([TOKENIZE])\n    # set of processor requirements for this processor\n    REQUIRES_DEFAULT = set([])\n    # default max sequence length\n    MAX_SEQ_LENGTH_DEFAULT = 1000\n\n    def _set_up_model(self, config, pipeline, device):\n        # set up trainer\n        if config.get('pretokenized'):\n            self._trainer = None\n        else:\n            args = {'charlm_forward_file': config.get('forward_charlm_path', None)}\n            self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)\n\n        # get and typecheck the postprocessor\n        postprocessor = config.get('postprocessor')\n        if postprocessor and callable(postprocessor):\n            self._postprocessor = postprocessor\n        elif not postprocessor:\n            self._postprocessor = None\n        else:\n            raise ValueError(\"Tokenizer received 'postprocessor' option of unrecognized type; postprocessor must be callable. Got %s\" % postprocessor)\n\n    def process_pre_tokenized_text(self, input_src):\n        \"\"\"\n        Pretokenized text can be provided in 2 manners:\n\n        1.) str, tokenized by whitespace, sentence split by newline\n        2.) list of token lists, each token list represents a sentence\n\n        generate dictionary data structure\n        \"\"\"\n\n        document = []\n        if isinstance(input_src, str):\n            sentences = [sent.strip().split() for sent in input_src.strip().split('\\n') if len(sent.strip()) > 0]\n        elif isinstance(input_src, list):\n            sentences = input_src\n        idx = 0\n        for sentence in sentences:\n            sent = []\n            for token_id, token in enumerate(sentence):\n                sent.append({doc.ID: (token_id + 1, ), doc.TEXT: token, doc.MISC: f'start_char={idx}|end_char={idx + len(token)}'})\n                idx += len(token) + 1\n            document.append(sent)\n        raw_text = ' '.join([' '.join(sentence) for sentence in sentences])\n        return raw_text, document\n\n    def process(self, document):\n        if not (isinstance(document, str) or isinstance(document, doc.Document) or (self.config.get('pretokenized') or self.config.get('no_ssplit', False))):\n            raise ValueError(\"If neither 'pretokenized' or 'no_ssplit' option is enabled, the input to the TokenizerProcessor must be a string or a Document object.  Got %s\" % str(type(document)))\n\n        if isinstance(document, doc.Document):\n            if self.config.get('pretokenized'):\n                return document\n            document = document.text\n\n        if self.config.get('pretokenized'):\n            raw_text, document = self.process_pre_tokenized_text(document)\n            return doc.Document(document, raw_text)\n\n        if hasattr(self, '_variant'):\n            return self._variant.process(document)\n\n        raw_text = '\\n\\n'.join(document) if isinstance(document, list) else document\n\n        max_seq_len = self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT)\n\n        # set up batches\n        batches = TokenizationDataset(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True, dictionary=self.trainer.dictionary)\n        # get dict data\n        with torch.no_grad():\n            _, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None,\n                                                   max_seq_len,\n                                                   orig_text=raw_text,\n                                                   no_ssplit=self.config.get('no_ssplit', False),\n                                                   num_workers = self.config.get('num_workers', 0),\n                                                   postprocessor = self._postprocessor)\n\n        # replace excessively long tokens with <UNK> to avoid downstream GPU memory issues in POS\n        for sentence in document:\n            for token in sentence:\n                if len(token['text']) > max_seq_len:\n                    token['text'] = TOKEN_TOO_LONG_REPLACEMENT\n\n        return doc.Document(document, raw_text)\n\n    def bulk_process(self, docs):\n        \"\"\"\n        The tokenizer cannot use UDProcessor's sentence-level cross-document batching interface, and requires special handling.\n        Essentially, this method concatenates the text of multiple documents with \"\\n\\n\", tokenizes it with the neural tokenizer,\n        then splits the result into the original Documents and recovers the original character offsets.\n        \"\"\"\n        if hasattr(self, '_variant'):\n            return self._variant.bulk_process(docs)\n\n        if self.config.get('pretokenized'):\n            res = []\n            for document in docs:\n                if len(document.sentences) > 0:\n                    # perhaps this is a document already tokenized,\n                    # being sent back in for more analysis / reparsing / etc?\n                    # in that case, no need to try to tokenize it\n                    # based on whitespace tokenizing the document text\n                    # which, interestingly, may not even exist depending on\n                    # how the document was created)\n                    # by making a whole deepcopy, the original Document is unchanged\n                    res.append(copy.deepcopy(document))\n                else:\n                    raw_text, document = self.process_pre_tokenized_text(document.text)\n                    res.append(doc.Document(document, raw_text))\n            return res\n\n        combined_text = '\\n\\n'.join([thisdoc.text for thisdoc in docs])\n        processed_combined = self.process(doc.Document([], text=combined_text))\n\n        # postprocess sentences and tokens to reset back pointers and char offsets\n        charoffset = 0\n        sentst = senten = 0\n        for thisdoc in docs:\n            while senten < len(processed_combined.sentences) and processed_combined.sentences[senten].tokens[-1].end_char - charoffset <= len(thisdoc.text):\n                senten += 1\n\n            sentences = processed_combined.sentences[sentst:senten]\n            thisdoc.sentences = sentences\n            for sent in sentences:\n                # fix doc back pointers for sentences\n                sent._doc = thisdoc\n\n                # fix char offsets for tokens and words\n                for token in sent.tokens:\n                    token._start_char -= charoffset\n                    token._end_char -= charoffset\n                    if token.words:  # not-yet-processed MWT can leave empty tokens\n                        for word in token.words:\n                            word._start_char -= charoffset\n                            word._end_char -= charoffset\n\n            # Here we need to fix up the SpacesAfter for the very last token\n            # and the SpacesBefore for the first token of the next doc\n            # After all, we had connected the text with \\n\\n\n            # Need to be careful about this - in a case such as\n            #   \" -text one- \"\n            #   \" -text two- \"\n            # We want the SpacesBefore for the second document to reflect\n            # the extra space at the start of its text\n            # and the SpacesAfter for the first document to reflect\n            # the whitespace after its text\n            if len(sentences) > 0:\n                last_token = sentences[-1].tokens[-1]\n                last_whitespace = thisdoc.text[last_token.end_char:]\n                last_token.spaces_after = last_whitespace\n\n                first_token = sentences[0].tokens[0]\n                first_whitespace = thisdoc.text[:first_token.start_char]\n                first_token.spaces_before = first_whitespace\n\n            thisdoc.num_tokens = sum(len(sent.tokens) for sent in sentences)\n            thisdoc.num_words = sum(len(sent.words) for sent in sentences)\n            sentst = senten\n\n            charoffset += len(thisdoc.text) + 2\n\n        return docs\n"
  },
  {
    "path": "stanza/protobuf/CoreNLP_pb2.py",
    "content": "# -*- coding: utf-8 -*-\n# Generated by the protocol buffer compiler.  DO NOT EDIT!\n# source: CoreNLP.proto\n# Protobuf Python Version: 4.25.5\n\"\"\"Generated protocol buffer code.\"\"\"\nfrom google.protobuf import descriptor as _descriptor\nfrom google.protobuf import descriptor_pool as _descriptor_pool\nfrom google.protobuf import symbol_database as _symbol_database\nfrom google.protobuf.internal import builder as _builder\n# @@protoc_insertion_point(imports)\n\n_sym_db = _symbol_database.Default()\n\n\n\n\nDESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\\n\\rCoreNLP.proto\\x12\\x19\\x65\\x64u.stanford.nlp.pipeline\\\"\\xe1\\x05\\n\\x08\\x44ocument\\x12\\x0c\\n\\x04text\\x18\\x01 \\x02(\\t\\x12\\x35\\n\\x08sentence\\x18\\x02 \\x03(\\x0b\\x32#.edu.stanford.nlp.pipeline.Sentence\\x12\\x39\\n\\ncorefChain\\x18\\x03 \\x03(\\x0b\\x32%.edu.stanford.nlp.pipeline.CorefChain\\x12\\r\\n\\x05\\x64ocID\\x18\\x04 \\x01(\\t\\x12\\x0f\\n\\x07\\x64ocDate\\x18\\x07 \\x01(\\t\\x12\\x10\\n\\x08\\x63\\x61lendar\\x18\\x08 \\x01(\\x04\\x12;\\n\\x11sentencelessToken\\x18\\x05 \\x03(\\x0b\\x32 .edu.stanford.nlp.pipeline.Token\\x12\\x33\\n\\tcharacter\\x18\\n \\x03(\\x0b\\x32 .edu.stanford.nlp.pipeline.Token\\x12/\\n\\x05quote\\x18\\x06 \\x03(\\x0b\\x32 .edu.stanford.nlp.pipeline.Quote\\x12\\x37\\n\\x08mentions\\x18\\t \\x03(\\x0b\\x32%.edu.stanford.nlp.pipeline.NERMention\\x12#\\n\\x1bhasEntityMentionsAnnotation\\x18\\r \\x01(\\x08\\x12\\x0e\\n\\x06xmlDoc\\x18\\x0b \\x01(\\x08\\x12\\x34\\n\\x08sections\\x18\\x0c \\x03(\\x0b\\x32\\\".edu.stanford.nlp.pipeline.Section\\x12<\\n\\x10mentionsForCoref\\x18\\x0e \\x03(\\x0b\\x32\\\".edu.stanford.nlp.pipeline.Mention\\x12!\\n\\x19hasCorefMentionAnnotation\\x18\\x0f \\x01(\\x08\\x12\\x1a\\n\\x12hasCorefAnnotation\\x18\\x10 \\x01(\\x08\\x12+\\n#corefMentionToEntityMentionMappings\\x18\\x11 \\x03(\\x05\\x12+\\n#entityMentionToCorefMentionMappings\\x18\\x12 \\x03(\\x05*\\x05\\x08\\x64\\x10\\x80\\x02\\\"\\xf3\\x0f\\n\\x08Sentence\\x12/\\n\\x05token\\x18\\x01 \\x03(\\x0b\\x32 .edu.stanford.nlp.pipeline.Token\\x12\\x18\\n\\x10tokenOffsetBegin\\x18\\x02 \\x02(\\r\\x12\\x16\\n\\x0etokenOffsetEnd\\x18\\x03 \\x02(\\r\\x12\\x15\\n\\rsentenceIndex\\x18\\x04 \\x01(\\r\\x12\\x1c\\n\\x14\\x63haracterOffsetBegin\\x18\\x05 \\x01(\\r\\x12\\x1a\\n\\x12\\x63haracterOffsetEnd\\x18\\x06 \\x01(\\r\\x12\\x37\\n\\tparseTree\\x18\\x07 \\x01(\\x0b\\x32$.edu.stanford.nlp.pipeline.ParseTree\\x12@\\n\\x12\\x62inarizedParseTree\\x18\\x1f \\x01(\\x0b\\x32$.edu.stanford.nlp.pipeline.ParseTree\\x12@\\n\\x12\\x61nnotatedParseTree\\x18  \\x01(\\x0b\\x32$.edu.stanford.nlp.pipeline.ParseTree\\x12\\x11\\n\\tsentiment\\x18! \\x01(\\t\\x12=\\n\\x0fkBestParseTrees\\x18\\\" \\x03(\\x0b\\x32$.edu.stanford.nlp.pipeline.ParseTree\\x12\\x45\\n\\x11\\x62\\x61sicDependencies\\x18\\x08 \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12I\\n\\x15\\x63ollapsedDependencies\\x18\\t \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12T\\n collapsedCCProcessedDependencies\\x18\\n \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12K\\n\\x17\\x61lternativeDependencies\\x18\\r \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12?\\n\\x0copenieTriple\\x18\\x0e \\x03(\\x0b\\x32).edu.stanford.nlp.pipeline.RelationTriple\\x12<\\n\\tkbpTriple\\x18\\x10 \\x03(\\x0b\\x32).edu.stanford.nlp.pipeline.RelationTriple\\x12\\x45\\n\\x10\\x65ntailedSentence\\x18\\x0f \\x03(\\x0b\\x32+.edu.stanford.nlp.pipeline.SentenceFragment\\x12\\x43\\n\\x0e\\x65ntailedClause\\x18# \\x03(\\x0b\\x32+.edu.stanford.nlp.pipeline.SentenceFragment\\x12H\\n\\x14\\x65nhancedDependencies\\x18\\x11 \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12P\\n\\x1c\\x65nhancedPlusPlusDependencies\\x18\\x12 \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12\\x33\\n\\tcharacter\\x18\\x13 \\x03(\\x0b\\x32 .edu.stanford.nlp.pipeline.Token\\x12\\x11\\n\\tparagraph\\x18\\x0b \\x01(\\r\\x12\\x0c\\n\\x04text\\x18\\x0c \\x01(\\t\\x12\\x12\\n\\nlineNumber\\x18\\x14 \\x01(\\r\\x12\\x1e\\n\\x16hasRelationAnnotations\\x18\\x33 \\x01(\\x08\\x12\\x31\\n\\x06\\x65ntity\\x18\\x34 \\x03(\\x0b\\x32!.edu.stanford.nlp.pipeline.Entity\\x12\\x35\\n\\x08relation\\x18\\x35 \\x03(\\x0b\\x32#.edu.stanford.nlp.pipeline.Relation\\x12$\\n\\x1chasNumerizedTokensAnnotation\\x18\\x36 \\x01(\\x08\\x12\\x37\\n\\x08mentions\\x18\\x37 \\x03(\\x0b\\x32%.edu.stanford.nlp.pipeline.NERMention\\x12<\\n\\x10mentionsForCoref\\x18\\x38 \\x03(\\x0b\\x32\\\".edu.stanford.nlp.pipeline.Mention\\x12\\\"\\n\\x1ahasCorefMentionsAnnotation\\x18\\x39 \\x01(\\x08\\x12\\x12\\n\\nsentenceID\\x18: \\x01(\\t\\x12\\x13\\n\\x0bsectionDate\\x18; \\x01(\\t\\x12\\x14\\n\\x0csectionIndex\\x18< \\x01(\\r\\x12\\x13\\n\\x0bsectionName\\x18= \\x01(\\t\\x12\\x15\\n\\rsectionAuthor\\x18> \\x01(\\t\\x12\\r\\n\\x05\\x64ocID\\x18? \\x01(\\t\\x12\\x15\\n\\rsectionQuoted\\x18@ \\x01(\\x08\\x12#\\n\\x1bhasEntityMentionsAnnotation\\x18\\x41 \\x01(\\x08\\x12\\x1f\\n\\x17hasKBPTriplesAnnotation\\x18\\x44 \\x01(\\x08\\x12\\\"\\n\\x1ahasOpenieTriplesAnnotation\\x18\\x45 \\x01(\\x08\\x12\\x14\\n\\x0c\\x63hapterIndex\\x18\\x42 \\x01(\\r\\x12\\x16\\n\\x0eparagraphIndex\\x18\\x43 \\x01(\\r\\x12=\\n\\x10\\x65nhancedSentence\\x18\\x46 \\x01(\\x0b\\x32#.edu.stanford.nlp.pipeline.Sentence\\x12\\x0f\\n\\x07speaker\\x18G \\x01(\\t\\x12\\x13\\n\\x0bspeakerType\\x18H \\x01(\\t*\\x05\\x08\\x64\\x10\\x80\\x02\\\"\\xf6\\x0c\\n\\x05Token\\x12\\x0c\\n\\x04word\\x18\\x01 \\x01(\\t\\x12\\x0b\\n\\x03pos\\x18\\x02 \\x01(\\t\\x12\\r\\n\\x05value\\x18\\x03 \\x01(\\t\\x12\\x10\\n\\x08\\x63\\x61tegory\\x18\\x04 \\x01(\\t\\x12\\x0e\\n\\x06\\x62\\x65\\x66ore\\x18\\x05 \\x01(\\t\\x12\\r\\n\\x05\\x61\\x66ter\\x18\\x06 \\x01(\\t\\x12\\x14\\n\\x0coriginalText\\x18\\x07 \\x01(\\t\\x12\\x0b\\n\\x03ner\\x18\\x08 \\x01(\\t\\x12\\x11\\n\\tcoarseNER\\x18> \\x01(\\t\\x12\\x16\\n\\x0e\\x66ineGrainedNER\\x18? \\x01(\\t\\x12\\x15\\n\\rnerLabelProbs\\x18\\x42 \\x03(\\t\\x12\\x15\\n\\rnormalizedNER\\x18\\t \\x01(\\t\\x12\\r\\n\\x05lemma\\x18\\n \\x01(\\t\\x12\\x11\\n\\tbeginChar\\x18\\x0b \\x01(\\r\\x12\\x0f\\n\\x07\\x65ndChar\\x18\\x0c \\x01(\\r\\x12\\x11\\n\\tutterance\\x18\\r \\x01(\\r\\x12\\x0f\\n\\x07speaker\\x18\\x0e \\x01(\\t\\x12\\x13\\n\\x0bspeakerType\\x18M \\x01(\\t\\x12\\x12\\n\\nbeginIndex\\x18\\x0f \\x01(\\r\\x12\\x10\\n\\x08\\x65ndIndex\\x18\\x10 \\x01(\\r\\x12\\x17\\n\\x0ftokenBeginIndex\\x18\\x11 \\x01(\\r\\x12\\x15\\n\\rtokenEndIndex\\x18\\x12 \\x01(\\r\\x12\\x34\\n\\ntimexValue\\x18\\x13 \\x01(\\x0b\\x32 .edu.stanford.nlp.pipeline.Timex\\x12\\x15\\n\\rhasXmlContext\\x18\\x15 \\x01(\\x08\\x12\\x12\\n\\nxmlContext\\x18\\x16 \\x03(\\t\\x12\\x16\\n\\x0e\\x63orefClusterID\\x18\\x17 \\x01(\\r\\x12\\x0e\\n\\x06\\x61nswer\\x18\\x18 \\x01(\\t\\x12\\x15\\n\\rheadWordIndex\\x18\\x1a \\x01(\\r\\x12\\x35\\n\\x08operator\\x18\\x1b \\x01(\\x0b\\x32#.edu.stanford.nlp.pipeline.Operator\\x12\\x35\\n\\x08polarity\\x18\\x1c \\x01(\\x0b\\x32#.edu.stanford.nlp.pipeline.Polarity\\x12\\x14\\n\\x0cpolarity_dir\\x18\\' \\x01(\\t\\x12-\\n\\x04span\\x18\\x1d \\x01(\\x0b\\x32\\x1f.edu.stanford.nlp.pipeline.Span\\x12\\x11\\n\\tsentiment\\x18\\x1e \\x01(\\t\\x12\\x16\\n\\x0equotationIndex\\x18\\x1f \\x01(\\x05\\x12\\x42\\n\\x0e\\x63onllUFeatures\\x18  \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.MapStringString\\x12\\x11\\n\\tcoarseTag\\x18! \\x01(\\t\\x12\\x38\\n\\x0f\\x63onllUTokenSpan\\x18\\\" \\x01(\\x0b\\x32\\x1f.edu.stanford.nlp.pipeline.Span\\x12\\x12\\n\\nconllUMisc\\x18# \\x01(\\t\\x12G\\n\\x13\\x63onllUSecondaryDeps\\x18$ \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.MapStringString\\x12\\x17\\n\\x0fwikipediaEntity\\x18% \\x01(\\t\\x12\\x11\\n\\tisNewline\\x18& \\x01(\\x08\\x12\\x0e\\n\\x06gender\\x18\\x33 \\x01(\\t\\x12\\x10\\n\\x08trueCase\\x18\\x34 \\x01(\\t\\x12\\x14\\n\\x0ctrueCaseText\\x18\\x35 \\x01(\\t\\x12\\x13\\n\\x0b\\x63hineseChar\\x18\\x36 \\x01(\\t\\x12\\x12\\n\\nchineseSeg\\x18\\x37 \\x01(\\t\\x12\\x16\\n\\x0e\\x63hineseXMLChar\\x18< \\x01(\\t\\x12\\x11\\n\\tarabicSeg\\x18L \\x01(\\t\\x12\\x13\\n\\x0bsectionName\\x18\\x38 \\x01(\\t\\x12\\x15\\n\\rsectionAuthor\\x18\\x39 \\x01(\\t\\x12\\x13\\n\\x0bsectionDate\\x18: \\x01(\\t\\x12\\x17\\n\\x0fsectionEndLabel\\x18; \\x01(\\t\\x12\\x0e\\n\\x06parent\\x18= \\x01(\\t\\x12\\x19\\n\\x11\\x63orefMentionIndex\\x18@ \\x03(\\r\\x12\\x1a\\n\\x12\\x65ntityMentionIndex\\x18\\x41 \\x01(\\r\\x12\\r\\n\\x05isMWT\\x18\\x43 \\x01(\\x08\\x12\\x12\\n\\nisFirstMWT\\x18\\x44 \\x01(\\x08\\x12\\x0f\\n\\x07mwtText\\x18\\x45 \\x01(\\t\\x12\\x0f\\n\\x07mwtMisc\\x18N \\x01(\\t\\x12\\x14\\n\\x0cnumericValue\\x18\\x46 \\x01(\\x04\\x12\\x13\\n\\x0bnumericType\\x18G \\x01(\\t\\x12\\x1d\\n\\x15numericCompositeValue\\x18H \\x01(\\x04\\x12\\x1c\\n\\x14numericCompositeType\\x18I \\x01(\\t\\x12\\x1c\\n\\x14\\x63odepointOffsetBegin\\x18J \\x01(\\r\\x12\\x1a\\n\\x12\\x63odepointOffsetEnd\\x18K \\x01(\\r\\x12\\r\\n\\x05index\\x18O \\x01(\\r\\x12\\x12\\n\\nemptyIndex\\x18P \\x01(\\r*\\x05\\x08\\x64\\x10\\x80\\x02\\\"\\xe4\\x03\\n\\x05Quote\\x12\\x0c\\n\\x04text\\x18\\x01 \\x01(\\t\\x12\\r\\n\\x05\\x62\\x65gin\\x18\\x02 \\x01(\\r\\x12\\x0b\\n\\x03\\x65nd\\x18\\x03 \\x01(\\r\\x12\\x15\\n\\rsentenceBegin\\x18\\x05 \\x01(\\r\\x12\\x13\\n\\x0bsentenceEnd\\x18\\x06 \\x01(\\r\\x12\\x12\\n\\ntokenBegin\\x18\\x07 \\x01(\\r\\x12\\x10\\n\\x08tokenEnd\\x18\\x08 \\x01(\\r\\x12\\r\\n\\x05\\x64ocid\\x18\\t \\x01(\\t\\x12\\r\\n\\x05index\\x18\\n \\x01(\\r\\x12\\x0e\\n\\x06\\x61uthor\\x18\\x0b \\x01(\\t\\x12\\x0f\\n\\x07mention\\x18\\x0c \\x01(\\t\\x12\\x14\\n\\x0cmentionBegin\\x18\\r \\x01(\\r\\x12\\x12\\n\\nmentionEnd\\x18\\x0e \\x01(\\r\\x12\\x13\\n\\x0bmentionType\\x18\\x0f \\x01(\\t\\x12\\x14\\n\\x0cmentionSieve\\x18\\x10 \\x01(\\t\\x12\\x0f\\n\\x07speaker\\x18\\x11 \\x01(\\t\\x12\\x14\\n\\x0cspeakerSieve\\x18\\x12 \\x01(\\t\\x12\\x18\\n\\x10\\x63\\x61nonicalMention\\x18\\x13 \\x01(\\t\\x12\\x1d\\n\\x15\\x63\\x61nonicalMentionBegin\\x18\\x14 \\x01(\\r\\x12\\x1b\\n\\x13\\x63\\x61nonicalMentionEnd\\x18\\x15 \\x01(\\r\\x12N\\n\\x1a\\x61ttributionDependencyGraph\\x18\\x16 \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\\"\\xc7\\x01\\n\\tParseTree\\x12\\x33\\n\\x05\\x63hild\\x18\\x01 \\x03(\\x0b\\x32$.edu.stanford.nlp.pipeline.ParseTree\\x12\\r\\n\\x05value\\x18\\x02 \\x01(\\t\\x12\\x17\\n\\x0fyieldBeginIndex\\x18\\x03 \\x01(\\r\\x12\\x15\\n\\ryieldEndIndex\\x18\\x04 \\x01(\\r\\x12\\r\\n\\x05score\\x18\\x05 \\x01(\\x01\\x12\\x37\\n\\tsentiment\\x18\\x06 \\x01(\\x0e\\x32$.edu.stanford.nlp.pipeline.Sentiment\\\"\\x9b\\x04\\n\\x0f\\x44\\x65pendencyGraph\\x12=\\n\\x04node\\x18\\x01 \\x03(\\x0b\\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\\x12=\\n\\x04\\x65\\x64ge\\x18\\x02 \\x03(\\x0b\\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\\x12\\x10\\n\\x04root\\x18\\x03 \\x03(\\rB\\x02\\x10\\x01\\x12/\\n\\x05token\\x18\\x04 \\x03(\\x0b\\x32 .edu.stanford.nlp.pipeline.Token\\x12\\x14\\n\\x08rootNode\\x18\\x05 \\x03(\\rB\\x02\\x10\\x01\\x1aX\\n\\x04Node\\x12\\x15\\n\\rsentenceIndex\\x18\\x01 \\x02(\\r\\x12\\r\\n\\x05index\\x18\\x02 \\x02(\\r\\x12\\x16\\n\\x0e\\x63opyAnnotation\\x18\\x03 \\x01(\\r\\x12\\x12\\n\\nemptyIndex\\x18\\x04 \\x01(\\r\\x1a\\xd6\\x01\\n\\x04\\x45\\x64ge\\x12\\x0e\\n\\x06source\\x18\\x01 \\x02(\\r\\x12\\x0e\\n\\x06target\\x18\\x02 \\x02(\\r\\x12\\x0b\\n\\x03\\x64\\x65p\\x18\\x03 \\x01(\\t\\x12\\x0f\\n\\x07isExtra\\x18\\x04 \\x01(\\x08\\x12\\x12\\n\\nsourceCopy\\x18\\x05 \\x01(\\r\\x12\\x12\\n\\ntargetCopy\\x18\\x06 \\x01(\\r\\x12\\x13\\n\\x0bsourceEmpty\\x18\\x08 \\x01(\\r\\x12\\x13\\n\\x0btargetEmpty\\x18\\t \\x01(\\r\\x12>\\n\\x08language\\x18\\x07 \\x01(\\x0e\\x32#.edu.stanford.nlp.pipeline.Language:\\x07Unknown\\\"\\xc6\\x02\\n\\nCorefChain\\x12\\x0f\\n\\x07\\x63hainID\\x18\\x01 \\x02(\\x05\\x12\\x43\\n\\x07mention\\x18\\x02 \\x03(\\x0b\\x32\\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\\x12\\x16\\n\\x0erepresentative\\x18\\x03 \\x02(\\r\\x1a\\xc9\\x01\\n\\x0c\\x43orefMention\\x12\\x11\\n\\tmentionID\\x18\\x01 \\x01(\\x05\\x12\\x13\\n\\x0bmentionType\\x18\\x02 \\x01(\\t\\x12\\x0e\\n\\x06number\\x18\\x03 \\x01(\\t\\x12\\x0e\\n\\x06gender\\x18\\x04 \\x01(\\t\\x12\\x0f\\n\\x07\\x61nimacy\\x18\\x05 \\x01(\\t\\x12\\x12\\n\\nbeginIndex\\x18\\x06 \\x01(\\r\\x12\\x10\\n\\x08\\x65ndIndex\\x18\\x07 \\x01(\\r\\x12\\x11\\n\\theadIndex\\x18\\t \\x01(\\r\\x12\\x15\\n\\rsentenceIndex\\x18\\n \\x01(\\r\\x12\\x10\\n\\x08position\\x18\\x0b \\x01(\\r\\\"\\xef\\x08\\n\\x07Mention\\x12\\x11\\n\\tmentionID\\x18\\x01 \\x01(\\x05\\x12\\x13\\n\\x0bmentionType\\x18\\x02 \\x01(\\t\\x12\\x0e\\n\\x06number\\x18\\x03 \\x01(\\t\\x12\\x0e\\n\\x06gender\\x18\\x04 \\x01(\\t\\x12\\x0f\\n\\x07\\x61nimacy\\x18\\x05 \\x01(\\t\\x12\\x0e\\n\\x06person\\x18\\x06 \\x01(\\t\\x12\\x12\\n\\nstartIndex\\x18\\x07 \\x01(\\r\\x12\\x10\\n\\x08\\x65ndIndex\\x18\\t \\x01(\\r\\x12\\x11\\n\\theadIndex\\x18\\n \\x01(\\x05\\x12\\x12\\n\\nheadString\\x18\\x0b \\x01(\\t\\x12\\x11\\n\\tnerString\\x18\\x0c \\x01(\\t\\x12\\x13\\n\\x0boriginalRef\\x18\\r \\x01(\\x05\\x12\\x1a\\n\\x12goldCorefClusterID\\x18\\x0e \\x01(\\x05\\x12\\x16\\n\\x0e\\x63orefClusterID\\x18\\x0f \\x01(\\x05\\x12\\x12\\n\\nmentionNum\\x18\\x10 \\x01(\\x05\\x12\\x0f\\n\\x07sentNum\\x18\\x11 \\x01(\\x05\\x12\\r\\n\\x05utter\\x18\\x12 \\x01(\\x05\\x12\\x11\\n\\tparagraph\\x18\\x13 \\x01(\\x05\\x12\\x11\\n\\tisSubject\\x18\\x14 \\x01(\\x08\\x12\\x16\\n\\x0eisDirectObject\\x18\\x15 \\x01(\\x08\\x12\\x18\\n\\x10isIndirectObject\\x18\\x16 \\x01(\\x08\\x12\\x1b\\n\\x13isPrepositionObject\\x18\\x17 \\x01(\\x08\\x12\\x0f\\n\\x07hasTwin\\x18\\x18 \\x01(\\x08\\x12\\x0f\\n\\x07generic\\x18\\x19 \\x01(\\x08\\x12\\x13\\n\\x0bisSingleton\\x18\\x1a \\x01(\\x08\\x12\\x1a\\n\\x12hasBasicDependency\\x18\\x1b \\x01(\\x08\\x12\\x1d\\n\\x15hasEnhancedDependency\\x18\\x1c \\x01(\\x08\\x12\\x1b\\n\\x13hasContextParseTree\\x18\\x1d \\x01(\\x08\\x12?\\n\\x0fheadIndexedWord\\x18\\x1e \\x01(\\x0b\\x32&.edu.stanford.nlp.pipeline.IndexedWord\\x12=\\n\\rdependingVerb\\x18\\x1f \\x01(\\x0b\\x32&.edu.stanford.nlp.pipeline.IndexedWord\\x12\\x38\\n\\x08headWord\\x18  \\x01(\\x0b\\x32&.edu.stanford.nlp.pipeline.IndexedWord\\x12;\\n\\x0bspeakerInfo\\x18! \\x01(\\x0b\\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\\x12=\\n\\rsentenceWords\\x18\\x32 \\x03(\\x0b\\x32&.edu.stanford.nlp.pipeline.IndexedWord\\x12<\\n\\x0coriginalSpan\\x18\\x33 \\x03(\\x0b\\x32&.edu.stanford.nlp.pipeline.IndexedWord\\x12\\x12\\n\\ndependents\\x18\\x34 \\x03(\\t\\x12\\x19\\n\\x11preprocessedTerms\\x18\\x35 \\x03(\\t\\x12\\x13\\n\\x0b\\x61ppositions\\x18\\x36 \\x03(\\x05\\x12\\x1c\\n\\x14predicateNominatives\\x18\\x37 \\x03(\\x05\\x12\\x18\\n\\x10relativePronouns\\x18\\x38 \\x03(\\x05\\x12\\x13\\n\\x0blistMembers\\x18\\x39 \\x03(\\x05\\x12\\x15\\n\\rbelongToLists\\x18: \\x03(\\x05\\\"X\\n\\x0bIndexedWord\\x12\\x13\\n\\x0bsentenceNum\\x18\\x01 \\x01(\\x05\\x12\\x12\\n\\ntokenIndex\\x18\\x02 \\x01(\\x05\\x12\\r\\n\\x05\\x64ocID\\x18\\x03 \\x01(\\x05\\x12\\x11\\n\\tcopyCount\\x18\\x04 \\x01(\\r\\\"4\\n\\x0bSpeakerInfo\\x12\\x13\\n\\x0bspeakerName\\x18\\x01 \\x01(\\t\\x12\\x10\\n\\x08mentions\\x18\\x02 \\x03(\\x05\\\"\\\"\\n\\x04Span\\x12\\r\\n\\x05\\x62\\x65gin\\x18\\x01 \\x02(\\r\\x12\\x0b\\n\\x03\\x65nd\\x18\\x02 \\x02(\\r\\\"w\\n\\x05Timex\\x12\\r\\n\\x05value\\x18\\x01 \\x01(\\t\\x12\\x10\\n\\x08\\x61ltValue\\x18\\x02 \\x01(\\t\\x12\\x0c\\n\\x04text\\x18\\x03 \\x01(\\t\\x12\\x0c\\n\\x04type\\x18\\x04 \\x01(\\t\\x12\\x0b\\n\\x03tid\\x18\\x05 \\x01(\\t\\x12\\x12\\n\\nbeginPoint\\x18\\x06 \\x01(\\r\\x12\\x10\\n\\x08\\x65ndPoint\\x18\\x07 \\x01(\\r\\\"\\xdb\\x01\\n\\x06\\x45ntity\\x12\\x11\\n\\theadStart\\x18\\x06 \\x01(\\r\\x12\\x0f\\n\\x07headEnd\\x18\\x07 \\x01(\\r\\x12\\x13\\n\\x0bmentionType\\x18\\x08 \\x01(\\t\\x12\\x16\\n\\x0enormalizedName\\x18\\t \\x01(\\t\\x12\\x16\\n\\x0eheadTokenIndex\\x18\\n \\x01(\\r\\x12\\x0f\\n\\x07\\x63orefID\\x18\\x0b \\x01(\\t\\x12\\x10\\n\\x08objectID\\x18\\x01 \\x01(\\t\\x12\\x13\\n\\x0b\\x65xtentStart\\x18\\x02 \\x01(\\r\\x12\\x11\\n\\textentEnd\\x18\\x03 \\x01(\\r\\x12\\x0c\\n\\x04type\\x18\\x04 \\x01(\\t\\x12\\x0f\\n\\x07subtype\\x18\\x05 \\x01(\\t\\\"\\xb7\\x01\\n\\x08Relation\\x12\\x0f\\n\\x07\\x61rgName\\x18\\x06 \\x03(\\t\\x12.\\n\\x03\\x61rg\\x18\\x07 \\x03(\\x0b\\x32!.edu.stanford.nlp.pipeline.Entity\\x12\\x11\\n\\tsignature\\x18\\x08 \\x01(\\t\\x12\\x10\\n\\x08objectID\\x18\\x01 \\x01(\\t\\x12\\x13\\n\\x0b\\x65xtentStart\\x18\\x02 \\x01(\\r\\x12\\x11\\n\\textentEnd\\x18\\x03 \\x01(\\r\\x12\\x0c\\n\\x04type\\x18\\x04 \\x01(\\t\\x12\\x0f\\n\\x07subtype\\x18\\x05 \\x01(\\t\\\"\\xb2\\x01\\n\\x08Operator\\x12\\x0c\\n\\x04name\\x18\\x01 \\x02(\\t\\x12\\x1b\\n\\x13quantifierSpanBegin\\x18\\x02 \\x02(\\x05\\x12\\x19\\n\\x11quantifierSpanEnd\\x18\\x03 \\x02(\\x05\\x12\\x18\\n\\x10subjectSpanBegin\\x18\\x04 \\x02(\\x05\\x12\\x16\\n\\x0esubjectSpanEnd\\x18\\x05 \\x02(\\x05\\x12\\x17\\n\\x0fobjectSpanBegin\\x18\\x06 \\x02(\\x05\\x12\\x15\\n\\robjectSpanEnd\\x18\\x07 \\x02(\\x05\\\"\\xa9\\x04\\n\\x08Polarity\\x12K\\n\\x12projectEquivalence\\x18\\x01 \\x02(\\x0e\\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\\x12Q\\n\\x18projectForwardEntailment\\x18\\x02 \\x02(\\x0e\\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\\x12Q\\n\\x18projectReverseEntailment\\x18\\x03 \\x02(\\x0e\\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\\x12H\\n\\x0fprojectNegation\\x18\\x04 \\x02(\\x0e\\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\\x12K\\n\\x12projectAlternation\\x18\\x05 \\x02(\\x0e\\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\\x12\\x45\\n\\x0cprojectCover\\x18\\x06 \\x02(\\x0e\\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\\x12L\\n\\x13projectIndependence\\x18\\x07 \\x02(\\x0e\\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\\\"\\xdd\\x02\\n\\nNERMention\\x12\\x15\\n\\rsentenceIndex\\x18\\x01 \\x01(\\r\\x12%\\n\\x1dtokenStartInSentenceInclusive\\x18\\x02 \\x02(\\r\\x12#\\n\\x1btokenEndInSentenceExclusive\\x18\\x03 \\x02(\\r\\x12\\x0b\\n\\x03ner\\x18\\x04 \\x02(\\t\\x12\\x15\\n\\rnormalizedNER\\x18\\x05 \\x01(\\t\\x12\\x12\\n\\nentityType\\x18\\x06 \\x01(\\t\\x12/\\n\\x05timex\\x18\\x07 \\x01(\\x0b\\x32 .edu.stanford.nlp.pipeline.Timex\\x12\\x17\\n\\x0fwikipediaEntity\\x18\\x08 \\x01(\\t\\x12\\x0e\\n\\x06gender\\x18\\t \\x01(\\t\\x12\\x1a\\n\\x12\\x65ntityMentionIndex\\x18\\n \\x01(\\r\\x12#\\n\\x1b\\x63\\x61nonicalEntityMentionIndex\\x18\\x0b \\x01(\\r\\x12\\x19\\n\\x11\\x65ntityMentionText\\x18\\x0c \\x01(\\t\\\"Y\\n\\x10SentenceFragment\\x12\\x12\\n\\ntokenIndex\\x18\\x01 \\x03(\\r\\x12\\x0c\\n\\x04root\\x18\\x02 \\x01(\\r\\x12\\x14\\n\\x0c\\x61ssumedTruth\\x18\\x03 \\x01(\\x08\\x12\\r\\n\\x05score\\x18\\x04 \\x01(\\x01\\\":\\n\\rTokenLocation\\x12\\x15\\n\\rsentenceIndex\\x18\\x01 \\x01(\\r\\x12\\x12\\n\\ntokenIndex\\x18\\x02 \\x01(\\r\\\"\\x9a\\x03\\n\\x0eRelationTriple\\x12\\x0f\\n\\x07subject\\x18\\x01 \\x01(\\t\\x12\\x10\\n\\x08relation\\x18\\x02 \\x01(\\t\\x12\\x0e\\n\\x06object\\x18\\x03 \\x01(\\t\\x12\\x12\\n\\nconfidence\\x18\\x04 \\x01(\\x01\\x12?\\n\\rsubjectTokens\\x18\\r \\x03(\\x0b\\x32(.edu.stanford.nlp.pipeline.TokenLocation\\x12@\\n\\x0erelationTokens\\x18\\x0e \\x03(\\x0b\\x32(.edu.stanford.nlp.pipeline.TokenLocation\\x12>\\n\\x0cobjectTokens\\x18\\x0f \\x03(\\x0b\\x32(.edu.stanford.nlp.pipeline.TokenLocation\\x12\\x38\\n\\x04tree\\x18\\x08 \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12\\x0e\\n\\x06istmod\\x18\\t \\x01(\\x08\\x12\\x10\\n\\x08prefixBe\\x18\\n \\x01(\\x08\\x12\\x10\\n\\x08suffixBe\\x18\\x0b \\x01(\\x08\\x12\\x10\\n\\x08suffixOf\\x18\\x0c \\x01(\\x08\\\"-\\n\\x0fMapStringString\\x12\\x0b\\n\\x03key\\x18\\x01 \\x03(\\t\\x12\\r\\n\\x05value\\x18\\x02 \\x03(\\t\\\"*\\n\\x0cMapIntString\\x12\\x0b\\n\\x03key\\x18\\x01 \\x03(\\r\\x12\\r\\n\\x05value\\x18\\x02 \\x03(\\t\\\"\\xfc\\x01\\n\\x07Section\\x12\\x11\\n\\tcharBegin\\x18\\x01 \\x02(\\r\\x12\\x0f\\n\\x07\\x63harEnd\\x18\\x02 \\x02(\\r\\x12\\x0e\\n\\x06\\x61uthor\\x18\\x03 \\x01(\\t\\x12\\x17\\n\\x0fsentenceIndexes\\x18\\x04 \\x03(\\r\\x12\\x10\\n\\x08\\x64\\x61tetime\\x18\\x05 \\x01(\\t\\x12\\x30\\n\\x06quotes\\x18\\x06 \\x03(\\x0b\\x32 .edu.stanford.nlp.pipeline.Quote\\x12\\x17\\n\\x0f\\x61uthorCharBegin\\x18\\x07 \\x01(\\r\\x12\\x15\\n\\rauthorCharEnd\\x18\\x08 \\x01(\\r\\x12\\x30\\n\\x06xmlTag\\x18\\t \\x02(\\x0b\\x32 .edu.stanford.nlp.pipeline.Token\\\"\\xe4\\x01\\n\\x0eSemgrexRequest\\x12\\x0f\\n\\x07semgrex\\x18\\x01 \\x03(\\t\\x12\\x45\\n\\x05query\\x18\\x02 \\x03(\\x0b\\x32\\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\\x1az\\n\\x0c\\x44\\x65pendencies\\x12/\\n\\x05token\\x18\\x01 \\x03(\\x0b\\x32 .edu.stanford.nlp.pipeline.Token\\x12\\x39\\n\\x05graph\\x18\\x02 \\x02(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\\"\\xfb\\x06\\n\\x0fSemgrexResponse\\x12\\x46\\n\\x06result\\x18\\x01 \\x03(\\x0b\\x32\\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\\x1a-\\n\\tNamedNode\\x12\\x0c\\n\\x04name\\x18\\x01 \\x02(\\t\\x12\\x12\\n\\nmatchIndex\\x18\\x02 \\x02(\\x05\\x1a+\\n\\rNamedRelation\\x12\\x0c\\n\\x04name\\x18\\x01 \\x02(\\t\\x12\\x0c\\n\\x04reln\\x18\\x02 \\x02(\\t\\x1a\\x80\\x01\\n\\tNamedEdge\\x12\\x0c\\n\\x04name\\x18\\x01 \\x02(\\t\\x12\\x0e\\n\\x06source\\x18\\x02 \\x02(\\x05\\x12\\x0e\\n\\x06target\\x18\\x03 \\x02(\\x05\\x12\\x0c\\n\\x04reln\\x18\\x04 \\x01(\\t\\x12\\x0f\\n\\x07isExtra\\x18\\x05 \\x01(\\x08\\x12\\x12\\n\\nsourceCopy\\x18\\x06 \\x01(\\r\\x12\\x12\\n\\ntargetCopy\\x18\\x07 \\x01(\\r\\x1a-\\n\\x0eVariableString\\x12\\x0c\\n\\x04name\\x18\\x01 \\x02(\\t\\x12\\r\\n\\x05value\\x18\\x02 \\x02(\\t\\x1a\\xe6\\x02\\n\\x05Match\\x12\\x12\\n\\nmatchIndex\\x18\\x01 \\x02(\\x05\\x12\\x42\\n\\x04node\\x18\\x02 \\x03(\\x0b\\x32\\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\\x12\\x46\\n\\x04reln\\x18\\x03 \\x03(\\x0b\\x32\\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\\x12\\x42\\n\\x04\\x65\\x64ge\\x18\\x06 \\x03(\\x0b\\x32\\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedEdge\\x12L\\n\\tvarstring\\x18\\x07 \\x03(\\x0b\\x32\\x39.edu.stanford.nlp.pipeline.SemgrexResponse.VariableString\\x12\\x15\\n\\rsentenceIndex\\x18\\x04 \\x01(\\x05\\x12\\x14\\n\\x0csemgrexIndex\\x18\\x05 \\x01(\\x05\\x1aP\\n\\rSemgrexResult\\x12?\\n\\x05match\\x18\\x01 \\x03(\\x0b\\x32\\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\\x1aW\\n\\x0bGraphResult\\x12H\\n\\x06result\\x18\\x01 \\x03(\\x0b\\x32\\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\\\"\\xf0\\x01\\n\\x0fSsurgeonRequest\\x12\\x45\\n\\x08ssurgeon\\x18\\x01 \\x03(\\x0b\\x32\\x33.edu.stanford.nlp.pipeline.SsurgeonRequest.Ssurgeon\\x12\\x39\\n\\x05graph\\x18\\x02 \\x03(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x1a[\\n\\x08Ssurgeon\\x12\\x0f\\n\\x07semgrex\\x18\\x01 \\x01(\\t\\x12\\x11\\n\\toperation\\x18\\x02 \\x03(\\t\\x12\\n\\n\\x02id\\x18\\x03 \\x01(\\t\\x12\\r\\n\\x05notes\\x18\\x04 \\x01(\\t\\x12\\x10\\n\\x08language\\x18\\x05 \\x01(\\t\\\"\\xbc\\x01\\n\\x10SsurgeonResponse\\x12J\\n\\x06result\\x18\\x01 \\x03(\\x0b\\x32:.edu.stanford.nlp.pipeline.SsurgeonResponse.SsurgeonResult\\x1a\\\\\\n\\x0eSsurgeonResult\\x12\\x39\\n\\x05graph\\x18\\x01 \\x01(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12\\x0f\\n\\x07\\x63hanged\\x18\\x02 \\x01(\\x08\\\"W\\n\\x12TokensRegexRequest\\x12\\x30\\n\\x03\\x64oc\\x18\\x01 \\x02(\\x0b\\x32#.edu.stanford.nlp.pipeline.Document\\x12\\x0f\\n\\x07pattern\\x18\\x02 \\x03(\\t\\\"\\xa7\\x03\\n\\x13TokensRegexResponse\\x12J\\n\\x05match\\x18\\x01 \\x03(\\x0b\\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\\x1a\\x39\\n\\rMatchLocation\\x12\\x0c\\n\\x04text\\x18\\x01 \\x01(\\t\\x12\\r\\n\\x05\\x62\\x65gin\\x18\\x02 \\x01(\\x05\\x12\\x0b\\n\\x03\\x65nd\\x18\\x03 \\x01(\\x05\\x1a\\xb3\\x01\\n\\x05Match\\x12\\x10\\n\\x08sentence\\x18\\x01 \\x02(\\x05\\x12K\\n\\x05match\\x18\\x02 \\x02(\\x0b\\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\\x12K\\n\\x05group\\x18\\x03 \\x03(\\x0b\\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\\x1aS\\n\\x0cPatternMatch\\x12\\x43\\n\\x05match\\x18\\x01 \\x03(\\x0b\\x32\\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\\\"\\xae\\x01\\n\\x19\\x44\\x65pendencyEnhancerRequest\\x12\\x35\\n\\x08\\x64ocument\\x18\\x01 \\x02(\\x0b\\x32#.edu.stanford.nlp.pipeline.Document\\x12\\x37\\n\\x08language\\x18\\x02 \\x01(\\x0e\\x32#.edu.stanford.nlp.pipeline.LanguageH\\x00\\x12\\x1a\\n\\x10relativePronouns\\x18\\x03 \\x01(\\tH\\x00\\x42\\x05\\n\\x03ref\\\"\\xb4\\x01\\n\\x12\\x46lattenedParseTree\\x12\\x41\\n\\x05nodes\\x18\\x01 \\x03(\\x0b\\x32\\x32.edu.stanford.nlp.pipeline.FlattenedParseTree.Node\\x1a[\\n\\x04Node\\x12\\x12\\n\\x08openNode\\x18\\x01 \\x01(\\x08H\\x00\\x12\\x13\\n\\tcloseNode\\x18\\x02 \\x01(\\x08H\\x00\\x12\\x0f\\n\\x05value\\x18\\x03 \\x01(\\tH\\x00\\x12\\r\\n\\x05score\\x18\\x04 \\x01(\\x01\\x42\\n\\n\\x08\\x63ontents\\\"\\xf6\\x01\\n\\x15\\x45valuateParserRequest\\x12N\\n\\x08treebank\\x18\\x01 \\x03(\\x0b\\x32<.edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult\\x1a\\x8c\\x01\\n\\x0bParseResult\\x12;\\n\\x04gold\\x18\\x01 \\x02(\\x0b\\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\\x12@\\n\\tpredicted\\x18\\x02 \\x03(\\x0b\\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\\\"E\\n\\x16\\x45valuateParserResponse\\x12\\n\\n\\x02\\x66\\x31\\x18\\x01 \\x02(\\x01\\x12\\x0f\\n\\x07kbestF1\\x18\\x02 \\x01(\\x01\\x12\\x0e\\n\\x06treeF1\\x18\\x03 \\x03(\\x01\\\"\\xc8\\x01\\n\\x0fTsurgeonRequest\\x12H\\n\\noperations\\x18\\x01 \\x03(\\x0b\\x32\\x34.edu.stanford.nlp.pipeline.TsurgeonRequest.Operation\\x12<\\n\\x05trees\\x18\\x02 \\x03(\\x0b\\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\\x1a-\\n\\tOperation\\x12\\x0e\\n\\x06tregex\\x18\\x01 \\x02(\\t\\x12\\x10\\n\\x08tsurgeon\\x18\\x02 \\x03(\\t\\\"P\\n\\x10TsurgeonResponse\\x12<\\n\\x05trees\\x18\\x01 \\x03(\\x0b\\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\\\"\\x85\\x01\\n\\x11MorphologyRequest\\x12\\x46\\n\\x05words\\x18\\x01 \\x03(\\x0b\\x32\\x37.edu.stanford.nlp.pipeline.MorphologyRequest.TaggedWord\\x1a(\\n\\nTaggedWord\\x12\\x0c\\n\\x04word\\x18\\x01 \\x02(\\t\\x12\\x0c\\n\\x04xpos\\x18\\x02 \\x01(\\t\\\"\\x9a\\x01\\n\\x12MorphologyResponse\\x12I\\n\\x05words\\x18\\x01 \\x03(\\x0b\\x32:.edu.stanford.nlp.pipeline.MorphologyResponse.WordTagLemma\\x1a\\x39\\n\\x0cWordTagLemma\\x12\\x0c\\n\\x04word\\x18\\x01 \\x02(\\t\\x12\\x0c\\n\\x04xpos\\x18\\x02 \\x01(\\t\\x12\\r\\n\\x05lemma\\x18\\x03 \\x02(\\t\\\"Z\\n\\x1a\\x44\\x65pendencyConverterRequest\\x12<\\n\\x05trees\\x18\\x01 \\x03(\\x0b\\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\\\"\\x90\\x02\\n\\x1b\\x44\\x65pendencyConverterResponse\\x12`\\n\\x0b\\x63onversions\\x18\\x01 \\x03(\\x0b\\x32K.edu.stanford.nlp.pipeline.DependencyConverterResponse.DependencyConversion\\x1a\\x8e\\x01\\n\\x14\\x44\\x65pendencyConversion\\x12\\x39\\n\\x05graph\\x18\\x01 \\x02(\\x0b\\x32*.edu.stanford.nlp.pipeline.DependencyGraph\\x12;\\n\\x04tree\\x18\\x02 \\x01(\\x0b\\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree*\\xa3\\x01\\n\\x08Language\\x12\\x0b\\n\\x07Unknown\\x10\\x00\\x12\\x07\\n\\x03\\x41ny\\x10\\x01\\x12\\n\\n\\x06\\x41rabic\\x10\\x02\\x12\\x0b\\n\\x07\\x43hinese\\x10\\x03\\x12\\x0b\\n\\x07\\x45nglish\\x10\\x04\\x12\\n\\n\\x06German\\x10\\x05\\x12\\n\\n\\x06\\x46rench\\x10\\x06\\x12\\n\\n\\x06Hebrew\\x10\\x07\\x12\\x0b\\n\\x07Spanish\\x10\\x08\\x12\\x14\\n\\x10UniversalEnglish\\x10\\t\\x12\\x14\\n\\x10UniversalChinese\\x10\\n*h\\n\\tSentiment\\x12\\x13\\n\\x0fSTRONG_NEGATIVE\\x10\\x00\\x12\\x11\\n\\rWEAK_NEGATIVE\\x10\\x01\\x12\\x0b\\n\\x07NEUTRAL\\x10\\x02\\x12\\x11\\n\\rWEAK_POSITIVE\\x10\\x03\\x12\\x13\\n\\x0fSTRONG_POSITIVE\\x10\\x04*\\x93\\x01\\n\\x14NaturalLogicRelation\\x12\\x0f\\n\\x0b\\x45QUIVALENCE\\x10\\x00\\x12\\x16\\n\\x12\\x46ORWARD_ENTAILMENT\\x10\\x01\\x12\\x16\\n\\x12REVERSE_ENTAILMENT\\x10\\x02\\x12\\x0c\\n\\x08NEGATION\\x10\\x03\\x12\\x0f\\n\\x0b\\x41LTERNATION\\x10\\x04\\x12\\t\\n\\x05\\x43OVER\\x10\\x05\\x12\\x10\\n\\x0cINDEPENDENCE\\x10\\x06\\x42*\\n\\x19\\x65\\x64u.stanford.nlp.pipelineB\\rCoreNLPProtos')\n\n_globals = globals()\n_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)\n_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'CoreNLP_pb2', _globals)\nif _descriptor._USE_C_DESCRIPTORS == False:\n  _globals['DESCRIPTOR']._options = None\n  _globals['DESCRIPTOR']._serialized_options = b'\\n\\031edu.stanford.nlp.pipelineB\\rCoreNLPProtos'\n  _globals['_DEPENDENCYGRAPH'].fields_by_name['root']._options = None\n  _globals['_DEPENDENCYGRAPH'].fields_by_name['root']._serialized_options = b'\\020\\001'\n  _globals['_DEPENDENCYGRAPH'].fields_by_name['rootNode']._options = None\n  _globals['_DEPENDENCYGRAPH'].fields_by_name['rootNode']._serialized_options = b'\\020\\001'\n  _globals['_LANGUAGE']._serialized_start=13585\n  _globals['_LANGUAGE']._serialized_end=13748\n  _globals['_SENTIMENT']._serialized_start=13750\n  _globals['_SENTIMENT']._serialized_end=13854\n  _globals['_NATURALLOGICRELATION']._serialized_start=13857\n  _globals['_NATURALLOGICRELATION']._serialized_end=14004\n  _globals['_DOCUMENT']._serialized_start=45\n  _globals['_DOCUMENT']._serialized_end=782\n  _globals['_SENTENCE']._serialized_start=785\n  _globals['_SENTENCE']._serialized_end=2820\n  _globals['_TOKEN']._serialized_start=2823\n  _globals['_TOKEN']._serialized_end=4477\n  _globals['_QUOTE']._serialized_start=4480\n  _globals['_QUOTE']._serialized_end=4964\n  _globals['_PARSETREE']._serialized_start=4967\n  _globals['_PARSETREE']._serialized_end=5166\n  _globals['_DEPENDENCYGRAPH']._serialized_start=5169\n  _globals['_DEPENDENCYGRAPH']._serialized_end=5708\n  _globals['_DEPENDENCYGRAPH_NODE']._serialized_start=5403\n  _globals['_DEPENDENCYGRAPH_NODE']._serialized_end=5491\n  _globals['_DEPENDENCYGRAPH_EDGE']._serialized_start=5494\n  _globals['_DEPENDENCYGRAPH_EDGE']._serialized_end=5708\n  _globals['_COREFCHAIN']._serialized_start=5711\n  _globals['_COREFCHAIN']._serialized_end=6037\n  _globals['_COREFCHAIN_COREFMENTION']._serialized_start=5836\n  _globals['_COREFCHAIN_COREFMENTION']._serialized_end=6037\n  _globals['_MENTION']._serialized_start=6040\n  _globals['_MENTION']._serialized_end=7175\n  _globals['_INDEXEDWORD']._serialized_start=7177\n  _globals['_INDEXEDWORD']._serialized_end=7265\n  _globals['_SPEAKERINFO']._serialized_start=7267\n  _globals['_SPEAKERINFO']._serialized_end=7319\n  _globals['_SPAN']._serialized_start=7321\n  _globals['_SPAN']._serialized_end=7355\n  _globals['_TIMEX']._serialized_start=7357\n  _globals['_TIMEX']._serialized_end=7476\n  _globals['_ENTITY']._serialized_start=7479\n  _globals['_ENTITY']._serialized_end=7698\n  _globals['_RELATION']._serialized_start=7701\n  _globals['_RELATION']._serialized_end=7884\n  _globals['_OPERATOR']._serialized_start=7887\n  _globals['_OPERATOR']._serialized_end=8065\n  _globals['_POLARITY']._serialized_start=8068\n  _globals['_POLARITY']._serialized_end=8621\n  _globals['_NERMENTION']._serialized_start=8624\n  _globals['_NERMENTION']._serialized_end=8973\n  _globals['_SENTENCEFRAGMENT']._serialized_start=8975\n  _globals['_SENTENCEFRAGMENT']._serialized_end=9064\n  _globals['_TOKENLOCATION']._serialized_start=9066\n  _globals['_TOKENLOCATION']._serialized_end=9124\n  _globals['_RELATIONTRIPLE']._serialized_start=9127\n  _globals['_RELATIONTRIPLE']._serialized_end=9537\n  _globals['_MAPSTRINGSTRING']._serialized_start=9539\n  _globals['_MAPSTRINGSTRING']._serialized_end=9584\n  _globals['_MAPINTSTRING']._serialized_start=9586\n  _globals['_MAPINTSTRING']._serialized_end=9628\n  _globals['_SECTION']._serialized_start=9631\n  _globals['_SECTION']._serialized_end=9883\n  _globals['_SEMGREXREQUEST']._serialized_start=9886\n  _globals['_SEMGREXREQUEST']._serialized_end=10114\n  _globals['_SEMGREXREQUEST_DEPENDENCIES']._serialized_start=9992\n  _globals['_SEMGREXREQUEST_DEPENDENCIES']._serialized_end=10114\n  _globals['_SEMGREXRESPONSE']._serialized_start=10117\n  _globals['_SEMGREXRESPONSE']._serialized_end=11008\n  _globals['_SEMGREXRESPONSE_NAMEDNODE']._serialized_start=10208\n  _globals['_SEMGREXRESPONSE_NAMEDNODE']._serialized_end=10253\n  _globals['_SEMGREXRESPONSE_NAMEDRELATION']._serialized_start=10255\n  _globals['_SEMGREXRESPONSE_NAMEDRELATION']._serialized_end=10298\n  _globals['_SEMGREXRESPONSE_NAMEDEDGE']._serialized_start=10301\n  _globals['_SEMGREXRESPONSE_NAMEDEDGE']._serialized_end=10429\n  _globals['_SEMGREXRESPONSE_VARIABLESTRING']._serialized_start=10431\n  _globals['_SEMGREXRESPONSE_VARIABLESTRING']._serialized_end=10476\n  _globals['_SEMGREXRESPONSE_MATCH']._serialized_start=10479\n  _globals['_SEMGREXRESPONSE_MATCH']._serialized_end=10837\n  _globals['_SEMGREXRESPONSE_SEMGREXRESULT']._serialized_start=10839\n  _globals['_SEMGREXRESPONSE_SEMGREXRESULT']._serialized_end=10919\n  _globals['_SEMGREXRESPONSE_GRAPHRESULT']._serialized_start=10921\n  _globals['_SEMGREXRESPONSE_GRAPHRESULT']._serialized_end=11008\n  _globals['_SSURGEONREQUEST']._serialized_start=11011\n  _globals['_SSURGEONREQUEST']._serialized_end=11251\n  _globals['_SSURGEONREQUEST_SSURGEON']._serialized_start=11160\n  _globals['_SSURGEONREQUEST_SSURGEON']._serialized_end=11251\n  _globals['_SSURGEONRESPONSE']._serialized_start=11254\n  _globals['_SSURGEONRESPONSE']._serialized_end=11442\n  _globals['_SSURGEONRESPONSE_SSURGEONRESULT']._serialized_start=11350\n  _globals['_SSURGEONRESPONSE_SSURGEONRESULT']._serialized_end=11442\n  _globals['_TOKENSREGEXREQUEST']._serialized_start=11444\n  _globals['_TOKENSREGEXREQUEST']._serialized_end=11531\n  _globals['_TOKENSREGEXRESPONSE']._serialized_start=11534\n  _globals['_TOKENSREGEXRESPONSE']._serialized_end=11957\n  _globals['_TOKENSREGEXRESPONSE_MATCHLOCATION']._serialized_start=11633\n  _globals['_TOKENSREGEXRESPONSE_MATCHLOCATION']._serialized_end=11690\n  _globals['_TOKENSREGEXRESPONSE_MATCH']._serialized_start=11693\n  _globals['_TOKENSREGEXRESPONSE_MATCH']._serialized_end=11872\n  _globals['_TOKENSREGEXRESPONSE_PATTERNMATCH']._serialized_start=11874\n  _globals['_TOKENSREGEXRESPONSE_PATTERNMATCH']._serialized_end=11957\n  _globals['_DEPENDENCYENHANCERREQUEST']._serialized_start=11960\n  _globals['_DEPENDENCYENHANCERREQUEST']._serialized_end=12134\n  _globals['_FLATTENEDPARSETREE']._serialized_start=12137\n  _globals['_FLATTENEDPARSETREE']._serialized_end=12317\n  _globals['_FLATTENEDPARSETREE_NODE']._serialized_start=12226\n  _globals['_FLATTENEDPARSETREE_NODE']._serialized_end=12317\n  _globals['_EVALUATEPARSERREQUEST']._serialized_start=12320\n  _globals['_EVALUATEPARSERREQUEST']._serialized_end=12566\n  _globals['_EVALUATEPARSERREQUEST_PARSERESULT']._serialized_start=12426\n  _globals['_EVALUATEPARSERREQUEST_PARSERESULT']._serialized_end=12566\n  _globals['_EVALUATEPARSERRESPONSE']._serialized_start=12568\n  _globals['_EVALUATEPARSERRESPONSE']._serialized_end=12637\n  _globals['_TSURGEONREQUEST']._serialized_start=12640\n  _globals['_TSURGEONREQUEST']._serialized_end=12840\n  _globals['_TSURGEONREQUEST_OPERATION']._serialized_start=12795\n  _globals['_TSURGEONREQUEST_OPERATION']._serialized_end=12840\n  _globals['_TSURGEONRESPONSE']._serialized_start=12842\n  _globals['_TSURGEONRESPONSE']._serialized_end=12922\n  _globals['_MORPHOLOGYREQUEST']._serialized_start=12925\n  _globals['_MORPHOLOGYREQUEST']._serialized_end=13058\n  _globals['_MORPHOLOGYREQUEST_TAGGEDWORD']._serialized_start=13018\n  _globals['_MORPHOLOGYREQUEST_TAGGEDWORD']._serialized_end=13058\n  _globals['_MORPHOLOGYRESPONSE']._serialized_start=13061\n  _globals['_MORPHOLOGYRESPONSE']._serialized_end=13215\n  _globals['_MORPHOLOGYRESPONSE_WORDTAGLEMMA']._serialized_start=13158\n  _globals['_MORPHOLOGYRESPONSE_WORDTAGLEMMA']._serialized_end=13215\n  _globals['_DEPENDENCYCONVERTERREQUEST']._serialized_start=13217\n  _globals['_DEPENDENCYCONVERTERREQUEST']._serialized_end=13307\n  _globals['_DEPENDENCYCONVERTERRESPONSE']._serialized_start=13310\n  _globals['_DEPENDENCYCONVERTERRESPONSE']._serialized_end=13582\n  _globals['_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION']._serialized_start=13440\n  _globals['_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION']._serialized_end=13582\n# @@protoc_insertion_point(module_scope)\n"
  },
  {
    "path": "stanza/protobuf/__init__.py",
    "content": "from __future__ import absolute_import\n\nfrom io import BytesIO\nimport warnings\n\nfrom google.protobuf.internal.encoder import _EncodeVarint\nfrom google.protobuf.internal.decoder import _DecodeVarint\nfrom google.protobuf.message import DecodeError\nfrom .CoreNLP_pb2 import *\n\ndef parseFromDelimitedString(obj, buf, offset=0):\n    \"\"\"\n    Stanford CoreNLP uses the Java \"writeDelimitedTo\" function, which\n    writes the size (and offset) of the buffer before writing the object.\n    This function handles parsing this message starting from offset 0.\n\n    @returns how many bytes of @buf were consumed.\n    \"\"\"\n    size, pos = _DecodeVarint(buf, offset)\n    try:\n        obj.ParseFromString(buf[offset+pos:offset+pos+size])\n    except DecodeError as e:\n        warnings.warn(\"Failed to decode a serialized output from CoreNLP server. An incomplete or empty object will be returned.\", \\\n            RuntimeWarning)\n    return pos+size\n\ndef writeToDelimitedString(obj, stream=None):\n    \"\"\"\n    Stanford CoreNLP uses the Java \"writeDelimitedTo\" function, which\n    writes the size (and offset) of the buffer before writing the object.\n    This function handles parsing this message starting from offset 0.\n\n    @returns how many bytes of @buf were consumed.\n    \"\"\"\n    if stream is None:\n        stream = BytesIO()\n\n    _EncodeVarint(stream.write, obj.ByteSize(), True)\n    stream.write(obj.SerializeToString())\n    return stream\n\ndef to_text(sentence):\n    \"\"\"\n    Helper routine that converts a Sentence protobuf to a string from\n    its tokens.\n    \"\"\"\n    text = \"\"\n    for i, tok in enumerate(sentence.token):\n        if i != 0:\n            text += tok.before\n        text += tok.word\n    return text\n"
  },
  {
    "path": "stanza/resources/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/resources/common.py",
    "content": "\"\"\"\nCommon utilities for Stanza resources.\n\"\"\"\n\nfrom collections import defaultdict, namedtuple\nimport errno\nimport hashlib\nimport json\nimport logging\nimport os\nfrom pathlib import Path\nimport requests\nimport shutil\nimport tempfile\nimport zipfile\n\nfrom platformdirs import user_cache_dir\nfrom tqdm.auto import tqdm\n\nfrom stanza.utils.helper_func import make_table\nfrom stanza.pipeline._constants import TOKENIZE, MWT, POS, LEMMA, DEPPARSE, NER, SENTIMENT\nfrom stanza.pipeline.registry import PIPELINE_NAMES, PROCESSOR_VARIANTS\nfrom stanza.resources.default_packages import PACKAGES\nfrom stanza._version import __resources_version__\n\nlogger = logging.getLogger('stanza')\n\n# set home dir for default\nUSER_CACHE_DIR = user_cache_dir('stanza', 'StanfordNLP', __resources_version__)\nSTANFORDNLP_RESOURCES_URL = 'https://nlp.stanford.edu/software/stanza/stanza-resources/'\nSTANZA_RESOURCES_GITHUB = 'https://raw.githubusercontent.com/stanfordnlp/stanza-resources/'\nDEFAULT_RESOURCES_URL = os.getenv('STANZA_RESOURCES_URL', STANZA_RESOURCES_GITHUB + 'main')\nDEFAULT_RESOURCES_VERSION = os.getenv(\n    'STANZA_RESOURCES_VERSION',\n    __resources_version__\n)\nDEFAULT_MODEL_URL = os.getenv('STANZA_MODEL_URL', 'default')\nDEFAULT_MODEL_DIR = os.getenv(\n    'STANZA_RESOURCES_DIR',\n    os.path.join(USER_CACHE_DIR, 'resources')\n)\n\nPRETRAIN_NAMES = (\"pretrain\", \"forward_charlm\", \"backward_charlm\")\n\nclass ResourcesFileNotFoundError(FileNotFoundError):\n    def __init__(self, resources_filepath):\n        super().__init__(f\"Resources file not found at: {resources_filepath}  Try to download the model again.\")\n        self.resources_filepath = resources_filepath\n\nclass UnknownLanguageError(ValueError):\n    def __init__(self, unknown):\n        super().__init__(f\"Unknown language requested: {unknown}\")\n        self.unknown_language = unknown\n\nclass UnknownProcessorError(ValueError):\n    def __init__(self, unknown):\n        super().__init__(f\"Unknown processor type requested: {unknown}\")\n        self.unknown_processor = unknown\n\nModelSpecification = namedtuple('ModelSpecification', ['processor', 'package', 'dependencies'])\n\ndef ensure_dir(path):\n    \"\"\"\n    Create dir in case it does not exist.\n    \"\"\"\n    Path(path).mkdir(parents=True, exist_ok=True)\n\ndef get_md5(path):\n    \"\"\"\n    Get the MD5 value of a path.\n    \"\"\"\n    try:\n        with open(path, 'rb') as fin:\n            data = fin.read()\n    except OSError as e:\n        if not e.filename:\n            e.filename = path\n        raise\n    return hashlib.md5(data).hexdigest()\n\ndef unzip(path, filename):\n    \"\"\"\n    Fully unzip a file `filename` that's in a directory `dir`.\n    \"\"\"\n    logger.debug(f'Unzip: {path}/{filename}...')\n    with zipfile.ZipFile(os.path.join(path, filename)) as f:\n        f.extractall(path)\n\ndef get_root_from_zipfile(filename):\n    \"\"\"\n    Get the root directory from a archived zip file.\n    \"\"\"\n    zf = zipfile.ZipFile(filename, \"r\")\n    assert len(zf.filelist) > 0, \\\n        f\"Zip file at f{filename} seems to be corrupted. Please check it.\"\n    return os.path.dirname(zf.filelist[0].filename)\n\ndef file_exists(path, md5):\n    \"\"\"\n    Check if the file at `path` exists and match the provided md5 value.\n    \"\"\"\n    return os.path.exists(path) and get_md5(path) == md5\n\ndef assert_file_exists(path, md5=None, alternate_md5=None):\n    if not os.path.exists(path):\n        raise FileNotFoundError(errno.ENOENT, \"Cannot find expected file\", path)\n    if md5:\n        file_md5 = get_md5(path)\n        if file_md5 != md5:\n            if file_md5 == alternate_md5:\n                logger.debug(\"Found a possibly older version of file %s, md5 %s instead of %s\", path, alternate_md5, md5)\n            else:\n                raise ValueError(\"md5 for %s is %s, expected %s\" % (path, file_md5, md5))\n\ndef download_file(url, path, proxies, raise_for_status=False):\n    \"\"\"\n    Download a URL into a file as specified by `path`.\n    \"\"\"\n    verbose = logger.level in [0, 10, 20]\n    r = requests.get(url, stream=True, proxies=proxies)\n    if raise_for_status:\n        r.raise_for_status()\n    with open(path, 'wb') as f:\n        file_size = r.headers.get('content-length', None)\n        if file_size:\n            file_size = int(file_size)\n        default_chunk_size = 131072\n        desc = 'Downloading ' + url\n        with tqdm(total=file_size, unit='B', unit_scale=True, \\\n            disable=not verbose, desc=desc) as pbar:\n            for chunk in r.iter_content(chunk_size=default_chunk_size):\n                if chunk:\n                    f.write(chunk)\n                    f.flush()\n                    pbar.update(len(chunk))\n    return r.status_code\n\ndef request_file(url, path, proxies=None, md5=None, raise_for_status=False, log_info=True, alternate_md5=None):\n    \"\"\"\n    A complete wrapper over download_file() that also make sure the directory of\n    `path` exists, and that a file matching the md5 value does not exist.\n\n    alternate_md5 allows for an alternate md5 that is acceptable (such as if an older version of a file is okay)\n    \"\"\"\n    basedir = Path(path).parent\n    ensure_dir(basedir)\n    if file_exists(path, md5):\n        if log_info:\n            logger.info(f'File exists: {path}')\n        else:\n            logger.debug(f'File exists: {path}')\n        return\n    # We write data first to a temporary directory,\n    # then use os.replace() so that multiple processes\n    # running at the same time don't clobber each other\n    # with partially downloaded files\n    # This was especially common with resources.json\n    with tempfile.TemporaryDirectory(dir=basedir) as temp:\n        temppath = os.path.join(temp, os.path.split(path)[-1])\n        download_file(url, temppath, proxies, raise_for_status)\n        os.replace(temppath, path)\n    assert_file_exists(path, md5, alternate_md5)\n    if log_info:\n        logger.info(f'Downloaded file to {path}')\n    else:\n        logger.debug(f'Downloaded file to {path}')\n\ndef sort_processors(processor_list):\n    sorted_list = []\n    for processor in PIPELINE_NAMES:\n        for item in processor_list:\n            if item[0] == processor:\n                sorted_list.append(item)\n    # going just by processors in PIPELINE_NAMES, this drops any names\n    # which are not an official processor but might still be useful\n    # check the list and append them to the end\n    # this is especially useful when downloading pretrain or charlm models\n    for processor in processor_list:\n        for item in sorted_list:\n            if processor[0] == item[0]:\n                break\n        else:\n            sorted_list.append(item)\n    return sorted_list\n\ndef add_mwt(processors, resources, lang):\n    \"\"\"Add mwt if tokenize is passed without mwt.\n\n    If tokenize is in the list, but mwt is not, and there is a corresponding\n    tokenize and mwt pair in the resources file, mwt is added so no missing\n    mwt errors are raised.\n    \"\"\"\n    value = processors[TOKENIZE]\n    if value in resources[lang][PACKAGES] and MWT in resources[lang][PACKAGES][value]:\n        logger.warning(\"Language %s package %s expects mwt, which has been added\", lang, value)\n        processors[MWT] = value\n    elif (value in resources[lang][TOKENIZE] and MWT in resources[lang] and value in resources[lang][MWT]):\n        logger.warning(\"Language %s package %s expects mwt, which has been added\", lang, value)\n        processors[MWT] = value\n\ndef maintain_processor_list(resources, lang, package, processors, allow_pretrain=False, maybe_add_mwt=True):\n    \"\"\"\n    Given a parsed resources file, language, and possible package\n    and/or processors, expands the package to the list of processors\n\n    Returns a list of processors\n    Each item in the list of processors is a pair:\n      name, then a list of ModelSpecification\n    so, for example:\n      [['pos', [ModelSpecification(processor='pos', package='gsd', dependencies=None)]],\n       ['depparse', [ModelSpecification(processor='depparse', package='gsd', dependencies=None)]]]\n    \"\"\"\n    processor_list = defaultdict(list)\n    # resolve processor models\n    if processors:\n        logger.debug(f'Processing parameter \"processors\"...')\n        if maybe_add_mwt and TOKENIZE in processors and MWT not in processors:\n            add_mwt(processors, resources, lang)\n        for key, plist in processors.items():\n            if not isinstance(key, str):\n                raise ValueError(\"Processor names must be strings\")\n            if not isinstance(plist, (tuple, list, str)):\n                raise ValueError(\"Processor values must be strings\")\n            if isinstance(plist, str):\n                plist = [plist]\n            if key not in PIPELINE_NAMES:\n                if not allow_pretrain or key not in PRETRAIN_NAMES:\n                    raise UnknownProcessorError(key)\n            for value in plist:\n                # check if keys and values can be found\n                if key in resources[lang] and value in resources[lang][key]:\n                    logger.debug(f'Found {key}: {value}.')\n                    processor_list[key].append(value)\n                # allow values to be default in some cases\n                elif value in resources[lang][PACKAGES] and key in resources[lang][PACKAGES][value]:\n                    logger.debug(\n                        f'Found {key}: {resources[lang][PACKAGES][value][key]}.'\n                    )\n                    processor_list[key].append(resources[lang][PACKAGES][value][key])\n                # optional defaults will be activated if specifically turned on\n                elif value in resources[lang][PACKAGES] and 'optional' in resources[lang][PACKAGES][value] and key in resources[lang][PACKAGES][value]['optional']:\n                    logger.debug(\n                        f\"Found {key}: {resources[lang][PACKAGES][value]['optional'][key]}.\"\n                    )\n                    processor_list[key].append(resources[lang][PACKAGES][value]['optional'][key])\n                # allow processors to be set to variants that we didn't implement\n                elif value in PROCESSOR_VARIANTS[key]:\n                    logger.debug(\n                        f'Found {key}: {value}. '\n                        f'Using external {value} variant for the {key} processor.'\n                    )\n                    processor_list[key].append(value)\n                # allow lemma to be set to \"identity\"\n                elif key == LEMMA and value == 'identity':\n                    logger.debug(\n                        f'Found {key}: {value}. Using identity lemmatizer.'\n                    )\n                    processor_list[key].append(value)\n                # not a processor in the officially supported processor list\n                elif key not in resources[lang]:\n                    logger.debug(\n                        f'{key}: {value} is not officially supported by Stanza, '\n                        f'loading it anyway.'\n                    )\n                    processor_list[key].append(value)\n                # cannot find the package for a processor and warn user\n                else:\n                    logger.warning(\n                        f'Can not find {key}: {value} from official model list. '\n                        f'Ignoring it.'\n                    )\n    # resolve package\n    if package:\n        logger.debug(f'Processing parameter \"package\"...')\n        if PACKAGES in resources[lang] and package in resources[lang][PACKAGES]:\n            for key, value in resources[lang][PACKAGES][package].items():\n                if key != 'optional' and key not in processor_list:\n                    logger.debug(f'Found {key}: {value}.')\n                    processor_list[key].append(value)\n        else:\n            flag = False\n            for key in PIPELINE_NAMES:\n                if key not in resources[lang]: continue\n                if package in resources[lang][key]:\n                    flag = True\n                    if key not in processor_list:\n                        logger.debug(f'Found {key}: {package}.')\n                        processor_list[key].append(package)\n                    else:\n                        logger.debug(\n                            f'{key}: {package} is overwritten by '\n                            f'{key}: {processors[key]}.'\n                        )\n            if not flag: logger.warning((f'Can not find package: {package}.'))\n    processor_list = [[key, [ModelSpecification(processor=key, package=value, dependencies=None) for value in plist]] for key, plist in processor_list.items()]\n    processor_list = sort_processors(processor_list)\n    return processor_list\n\ndef add_dependencies(resources, lang, processor_list):\n    \"\"\"\n    Expand the processor_list as given in maintain_processor_list to have the dependencies\n\n    Still a list of model types to ModelSpecifications\n    the dependencies are tuples: name and package\n    for example:\n    [['pos', (ModelSpecification(processor='pos', package='gsd', dependencies=(('pretrain', 'gsd'),)),)],\n     ['depparse', (ModelSpecification(processor='depparse', package='gsd', dependencies=(('pretrain', 'gsd'),)),)]]\n    \"\"\"\n    lang_resources = resources[lang]\n    for item in processor_list:\n        processor, model_specs = item\n        new_model_specs = []\n        for model_spec in model_specs:\n            # skip dependency checking for external variants of processors and identity lemmatizer\n            if not any([\n                    model_spec.package in PROCESSOR_VARIANTS[processor],\n                    processor == LEMMA and model_spec.package == 'identity'\n                ]):\n                dependencies = lang_resources.get(processor, {}).get(model_spec.package, {}).get('dependencies', [])\n                dependencies = [(dependency['model'], dependency['package']) for dependency in dependencies]\n                model_spec = model_spec._replace(dependencies=tuple(dependencies))\n                logger.debug(\"Found dependencies %s for processor %s model %s\", dependencies, processor, model_spec.package)\n            new_model_specs.append(model_spec)\n        item[1] = tuple(new_model_specs)\n    return processor_list\n\ndef flatten_processor_list(processor_list):\n    \"\"\"\n    The flattened processor list is just a list of types & packages\n\n    For example:\n      [['pos', 'gsd'], ['depparse', 'gsd'], ['pretrain', 'gsd']]\n    \"\"\"\n    flattened_processor_list = []\n    dependencies_list = []\n    for item in processor_list:\n        processor, model_specs = item\n        for model_spec in model_specs:\n            package = model_spec.package\n            dependencies = model_spec.dependencies\n            flattened_processor_list.append([processor, package])\n            if dependencies:\n                dependencies_list += [tuple(dependency) for dependency in dependencies]\n    dependencies_list = [list(item) for item in set(dependencies_list)]\n    for processor, package in dependencies_list:\n        logger.debug(f'Find dependency {processor}: {package}.')\n    flattened_processor_list += dependencies_list\n    return flattened_processor_list\n\ndef set_logging_level(logging_level, verbose):\n    # Check verbose for easy logging control\n    if verbose == False:\n        logging_level = 'ERROR'\n    elif verbose == True:\n        logging_level = 'INFO'\n\n    if logging_level is None:\n        # default logging level of INFO is set in stanza.__init__\n        # but the user may have set it via the logging API\n        # it should NOT be 0, but let's check to be sure...\n        if logger.level == 0:\n            logger.setLevel('INFO')\n        return logger.level\n\n    # Set logging level\n    logging_level = logging_level.upper()\n    all_levels = ['DEBUG', 'INFO', 'WARNING', 'WARN', 'ERROR', 'CRITICAL', 'FATAL']\n    if logging_level not in all_levels:\n        raise ValueError(\n            f\"Unrecognized logging level for pipeline: \"\n            f\"{logging_level}. Must be one of {', '.join(all_levels)}.\"\n        )\n    logger.setLevel(logging_level)\n    return logger.level\n\ndef process_pipeline_parameters(lang, model_dir, package, processors):\n    # Check parameter types and convert values to lower case\n    if isinstance(lang, str):\n        lang = lang.strip().lower()\n    elif lang is not None:\n        raise TypeError(\n            f\"The parameter 'lang' should be str, \"\n            f\"but got {type(lang).__name__} instead.\"\n        )\n\n    if isinstance(model_dir, str):\n        model_dir = model_dir.strip()\n    elif model_dir is not None:\n        raise TypeError(\n            f\"The parameter 'model_dir' should be str, \"\n            f\"but got {type(model_dir).__name__} instead.\"\n        )\n\n    if isinstance(processors, (str, list, tuple)):\n        # Special case: processors is str, compatible with older version\n        # also allow for setting alternate packages for these processors\n        # via the package argument\n        if package is None:\n            # each processor will be 'default' for this language\n            package = defaultdict(lambda: 'default')\n        elif isinstance(package, str):\n            # same, but now the named package will be the default instead\n            default = package\n            package = defaultdict(lambda: default)\n        elif isinstance(package, dict):\n            # the dictionary of packages will be used to build the processors dict\n            # any processor not specified in package will be 'default'\n            package = defaultdict(lambda: 'default', package)\n        else:\n            raise TypeError(\n                f\"The parameter 'package' should be None, str, or dict, \"\n                f\"but got {type(package).__name__} instead.\"\n            )\n        if isinstance(processors, str):\n            processors = [x.strip().lower() for x in processors.split(\",\")]\n        processors = {\n            processor: package[processor] for processor in processors\n        }\n        package = None\n    elif isinstance(processors, dict):\n        processors = {\n            k.strip().lower(): ([v_i.strip().lower() for v_i in v] if isinstance(v, (tuple, list)) else v.strip().lower())\n            for k, v in processors.items()\n        }\n    elif processors is not None:\n        raise TypeError(\n            f\"The parameter 'processors' should be dict or str, \"\n            f\"but got {type(processors).__name__} instead.\"\n        )\n\n    if isinstance(package, str):\n        package = package.strip().lower()\n    elif package is not None:\n        raise TypeError(\n            f\"The parameter 'package' should be str, or a dict if 'processors' is a str, \"\n            f\"but got {type(package).__name__} instead.\"\n        )\n\n    return lang, model_dir, package, processors\n\ndef download_resources_json(model_dir=DEFAULT_MODEL_DIR,\n                            resources_url=DEFAULT_RESOURCES_URL,\n                            resources_branch=None,\n                            resources_version=DEFAULT_RESOURCES_VERSION,\n                            resources_filepath=None,\n                            proxies=None):\n    \"\"\"\n    Downloads resources.json to obtain latest packages.\n    \"\"\"\n    if resources_url == DEFAULT_RESOURCES_URL and resources_branch is not None:\n        resources_url = STANZA_RESOURCES_GITHUB + resources_branch\n    # handle short name for resources urls; otherwise treat it as url\n    if resources_url.lower() in ('stanford', 'stanfordnlp'):\n        resources_url = STANFORDNLP_RESOURCES_URL\n    resources_url = f'{resources_url}/resources_{resources_version}.json'\n    logger.debug('Downloading resource file from %s', resources_url)\n    if resources_filepath is None:\n        resources_filepath = os.path.join(model_dir, 'resources.json')\n    # make request\n    request_file(\n        resources_url,\n        resources_filepath,\n        proxies,\n        raise_for_status=True\n    )\n\n\ndef load_resources_json(model_dir=DEFAULT_MODEL_DIR, resources_filepath=None):\n    \"\"\"\n    Unpack the resources json file from the given model_dir\n    \"\"\"\n    if resources_filepath is None:\n        resources_filepath = os.path.join(model_dir, 'resources.json')\n    if not os.path.exists(resources_filepath):\n        raise ResourcesFileNotFoundError(resources_filepath)\n    with open(resources_filepath, encoding=\"utf-8\") as fin:\n        resources = json.load(fin)\n    return resources\n\ndef get_language_resources(resources, lang):\n    \"\"\"\n    Get the resources for a lang from an already loaded resources json, following 'alias' if needed\n    \"\"\"\n    if lang not in resources:\n        return None\n\n    lang_resources = resources[lang]\n    while 'alias' in lang_resources:\n        lang = lang_resources['alias']\n        lang_resources = resources[lang]\n\n    return lang_resources\n\ndef list_available_languages(model_dir=DEFAULT_MODEL_DIR,\n                             resources_url=DEFAULT_RESOURCES_URL,\n                             resources_branch=None,\n                             resources_version=DEFAULT_RESOURCES_VERSION,\n                             proxies=None):\n    \"\"\"\n    List the non-alias languages in the resources file\n    \"\"\"\n    download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies)\n    resources = load_resources_json(model_dir)\n    # isinstance(str) is because of fields such as \"url\"\n    # 'alias' is because we want to skip German, alias of de, for example\n    languages = [lang for lang in resources\n                 if not isinstance(resources[lang], str) and 'alias' not in resources[lang]]\n    languages = sorted(languages)\n    return languages\n\ndef expand_model_url(resources, model_url):\n    \"\"\"\n    Returns the url in the resources dict if model_url is default, or returns the model_url\n    \"\"\"\n    return resources['url'] if model_url.lower() == 'default' else model_url\n\ndef download_models(download_list,\n                    resources,\n                    lang,\n                    model_dir=DEFAULT_MODEL_DIR,\n                    resources_version=DEFAULT_RESOURCES_VERSION,\n                    model_url=DEFAULT_MODEL_URL,\n                    proxies=None,\n                    log_info=True):\n    lang_name = resources.get(lang, {}).get('lang_name', lang)\n    download_table = make_table(['Processor', 'Package'], download_list)\n    if log_info:\n        log_msg = logger.info\n    else:\n        log_msg = logger.debug\n    log_msg(\n        f'Downloading these customized packages for language: '\n        f'{lang} ({lang_name})...\\n{download_table}'\n    )\n\n    url = expand_model_url(resources, model_url)\n\n    # Download packages\n    for key, value in download_list:\n        try:\n            request_file(\n                url.format(resources_version=resources_version, lang=lang, filename=f\"{key}/{value}.pt\"),\n                os.path.join(model_dir, lang, key, f'{value}.pt'),\n                proxies,\n                md5=resources[lang][key][value]['md5'],\n                log_info=log_info,\n                alternate_md5=resources[lang][key][value].get('alternate_md5', None)\n            )\n        except KeyError as e:\n            raise ValueError(\n                f'Cannot find the following processor and model name combination: '\n                f'{key}, {value}. Please check if you have provided the correct model name.'\n            ) from e\n\n# main download function\ndef download(\n        lang='en',\n        model_dir=DEFAULT_MODEL_DIR,\n        package='default',\n        processors={},\n        logging_level=None,\n        verbose=None,\n        resources_url=DEFAULT_RESOURCES_URL,\n        resources_branch=None,\n        resources_version=DEFAULT_RESOURCES_VERSION,\n        model_url=DEFAULT_MODEL_URL,\n        proxies=None,\n        download_json=True\n    ):\n    # set global logging level\n    set_logging_level(logging_level, verbose)\n    # process different pipeline parameters\n    lang, model_dir, package, processors = process_pipeline_parameters(\n        lang, model_dir, package, processors\n    )\n\n    if download_json or not os.path.exists(os.path.join(model_dir, 'resources.json')):\n        if not download_json:\n            logger.warning(\"Asked to skip downloading resources.json, but the file does not exist.  Downloading anyway\")\n        download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies)\n\n    resources = load_resources_json(model_dir)\n    if lang not in resources:\n        raise UnknownLanguageError(lang)\n    if 'alias' in resources[lang]:\n        logger.info(f'\"{lang}\" is an alias for \"{resources[lang][\"alias\"]}\"')\n        lang = resources[lang]['alias']\n    lang_name = resources.get(lang, {}).get('lang_name', lang)\n    url = expand_model_url(resources, model_url)\n\n    # Default: download zipfile and unzip\n    if package == 'default' and (processors is None or len(processors) == 0):\n        logger.info(\n            f'Downloading default packages for language: {lang} ({lang_name}) ...'\n        )\n        # want the URL to become, for example:\n        # https://huggingface.co/stanfordnlp/stanza-af/resolve/v1.3.0/models/default.zip\n        # so we hopefully start from\n        # https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}\n        request_file(\n            url.format(resources_version=resources_version, lang=lang, filename=\"default.zip\"),\n            os.path.join(model_dir, lang, f'default.zip'),\n            proxies,\n            md5=resources[lang]['default_md5'],\n        )\n        unzip(os.path.join(model_dir, lang), 'default.zip')\n        download_list = [['zip', 'default.zip']]\n    # Customize: maintain download list\n    else:\n        download_list = maintain_processor_list(resources, lang, package, processors, allow_pretrain=True)\n        download_list = add_dependencies(resources, lang, download_list)\n        download_list = flatten_processor_list(download_list)\n        download_models(download_list=download_list,\n                        resources=resources,\n                        lang=lang,\n                        model_dir=model_dir,\n                        resources_version=resources_version,\n                        model_url=model_url,\n                        proxies=proxies,\n                        log_info=True)\n    logger.info(f'Finished downloading models and saved to {model_dir}')\n    return download_list\n"
  },
  {
    "path": "stanza/resources/default_packages.py",
    "content": "\"\"\"\nConstants for default packages, default pretrains, charlms, etc\n\nSeparated from prepare_resources.py so that other modules can use the\nsame lists / maps without importing the resources script and possibly\ncausing a circular import\n\"\"\"\n\nimport copy\n\n# all languages will have a map which represents the available packages\nPACKAGES = \"packages\"\n\n# default treebank for languages\ndefault_treebanks = {\n    \"ab\":      \"abnc\",\n    \"af\":      \"afribooms\",\n    # currently not publicly released!  sent to us from the group developing this resource\n    \"ang\":     \"nerthus\",\n    \"ar\":      \"padt\",\n    \"be\":      \"hse\",\n    \"bg\":      \"btb\",\n    \"bxr\":     \"bdt\",\n    \"ca\":      \"ancora\",\n    \"cop\":     \"scriptorium\",\n    \"cs\":      \"pdt\",\n    \"cu\":      \"proiel\",\n    \"cy\":      \"ccg\",\n    \"da\":      \"ddt\",\n    \"de\":      \"combined\",\n    \"el\":      \"gdt\",\n    \"en\":      \"combined\",\n    \"es\":      \"combined\",\n    \"et\":      \"edt\",\n    \"eu\":      \"bdt\",\n    \"fa\":      \"perdt\",\n    \"fi\":      \"tdt\",\n    \"fo\":      \"farpahc\",\n    \"fr\":      \"combined\",\n    \"fro\":     \"profiterole\",\n    \"ga\":      \"idt\",\n    \"gd\":      \"arcosg\",\n    \"gl\":      \"ctg\",\n    \"got\":     \"proiel\",\n    \"grc\":     \"perseus\",\n    \"gv\":      \"cadhan\",\n    \"hbo\":     \"ptnk\",\n    \"he\":      \"combined\",\n    \"hi\":      \"hdtb\",\n    \"hr\":      \"set\",\n    \"hsb\":     \"ufal\",\n    \"hu\":      \"szeged\",\n    \"hy\":      \"armtdp\",\n    \"hyw\":     \"armtdp\",\n    \"id\":      \"gsd\",\n    \"is\":      \"icepahc\",\n    \"it\":      \"combined\",\n    \"ja\":      \"combined\",\n    \"ka\":      \"glc\",\n    \"kk\":      \"ktb\",\n    \"kmr\":     \"mg\",\n    \"ko\":      \"kaist\",\n    \"kpv\":     \"lattice\",\n    \"ky\":      \"ktmu\",\n    \"la\":      \"ittb\",\n    \"lij\":     \"glt\",\n    \"lt\":      \"alksnis\",\n    \"lv\":      \"lvtb\",\n    \"lzh\":     \"kyoto\",\n    \"mr\":      \"ufal\",\n    \"mt\":      \"mudt\",\n    \"my\":      \"ucsy\",\n    \"myv\":     \"jr\",\n    \"nb\":      \"bokmaal\",\n    \"nds\":     \"lsdc\",\n    \"nl\":      \"alpino\",\n    \"nn\":      \"nynorsk\",\n    \"olo\":     \"kkpp\",\n    \"orv\":     \"torot\",\n    \"ota\":     \"boun\",\n    \"pcm\":     \"nsc\",\n    \"pl\":      \"pdb\",\n    \"pt\":      \"bosque\",\n    \"qaf\":     \"arabizi\",\n    \"qpm\":     \"philotis\",\n    \"qtd\":     \"sagt\",\n    \"ro\":      \"rrt\",\n    \"ru\":      \"syntagrus\",\n    \"sa\":      \"vedic\",\n    \"sd\":      \"isra\",\n    \"sk\":      \"snk\",\n    \"sl\":      \"ssj\",\n    \"sme\":     \"giella\",\n    \"sq\":      \"combined\",\n    \"sr\":      \"set\",\n    \"sv\":      \"talbanken\",\n    \"swl\":     \"sslc\",\n    \"ta\":      \"ttb\",\n    \"te\":      \"mtg\",\n    \"th\":      \"tud\",\n    \"tr\":      \"imst\",\n    \"ug\":      \"udt\",\n    \"uk\":      \"iu\",\n    \"ur\":      \"udtb\",\n    \"vi\":      \"vtb\",\n    \"wo\":      \"wtb\",\n    \"xcl\":     \"caval\",\n    \"zh-hans\": \"gsdsimp\",\n    \"zh-hant\": \"gsd\",\n    \"multilingual\": \"ud\"\n}\n\nno_pretrain_languages = set([\n    \"cop\",\n    \"olo\",\n    \"orv\",\n    \"pcm\",\n    \"qaf\",   # the QAF treebank is code switched and Romanized, so not easy to reuse existing resources\n    \"qpm\",   # have talked about deriving this from a language neighborinig to Pomak, but that hasn't happened yet\n    \"qtd\",\n    \"swl\",\n\n    \"multilingual\", # special case so that all languages with a default treebank are represented somewhere\n])\n\n\n# in some cases, we give the pretrain a name other than the original\n# name for the UD dataset\n# we will eventually do this for all of the pretrains\nspecific_default_pretrains = {\n    \"ab\":      \"fasttextwiki\",\n    \"af\":      \"fasttextwiki\",\n    \"ang\":     \"nerthus\",\n    \"ar\":      \"conll17\",\n    \"be\":      \"fasttextwiki\",\n    \"bg\":      \"conll17\",\n    \"bxr\":     \"fasttextwiki\",\n    \"ca\":      \"conll17\",\n    \"cs\":      \"conll17\",\n    \"cu\":      \"conll17\",\n    \"cy\":      \"fasttext157\",\n    \"da\":      \"conll17\",\n    \"de\":      \"conll17\",\n    \"el\":      \"conll17\",\n    \"en\":      \"conll17\",\n    \"es\":      \"conll17\",\n    \"et\":      \"conll17\",\n    \"eu\":      \"conll17\",\n    \"fa\":      \"conll17\",\n    \"fi\":      \"conll17\",\n    \"fo\":      \"fasttextwiki\",\n    \"fr\":      \"conll17\",\n    \"fro\":     \"conll17\",\n    \"ga\":      \"conll17\",\n    \"gd\":      \"fasttextwiki\",\n    \"gl\":      \"conll17\",\n    \"got\":     \"fasttextwiki\",\n    \"grc\":     \"conll17\",\n    \"gv\":      \"fasttext157\",\n    \"hbo\":     \"utah\",\n    \"he\":      \"conll17\",\n    \"hi\":      \"conll17\",\n    \"hr\":      \"conll17\",\n    \"hsb\":     \"fasttextwiki\",\n    \"hu\":      \"conll17\",\n    \"hy\":      \"isprasglove\",\n    \"hyw\":     \"isprasglove\",\n    \"id\":      \"conll17\",\n    \"is\":      \"fasttext157\",\n    \"it\":      \"conll17\",\n    \"ja\":      \"conll17\",\n    \"ka\":      \"fasttext157\",\n    \"kk\":      \"fasttext157\",\n    \"kmr\":     \"fasttextwiki\",\n    \"ko\":      \"conll17\",\n    \"kpv\":     \"fasttextwiki\",\n    \"ky\":      \"fasttext157\",\n    \"la\":      \"conll17\",\n    \"lij\":     \"fasttextwiki\",\n    \"lt\":      \"fasttextwiki\",\n    \"lv\":      \"conll17\",\n    \"lzh\":     \"fasttextwiki\",\n    \"mr\":      \"fasttextwiki\",\n    \"mt\":      \"fasttextwiki\",\n    \"my\":      \"ucsy\",\n    \"myv\":     \"mokha\",\n    \"nb\":      \"conll17\",\n    \"nds\":     \"fasttext157\",\n    \"nl\":      \"conll17\",\n    \"nn\":      \"conll17\",\n    \"or\":      \"fasttext157\",\n    \"ota\":     \"conll17\",\n    \"pl\":      \"conll17\",\n    \"pt\":      \"conll17\",\n    \"ro\":      \"conll17\",\n    \"ru\":      \"conll17\",\n    \"sa\":      \"fasttext157\",\n    \"sd\":      \"isra\",\n    \"sk\":      \"conll17\",\n    \"sl\":      \"conll17\",\n    \"sme\":     \"fasttextwiki\",\n    \"sq\":      \"fasttext157\",\n    \"sr\":      \"fasttextwiki\",\n    \"sv\":      \"conll17\",\n    \"ta\":      \"fasttextwiki\",\n    \"te\":      \"fasttextwiki\",\n    \"th\":      \"fasttext157\",\n    \"tr\":      \"conll17\",\n    \"ug\":      \"conll17\",\n    \"uk\":      \"conll17\",\n    \"ur\":      \"conll17\",\n    \"vi\":      \"conll17\",\n    \"wo\":      \"fasttextwiki\",\n    \"xcl\":     \"caval\",\n    \"zh-hans\": \"fasttext157\",\n    \"zh-hant\": \"conll17\",\n}\n\n\ndef build_default_pretrains(default_treebanks):\n    default_pretrains = dict(default_treebanks)\n    for lang in no_pretrain_languages:\n        default_pretrains.pop(lang, None)\n    for lang in specific_default_pretrains.keys():\n        default_pretrains[lang] = specific_default_pretrains[lang]\n    return default_pretrains\n\ndefault_pretrains = build_default_pretrains(default_treebanks)\n\npos_pretrains = {\n    \"en\": {\n        \"craft\":            \"biomed\",\n        \"genia\":            \"biomed\",\n        \"mimic\":            \"mimic\",\n    },\n}\n\ndepparse_pretrains = pos_pretrains\n\nner_pretrains = {\n    \"ar\": {\n        \"aqmar\": \"fasttextwiki\",\n    },\n    \"de\": {\n        \"conll03\":      \"fasttextwiki\",\n        # the bert version of germeval uses the smaller vector file\n        \"germeval2014\": \"fasttextwiki\",\n    },\n    \"en\": {\n        \"anatem\":       \"biomed\",\n        \"bc4chemd\":     \"biomed\",\n        \"bc5cdr\":       \"biomed\",\n        \"bionlp13cg\":   \"biomed\",\n        \"jnlpba\":       \"biomed\",\n        \"linnaeus\":     \"biomed\",\n        \"ncbi_disease\": \"biomed\",\n        \"s800\":         \"biomed\",\n\n        \"ontonotes\":    \"fasttextcrawl\",\n        # the stanza-train sample NER model should use the default NER pretrain\n        # for English, that is the same as ontonotes\n        \"sample\":       \"fasttextcrawl\",\n\n        \"conll03\":      \"glove\",\n\n        \"i2b2\":         \"mimic\",\n        \"radiology\":    \"mimic\",\n    },\n    \"es\": {\n        \"ancora\":  \"fasttextwiki\",\n        \"conll02\": \"fasttextwiki\",\n    },\n    \"nl\": {\n        \"conll02\": \"fasttextwiki\",\n        \"wikiner\": \"fasttextwiki\",\n    },\n    \"ru\": {\n        \"wikiner\": \"fasttextwiki\",\n    },\n    \"th\": {\n        \"lst20\": \"fasttext157\",\n    },\n}\n\n\n# default charlms for languages\ndefault_charlms = {\n    \"af\": \"oscar\",\n    \"ang\": \"nerthus1024\",\n    \"ar\": \"ccwiki\",\n    \"bg\": \"conll17\",\n    \"da\": \"oscar\",\n    \"de\": \"newswiki\",\n    \"en\": \"1billion\",\n    \"es\": \"newswiki\",\n    \"fa\": \"conll17\",\n    \"fi\": \"conll17\",\n    \"fr\": \"newswiki\",\n    \"he\": \"oscar\",\n    \"hi\": \"oscar\",\n    \"id\": \"oscar2023\",\n    \"it\": \"conll17\",\n    \"ja\": \"conll17\",\n    \"kk\": \"oscar\",\n    \"mr\": \"l3cube\",\n    \"my\": \"oscar\",\n    \"nb\": \"conll17\",\n    \"nl\": \"ccwiki\",\n    \"pl\": \"oscar\",\n    \"pt\": \"oscar2023\",\n    \"ru\": \"newswiki\",\n    \"sd\": \"isra\",\n    \"sv\": \"conll17\",\n    \"te\": \"oscar2022\",\n    \"th\": \"oscar\",\n    \"tr\": \"conll17\",\n    \"uk\": \"conll17\",\n    \"vi\": \"conll17\",\n    \"zh-hans\": \"gigaword\"\n}\n\npos_charlms = {\n    \"en\": {\n        # none of the English charlms help with craft or genia\n        \"craft\": None,\n        \"genia\": None,\n        \"mimic\": \"mimic\",\n    },\n    \"tr\": {   # no idea why, but this particular one goes down in dev score\n        \"boun\": None,\n    },\n}\n\ndepparse_charlms = copy.deepcopy(pos_charlms)\n\nlemma_charlms = copy.deepcopy(pos_charlms)\n\ntokenizer_charlms = copy.deepcopy(pos_charlms)\n\nner_charlms = {\n    \"en\": {\n        \"conll03\": \"1billion\",\n        \"ontonotes\": \"1billion\",\n        \"anatem\": \"pubmed\",\n        \"bc4chemd\": \"pubmed\",\n        \"bc5cdr\": \"pubmed\",\n        \"bionlp13cg\": \"pubmed\",\n        \"i2b2\": \"mimic\",\n        \"jnlpba\": \"pubmed\",\n        \"linnaeus\": \"pubmed\",\n        \"ncbi_disease\": \"pubmed\",\n        \"radiology\": \"mimic\",\n        \"s800\": \"pubmed\",\n    },\n    \"hu\": {\n        \"combined\": None,\n    },\n    \"nn\": {\n        \"norne\": None,\n    },\n}\n\n# default ner for languages\ndefault_ners = {\n    \"af\": \"nchlt\",\n    \"ang\": \"oedt_charlm\",\n    \"ar\": \"aqmar_charlm\",\n    \"bg\": \"bsnlp19\",\n    \"da\": \"ddt\",\n    \"de\": \"germeval2014\",\n    \"en\": \"ontonotes-ww-multi_charlm\",\n    \"es\": \"conll02\",\n    \"fa\": \"arman\",\n    \"fi\": \"turku\",\n    \"fr\": \"wikinergold_charlm\",\n    \"he\": \"iahlt_charlm\",\n    \"hi\": \"ilner_charlm\",\n    \"hu\": \"combined\",\n    \"hy\": \"armtdp\",\n    \"it\": \"fbk\",\n    \"ja\": \"gsd\",\n    \"kk\": \"kazNERD\",\n    \"mr\": \"l3cube\",\n    \"my\": \"ucsy\",\n    \"nb\": \"norne\",\n    \"nl\": \"conll02\",\n    \"nn\": \"norne\",\n    \"pl\": \"nkjp\",\n    \"ru\": \"wikiner\",\n    \"sd\": \"siner\",\n    \"sv\": \"suc3shuffle\",\n    \"te\": \"ilner_charlm\",\n    \"th\": \"lst20\",\n    \"tr\": \"starlang\",\n    \"uk\": \"languk\",\n    \"ur\": \"ilner_nocharlm\",\n    \"vi\": \"vlsp\",\n    \"zh-hans\": \"ontonotes\",\n}\n\n# a few languages have sentiment classifier models\ndefault_sentiment = {\n    \"en\": \"sstplus_charlm\",\n    \"de\": \"sb10k_charlm\",\n    \"es\": \"tass2020_charlm\",\n    \"mr\": \"l3cube_charlm\",\n    \"vi\": \"vsfc_charlm\",\n    \"zh-hans\": \"ren_charlm\",\n}\n\n# also, a few languages (very few, currently) have constituency parser models\ndefault_constituency = {\n    \"da\": \"arboretum_charlm\",\n    \"de\": \"spmrl_charlm\",\n    \"en\": \"ptb3-revised_charlm\",\n    \"es\": \"combined_charlm\",\n    \"id\": \"icon_charlm\",\n    \"it\": \"vit_charlm\",\n    \"ja\": \"alt_charlm\",\n    \"pt\": \"cintil_charlm\",\n    #\"tr\": \"starlang_charlm\",\n    \"vi\": \"vlsp22_charlm\",\n    \"zh-hans\": \"ctb-51_charlm\",\n}\n\noptional_constituency = {\n    \"tr\": \"starlang_charlm\",\n}\n\n# an alternate tokenizer for languages which aren't trained from a base UD source\ndefault_tokenizer = {\n    \"my\": \"alt\",\n}\n\n# ideally we would have a less expensive model as the base model\n#default_coref = {\n#    \"en\": \"ontonotes_roberta-large_finetuned\",\n#}\n\noptional_coref = {\n    \"ca\": \"udcoref_xlm-roberta-lora\",\n    \"cs\": \"udcoref_xlm-roberta-lora\",\n    \"de\": \"udcoref_xlm-roberta-lora\",\n    \"en\": \"udcoref_xlm-roberta-lora\",\n    \"es\": \"udcoref_xlm-roberta-lora\",\n    \"fr\": \"udcoref_xlm-roberta-lora\",\n    \"he\": \"iahlt_xlm-roberta-lora\",\n    \"hi\": \"deeph_muril-large-cased-lora\",\n    # UD Coref has both nb and nn datasets for Norwegian\n    \"nb\": \"udcoref_xlm-roberta-lora\",\n    \"nn\": \"udcoref_xlm-roberta-lora\",\n    \"pl\": \"udcoref_xlm-roberta-lora\",\n    \"ru\": \"udcoref_xlm-roberta-lora\",\n    \"ta\": \"kbc_muril-large-cased-lora\",\n}\n\n\"\"\"\ndefault transformers to use for various languages\n\nwe try to document why we choose a particular model in each case\n\"\"\"\nTRANSFORMERS = {\n    # We tested three candidate AR models on POS, Depparse, and NER\n    #\n    # POS: padt dev set scores, AllTags\n    # depparse: padt dev set scores, LAS\n    # NER: dev scores on a random split of AQMAR, entity scores\n    #\n    #                                             pos   depparse  ner\n    # none (pt & charlm only)                    94.08    83.49  84.19\n    # asafaya/bert-base-arabic                   95.10    84.96  85.98\n    # aubmindlab/bert-base-arabertv2             95.33    85.28  84.93\n    # aubmindlab/araelectra-base-discriminator   95.66    85.83  86.10\n    \"ar\": \"aubmindlab/araelectra-base-discriminator\",\n\n    # https://huggingface.co/Maltehb/danish-bert-botxo\n    # contrary to normal expectations, this hurts F1\n    # on a dev split by about 1 F1\n    # \"da\": \"Maltehb/danish-bert-botxo\",\n    #\n    # the multilingual bert is a marginal improvement for conparse\n    #\n    # December 2022 update:\n    # there are quite a few Danish transformers available on HuggingFace\n    # here are the results of training a constituency parser with adadelta/adamw\n    # on each of them:\n    #\n    # no bert                              0.8245    0.8230\n    # alexanderfalk/danbert-small-cased    0.8236    0.8286\n    # Geotrend/distilbert-base-da-cased    0.8268    0.8306\n    # sarnikowski/convbert-small-da-cased  0.8322    0.8341\n    # bert-base-multilingual-cased         0.8341    0.8342\n    # vesteinn/ScandiBERT-no-faroese       0.8373    0.8408\n    # Maltehb/danish-bert-botxo            0.8383    0.8408\n    # vesteinn/ScandiBERT                  0.8421    0.8475\n    #\n    # Also, two models have token windows too short for use with the\n    # Danish dataset:\n    #  jonfd/electra-small-nordic\n    #  Maltehb/aelaectra-danish-electra-small-cased\n    #\n    \"da\": \"vesteinn/ScandiBERT\",\n\n    # As of April 2022, the bert models available have a weird\n    # tokenizer issue where soft hyphen causes it to crash.\n    # We attempt to compensate for that in the dev branch\n    #\n    # NER scores\n    #     model                                       dev      text\n    # xlm-roberta-large                              86.56    85.23\n    # bert-base-german-cased                         87.59    86.95\n    # dbmdz/bert-base-german-cased                   88.27    87.47\n    # german-nlp-group/electra-base-german-uncased   88.60    87.09\n    #\n    # constituency scores w/ peft, March 2024 model, in-order\n    #    model             dev     test\n    #   xlm-roberta-base  95.17   93.34\n    #   xlm-roberta-large 95.86   94.46    (!!!)\n    #   bert-base         95.24   93.24\n    #   dbmdz/bert        95.32   93.33\n    #   german/electra    95.72   94.05\n    #\n    # POS scores\n    #    model             dev     test\n    #   None              88.65   87.28\n    #   xlm-roberta-large 89.21   88.11\n    #   bert-base         89.52   88.42\n    #   dbmdz/bert        89.67   88.54\n    #   german/electra    89.98   88.66\n    #\n    # depparse scores, LAS\n    #    model             dev     test\n    #   None              87.76   84.37\n    #   xlm-roberta-large 89.00   85.79\n    #   bert-base         88.72   85.40\n    #   dbmdz/bert        88.70   85.14\n    #   german/electra    89.21   86.06\n    \"de\": \"german-nlp-group/electra-base-german-uncased\",\n\n    # experiments on various forms of roberta & electra\n    #  https://huggingface.co/roberta-base\n    #  https://huggingface.co/roberta-large\n    #  https://huggingface.co/google/electra-small-discriminator\n    #  https://huggingface.co/google/electra-base-discriminator\n    #  https://huggingface.co/google/electra-large-discriminator\n    #\n    # experiments using the different models for POS tagging,\n    # dev set, including WV and charlm, AllTags score:\n    #  roberta-base:   95.67\n    #  roberta-large:  95.98\n    #  electra-small:  95.31\n    #  electra-base:   95.90\n    #  electra-large:  96.01\n    #\n    # depparse scores, dev set, no finetuning, with WV and charlm\n    #                   UAS   LAS  CLAS  MLAS  BLEX\n    #  roberta-base:   93.16 91.20 89.87 89.38 89.87\n    #  roberta-large:  93.47 91.56 90.13 89.71 90.13\n    #  electra-small:  92.17 90.02 88.25 87.66 88.25\n    #  electra-base:   93.42 91.44 90.10 89.67 90.10\n    #  electra-large:  94.07 92.17 90.99 90.53 90.99\n    #\n    # conparse scores, dev & test set, with WV and charlm\n    #  roberta_base:   96.05 95.60\n    #  roberta_large:  95.95 95.60\n    #  electra-small:  95.33 95.04\n    #  electra-base:   96.09 95.98\n    #  electra-large:  96.25 96.14\n    #\n    # conparse scores w/ finetune, dev & test set, with WV and charlm\n    #  roberta_base:   96.07 95.81\n    #  roberta_large:  96.37 96.41   (!!!)\n    #  electra-small:  95.62 95.36\n    #  electra-base:   96.21 95.94\n    #  electra-large:  96.40 96.32\n    #\n    \"en\": \"google/electra-large-discriminator\",\n\n    # TODO need to test, possibly compare with others\n    \"es\": \"bertin-project/bertin-roberta-base-spanish\",\n\n    # NER scores for a couple Persian options:\n    # none:\n    # dev:  2022-04-23 01:44:53 INFO: fa_arman 79.46\n    # test: 2022-04-23 01:45:03 INFO: fa_arman 80.06\n    #\n    # HooshvareLab/bert-fa-zwnj-base\n    # dev:  2022-04-23 02:43:44 INFO: fa_arman 80.87\n    # test: 2022-04-23 02:44:07 INFO: fa_arman 80.81\n    #\n    # HooshvareLab/roberta-fa-zwnj-base\n    # dev:  2022-04-23 16:23:25 INFO: fa_arman 81.23\n    # test: 2022-04-23 16:23:48 INFO: fa_arman 81.11\n    #\n    # HooshvareLab/bert-base-parsbert-uncased\n    # dev:  2022-04-26 10:42:09 INFO: fa_arman 82.49\n    # test: 2022-04-26 10:42:31 INFO: fa_arman 83.16\n    \"fa\": 'HooshvareLab/bert-base-parsbert-uncased',\n\n    # NER scores for a couple options:\n    # none:\n    # dev:  2022-03-04 INFO: fi_turku 83.45\n    # test: 2022-03-04 INFO: fi_turku 86.25\n    #\n    # bert-base-multilingual-cased\n    # dev:  2022-03-04 INFO: fi_turku 85.23\n    # test: 2022-03-04 INFO: fi_turku 89.00\n    #\n    # TurkuNLP/bert-base-finnish-cased-v1:\n    # dev:  2022-03-04 INFO: fi_turku 88.41\n    # test: 2022-03-04 INFO: fi_turku 91.36\n    \"fi\": \"TurkuNLP/bert-base-finnish-cased-v1\",\n\n    # POS dev set tagging results for French:\n    #  No bert:\n    #    98.60  100.00   98.55   98.04\n    #  dbmdz/electra-base-french-europeana-cased-discriminator\n    #    98.70  100.00   98.69   98.24\n    #  benjamin/roberta-base-wechsel-french\n    #    98.71  100.00   98.75   98.26\n    #  camembert/camembert-large\n    #    98.75  100.00   98.75   98.30\n    #  camembert-base\n    #    98.78  100.00   98.77   98.33\n    #\n    # GSD depparse dev set results for French:\n    #  No bert:\n    #    95.83 94.52 91.34 91.10 91.34\n    #  camembert/camembert-large\n    #    96.80 95.71 93.37 93.13 93.37\n    #  TODO: the rest of the chart\n    \"fr\": \"camembert/camembert-large\",\n\n    # Ancient Greek has a surprising number of transformers, considering\n    #    Model           POS        Depparse LAS\n    # None              0.8812       0.7684\n    # Microbert M       0.8883       0.7706\n    # Microbert MX      0.8910       0.7755\n    # Microbert MXP     0.8916       0.7742\n    # Pranaydeeps Bert  0.9139       0.7987\n    \"grc\": \"pranaydeeps/Ancient-Greek-BERT\",\n\n    # a couple possibilities to experiment with for Hebrew\n    # dev scores for POS and depparse\n    # https://huggingface.co/imvladikon/alephbertgimmel-base-512\n    #   UPOS    XPOS  UFeats AllTags\n    #  97.25   97.25   92.84   91.81\n    #   UAS   LAS  CLAS  MLAS  BLEX\n    #  94.42 92.47 89.49 88.82 89.49\n    #\n    # https://huggingface.co/onlplab/alephbert-base\n    #   UPOS    XPOS  UFeats AllTags\n    #  97.37   97.37   92.50   91.55\n    #   UAS   LAS  CLAS  MLAS  BLEX\n    #  94.06 92.12 88.80 88.13 88.80\n    #\n    # https://huggingface.co/avichr/heBERT\n    #   UPOS    XPOS  UFeats AllTags\n    #  97.09   97.09   92.36   91.28\n    #   UAS   LAS  CLAS  MLAS  BLEX\n    #  94.29 92.30 88.99 88.38 88.99\n    \"he\": \"imvladikon/alephbertgimmel-base-512\",\n\n    # can also experiment with xlm-roberta\n    # on a coref dataset from IITH, span F1:\n    #                         dev      test\n    #  xlm-roberta-large   0.63635   0.66579\n    #  muril-large         0.65369   0.68290\n    \"hi\": \"google/muril-large-cased\",\n\n    # https://huggingface.co/xlm-roberta-base\n    # Scores by entity for armtdp NER on 18 labels:\n    # no bert : 86.68\n    # xlm-roberta-base : 89.31\n    \"hy\": \"xlm-roberta-base\",\n\n    # Indonesian POS experiments: dev set of GSD\n    # python3 stanza/utils/training/run_pos.py id_gsd --no_bert\n    # python3 stanza/utils/training/run_pos.py id_gsd --bert_model ...\n    # also ran on the ICON constituency dataset\n    #  model                                      POS       CON\n    # no_bert                                    89.95     84.74\n    # flax-community/indonesian-roberta-large    89.78 (!)  xxx\n    # flax-community/indonesian-roberta-base     90.14      xxx\n    # indobenchmark/indobert-base-p2             90.09\n    # indobenchmark/indobert-base-p1             90.14\n    # indobenchmark/indobert-large-p1            90.19\n    # indolem/indobert-base-uncased              90.21     88.60\n    # cahya/bert-base-indonesian-1.5G            90.32     88.15\n    # cahya/roberta-base-indonesian-1.5G         90.40     87.27\n    \"id\": \"indolem/indobert-base-uncased\",\n\n    # from https://github.com/idb-ita/GilBERTo\n    # annoyingly, it doesn't handle cased text\n    # supposedly there is an argument \"do_lower_case\"\n    # but that still leaves a lot of unk tokens\n    # \"it\": \"idb-ita/gilberto-uncased-from-camembert\",\n    #\n    # from https://github.com/musixmatchresearch/umberto\n    # on NER, this gets 88.37 dev and 91.02 test\n    # another option is dbmdz/bert-base-italian-cased,\n    # which gets 87.27 dev and 90.32 test\n    #\n    #  in-order constituency parser on the VIT dev set:\n    # dbmdz/bert-base-italian-cased                       0.8079\n    # dbmdz/bert-base-italian-xxl-cased:                  0.8195\n    # Musixmatch/umberto-commoncrawl-cased-v1:            0.8256\n    # dbmdz/electra-base-italian-xxl-cased-discriminator: 0.8314\n    #\n    #  FBK NER dev set:\n    # dbmdz/bert-base-italian-cased:                      87.76\n    # Musixmatch/umberto-commoncrawl-cased-v1:            88.62\n    # dbmdz/bert-base-italian-xxl-cased:                  88.84\n    # dbmdz/electra-base-italian-xxl-cased-discriminator: 89.91\n    #\n    #  combined UD POS dev set:                             UPOS    XPOS  UFeats AllTags\n    # dbmdz/bert-base-italian-cased:                       98.62   98.53   98.06   97.49\n    # dbmdz/bert-base-italian-xxl-cased:                   98.61   98.54   98.07   97.58\n    # dbmdz/electra-base-italian-xxl-cased-discriminator:  98.64   98.54   98.14   97.61\n    # Musixmatch/umberto-commoncrawl-cased-v1:             98.56   98.45   98.13   97.62\n    \"it\": \"dbmdz/electra-base-italian-xxl-cased-discriminator\",\n\n    # for Japanese\n    # there are others that would also work,\n    # but they require different tokenizers instead of being\n    # plug & play\n    #\n    # Constitutency scores on ALT (in-order)\n    # no bert: 90.68 dev, 91.40 test\n    # rinna:   91.54 dev, 91.89 test\n    \"ja\": \"rinna/japanese-roberta-base\",\n\n    # could also try:\n    # l3cube-pune/marathi-bert-v2\n    #  or\n    # https://huggingface.co/l3cube-pune/hindi-marathi-dev-roberta\n    # l3cube-pune/hindi-marathi-dev-roberta\n    #\n    # depparse ufal dev scores:\n    #  no transformer              74.89 63.70 57.43 53.01 57.43\n    #  l3cube-pune/marathi-roberta 76.48 66.21 61.20 57.60 61.20\n    \"mr\": \"l3cube-pune/marathi-roberta\",\n\n    \"or\": \"google/muril-large-cased\",\n\n    # https://huggingface.co/allegro/herbert-base-cased\n    # Scores by entity on the NKJP NER task:\n    # no bert (dev/test): 88.64/88.75\n    # herbert-base-cased (dev/test): 91.48/91.02,\n    # herbert-large-cased (dev/test): 92.25/91.62\n    # sdadas/polish-roberta-large-v2 (dev/test): 92.66/91.22\n    \"pl\": \"allegro/herbert-base-cased\",\n\n    # experiments on the cintil conparse dataset\n    # ran a variety of transformer settings\n    # found the following dev set scores after 400 iterations:\n    # Geotrend/distilbert-base-pt-cased : not plug & play\n    # no bert: 0.9082\n    # xlm-roberta-base: 0.9109\n    # xlm-roberta-large: 0.9254\n    # adalbertojunior/distilbert-portuguese-cased: 0.9300\n    # neuralmind/bert-base-portuguese-cased: 0.9307\n    # neuralmind/bert-large-portuguese-cased: 0.9343\n    \"pt\": \"neuralmind/bert-large-portuguese-cased\",\n\n    # hope is actually to build our own using a large text collection\n    \"sd\": \"google/muril-large-cased\",\n\n    # Tamil options: quite a few, need to run a bunch of experiments\n    #                               dev pos    dev depparse las\n    # no transformer                 82.82        69.12\n    # ai4bharat/indic-bert           82.98        70.47\n    # lgessler/microbert-tamil-mxp   83.21        69.28\n    # monsoon-nlp/tamillion          83.37        69.28\n    # l3cube-pune/tamil-bert         85.27        72.53\n    # d42kw01f/Tamil-RoBERTa         85.59        70.55\n    # google/muril-base-cased        85.67        72.68\n    # google/muril-large-cased       86.30        72.45\n    #\n    # should also consider xlm-roberta-large\n    # updated on UD 2.16 data:      dev pos      ner\n    # google/muril-large-cased       86.86      65.08\n    # xlm-roberta-large                         66.28\n    \"ta\": \"google/muril-large-cased\",\n\n    \"te\": \"google/muril-large-cased\",\n\n    # https://huggingface.co/airesearch/wangchanberta-base-att-spm-uncased\n    # this is clearly better than no transformer on a couple datasets:\n    #\n    #                    TUD dev upos   TUD dev depparse LAS\n    # no transformer       91.26             73.57\n    # wangchanberta        92.21             76.65\n    \"th\": \"airesearch/wangchanberta-base-att-spm-uncased\",\n\n    # https://huggingface.co/dbmdz/bert-base-turkish-128k-cased\n    # helps the Turkish model quite a bit\n    \"tr\": \"dbmdz/bert-base-turkish-128k-cased\",\n\n    \"ur\": \"google/muril-large-cased\",\n\n    # from https://github.com/VinAIResearch/PhoBERT\n    # \"vi\": \"vinai/phobert-base\",\n    # using 6 or 7 layers of phobert-large is slightly\n    # more effective for constituency parsing than\n    # using 4 layers of phobert-base\n    # ... going beyond 4 layers of phobert-base\n    # does not help the scores\n    \"vi\": \"vinai/phobert-large\",\n\n    # https://github.com/ymcui/Chinese-BERT-wwm\n    # there's also hfl/chinese-roberta-wwm-ext-large\n    # or hfl/chinese-electra-base-discriminator\n    # or hfl/chinese-electra-180g-large-discriminator,\n    #   which works better than the below roberta on constituency\n    # \"zh-hans\": \"hfl/chinese-roberta-wwm-ext\",\n    # conparse dev scores (averaged over 5):\n    #   google bert:  0.9422\n    #   hfl bert:     0.9469\n    #   hfl roberta:  0.9459\n    #   hfl electra:  0.9515\n    #   hfl macbert:  0.9530\n    # There is also a ShannonAI model, but our current codebase is\n    # somehow not compatible\n    # further comparing HFL:\n    #                    POS dev  Depparse dev LAS     NER dev\n    #   HFL Electra      96.90     85.66                77.90\n    #   HFL Macbert      96.53     84.72                78.46\n    # \"zh-hans\": \"hfl/chinese-macbert-large\",\n    \"zh-hans\": \"hfl/chinese-electra-180g-large-discriminator\",\n}\n\nTRANSFORMER_LAYERS = {\n    # not clear what the best number is without more experiments,\n    # but more than 4 is working better than just 4\n    \"vi\": 7,\n}\n\nTRANSFORMER_NICKNAMES = {\n    # ar\n    \"asafaya/bert-base-arabic\": \"asafaya-bert\",\n    \"aubmindlab/araelectra-base-discriminator\": \"aubmind-electra\",\n    \"aubmindlab/bert-base-arabertv2\": \"aubmind-bert\",\n\n    # da\n    \"vesteinn/ScandiBERT\": \"scandibert\",\n\n    # de\n    \"bert-base-german-cased\": \"bert-base-german-cased\",\n    \"dbmdz/bert-base-german-cased\": \"dbmdz-bert-german-cased\",\n    \"german-nlp-group/electra-base-german-uncased\": \"german-nlp-electra\",\n\n    # en\n    \"bert-base-multilingual-cased\": \"mbert\",\n    \"xlm-roberta-large\": \"xlm-roberta-large\",\n    \"google/electra-large-discriminator\": \"electra-large\",\n    \"microsoft/deberta-v3-large\": \"deberta-v3-large\",\n    \"princeton-nlp/Sheared-LLaMA-1.3B\": \"sheared-llama-1b3\",\n\n    # es\n    \"bertin-project/bertin-roberta-base-spanish\": \"bertin-roberta\",\n\n    # fa\n    \"HooshvareLab/bert-base-parsbert-uncased\": \"parsbert\",\n\n    # fi\n    \"TurkuNLP/bert-base-finnish-cased-v1\": \"bert\",\n\n    # fr\n    \"benjamin/roberta-base-wechsel-french\": \"wechsel-roberta\",\n    \"camembert-base\": \"camembert-base\",\n    \"camembert/camembert-large\": \"camembert-large\",\n    \"dbmdz/electra-base-french-europeana-cased-discriminator\": \"dbmdz-electra\",\n\n    # grc\n    \"pranaydeeps/Ancient-Greek-BERT\": \"grc-pranaydeeps\",\n    \"lgessler/microbert-ancient-greek-m\": \"grc-microbert-m\",\n    \"lgessler/microbert-ancient-greek-mx\": \"grc-microbert-mx\",\n    \"lgessler/microbert-ancient-greek-mxp\": \"grc-microbert-mxp\",\n    \"altsoph/bert-base-ancientgreek-uncased\": \"grc-altsoph\",\n\n    # he\n    \"HeNLP/HeRo\": \"hero-roberta\",\n    \"imvladikon/alephbertgimmel-base-512\": \"alephbertgimmel\",\n    \"onlplab/alephbert-base\": \"alephbert\",\n\n    # hy\n    \"xlm-roberta-base\": \"xlm-roberta-base\",\n\n    # id\n    \"indolem/indobert-base-uncased\":         \"indobert\",\n    \"indobenchmark/indobert-large-p1\":       \"indobenchmark-large-p1\",\n    \"indobenchmark/indobert-base-p1\":        \"indobenchmark-base-p1\",\n    \"indobenchmark/indobert-lite-large-p1\":  \"indobenchmark-lite-large-p1\",\n    \"indobenchmark/indobert-lite-base-p1\":   \"indobenchmark-lite-base-p1\",\n    \"indobenchmark/indobert-large-p2\":       \"indobenchmark-large-p2\",\n    \"indobenchmark/indobert-base-p2\":        \"indobenchmark-base-p2\",\n    \"indobenchmark/indobert-lite-large-p2\":  \"indobenchmark-lite-large-p2\",\n    \"indobenchmark/indobert-lite-base-p2\":   \"indobenchmark-lite-base-p2\",\n\n    # it\n    \"dbmdz/electra-base-italian-xxl-cased-discriminator\": \"electra\",\n\n    # ja\n    \"rinna/japanese-roberta-base\": \"rinna-roberta\",\n\n    # mr\n    \"l3cube-pune/marathi-roberta\": \"l3cube-marathi-roberta\",\n\n    # pl\n    \"allegro/herbert-base-cased\": \"herbert\",\n\n    # pt\n    \"neuralmind/bert-large-portuguese-cased\": \"bertimbau\",\n\n    # ta: tamil\n    \"monsoon-nlp/tamillion\":         \"tamillion\",\n    \"lgessler/microbert-tamil-m\":    \"ta-microbert-m\",\n    \"lgessler/microbert-tamil-mxp\":  \"ta-microbert-mxp\",\n    \"l3cube-pune/tamil-bert\":        \"l3cube-tamil-bert\",\n    \"d42kw01f/Tamil-RoBERTa\":        \"ta-d42kw01f-roberta\",\n\n    # th\n    \"airesearch/wangchanberta-base-att-spm-uncased\":   \"wangchanberta\",\n\n    # tr\n    \"dbmdz/bert-base-turkish-128k-cased\": \"bert\",\n\n    # vi\n    \"vinai/phobert-base\": \"phobert-base\",\n    \"vinai/phobert-large\": \"phobert-large\",\n\n    # zh\n    \"google-bert/bert-base-chinese\": \"google-bert-chinese\",\n    \"hfl/chinese-bert-wwm\": \"hfl-bert-chinese\",\n    \"hfl/chinese-macbert-large\": \"hfl-macbert-chinese\",\n    \"hfl/chinese-roberta-wwm-ext\": \"hfl-roberta-chinese\",\n    \"hfl/chinese-electra-180g-large-discriminator\": \"electra-large\",\n    \"ShannonAI/ChineseBERT-base\": \"shannonai-chinese-bert\",\n\n    # multi-lingual Indic\n    \"ai4bharat/indic-bert\": \"indic-bert\",\n    \"google/muril-base-cased\": \"muril-base-cased\",\n    \"google/muril-large-cased\": \"muril-large-cased\",\n\n    # multi-lingual\n    \"FacebookAI/xlm-roberta-large\": \"xlm-roberta-large\",\n}\n\ndef known_nicknames():\n    \"\"\"\n    Return a list of all the transformer nicknames\n\n    We return a list so that we can sort them in decreasing key length\n    \"\"\"\n    nicknames = list(value for key, value in TRANSFORMER_NICKNAMES.items())\n\n    # previously unspecific transformers get \"transformer\" as the nickname\n    nicknames.append(\"transformer\")\n\n    nicknames = sorted(nicknames, key=lambda x: -len(x))\n\n    return nicknames\n"
  },
  {
    "path": "stanza/resources/installation.py",
    "content": "\"\"\"\nFunctions for setting up the environments.\n\"\"\"\n\nimport os\nimport logging\nimport zipfile\nimport shutil\n\nfrom stanza.resources.common import USER_CACHE_DIR, request_file, unzip, \\\n    get_root_from_zipfile, set_logging_level\n\nlogger = logging.getLogger('stanza')\n\nDEFAULT_CORENLP_MODEL_URL = os.getenv(\n    'CORENLP_MODEL_URL',\n    'https://huggingface.co/stanfordnlp/corenlp-{model}/resolve/{tag}/stanford-corenlp-models-{model}.jar'\n)\nBACKUP_CORENLP_MODEL_URL = \"http://nlp.stanford.edu/software/stanford-corenlp-{version}-models-{model}.jar\"\n\nDEFAULT_CORENLP_URL = os.getenv(\n    'CORENLP_MODEL_URL',\n    'https://huggingface.co/stanfordnlp/CoreNLP/resolve/{tag}/stanford-corenlp-latest.zip'\n)\n\nDEFAULT_CORENLP_DIR = os.getenv(\n    'CORENLP_HOME',\n    os.path.join(USER_CACHE_DIR, 'corenlp')\n)\n\nAVAILABLE_MODELS = set(['arabic', 'chinese', 'english-extra', 'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish'])\n\n\ndef download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_MODEL_URL, logging_level='INFO', proxies=None, force=True):\n    \"\"\"\n    A automatic way to download the CoreNLP models.\n\n    Args:\n        model: the name of the model, can be one of 'arabic', 'chinese', 'english',\n            'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish'\n        version: the version of the model\n        dir: the directory to download CoreNLP model into; alternatively can be\n            set up with environment variable $CORENLP_HOME\n        url: The link to download CoreNLP models.\n             It will need {model} and either {version} or {tag} to properly format the URL\n        logging_level: logging level to use during installation\n        force: Download model anyway, no matter model file exists or not\n    \"\"\"\n    dir = os.path.expanduser(dir)\n    if not model or not version:\n        raise ValueError(\n            \"Both model and model version should be specified.\"\n        )\n    logger.info(f\"Downloading {model} models (version {version}) into directory {dir}\")\n    model = model.strip().lower()\n    if model not in AVAILABLE_MODELS:\n        raise KeyError(\n            f'{model} is currently not supported. '\n            f'Must be one of: {list(AVAILABLE_MODELS)}.'\n        )\n    # for example:\n    # https://huggingface.co/stanfordnlp/CoreNLP/resolve/v4.2.2/stanford-corenlp-models-french.jar\n    tag = version if version == 'main' else 'v' + version\n    download_url = url.format(tag=tag, model=model, version=version)\n    model_path = os.path.join(dir, f'stanford-corenlp-{version}-models-{model}.jar')\n\n    if os.path.exists(model_path) and not force:\n        logger.warn(\n            f\"Model file {model_path} already exists. \"\n            f\"Please download this model to a new directory.\")\n        return\n\n    try:\n        request_file(\n            download_url,\n            model_path,\n            proxies\n        )\n    except (KeyboardInterrupt, SystemExit):\n        raise\n    except Exception as e:\n        raise RuntimeError(\n            \"Downloading CoreNLP model file failed. \"\n            \"Please try manual downloading at: https://stanfordnlp.github.io/CoreNLP/.\"\n        ) from e\n\n\ndef install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None, version=\"main\"):\n    \"\"\"\n    A fully automatic way to install and setting up the CoreNLP library \n    to use the client functionality.\n\n    Args:\n        dir: the directory to download CoreNLP model into; alternatively can be\n            set up with environment variable $CORENLP_HOME\n        url: The link to download CoreNLP models\n             Needs a {version} or {tag} parameter to specify the version\n        logging_level: logging level to use during installation\n    \"\"\"\n    dir = os.path.expanduser(dir)\n    set_logging_level(logging_level=logging_level, verbose=None)\n    if os.path.exists(dir) and len(os.listdir(dir)) > 0:\n        logger.warn(\n            f\"Directory {dir} already exists. \"\n            f\"Please install CoreNLP to a new directory.\")\n        return\n\n    logger.info(f\"Installing CoreNLP package into {dir}\")\n    # First download the URL package\n    logger.debug(f\"Download to destination file: {os.path.join(dir, 'corenlp.zip')}\")\n    tag = version if version == 'main' else 'v' + version\n    url = url.format(version=version, tag=tag)\n    try:\n        request_file(url, os.path.join(dir, 'corenlp.zip'), proxies)\n\n    except (KeyboardInterrupt, SystemExit):\n        raise\n    except Exception as e:\n        raise RuntimeError(\n            \"Downloading CoreNLP zip file failed. \"\n            \"Please try manual installation: https://stanfordnlp.github.io/CoreNLP/.\"\n        ) from e\n\n    # Unzip corenlp into dir\n    logger.debug(\"Unzipping downloaded zip file...\")\n    unzip(dir, 'corenlp.zip')\n\n    # By default CoreNLP will be unzipped into a version-dependent folder, \n    # e.g., stanford-corenlp-4.0.0. We need some hack around that and move\n    # files back into our designated folder\n    logger.debug(f\"Moving files into the designated folder at: {dir}\")\n    corenlp_dirname = get_root_from_zipfile(os.path.join(dir, 'corenlp.zip'))\n    corenlp_dirname = os.path.join(dir, corenlp_dirname)\n    for f in os.listdir(corenlp_dirname):\n        shutil.move(os.path.join(corenlp_dirname, f), dir)\n\n    # Remove original zip and folder\n    logger.debug(\"Removing downloaded zip file...\")\n    os.remove(os.path.join(dir, 'corenlp.zip'))\n    shutil.rmtree(corenlp_dirname)\n\n    # Warn user to set up env\n    if dir != DEFAULT_CORENLP_DIR:\n        logger.warning(\n            f\"For customized installation location, please set the `CORENLP_HOME` \"\n            f\"environment variable to the location of the installation. \"\n            f\"In Unix, this is done with `export CORENLP_HOME={dir}`.\")\n\n"
  },
  {
    "path": "stanza/resources/prepare_resources.py",
    "content": "\"\"\"\nConverts a directory of models organized by type into a directory organized by language.\n\nAlso produces the resources.json file.\n\nFor example, on the cluster, you can do this:\n\npython3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0 > resources.out 2>&1\nnlprun -a stanza-1.2 -q john \"python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0\" -o resources.out\n\"\"\"\n\nimport argparse\nfrom collections import defaultdict\nimport json\nimport os\nfrom pathlib import Path\nimport hashlib\nimport shutil\nimport zipfile\n\nfrom stanza import __resources_version__\nfrom stanza.models.common.constant import lcode2lang, two_to_three_letters, three_to_two_letters, extra_lang_to_lcodes\nfrom stanza.resources.default_packages import PACKAGES, TRANSFORMERS, TRANSFORMER_NICKNAMES\nfrom stanza.resources.default_packages import *\nfrom stanza.utils.datasets.prepare_lemma_classifier import DATASET_MAPPING as LEMMA_CLASSIFIER_DATASETS\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input_dir', type=str, default=\"/u/nlp/software/stanza/models/current-models-%s\" % __resources_version__, help='Input dir for various models.  Defaults to the recommended home on the nlp cluster')\n    parser.add_argument('--output_dir', type=str, default=\"/u/nlp/software/stanza/models/%s\" % __resources_version__, help='Output dir for various models.')\n    parser.add_argument('--packages_only', action='store_true', default=False, help='Only build the package maps instead of rebuilding everything')\n    parser.add_argument('--lang', type=str, default=None, help='Only process this language or a comma-separated list of languages.  If left blank, will prepare all languages.  To use this argument, a previous prepared resources with all of the languages is necessary.')\n    args = parser.parse_args()\n    args.input_dir = os.path.abspath(args.input_dir)\n    args.output_dir = os.path.abspath(args.output_dir)\n    if args.lang is not None:\n        args.lang = \",\".join(args.lang.strip().split())\n    return args\n\n\nallowed_empty_languages = [\n    # only tokenize and NER for Myanmar right now (soon...)\n    \"my\",\n    # currently only an NER, not even a tokenizer, for Oriya\n    \"or\",\n]\n\n# map processor name to file ending\n# the order of this dict determines the order in which default.zip files are built\n# changing it will necessitate rebuilding all of the default.zip files\n# not a disaster, but it would involve a bunch of uploading\nprocessor_to_ending = {\n    \"tokenize\": \"tokenizer\",\n    \"mwt\": \"mwt_expander\",\n    \"lemma\": \"lemmatizer\",\n    \"pos\": \"tagger\",\n    \"depparse\": \"parser\",\n    \"pretrain\": \"pretrain\",\n    \"ner\": \"nertagger\",\n    \"forward_charlm\": \"forward_charlm\",\n    \"backward_charlm\": \"backward_charlm\",\n    \"sentiment\": \"sentiment\",\n    \"constituency\": \"constituency\",\n    \"coref\": \"coref\",\n    \"langid\": \"langid\",\n}\nending_to_processor = {j: i for i, j in processor_to_ending.items()}\nPROCESSORS = list(processor_to_ending.keys())\n\ndef ensure_dir(dir):\n    Path(dir).mkdir(parents=True, exist_ok=True)\n\n\ndef copy_file(src, dst):\n    ensure_dir(Path(dst).parent)\n    shutil.copy2(src, dst)\n\n\ndef get_md5(path):\n    data = open(path, 'rb').read()\n    return hashlib.md5(data).hexdigest()\n\n\ndef split_model_name(model):\n    \"\"\"\n    Split model names by _\n\n    Takes into account packages with _ and processor types with _\n    \"\"\"\n    model = model[:-3].replace('.', '_')\n    # sort by key length so that nertagger is checked before tagger, for example\n    for processor in sorted(ending_to_processor.keys(), key=lambda x: -len(x)):\n        if model.endswith(processor):\n            model = model[:-(len(processor)+1)]\n            processor = ending_to_processor[processor]\n            break\n    else:\n        raise AssertionError(f\"Could not find a processor type in {model}\")\n    lang, package = model.split('_', 1)\n    return lang, package, processor\n\ndef split_package(package, default_use_charlm=True):\n    if package.endswith(\"_finetuned\"):\n        package = package[:-10]\n\n    if package.endswith(\"_nopretrain\"):\n        package = package[:-11]\n        return package, False, False\n    if package.endswith(\"_nocharlm\"):\n        package = package[:-9]\n        return package, True, False\n    if package.endswith(\"_charlm\"):\n        package = package[:-7]\n        return package, True, True\n    underscore = package.rfind(\"_\")\n    if underscore >= 0:\n        # +1 to skip the underscore\n        nickname = package[underscore+1:]\n        if nickname in known_nicknames():\n            return package[:underscore], True, True\n\n    # guess it was a model which wasn't built with the new naming convention of putting the pretrain type at the end\n    # assume WV and charlm... if the language / package doesn't allow for one, that should be caught later\n    return package, True, default_use_charlm\n\ndef get_pretrain_package(lang, package, model_pretrains, default_pretrains):\n    package, uses_pretrain, _ = split_package(package)\n\n    if not uses_pretrain or lang in no_pretrain_languages:\n        return None\n    elif model_pretrains is not None and lang in model_pretrains and package in model_pretrains[lang]:\n        return model_pretrains[lang][package]\n    elif lang in default_pretrains:\n        return default_pretrains[lang]\n\n    raise RuntimeError(\"pretrain not specified for lang %s package %s\" % (lang, package))\n\ndef get_charlm_package(lang, package, model_charlms, default_charlms, default_use_charlm=True):\n    package, _, uses_charlm = split_package(package, default_use_charlm)\n\n    if not uses_charlm:\n        return None\n\n    if model_charlms is not None and lang in model_charlms and package in model_charlms[lang]:\n        return model_charlms[lang][package]\n    else:\n        return default_charlms.get(lang, None)\n\ndef get_con_dependencies(lang, package):\n    # so far, this invariant is true:\n    # constituency models use the default pretrain and charlm for the language\n    # sometimes there is no charlm for a language that has constituency, though\n    pretrain_package = get_pretrain_package(lang, package, None, default_pretrains)\n    dependencies = [{'model': 'pretrain', 'package': pretrain_package}]\n\n    charlm_package = default_charlms.get(lang, None)\n    if charlm_package is not None:\n        dependencies.append({'model': 'forward_charlm', 'package': charlm_package})\n        dependencies.append({'model': 'backward_charlm', 'package': charlm_package})\n\n    return dependencies\n\ndef get_pos_charlm_package(lang, package):\n    return get_charlm_package(lang, package, pos_charlms, default_charlms)\n\ndef get_pos_dependencies(lang, package):\n    dependencies = []\n\n    pretrain_package = get_pretrain_package(lang, package, pos_pretrains, default_pretrains)\n    if pretrain_package is not None:\n        dependencies.append({'model': 'pretrain', 'package': pretrain_package})\n\n    charlm_package = get_pos_charlm_package(lang, package)\n    if charlm_package is not None:\n        dependencies.append({'model': 'forward_charlm', 'package': charlm_package})\n        dependencies.append({'model': 'backward_charlm', 'package': charlm_package})\n\n    return dependencies\n\ndef get_lemma_pretrain_package(lang, package):\n    package, uses_pretrain, uses_charlm = split_package(package)\n    if not uses_pretrain:\n        return None\n    if not uses_charlm:\n        # currently the contextual lemma classifier is only active\n        # for the charlm lemmatizers\n        return None\n    if \"%s_%s\" % (lang, package) not in LEMMA_CLASSIFIER_DATASETS:\n        return None\n    return get_pretrain_package(lang, package, {}, default_pretrains)\n\ndef get_lemma_charlm_package(lang, package):\n    return get_charlm_package(lang, package, lemma_charlms, default_charlms)\n\ndef get_lemma_dependencies(lang, package):\n    dependencies = []\n\n    pretrain_package = get_lemma_pretrain_package(lang, package)\n    if pretrain_package is not None:\n        dependencies.append({'model': 'pretrain', 'package': pretrain_package})\n\n    charlm_package = get_lemma_charlm_package(lang, package)\n    if charlm_package is not None:\n        dependencies.append({'model': 'forward_charlm', 'package': charlm_package})\n        dependencies.append({'model': 'backward_charlm', 'package': charlm_package})\n\n    return dependencies\n\n\ndef get_tokenizer_charlm_package(lang, package):\n    return get_charlm_package(lang, package, tokenizer_charlms, default_charlms, default_use_charlm=False)\n\ndef get_tokenizer_dependencies(lang, package):\n    dependencies = []\n    charlm_package = get_tokenizer_charlm_package(lang, package)\n    if charlm_package is not None:\n        dependencies.append({'model': 'forward_charlm', 'package': charlm_package})\n    return dependencies\n\ndef get_depparse_charlm_package(lang, package):\n    return get_charlm_package(lang, package, depparse_charlms, default_charlms)\n\ndef get_depparse_dependencies(lang, package):\n    dependencies = []\n\n    pretrain_package = get_pretrain_package(lang, package, depparse_pretrains, default_pretrains)\n    if pretrain_package is not None:\n        dependencies.append({'model': 'pretrain', 'package': pretrain_package})\n\n    charlm_package = get_depparse_charlm_package(lang, package)\n    if charlm_package is not None:\n        dependencies.append({'model': 'forward_charlm', 'package': charlm_package})\n        dependencies.append({'model': 'backward_charlm', 'package': charlm_package})\n\n    return dependencies\n\ndef get_ner_charlm_package(lang, package):\n    return get_charlm_package(lang, package, ner_charlms, default_charlms)\n\ndef get_ner_pretrain_package(lang, package):\n    return get_pretrain_package(lang, package, ner_pretrains, default_pretrains)\n\ndef get_ner_dependencies(lang, package):\n    dependencies = []\n\n    pretrain_package = get_ner_pretrain_package(lang, package)\n    if pretrain_package is not None:\n        dependencies.append({'model': 'pretrain', 'package': pretrain_package})\n\n    charlm_package = get_ner_charlm_package(lang, package)\n    if charlm_package is not None:\n        dependencies.append({'model': 'forward_charlm', 'package': charlm_package})\n        dependencies.append({'model': 'backward_charlm', 'package': charlm_package})\n\n    return dependencies\n\ndef get_sentiment_dependencies(lang, package):\n    \"\"\"\n    Return a list of dependencies for the sentiment model\n\n    Generally this will be pretrain, forward & backward charlm\n    So far, this invariant is true:\n    sentiment models use the default pretrain for the language\n    also, they all use the default charlm for a language\n    \"\"\"\n    pretrain_package = get_pretrain_package(lang, package, None, default_pretrains)\n    dependencies = [{'model': 'pretrain', 'package': pretrain_package}]\n\n    charlm_package = default_charlms.get(lang, None)\n    if charlm_package is not None:\n        dependencies.append({'model': 'forward_charlm', 'package': charlm_package})\n        dependencies.append({'model': 'backward_charlm', 'package': charlm_package})\n\n    return dependencies\n\ndef get_dependencies(processor, lang, package):\n    \"\"\"\n    Get the dependencies for a particular lang/package based on the package name\n\n    The package can include descriptors such as _nopretrain, _nocharlm, _charlm\n    which inform whether or not this particular model uses charlm or pretrain\n    \"\"\"\n    if processor == 'depparse':\n        return get_depparse_dependencies(lang, package)\n    elif processor == 'lemma':\n        return get_lemma_dependencies(lang, package)\n    elif processor == 'pos':\n        return get_pos_dependencies(lang, package)\n    elif processor == 'ner':\n        return get_ner_dependencies(lang, package)\n    elif processor == 'sentiment':\n        return get_sentiment_dependencies(lang, package)\n    elif processor == 'constituency':\n        return get_con_dependencies(lang, package)\n    elif processor == 'tokenize':\n        return get_tokenizer_dependencies(lang, package)\n    return {}\n\ndef process_dirs(args):\n    dirs = sorted(os.listdir(args.input_dir))\n    resources = {}\n    if args.lang:\n        resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))\n        # this one language gets overridden\n        # if this is not done, and we reuse the old resources,\n        # any models which were deleted will still be in the resources\n        for lang in args.lang.split(\",\"):\n            resources[lang] = {}\n\n    for model_dir in dirs:\n        print(f\"Processing models in {model_dir}\")\n        models = sorted(os.listdir(os.path.join(args.input_dir, model_dir)))\n        for model in tqdm(models):\n            if not model.endswith('.pt'): continue\n            # get processor\n            lang, package, processor = split_model_name(model)\n            if args.lang and lang not in args.lang.split(\",\"):\n                continue\n\n            # copy file\n            input_path = os.path.join(args.input_dir, model_dir, model)\n            output_path = os.path.join(args.output_dir, lang, \"models\", processor, package + '.pt')\n            copy_file(input_path, output_path)\n            # maintain md5\n            md5 = get_md5(output_path)\n            # maintain dependencies\n            dependencies = get_dependencies(processor, lang, package)\n            # maintain resources\n            if lang not in resources: resources[lang] = {}\n            if processor not in resources[lang]: resources[lang][processor] = {}\n            if dependencies:\n                resources[lang][processor][package] = {'md5': md5, 'dependencies': dependencies}\n            else:\n                resources[lang][processor][package] = {'md5': md5}\n    print(\"Processed initial model directories.  Writing preliminary resources.json\")\n    json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)\n\ndef get_default_pos_package(lang, ud_package):\n    charlm_package = get_pos_charlm_package(lang, ud_package)\n    if charlm_package is not None:\n        return ud_package + \"_charlm\"\n    if lang in no_pretrain_languages:\n        return ud_package + \"_nopretrain\"\n    return ud_package + \"_nocharlm\"\n\ndef get_default_depparse_package(lang, ud_package):\n    charlm_package = get_depparse_charlm_package(lang, ud_package)\n    if charlm_package is not None:\n        return ud_package + \"_charlm\"\n    if lang in no_pretrain_languages:\n        return ud_package + \"_nopretrain\"\n    return ud_package + \"_nocharlm\"\n\ndef process_default_zips(args):\n    resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))\n    for lang in resources:\n        # check url, alias, and lang_name in case we are rerunning this step on an already built resources.json\n        if lang == 'url':\n            continue\n        if 'alias' in resources[lang]:\n            continue\n        if all(k in (\"backward_charlm\", \"forward_charlm\", \"pretrain\", \"lang_name\") for k in resources[lang].keys()):\n            continue\n        if lang in allowed_empty_languages and lang not in default_treebanks:\n            continue\n        if lang not in default_treebanks:\n            raise AssertionError(f'{lang} not in default treebanks!!!')\n\n        if args.lang and lang not in args.lang.split(\",\"):\n            continue\n\n        print(f'Preparing default models for language {lang}')\n\n        models_needed = defaultdict(set)\n\n        packages = resources[lang][PACKAGES][\"default\"]\n        for processor, package in packages.items():\n            if processor == 'lemma' and package == 'identity':\n                continue\n            if processor == 'optional':\n                continue\n            models_needed[processor].add(package)\n            dependencies = get_dependencies(processor, lang, package)\n            for dependency in dependencies:\n                models_needed[dependency['model']].add(dependency['package'])\n\n        model_files = []\n        for processor in PROCESSORS:\n            if processor in models_needed:\n                for package in sorted(models_needed[processor]):\n                    filename = os.path.join(args.output_dir, lang, \"models\", processor, package + '.pt')\n                    if os.path.exists(filename):\n                        print(\"   Model {} package {}: file {}\".format(processor, package, filename))\n                        model_files.append((filename, processor, package))\n                    else:\n                        raise FileNotFoundError(f\"Processor {processor} package {package} needed for {lang} but cannot be found at {filename}\")\n\n        with zipfile.ZipFile(os.path.join(args.output_dir, lang, 'models', 'default.zip'), 'w', zipfile.ZIP_DEFLATED) as zipf:\n            for filename, processor, package in model_files:\n                zipf.write(filename=filename, arcname=os.path.join(processor, package + '.pt'))\n\n        default_md5 = get_md5(os.path.join(args.output_dir, lang, 'models', 'default.zip'))\n        resources[lang]['default_md5'] = default_md5\n\n    print(\"Processed default model zips.  Writing resources.json\")\n    json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)\n\ndef get_default_processors(resources, lang):\n    \"\"\"\n    Build a default package for this language\n\n    Will add each of pos, lemma, depparse, etc if those are available\n    Uses the existing models scraped from the language directories into resources.json, as relevant\n    \"\"\"\n    if lang == \"multilingual\":\n        return {\"langid\": \"ud\"}\n\n    default_package = default_treebanks[lang]\n    default_processors = {}\n    if lang in default_tokenizer:\n        default_processors['tokenize'] = default_tokenizer[lang]\n    else:\n        tokenize_package = default_package\n        if tokenize_package not in resources[lang]['tokenize']:\n            tokenize_package = tokenize_package + \"_nocharlm\"\n        if tokenize_package not in resources[lang]['tokenize']:\n            raise AssertionError(\"Can't find a tokenizer package for %s!  Tried %s and %s\" % (lang, default_package, tokenize_package))\n        default_processors['tokenize'] = tokenize_package\n\n    if 'mwt' in resources[lang] and default_package in resources[lang]['mwt']:\n        # if this doesn't happen, we just skip MWT\n        default_processors['mwt'] = default_package\n\n    if 'lemma' in resources[lang]:\n        expected_lemma = default_package + \"_nocharlm\"\n        if expected_lemma in resources[lang]['lemma']:\n            default_processors['lemma'] = expected_lemma\n        else:\n            expected_lemma = default_package + \"_charlm\"\n            if expected_lemma in resources[lang]['lemma']:\n                default_processors['lemma'] = expected_lemma\n                print(\"WARNING: nocharlm lemmatizer for %s model does not exist, but %s does\" % (default_package, expected_lemma))\n    elif lang not in allowed_empty_languages:\n        default_processors['lemma'] = 'identity'\n\n    if 'pos' in resources[lang]:\n        default_processors['pos'] = get_default_pos_package(lang, default_package)\n        if default_processors['pos'] not in resources[lang]['pos']:\n            raise AssertionError(\"Expected POS model not in resources: %s\" % default_processors['pos'])\n    elif lang not in allowed_empty_languages:\n        raise AssertionError(\"Expected to find POS models for language %s\" % lang)\n\n    if 'depparse' in resources[lang]:\n        default_processors['depparse'] = get_default_depparse_package(lang, default_package)\n        if default_processors['depparse'] not in resources[lang]['depparse']:\n            raise AssertionError(\"Expected depparse model not in resources: %s\" % default_processors['depparse'])\n    elif lang not in allowed_empty_languages:\n        raise AssertionError(\"Expected to find depparse models for language %s\" % lang)\n\n    if lang in default_ners:\n        default_processors['ner'] = default_ners[lang]\n\n    if lang in default_sentiment:\n        default_processors['sentiment'] = default_sentiment[lang]\n\n    if lang in default_constituency:\n        default_processors['constituency'] = default_constituency[lang]\n\n    optional = get_default_optional_processors(resources, lang)\n    if optional:\n        default_processors['optional'] = optional\n\n    return default_processors\n\ndef get_default_optional_processors(resources, lang):\n    optional_processors = {}\n    if lang in optional_constituency:\n        optional_processors['constituency'] = optional_constituency[lang]\n\n    if lang in optional_coref:\n        optional_processors['coref'] = optional_coref[lang]\n\n    return optional_processors\n\ndef update_processor_add_transformer(resources, lang, current_processors, processor, transformer):\n    if processor not in current_processors:\n        return\n\n    new_model = current_processors[processor].replace('_charlm', \"_\" + transformer).replace('_nocharlm', \"_\" + transformer)\n    if new_model in resources[lang][processor]:\n        current_processors[processor] = new_model\n    else:\n        print(\"WARNING: wanted to use %s for %s accurate %s, but that model does not exist\" % (new_model, lang, processor))\n\ndef get_default_accurate(resources, lang):\n    \"\"\"\n    A package that, if available, uses charlm and transformer models for each processor\n    \"\"\"\n    default_processors = get_default_processors(resources, lang)\n\n    tokenizer_model = default_processors['tokenize']\n    if tokenizer_model.endswith('_nocharlm'):\n        tokenizer_model = tokenizer_model.replace('_nocharlm', '_charlm')\n    elif 'charlm' not in tokenizer_model:\n        tokenizer_model = tokenizer_model + '_charlm'\n    if tokenizer_model.endswith('_charlm') and tokenizer_model in resources[lang]['tokenize']:\n        default_processors['tokenize'] = tokenizer_model\n        print(\"TOKENIZE found a charlm version %s for %s default_accurate\" % (tokenizer_model, lang))\n\n    if 'lemma' in default_processors and default_processors['lemma'] != 'identity':\n        lemma_model = default_processors['lemma']\n        lemma_model = lemma_model.replace('_nocharlm', '_charlm')\n        charlm_package = get_lemma_charlm_package(lang, lemma_model)\n        if charlm_package is not None:\n            if lemma_model in resources[lang]['lemma']:\n                default_processors['lemma'] = lemma_model\n            else:\n                print(\"WARNING: wanted to use %s for %s default_accurate lemma, but that model does not exist\" % (lemma_model, lang))\n\n    transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None)\n    if transformer is not None:\n        for processor in ('pos', 'depparse', 'constituency', 'sentiment'):\n            update_processor_add_transformer(resources, lang, default_processors, processor, transformer)\n        if 'ner' in default_processors and (default_processors['ner'].endswith(\"_charlm\") or default_processors['ner'].endswith(\"_nocharlm\")):\n            update_processor_add_transformer(resources, lang, default_processors, \"ner\", transformer)\n\n    optional = get_optional_accurate(resources, lang)\n    if optional:\n        default_processors['optional'] = optional\n\n    return default_processors\n\ndef get_optional_accurate(resources, lang):\n    optional_processors = get_default_optional_processors(resources, lang)\n\n    transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None)\n    if transformer is not None:\n        for processor in ('pos', 'depparse', 'constituency', 'sentiment'):\n            update_processor_add_transformer(resources, lang, optional_processors, processor, transformer)\n\n    if lang in optional_coref:\n        optional_processors['coref'] = optional_coref[lang]\n\n    return optional_processors\n\n\ndef get_default_fast(resources, lang):\n    \"\"\"\n    Build a packages entry which only has the nocharlm models\n\n    Will make it easy for people to use the lower tier of models\n\n    We do this by building the same default package as normal,\n    then switching everything out for the lower tier model when possible.\n    We also remove constituency, as it is super slow.\n    Note that in the case of a language which doesn't have a charlm,\n    that means we wind up building the same for default and default_nocharlm\n    \"\"\"\n    default_processors = get_default_processors(resources, lang)\n\n    # this is a slow model and we don't have non-charlm versions of it yet\n    if 'constituency' in default_processors:\n        default_processors.pop('constituency')\n\n    for processor, model in default_processors.items():\n        if \"_charlm\" in model:\n            nocharlm = model.replace(\"_charlm\", \"_nocharlm\")\n            if nocharlm not in resources[lang][processor]:\n                print(\"WARNING: wanted to use %s for %s default_fast processor %s, but that model does not exist\" % (nocharlm, lang, processor))\n            else:\n                default_processors[processor] = nocharlm\n\n    return default_processors\n\ndef process_packages(args):\n    \"\"\"\n    Build a package for a language's default processors and all of the treebanks specifically used for that language\n    \"\"\"\n    resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))\n\n    for lang in resources:\n        # check url, alias, and lang_name in case we are rerunning this step on an already built resources.json\n        if lang == 'url':\n            continue\n        if 'alias' in resources[lang]:\n            continue\n        if all(k in (\"backward_charlm\", \"forward_charlm\", \"pretrain\", \"lang_name\") for k in resources[lang].keys()):\n            continue\n        if lang in allowed_empty_languages and lang not in default_treebanks:\n            continue\n        if lang not in default_treebanks:\n            raise AssertionError(f'{lang} not in default treebanks!!!')\n\n        if args.lang and lang not in args.lang.split(\",\"):\n            continue\n\n        default_processors = get_default_processors(resources, lang)\n\n        # TODO: eventually we can remove default_processors\n        # For now, we want to keep this so that v1.5.1 is compatible\n        # with the next iteration of resources files\n        resources[lang]['default_processors'] = default_processors\n        resources[lang][PACKAGES] = {}\n        resources[lang][PACKAGES]['default'] = default_processors\n\n        if lang not in no_pretrain_languages and lang != \"multilingual\":\n            default_fast = get_default_fast(resources, lang)\n            resources[lang][PACKAGES]['default_fast'] = default_fast\n\n            default_accurate = get_default_accurate(resources, lang)\n            resources[lang][PACKAGES]['default_accurate'] = default_accurate\n\n        # Now we loop over each of the tokenizers for this language\n        # ... we use this as a proxy for the available UD treebanks\n        # This loop also catches things such as \"craft\" which are\n        # included treebanks that aren't UD\n        # We then create a package in the packages dict for each of those treebanks\n        if 'tokenize' in resources[lang]:\n            for package in resources[lang]['tokenize']:\n                package, _, _ = split_package(package)\n                if package in resources[lang][PACKAGES]:\n                    # can happen in the case of a _nocharlm and _charlm version of the tokenizer\n                    continue\n\n                processors = {}\n                # TODO: when we rebuild all the models, make all the tokenizers say _nocharlm\n                if package in resources[lang]['tokenize']:\n                    processors[\"tokenize\"] = package\n                elif package + \"_nocharlm\" in resources[lang]['tokenize']:\n                    processors[\"tokenize\"] = package + \"_nocharlm\"\n                else:\n                    raise AssertionError(\"Should have found a tokenizer for lang %s package %s\" % (lang, package))\n\n                if \"mwt\" in resources[lang] and package in resources[lang][\"mwt\"]:\n                    processors[\"mwt\"] = package\n\n                if \"pos\" in resources[lang]:\n                    if package + \"_charlm\" in resources[lang][\"pos\"]:\n                        processors[\"pos\"] = package + \"_charlm\"\n                    elif package + \"_nocharlm\" in resources[lang][\"pos\"]:\n                        processors[\"pos\"] = package + \"_nocharlm\"\n\n                if \"lemma\" in resources[lang] and \"pos\" in processors:\n                    lemma_package = package + \"_nocharlm\"\n                    if lemma_package in resources[lang][\"lemma\"]:\n                        processors[\"lemma\"] = lemma_package\n                    else:\n                        lemma_package = package + \"_charlm\"\n                        if lemma_package in resources[lang]['lemma']:\n                            processors['lemma'] = lemma_package\n                            print(\"WARNING: nocharlm lemmatizer for %s model does not exist, but %s does\" % (package, lemma_package))\n\n                if \"depparse\" in resources[lang] and \"pos\" in processors:\n                    depparse_package = None\n                    if package + \"_charlm\" in resources[lang][\"depparse\"]:\n                        depparse_package = package + \"_charlm\"\n                    elif package + \"_nocharlm\" in resources[lang][\"depparse\"]:\n                        depparse_package = package + \"_nocharlm\"\n                    # we want to set the lemma first if it's identity\n                    # THEN set the depparse\n                    if depparse_package is not None:\n                        if \"lemma\" not in processors:\n                            processors[\"lemma\"] = \"identity\"\n                        processors[\"depparse\"] = depparse_package\n\n                resources[lang][PACKAGES][package] = processors\n\n    print(\"Processed packages.  Writing resources.json\")\n    json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)\n\ndef process_lcode(args):\n    resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))\n    resources_new = {}\n    resources_new[\"multilingual\"] = resources[\"multilingual\"]\n    for lang in resources:\n        if lang == 'multilingual':\n            continue\n        if 'alias' in resources[lang]:\n            continue\n        if lang not in lcode2lang:\n            print(lang + ' not found in lcode2lang!')\n            continue\n        lang_name = lcode2lang[lang]\n        resources[lang]['lang_name'] = lang_name\n        resources_new[lang.lower()] = resources[lang.lower()]\n        resources_new[lang_name.lower()] = {'alias': lang.lower()}\n        if lang.lower() in two_to_three_letters:\n            resources_new[two_to_three_letters[lang.lower()]] = {'alias': lang.lower()}\n        elif lang.lower() in three_to_two_letters:\n            resources_new[three_to_two_letters[lang.lower()]] = {'alias': lang.lower()}\n        if lang.lower() in extra_lang_to_lcodes:\n            alternative = extra_lang_to_lcodes[lang.lower()].lower()\n            if alternative not in resources_new:\n                resources_new[alternative] = {'alias': lang.lower()}\n    print(\"Processed lcode aliases.  Writing resources.json\")\n    json.dump(resources_new, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)\n\n\ndef process_misc(args):\n    resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))\n    resources['no'] = {'alias': 'nb'}\n    resources['zh'] = {'alias': 'zh-hans'}\n    # This is intended to be unformatted.  expand_model_url in common.py will fill in the raw string\n    # with the appropriate values in order to find the needed model file on huggingface\n    resources['url'] = 'https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}'\n    print(\"Finalized misc attributes.  Writing resources.json\")\n    json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)\n\n\ndef main():\n    args = parse_args()\n    print(\"Converting models from %s to %s\" % (args.input_dir, args.output_dir))\n    if not args.packages_only:\n        process_dirs(args)\n    process_packages(args)\n    if not args.packages_only:\n        process_default_zips(args)\n        process_lcode(args)\n        process_misc(args)\n\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/resources/print_charlm_depparse.py",
    "content": "\"\"\"\nA small utility script to output which depparse models use charlm\n\n(It should skip en_genia, en_craft, but currently doesn't)\n\nNot frequently useful, but seems like the kind of thing that might get used a couple times\n\"\"\"\n\nfrom stanza.resources.common import load_resources_json\nfrom stanza.resources.default_packages import default_charlms, depparse_charlms\n\ndef list_depparse():\n    charlm_langs = list(default_charlms.keys())\n    resources = load_resources_json()\n\n    models = [\"%s_%s\" % (lang, model) for lang in charlm_langs for model in resources[lang].get(\"depparse\", {})\n              if lang not in depparse_charlms or model not in depparse_charlms[lang] or depparse_charlms[lang][model] is not None]\n    return models\n\nif __name__ == \"__main__\":\n    models = list_depparse()\n    print(\" \".join(models))\n"
  },
  {
    "path": "stanza/server/__init__.py",
    "content": "from stanza.protobuf import to_text\nfrom stanza.protobuf import Document, Sentence, Token, IndexedWord, Span\nfrom stanza.protobuf import ParseTree, DependencyGraph, CorefChain\nfrom stanza.protobuf import Mention, NERMention, Entity, Relation, RelationTriple, Timex\nfrom stanza.protobuf import Quote, SpeakerInfo\nfrom stanza.protobuf import Operator, Polarity\nfrom stanza.protobuf import SentenceFragment, TokenLocation\nfrom stanza.protobuf import MapStringString, MapIntString\nfrom .client import CoreNLPClient, AnnotationException, TimeoutException, PermanentlyFailedException, StartServer\nfrom .annotator import Annotator\n"
  },
  {
    "path": "stanza/server/annotator.py",
    "content": "\"\"\"\nDefines a base class that can be used to annotate.\n\"\"\"\nimport io\nfrom multiprocessing import Process\nfrom http.server import BaseHTTPRequestHandler, HTTPServer\nfrom http import client as HTTPStatus\n\nfrom stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString\n\nclass Annotator(Process):\n    \"\"\"\n    This annotator base class hosts a lightweight server that accepts\n    annotation requests from CoreNLP.\n    Each annotator simply defines 3 functions: requires, provides and annotate.\n\n    This class takes care of defining appropriate endpoints to interface\n    with CoreNLP.\n    \"\"\"\n    @property\n    def name(self):\n        \"\"\"\n        Name of the annotator (used by CoreNLP)\n        \"\"\"\n        raise NotImplementedError()\n\n    @property\n    def requires(self):\n        \"\"\"\n        Requires has to specify all the annotations required before we\n        are called.\n        \"\"\"\n        raise NotImplementedError()\n\n    @property\n    def provides(self):\n        \"\"\"\n        The set of annotations guaranteed to be provided when we are done.\n        NOTE: that these annotations are either fully qualified Java\n        class names or refer to nested classes of\n        edu.stanford.nlp.ling.CoreAnnotations (as is the case below).\n        \"\"\"\n        raise NotImplementedError()\n\n    def annotate(self, ann):\n        \"\"\"\n        @ann: is a protobuf annotation object.\n        Actually populate @ann with tokens.\n        \"\"\"\n        raise NotImplementedError()\n\n    @property\n    def properties(self):\n        \"\"\"\n        Defines a Java property to define this annotator to CoreNLP.\n        \"\"\"\n        return {\n            \"customAnnotatorClass.{}\".format(self.name): \"edu.stanford.nlp.pipeline.GenericWebServiceAnnotator\",\n            \"generic.endpoint\": \"http://{}:{}\".format(self.host, self.port),\n            \"generic.requires\": \",\".join(self.requires),\n            \"generic.provides\": \",\".join(self.provides),\n            }\n\n    class _Handler(BaseHTTPRequestHandler):\n        annotator = None\n\n        def __init__(self, request, client_address, server):\n            BaseHTTPRequestHandler.__init__(self, request, client_address, server)\n\n        def do_GET(self):\n            \"\"\"\n            Handle a ping request\n            \"\"\"\n            if not self.path.endswith(\"/\"): self.path += \"/\"\n            if self.path == \"/ping/\":\n                msg = \"pong\".encode(\"UTF-8\")\n\n                self.send_response(HTTPStatus.OK)\n                self.send_header(\"Content-Type\", \"text/application\")\n                self.send_header(\"Content-Length\", len(msg))\n                self.end_headers()\n                self.wfile.write(msg)\n            else:\n                self.send_response(HTTPStatus.BAD_REQUEST)\n                self.end_headers()\n\n        def do_POST(self):\n            \"\"\"\n            Handle an annotate request\n            \"\"\"\n            if not self.path.endswith(\"/\"): self.path += \"/\"\n            if self.path == \"/annotate/\":\n                # Read message\n                length = int(self.headers.get('content-length'))\n                msg = self.rfile.read(length)\n\n                # Do the annotation\n                doc = Document()\n                parseFromDelimitedString(doc, msg)\n                self.annotator.annotate(doc)\n\n                with io.BytesIO() as stream:\n                    writeToDelimitedString(doc, stream)\n                    msg = stream.getvalue()\n\n                # write message\n                self.send_response(HTTPStatus.OK)\n                self.send_header(\"Content-Type\", \"application/x-protobuf\")\n                self.send_header(\"Content-Length\", len(msg))\n                self.end_headers()\n                self.wfile.write(msg)\n\n            else:\n                self.send_response(HTTPStatus.BAD_REQUEST)\n                self.end_headers()\n\n    def __init__(self, host=\"\", port=8432):\n        \"\"\"\n        Launches a server endpoint to communicate with CoreNLP\n        \"\"\"\n        Process.__init__(self)\n        self.host, self.port = host, port\n        self._Handler.annotator = self\n\n    def run(self):\n        \"\"\"\n        Runs the server using Python's simple HTTPServer.\n        TODO: make this multithreaded.\n        \"\"\"\n        httpd = HTTPServer((self.host, self.port), self._Handler)\n        sa = httpd.socket.getsockname()\n        serve_message = \"Serving HTTP on {host} port {port} (http://{host}:{port}/) ...\"\n        print(serve_message.format(host=sa[0], port=sa[1]))\n        try:\n            httpd.serve_forever()\n        except KeyboardInterrupt:\n            print(\"\\nKeyboard interrupt received, exiting.\")\n            httpd.shutdown()\n"
  },
  {
    "path": "stanza/server/client.py",
    "content": "\"\"\"\nClient for accessing Stanford CoreNLP in Python\n\"\"\"\n\nimport atexit\nimport contextlib\nimport enum\nimport io\nimport os\nimport re\nimport requests\nimport logging\nimport json\nimport shlex\nimport socket\nimport subprocess\nimport time\nimport sys\nimport uuid\n\nfrom datetime import datetime\nfrom pathlib import Path\nfrom urllib.parse import urlparse\n\nfrom stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString, to_text\n__author__ = 'arunchaganty, kelvinguu, vzhong, wmonroe4'\n\nlogger = logging.getLogger('stanza')\n\n# pattern tmp props file should follow\nSERVER_PROPS_TMP_FILE_PATTERN = re.compile('corenlp_server-(.*).props')\n\n# Check if str is CoreNLP supported language\nCORENLP_LANGS = ['ar', 'arabic', 'chinese', 'zh', 'english', 'en', 'french', 'fr', 'de', 'german', 'hu', 'hungarian',\n                 'it', 'italian', 'es', 'spanish']\n\n# map shorthands to full language names\nLANGUAGE_SHORTHANDS_TO_FULL = {\n    \"ar\": \"arabic\",\n    \"zh\": \"chinese\",\n    \"en\": \"english\",\n    \"fr\": \"french\",\n    \"de\": \"german\",\n    \"hu\": \"hungarian\",\n    \"it\": \"italian\",\n    \"es\": \"spanish\"\n}\n\n\ndef is_corenlp_lang(props_str):\n    \"\"\" Check if a string references a CoreNLP language \"\"\"\n    return props_str.lower() in CORENLP_LANGS\n\n\n# Validate CoreNLP properties\nCORENLP_OUTPUT_VALS = [\"conll\", \"conllu\", \"json\", \"serialized\", \"text\", \"xml\", \"inlinexml\"]\n\n\ndef validate_corenlp_props(properties=None, annotators=None, output_format=None):\n    \"\"\" Do basic checks to validate CoreNLP properties \"\"\"\n    if output_format and output_format.lower() not in CORENLP_OUTPUT_VALS:\n        raise ValueError(f\"{output_format} not a valid CoreNLP outputFormat value! Choose from: {CORENLP_OUTPUT_VALS}\")\n    if type(properties) == dict:\n        if \"outputFormat\" in properties and properties[\"outputFormat\"].lower() not in CORENLP_OUTPUT_VALS:\n            raise ValueError(f\"{properties['outputFormat']} not a valid CoreNLP outputFormat value! Choose from: \"\n                             f\"{CORENLP_OUTPUT_VALS}\")\n\n\nclass AnnotationException(Exception):\n    \"\"\" Exception raised when there was an error communicating with the CoreNLP server. \"\"\"\n    pass\n\n\nclass TimeoutException(AnnotationException):\n    \"\"\" Exception raised when the CoreNLP server timed out. \"\"\"\n    pass\n\n\nclass ShouldRetryException(Exception):\n    \"\"\" Exception raised if the service should retry the request. \"\"\"\n    pass\n\n\nclass PermanentlyFailedException(Exception):\n    \"\"\" Exception raised if the service should NOT retry the request. \"\"\"\n    pass\n\nclass StartServer(enum.Enum):\n    DONT_START = 0\n    FORCE_START = 1\n    TRY_START = 2\n\n\ndef clean_props_file(props_file):\n    # check if there is a temp server props file to remove and remove it\n    if props_file:\n        if os.path.isfile(props_file) and SERVER_PROPS_TMP_FILE_PATTERN.match(os.path.basename(props_file)):\n            os.remove(props_file)\n\n\nclass RobustService(object):\n    \"\"\" Service that resuscitates itself if it is not available. \"\"\"\n    CHECK_ALIVE_TIMEOUT = 120\n\n    def __init__(self, start_cmd, stop_cmd, endpoint, stdout=None,\n                 stderr=None, be_quiet=False, host=None, port=None, ignore_binding_error=False):\n        self.start_cmd = start_cmd and shlex.split(start_cmd)\n        self.stop_cmd = stop_cmd and shlex.split(stop_cmd)\n        self.endpoint = endpoint\n        self.stdout = stdout\n        self.stderr = stderr\n\n        self.server = None\n        self.is_active = False\n        self.be_quiet = be_quiet\n        self.host = host\n        self.port = port\n        self.ignore_binding_error = ignore_binding_error\n        atexit.register(self.atexit_kill)\n\n    def is_alive(self):\n        try:\n            if not self.ignore_binding_error and self.server is not None and self.server.poll() is not None:\n                return False\n            return requests.get(self.endpoint + \"/ping\").ok\n        except requests.exceptions.ConnectionError as e:\n            raise ShouldRetryException(e)\n\n    def start(self):\n        if self.start_cmd:\n            if self.host and self.port:\n                with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:\n                    try:\n                        sock.bind((self.host, self.port))\n                    except socket.error as e:\n                        if self.ignore_binding_error:\n                            logger.info(f\"Connecting to existing CoreNLP server at {self.host}:{self.port}\")\n                            self.server = None\n                            return\n                        else:\n                            raise PermanentlyFailedException(\"Error: unable to start the CoreNLP server on port %d \"\n                                                             \"(possibly something is already running there)\" % self.port) from e\n            if self.be_quiet:\n                # Issue #26: subprocess.DEVNULL isn't supported in python 2.7.\n                if hasattr(subprocess, 'DEVNULL'):\n                    stderr = subprocess.DEVNULL\n                else:\n                    stderr = open(os.devnull, 'w')\n                stdout = stderr\n            else:\n                stdout = self.stdout\n                stderr = self.stderr\n            logger.info(f\"Starting server with command: {' '.join(self.start_cmd)}\")\n            try:\n                self.server = subprocess.Popen(self.start_cmd,\n                                               stderr=stderr,\n                                               stdout=stdout)\n            except FileNotFoundError as e:\n                raise FileNotFoundError(\"When trying to run CoreNLP, a FileNotFoundError occurred, which frequently means Java was not installed or was not in the classpath.\") from e\n\n    def atexit_kill(self):\n        # make some kind of effort to stop the service (such as a\n        # CoreNLP server) at the end of the program.  not waiting so\n        # that the python script exiting isn't delayed\n        if self.server and self.server.poll() is None:\n            self.server.terminate()\n\n    def stop(self):\n        if self.server:\n            self.server.terminate()\n            try:\n                self.server.wait(5)\n            except subprocess.TimeoutExpired:\n                # Resorting to more aggressive measures...\n                self.server.kill()\n                try:\n                    self.server.wait(5)\n                except subprocess.TimeoutExpired:\n                    # oh well\n                    pass\n            self.server = None\n        if self.stop_cmd:\n            subprocess.run(self.stop_cmd, check=True)\n        self.is_active = False\n\n    def __enter__(self):\n        self.start()\n        return self\n\n    def __exit__(self, _, __, ___):\n        self.stop()\n\n    def ensure_alive(self):\n        # Check if the service is active and alive\n        if self.is_active:\n            try:\n                if self.is_alive():\n                    return\n                else:\n                    self.stop()\n            except ShouldRetryException:\n                pass\n\n        # If not, try to start up the service.\n        if self.server is None:\n            self.start()\n\n        # Wait for the service to start up.\n        start_time = time.time()\n        while True:\n            try:\n                if self.is_alive():\n                    break\n            except ShouldRetryException:\n                pass\n\n            if time.time() - start_time < self.CHECK_ALIVE_TIMEOUT:\n                time.sleep(1)\n            else:\n                raise PermanentlyFailedException(\"Timed out waiting for service to come alive.\")\n\n        # At this point we are guaranteed that the service is alive.\n        self.is_active = True\n\n\ndef resolve_classpath(classpath=None):\n    \"\"\"\n    Returns the classpath to use for corenlp.\n\n    Prefers to use the given classpath parameter, if available.  If\n    not, uses the CORENLP_HOME environment variable.  Resolves $CLASSPATH\n    (the exact string) in either the classpath parameter or $CORENLP_HOME.\n    \"\"\"\n    if classpath == '$CLASSPATH' or (classpath is None and os.getenv(\"CORENLP_HOME\", None) == '$CLASSPATH'):\n        classpath = os.getenv(\"CLASSPATH\")\n    elif classpath is None:\n        classpath = os.getenv(\"CORENLP_HOME\", os.path.join(str(Path.home()), 'stanza_corenlp'))\n\n        if not os.path.exists(classpath):\n            raise FileNotFoundError(\"Please install CoreNLP by running `stanza.install_corenlp()`. If you have installed it, please define \"\n                                    \"$CORENLP_HOME to be location of your CoreNLP distribution or pass in a classpath parameter.  \"\n                                    \"$CORENLP_HOME={}\".format(os.getenv(\"CORENLP_HOME\")))\n        classpath = os.path.join(classpath, \"*\")\n    return classpath\n\n\nclass CoreNLPClient(RobustService):\n    \"\"\" A client to the Stanford CoreNLP server. \"\"\"\n\n    DEFAULT_ENDPOINT = \"http://localhost:9000\"\n    DEFAULT_TIMEOUT = 60000\n    DEFAULT_THREADS = 5\n    DEFAULT_OUTPUT_FORMAT = \"serialized\"\n    DEFAULT_MEMORY = \"5G\"\n    DEFAULT_MAX_CHAR_LENGTH = 100000\n\n    def __init__(self, start_server=StartServer.FORCE_START,\n                 endpoint=DEFAULT_ENDPOINT,\n                 timeout=DEFAULT_TIMEOUT,\n                 threads=DEFAULT_THREADS,\n                 annotators=None,\n                 pretokenized=False,\n                 output_format=None,\n                 properties=None,\n                 stdout=None,\n                 stderr=None,\n                 memory=DEFAULT_MEMORY,\n                 be_quiet=False,\n                 max_char_length=DEFAULT_MAX_CHAR_LENGTH,\n                 preload=True,\n                 classpath=None,\n                 **kwargs):\n\n        # whether or not server should be started by client\n        self.start_server = start_server\n        self.server_props_path = None\n        self.server_start_time = None\n        self.server_host = None\n        self.server_port = None\n        self.server_classpath = None\n        # validate properties\n        validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format)\n        # set up client defaults\n        self.properties = properties\n        self.annotators = annotators\n        self.pretokenized = pretokenized\n        self.output_format = output_format\n        self._setup_client_defaults()\n        # start the server\n        if isinstance(start_server, bool):\n            warning_msg = f\"Setting 'start_server' to a boolean value when constructing {self.__class__.__name__} is deprecated and will stop\" + \\\n                \" to function in a future version of stanza. Please consider switching to using a value from stanza.server.StartServer.\"\n            logger.warning(warning_msg)\n            start_server = StartServer.FORCE_START if start_server is True else StartServer.DONT_START\n\n        # start the server\n        if start_server is StartServer.FORCE_START or start_server is StartServer.TRY_START:\n            # record info for server start\n            self.server_start_time = datetime.now()\n            # set up default properties for server\n            self._setup_server_defaults()\n            host, port = urlparse(endpoint).netloc.split(\":\")\n            port = int(port)\n            assert host == \"localhost\", \"If starting a server, endpoint must be localhost\"\n            classpath = resolve_classpath(classpath)\n            start_cmd = f\"java -Xmx{memory} -cp '{classpath}'  edu.stanford.nlp.pipeline.StanfordCoreNLPServer \" \\\n                        f\"-port {port} -timeout {timeout} -threads {threads} -maxCharLength {max_char_length} \" \\\n                        f\"-quiet {be_quiet} \"\n\n            self.server_classpath = classpath\n            self.server_host = host\n            self.server_port = port\n\n            # set up server defaults\n            if self.server_props_path is not None:\n                start_cmd += f\" -serverProperties {self.server_props_path}\"\n\n            # possibly set pretokenized\n            if self.pretokenized:\n                start_cmd += f\" -preTokenized\"\n\n            # set annotators for server default\n            if self.annotators is not None:\n                annotators_str = self.annotators if type(annotators) == str else \",\".join(annotators)\n                start_cmd += f\" -annotators {annotators_str}\"\n\n            # specify what to preload, if anything\n            if preload:\n                if type(preload) == bool:\n                    # -preload flag means to preload all default annotators\n                    start_cmd += \" -preload\"\n                elif type(preload) == list:\n                    # turn list into comma separated list string, only preload these annotators\n                    start_cmd += f\" -preload {','.join(preload)}\"\n                elif type(preload) == str:\n                    # comma separated list of annotators\n                    start_cmd += f\" -preload {preload}\"\n\n            # set outputFormat for server default\n            # if no output format requested by user, set to serialized\n            start_cmd += f\" -outputFormat {self.output_format}\"\n\n            # additional options for server:\n            # - server_id\n            # - ssl\n            # - status_port\n            # - uriContext\n            # - strict\n            # - key\n            # - username\n            # - password\n            # - blockList\n            for kw in ['ssl', 'strict']:\n                if kwargs.get(kw) is not None:\n                    start_cmd += f\" -{kw}\"\n            for kw in ['status_port', 'uriContext', 'key', 'username', 'password', 'blockList', 'server_id']:\n                if kwargs.get(kw) is not None:\n                    start_cmd += f\" -{kw} {kwargs.get(kw)}\"\n            stop_cmd = None\n        else:\n            start_cmd = stop_cmd = None\n            host = port = None\n\n        super(CoreNLPClient, self).__init__(start_cmd, stop_cmd, endpoint,\n                                            stdout, stderr, be_quiet, host=host, port=port, ignore_binding_error=(start_server == StartServer.TRY_START))\n\n        self.timeout = timeout\n\n    def _setup_client_defaults(self):\n        \"\"\"\n        Do some processing of annotators and output_format specified for the client.\n        If interacting with an externally started server, these will be defaults for annotate() calls.\n        :return: None\n        \"\"\"\n        # normalize annotators to str\n        if self.annotators is not None:\n            self.annotators = self.annotators if type(self.annotators) == str else \",\".join(self.annotators)\n\n        # handle case where no output format is specified\n        if self.output_format is None:\n            if type(self.properties) == dict and 'outputFormat' in self.properties:\n                self.output_format = self.properties['outputFormat']\n            else:\n                self.output_format = CoreNLPClient.DEFAULT_OUTPUT_FORMAT\n\n    def _setup_server_defaults(self):\n        \"\"\"\n        Set up the default properties for the server.\n\n        The properties argument can take on one of 3 value types\n\n        1. File path on system or in CLASSPATH (e.g. /path/to/server.props or StanfordCoreNLP-french.properties\n        2. Name of a Stanford CoreNLP supported language (e.g. french or fr)\n        3. Python dictionary (properties written to tmp file for Java server, erased at end)\n\n        In addition, an annotators list and output_format can be specified directly with arguments. These\n        will overwrite any settings in the specified properties.\n\n        If no properties are specified, the standard Stanford CoreNLP English server will be launched. The outputFormat\n        will be set to 'serialized' and use the ProtobufAnnotationSerializer.\n        \"\"\"\n\n        # ensure properties is str or dict\n        if self.properties is None or (not isinstance(self.properties, str) and not isinstance(self.properties, dict)):\n            if self.properties is not None:\n                logger.warning('properties passed invalid value (not a str or dict), setting properties = {}')\n            self.properties = {}\n        # check if properties is a string, pass on to server which can handle\n        if isinstance(self.properties, str):\n            # try to translate to Stanford CoreNLP language name, or assume properties is a path\n            if is_corenlp_lang(self.properties):\n                if self.properties.lower() in LANGUAGE_SHORTHANDS_TO_FULL:\n                    self.properties = LANGUAGE_SHORTHANDS_TO_FULL[self.properties]\n                logger.info(\n                    f\"Using CoreNLP default properties for: {self.properties}.  Make sure to have \"\n                    f\"{self.properties} models jar (available for download here: \"\n                    f\"https://stanfordnlp.github.io/CoreNLP/) in CLASSPATH\")\n            else:\n                if not os.path.isfile(self.properties):\n                    logger.warning(f\"{self.properties} does not correspond to a file path. Make sure this file is in \"\n                                   f\"your CLASSPATH.\")\n            self.server_props_path = self.properties\n        elif isinstance(self.properties, dict):\n            # make a copy\n            server_start_properties = dict(self.properties)\n            if self.annotators is not None:\n                server_start_properties['annotators'] = self.annotators\n            if self.output_format is not None and isinstance(self.output_format, str):\n                server_start_properties['outputFormat'] = self.output_format\n            # write desired server start properties to tmp file\n            # set up to erase on exit\n            tmp_path = write_corenlp_props(server_start_properties)\n            logger.info(f\"Writing properties to tmp file: {tmp_path}\")\n            atexit.register(clean_props_file, tmp_path)\n            self.server_props_path = tmp_path\n\n    def _request(self, buf, properties, reset_default=False, **kwargs):\n        \"\"\"\n        Send a request to the CoreNLP server.\n\n        :param (str | bytes) buf: data to be sent with the request\n        :param (dict) properties: properties that the server expects\n        :return: request result\n        \"\"\"\n        if self.start_server is not StartServer.DONT_START:\n            self.ensure_alive()\n\n        try:\n            input_format = properties.get(\"inputFormat\", \"text\")\n            if input_format == \"text\":\n                ctype = \"text/plain; charset=utf-8\"\n            elif input_format == \"serialized\":\n                ctype = \"application/x-protobuf\"\n            else:\n                raise ValueError(\"Unrecognized inputFormat \" + input_format)\n            # handle auth\n            if 'username' in kwargs and 'password' in kwargs:\n                kwargs['auth'] = requests.auth.HTTPBasicAuth(kwargs['username'], kwargs['password'])\n                kwargs.pop('username')\n                kwargs.pop('password')\n            r = requests.post(self.endpoint,\n                              params={'properties': str(properties), 'resetDefault': str(reset_default).lower()},\n                              data=buf, headers={'content-type': ctype},\n                              timeout=(self.timeout*2)/1000, **kwargs)\n            r.raise_for_status()\n            return r\n        except requests.exceptions.Timeout as e:\n            raise TimeoutException(\"Timeout requesting to CoreNLPServer. Maybe server is unavailable or your document is too long\")\n        except requests.exceptions.RequestException as e:\n            if e.response is not None and e.response.text is not None:\n                raise AnnotationException(e.response.text) from e\n            elif e.args:\n                raise AnnotationException(e.args[0]) from e\n            raise AnnotationException() from e\n\n    def annotate(self, text, annotators=None, output_format=None, properties=None, reset_default=None, **kwargs):\n        \"\"\"\n        Send a request to the CoreNLP server.\n\n        :param (str | unicode) text: raw text for the CoreNLPServer to parse\n        :param (list | string) annotators: list of annotators to use\n        :param (str) output_format: output type from server: serialized, json, text, conll, conllu, or xml\n        :param (dict) properties: additional request properties (written on top of defaults)\n        :param (bool) reset_default: don't use server defaults\n\n        Precedence for settings:\n\n        1. annotators and output_format args\n        2. Values from properties dict\n        3. Client defaults self.annotators and self.output_format (set during client construction)\n        4. Server defaults\n\n        Additional request parameters (apart from CoreNLP pipeline properties) such as 'username' and 'password'\n        can be specified with the kwargs.\n\n        :return: request result\n        \"\"\"\n\n        # validate request properties\n        validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format)\n        # set request properties\n        request_properties = {}\n\n        # start with client defaults\n        if self.annotators is not None:\n            request_properties['annotators'] = self.annotators\n        if self.output_format is not None:\n            request_properties['outputFormat'] = self.output_format\n\n        # add values from properties arg\n        # handle str case\n        if type(properties) == str:\n            if is_corenlp_lang(properties):\n                properties = {'pipelineLanguage': properties.lower()}\n                if reset_default is None:\n                    reset_default = True\n            else:\n                raise ValueError(f\"Unrecognized properties keyword {properties}\")\n\n        if type(properties) == dict:\n            request_properties.update(properties)\n\n        # if annotators list is specified, override with that\n        # also can use the annotators field the object was created with\n        if annotators is not None and (type(annotators) == str or type(annotators) == list):\n            request_properties['annotators'] = annotators if type(annotators) == str else \",\".join(annotators)\n\n        # if output format is specified, override with that\n        if output_format is not None and type(output_format) == str:\n            request_properties['outputFormat'] = output_format\n\n        # make the request\n        # if not explicitly set or the case of pipelineLanguage, reset_default should be None\n        if reset_default is None:\n            reset_default = False\n        r = self._request(text.encode('utf-8'), request_properties, reset_default, **kwargs)\n        if request_properties[\"outputFormat\"] == \"json\":\n            return r.json()\n        elif request_properties[\"outputFormat\"] == \"serialized\":\n            doc = Document()\n            parseFromDelimitedString(doc, r.content)\n            return doc\n        elif request_properties[\"outputFormat\"] in [\"text\", \"conllu\", \"conll\", \"xml\"]:\n            return r.text\n        else:\n            return r\n\n    def update(self, doc, annotators=None, properties=None):\n        if properties is None:\n            properties = {}\n            properties.update({\n                'inputFormat': 'serialized',\n                'outputFormat': 'serialized',\n                'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'\n            })\n        if annotators:\n            properties['annotators'] = annotators if type(annotators) == str else \",\".join(annotators)\n        with io.BytesIO() as stream:\n            writeToDelimitedString(doc, stream)\n            msg = stream.getvalue()\n\n        r = self._request(msg, properties)\n        doc = Document()\n        parseFromDelimitedString(doc, r.content)\n        return doc\n\n    def tokensregex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None):\n        # this is required for some reason\n        matches = self.__regex('/tokensregex', text, pattern, filter, annotators, properties)\n        if to_words:\n            matches = regex_matches_to_indexed_words(matches)\n        return matches\n\n    def semgrex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None):\n        matches = self.__regex('/semgrex', text, pattern, filter, annotators, properties)\n        if to_words:\n            matches = regex_matches_to_indexed_words(matches)\n        return matches\n\n    def fill_tree_proto(self, tree, proto_tree):\n        if tree.label:\n            proto_tree.value = tree.label\n        for child in tree.children:\n            proto_child = proto_tree.child.add()\n            self.fill_tree_proto(child, proto_child)\n\n    def tregex(self, text=None, pattern=None, filter=False, annotators=None, properties=None, trees=None):\n        # parse is not included by default in some of the pipelines,\n        # so we may need to manually override the annotators\n        # to include parse in order for tregex to do anything\n        if annotators is None and self.annotators is not None:\n            assert isinstance(self.annotators, str)\n            pieces = self.annotators.split(\",\")\n            if \"parse\" not in pieces:\n                annotators = self.annotators + \",parse\"\n        else:\n            annotators = \"tokenize,ssplit,pos,parse\"\n        if pattern is None:\n            raise ValueError(\"Cannot have None as a pattern for tregex\")\n\n        # TODO: we could also allow for passing in a complete document,\n        # along with the original text, so that the spans returns are more accurate\n        if trees is not None:\n            if properties is None:\n                properties = {}\n            properties['inputFormat'] = 'serialized'\n            if 'serializer' not in properties:\n                properties['serializer'] = 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'\n            doc = Document()\n            full_text = []\n            for tree_idx, tree in enumerate(trees):\n                sentence = doc.sentence.add()\n                sentence.sentenceIndex = tree_idx\n                sentence.tokenOffsetBegin = len(full_text)\n                leaves = tree.leaf_labels()\n                full_text.extend(leaves)\n                sentence.tokenOffsetEnd = len(full_text)\n                self.fill_tree_proto(tree, sentence.parseTree)\n                for word in leaves:\n                    token = sentence.token.add()\n                    # the other side uses both value and word, weirdly enough\n                    token.value = word\n                    token.word = word\n                    # without the actual tokenization, at least we can\n                    # stop the words from running together\n                    token.after = \" \"\n            doc.text = \" \".join(full_text)\n            with io.BytesIO() as stream:\n                writeToDelimitedString(doc, stream)\n                text = stream.getvalue()\n\n        return self.__regex('/tregex', text, pattern, filter, annotators, properties)\n\n    def __regex(self, path, text, pattern, filter, annotators=None, properties=None):\n        \"\"\"\n        Send a regex-related request to the CoreNLP server.\n\n        :param (str | unicode) path: the path for the regex endpoint\n        :param text: raw text for the CoreNLPServer to apply the regex\n        :param (str | unicode) pattern: regex pattern\n        :param (bool) filter: option to filter sentences that contain matches, if false returns matches\n        :param properties: option to filter sentences that contain matches, if false returns matches\n        :return: request result\n        \"\"\"\n        if self.start_server is not StartServer.DONT_START:\n            self.ensure_alive()\n        if properties is None:\n            properties = {}\n            properties.update({\n                'inputFormat': 'text',\n                'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'\n            })\n        if annotators:\n            properties['annotators'] = \",\".join(annotators) if isinstance(annotators, list) else annotators\n\n        # force output for regex requests to be json\n        properties['outputFormat'] = 'json'\n        # if the server is trying to send back character offsets, it\n        # should send back codepoints counts as well in case the text\n        # has extra wide characters\n        properties['tokenize.codepoint'] = 'true'\n\n        try:\n            # Error occurs unless put properties in params\n            input_format = properties.get(\"inputFormat\", \"text\")\n            if input_format == \"text\":\n                ctype = \"text/plain; charset=utf-8\"\n            elif input_format == \"serialized\":\n                ctype = \"application/x-protobuf\"\n            else:\n                raise ValueError(\"Unrecognized inputFormat \" + input_format)\n            # change request method from `get` to `post` as required by CoreNLP\n            r = requests.post(\n                self.endpoint + path, params={\n                    'pattern': pattern,\n                    'filter': filter,\n                    'properties': str(properties)\n                },\n                data=text.encode('utf-8') if isinstance(text, str) else text,\n                headers={'content-type': ctype},\n                timeout=(self.timeout*2)/1000,\n            )\n            r.raise_for_status()\n            if r.encoding is None:\n                r.encoding = \"utf-8\"\n            return json.loads(r.text)\n        except requests.HTTPError as e:\n            if r.text.startswith(\"Timeout\"):\n                raise TimeoutException(r.text)\n            else:\n                raise AnnotationException(r.text)\n        except json.JSONDecodeError:\n            raise AnnotationException(r.text)\n\n\n    def scenegraph(self, text, properties=None):\n        \"\"\"\n        Send a request to the server which processes the text using SceneGraph\n\n        This will require a new CoreNLP release, 4.5.5 or later\n        \"\"\"\n        # since we're using requests ourself,\n        # check if the server has started or not\n        if self.start_server is not StartServer.DONT_START:\n            self.ensure_alive()\n\n        if properties is None:\n            properties = {}\n        # the only thing the scenegraph knows how to use is text\n        properties['inputFormat'] = 'text'\n        ctype = \"text/plain; charset=utf-8\"\n        # the json output format is much more useful\n        properties['outputFormat'] = 'json'\n        try:\n            r = requests.post(\n                self.endpoint + \"/scenegraph\",\n                params={\n                    'properties': str(properties)\n                },\n                data=text.encode('utf-8') if isinstance(text, str) else text,\n                headers={'content-type': ctype},\n                timeout=(self.timeout*2)/1000,\n            )\n            r.raise_for_status()\n            if r.encoding is None:\n                r.encoding = \"utf-8\"\n            return json.loads(r.text)\n        except requests.HTTPError as e:\n            if r.text.startswith(\"Timeout\"):\n                raise TimeoutException(r.text)\n            else:\n                raise AnnotationException(r.text)\n        except json.JSONDecodeError:\n            raise AnnotationException(r.text)\n\n\ndef read_corenlp_props(props_path):\n    \"\"\" Read a Stanford CoreNLP properties file into a dict \"\"\"\n    props_dict = {}\n    with open(props_path) as props_file:\n        entry_lines = [entry_line for entry_line in props_file.read().split('\\n')\n                       if entry_line.strip() and not entry_line.startswith('#')]\n        for entry_line in entry_lines:\n            k = entry_line.split('=')[0]\n            k_len = len(k+\"=\")\n            v = entry_line[k_len:]\n            props_dict[k.strip()] = v\n    return props_dict\n\n\ndef write_corenlp_props(props_dict, file_path=None):\n    \"\"\" Write a Stanford CoreNLP properties dict to a file \"\"\"\n    if file_path is None:\n        file_path = f\"corenlp_server-{uuid.uuid4().hex[:16]}.props\"\n        # confirm tmp file path matches pattern\n        assert SERVER_PROPS_TMP_FILE_PATTERN.match(file_path)\n    with open(file_path, 'w') as props_file:\n        for k, v in props_dict.items():\n            if isinstance(v, list):\n                writeable_v = \",\".join(v)\n            else:\n                writeable_v = v\n            props_file.write(f'{k} = {writeable_v}\\n\\n')\n    return file_path\n\n\ndef regex_matches_to_indexed_words(matches):\n    \"\"\"\n    Transforms tokensregex and semgrex matches to indexed words.\n    :param matches: unprocessed regex matches\n    :return: flat array of indexed words\n    \"\"\"\n    words = [dict(v, **dict([('sentence', i)]))\n             for i, s in enumerate(matches['sentences'])\n             for k, v in s.items() if k != 'length']\n    return words\n\n\n__all__ = [\"CoreNLPClient\", \"AnnotationException\", \"TimeoutException\", \"to_text\"]\n"
  },
  {
    "path": "stanza/server/dependency_converter.py",
    "content": "\"\"\"\nA converter from constituency trees to dependency trees using CoreNLP's UniversalEnglish converter.\n\nONLY works on English.\n\"\"\"\n\nimport stanza\nfrom stanza.protobuf import DependencyConverterRequest, DependencyConverterResponse\nfrom stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext\n\nCONVERTER_JAVA = \"edu.stanford.nlp.trees.ProcessDependencyConverterRequest\"\n\ndef send_converter_request(request, classpath=None):\n    return send_request(request, DependencyConverterResponse, CONVERTER_JAVA, classpath=classpath)\n\ndef build_request(doc):\n    \"\"\"\n    Request format is simple: one tree per sentence in the document\n    \"\"\"\n    request = DependencyConverterRequest()\n    for sentence in doc.sentences:\n        request.trees.append(build_tree(sentence.constituency, None))\n    return request\n\ndef process_doc(doc, classpath=None):\n    \"\"\"\n    Convert the constituency trees in the document,\n    then attach the resulting dependencies to the sentences\n    \"\"\"\n    request = build_request(doc)\n    response = send_converter_request(request, classpath=classpath)\n    attach_dependencies(doc, response)\n\ndef attach_dependencies(doc, response):\n    if len(doc.sentences) != len(response.conversions):\n        raise ValueError(\"Sent %d sentences but got back %d conversions\" % (len(doc.sentences), len(response.conversions)))\n    for sent_idx, (sentence, conversion) in enumerate(zip(doc.sentences, response.conversions)):\n        graph = conversion.graph\n\n        # The deterministic conversion should have an equal number of words and one fewer edge\n        # ... the root is represented by a word with no parent\n        if len(sentence.words) != len(graph.node):\n            raise ValueError(\"Sentence %d of the conversion should have %d words but got back %d nodes in the graph\" % (sent_idx, len(sentence.words), len(graph.node)))        \n        if len(sentence.words) != len(graph.edge) + 1:\n            raise ValueError(\"Sentence %d of the conversion should have %d edges (one per word, plus the root) but got back %d edges in the graph\" % (sent_idx, len(sentence.words) - 1, len(graph.edge)))\n\n        expected_nodes = set(range(1, len(sentence.words) + 1))\n        targets = set()\n        for edge in graph.edge:\n            if edge.target in targets:\n                raise ValueError(\"Found two parents of %d in sentence %d\" % (edge.target, sent_idx))\n            targets.add(edge.target)\n            # -1 since the words are 0 indexed in the sentence,\n            # but we count dependencies from 1\n            sentence.words[edge.target-1].head = edge.source\n            sentence.words[edge.target-1].deprel = edge.dep\n        roots = expected_nodes - targets\n        assert len(roots) == 1\n        for root in roots:\n            sentence.words[root-1].head = 0\n            sentence.words[root-1].deprel = \"root\"\n        sentence.build_dependencies()\n\n\nclass DependencyConverter(JavaProtobufContext):\n    \"\"\"\n    Context window for the dependency converter\n\n    This is a context window which keeps a process open.  Should allow\n    for multiple requests without launching new java processes each time.\n    \"\"\"\n    def __init__(self, classpath=None):\n        super(DependencyConverter, self).__init__(classpath, DependencyConverterResponse, CONVERTER_JAVA)\n\n    def process(self, doc):\n        \"\"\"\n        Converts a constituency tree to dependency trees for each of the sentences in the document\n        \"\"\"\n        request = build_request(doc)\n        response = self.process_request(request)\n        attach_dependencies(doc, response)\n        return doc\n\ndef main():\n    nlp = stanza.Pipeline('en',\n                          processors='tokenize,pos,constituency')\n\n    doc = nlp('I like blue antennae.')\n    print(\"{:C}\".format(doc))\n    process_doc(doc, classpath=\"$CLASSPATH\")\n    print(\"{:C}\".format(doc))\n\n    doc = nlp('And I cannot lie.')\n    print(\"{:C}\".format(doc))\n    with DependencyConverter(classpath=\"$CLASSPATH\") as converter:\n        converter.process(doc)\n        print(\"{:C}\".format(doc))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/server/java_protobuf_requests.py",
    "content": "from collections import deque\nimport subprocess\n\nfrom stanza.models.common.utils import misc_to_space_after\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.protobuf import DependencyGraph, FlattenedParseTree\nfrom stanza.server.client import resolve_classpath\n\ndef send_request(request, response_type, java_main, classpath=None):\n    \"\"\"\n    Use subprocess to run a Java protobuf processor on the given request\n\n    Returns the protobuf response\n    \"\"\"\n    classpath = resolve_classpath(classpath)\n    if classpath is None:\n        raise ValueError(\"Classpath is None,  Perhaps you need to set the $CLASSPATH or $CORENLP_HOME environment variable to point to a CoreNLP install.\")\n    pipe = subprocess.run([\"java\", \"-cp\", classpath, java_main],\n                          input=request.SerializeToString(),\n                          stdout=subprocess.PIPE,\n                          check=True)\n    response = response_type()\n    response.ParseFromString(pipe.stdout)\n    return response\n\ndef add_tree_nodes(proto_tree, tree, score):\n    # add an open node\n    node = proto_tree.nodes.add()\n    node.openNode = True\n    if score is not None:\n        node.score = score\n\n    # add the content of this node\n    node = proto_tree.nodes.add()\n    node.value = tree.label\n\n    # add all children...\n    # leaves get just one node\n    # branches are called recursively\n    for child in tree.children:\n        if child.is_leaf():\n            node = proto_tree.nodes.add()\n            node.value = child.label\n        else:\n            add_tree_nodes(proto_tree, child, None)\n\n    node = proto_tree.nodes.add()\n    node.closeNode = True\n\ndef build_tree(tree, score):\n    \"\"\"\n    Builds a FlattenedParseTree from CoreNLP.proto\n\n    Populates the value field from tree.label and iterates through the\n    children via tree.children.  Should work on any tree structure\n    which follows that layout\n\n    The score will be added to the top node (if it is not None)\n\n    Operates by recursively calling add_tree_nodes\n    \"\"\"\n    proto_tree = FlattenedParseTree()\n    add_tree_nodes(proto_tree, tree, score)\n    return proto_tree\n\ndef from_tree(proto_tree):\n    \"\"\"\n    Convert a FlattenedParseTree back into a Tree\n\n    returns Tree, score\n      (score might be None if it is missing)\n    \"\"\"\n    score = None\n    stack = deque()\n    for node in proto_tree.nodes:\n        if node.HasField(\"score\") and score is None:\n            score = node.score\n\n        if node.openNode:\n            if len(stack) > 0 and isinstance(stack[-1], FlattenedParseTree.Node) and stack[-1].openNode:\n                raise ValueError(\"Got a proto with no label on a node: {}\".format(proto_tree))\n            stack.append(node)\n            continue\n        if not node.closeNode:\n            child = Tree(label=node.value)\n            # TODO: do something with the score\n            stack.append(child)\n            continue\n\n        # must be a close operation...\n        if len(stack) <= 1:\n            raise ValueError(\"Got a proto with too many close operations: {}\".format(proto_tree))\n        # on a close operation, pop until we hit the open\n        # then turn everything in that span into a new node\n        children = []\n        nextNode = stack.pop()\n        while not isinstance(nextNode, FlattenedParseTree.Node):\n            children.append(nextNode)\n            nextNode = stack.pop()\n        if len(children) == 0:\n            raise ValueError(\"Got a proto with an open immediately followed by a close: {}\".format(proto_tree))\n        children.reverse()\n        label = children[0]\n        children = children[1:]\n        subtree = Tree(label=label.label, children=children)\n        stack.append(subtree)\n\n    if len(stack) > 1:\n        raise ValueError(\"Got a proto which does not close all of the nodes: {}\".format(proto_tree))\n    tree = stack.pop()\n    if not isinstance(tree, Tree):\n        raise ValueError(\"Got a proto which was just one Open operation: {}\".format(proto_tree))\n    return tree, score\n\ndef add_token(token_list, word, token):\n    \"\"\"\n    Add a token to a proto request.\n\n    CoreNLP tokens have components of both word and token from stanza.\n\n    We pass along \"after\" but not \"before\"\n    \"\"\"\n    if token is None and isinstance(word.id, int):\n        raise AssertionError(\"Only expected word w/o token for 'extra' words\")\n\n    query_token = token_list.add()\n    query_token.word = word.text\n    query_token.value = word.text\n    if word.lemma is not None:\n        query_token.lemma = word.lemma\n    if word.xpos is not None:\n        query_token.pos = word.xpos\n    if word.upos is not None:\n        query_token.coarseTag = word.upos\n    if word.feats and word.feats != \"_\":\n        for feature in word.feats.split(\"|\"):\n            key, value = feature.split(\"=\", maxsplit=1)\n            query_token.conllUFeatures.key.append(key)\n            query_token.conllUFeatures.value.append(value)\n    if token is not None:\n        if token.ner is not None:\n            query_token.ner = token.ner\n        if token is not None and len(token.id) > 1:\n            query_token.mwtText = token.text\n            query_token.isMWT = True\n            query_token.isFirstMWT = token.id[0] == word.id\n        if token.id[-1] != word.id:\n            # if we are not the last word of an MWT token\n            # we are absolutely not followed by space\n            pass\n        else:\n            query_token.after = token.spaces_after\n\n        query_token.index = word.id\n    else:\n        # presumably empty words won't really be written this way,\n        # but we can still keep track of it\n        query_token.after = misc_to_space_after(word.misc)\n\n        query_token.index = word.id[0]\n        query_token.emptyIndex = word.id[1]\n\n    if word.misc and word.misc != \"_\":\n        query_token.conllUMisc = word.misc\n    if token is not None and token.misc and token.misc != \"_\":\n        query_token.mwtMisc = token.misc\n\ndef add_sentence(request_sentences, sentence, num_tokens):\n    \"\"\"\n    Add the tokens for this stanza sentence to a list of protobuf sentences\n    \"\"\"\n    request_sentence = request_sentences.add()\n    request_sentence.tokenOffsetBegin = num_tokens\n    request_sentence.tokenOffsetEnd = num_tokens + sum(len(token.words) for token in sentence.tokens)\n    for token in sentence.tokens:\n        for word in token.words:\n            add_token(request_sentence.token, word, token)\n    return request_sentence\n\ndef add_word_to_graph(graph, word, sent_idx):\n    \"\"\"\n    Add a node and possibly an edge for a word in a basic dependency graph.\n    \"\"\"\n    node = graph.node.add()\n    node.sentenceIndex = sent_idx+1\n    if isinstance(word.id, int):\n        node.index = word.id\n    else:\n        node.index = word.id[0]\n        node.emptyIndex = word.id[1]\n\n    if word.head != 0 and word.head is not None:\n        edge = graph.edge.add()\n        edge.source = word.head\n        if isinstance(word.id, int):\n            edge.target = word.id\n        else:\n            edge.target = word.id[0]\n            edge.targetEmpty = word.id[1]\n        if word.deprel is not None:\n            edge.dep = word.deprel\n        else:\n            # the receiving side doesn't like null as a dependency\n            edge.dep = \"_\"\n\ndef convert_networkx_graph(graph_proto, sentence, sent_idx):\n    \"\"\"\n    Turns a networkx graph into a DependencyGraph from the proto file\n    \"\"\"\n    for token in sentence.tokens:\n        for word in token.words:\n            add_token(graph_proto.token, word, token)\n    for word in sentence.empty_words:\n        add_token(graph_proto.token, word, None)\n\n    dependencies = sentence._enhanced_dependencies\n    for target in dependencies:\n        if target == 0:\n            # don't need to send the explicit root\n            continue\n        for source in dependencies.predecessors(target):\n            if source == 0:\n                # unlike with basic, we need to send over the roots,\n                # as the enhanced can have loops\n                graph_proto.rootNode.append(len(graph_proto.node))\n                continue\n            for deprel in dependencies.get_edge_data(source, target):\n                edge = graph_proto.edge.add()\n                if isinstance(source, int):\n                    edge.source = source\n                else:\n                    edge.source = source[0]\n                    if source[1] != 0:\n                        edge.sourceEmpty = source[1]\n                if isinstance(target, int):\n                    edge.target = target\n                else:\n                    edge.target = target[0]\n                    if target[1] != 0:\n                        edge.targetEmpty = target[1]\n                edge.dep = deprel\n        node = graph_proto.node.add()\n        node.sentenceIndex = sent_idx + 1\n        # the nodes in the networkx graph are indexed from 1, not counting the root\n        if isinstance(target, int):\n            node.index = target\n        else:\n            node.index = target[0]\n            if target[1] != 0:\n                node.emptyIndex = target[1]\n    return graph_proto\n\ndef features_to_string(features):\n    if not features:\n        return None\n    if len(features.key) == 0:\n        return None\n    return \"|\".join(\"%s=%s\" % (key, value) for key, value in zip(features.key, features.value))\n\ndef misc_space_pieces(misc):\n    \"\"\"\n    Return only the space-related misc pieces\n    \"\"\"\n    if misc is None or misc == \"\" or misc == \"_\":\n        return misc\n    pieces = misc.split(\"|\")\n    pieces = [x for x in pieces if x.split(\"=\", maxsplit=1)[0] in (\"SpaceAfter\", \"SpacesAfter\", \"SpacesBefore\")]\n    if len(pieces) > 0:\n        return \"|\".join(pieces)\n    return None\n\ndef remove_space_misc(misc):\n    \"\"\"\n    Remove any pieces from misc which are space-related\n    \"\"\"\n    if misc is None or misc == \"\" or misc == \"_\":\n        return misc\n    pieces = misc.split(\"|\")\n    pieces = [x for x in pieces if x.split(\"=\", maxsplit=1)[0] not in (\"SpaceAfter\", \"SpacesAfter\", \"SpacesBefore\")]\n    if len(pieces) > 0:\n        return \"|\".join(pieces)\n    return None\n\ndef substitute_space_misc(misc, space_misc):\n    space_misc_pieces = space_misc.split(\"|\") if space_misc else []\n    space_misc_after = None\n    space_misc_before = None\n    for piece in space_misc_pieces:\n        if piece.startswith(\"SpaceBefore\"):\n            space_misc_before = piece\n        elif piece.startswith(\"SpaceAfter\") or piece.startswith(\"SpacesAfter\"):\n            space_misc_after = piece\n        else:\n            raise AssertionError(\"An unknown piece wound up in the misc space fields: %s\" % piece)\n\n    pieces = misc.split(\"|\")\n    new_pieces = []\n    for piece in pieces:\n        if piece.startswith(\"SpaceBefore\"):\n            if space_misc_before:\n                new_pieces.append(space_misc_before)\n                space_misc_before = None\n        elif piece.startswith(\"SpaceAfter\") or piece.startswith(\"SpacesAfter\"):\n            if space_misc_after:\n                new_pieces.append(space_misc_after)\n                space_misc_after = None\n        else:\n            new_pieces.append(piece)\n    if space_misc_after:\n        new_pieces.append(space_misc_after)\n    if space_misc_before:\n        new_pieces.append(space_misc_before)\n    if len(new_pieces) == 0:\n        return None\n    return \"|\".join(new_pieces)\n\nclass JavaProtobufContext(object):\n    \"\"\"\n    A generic context for sending requests to a java program using protobufs in a subprocess\n    \"\"\"\n    def __init__(self, classpath, build_response, java_main, extra_args=None):\n        self.classpath = resolve_classpath(classpath)\n        self.build_response = build_response\n        self.java_main = java_main\n\n        if extra_args is None:\n            extra_args = []\n        self.extra_args = extra_args\n        self.pipe = None\n\n    def open_pipe(self):\n        self.pipe = subprocess.Popen([\"java\", \"-cp\", self.classpath, self.java_main, \"-multiple\"] + self.extra_args,\n                                     stdin=subprocess.PIPE,\n                                     stdout=subprocess.PIPE)\n\n    def close_pipe(self):\n        if self.pipe.poll() is None:\n            self.pipe.stdin.write((0).to_bytes(4, 'big'))\n            self.pipe.stdin.flush()\n            self.pipe = None\n\n    def __enter__(self):\n        self.open_pipe()\n        return self\n\n    def __exit__(self, type, value, traceback):\n        self.close_pipe()\n\n    def process_request(self, request):\n        if self.pipe is None:\n            raise RuntimeError(\"Pipe to java process is not open or was closed\")\n\n        text = request.SerializeToString()\n        self.pipe.stdin.write(len(text).to_bytes(4, 'big'))\n        self.pipe.stdin.write(text)\n        self.pipe.stdin.flush()\n        response_length = self.pipe.stdout.read(4)\n        if len(response_length) < 4:\n            raise BrokenPipeError(\"Could not communicate with java process!\")\n        response_length = int.from_bytes(response_length, \"big\")\n        response_text = self.pipe.stdout.read(response_length)\n        response = self.build_response()\n        response.ParseFromString(response_text)\n        return response\n\n"
  },
  {
    "path": "stanza/server/main.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nSimple shell program to pipe in \n\"\"\"\n\nimport corenlp\n\nimport json\nimport re\nimport csv\nimport sys\nfrom collections import namedtuple, OrderedDict\n\nFLOAT_RE = re.compile(r\"\\d*\\.\\d+\")\nINT_RE = re.compile(r\"\\d+\")\n\ndef dictstr(arg):\n    \"\"\"\n    Parse a key=value string as a tuple (key, value) that can be provided as an argument to dict()\n    \"\"\"\n    key, value = arg.split(\"=\")\n\n    if value.lower() == \"true\" or value.lower() == \"false\":\n        value = bool(value)\n    elif INT_RE.match(value):\n        value = int(value)\n    elif FLOAT_RE.match(value):\n        value = float(value)\n    return (key, value)\n\n\ndef do_annotate(args):\n    args.props = dict(args.props) if args.props else {}\n    if args.sentence_mode:\n        args.props[\"ssplit.isOneSentence\"] = True\n\n    with corenlp.CoreNLPClient(annotators=args.annotators, properties=args.props, be_quiet=not args.verbose_server) as client:\n        for line in args.input:\n            if line.startswith(\"#\"): continue\n\n            ann = client.annotate(line.strip(), output_format=args.format)\n\n            if args.format == \"json\":\n                if args.sentence_mode:\n                    ann = ann[\"sentences\"][0]\n\n                args.output.write(json.dumps(ann))\n                args.output.write(\"\\n\")\n\ndef main():\n    import argparse\n    parser = argparse.ArgumentParser(description='Annotate data')\n    parser.add_argument('-i', '--input', type=argparse.FileType('r'), default=sys.stdin, help=\"Input file to process; each line contains one document (default: stdin)\")\n    parser.add_argument('-o', '--output', type=argparse.FileType('w'), default=sys.stdout, help=\"File to write annotations to (default: stdout)\")\n    parser.add_argument('-f', '--format', choices=[\"json\",], default=\"json\", help=\"Output format\")\n    parser.add_argument('-a', '--annotators', nargs=\"+\", type=str, default=[\"tokenize ssplit lemma pos\"], help=\"A list of annotators\")\n    parser.add_argument('-s', '--sentence-mode', action=\"store_true\",help=\"Assume each line of input is a sentence.\")\n    parser.add_argument('-v', '--verbose-server', action=\"store_true\",help=\"Server is made verbose\")\n    parser.add_argument('-m', '--memory', type=str, default=\"4G\", help=\"Memory to use for the server\")\n    parser.add_argument('-p', '--props', nargs=\"+\", type=dictstr, help=\"Properties as a list of key=value pairs\")\n    parser.set_defaults(func=do_annotate)\n\n    ARGS = parser.parse_args()\n    if ARGS.func is None:\n        parser.print_help()\n        sys.exit(1)\n    else:\n        ARGS.func(ARGS)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/server/morphology.py",
    "content": "\"\"\"\nDirect pipe connection to the Java CoreNLP Morphology class\n\nOnly effective for English.  Must be supplied with PTB scheme xpos, not upos\n\"\"\"\n\n\nfrom stanza.protobuf import MorphologyRequest, MorphologyResponse\nfrom stanza.server.java_protobuf_requests import send_request, JavaProtobufContext\n\n\nMORPHOLOGY_JAVA = \"edu.stanford.nlp.process.ProcessMorphologyRequest\"\n\ndef send_morphology_request(request):\n    return send_request(request, MorphologyResponse, MORPHOLOGY_JAVA)\n\ndef build_request(words, xpos_tags):\n    \"\"\"\n    Turn a list of words and a list of tags into a request\n\n    tags must be xpos, not upos\n    \"\"\"\n    request = MorphologyRequest()\n    for word, tag in zip(words, xpos_tags):\n        tagged_word = request.words.add()\n        tagged_word.word = word\n        tagged_word.xpos = tag\n    return request\n\n\ndef process_text(words, xpos_tags):\n    \"\"\"\n    Get the lemmata for each word/tag pair\n\n    Currently the return is a MorphologyResponse from CoreNLP.proto\n\n    tags must be xpos, not upos\n    \"\"\"\n    request = build_request(words, xpos_tags)\n\n    return send_morphology_request(request)\n\n\n\nclass Morphology(JavaProtobufContext):\n    \"\"\"\n    Morphology context window\n\n    This is a context window which keeps a process open.  Should allow\n    for multiple requests without launching new java processes each time.\n\n    (much faster than calling process_text over and over)\n    \"\"\"\n    def __init__(self, classpath=None):\n        super(Morphology, self).__init__(classpath, MorphologyResponse, MORPHOLOGY_JAVA)\n\n    def process(self, words, xpos_tags):\n        \"\"\"\n        Get the lemmata for each word/tag pair\n        \"\"\"\n        request = build_request(words, xpos_tags)\n        return self.process_request(request)\n\n\ndef main():\n    # TODO: turn this into a unit test, once a new CoreNLP is released\n    words    = [\"Jennifer\", \"has\",  \"the\", \"prettiest\", \"antennae\"]\n    tags     = [\"NNP\",      \"VBZ\",  \"DT\",  \"JJS\",       \"NNS\"]\n    expected = [\"Jennifer\", \"have\", \"the\", \"pretty\",    \"antenna\"]\n    result = process_text(words, tags)\n    lemma = [x.lemma for x in result.words]\n    print(lemma)\n    assert lemma == expected\n\n    with Morphology() as morph:\n        result = morph.process(words, tags)\n        lemma = [x.lemma for x in result.words]\n        assert lemma == expected\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/server/parser_eval.py",
    "content": "\"\"\"\nThis class runs a Java process to evaluate a treebank prediction using CoreNLP\n\"\"\"\n\nfrom collections import namedtuple\nimport sys\n\nimport stanza\nfrom stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse\nfrom stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext\nfrom stanza.models.constituency.tree_reader import read_treebank\n\nEVALUATE_JAVA = \"edu.stanford.nlp.parser.metrics.EvaluateExternalParser\"\n\nParseResult = namedtuple(\"ParseResult\", ['gold', 'predictions', 'state', 'constituents'])\nScoredTree = namedtuple(\"ScoredTree\", ['tree', 'score'])\n\ndef build_request(treebank):\n    \"\"\"\n    treebank should be a list of pairs:  [gold, predictions]\n      each predictions is a list of tuples (prediction, score, state)\n      state is ignored and can be None\n    Note that for now, only one tree is measured, but this may be extensible in the future\n    Trees should be in the form of a Tree from parse_tree.py\n    \"\"\"\n    request = EvaluateParserRequest()\n    for raw_result in treebank:\n        gold = raw_result.gold\n        predictions = raw_result.predictions\n        parse_result = request.treebank.add()\n        parse_result.gold.CopyFrom(build_tree(gold, None))\n        for pred in predictions:\n            if isinstance(pred, tuple):\n                prediction, score = pred\n            else:\n                prediction = pred\n                score = None\n            try:\n                parse_result.predicted.append(build_tree(prediction, score))\n            except Exception as e:\n                raise RuntimeError(\"Unable to build parser request from tree {}\".format(pred)) from e\n\n    return request\n\ndef collate(gold_treebank, predictions_treebank):\n    \"\"\"\n    Turns a list of gold and prediction into a evaluation object\n    \"\"\"\n    treebank = []\n    for gold, prediction in zip(gold_treebank, predictions_treebank):\n        result = ParseResult(gold, [prediction], None, None)\n        treebank.append(result)\n    return treebank\n\n\nclass EvaluateParser(JavaProtobufContext):\n    \"\"\"\n    Parser evaluation context window\n\n    This is a context window which keeps a process open.  Should allow\n    for multiple requests without launching new java processes each time.\n    \"\"\"\n    def __init__(self, classpath=None, kbest=None, silent=False):\n        if kbest is not None:\n            extra_args = [\"-evalPCFGkBest\", \"{}\".format(kbest), \"-evals\", \"pcfgTopK\"]\n        else:\n            extra_args = []\n\n        if silent:\n            extra_args.extend([\"-evals\", \"summary=False\"])\n\n        super(EvaluateParser, self).__init__(classpath, EvaluateParserResponse, EVALUATE_JAVA, extra_args=extra_args)\n\n    def process(self, treebank):\n        request = build_request(treebank)\n        return self.process_request(request)\n\n\ndef main():\n    gold = read_treebank(sys.argv[1])\n    predictions = read_treebank(sys.argv[2])\n    treebank = collate(gold, predictions)\n\n    with EvaluateParser() as ep:\n        ep.process(treebank)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/server/semgrex.py",
    "content": "\"\"\"Invokes the Java semgrex on a document\n\nThe server client has a method \"semgrex\" which sends text to Java\nCoreNLP for processing with a semgrex (SEMantic GRaph regEX) query:\n\nhttps://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html\n\nHowever, this operates on text using the CoreNLP tools, which means\nthe dependency graphs may not align with stanza's depparse module, and\nthis also limits the languages for which it can be used.  This module\nallows for running semgrex commands on the graphs produced by\ndepparse.\n\nTo use, first process text into a doc using stanza.Pipeline\n\nNext, pass the processed doc and a list of semgrex patterns to\nprocess_doc in this module.  It will run the java semgrex module as a\nsubprocess and return the result in the form of a SemgrexResponse,\nwhose description is in the proto file included with stanza.\n\nA minimal example is the main method of this module.\n\nNote that launching the subprocess is potentially quite expensive\nrelative to the search if used many times on small documents.  Ideally\nlarger texts would be processed, and all of the desired semgrex\npatterns would be run at once.  The worst thing to do would be to call\nthis multiple times on a large document, one invocation per semgrex\npattern, as that would serialize the document each time.\nIncluded here is a context manager which allows for keeping the same\njava process open for multiple requests.  This saves on the subprocess\nlaunching time.  It is still important not to wastefully serialize the\nsame document over and over, though.\n\"\"\"\n\nimport argparse\nfrom collections import namedtuple\nimport copy\nimport os\nimport re\n\nimport stanza\nfrom stanza.protobuf import SemgrexRequest, SemgrexResponse\nfrom stanza.server.java_protobuf_requests import send_request, add_token, add_word_to_graph, JavaProtobufContext, convert_networkx_graph\nfrom stanza.utils.conll import CoNLL\n\nSEMGREX_JAVA = \"edu.stanford.nlp.semgraph.semgrex.ProcessSemgrexRequest\"\n\nSemgrexQuery = namedtuple(\"SemgrexQuery\", \"pattern comments\")\n\ndef send_semgrex_request(request):\n    return send_request(request, SemgrexResponse, SEMGREX_JAVA)\n\ndef build_request(doc, semgrex_patterns, enhanced=False):\n    request = SemgrexRequest()\n    if isinstance(semgrex_patterns, str):\n        semgrex_patterns = [semgrex_patterns]\n    semgrex_patterns = [x if isinstance(x, SemgrexQuery) else SemgrexQuery(x, []) for x in semgrex_patterns]\n    for semgrex in semgrex_patterns:\n        request.semgrex.append(semgrex.pattern)\n\n    for sent_idx, sentence in enumerate(doc.sentences):\n        query = request.query.add()\n        if enhanced:\n            # tokens will be added on to the graph object\n            convert_networkx_graph(query.graph, sentence, sent_idx)\n        else:\n            word_idx = 0\n            for token in sentence.tokens:\n                for word in token.words:\n                    add_token(query.token, word, token)\n                    add_word_to_graph(query.graph, word, sent_idx)\n\n                    word_idx = word_idx + 1\n\n    return request\n\ndef process_doc(doc, *semgrex_patterns, enhanced=False):\n    \"\"\"\n    Returns the result of processing the given semgrex expression on the stanza doc.\n\n    Currently the return is a SemgrexResponse from CoreNLP.proto\n    \"\"\"\n    request = build_request(doc, semgrex_patterns, enhanced=enhanced)\n\n    return send_semgrex_request(request)\n\nclass Semgrex(JavaProtobufContext):\n    \"\"\"\n    Semgrex context window\n\n    This is a context window which keeps a process open.  Should allow\n    for multiple requests without launching new java processes each time.\n    \"\"\"\n    def __init__(self, classpath=None):\n        super(Semgrex, self).__init__(classpath, SemgrexResponse, SEMGREX_JAVA)\n\n    def process(self, doc, *semgrex_patterns):\n        \"\"\"\n        Apply each of the semgrex patterns to each of the dependency trees in doc\n        \"\"\"\n        request = build_request(doc, semgrex_patterns)\n        return self.process_request(request)\n\ndef annotate_doc(doc, semgrex_result, semgrex_patterns, matches_only, exclude_matches):\n    \"\"\"\n    Put comments on the sentences which describe the matching semgrex patterns\n    \"\"\"\n    doc = copy.deepcopy(doc)\n    if isinstance(semgrex_patterns, str):\n        semgrex_patterns = [semgrex_patterns]\n    semgrex_patterns = [x if isinstance(x, SemgrexQuery) else SemgrexQuery(x, []) for x in semgrex_patterns]\n    matched_ids = set()\n    for sentence_result in semgrex_result.result:\n        for pattern_result in sentence_result.result:\n            for match in pattern_result.match:\n                matched_ids.add(match.sentenceIndex)\n\n    pattern_texts = [semgrex_pattern.pattern.replace(\"\\n\", \" \") for semgrex_pattern in semgrex_patterns]\n\n    matching_sentences = []\n    for sentence_result in semgrex_result.result:\n        sentence_matched = False\n        matched_semgrex_ids = set()\n        for pattern_result in sentence_result.result:\n            if len(pattern_result.match) == 0:\n                continue\n\n            highlight_tokens = []\n            highlight_edges = []\n            for match in pattern_result.match:\n                sentence_matched = True\n                sentence = doc.sentences[match.sentenceIndex]\n                semgrex_pattern = semgrex_patterns[match.semgrexIndex]\n                pattern_text = pattern_texts[match.semgrexIndex]\n                matched_semgrex_ids.add(match.semgrexIndex)\n\n                match_word = \"%d:%s\" % (match.matchIndex, sentence.words[match.matchIndex-1].text)\n                if len(match.node) == 0:\n                    node_matches = \"\"\n                else:\n                    node_matches = [\"%s=%d:%s\" % (node.name, node.matchIndex, sentence.words[node.matchIndex-1].text)\n                                    for node in match.node]\n                    node_matches = \"  \" + \" \".join(node_matches)\n                if len(match.varstring) == 0:\n                    var_values = \"\"\n                else:\n                    var_values = [\"%s=%s\" % (v.name, v.value) for v in match.varstring]\n                    var_values = \"  \" + \" \".join(var_values)\n                sentence.add_comment(\"# semgrex pattern |%s| matched at %s%s%s\" % (pattern_text, match_word, node_matches, var_values))\n                for comment in semgrex_pattern.comments:\n                    sentence.add_comment(\"# semgrex comment: %s\" % comment)\n                highlight_tokens.append(match.matchIndex)\n                for edge in match.edge:\n                    highlight_edges.append(edge.target)\n            if len(highlight_tokens) > 0:\n                sentence.add_comment(\"# highlight tokens = %s\" % (\" \".join(\"%d\" % x for x in highlight_tokens)))\n            if len(highlight_edges) > 0:\n                sentence.add_comment(\"# highlight deprels = %s\" % (\" \".join(\"%d\" % x for x in highlight_edges)))\n\n        if sentence_matched and not matches_only:\n            for semgrex_idx, pattern_text in enumerate(pattern_texts):\n                if semgrex_idx not in matched_semgrex_ids:\n                    sentence.add_comment(\"# semgrex pattern |%s| did not match!\" % pattern_text)\n\n        if sentence_matched:\n            matching_sentences.append(sentence)\n\n    nonmatching_sentences = [sentence for sentence_idx, sentence in enumerate(doc.sentences) if sentence_idx not in matched_ids]\n    for sentence in nonmatching_sentences:\n        for semgrex_idx, pattern_text in enumerate(pattern_texts):\n            sentence.add_comment(\"# semgrex pattern |%s| did not match!\" % pattern_text)\n\n    if matches_only:\n        doc.sentences = matching_sentences\n    elif exclude_matches:\n        doc.sentences = nonmatching_sentences\n    return doc\n\n\ndef main():\n    \"\"\"\n    Runs a toy example, or can run a given semgrex expression on the given input file.\n\n    For example:\n    python3 -m stanza.server.semgrex --input_file demo/semgrex_sample.conllu\n\n    --matches_only to only print sentences that match the semgrex pattern\n    --no_print_input to not print the input\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input', type=str, default=None, help='Process this file or directory')\n    parser.add_argument('--input_filter', type=str, default=None, help='Only process files that match this regex')\n    parser.add_argument('semgrex', type=str, nargs=\"*\", default=[\"{}=source >obj=zzz {}=target\"], help=\"Semgrex to apply to the text.  The default looks for sentences with objects\")\n    parser.add_argument('--semgrex_file', type=str, default=None, help=\"File to read semgrex patterns from - relevant in case the pattern you want to use doesn't work well on the command line, for example\")\n    parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help=\"Print the input alongside the output - gets kind of noisy\")\n    parser.add_argument('--no_print_input', dest='print_input', action='store_false', help=\"Don't print the input alongside the output - gets kind of noisy\")\n\n    parser.add_argument('--matches_only', action='store_true', default=True, help=\"Only print the matching sentences\")\n    parser.add_argument('--no_matches_only', dest='matches_only', action='store_false', help=\"Only print the matching sentences\")\n    parser.add_argument('--exclude_matches', action='store_true', default=False, help=\"Only print the NON-matching sentences\")\n\n    parser.add_argument('--enhanced', action='store_true', default=False, help='Use the enhanced dependencies instead of the basic')\n    parser.add_argument('--no_combined_doc', dest='combined_doc', action='store_false', default=True, help='By default, combine all the input docs into one big document.  Allows for easier secondary processing like sorting')\n    args = parser.parse_args()\n\n    if args.semgrex_file:\n        with open(args.semgrex_file) as fin:\n            args.semgrex = [x.strip() for x in fin.readlines()]\n\n    semgrex_patterns = []\n    current_comments = []\n    for line in args.semgrex:\n        if not line:\n            current_comments = []\n        elif line.startswith(\"#\"):\n            current_comments.append(line[1:].strip())\n        else:\n            semgrex_patterns.append(SemgrexQuery(line, current_comments))\n            current_comments = []\n        args.semgrex = semgrex_patterns\n\n    if args.input:\n        if os.path.isfile(args.input):\n            docs = [CoNLL.conll2doc(input_file=args.input, ignore_gapping=False)]\n        else:\n            filenames = sorted(os.listdir(args.input))\n            if args.input_filter:\n                input_filter = re.compile(args.input_filter)\n                filenames = [x for x in filenames if input_filter.match(x)]\n            filenames = [os.path.join(args.input, filename) for filename in filenames]\n            filenames = [filename for filename in filenames if os.path.isfile(filename)]\n            docs = [CoNLL.conll2doc(input_file=filename, ignore_gapping=False) for filename in filenames]\n    else:\n        nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma,depparse')\n        docs = [nlp('Uro ruined modern.  Fortunately, Wotc banned him.')]\n\n    if args.combined_doc:\n        sentences = [sent for doc in docs for sent in doc.sentences]\n        docs = [docs[0]]\n        docs[0].sentences = sentences\n\n    for doc in docs:\n        if args.print_input:\n            print(\"{:C}\".format(doc))\n            print()\n            print(\"-\" * 75)\n            print()\n        semgrex_result = process_doc(doc, *args.semgrex, enhanced=args.enhanced)\n        doc = annotate_doc(doc, semgrex_result, args.semgrex, args.matches_only, args.exclude_matches)\n        if len(doc.sentences) > 0:\n            print(\"{:C}\\n\".format(doc))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/server/ssurgeon.py",
    "content": "\"\"\"Invokes the Java ssurgeon on a document\n\n\"ssurgeon\" sends text to Java CoreNLP for processing with a ssurgeon\n(Semantic graph SURGEON) query\n\nThe main program in this file gives a very short intro to how to use it.\n\"\"\"\n\n\nimport argparse\nfrom collections import namedtuple\nimport copy\nimport os\nimport re\nimport sys\n\nfrom stanza.models.common.utils import misc_to_space_after, space_after_to_misc\nfrom stanza.protobuf import SsurgeonRequest, SsurgeonResponse\nfrom stanza.server import java_protobuf_requests\nfrom stanza.utils.conll import CoNLL\n\nfrom stanza.models.common.doc import ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, NER, Word, Token, Sentence\n\nSSURGEON_JAVA = \"edu.stanford.nlp.semgraph.semgrex.ssurgeon.ProcessSsurgeonRequest\"\n\nSsurgeonEdit = namedtuple(\"SsurgeonEdit\",\n                          \"semgrex_pattern ssurgeon_edits ssurgeon_id notes language\",\n                          defaults=[None, None, \"UniversalEnglish\"])\n\ndef parse_ssurgeon_edits(ssurgeon_text):\n    ssurgeon_text = ssurgeon_text.strip()\n    ssurgeon_blocks = re.split(\"\\n\\n+\", ssurgeon_text)\n    ssurgeon_edits = []\n    for idx, block in enumerate(ssurgeon_blocks):\n        lines = block.split(\"\\n\")\n        comments = [line[1:].strip() for line in lines if line.startswith(\"#\")]\n        notes = \" \".join(comments)\n        lines = [x.strip() for x in lines if x.strip() and not x.startswith(\"#\")]\n        if len(lines) == 0:\n            # was a block of entirely comments\n            continue\n        semgrex = lines[0]\n        ssurgeon = lines[1:]\n        ssurgeon_edits.append(SsurgeonEdit(semgrex, ssurgeon, \"%d\" % (idx + 1), notes))\n    return ssurgeon_edits\n\ndef read_ssurgeon_edits(edit_file):\n    with open(edit_file, encoding=\"utf-8\") as fin:\n        return parse_ssurgeon_edits(fin.read())\n\ndef send_ssurgeon_request(request):\n    return java_protobuf_requests.send_request(request, SsurgeonResponse, SSURGEON_JAVA)\n\ndef build_request(doc, ssurgeon_edits):\n    request = SsurgeonRequest()\n\n    for ssurgeon in ssurgeon_edits:\n        ssurgeon_proto = request.ssurgeon.add()\n        ssurgeon_proto.semgrex = ssurgeon.semgrex_pattern\n        for operation in ssurgeon.ssurgeon_edits:\n            ssurgeon_proto.operation.append(operation)\n        if ssurgeon.ssurgeon_id is not None:\n            ssurgeon_proto.id = ssurgeon.ssurgeon_id\n        if ssurgeon.notes is not None:\n            ssurgeon_proto.notes = ssurgeon.notes\n        if ssurgeon.language is not None:\n            ssurgeon_proto.language = ssurgeon.language\n\n    try:\n        for sent_idx, sentence in enumerate(doc.sentences):\n            graph = request.graph.add()\n            word_idx = 0\n            for token in sentence.tokens:\n                for word in token.words:\n                    java_protobuf_requests.add_token(graph.token, word, token)\n                    java_protobuf_requests.add_word_to_graph(graph, word, sent_idx)\n\n                    word_idx = word_idx + 1\n    except Exception as e:\n        raise RuntimeError(\"Failed to process sentence {}:\\n{:C}\".format(sent_idx, sentence)) from e\n\n    return request\n\ndef build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):\n    ssurgeon_edit = SsurgeonEdit(semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)\n    return build_request(doc, [ssurgeon_edit])\n\ndef process_doc(doc, ssurgeon_edits):\n    \"\"\"\n    Returns the result of processing the given semgrex expression and ssurgeon edits on the stanza doc.\n\n    Currently the return is a SsurgeonResponse from CoreNLP.proto\n    \"\"\"\n    request = build_request(doc, ssurgeon_edits)\n\n    return send_ssurgeon_request(request)\n\ndef process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):\n    request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)\n\n    return send_ssurgeon_request(request)\n\ndef build_word_entry(word_index, graph_word):\n    word_entry = {\n        ID: word_index,\n        TEXT: graph_word.word if graph_word.word else None,\n        LEMMA: graph_word.lemma if graph_word.lemma else None,\n        UPOS: graph_word.coarseTag if graph_word.coarseTag else None,\n        XPOS: graph_word.pos if graph_word.pos else None,\n        FEATS: java_protobuf_requests.features_to_string(graph_word.conllUFeatures),\n        DEPS: None,\n        NER: graph_word.ner if graph_word.ner else None,\n        MISC: None,\n        START_CHAR: None,   # TODO: fix this?  one problem is the text positions\n        END_CHAR: None,     #   might change across all of the sentences\n        # presumably python will complain if this conflicts\n        # with one of the constants above\n        \"is_mwt\": graph_word.isMWT,\n        \"is_first_mwt\": graph_word.isFirstMWT,\n        \"mwt_text\": graph_word.mwtText,\n        \"mwt_misc\": graph_word.mwtMisc,\n    }\n    # TODO: do \"before\" as well\n    word_entry[MISC] = space_after_to_misc(graph_word.after)\n    if graph_word.conllUMisc:\n        word_entry[MISC] = java_protobuf_requests.substitute_space_misc(graph_word.conllUMisc, word_entry[MISC])\n    return word_entry\n\ndef convert_response_to_doc(doc, semgrex_response, add_missing_text):\n    doc = copy.deepcopy(doc)\n    try:\n        for sent_idx, (sentence, ssurgeon_result) in enumerate(zip(doc.sentences, semgrex_response.result)):\n            # EditNode is currently bugged... :/\n            # TODO: change this after next CoreNLP release (after 4.5.3)\n            #if not ssurgeon_result.changed:\n            #    continue\n\n            ssurgeon_graph = ssurgeon_result.graph\n            tokens = []\n            token_id_to_idx = {}\n            for graph_node, graph_word in zip(ssurgeon_graph.node, ssurgeon_graph.token):\n                word_entry = build_word_entry(graph_node.index, graph_word)\n                token_id_to_idx[graph_node.index] = len(tokens)\n                tokens.append(word_entry)\n            for root in ssurgeon_graph.root:\n                tokens[token_id_to_idx[root]][HEAD] = 0\n                tokens[token_id_to_idx[root]][DEPREL] = \"root\"\n            for edge in ssurgeon_graph.edge:\n                # can't do anything about the extra dependencies for now\n                # TODO: put them all in .deps\n                if edge.isExtra:\n                    continue\n                tokens[token_id_to_idx[edge.target]][HEAD] = edge.source\n                tokens[token_id_to_idx[edge.target]][DEPREL] = edge.dep\n\n            tokens.sort(key=lambda x: x[ID])\n\n            # for any MWT, produce a token_entry which represents the word range\n            mwt_tokens = []\n            for word_start_idx, word in enumerate(tokens):\n                if not word[\"is_first_mwt\"]:\n                    if word[\"is_mwt\"]:\n                        word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC])\n                    mwt_tokens.append(word)\n                    continue\n                word_end_idx = word_start_idx + 1\n                while word_end_idx < len(tokens) and tokens[word_end_idx][\"is_mwt\"] and not tokens[word_end_idx][\"is_first_mwt\"]:\n                    word_end_idx += 1\n                mwt_token_entry = {\n                    # the tokens don't fencepost the way lists do\n                    ID: (tokens[word_start_idx][ID], tokens[word_end_idx-1][ID]),\n                    TEXT: word[\"mwt_text\"],\n                    NER: word[NER],\n                    # use the SpaceAfter=No (or not) from the last word in the token\n                    MISC: None,\n                }\n                mwt_token_entry[MISC] = java_protobuf_requests.misc_space_pieces(tokens[word_end_idx-1][MISC])\n                if tokens[word_end_idx-1][\"mwt_misc\"]:\n                    mwt_token_entry[MISC] = java_protobuf_requests.substitute_space_misc(tokens[word_end_idx-1][\"mwt_misc\"], mwt_token_entry[MISC])\n                word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC])\n                mwt_tokens.append(mwt_token_entry)\n                mwt_tokens.append(word)\n\n            old_comments = list(sentence.comments)\n            sentence = Sentence(mwt_tokens, doc)\n\n            token_text = []\n            for token_idx, token in enumerate(sentence.tokens):\n                token_text.append(token.text)\n                if token_idx == len(sentence.tokens) - 1:\n                    break\n                token_text.append(token.spaces_after)\n\n            sentence_text = \"\".join(token_text)\n\n            found_text = False\n            for comment in old_comments:\n                if comment.startswith(\"# text \") or comment.startswith(\"#text \") or comment.startswith(\"# text=\") or comment.startswith(\"#text=\"):\n                    sentence.add_comment(\"# text = \" + sentence_text)\n                    found_text = True\n                else:\n                    sentence.add_comment(comment)\n            if not found_text and add_missing_text:\n                sentence.add_comment(\"# text = \" + sentence_text)\n\n            doc.sentences[sent_idx] = sentence\n\n            sentence.rebuild_dependencies()\n    except Exception as e:\n        raise RuntimeError(\"Ssurgeon could not process sentence {}\\nSsurgeon result:\\n{}\\nOriginal sentence:\\n{:C}\".format(sent_idx, ssurgeon_result, sentence)) from e\n    return doc\n\nclass Ssurgeon(java_protobuf_requests.JavaProtobufContext):\n    \"\"\"\n    Ssurgeon context window\n\n    This is a context window which keeps a process open.  Should allow\n    for multiple requests without launching new java processes each time.\n    \"\"\"\n    def __init__(self, classpath=None):\n        super(Ssurgeon, self).__init__(classpath, SsurgeonResponse, SSURGEON_JAVA)\n\n    def process(self, doc, ssurgeon_edits):\n        \"\"\"\n        Apply each of the ssurgeon patterns to each of the dependency trees in doc\n        \"\"\"\n        request = build_request(doc, ssurgeon_edits)\n        return self.process_request(request)\n\n    def process_one_operation(self, doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):\n        \"\"\"\n        Convenience method - build one operation, then apply it\n        \"\"\"\n        request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)\n        return self.process_request(request)\n\nSAMPLE_DOC = \"\"\"\n# sent_id = 271\n# text = Hers is easy to clean.\n# previous = What did the dealer like about Alex's car?\n# comment = extraction/raising via \"tough extraction\" and clausal subject\n1\tHers\thers\tPRON\tPRP\tGender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t3\tnsubj\t_\t_\n2\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\tcop\t_\t_\n3\teasy\teasy\tADJ\tJJ\tDegree=Pos\t0\troot\t_\t_\n4\tto\tto\tPART\tTO\t_\t5\tmark\t_\t_\n5\tclean\tclean\tVERB\tVB\tVerbForm=Inf\t3\tcsubj\t_\tSpaceAfter=No\n6\t.\t.\tPUNCT\t.\t_\t5\tpunct\t_\t_\n\"\"\"\n\ndef main():\n    # for Windows, so that we aren't randomly printing garbage (or just failing to print)\n    try:\n        sys.stdout.reconfigure(encoding='utf-8')\n    except AttributeError:\n        # TODO: deprecate 3.6 support after the next release\n        pass\n\n    # The default semgrex detects sentences in the UD_English-Pronouns dataset which have both nsubj and csubj on the same word.\n    # The default ssurgeon transforms the unwanted csubj to advcl\n    # See https://github.com/UniversalDependencies/docs/issues/923\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input', type=str, default=None, help=\"Input file / directory to process (otherwise will process a sample text)\")\n    parser.add_argument('--output', type=str, default=None, help=\"Output location (otherwise will write back to the input directory)\")\n    parser.add_argument('--stdout', action='store_true', default=False, help='Output to stdout')\n    parser.add_argument('--input_filter', type=str, default=\".*[.]conllu\", help=\"If processing a directory, only process files from --input that match this filter - regex, not shell filter.  Default: %(default)s\")\n    parser.add_argument('--no_input_filter', action='store_const', const=None, dest=\"input_filter\", help=\"Remove the default input filename filter\")\n    parser.add_argument('--edit_file', type=str, default=None, help=\"File to get semgrex and ssurgeon rules from\")\n    parser.add_argument('--semgrex', type=str, default=\"{}=source >nsubj {} >csubj=bad {}\", help=\"Semgrex to apply to the text.  A default detects words which have both an nsubj and a csubj.  Default: %(default)s\")\n    parser.add_argument('ssurgeon', type=str, default=[\"relabelNamedEdge -edge bad -reln advcl\"], nargs=\"*\", help=\"Ssurgeon edits to apply based on the Semgrex.  Can have multiple edits in a row.  A default exists to transform csubj into advcl.  Default: %(default)s\")\n    parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help=\"Print the input alongside the output - gets kind of noisy.  Default: %(default)s\")\n    parser.add_argument('--no_print_input', dest='print_input', action='store_false', help=\"Don't print the input alongside the output - gets kind of noisy\")\n    parser.add_argument('--no_add_missing_text', dest='add_missing_text', action='store_false', help=\"By default, the tool will add a #text comment if one does not exist.  This leaves that blank\")\n    args = parser.parse_args()\n\n    if args.edit_file:\n        ssurgeon_edits = read_ssurgeon_edits(args.edit_file)\n    else:\n        ssurgeon_edits = [SsurgeonEdit(args.semgrex, args.ssurgeon)]\n\n    if args.input:\n        if os.path.isfile(args.input):\n            docs = [CoNLL.conll2doc(input_file=args.input)]\n            if args.output is None:\n                outputs = [args.input]\n            else:\n                # TODO: could check if --output is a directory\n                outputs = [args.output]\n            input_output = zip(docs, outputs)\n        else:\n            if not args.output:\n                args.output = args.input\n            if not os.path.exists(args.output):\n                os.makedirs(args.output)\n            def read_docs():\n                for doc_filename in os.listdir(args.input):\n                    if args.input_filter:\n                        if not re.match(args.input_filter, doc_filename):\n                            continue\n                    doc_path = os.path.join(args.input, doc_filename)\n                    if not os.path.isfile(doc_path):\n                        continue\n                    output_path = os.path.join(args.output, doc_filename)\n                    print(\"Processing %s to %s\" % (doc_path, output_path))\n                    yield CoNLL.conll2doc(input_file=doc_path), output_path\n            input_output = read_docs()\n    else:\n        docs = [CoNLL.conll2doc(input_str=SAMPLE_DOC)]\n        outputs = [None]\n        input_output = zip(docs, outputs)\n        args.stdout = True\n\n    for doc, output in input_output:\n        if args.print_input:\n            print(\"{:C}\".format(doc))\n        ssurgeon_request = build_request(doc, ssurgeon_edits)\n        ssurgeon_response = send_ssurgeon_request(ssurgeon_request)\n        updated_doc = convert_response_to_doc(doc, ssurgeon_response, args.add_missing_text)\n        if output is not None:\n            with open(output, \"w\", encoding=\"utf-8\") as fout:\n                fout.write(\"{:C}\\n\\n\".format(updated_doc))\n        if args.stdout:\n            print(\"{:C}\\n\".format(updated_doc))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/server/tokensregex.py",
    "content": "\"\"\"Invokes the Java tokensregex on a document\n\nThis operates tokensregex on docs processed with stanza models.\n\nhttps://nlp.stanford.edu/software/tokensregex.html\n\nA minimal example is the main method of this module.\n\"\"\"\n\nimport stanza\n\nfrom stanza.protobuf import TokensRegexRequest, TokensRegexResponse\nfrom stanza.server.java_protobuf_requests import send_request, add_sentence\n\ndef send_tokensregex_request(request):\n    return send_request(request, TokensRegexResponse,\n                        \"edu.stanford.nlp.ling.tokensregex.ProcessTokensRegexRequest\")\n\ndef process_doc(doc, *patterns):\n    request = TokensRegexRequest()\n    for pattern in patterns:\n        request.pattern.append(pattern)\n\n    request_doc = request.doc\n    request_doc.text = doc.text\n    num_tokens = 0\n    for sentence in doc.sentences:\n        add_sentence(request_doc.sentence, sentence, num_tokens)\n        num_tokens = num_tokens + sum(len(token.words) for token in sentence.tokens)\n\n    return send_tokensregex_request(request)\n\ndef main():\n    #nlp = stanza.Pipeline('en',\n    #                      processors='tokenize,pos,lemma,ner')\n    nlp = stanza.Pipeline('en',\n                          processors='tokenize')\n\n    doc = nlp('Uro ruined modern.  Fortunately, Wotc banned him')\n    print(process_doc(doc, \"him\", \"ruined\"))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/server/tsurgeon.py",
    "content": "\"\"\"Invokes the Java tsurgeon on a list of trees\n\nIncluded with CoreNLP is a mechanism for modifying trees based on\nexisting patterns within a tree.  The patterns are found using tregex:\n\nhttps://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/trees/tregex/TregexPattern.html\n\nThe modifications are then performed using tsurgeon:\n\nhttps://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/trees/tregex/tsurgeon/Tsurgeon.html\n\nThis module accepts Tree objects as produced by the conparser and\nreturns the modified trees that result from one or more tsurgeon\noperations.\n\"\"\"\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.protobuf import TsurgeonRequest, TsurgeonResponse\nfrom stanza.server.java_protobuf_requests import send_request, build_tree, from_tree, JavaProtobufContext\n\nTSURGEON_JAVA = \"edu.stanford.nlp.trees.tregex.tsurgeon.ProcessTsurgeonRequest\"\n\ndef send_tsurgeon_request(request):\n    return send_request(request, TsurgeonResponse, TSURGEON_JAVA)\n\n\ndef build_request(trees, operations):\n    \"\"\"\n    Build the TsurgeonRequest object\n\n    trees: a list of trees\n    operations: a list of (tregex, tsurgeon, tsurgeon, ...)\n    \"\"\"\n    if isinstance(trees, Tree):\n        trees = (trees,)\n\n    request = TsurgeonRequest()\n    for tree in trees:\n        request.trees.append(build_tree(tree, 0.0))\n    if all(isinstance(x, str) for x in operations):\n        operations = (operations,)\n    for operation in operations:\n        if len(operation) == 1:\n            raise ValueError(\"Expected [tregex, tsurgeon, ...] but just got a tregex\")\n        operation_request = request.operations.add()\n        operation_request.tregex = operation[0]\n        for tsurgeon in operation[1:]:\n            operation_request.tsurgeon.append(tsurgeon)\n    return request\n\n\ndef process_trees(trees, *operations):\n    \"\"\"\n    Returns the result of processing the given tsurgeon operations on the given trees\n\n    Returns a list of modified trees, eg, the result is already processed\n    \"\"\"\n    request = build_request(trees, operations)\n    result = send_tsurgeon_request(request)\n\n    return [from_tree(t)[0] for t in result.trees]\n\n\nclass Tsurgeon(JavaProtobufContext):\n    \"\"\"\n    Tsurgeon context window\n\n    This is a context window which keeps a process open.  Should allow\n    for multiple requests without launching new java processes each time.\n    \"\"\"\n    def __init__(self, classpath=None):\n        super(Tsurgeon, self).__init__(classpath, TsurgeonResponse, TSURGEON_JAVA)\n\n    def process(self, trees, *operations):\n        request = build_request(trees, operations)\n        result = self.process_request(request)\n        return [from_tree(t)[0] for t in result.trees]\n\n\ndef main():\n    \"\"\"\n    A small demonstration of a tsurgeon operation\n    \"\"\"\n    text=\"( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n\n    tregex = \"WP=wp\"\n    tsurgeon = \"relabel wp WWWPPP\"\n\n    result = process_trees(trees, (tregex, tsurgeon))\n    print(result)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/server/ud_enhancer.py",
    "content": "\n\nimport stanza\nfrom stanza.protobuf import DependencyEnhancerRequest, Document, Language\nfrom stanza.server.java_protobuf_requests import send_request, add_sentence, JavaProtobufContext\n\nENHANCER_JAVA = \"edu.stanford.nlp.trees.ud.ProcessUniversalEnhancerRequest\"\n\ndef build_enhancer_request(doc, language, pronouns_pattern):\n    if bool(language) == bool(pronouns_pattern):\n        raise ValueError(\"Should set exactly one of language and pronouns_pattern\")\n\n    request = DependencyEnhancerRequest()\n    if pronouns_pattern:\n        request.setRelativePronouns(pronouns_pattern)\n    elif language.lower() in (\"en\", \"english\"):\n        request.language = Language.UniversalEnglish\n    elif language.lower() in (\"zh\", \"zh-hans\", \"chinese\"):\n        request.language = Language.UniversalChinese\n    else:\n        raise ValueError(\"Sorry, but language \" + language + \" is not supported yet.  Either set a pronouns pattern or file an issue at https://stanfordnlp.github.io/stanza suggesting a mechanism for converting this language\")\n\n    request_doc = request.document\n    request_doc.text = doc.text\n    num_tokens = 0\n    for sent_idx, sentence in enumerate(doc.sentences):\n        request_sentence = add_sentence(request_doc.sentence, sentence, num_tokens)\n        num_tokens = num_tokens + sum(len(token.words) for token in sentence.tokens)\n\n        graph = request_sentence.basicDependencies\n        nodes = []\n        word_index = 0\n        for token in sentence.tokens:\n            for word in token.words:\n                # TODO: refactor with the bit in java_protobuf_requests\n                word_index = word_index + 1\n                node = graph.node.add()\n                node.sentenceIndex = sent_idx\n                node.index = word_index\n\n                if word.head != 0:\n                    edge = graph.edge.add()\n                    edge.source = word.head\n                    edge.target = word_index\n                    edge.dep = word.deprel\n\n    return request\n\ndef process_doc(doc, language=None, pronouns_pattern=None):\n    request = build_enhancer_request(doc, language, pronouns_pattern)\n    return send_request(request, Document, ENHANCER_JAVA)\n\nclass UniversalEnhancer(JavaProtobufContext):\n    \"\"\"\n    UniversalEnhancer context window\n\n    This is a context window which keeps a process open.  Should allow\n    for multiple requests without launching new java processes each time.\n    \"\"\"\n    def __init__(self, language=None, pronouns_pattern=None, classpath=None):\n        super(UniversalEnhancer, self).__init__(classpath, Document, ENHANCER_JAVA)\n        if bool(language) == bool(pronouns_pattern):\n            raise ValueError(\"Should set exactly one of language and pronouns_pattern\")\n        self.language = language\n        self.pronouns_pattern = pronouns_pattern\n\n    def process(self, doc):\n        request = build_enhancer_request(doc, self.language, self.pronouns_pattern)\n        return self.process_request(request)\n\ndef main():\n    nlp = stanza.Pipeline('en',\n                          processors='tokenize,pos,lemma,depparse')\n\n    with UniversalEnhancer(language=\"en\") as enhancer:\n        doc = nlp(\"This is the car that I bought\")\n        result = enhancer.process(doc)\n        print(result.sentence[0].enhancedDependencies)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/tests/__init__.py",
    "content": "\"\"\"\nUtilities for testing\n\"\"\"\n\nimport os\nimport re\n\nfrom platformdirs import user_cache_dir\n\nfrom stanza import __resources_version__\n\n# Environment Variables\n# set this to specify working directory of tests\nTEST_HOME_VAR = 'STANZA_TEST_HOME'\n\n# Global Variables\nTEST_DIR_BASE_NAME = 'stanza_test'\n\nTEST_WORKING_DIR = os.getenv(TEST_HOME_VAR, None)\nif not TEST_WORKING_DIR:\n    TEST_WORKING_DIR = user_cache_dir(TEST_DIR_BASE_NAME, 'StanfordNLP', __resources_version__)\n\nTEST_MODELS_DIR = f'{TEST_WORKING_DIR}/models'\nTEST_CORENLP_DIR = f'{TEST_WORKING_DIR}/corenlp_dir'\n\n# server resources\nSERVER_TEST_PROPS = f'{TEST_WORKING_DIR}/scripts/external_server.properties'\n\n# language resources\nLANGUAGE_RESOURCES = {}\n\nTOKENIZE_MODEL = 'tokenizer.pt'\nMWT_MODEL = 'mwt_expander.pt'\nPOS_MODEL = 'tagger.pt'\nPOS_PRETRAIN = 'pretrain.pt'\nLEMMA_MODEL = 'lemmatizer.pt'\nDEPPARSE_MODEL = 'parser.pt'\nDEPPARSE_PRETRAIN = 'pretrain.pt'\n\nMODEL_FILES = [TOKENIZE_MODEL, MWT_MODEL, POS_MODEL, POS_PRETRAIN, LEMMA_MODEL, DEPPARSE_MODEL, DEPPARSE_PRETRAIN]\n\n# English resources\nEN_KEY = 'en'\nEN_SHORTHAND = 'en_ewt'\n# models\nEN_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{EN_SHORTHAND}_models'\nEN_MODEL_FILES = [f'{EN_MODELS_DIR}/{EN_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES]\n\n# French resources\nFR_KEY = 'fr'\nFR_SHORTHAND = 'fr_gsd'\n# regression file paths\nFR_TEST_IN = f'{TEST_WORKING_DIR}/in/fr_gsd.test.txt'\nFR_TEST_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out'\nFR_TEST_GOLD_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out.gold'\n# models\nFR_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{FR_SHORTHAND}_models'\nFR_MODEL_FILES = [f'{FR_MODELS_DIR}/{FR_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES]\n\n# Other language resources\nAR_SHORTHAND = 'ar_padt'\nDE_SHORTHAND = 'de_gsd'\nKK_SHORTHAND = 'kk_ktb'\nKO_SHORTHAND = 'ko_gsd'\n\n\n# utils for clean up\n# only allow removal of dirs/files in this approved list\nREMOVABLE_PATHS = ['en_ewt_models', 'en_ewt_tokenizer.pt', 'en_ewt_mwt_expander.pt', 'en_ewt_tagger.pt',\n                   'en_ewt.pretrain.pt', 'en_ewt_lemmatizer.pt', 'en_ewt_parser.pt', 'fr_gsd_models',\n                   'fr_gsd_tokenizer.pt', 'fr_gsd_mwt_expander.pt', 'fr_gsd_tagger.pt', 'fr_gsd.pretrain.pt',\n                   'fr_gsd_lemmatizer.pt', 'fr_gsd_parser.pt', 'ar_padt_models', 'ar_padt_tokenizer.pt',\n                   'ar_padt_mwt_expander.pt', 'ar_padt_tagger.pt', 'ar_padt.pretrain.pt', 'ar_padt_lemmatizer.pt',\n                   'ar_padt_parser.pt', 'de_gsd_models', 'de_gsd_tokenizer.pt', 'de_gsd_mwt_expander.pt',\n                   'de_gsd_tagger.pt', 'de_gsd.pretrain.pt', 'de_gsd_lemmatizer.pt', 'de_gsd_parser.pt',\n                   'kk_ktb_models', 'kk_ktb_tokenizer.pt', 'kk_ktb_mwt_expander.pt', 'kk_ktb_tagger.pt',\n                   'kk_ktb.pretrain.pt', 'kk_ktb_lemmatizer.pt', 'kk_ktb_parser.pt', 'ko_gsd_models',\n                   'ko_gsd_tokenizer.pt', 'ko_gsd_mwt_expander.pt', 'ko_gsd_tagger.pt', 'ko_gsd.pretrain.pt',\n                   'ko_gsd_lemmatizer.pt', 'ko_gsd_parser.pt']\n\n\ndef safe_rm(path_to_rm):\n    \"\"\"\n    Safely remove a directory of files or a file\n    1.) check path exists, files are files, dirs are dirs\n    2.) only remove things on approved list REMOVABLE_PATHS\n    3.) assert no longer exists\n    \"\"\"\n    # just return if path doesn't exist\n    if not os.path.exists(path_to_rm):\n        return\n    # handle directory\n    if os.path.isdir(path_to_rm):\n        files_to_rm = [f'{path_to_rm}/{fname}' for fname in os.listdir(path_to_rm)]\n        dir_to_rm = path_to_rm\n    else:\n        files_to_rm = [path_to_rm]\n        dir_to_rm = None\n    # clear out files\n    for file_to_rm in files_to_rm:\n        if os.path.isfile(file_to_rm) and os.path.basename(file_to_rm) in REMOVABLE_PATHS:\n            os.remove(file_to_rm)\n            assert not os.path.exists(file_to_rm), f'Error removing: {file_to_rm}'\n    # clear out directory\n    if dir_to_rm is not None and os.path.isdir(dir_to_rm):\n        os.rmdir(dir_to_rm)\n        assert not os.path.exists(dir_to_rm), f'Error removing: {dir_to_rm}'\n\ndef compare_ignoring_whitespace(predicted, expected):\n    predicted = re.sub('[ \\t]+', ' ', predicted.strip())\n    predicted = re.sub('\\r\\n', '\\n', predicted)\n    expected = re.sub('[ \\t]+', ' ', expected.strip())\n    expected = re.sub('\\r\\n', '\\n', expected)\n    assert predicted == expected\n\n"
  },
  {
    "path": "stanza/tests/classifiers/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/classifiers/test_classifier.py",
    "content": "import glob\nimport os\n\nimport pytest\n\nimport numpy as np\nimport torch\n\nimport stanza\nimport stanza.models.classifier as classifier\nimport stanza.models.classifiers.data as data\nfrom stanza.models.classifiers.trainer import Trainer\nfrom stanza.models.common import pretrain\nfrom stanza.models.common import utils\n\nfrom stanza.tests import TEST_MODELS_DIR\nfrom stanza.tests.classifiers.test_data import train_file, dev_file, test_file, DATASET, SENTENCES\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nEMB_DIM = 5\n\n@pytest.fixture(scope=\"module\")\ndef fake_embeddings(tmp_path_factory):\n    \"\"\"\n    will return a path to a fake embeddings file with the words in SENTENCES\n    \"\"\"\n    # could set np random seed here\n    words = sorted(set([x.lower() for y in SENTENCES for x in y]))\n    words = words[:-1]\n    embedding_dir = tmp_path_factory.mktemp(\"data\")\n    embedding_txt = embedding_dir / \"embedding.txt\"\n    embedding_pt  = embedding_dir / \"embedding.pt\"\n    embedding = np.random.random((len(words), EMB_DIM))\n\n    with open(embedding_txt, \"w\", encoding=\"utf-8\") as fout:\n        for word, emb in zip(words, embedding):\n            fout.write(word)\n            fout.write(\"\\t\")\n            fout.write(\"\\t\".join(str(x) for x in emb))\n            fout.write(\"\\n\")\n\n    pt = pretrain.Pretrain(str(embedding_pt), str(embedding_txt))\n    pt.load()\n    assert os.path.exists(embedding_pt)\n    return embedding_pt\n\nclass TestClassifier:\n    def build_model(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None):\n        \"\"\"\n        Build a model to be used by one of the later tests\n        \"\"\"\n        save_dir = str(tmp_path / \"classifier\")\n        save_name = \"model.pt\"\n        args = [\"--save_dir\", save_dir,\n                \"--save_name\", save_name,\n                \"--wordvec_pretrain_file\", str(fake_embeddings),\n                \"--filter_channels\", \"20\",\n                \"--fc_shapes\", \"20,10\",\n                \"--train_file\", str(train_file),\n                \"--dev_file\", str(dev_file),\n                \"--max_epochs\", \"2\",\n                \"--batch_size\", \"60\"]\n        if extra_args is not None:\n            args = args + extra_args\n        args = classifier.parse_args(args)\n        train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)\n        if checkpoint_file:\n            trainer = Trainer.load(checkpoint_file, args, load_optimizer=True)\n        else:\n            trainer = Trainer.build_new_model(args, train_set)\n        return trainer, train_set, args\n\n    def run_training(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None):\n        \"\"\"\n        Iterate a couple times over a model\n        \"\"\"\n        trainer, train_set, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args, checkpoint_file)\n        dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)\n        labels = data.dataset_labels(train_set)\n\n        save_filename = os.path.join(args.save_dir, args.save_name)\n        if checkpoint_file is None:\n            checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)\n        classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)\n        return trainer, save_filename, checkpoint_file\n\n    def test_build_model(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"\n        Test that building a basic model works\n        \"\"\"\n        self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\"])\n\n    def test_save_load(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"\n        Test that a basic model can save & load\n        \"\"\"\n        trainer, _, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\"])\n\n        save_filename = os.path.join(args.save_dir, args.save_name)\n        trainer.save(save_filename)\n\n        args.load_name = args.save_name\n        trainer = Trainer.load(args.load_name, args)\n        args.load_name = save_filename\n        trainer = Trainer.load(args.load_name, args)\n\n    def test_train_basic(self, tmp_path, fake_embeddings, train_file, dev_file):\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\"])\n\n    def test_train_bilstm(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"\n        Test w/ and w/o bilstm variations of the classifier\n        \"\"\"\n        args = [\"--bilstm\", \"--bilstm_hidden_dim\", \"20\"]\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n\n        args = [\"--no_bilstm\"]\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n\n    def test_train_maxpool_width(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"\n        Test various maxpool widths\n\n        Also sets --filter_channels to a multiple of 2 but not of 3 for\n        the test to make sure the math is done correctly on a non-divisible width\n        \"\"\"\n        args = [\"--maxpool_width\", \"1\", \"--filter_channels\", \"20\", \"--bilstm_hidden_dim\", \"20\"]\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n\n        args = [\"--maxpool_width\", \"2\", \"--filter_channels\", \"20\", \"--bilstm_hidden_dim\", \"20\"]\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n\n        args = [\"--maxpool_width\", \"3\", \"--filter_channels\", \"20\", \"--bilstm_hidden_dim\", \"20\"]\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n\n    def test_train_conv_2d(self, tmp_path, fake_embeddings, train_file, dev_file):\n        args = [\"--filter_sizes\", \"(3,4,5)\", \"--filter_channels\", \"20\", \"--bilstm_hidden_dim\", \"20\"]\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n\n        args = [\"--filter_sizes\", \"((3,2),)\", \"--filter_channels\", \"20\", \"--bilstm_hidden_dim\", \"20\"]\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n\n        args = [\"--filter_sizes\", \"((3,2),3)\", \"--filter_channels\", \"20\", \"--bilstm_hidden_dim\", \"20\"]\n        self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n\n    def test_train_filter_channels(self, tmp_path, fake_embeddings, train_file, dev_file):\n        args = [\"--filter_sizes\", \"((3,2),3)\", \"--filter_channels\", \"20\", \"--no_bilstm\"]\n        trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n        assert trainer.model.fc_input_size == 40\n\n        args = [\"--filter_sizes\", \"((3,2),3)\", \"--filter_channels\", \"15,20\", \"--no_bilstm\"]\n        trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)\n        # 50 = 2x15 for the 2d conv (over 5 dim embeddings) + 20\n        assert trainer.model.fc_input_size == 50\n\n    def test_train_bert(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"\n        Test on a tiny Bert WITHOUT finetuning, which hopefully does not take up too much disk space or memory\n        \"\"\"\n        bert_model = \"hf-internal-testing/tiny-bert\"\n\n        trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\", \"--bert_model\", bert_model])\n        assert os.path.exists(save_filename)\n        saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)\n        # check that the bert model wasn't saved as part of the classifier\n        assert not saved_model['params']['config']['force_bert_saved']\n        assert not any(x.startswith(\"bert_model\") for x in saved_model['params']['model'].keys())\n\n    def test_finetune_bert(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"\n        Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory\n        \"\"\"\n        bert_model = \"hf-internal-testing/tiny-bert\"\n\n        trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\", \"--bert_model\", bert_model, \"--bert_finetune\"])\n        assert os.path.exists(save_filename)\n        saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)\n        # after finetuning the bert model, make sure that the save file DOES contain parts of the transformer\n        assert saved_model['params']['config']['force_bert_saved']\n        assert any(x.startswith(\"bert_model\") for x in saved_model['params']['model'].keys())\n\n    def test_finetune_bert_layers(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory, using 2 layers\n\n        As an added bonus (or eager test), load the finished model and continue\n        training from there.  Then check that the initial model and\n        the middle model are different, then that the middle model and\n        final model are different\n\n        \"\"\"\n        bert_model = \"hf-internal-testing/tiny-bert\"\n\n        trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\", \"--bert_model\", bert_model, \"--bert_finetune\", \"--bert_hidden_layers\", \"2\", \"--save_intermediate_models\"])\n        assert os.path.exists(save_filename)\n\n        save_path = os.path.split(save_filename)[0]\n\n        initial_model = glob.glob(os.path.join(save_path, \"*E0000*\"))\n        assert len(initial_model) == 1\n        initial_model = initial_model[0]\n        initial_model = torch.load(initial_model, lambda storage, loc: storage, weights_only=True)\n\n        second_model_file = glob.glob(os.path.join(save_path, \"*E0002*\"))\n        assert len(second_model_file) == 1\n        second_model_file = second_model_file[0]\n        second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)\n\n        for layer_idx in range(2):\n            bert_names = [x for x in second_model['params']['model'].keys() if x.startswith(\"bert_model\") and \"layer.%d.\" % layer_idx in x]\n            assert len(bert_names) > 0\n            assert all(x in initial_model['params']['model'] and x in second_model['params']['model'] for x in bert_names)\n            assert not all(torch.allclose(initial_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names)\n\n        # put some random marker in the file to look for later,\n        # check the continued training didn't clobber the expected file\n        assert \"asdf\" not in second_model\n        second_model[\"asdf\"] = 1234\n        torch.save(second_model, second_model_file)\n\n        trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\", \"--bert_model\", bert_model, \"--bert_finetune\", \"--bert_hidden_layers\", \"2\", \"--save_intermediate_models\", \"--max_epochs\", \"5\"], checkpoint_file=checkpoint_file)\n\n        second_model_file_redo = glob.glob(os.path.join(save_path, \"*E0002*\"))\n        assert len(second_model_file_redo) == 1\n        assert second_model_file == second_model_file_redo[0]\n        second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)\n        assert \"asdf\" in second_model\n\n        fifth_model_file = glob.glob(os.path.join(save_path, \"*E0005*\"))\n        assert len(fifth_model_file) == 1\n\n        final_model = torch.load(fifth_model_file[0], lambda storage, loc: storage, weights_only=True)\n        for layer_idx in range(2):\n            bert_names = [x for x in final_model['params']['model'].keys() if x.startswith(\"bert_model\") and \"layer.%d.\" % layer_idx in x]\n            assert len(bert_names) > 0\n            assert all(x in final_model['params']['model'] and x in second_model['params']['model'] for x in bert_names)\n            assert not all(torch.allclose(final_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names)\n\n    def test_finetune_peft(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"\n        Test on a tiny Bert with PEFT finetuning\n        \"\"\"\n        bert_model = \"hf-internal-testing/tiny-bert\"\n\n        trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\", \"--bert_model\", bert_model, \"--bert_finetune\", \"--use_peft\", \"--lora_modules_to_save\", \"pooler\"])\n        assert os.path.exists(save_filename)\n        saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)\n        # after finetuning the bert model, make sure that the save file DOES contain parts of the transformer, but only in peft form\n        assert saved_model['params']['config']['bert_model'] == bert_model\n        assert saved_model['params']['config']['force_bert_saved']\n        assert saved_model['params']['config']['use_peft']\n\n        assert not saved_model['params']['config']['has_charlm_forward']\n        assert not saved_model['params']['config']['has_charlm_backward']\n\n        assert len(saved_model['params']['bert_lora']) > 0\n        assert any(x.find(\".pooler.\") >= 0 for x in saved_model['params']['bert_lora'])\n        assert any(x.find(\".encoder.\") >= 0 for x in saved_model['params']['bert_lora'])\n        assert not any(x.startswith(\"bert_model\") for x in saved_model['params']['model'].keys())\n\n        # The Pipeline should load and run a PEFT trained model,\n        # although obviously we don't expect the results to do\n        # anything correct\n        pipeline = stanza.Pipeline(\"en\", download_method=None, model_dir=TEST_MODELS_DIR, processors=\"tokenize,sentiment\", sentiment_model_path=save_filename, sentiment_pretrain_path=str(fake_embeddings))\n        doc = pipeline(\"This is a test\")\n\n    def test_finetune_peft_restart(self, tmp_path, fake_embeddings, train_file, dev_file):\n        \"\"\"\n        Test that if we restart training on a peft model, the peft weights change\n        \"\"\"\n        bert_model = \"hf-internal-testing/tiny-bert\"\n\n        trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\", \"--bert_model\", bert_model, \"--bert_finetune\", \"--use_peft\", \"--lora_modules_to_save\", \"pooler\", \"--save_intermediate_models\"])\n\n        assert os.path.exists(save_file)\n        saved_model = torch.load(save_file, lambda storage, loc: storage, weights_only=True)\n        assert any(x.find(\".encoder.\") >= 0 for x in saved_model['params']['bert_lora'])\n\n\n        trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=[\"--bilstm_hidden_dim\", \"20\", \"--bert_model\", bert_model, \"--bert_finetune\", \"--use_peft\", \"--lora_modules_to_save\", \"pooler\", \"--save_intermediate_models\", \"--max_epochs\", \"5\"], checkpoint_file=checkpoint_file)\n\n        save_path = os.path.split(save_file)[0]\n\n        initial_model_file = glob.glob(os.path.join(save_path, \"*E0000*\"))\n        assert len(initial_model_file) == 1\n        initial_model_file = initial_model_file[0]\n        initial_model = torch.load(initial_model_file, lambda storage, loc: storage, weights_only=True)\n\n        second_model_file = glob.glob(os.path.join(save_path, \"*E0002*\"))\n        assert len(second_model_file) == 1\n        second_model_file = second_model_file[0]\n        second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)\n\n        final_model_file = glob.glob(os.path.join(save_path, \"*E0005*\"))\n        assert len(final_model_file) == 1\n        final_model_file = final_model_file[0]\n        final_model = torch.load(final_model_file, lambda storage, loc: storage, weights_only=True)\n\n        # params in initial_model & second_model start with \"base_model.model.\"\n        # whereas params in final_model start directly with \"encoder\" or \"pooler\"\n        initial_lora = initial_model['params']['bert_lora']\n        second_lora = second_model['params']['bert_lora']\n        final_lora = final_model['params']['bert_lora']\n        for side in (\"_A.\", \"_B.\"):\n            for layer in (\".0.\", \".1.\"):\n                initial_params = sorted([x for x in initial_lora if x.find(\".encoder.\") > 0 and x.find(side) > 0 and x.find(layer) > 0])\n                second_params = sorted([x for x in second_lora if x.find(\".encoder.\") > 0 and x.find(side) > 0 and x.find(layer) > 0])\n                final_params = sorted([x for x in final_lora if x.startswith(\"encoder.\") > 0 and x.find(side) > 0 and x.find(layer) > 0])\n                assert len(initial_params) > 0\n                assert len(initial_params) == len(second_params)\n                assert len(initial_params) == len(final_params)\n                for x, y in zip(second_params, final_params):\n                    assert x.endswith(y)\n                    if side != \"_A.\":  # the A tensors don't move very much, if at all\n                        assert not torch.allclose(initial_lora.get(x), second_lora.get(x))\n                        assert not torch.allclose(second_lora.get(x), final_lora.get(y))\n\n"
  },
  {
    "path": "stanza/tests/classifiers/test_constituency_classifier.py",
    "content": "import os\n\nimport pytest\n\nimport stanza\nimport stanza.models.classifier as classifier\nimport stanza.models.classifiers.data as data\nfrom stanza.models.classifiers.trainer import Trainer\nfrom stanza.tests import TEST_MODELS_DIR\nfrom stanza.tests.classifiers.test_classifier import fake_embeddings\nfrom stanza.tests.classifiers.test_data import train_file_with_trees, dev_file_with_trees\nfrom stanza.models.common import utils\nfrom stanza.tests.constituency.test_trainer import build_trainer, TREEBANK\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nclass TestConstituencyClassifier:\n    @pytest.fixture(scope=\"class\")\n    def constituency_model(self, fake_embeddings, tmp_path_factory):\n        args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']\n        trainer = build_trainer(str(fake_embeddings), *args, treebank=TREEBANK)\n\n        trainer_pt = str(tmp_path_factory.mktemp(\"constituency\") / \"constituency.pt\")\n        trainer.save(trainer_pt, save_optimizer=False)\n        return trainer_pt\n\n    def build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):\n        \"\"\"\n        Build a Constituency Classifier model to be used by one of the later tests\n        \"\"\"\n        save_dir = str(tmp_path / \"classifier\")\n        save_name = \"model.pt\"\n        args = [\"--save_dir\", save_dir,\n                \"--save_name\", save_name,\n                \"--model_type\", \"constituency\",\n                \"--constituency_model\", constituency_model,\n                \"--wordvec_pretrain_file\", str(fake_embeddings),\n                \"--fc_shapes\", \"20,10\",\n                \"--train_file\", str(train_file_with_trees),\n                \"--dev_file\", str(dev_file_with_trees),\n                \"--max_epochs\", \"2\",\n                \"--batch_size\", \"60\"]\n        if extra_args is not None:\n            args = args + extra_args\n        args = classifier.parse_args(args)\n        train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)\n        trainer = Trainer.build_new_model(args, train_set)\n        return trainer, train_set, args\n\n    def run_training(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):\n        \"\"\"\n        Iterate a couple times over a model\n        \"\"\"\n        trainer, train_set, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args)\n        dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)\n        labels = data.dataset_labels(train_set)\n\n        save_filename = os.path.join(args.save_dir, args.save_name)\n        checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)\n        classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)\n        return trainer, train_set, args\n\n    def test_build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):\n        \"\"\"\n        Test that building a basic constituency-based model works\n        \"\"\"\n        self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)\n\n    def test_save_load(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):\n        \"\"\"\n        Test that a constituency model can save & load\n        \"\"\"\n        trainer, _, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)\n\n        save_filename = os.path.join(args.save_dir, args.save_name)\n        trainer.save(save_filename)\n\n        args.load_name = args.save_name\n        trainer = Trainer.load(args.load_name, args)\n\n    def test_train_basic(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):\n        self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)\n\n    def test_train_pipeline(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):\n        \"\"\"\n        Test that writing out a temp model, then loading it in the pipeline is a thing that works\n        \"\"\"\n        trainer, _, args = self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)\n        save_filename = os.path.join(args.save_dir, args.save_name)\n        assert os.path.exists(save_filename)\n        assert os.path.exists(args.constituency_model)\n\n        pipeline_args = {\"lang\": \"en\",\n                         \"download_method\": None,\n                         \"model_dir\": TEST_MODELS_DIR,\n                         \"processors\": \"tokenize,pos,constituency,sentiment\",\n                         \"tokenize_pretokenized\": True,\n                         \"constituency_model_path\": args.constituency_model,\n                         \"constituency_pretrain_path\": args.wordvec_pretrain_file,\n                         \"constituency_backward_charlm_path\": None,\n                         \"constituency_forward_charlm_path\": None,\n                         \"sentiment_model_path\": save_filename,\n                         \"sentiment_pretrain_path\": args.wordvec_pretrain_file,\n                         \"sentiment_backward_charlm_path\": None,\n                         \"sentiment_forward_charlm_path\": None}\n        pipeline = stanza.Pipeline(**pipeline_args)\n        doc = pipeline(\"This is a test\")\n        # since the model is random, we have no expectations for what the result actually is\n        assert doc.sentences[0].sentiment is not None\n\n\n    def test_train_all_words(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):\n        self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_all_words'])\n\n        self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_all_words'])\n\n    def test_train_top_layer(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):\n        self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_top_layer'])\n\n        self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_top_layer'])\n\n    def test_train_attn(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):\n        self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--no_constituency_all_words'])\n\n        self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--constituency_all_words'])\n\n        self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_node_attn'])\n\n"
  },
  {
    "path": "stanza/tests/classifiers/test_data.py",
    "content": "import json\nimport pytest\n\nimport stanza.models.classifiers.data as data\nfrom stanza.models.classifiers.utils import WVType\nfrom stanza.models.common.vocab import PAD, UNK\nfrom stanza.models.constituency.parse_tree import Tree\n\nSENTENCES = [\n    [\"I\", \"hate\", \"the\", \"Opal\", \"banning\"],\n    [\"Tell\", \"my\", \"wife\", \"hello\"], # obviously this is the neutral result\n    [\"I\", \"like\", \"Sh'reyan\", \"'s\", \"antennae\"],\n]\n\nDATASET = [\n    {\"sentiment\": \"0\", \"text\": SENTENCES[0]},\n    {\"sentiment\": \"1\", \"text\": SENTENCES[1]},\n    {\"sentiment\": \"2\", \"text\": SENTENCES[2]},\n]\n\nTREES = [\n    \"(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))\",\n    \"(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))\",\n    \"(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))\",\n]\n\nDATASET_WITH_TREES = [\n    {\"sentiment\": \"0\", \"text\": SENTENCES[0], \"constituency\": TREES[0]},\n    {\"sentiment\": \"1\", \"text\": SENTENCES[1], \"constituency\": TREES[1]},\n    {\"sentiment\": \"2\", \"text\": SENTENCES[2], \"constituency\": TREES[2]},\n]\n\n@pytest.fixture(scope=\"module\")\ndef train_file(tmp_path_factory):\n    train_set = DATASET * 20\n    train_filename = tmp_path_factory.mktemp(\"data\") / \"train.json\"\n    with open(train_filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(train_set, fout, ensure_ascii=False)\n    return train_filename\n\n@pytest.fixture(scope=\"module\")\ndef dev_file(tmp_path_factory):\n    dev_set = DATASET * 2\n    dev_filename = tmp_path_factory.mktemp(\"data\") / \"dev.json\"\n    with open(dev_filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(dev_set, fout, ensure_ascii=False)\n    return dev_filename\n\n@pytest.fixture(scope=\"module\")\ndef test_file(tmp_path_factory):\n    test_set = DATASET\n    test_filename = tmp_path_factory.mktemp(\"data\") / \"test.json\"\n    with open(test_filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(test_set, fout, ensure_ascii=False)\n    return test_filename\n\n@pytest.fixture(scope=\"module\")\ndef train_file_with_trees(tmp_path_factory):\n    train_set = DATASET_WITH_TREES * 20\n    train_filename = tmp_path_factory.mktemp(\"data\") / \"train_trees.json\"\n    with open(train_filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(train_set, fout, ensure_ascii=False)\n    return train_filename\n\n@pytest.fixture(scope=\"module\")\ndef dev_file_with_trees(tmp_path_factory):\n    dev_set = DATASET_WITH_TREES * 2\n    dev_filename = tmp_path_factory.mktemp(\"data\") / \"dev_trees.json\"\n    with open(dev_filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(dev_set, fout, ensure_ascii=False)\n    return dev_filename\n\nclass TestClassifierData:\n    def test_read_data(self, train_file):\n        \"\"\"\n        Test reading of the json format\n        \"\"\"\n        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)\n        assert len(train_set) == 60\n\n    def test_read_data_with_trees(self, train_file, train_file_with_trees):\n        \"\"\"\n        Test reading of the json format\n        \"\"\"\n        train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1)\n        assert len(train_trees_set) == 60\n        for idx, x in enumerate(train_trees_set):\n            assert isinstance(x.constituency, Tree)\n            assert str(x.constituency) == TREES[idx % len(TREES)]\n\n        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)\n\n    def test_dataset_vocab(self, train_file):\n        \"\"\"\n        Converting a dataset to vocab should have a specific set of words along with PAD and UNK\n        \"\"\"\n        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)\n        vocab = data.dataset_vocab(train_set)\n        expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y])\n        assert set(vocab) == expected\n\n    def test_dataset_labels(self, train_file):\n        \"\"\"\n        Test the extraction of labels from a dataset\n        \"\"\"\n        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)\n        labels = data.dataset_labels(train_set)\n        assert labels == [\"0\", \"1\", \"2\"]\n\n    def test_sort_by_length(self, train_file):\n        \"\"\"\n        There are two unique lengths in the toy dataset\n        \"\"\"\n        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)\n        sorted_dataset = data.sort_dataset_by_len(train_set)\n        assert list(sorted_dataset.keys()) == [4, 5]\n        assert len(sorted_dataset[4]) == len(train_set) // 3\n        assert len(sorted_dataset[5]) == 2 * len(train_set) // 3\n\n    def test_check_labels(self, train_file):\n        \"\"\"\n        Check that an exception is thrown for an unknown label\n        \"\"\"\n        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)\n        labels = sorted(set([x[\"sentiment\"] for x in DATASET]))\n        assert len(labels) > 1\n        data.check_labels(labels, train_set)\n        with pytest.raises(RuntimeError):\n            data.check_labels(labels[:1], train_set)\n\n"
  },
  {
    "path": "stanza/tests/classifiers/test_process_utils.py",
    "content": "\"\"\"\nA few tests of the utils module for the sentiment datasets\n\"\"\"\n\nimport os\nimport pytest\n\nimport stanza\n\nfrom stanza.models.classifiers import data\nfrom stanza.models.classifiers.data import SentimentDatum\nfrom stanza.models.classifiers.utils import WVType\nfrom stanza.utils.datasets.sentiment import process_utils\n\nfrom stanza.tests import TEST_MODELS_DIR\nfrom stanza.tests.classifiers.test_data import train_file, dev_file, test_file\n\n\ndef test_write_list(tmp_path, train_file):\n    \"\"\"\n    Test that writing a single list of items to an output file works\n    \"\"\"\n    train_set = data.read_dataset(train_file, WVType.OTHER, 1)\n\n    dataset_file = tmp_path / \"foo.json\"\n    process_utils.write_list(dataset_file, train_set)\n\n    train_copy = data.read_dataset(dataset_file, WVType.OTHER, 1)\n    assert train_copy == train_set\n\ndef test_write_dataset(tmp_path, train_file, dev_file, test_file):\n    \"\"\"\n    Test that writing all three parts of a dataset works\n    \"\"\"\n    dataset = [data.read_dataset(filename, WVType.OTHER, 1) for filename in (train_file, dev_file, test_file)]\n    process_utils.write_dataset(dataset, tmp_path, \"en_test\")\n\n    expected_files = ['en_test.train.json', 'en_test.dev.json', 'en_test.test.json']\n    dataset_files = os.listdir(tmp_path)\n    assert sorted(dataset_files) == sorted(expected_files)\n\n    for filename, expected in zip(expected_files, dataset):\n        written = data.read_dataset(tmp_path / filename, WVType.OTHER, 1)\n        assert written == expected\n\ndef test_read_snippets(tmp_path):\n    \"\"\"\n    Test the basic operation of the read_snippets function\n    \"\"\"\n    filename = tmp_path / \"foo.csv\"\n    with open(filename, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(\"FOO\\tThis is a test\\thappy\\n\")\n        fout.write(\"FOO\\tThis is a second sentence\\tsad\\n\")\n\n    nlp = stanza.Pipeline(\"en\", dir=TEST_MODELS_DIR, processors=\"tokenize\", download_method=None)\n\n    mapping = {\"happy\": 0, \"sad\": 1}\n\n    snippets = process_utils.read_snippets(filename, 2, 1, \"en\", mapping, nlp=nlp)\n    assert len(snippets) == 2\n    assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),\n                        SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence'])]\n\ndef test_read_snippets_two_columns(tmp_path):\n    \"\"\"\n    Test what happens when multiple columns are combined for the sentiment value\n    \"\"\"\n    filename = tmp_path / \"foo.csv\"\n    with open(filename, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(\"FOO\\tThis is a test\\thappy\\tfoo\\n\")\n        fout.write(\"FOO\\tThis is a second sentence\\tsad\\tbar\\n\")\n        fout.write(\"FOO\\tThis is a third sentence\\tsad\\tfoo\\n\")\n\n    nlp = stanza.Pipeline(\"en\", dir=TEST_MODELS_DIR, processors=\"tokenize\", download_method=None)\n\n    mapping = {(\"happy\", \"foo\"): 0, (\"sad\", \"bar\"): 1, (\"sad\", \"foo\"): 2}\n\n    snippets = process_utils.read_snippets(filename, (2,3), 1, \"en\", mapping, nlp=nlp)\n    assert len(snippets) == 3\n    assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),\n                        SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence']),\n                        SentimentDatum(sentiment=2, text=['This', 'is', 'a', 'third', 'sentence'])]\n\n"
  },
  {
    "path": "stanza/tests/common/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/common/test_bert_embedding.py",
    "content": "import pytest\nimport torch\n\nfrom stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nBERT_MODEL = \"hf-internal-testing/tiny-bert\"\n\n@pytest.fixture(scope=\"module\")\ndef tiny_bert():\n    m, t = load_bert(BERT_MODEL)\n    return m, t\n\ndef test_load_bert(tiny_bert):\n    \"\"\"\n    Empty method that just tests loading the bert\n    \"\"\"\n    m, t = tiny_bert\n\ndef test_run_bert(tiny_bert):\n    m, t = tiny_bert\n    device = next(m.parameters()).device\n    extract_bert_embeddings(BERT_MODEL, t, m, [[\"This\", \"is\", \"a\", \"test\"]], device, True)\n\ndef test_run_bert_empty_word(tiny_bert):\n    m, t = tiny_bert\n    device = next(m.parameters()).device\n    foo = extract_bert_embeddings(BERT_MODEL, t, m, [[\"This\", \"is\", \"-\", \"a\", \"test\"]], device, True)\n    bar = extract_bert_embeddings(BERT_MODEL, t, m, [[\"This\", \"is\", \"\", \"a\", \"test\"]], device, True)\n\n    assert len(foo) == 1\n    assert torch.allclose(foo[0], bar[0])\n"
  },
  {
    "path": "stanza/tests/common/test_char_model.py",
    "content": "\"\"\"\nCurrently tests a few configurations of files for creating a charlm vocab\n\nAlso has a skeleton test of loading & saving a charlm\n\"\"\"\n\nfrom collections import Counter\nimport glob\nimport lzma\nimport os\nimport tempfile\n\nimport pytest\n\nfrom stanza.models import charlm\nfrom stanza.models.common import char_model\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nfake_text_1 = \"\"\"\nUnban mox opal!\nI hate watching Peppa Pig\n\"\"\"\n\nfake_text_2 = \"\"\"\nThis is plastic cheese\n\"\"\"\n\nclass TestCharModel:\n    def test_single_file_vocab(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            sample_file = os.path.join(tempdir, \"text.txt\")\n            with open(sample_file, \"w\", encoding=\"utf-8\") as fout:\n                fout.write(fake_text_1)\n            vocab = char_model.build_charlm_vocab(sample_file)\n\n        for i in fake_text_1:\n            assert i in vocab\n        assert \"Q\" not in vocab\n\n    def test_single_file_xz_vocab(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            sample_file = os.path.join(tempdir, \"text.txt.xz\")\n            with lzma.open(sample_file, \"wt\", encoding=\"utf-8\") as fout:\n                fout.write(fake_text_1)\n            vocab = char_model.build_charlm_vocab(sample_file)\n\n        for i in fake_text_1:\n            assert i in vocab\n        assert \"Q\" not in vocab\n\n    def test_single_file_dir_vocab(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            sample_file = os.path.join(tempdir, \"text.txt\")\n            with open(sample_file, \"w\", encoding=\"utf-8\") as fout:\n                fout.write(fake_text_1)\n            vocab = char_model.build_charlm_vocab(tempdir)\n\n        for i in fake_text_1:\n            assert i in vocab\n        assert \"Q\" not in vocab\n\n    def test_multiple_files_vocab(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            sample_file = os.path.join(tempdir, \"t1.txt\")\n            with open(sample_file, \"w\", encoding=\"utf-8\") as fout:\n                fout.write(fake_text_1)\n            sample_file = os.path.join(tempdir, \"t2.txt.xz\")\n            with lzma.open(sample_file, \"wt\", encoding=\"utf-8\") as fout:\n                fout.write(fake_text_2)\n            vocab = char_model.build_charlm_vocab(tempdir)\n\n        for i in fake_text_1:\n            assert i in vocab\n        for i in fake_text_2:\n            assert i in vocab\n        assert \"Q\" not in vocab\n\n    def test_cutoff_vocab(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            sample_file = os.path.join(tempdir, \"t1.txt\")\n            with open(sample_file, \"w\", encoding=\"utf-8\") as fout:\n                fout.write(fake_text_1)\n            sample_file = os.path.join(tempdir, \"t2.txt.xz\")\n            with lzma.open(sample_file, \"wt\", encoding=\"utf-8\") as fout:\n                fout.write(fake_text_2)\n\n            vocab = char_model.build_charlm_vocab(tempdir, cutoff=2)\n\n        counts = Counter(fake_text_1) + Counter(fake_text_2)\n        for letter, count in counts.most_common():\n            if count < 2:\n                assert letter not in vocab\n            else:\n                assert letter in vocab\n\n    def test_build_model(self):\n        \"\"\"\n        Test the whole thing on a small dataset for an iteration or two\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tempdir:\n            eval_file = os.path.join(tempdir, \"en_test.dev.txt\")\n            with open(eval_file, \"w\", encoding=\"utf-8\") as fout:\n                fout.write(fake_text_1)\n            train_file = os.path.join(tempdir, \"en_test.train.txt\")\n            with open(train_file, \"w\", encoding=\"utf-8\") as fout:\n                for i in range(1000):\n                    fout.write(fake_text_1)\n                    fout.write(\"\\n\")\n                    fout.write(fake_text_2)\n                    fout.write(\"\\n\")\n            save_name = 'en_test.forward.pt'\n            vocab_save_name = 'en_text.vocab.pt'\n            checkpoint_save_name = 'en_text.checkpoint.pt'\n            args = ['--train_file', train_file,\n                    '--eval_file', eval_file,\n                    '--eval_steps', '0', # eval once per opoch\n                    '--epochs', '2',\n                    '--cutoff', '1',\n                    '--batch_size', '%d' % len(fake_text_1),\n                    '--shorthand', 'en_test',\n                    '--save_dir', tempdir,\n                    '--save_name', save_name,\n                    '--vocab_save_name', vocab_save_name,\n                    '--checkpoint_save_name', checkpoint_save_name]\n            args = charlm.parse_args(args)\n            charlm.train(args)\n\n            assert os.path.exists(os.path.join(tempdir, vocab_save_name))\n\n            # test that saving & loading of the model worked\n            assert os.path.exists(os.path.join(tempdir, save_name))\n            model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, save_name))\n\n            # test that saving & loading of the checkpoint worked\n            assert os.path.exists(os.path.join(tempdir, checkpoint_save_name))\n            model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, checkpoint_save_name))\n            trainer = char_model.CharacterLanguageModelTrainer.load(args, os.path.join(tempdir, checkpoint_save_name))\n\n            assert trainer.global_step > 0\n            assert trainer.epoch == 2\n\n            # quick test to verify this method works with a trained model\n            charlm.get_current_lr(trainer, args)\n\n            # test loading a vocab built by the training method...\n            vocab = charlm.load_char_vocab(os.path.join(tempdir, vocab_save_name))\n            trainer = char_model.CharacterLanguageModelTrainer.from_new_model(args, vocab)\n            # ... and test the get_current_lr for an untrained model as well\n            # this test is super \"eager\"\n            assert charlm.get_current_lr(trainer, args) == args['lr0']\n\n    @pytest.fixture(scope=\"class\")\n    def english_forward(self):\n        # eg, stanza_test/models/en/forward_charlm/1billion.pt\n        models_path = os.path.join(TEST_MODELS_DIR, \"en\", \"forward_charlm\", \"*\")\n        models = glob.glob(models_path)\n        # we expect at least one English model downloaded for the tests\n        assert len(models) >= 1\n        model_file = models[0]\n        return char_model.CharacterLanguageModel.load(model_file)\n\n    @pytest.fixture(scope=\"class\")\n    def english_backward(self):\n        # eg, stanza_test/models/en/forward_charlm/1billion.pt\n        models_path = os.path.join(TEST_MODELS_DIR, \"en\", \"backward_charlm\", \"*\")\n        models = glob.glob(models_path)\n        # we expect at least one English model downloaded for the tests\n        assert len(models) >= 1\n        model_file = models[0]\n        return char_model.CharacterLanguageModel.load(model_file)\n\n    def test_load_model(self, english_forward, english_backward):\n        \"\"\"\n        Check that basic loading functions work\n        \"\"\"\n        assert english_forward.is_forward_lm\n        assert not english_backward.is_forward_lm\n\n    def test_save_load_model(self, english_forward, english_backward):\n        \"\"\"\n        Load, save, and load again\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tempdir:\n            for model in (english_forward, english_backward):\n                save_file = os.path.join(tempdir, \"resaved\", \"charlm.pt\")\n                model.save(save_file)\n                reloaded = char_model.CharacterLanguageModel.load(save_file)\n                assert model.is_forward_lm == reloaded.is_forward_lm\n"
  },
  {
    "path": "stanza/tests/common/test_chuliu_edmonds.py",
    "content": "\"\"\"\nTest some use cases of the chuliu_edmonds algorithm\n\n(currently just the tarjan implementation)\n\"\"\"\n\nimport numpy as np\nimport pytest\n\nfrom stanza.models.common.chuliu_edmonds import tarjan\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_tarjan_basic():\n    simple = np.array([0, 4, 4, 4, 0])\n    result = tarjan(simple)\n    assert result == []\n\n    simple = np.array([0, 2, 0, 4, 2, 2])\n    result = tarjan(simple)\n    assert result == []\n\ndef test_tarjan_cycle():\n    cycle_graph = np.array([0, 3, 1, 2])\n    result = tarjan(cycle_graph)\n    expected = np.array([False,  True,  True,  True])\n    assert len(result) == 1\n    np.testing.assert_array_equal(result[0], expected)\n\n    cycle_graph = np.array([0, 3, 1, 2, 5, 6, 4])\n    result = tarjan(cycle_graph)\n    assert len(result) == 2\n    expected = [np.array([False,  True,  True,  True, False, False, False]),\n                np.array([False, False, False, False,  True,  True,  True])]\n    for r, e in zip(result, expected):\n        np.testing.assert_array_equal(r, e)\n"
  },
  {
    "path": "stanza/tests/common/test_common_data.py",
    "content": "import pytest\nimport stanza\n\nfrom stanza.tests import *\nfrom stanza.models.common.data import get_augment_ratio, augment_punct\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_augment_ratio():\n    data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n    should_augment = lambda x: x >= 3\n    can_augment = lambda x: x >= 4\n    # check that zero is returned if no augmentation is needed\n    # which will be the case since 2 are already satisfactory\n    assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.1) == 0.0\n\n    # this should throw an error\n    with pytest.raises(AssertionError):\n        get_augment_ratio(data, can_augment, should_augment)\n\n    # with a desired ratio of 0.4,\n    # there are already 2 that don't need augmenting\n    # and 7 that are eligible to be augmented\n    # so 2/7 will need to be augmented\n    assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.4) == pytest.approx(2/7)\n\ndef test_augment_punct():\n    data = [[\"Simple\", \"test\", \".\"]]\n    should_augment = lambda x: x[-1] == \".\"\n    can_augment = should_augment\n    new_data = augment_punct(data, 1.0, should_augment, can_augment)\n    assert new_data == [[\"Simple\", \"test\"]]\n"
  },
  {
    "path": "stanza/tests/common/test_confusion.py",
    "content": "\"\"\"\nTest a couple simple confusion matrices and output formats\n\"\"\"\n\nfrom collections import defaultdict\nimport pytest\n\nfrom stanza.utils.confusion import format_confusion, confusion_to_f1, confusion_to_macro_f1, confusion_to_weighted_f1\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\n@pytest.fixture\ndef simple_confusion():\n    confusion = defaultdict(lambda: defaultdict(int))\n    confusion[\"B-ORG\"][\"B-ORG\"] = 1\n    confusion[\"B-ORG\"][\"B-PER\"] = 1\n    confusion[\"E-ORG\"][\"E-ORG\"] = 1\n    confusion[\"E-ORG\"][\"E-PER\"] = 1\n    confusion[\"O\"][\"O\"] = 4\n    return confusion\n\n@pytest.fixture\ndef short_confusion():\n    \"\"\"\n    Same thing, but with a short name.  This should not be sorted by entity type\n    \"\"\"\n    confusion = defaultdict(lambda: defaultdict(int))\n    confusion[\"A\"][\"B-ORG\"] = 1\n    confusion[\"B-ORG\"][\"B-PER\"] = 1\n    confusion[\"E-ORG\"][\"E-ORG\"] = 1\n    confusion[\"E-ORG\"][\"E-PER\"] = 1\n    confusion[\"O\"][\"O\"] = 4\n    return confusion\n\nEXPECTED_SIMPLE_OUTPUT = \"\"\"\n     t\\\\p      O B-ORG E-ORG B-PER E-PER\n        O     4     0     0     0     0\n    B-ORG     0     1     0     1     0\n    E-ORG     0     0     1     0     1\n    B-PER     0     0     0     0     0\n    E-PER     0     0     0     0     0\n\"\"\"[1:-1]  # don't want to strip\n\nEXPECTED_SHORT_OUTPUT = \"\"\"\n     t\\\\p      O     A B-ORG B-PER E-ORG E-PER\n        O     4     0     0     0     0     0\n        A     0     0     1     0     0     0\n    B-ORG     0     0     0     1     0     0\n    B-PER     0     0     0     0     0     0\n    E-ORG     0     0     0     0     1     1\n    E-PER     0     0     0     0     0     0\n\"\"\"[1:-1]\n\nEXPECTED_HIDE_BLANK_SHORT_OUTPUT = \"\"\"\n     t\\\\p      O B-ORG E-ORG B-PER E-PER\n        O     4     0     0     0     0\n        A     0     1     0     0     0\n    B-ORG     0     0     0     1     0\n    E-ORG     0     0     1     0     1\n\"\"\"[1:-1]\n\ndef test_simple_output(simple_confusion):\n    assert EXPECTED_SIMPLE_OUTPUT == format_confusion(simple_confusion)\n\ndef test_short_output(short_confusion):\n    assert EXPECTED_SHORT_OUTPUT == format_confusion(short_confusion)\n\ndef test_hide_blank_short_output(short_confusion):\n    assert EXPECTED_HIDE_BLANK_SHORT_OUTPUT == format_confusion(short_confusion, hide_blank=True)\n\ndef test_macro_f1(simple_confusion, short_confusion):\n    assert confusion_to_macro_f1(simple_confusion) == pytest.approx(0.466666666666)\n    assert confusion_to_macro_f1(short_confusion) == pytest.approx(0.277777777777)\n\ndef test_weighted_f1(simple_confusion, short_confusion):\n    assert confusion_to_weighted_f1(simple_confusion) == pytest.approx(0.83333333)\n    assert confusion_to_weighted_f1(short_confusion) == pytest.approx(0.66666666)\n\n    assert confusion_to_weighted_f1(simple_confusion, exclude=[\"O\"]) == pytest.approx(0.66666666)\n    assert confusion_to_weighted_f1(short_confusion, exclude=[\"O\"]) == pytest.approx(0.33333333)\n\n"
  },
  {
    "path": "stanza/tests/common/test_constant.py",
    "content": "\"\"\"\nTest the conversion to lcodes and splitting of dataset names\n\"\"\"\n\nimport tempfile\n\nimport pytest\n\nimport stanza\nfrom stanza.models.common.constant import treebank_to_short_name, lang_to_langcode, is_right_to_left, two_to_three_letters, langlower2lcode\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_treebank():\n    \"\"\"\n    Test the entire treebank name conversion\n    \"\"\"\n    # conversion of a UD_ name\n    assert \"hi_hdtb\" == treebank_to_short_name(\"UD_Hindi-HDTB\")\n    # conversion of names without UD\n    assert \"hi_fire2013\" == treebank_to_short_name(\"Hindi-fire2013\")\n    assert \"hi_fire2013\" == treebank_to_short_name(\"Hindi-Fire2013\")\n    assert \"hi_fire2013\" == treebank_to_short_name(\"Hindi-FIRE2013\")\n    # already short names are generally preserved\n    assert \"hi_fire2013\" == treebank_to_short_name(\"hi-fire2013\")\n    assert \"hi_fire2013\" == treebank_to_short_name(\"hi_fire2013\")\n    # a special case\n    assert \"zh-hant_pud\" == treebank_to_short_name(\"UD_Chinese-PUD\")\n    # a special case already converted once\n    assert \"zh-hant_pud\" == treebank_to_short_name(\"zh-hant_pud\")\n    assert \"zh-hant_pud\" == treebank_to_short_name(\"zh-hant-pud\")\n    assert \"zh-hans_gsdsimp\" == treebank_to_short_name(\"zh-hans_gsdsimp\")\n\n    assert \"wo_masakhane\" == treebank_to_short_name(\"wo_masakhane\")\n    assert \"wo_masakhane\" == treebank_to_short_name(\"wol_masakhane\")\n    assert \"wo_masakhane\" == treebank_to_short_name(\"Wol_masakhane\")\n    assert \"wo_masakhane\" == treebank_to_short_name(\"wolof_masakhane\")\n    assert \"wo_masakhane\" == treebank_to_short_name(\"Wolof_masakhane\")\n\ndef test_lang_to_langcode():\n    assert \"hi\" == lang_to_langcode(\"Hindi\")\n    assert \"hi\" == lang_to_langcode(\"HINDI\")\n    assert \"hi\" == lang_to_langcode(\"hindi\")\n    assert \"hi\" == lang_to_langcode(\"HI\")\n    assert \"hi\" == lang_to_langcode(\"hi\")\n\ndef test_right_to_left():\n    assert is_right_to_left(\"ar\")\n    assert is_right_to_left(\"Arabic\")\n\n    assert not is_right_to_left(\"en\")\n    assert not is_right_to_left(\"English\")\n\ndef test_two_to_three():\n    assert lang_to_langcode(\"Wolof\") == \"wo\"\n    assert lang_to_langcode(\"wol\") == \"wo\"\n\n    assert \"wo\" in two_to_three_letters\n    assert two_to_three_letters[\"wo\"] == \"wol\"\n\ndef test_langlower():\n    assert lang_to_langcode(\"WOLOF\") == \"wo\"\n    assert lang_to_langcode(\"nOrWeGiAn\") == \"nb\"\n\n    assert \"soj\" == langlower2lcode[\"soi\"]\n    assert \"soj\" == langlower2lcode[\"sohi\"]\n"
  },
  {
    "path": "stanza/tests/common/test_data_conversion.py",
    "content": "\"\"\"\nBasic tests of the data conversion\n\"\"\"\n\nimport io\nimport pytest\nimport tempfile\nfrom zipfile import ZipFile\n\nimport stanza\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models.common.doc import Document\nfrom stanza.tests import *\n\npytestmark = pytest.mark.pipeline\n\n# data for testing\nCONLL = [[['1', 'Nous', 'il', 'PRON', '_', 'Number=Plur|Person=1|PronType=Prs', '3', 'nsubj', '_', 'start_char=0|end_char=4'],\n          ['2', 'avons', 'avoir', 'AUX', '_', 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', '3', 'aux:tense', '_', 'start_char=5|end_char=10'],\n          ['3', 'atteint', 'atteindre', 'VERB', '_', 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', '0', 'root', '_', 'start_char=11|end_char=18'],\n          ['4', 'la', 'le', 'DET', '_', 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', '5', 'det', '_', 'start_char=19|end_char=21'],\n          ['5', 'fin', 'fin', 'NOUN', '_', 'Gender=Fem|Number=Sing', '3', 'obj', '_', 'start_char=22|end_char=25'],\n          ['6-7', 'du', '_', '_', '_', '_', '_', '_', '_', 'start_char=26|end_char=28'],\n          ['6', 'de', 'de', 'ADP', '_', '_', '8', 'case', '_', '_'],\n          ['7', 'le', 'le', 'DET', '_', 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', '8', 'det', '_', '_'],\n          ['8', 'sentier', 'sentier', 'NOUN', '_', 'Gender=Masc|Number=Sing', '5', 'nmod', '_', 'start_char=29|end_char=36'],\n          ['9', '.', '.', 'PUNCT', '_', '_', '3', 'punct', '_', 'start_char=36|end_char=37']]]\n\n\nDICT = [[{'id': (1,), 'text': 'Nous', 'lemma': 'il', 'upos': 'PRON', 'feats': 'Number=Plur|Person=1|PronType=Prs', 'head': 3, 'deprel': 'nsubj', 'misc': 'start_char=0|end_char=4'},\n         {'id': (2,), 'text': 'avons', 'lemma': 'avoir', 'upos': 'AUX', 'feats': 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', 'head': 3, 'deprel': 'aux:tense', 'misc': 'start_char=5|end_char=10'},\n         {'id': (3,), 'text': 'atteint', 'lemma': 'atteindre', 'upos': 'VERB', 'feats': 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', 'head': 0, 'deprel': 'root', 'misc': 'start_char=11|end_char=18'},\n         {'id': (4,), 'text': 'la', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', 'head': 5, 'deprel': 'det', 'misc': 'start_char=19|end_char=21'},\n         {'id': (5,), 'text': 'fin', 'lemma': 'fin', 'upos': 'NOUN', 'feats': 'Gender=Fem|Number=Sing', 'head': 3, 'deprel': 'obj', 'misc': 'start_char=22|end_char=25'},\n         {'id': (6, 7), 'text': 'du', 'misc': 'start_char=26|end_char=28'},\n         {'id': (6,), 'text': 'de', 'lemma': 'de', 'upos': 'ADP', 'head': 8, 'deprel': 'case'},\n         {'id': (7,), 'text': 'le', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', 'head': 8, 'deprel': 'det'},\n         {'id': (8,), 'text': 'sentier', 'lemma': 'sentier', 'upos': 'NOUN', 'feats': 'Gender=Masc|Number=Sing', 'head': 5, 'deprel': 'nmod', 'misc': 'start_char=29|end_char=36'},\n         {'id': (9,), 'text': '.', 'lemma': '.', 'upos': 'PUNCT', 'head': 3, 'deprel': 'punct', 'misc': 'start_char=36|end_char=37'}]]\n\ndef test_conll_to_dict():\n    dicts, empty = CoNLL.convert_conll(CONLL)\n    assert dicts == DICT\n    assert len(dicts) == len(empty)\n    assert all(len(x) == 0 for x in empty)\n\ndef test_dict_to_conll():\n    document = Document(DICT)\n    # :c = no comments\n    conll = [[sentence.split(\"\\t\") for sentence in doc.split(\"\\n\")] for doc in \"{:c}\".format(document).split(\"\\n\\n\")]\n    assert conll == CONLL\n\ndef test_dict_to_doc_and_doc_to_dict():\n    \"\"\"\n    Test the conversion from raw dict to Document and back\n\n    This code path will first turn start_char|end_char into start_char & end_char fields in the Document\n    That version to a dict will have separate fields for each of those\n    Finally, the conversion from that dict to a list of conll entries should convert that back to misc\n    \"\"\"\n    document = Document(DICT)\n    dicts = document.to_dict()\n    document = Document(dicts)\n    conll = [[sentence.split(\"\\t\") for sentence in doc.split(\"\\n\")] for doc in \"{:c}\".format(document).split(\"\\n\\n\")]\n    assert conll == CONLL\n\n# sample is two sentences long so that the tests check multiple sentences\nRUSSIAN_SAMPLE=\"\"\"\n# sent_id = yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253\n# genre = review\n# text = Как- то слишком мало цветов получают актёры после спектакля.\n1\tКак\tкак-то\tADV\t_\tDegree=Pos|PronType=Ind\t7\tadvmod\t_\tSpaceAfter=No\n2\t-\t-\tPUNCT\t_\t_\t3\tpunct\t_\t_\n3\tто\tто\tPART\t_\t_\t1\tlist\t_\tdeprel=list:goeswith\n4\tслишком\tслишком\tADV\t_\tDegree=Pos\t5\tadvmod\t_\t_\n5\tмало\tмало\tADV\t_\tDegree=Pos\t6\tadvmod\t_\t_\n6\tцветов\tцветок\tNOUN\t_\tAnimacy=Inan|Case=Gen|Gender=Masc|Number=Plur\t7\tobj\t_\t_\n7\tполучают\tполучать\tVERB\t_\tAspect=Imp|Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act\t0\troot\t_\t_\n8\tактёры\tактер\tNOUN\t_\tAnimacy=Anim|Case=Nom|Gender=Masc|Number=Plur\t7\tnsubj\t_\t_\n9\tпосле\tпосле\tADP\t_\t_\t10\tcase\t_\t_\n10\tспектакля\tспектакль\tNOUN\t_\tAnimacy=Inan|Case=Gen|Gender=Masc|Number=Sing\t7\tobl\t_\tSpaceAfter=No\n11\t.\t.\tPUNCT\t_\t_\t7\tpunct\t_\t_\n\n# sent_id = 4\n# genre = social\n# text = В женщине важна верность, а не красота.\n1\tВ\tв\tADP\t_\t_\t2\tcase\t_\t_\n2\tженщине\tженщина\tNOUN\t_\tAnimacy=Anim|Case=Loc|Gender=Fem|Number=Sing\t3\tobl\t_\t_\n3\tважна\tважный\tADJ\t_\tDegree=Pos|Gender=Fem|Number=Sing|Variant=Short\t0\troot\t_\t_\n4\tверность\tверность\tNOUN\t_\tAnimacy=Inan|Case=Nom|Gender=Fem|Number=Sing\t3\tnsubj\t_\tSpaceAfter=No\n5\t,\t,\tPUNCT\t_\t_\t8\tpunct\t_\t_\n6\tа\tа\tCCONJ\t_\t_\t8\tcc\t_\t_\n7\tне\tне\tPART\t_\tPolarity=Neg\t8\tadvmod\t_\t_\n8\tкрасота\tкрасота\tNOUN\t_\tAnimacy=Inan|Case=Nom|Gender=Fem|Number=Sing\t4\tconj\t_\tSpaceAfter=No\n9\t.\t.\tPUNCT\t_\t_\t3\tpunct\t_\t_\n\"\"\".strip()\n\nRUSSIAN_TEXT = [\"Как- то слишком мало цветов получают актёры после спектакля.\", \"В женщине важна верность, а не красота.\"]\nRUSSIAN_IDS = [\"yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253\", \"4\"]\n\ndef check_russian_doc(doc):\n    \"\"\"\n    Refactored the test for the Russian doc so we can use it to test various file methods\n    \"\"\"\n    lines = RUSSIAN_SAMPLE.split(\"\\n\")\n    assert len(doc.sentences) == 2\n    assert lines[0] == doc.sentences[0].comments[0]\n    assert lines[1] == doc.sentences[0].comments[1]\n    assert lines[2] == doc.sentences[0].comments[2]\n    for sent_idx, (expected_text, expected_id, sentence) in enumerate(zip(RUSSIAN_TEXT, RUSSIAN_IDS, doc.sentences)):\n        assert expected_text == sentence.text\n        assert expected_id == sentence.sent_id\n        assert sent_idx == sentence.index\n        assert len(sentence.comments) == 3\n        assert not sentence.has_enhanced_dependencies()\n\n    sentences = \"{:C}\".format(doc)\n    sentences = sentences.split(\"\\n\\n\")\n    assert len(sentences) == 2\n\n    sentence = sentences[0].split(\"\\n\")\n    assert len(sentence) == 14\n    assert lines[0] == sentence[0]\n    assert lines[1] == sentence[1]\n    assert lines[2] == sentence[2]\n\n    # assert that the weird deprel=list:goeswith was properly handled\n    assert doc.sentences[0].words[2].head == 1\n    assert doc.sentences[0].words[2].deprel == \"list:goeswith\"\n\ndef test_write_russian_doc(tmp_path):\n    \"\"\"\n    Specifically test the write_doc2conll method\n    \"\"\"\n    filename = tmp_path / \"russian.conll\"\n    doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)\n    check_russian_doc(doc)\n    CoNLL.write_doc2conll(doc, filename)\n\n    with open(filename, encoding=\"utf-8\") as fin:\n        text = fin.read()\n\n    # the conll docs have to end with \\n\\n\n    assert text.endswith(\"\\n\\n\")\n\n    # but to compare against the original, strip off the whitespace\n    text = text.strip()\n\n    # we skip the first sentence because the \"deprel=list:goeswith\" is weird\n    # note that the deprel itself is checked in check_russian_doc\n    text = text[text.find(\"# sent_id = 4\"):]\n    sample = RUSSIAN_SAMPLE[RUSSIAN_SAMPLE.find(\"# sent_id = 4\"):]\n    assert text == sample\n\n    doc2 = CoNLL.conll2doc(filename)\n    check_russian_doc(doc2)\n\n# random sentence from EN_Pronouns\nENGLISH_SAMPLE = \"\"\"\n# newdoc\n# sent_id = 1\n# text = It is hers.\n# previous = Which person owns this?\n# comment = copular subject\n1\tIt\tit\tPRON\tPRP\tNumber=Sing|Person=3|PronType=Prs\t3\tnsubj\t_\t_\n2\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\tcop\t_\t_\n3\thers\thers\tPRON\tPRP\tGender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t0\troot\t_\tSpaceAfter=No\n4\t.\t.\tPUNCT\t.\t_\t3\tpunct\t_\t_\n\"\"\".strip()\n\ndef test_write_to_io():\n    doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE)\n    output = io.StringIO()\n    CoNLL.write_doc2conll(doc, output)\n    output_value = output.getvalue()\n    assert output_value.endswith(\"\\n\\n\")\n    assert output_value.strip() == ENGLISH_SAMPLE\n\ndef test_write_doc2conll_append(tmp_path):\n    doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE)\n    filename = tmp_path / \"english.conll\"\n    CoNLL.write_doc2conll(doc, filename)\n    CoNLL.write_doc2conll(doc, filename, mode=\"a\")\n\n    with open(filename) as fin:\n        text = fin.read()\n    expected = ENGLISH_SAMPLE + \"\\n\\n\" + ENGLISH_SAMPLE + \"\\n\\n\"\n    assert text == expected\n\ndef test_doc_with_comments():\n    \"\"\"\n    Test that a doc with comments gets converted back with comments\n    \"\"\"\n    doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)\n    check_russian_doc(doc)\n\ndef test_unusual_misc():\n    \"\"\"\n    The above RUSSIAN_SAMPLE resulted in a blank misc field in one particular implementation of the conll code\n    (the below test would fail)\n    \"\"\"\n    doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)\n    sentences = \"{:C}\".format(doc).split(\"\\n\\n\")\n    assert len(sentences) == 2\n    sentence = sentences[0].split(\"\\n\")\n    assert len(sentence) == 14\n\n    for word in sentence:\n        pieces = word.split(\"\\t\")\n        assert len(pieces) == 1 or len(pieces) == 10\n        if len(pieces) == 10:\n            assert all(piece for piece in pieces)\n\ndef test_file():\n    \"\"\"\n    Test loading a doc from a file\n    \"\"\"\n    with tempfile.TemporaryDirectory() as tempdir:\n        filename = os.path.join(tempdir, \"russian.conll\")\n        with open(filename, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(RUSSIAN_SAMPLE)\n        doc = CoNLL.conll2doc(input_file=filename)\n        check_russian_doc(doc)\n\ndef test_zip_file():\n    \"\"\"\n    Test loading a doc from a zip file\n    \"\"\"\n    with tempfile.TemporaryDirectory() as tempdir:\n        zip_file = os.path.join(tempdir, \"russian.zip\")\n        filename = \"russian.conll\"\n        with ZipFile(zip_file, \"w\") as zout:\n            with zout.open(filename, \"w\") as fout:\n                fout.write(RUSSIAN_SAMPLE.encode())\n\n        doc = CoNLL.conll2doc(input_file=filename, zip_file=zip_file)\n        check_russian_doc(doc)\n\nSIMPLE_NER = \"\"\"\n# text = Teferi's best friend is Karn\n# sent_id = 0\n1\tTeferi\t_\t_\t_\t_\t0\t_\t_\tstart_char=0|end_char=6|ner=S-PERSON\n2\t's\t_\t_\t_\t_\t1\t_\t_\tstart_char=6|end_char=8|ner=O\n3\tbest\t_\t_\t_\t_\t2\t_\t_\tstart_char=9|end_char=13|ner=O\n4\tfriend\t_\t_\t_\t_\t3\t_\t_\tstart_char=14|end_char=20|ner=O\n5\tis\t_\t_\t_\t_\t4\t_\t_\tstart_char=21|end_char=23|ner=O\n6\tKarn\t_\t_\t_\t_\t5\t_\t_\tstart_char=24|end_char=28|ner=S-PERSON\n\"\"\".strip()\n\ndef test_simple_ner_conversion():\n    \"\"\"\n    Test that tokens get properly created with NER tags\n    \"\"\"\n    doc = CoNLL.conll2doc(input_str=SIMPLE_NER)\n    assert len(doc.sentences) == 1\n    sentence = doc.sentences[0]\n    assert len(sentence.tokens) == 6\n    EXPECTED_NER = [\"S-PERSON\", \"O\", \"O\", \"O\", \"O\", \"S-PERSON\"]\n    for token, ner in zip(sentence.tokens, EXPECTED_NER):\n        assert token.ner == ner\n        # check that the ner, start_char, end_char fields were not put on the token's misc\n        # those should all be set as specific fields on the token\n        assert not token.misc\n        assert len(token.words) == 1\n        # they should also not reach the word's misc field\n        assert not token.words[0].misc\n\n    conll = \"{:C}\".format(doc)\n    assert conll == SIMPLE_NER\n\nMWT_NER = \"\"\"\n# text = This makes John's headache worse\n# sent_id = 0\n1\tThis\t_\t_\t_\t_\t0\t_\t_\tstart_char=0|end_char=4|ner=O\n2\tmakes\t_\t_\t_\t_\t1\t_\t_\tstart_char=5|end_char=10|ner=O\n3-4\tJohn's\t_\t_\t_\t_\t_\t_\t_\tstart_char=11|end_char=17|ner=S-PERSON\n3\tJohn\t_\t_\t_\t_\t2\t_\t_\t_\n4\t's\t_\t_\t_\t_\t3\t_\t_\t_\n5\theadache\t_\t_\t_\t_\t4\t_\t_\tstart_char=18|end_char=26|ner=O\n6\tworse\t_\t_\t_\t_\t5\t_\t_\tstart_char=27|end_char=32|ner=O\n\"\"\".strip()\n\ndef test_mwt_ner_conversion():\n    \"\"\"\n    Test that tokens including MWT get properly created with NER tags\n\n    Note that this kind of thing happens with the EWT tokenizer for English, for example\n    \"\"\"\n    doc = CoNLL.conll2doc(input_str=MWT_NER)\n    assert len(doc.sentences) == 1\n    sentence = doc.sentences[0]\n    assert len(sentence.tokens) == 5\n    assert not sentence.has_enhanced_dependencies()\n    EXPECTED_NER = [\"O\", \"O\", \"S-PERSON\", \"O\", \"O\"]\n    EXPECTED_WORDS = [1, 1, 2, 1, 1]\n    for token, ner, expected_words in zip(sentence.tokens, EXPECTED_NER, EXPECTED_WORDS):\n        assert token.ner == ner\n        # check that the ner, start_char, end_char fields were not put on the token's misc\n        # those should all be set as specific fields on the token\n        assert not token.misc\n        assert len(token.words) == expected_words\n        # they should also not reach the word's misc field\n        assert not token.words[0].misc\n\n    conll = \"{:C}\".format(doc)\n    assert conll == MWT_NER\n\nALL_OFFSETS_CONLLU = \"\"\"\n# text = This makes John's headache worse\n# sent_id = 0\n1\tThis\t_\t_\t_\t_\t0\t_\t_\tstart_char=0|end_char=4\n2\tmakes\t_\t_\t_\t_\t1\t_\t_\tstart_char=5|end_char=10\n3-4\tJohn's\t_\t_\t_\t_\t_\t_\t_\tstart_char=11|end_char=17\n3\tJohn\t_\t_\t_\t_\t2\t_\t_\tstart_char=11|end_char=15\n4\t's\t_\t_\t_\t_\t3\t_\t_\tstart_char=15|end_char=17\n5\theadache\t_\t_\t_\t_\t4\t_\t_\tstart_char=18|end_char=26\n6\tworse\t_\t_\t_\t_\t5\t_\t_\tSpaceAfter=No|start_char=27|end_char=32\n\"\"\".strip()\n\nNO_OFFSETS_CONLLU = \"\"\"\n# text = This makes John's headache worse\n# sent_id = 0\n1\tThis\t_\t_\t_\t_\t0\t_\t_\t_\n2\tmakes\t_\t_\t_\t_\t1\t_\t_\t_\n3-4\tJohn's\t_\t_\t_\t_\t_\t_\t_\t_\n3\tJohn\t_\t_\t_\t_\t2\t_\t_\t_\n4\t's\t_\t_\t_\t_\t3\t_\t_\t_\n5\theadache\t_\t_\t_\t_\t4\t_\t_\t_\n6\tworse\t_\t_\t_\t_\t5\t_\t_\tSpaceAfter=No\n\"\"\".strip()\n\nNO_COMMENTS_NO_OFFSETS_CONLLU = \"\"\"\n1\tThis\t_\t_\t_\t_\t0\t_\t_\t_\n2\tmakes\t_\t_\t_\t_\t1\t_\t_\t_\n3-4\tJohn's\t_\t_\t_\t_\t_\t_\t_\t_\n3\tJohn\t_\t_\t_\t_\t2\t_\t_\t_\n4\t's\t_\t_\t_\t_\t3\t_\t_\t_\n5\theadache\t_\t_\t_\t_\t4\t_\t_\t_\n6\tworse\t_\t_\t_\t_\t5\t_\t_\tSpaceAfter=No\n\"\"\".strip()\n\n\ndef test_no_offsets_output():\n    doc = CoNLL.conll2doc(input_str=ALL_OFFSETS_CONLLU)\n    assert len(doc.sentences) == 1\n    sentence = doc.sentences[0]\n    assert len(sentence.tokens) == 5\n\n    conll = \"{:C}\".format(doc)\n    assert conll == ALL_OFFSETS_CONLLU\n\n    conll = \"{:C-o}\".format(doc)\n    assert conll == NO_OFFSETS_CONLLU\n\n    conll = \"{:c-o}\".format(doc)\n    assert conll == NO_COMMENTS_NO_OFFSETS_CONLLU\n\n# A random sentence from et_ewt-ud-train.conllu\n# which we use to test the deps conversion for multiple deps\nESTONIAN_DEPS = \"\"\"\n# newpar\n# sent_id = aia_foorum_37\n# text = Sestpeale ei mõistagi neid, kes koduaias sortidega tegelevad.\n1\tSestpeale\tsest_peale\tADV\tD\t_\t3\tadvmod\t3:advmod\t_\n2\tei\tei\tAUX\tV\tPolarity=Neg\t3\taux\t3:aux\t_\n3\tmõistagi\tmõistma\tVERB\tV\tConnegative=Yes|Mood=Ind|Tense=Pres|VerbForm=Fin|Voice=Act\t0\troot\t0:root\t_\n4\tneid\ttema\tPRON\tP\tCase=Par|Number=Plur|Person=3|PronType=Prs\t3\tobj\t3:obj|9:nsubj\tSpaceAfter=No\n5\t,\t,\tPUNCT\tZ\t_\t9\tpunct\t9:punct\t_\n6\tkes\tkes\tPRON\tP\tCase=Nom|Number=Plur|PronType=Int,Rel\t9\tnsubj\t4:ref\t_\n7\tkoduaias\tkodu_aed\tNOUN\tS\tCase=Ine|Number=Sing\t9\tobl\t9:obl\t_\n8\tsortidega\tsort\tNOUN\tS\tCase=Com|Number=Plur\t9\tobl\t9:obl\t_\n9\ttegelevad\ttegelema\tVERB\tV\tMood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act\t4\tacl:relcl\t4:acl\tSpaceAfter=No\n10\t.\t.\tPUNCT\tZ\t_\t3\tpunct\t3:punct\t_\n\"\"\".strip()\n\ndef test_deps_conversion():\n    doc = CoNLL.conll2doc(input_str=ESTONIAN_DEPS)\n    assert len(doc.sentences) == 1\n    sentence = doc.sentences[0]\n    assert len(sentence.tokens) == 10\n    assert sentence.has_enhanced_dependencies()\n\n    word = doc.sentences[0].words[3]\n    assert word.deps == \"3:obj|9:nsubj\"\n\n    conll = \"{:C}\".format(doc)\n    assert conll == ESTONIAN_DEPS\n\nESTONIAN_EMPTY_DEPS = \"\"\"\n# sent_id = ewtb2_000035_15\n# text = Ja paari aasta pärast rôômalt maasikatele ...\n1\tJa\tja\tCCONJ\tJ\t_\t3\tcc\t5.1:cc\t_\n2\tpaari\tpaar\tNUM\tN\tCase=Gen|Number=Sing|NumForm=Word|NumType=Card\t3\tnummod\t3:nummod\t_\n3\taasta\taasta\tNOUN\tS\tCase=Gen|Number=Sing\t0\troot\t5.1:obl\t_\n4\tpärast\tpärast\tADP\tK\tAdpType=Post\t3\tcase\t3:case\t_\n5\trôômalt\trõõmsalt\tADV\tD\tTypo=Yes\t3\tadvmod\t5.1:advmod\tOrphan=Yes|CorrectForm=rõõmsalt\n5.1\tpanna\tpanema\tVERB\tV\tVerbForm=Inf\t_\t_\t0:root\tEmpty=5.1\n6\tmaasikatele\tmaasikas\tNOUN\tS\tCase=All|Number=Plur\t3\tobl\t5.1:obl\tOrphan=Yes\n7\t...\t...\tPUNCT\tZ\t_\t3\tpunct\t5.1:punct\t_\n\"\"\".strip()\n\nESTONIAN_EMPTY_END_DEPS = \"\"\"\n# sent_id = ewtb2_000035_15\n# text = Ja paari aasta pärast rôômalt maasikatele ...\n1\tJa\tja\tCCONJ\tJ\t_\t3\tcc\t5.1:cc\t_\n2\tpaari\tpaar\tNUM\tN\tCase=Gen|Number=Sing|NumForm=Word|NumType=Card\t3\tnummod\t3:nummod\t_\n3\taasta\taasta\tNOUN\tS\tCase=Gen|Number=Sing\t0\troot\t5.1:obl\t_\n4\tpärast\tpärast\tADP\tK\tAdpType=Post\t3\tcase\t3:case\t_\n5\trôômalt\trõõmsalt\tADV\tD\tTypo=Yes\t3\tadvmod\t5.1:advmod\tOrphan=Yes|CorrectForm=rõõmsalt\n5.1\tpanna\tpanema\tVERB\tV\tVerbForm=Inf\t_\t_\t0:root\tEmpty=5.1\n\"\"\".strip()\n\ndef test_empty_deps_conversion():\n    \"\"\"\n    Check that we can read and then output a sentence with empty dependencies\n    \"\"\"\n    check_empty_deps_conversion(ESTONIAN_EMPTY_DEPS, 7)\n\ndef test_empty_deps_at_end_conversion():\n    \"\"\"\n    The empty deps conversion should also work if the empty dep is at the end\n    \"\"\"\n    check_empty_deps_conversion(ESTONIAN_EMPTY_END_DEPS, 5)\n\ndef check_empty_deps_conversion(input_str, expected_words):\n    doc = CoNLL.conll2doc(input_str=input_str, ignore_gapping=False)\n    assert len(doc.sentences) == 1\n    assert len(doc.sentences[0].tokens) == expected_words\n    assert len(doc.sentences[0].words) == expected_words\n    assert len(doc.sentences[0].empty_words) == 1\n\n    sentence = doc.sentences[0]\n    conll = \"{:C}\".format(doc)\n    assert conll == input_str\n\n    sentence_dict = doc.sentences[0].to_dict()\n    assert len(sentence_dict) == expected_words + 1\n    # currently this is true for both of the examples we run\n    assert sentence_dict[5]['id'] == (5, 1)\n\n    # redo the above checks to make sure\n    # there are no weird bugs in the accessors\n    assert len(doc.sentences) == 1\n    assert len(doc.sentences[0].tokens) == expected_words\n    assert len(doc.sentences[0].words) == expected_words\n    assert len(doc.sentences[0].empty_words) == 1\n\n\nESTONIAN_DOC_ID = \"\"\"\n# doc_id = this_is_a_doc\n# sent_id = ewtb2_000035_15\n# text = Ja paari aasta pärast rôômalt maasikatele ...\n1\tJa\tja\tCCONJ\tJ\t_\t3\tcc\t5.1:cc\t_\n2\tpaari\tpaar\tNUM\tN\tCase=Gen|Number=Sing|NumForm=Word|NumType=Card\t3\tnummod\t3:nummod\t_\n3\taasta\taasta\tNOUN\tS\tCase=Gen|Number=Sing\t0\troot\t5.1:obl\t_\n4\tpärast\tpärast\tADP\tK\tAdpType=Post\t3\tcase\t3:case\t_\n5\trôômalt\trõõmsalt\tADV\tD\tTypo=Yes\t3\tadvmod\t5.1:advmod\tOrphan=Yes|CorrectForm=rõõmsalt\n5.1\tpanna\tpanema\tVERB\tV\tVerbForm=Inf\t_\t_\t0:root\tEmpty=5.1\n6\tmaasikatele\tmaasikas\tNOUN\tS\tCase=All|Number=Plur\t3\tobl\t5.1:obl\tOrphan=Yes\n7\t...\t...\tPUNCT\tZ\t_\t3\tpunct\t5.1:punct\t_\n\"\"\".strip()\n\ndef test_read_doc_id():\n    doc = CoNLL.conll2doc(input_str=ESTONIAN_DOC_ID, ignore_gapping=False)\n    assert \"{:C}\".format(doc) == ESTONIAN_DOC_ID\n    assert doc.sentences[0].doc_id == 'this_is_a_doc'\n\nSIMPLE_DEPENDENCY_INDEX_ERROR = \"\"\"\n# text = Teferi's best friend is Karn\n# sent_id = 0\n# notes = this sentence has a dependency index outside the sentence.  it should throw an IndexError\n1\tTeferi\t_\t_\t_\t_\t0\troot\t_\tstart_char=0|end_char=6|ner=S-PERSON\n2\t's\t_\t_\t_\t_\t1\tdep\t_\tstart_char=6|end_char=8|ner=O\n3\tbest\t_\t_\t_\t_\t2\tdep\t_\tstart_char=9|end_char=13|ner=O\n4\tfriend\t_\t_\t_\t_\t3\tdep\t_\tstart_char=14|end_char=20|ner=O\n5\tis\t_\t_\t_\t_\t4\tdep\t_\tstart_char=21|end_char=23|ner=O\n6\tKarn\t_\t_\t_\t_\t8\tdep\t_\tstart_char=24|end_char=28|ner=S-PERSON\n\"\"\".strip()\n\ndef test_read_dependency_errors():\n    with pytest.raises(IndexError):\n        doc = CoNLL.conll2doc(input_str=SIMPLE_DEPENDENCY_INDEX_ERROR)\n\nMULTIPLE_DOC_IDS = \"\"\"\n# doc_id = doc_1\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0020\n# text = His mother was also killed in the attack.\n1\tHis\this\tPRON\tPRP$\tCase=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t2\tnmod:poss\t2:nmod:poss\t_\n2\tmother\tmother\tNOUN\tNN\tNumber=Sing\t5\tnsubj:pass\t5:nsubj:pass\t_\n3\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t5\taux:pass\t5:aux:pass\t_\n4\talso\talso\tADV\tRB\t_\t5\tadvmod\t5:advmod\t_\n5\tkilled\tkill\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t0:root\t_\n6\tin\tin\tADP\tIN\t_\t8\tcase\t8:case\t_\n7\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t8\tdet\t8:det\t_\n8\tattack\tattack\tNOUN\tNN\tNumber=Sing\t5\tobl\t5:obl:in\tSpaceAfter=No\n9\t.\t.\tPUNCT\t.\t_\t5\tpunct\t5:punct\t_\n\n# doc_id = doc_1\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0028\n# text = This item is a small one and easily missed.\n1\tThis\tthis\tDET\tDT\tNumber=Sing|PronType=Dem\t2\tdet\t2:det\t_\n2\titem\titem\tNOUN\tNN\tNumber=Sing\t6\tnsubj\t6:nsubj|9:nsubj:pass\t_\n3\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t6\tcop\t6:cop\t_\n4\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t6\tdet\t6:det\t_\n5\tsmall\tsmall\tADJ\tJJ\tDegree=Pos\t6\tamod\t6:amod\t_\n6\tone\tone\tNOUN\tNN\tNumber=Sing\t0\troot\t0:root\t_\n7\tand\tand\tCCONJ\tCC\t_\t9\tcc\t9:cc\t_\n8\teasily\teasily\tADV\tRB\t_\t9\tadvmod\t9:advmod\t_\n9\tmissed\tmiss\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t6\tconj\t6:conj:and\tSpaceAfter=No\n10\t.\t.\tPUNCT\t.\t_\t6\tpunct\t6:punct\t_\n\n# doc_id = doc_2\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0029\n# text = But in my view it is highly significant.\n1\tBut\tbut\tCCONJ\tCC\t_\t8\tcc\t8:cc\t_\n2\tin\tin\tADP\tIN\t_\t4\tcase\t4:case\t_\n3\tmy\tmy\tPRON\tPRP$\tCase=Gen|Number=Sing|Person=1|Poss=Yes|PronType=Prs\t4\tnmod:poss\t4:nmod:poss\t_\n4\tview\tview\tNOUN\tNN\tNumber=Sing\t8\tobl\t8:obl:in\t_\n5\tit\tit\tPRON\tPRP\tCase=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs\t8\tnsubj\t8:nsubj\t_\n6\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t8\tcop\t8:cop\t_\n7\thighly\thighly\tADV\tRB\t_\t8\tadvmod\t8:advmod\t_\n8\tsignificant\tsignificant\tADJ\tJJ\tDegree=Pos\t0\troot\t0:root\tSpaceAfter=No\n9\t.\t.\tPUNCT\t.\t_\t8\tpunct\t8:punct\t_\n\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0040\n# text = The trial begins again Nov.28.\n1\tThe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t2\tdet\t2:det\t_\n2\ttrial\ttrial\tNOUN\tNN\tNumber=Sing\t3\tnsubj\t3:nsubj\t_\n3\tbegins\tbegin\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t0:root\t_\n4\tagain\tagain\tADV\tRB\t_\t3\tadvmod\t3:advmod\t_\n5\tNov.\tNovember\tPROPN\tNNP\tAbbr=Yes|Number=Sing\t3\tobl:tmod\t3:obl:tmod\tSpaceAfter=No\n6\t28\t28\tNUM\tCD\tNumForm=Digit|NumType=Card\t5\tnummod\t5:nummod\tSpaceAfter=No\n7\t.\t.\tPUNCT\t.\t_\t3\tpunct\t3:punct\t_\n\n\"\"\".lstrip()\n\ndef test_read_multiple_doc_ids():\n    docs = CoNLL.conll2multi_docs(input_str=MULTIPLE_DOC_IDS)\n    assert len(docs) == 2\n    assert len(docs[0].sentences) == 2\n    assert len(docs[1].sentences) == 2\n\n    # remove the first doc_id comment\n    text = \"\\n\".join(MULTIPLE_DOC_IDS.split(\"\\n\")[1:])\n    docs = CoNLL.conll2multi_docs(input_str=text)\n    assert len(docs) == 3\n    assert len(docs[0].sentences) == 1\n    assert len(docs[1].sentences) == 1\n    assert len(docs[2].sentences) == 2\n\nENGLISH_TEST_SENTENCE = \"\"\"\n# text = This is a test\n# sent_id = 0\n1\tThis\tthis\tPRON\tDT\tNumber=Sing|PronType=Dem\t4\tnsubj\t_\tstart_char=0|end_char=4\n2\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\tcop\t_\tstart_char=5|end_char=7\n3\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t4\tdet\t_\tstart_char=8|end_char=9\n4\ttest\ttest\tNOUN\tNN\tNumber=Sing\t0\troot\t_\tSpaceAfter=No|start_char=10|end_char=14\n\"\"\".lstrip()\n\ndef test_convert_dict():\n    doc = CoNLL.conll2doc(input_str=ENGLISH_TEST_SENTENCE)\n    converted = CoNLL.convert_dict(doc.to_dict())\n\n    expected = [[['1', 'This', 'this', 'PRON', 'DT', 'Number=Sing|PronType=Dem', '4', 'nsubj', '_', 'start_char=0|end_char=4'],\n                 ['2', 'is', 'be', 'AUX', 'VBZ', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', '4', 'cop', '_', 'start_char=5|end_char=7'],\n                 ['3', 'a', 'a', 'DET', 'DT', 'Definite=Ind|PronType=Art', '4', 'det', '_', 'start_char=8|end_char=9'],\n                 ['4', 'test', 'test', 'NOUN', 'NN', 'Number=Sing', '0', 'root', '_', 'SpaceAfter=No|start_char=10|end_char=14']]]\n\n    assert converted == expected\n\ndef test_line_numbers():\n    doc = CoNLL.conll2doc(input_str=ENGLISH_TEST_SENTENCE, keep_line_numbers=True)\n    # currently the line numbers are not output in the conllu format\n    doc_conllu = \"{:C}\\n\".format(doc)\n    assert doc_conllu == ENGLISH_TEST_SENTENCE\n\n    # currently the line numbers are not output in the dict format\n    converted = CoNLL.convert_dict(doc.to_dict())\n    expected = [[['1', 'This', 'this', 'PRON', 'DT', 'Number=Sing|PronType=Dem', '4', 'nsubj', '_', 'start_char=0|end_char=4'],\n                 ['2', 'is', 'be', 'AUX', 'VBZ', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', '4', 'cop', '_', 'start_char=5|end_char=7'],\n                 ['3', 'a', 'a', 'DET', 'DT', 'Definite=Ind|PronType=Art', '4', 'det', '_', 'start_char=8|end_char=9'],\n                 ['4', 'test', 'test', 'NOUN', 'NN', 'Number=Sing', '0', 'root', '_', 'SpaceAfter=No|start_char=10|end_char=14']]]\n    assert converted == expected\n\n    for word_idx, word in enumerate(doc.sentences[0].words):\n        # the test sentence has two comments in it\n        assert word.line_number == word_idx + 2\n\n\nSPEAKER_EXAMPLE = \"\"\"\n# sent_id = GUM_fiction_pag-57\n# speaker = Siri\n# addressee = Pag\n# text = \"Sorry.\"\n1\t\"\t\"\tPUNCT\t``\t_\t2\tpunct\t2:punct\tDiscourse=joint-sequence_m:130->128:1:_|SpaceAfter=No\n2\tSorry\tsorry\tADJ\tJJ\tDegree=Pos\t0\troot\t0:root\tMSeg=Sorr-y|SpaceAfter=No\n3\t.\t.\tPUNCT\t.\t_\t2\tpunct\t2:punct\tSpaceAfter=No\n4\t\"\t\"\tPUNCT\t''\t_\t2\tpunct\t2:punct\t_\n\"\"\".lstrip()\n\ndef test_speaker():\n    doc = CoNLL.conll2doc(input_str=SPEAKER_EXAMPLE)\n    assert len(doc.sentences) == 1\n    assert doc.sentences[0].speaker == 'Siri'\n    assert \"# speaker = Siri\" in doc.sentences[0].comments\n\n    doc.sentences[0].speaker = \"foo\"\n    assert doc.sentences[0].speaker == 'foo'\n    assert any(comment.startswith(\"# speaker\") for comment in doc.sentences[0].comments)\n    assert \"# speaker = foo\" in doc.sentences[0].comments\n\n    doc.sentences[0].speaker = None\n    assert not any(comment.startswith(\"# speaker\") for comment in doc.sentences[0].comments)\n    assert doc.sentences[0].speaker is None\n\n    doc.sentences[0].speaker = \"Siri\"\n    assert doc.sentences[0].speaker == 'Siri'\n    assert \"# speaker = Siri\" in doc.sentences[0].comments\n"
  },
  {
    "path": "stanza/tests/common/test_data_objects.py",
    "content": "\"\"\"\nBasic tests of the stanza data objects, especially the setter/getter routines\n\"\"\"\nimport pytest\n\nimport stanza\nfrom stanza.models.common.doc import Document, Sentence, Word\nfrom stanza.tests import *\n\npytestmark = pytest.mark.pipeline\n\n# data for testing\nEN_DOC = \"This is a test document. Pretty cool!\"\n\nEN_DOC_UPOS_XPOS = (('PRON_DT', 'AUX_VBZ', 'DET_DT', 'NOUN_NN', 'NOUN_NN', 'PUNCT_.'), ('ADV_RB', 'ADJ_JJ', 'PUNCT_.'))\n\nEN_DOC2 = \"Chris Manning wrote a sentence. Then another.\"\n\n@pytest.fixture(scope=\"module\")\ndef nlp_pipeline():\n    nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en')\n    return nlp\n\ndef test_readonly(nlp_pipeline):\n    Document.add_property('some_property', 123)\n    doc = nlp_pipeline(EN_DOC)\n    assert doc.some_property == 123\n    with pytest.raises(ValueError):\n        doc.some_property = 456\n\n\ndef test_getter(nlp_pipeline):\n    Word.add_property('upos_xpos', getter=lambda self: f\"{self.upos}_{self.xpos}\")\n\n    doc = nlp_pipeline(EN_DOC)\n\n    assert EN_DOC_UPOS_XPOS == tuple(tuple(word.upos_xpos for word in sentence.words) for sentence in doc.sentences)\n\ndef test_setter_getter(nlp_pipeline):\n    int2str = {0: 'ok', 1: 'good', 2: 'bad'}\n    str2int = {'ok': 0, 'good': 1, 'bad': 2}\n    def setter(self, value):\n        self._classname = str2int[value]\n    Sentence.add_property('classname', getter=lambda self: int2str[self._classname] if self._classname is not None else None, setter=setter)\n\n    doc = nlp_pipeline(EN_DOC)\n    sentence = doc.sentences[0]\n    sentence.classname = 'good'\n    assert sentence._classname == 1\n\n    # don't try this at home\n    sentence._classname = 2\n    assert sentence.classname == 'bad'\n\ndef test_backpointer(nlp_pipeline):\n    doc = nlp_pipeline(EN_DOC2)\n    ent = doc.ents[0]\n    assert ent.sent is doc.sentences[0]\n    assert list(doc.iter_words())[0].sent is doc.sentences[0]\n    assert list(doc.iter_tokens())[-1].sent is doc.sentences[-1]\n"
  },
  {
    "path": "stanza/tests/common/test_doc.py",
    "content": "import pytest\n\nimport stanza\nfrom stanza.tests import *\nfrom stanza.models.common.doc import Document, ID, TEXT, NER, CONSTITUENCY, SENTIMENT\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\n@pytest.fixture\ndef sentences_dict():\n    return [[{ID: 1, TEXT: \"unban\"},\n             {ID: 2, TEXT: \"mox\"},\n             {ID: 3, TEXT: \"opal\"}],\n            [{ID: 4, TEXT: \"ban\"},\n             {ID: 5, TEXT: \"Lurrus\"}]]\n\n@pytest.fixture\ndef doc(sentences_dict):\n    doc = Document(sentences_dict)\n    return doc\n\ndef test_basic_values(doc, sentences_dict):\n    \"\"\"\n    Test that sentences & token text are properly set when constructing a doc\n    \"\"\"\n    assert len(doc.sentences) == len(sentences_dict)\n\n    for sentence, raw_sentence in zip(doc.sentences, sentences_dict):\n        assert sentence.doc == doc\n        assert len(sentence.tokens) == len(raw_sentence)\n        for token, raw_token in zip(sentence.tokens, raw_sentence):\n            assert token.text == raw_token[TEXT]\n\ndef test_set_sentence(doc):\n    \"\"\"\n    Test setting a field on the sentences themselves\n    \"\"\"\n    doc.set(fields=\"sentiment\",\n            contents=[\"4\", \"0\"],\n            to_sentence=True)\n\n    assert doc.sentences[0].sentiment == \"4\"\n    assert doc.sentences[1].sentiment == \"0\"\n\ndef test_set_tokens(doc):\n    \"\"\"\n    Test setting values on tokens\n    \"\"\"\n    ner_contents = [\"O\", \"ARTIFACT\", \"ARTIFACT\", \"O\", \"CAT\"]\n    doc.set(fields=NER,\n            contents=ner_contents,\n            to_token=True)\n\n    result = doc.get(NER, from_token=True)\n    assert result == ner_contents\n\n\ndef test_constituency_comment(doc):\n    \"\"\"\n    Test that setting the constituency tree on a doc sets the constituency comment\n    \"\"\"\n    for sentence in doc.sentences:\n        assert len([x for x in sentence.comments if x.startswith(\"# constituency\")]) == 0\n\n    # currently nothing is checking that the items are actually trees\n    trees = [\"asdf\", \"zzzz\"]\n    doc.set(fields=CONSTITUENCY,\n            contents=trees,\n            to_sentence=True)\n\n    for sentence, expected in zip(doc.sentences, trees):\n        constituency_comments = [x for x in sentence.comments if x.startswith(\"# constituency\")]\n        assert len(constituency_comments) == 1\n        assert constituency_comments[0].endswith(expected)\n\n    # Test that if we replace the trees with an updated tree, the comment is also replaced\n    trees = [\"zzzz\", \"asdf\"]\n    doc.set(fields=CONSTITUENCY,\n            contents=trees,\n            to_sentence=True)\n\n    for sentence, expected in zip(doc.sentences, trees):\n        constituency_comments = [x for x in sentence.comments if x.startswith(\"# constituency\")]\n        assert len(constituency_comments) == 1\n        assert constituency_comments[0].endswith(expected)\n\ndef test_sentiment_comment(doc):\n    \"\"\"\n    Test that setting the sentiment on a doc sets the sentiment comment\n    \"\"\"\n    for sentence in doc.sentences:\n        assert len([x for x in sentence.comments if x.startswith(\"# sentiment\")]) == 0\n\n    # currently nothing is checking that the items are actually trees\n    sentiments = [\"1\", \"2\"]\n    doc.set(fields=SENTIMENT,\n            contents=sentiments,\n            to_sentence=True)\n\n    for sentence, expected in zip(doc.sentences, sentiments):\n        sentiment_comments = [x for x in sentence.comments if x.startswith(\"# sentiment\")]\n        assert len(sentiment_comments) == 1\n        assert sentiment_comments[0].endswith(expected)\n\n    # Test that if we replace the trees with an updated tree, the comment is also replaced\n    sentiments = [\"3\", \"4\"]\n    doc.set(fields=SENTIMENT,\n            contents=sentiments,\n            to_sentence=True)\n\n    for sentence, expected in zip(doc.sentences, sentiments):\n        sentiment_comments = [x for x in sentence.comments if x.startswith(\"# sentiment\")]\n        assert len(sentiment_comments) == 1\n        assert sentiment_comments[0].endswith(expected)\n\ndef test_sent_id_comment(doc):\n    \"\"\"\n    Test that setting the sent_id on a sentence sets the sentiment comment\n    \"\"\"\n    for sent_idx, sentence in enumerate(doc.sentences):\n        assert len([x for x in sentence.comments if x.startswith(\"# sent_id\")]) == 1\n        assert sentence.sent_id == \"%d\" % sent_idx\n    doc.sentences[0].sent_id = \"foo\"\n    assert doc.sentences[0].sent_id == \"foo\"\n    assert len([x for x in doc.sentences[0].comments if x.startswith(\"# sent_id\")]) == 1\n    assert \"# sent_id = foo\" in doc.sentences[0].comments\n\n    doc.reindex_sentences(10)\n    for sent_idx, sentence in enumerate(doc.sentences):\n        assert sentence.sent_id == \"%d\" % (sent_idx + 10)\n        assert len([x for x in doc.sentences[0].comments if x.startswith(\"# sent_id\")]) == 1\n        assert \"# sent_id = %d\" % (sent_idx + 10) in sentence.comments\n\n    doc.sentences[0].add_comment(\"# sent_id = bar\")\n    assert doc.sentences[0].sent_id == \"bar\"\n    assert \"# sent_id = bar\" in doc.sentences[0].comments\n    assert len([x for x in doc.sentences[0].comments if x.startswith(\"# sent_id\")]) == 1\n\ndef test_doc_id_comment(doc):\n    \"\"\"\n    Test that setting the doc_id on a sentence sets the document comment\n    \"\"\"\n    assert doc.sentences[0].doc_id is None\n    assert len([x for x in doc.sentences[0].comments if x.startswith(\"# doc_id\")]) == 0\n\n    doc.sentences[0].doc_id = \"foo\"\n    assert len([x for x in doc.sentences[0].comments if x.startswith(\"# doc_id\")]) == 1\n    assert \"# doc_id = foo\" in doc.sentences[0].comments\n    assert doc.sentences[0].doc_id == \"foo\"\n\n    doc.sentences[0].add_comment(\"# doc_id = bar\")\n    assert len([x for x in doc.sentences[0].comments if x.startswith(\"# doc_id\")]) == 1\n    assert doc.sentences[0].doc_id == \"bar\"\n\n@pytest.fixture(scope=\"module\")\ndef pipeline():\n    return stanza.Pipeline(dir=TEST_MODELS_DIR)\n\ndef test_serialized(pipeline):\n    \"\"\"\n    Brief test of the serialized format\n\n    Checks that NER entities are correctly set.\n    Also checks that constituency & sentiment are set on the sentences.\n    \"\"\"\n    text = \"John Bauer works at Stanford\"\n    doc = pipeline(text)\n    assert len(doc.ents) == 2\n    serialized = doc.to_serialized()\n    doc2 = Document.from_serialized(serialized)\n    assert len(doc2.sentences) == 1\n    assert len(doc2.ents) == 2\n    assert doc.sentences[0].constituency == doc2.sentences[0].constituency\n    assert doc.sentences[0].sentiment == doc2.sentences[0].sentiment\n"
  },
  {
    "path": "stanza/tests/common/test_dropout.py",
    "content": "import pytest\n\nimport torch\n\nimport stanza\nfrom stanza.models.common.dropout import WordDropout\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_word_dropout():\n    \"\"\"\n    Test that word_dropout is randomly dropping out the entire final dimension of a tensor\n\n    Doing 600 small rows should be super fast, but it leaves us with\n    something like a 1 in 10^180 chance of the test failing.  Not very\n    common, in other words\n    \"\"\"\n    wd = WordDropout(0.5)\n    batch = torch.randn(600, 4)\n    dropped = wd(batch)\n    # the one time any of this happens, it's going to be really confusing\n    assert not torch.allclose(batch, dropped)\n    num_zeros = 0\n    for i in range(batch.shape[0]):\n        assert torch.allclose(dropped[i], batch[i]) or torch.sum(dropped[i]) == 0.0\n        if torch.sum(dropped[i]) == 0.0:\n            num_zeros += 1\n    assert num_zeros > 0 and num_zeros < batch.shape[0]\n"
  },
  {
    "path": "stanza/tests/common/test_foundation_cache.py",
    "content": "import glob\nimport os\nimport shutil\nimport tempfile\n\nimport pytest\n\nimport stanza\nfrom stanza.models.common.foundation_cache import FoundationCache, load_charlm\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_charlm_cache():\n    models_path = os.path.join(TEST_MODELS_DIR, \"en\", \"backward_charlm\", \"*\")\n    models = glob.glob(models_path)\n    # we expect at least one English model downloaded for the tests\n    assert len(models) >= 1\n    model_file = models[0]\n\n    cache = FoundationCache()\n    with tempfile.TemporaryDirectory(dir=\".\") as test_dir:\n        temp_file = os.path.join(test_dir, \"charlm.pt\")\n        shutil.copy2(model_file, temp_file)\n        # this will work\n        model = load_charlm(temp_file)\n\n        # this will save the model\n        model = cache.load_charlm(temp_file)\n\n    # this should no longer work\n    with pytest.raises(FileNotFoundError):\n        model = load_charlm(temp_file)\n\n    # it should remember the cached version\n    model = cache.load_charlm(temp_file)\n"
  },
  {
    "path": "stanza/tests/common/test_pretrain.py",
    "content": "import os\nimport tempfile\n\nimport pytest\nimport numpy as np\nimport torch\n\nfrom stanza.models.common import pretrain\nfrom stanza.models.common.vocab import UNK_ID\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef check_vocab(vocab):\n    # 4 base vectors, plus the 3 vectors actually present in the file\n    assert len(vocab) == 7\n    assert 'unban' in vocab\n    assert 'mox' in vocab\n    assert 'opal' in vocab\n\ndef check_embedding(emb, unk=False):\n    expected = np.array([[ 0.,  0.,  0.,  0.,],\n                         [ 0.,  0.,  0.,  0.,],\n                         [ 0.,  0.,  0.,  0.,],\n                         [ 0.,  0.,  0.,  0.,],\n                         [ 1.,  2.,  3.,  4.,],\n                         [ 5.,  6.,  7.,  8.,],\n                         [ 9., 10., 11., 12.,]])\n    if unk:\n        expected[UNK_ID] = -1\n    np.testing.assert_allclose(emb, expected)\n\ndef check_pretrain(pt):\n    check_vocab(pt.vocab)\n    check_embedding(pt.emb)\n\ndef test_text_pretrain():\n    pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.txt', save_to_file=False)\n    check_pretrain(pt)\n\ndef test_xz_pretrain():\n    pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False)\n    check_pretrain(pt)\n\ndef test_gz_pretrain():\n    pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.gz', save_to_file=False)\n    check_pretrain(pt)\n\ndef test_zip_pretrain():\n    pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.zip', save_to_file=False)\n    check_pretrain(pt)\n\ndef test_csv_pretrain():\n    pt = pretrain.Pretrain(csv_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.csv', save_to_file=False)\n    check_pretrain(pt)\n\ndef test_resave_pretrain():\n    \"\"\"\n    Test saving a pretrain and then loading from the existing file\n    \"\"\"\n    test_pt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=\".pt\", delete=False)\n    try:\n        test_pt_file.close()\n        # note that this tests the ability to save a pretrain and the\n        # ability to fall back when the existing pretrain isn't working\n        pt = pretrain.Pretrain(filename=test_pt_file.name,\n                               vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz')\n        check_pretrain(pt)\n\n        pt2 = pretrain.Pretrain(filename=test_pt_file.name,\n                               vec_filename=f'unban_mox_opal')\n        check_pretrain(pt2)\n\n        pt3 = torch.load(test_pt_file.name, weights_only=True)\n        check_embedding(pt3['emb'])\n    finally:\n        os.unlink(test_pt_file.name)\n\nSPACE_PRETRAIN=\"\"\"\n3 4\nunban mox 1 2 3 4\nopal 5 6 7 8\nfoo 9 10 11 12\n\"\"\".strip()\n\ndef test_whitespace():\n    \"\"\"\n    Test reading a pretrain with an ascii space in it\n\n    The vocab word with a space in it should have the correct number\n    of dimensions read, with the space converted to nbsp\n    \"\"\"\n    test_txt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=\".txt\", delete=False)\n    try:\n        test_txt_file.write(SPACE_PRETRAIN.encode())\n        test_txt_file.close()\n\n        pt = pretrain.Pretrain(vec_filename=test_txt_file.name, save_to_file=False)\n        check_embedding(pt.emb)\n        assert \"unban\\xa0mox\" in pt.vocab\n        # this one also works because of the normalize_unit in vocab.py\n        assert \"unban mox\" in pt.vocab\n    finally:\n        os.unlink(test_txt_file.name)\n\nNO_HEADER_PRETRAIN=\"\"\"\nunban 1 2 3 4\nmox 5 6 7 8\nopal 9 10 11 12\n\"\"\".strip()\n\ndef test_no_header():\n    \"\"\"\n    Check loading a pretrain with no rows,cols header\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir:\n        filename = os.path.join(tmpdir, \"tiny.txt\")\n        with open(filename, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(NO_HEADER_PRETRAIN)\n        pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False)\n        check_embedding(pt.emb)\n\nUNK_PRETRAIN=\"\"\"\nunban 1 2 3 4\nmox 5 6 7 8\nopal 9 10 11 12\n<unk> -1 -1 -1 -1\n\"\"\".strip()\n\ndef test_no_header():\n    \"\"\"\n    Check loading a pretrain with <unk> at the end, like GloVe does\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir:\n        filename = os.path.join(tmpdir, \"tiny.txt\")\n        with open(filename, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(UNK_PRETRAIN)\n        pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False)\n        check_embedding(pt.emb, unk=True)\n"
  },
  {
    "path": "stanza/tests/common/test_relative_attn.py",
    "content": "import pytest\n\nimport torch\n\nfrom stanza.models.common.relative_attn import RelativeAttention\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\n\ndef test_attn():\n    foo = RelativeAttention(d_model=100, num_heads=2, window=8, dropout=0.0)\n    bar = torch.randn(10, 13, 100)\n    result = foo(bar)\n    assert result.shape == bar.shape\n    value = foo.value(bar)\n    if not torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06):\n        raise ValueError(result[:, -1, :] - value[:, -1, :])\n    assert torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06)\n    assert not torch.allclose(result[:, 0, :], value[:, 0, :])\n\n\ndef test_shorter_sequence():\n    # originally this was failing because the batch was smaller than the window\n    foo = RelativeAttention(d_model=20, num_heads=2, window=5, dropout=0.0)\n    bar = torch.randn(10, 3, 20)\n    result = foo(bar)\n    assert result.shape == bar.shape\n\n    value = foo.value(bar)\n    if not torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06):\n        raise ValueError(result[:, -1, :] - value[:, -1, :])\n    assert torch.allclose(result[:, -1, :], value[:, -1, :], atol=1e-06)\n    assert not torch.allclose(result[:, 0, :], value[:, 0, :])\n\ndef test_reverse():\n    foo = RelativeAttention(d_model=100, num_heads=2, window=8, reverse=True, dropout=0.0)\n    bar = torch.randn(10, 13, 100)\n    result = foo(bar)\n    assert result.shape == bar.shape\n    value = foo.value(bar)\n    if not torch.allclose(result[:, 0, :], value[:, 0, :], atol=1e-06):\n        raise ValueError(result[:, 0, :] - value[:, 0, :])\n    assert torch.allclose(result[:, 0, :], value[:, 0, :], atol=1e-06)\n    assert not torch.allclose(result[:, -1, :], value[:, -1, :])\n\n\n"
  },
  {
    "path": "stanza/tests/common/test_short_name_to_treebank.py",
    "content": "import pytest\n\nimport stanza\nfrom stanza.models.common import short_name_to_treebank\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_short_name():\n    assert short_name_to_treebank.short_name_to_treebank(\"en_ewt\") == \"UD_English-EWT\"\n\ndef test_canonical_name():\n    assert short_name_to_treebank.canonical_treebank_name(\"UD_URDU-UDTB\") == \"UD_Urdu-UDTB\"\n    assert short_name_to_treebank.canonical_treebank_name(\"ur_udtb\") == \"UD_Urdu-UDTB\"\n    assert short_name_to_treebank.canonical_treebank_name(\"Unban_Mox_Opal\") == \"Unban_Mox_Opal\"\n"
  },
  {
    "path": "stanza/tests/common/test_utils.py",
    "content": "import lzma\nimport os\nimport tempfile\n\nimport pytest\n\nimport stanza\nimport stanza.models.common.utils as utils\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_wordvec_not_found():\n    \"\"\"\n    get_wordvec_file should fail if neither word2vec nor fasttext exists\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:\n        with pytest.raises(FileNotFoundError):\n            utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')\n\n\ndef test_word2vec_xz():\n    \"\"\"\n    Test searching for word2vec and xz files\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:\n        # make a fake directory for English word vectors\n        word2vec_dir = os.path.join(temp_dir, 'word2vec', 'English')\n        os.makedirs(word2vec_dir)\n\n        # make a fake English word vector file\n        fake_file = os.path.join(word2vec_dir, 'en.vectors.xz')\n        fout = open(fake_file, 'w')\n        fout.close()\n\n        # get_wordvec_file should now find this fake file\n        filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')\n        assert filename == fake_file\n\ndef test_fasttext_txt():\n    \"\"\"\n    Test searching for fasttext and txt files\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:\n        # make a fake directory for English word vectors\n        fasttext_dir = os.path.join(temp_dir, 'fasttext', 'English')\n        os.makedirs(fasttext_dir)\n\n        # make a fake English word vector file\n        fake_file = os.path.join(fasttext_dir, 'en.vectors.txt')\n        fout = open(fake_file, 'w')\n        fout.close()\n\n        # get_wordvec_file should now find this fake file\n        filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')\n        assert filename == fake_file\n\ndef test_wordvec_type():\n    \"\"\"\n    If we supply our own wordvec type, get_wordvec_file should find that\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:\n        # make a fake directory for English word vectors\n        google_dir = os.path.join(temp_dir, 'google', 'English')\n        os.makedirs(google_dir)\n\n        # make a fake English word vector file\n        fake_file = os.path.join(google_dir, 'en.vectors.txt')\n        fout = open(fake_file, 'w')\n        fout.close()\n\n        # get_wordvec_file should now find this fake file\n        filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo', wordvec_type='google')\n        assert filename == fake_file\n\n        # this file won't be found using the normal defaults\n        with pytest.raises(FileNotFoundError):\n            utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')\n\ndef test_sort_with_indices():\n    data = [[1, 2, 3], [4, 5], [6]]\n    ordered, orig_idx = utils.sort_with_indices(data, key=len)\n    assert ordered == ([6], [4, 5], [1, 2, 3])\n    assert orig_idx == (2, 1, 0)\n\n    unsorted = utils.unsort(ordered, orig_idx)\n    assert data == unsorted\n\ndef test_empty_sort_with_indices():\n    ordered, orig_idx = utils.sort_with_indices([])\n    assert len(ordered) == 0\n    assert len(orig_idx) == 0\n\n    unsorted = utils.unsort(ordered, orig_idx)\n    assert [] == unsorted\n\n\ndef test_split_into_batches():\n    data = []\n    for i in range(5):\n        data.append([\"Unban\", \"mox\", \"opal\", str(i)])\n\n    data.append([\"Do\", \"n't\", \"ban\", \"Urza\", \"'s\", \"Saga\", \"that\", \"card\", \"is\", \"great\"])\n    data.append([\"Ban\", \"Ragavan\"])\n\n    # small batches will put one element in each interval\n    batches = utils.split_into_batches(data, 5)\n    assert batches == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]\n\n    # this one has a batch interrupted in the middle by a large element\n    batches = utils.split_into_batches(data, 8)\n    assert batches == [(0, 2), (2, 4), (4, 5), (5, 6), (6, 7)]\n\n    # this one has the large element at the start of its own batch\n    batches = utils.split_into_batches(data[1:], 8)\n    assert batches == [(0, 2), (2, 4), (4, 5), (5, 6)]\n\n    # overloading the test!  assert that the key & reverse is working\n    ordered, orig_idx = utils.sort_with_indices(data, key=len, reverse=True)\n    assert [len(x) for x in ordered] == [10, 4, 4, 4, 4, 4, 2]\n\n    # this has the large element at the start\n    batches = utils.split_into_batches(ordered, 8)\n    assert batches == [(0, 1), (1, 3), (3, 5), (5, 7)]\n\n    # double check that unsort is working as expected\n    assert data == utils.unsort(ordered, orig_idx)\n\n\ndef test_find_missing_tags():\n    assert utils.find_missing_tags([\"O\", \"PER\", \"LOC\"], [\"O\", \"PER\", \"LOC\"]) == []\n    assert utils.find_missing_tags([\"O\", \"PER\", \"LOC\"], [\"O\", \"PER\", \"LOC\", \"ORG\"]) == ['ORG']\n    assert utils.find_missing_tags([[\"O\", \"PER\"], [\"O\", \"LOC\"]], [[\"O\", \"PER\"], [\"LOC\", \"ORG\"]]) == ['ORG']\n\n\ndef test_open_read_text():\n    \"\"\"\n    test that we can read either .xz or regular txt\n    \"\"\"\n    TEXT = \"this is a test\"\n    with tempfile.TemporaryDirectory() as tempdir:\n        # test text file\n        filename = os.path.join(tempdir, \"foo.txt\")\n        with open(filename, \"w\") as fout:\n            fout.write(TEXT)\n        with utils.open_read_text(filename) as fin:\n            in_text = fin.read()\n            assert TEXT == in_text\n\n        assert fin.closed\n\n        # the context should close the file when we throw an exception!\n        try:\n            with utils.open_read_text(filename) as finex:\n                assert not finex.closed\n                raise ValueError(\"unban mox opal!\")\n        except ValueError:\n            pass\n        assert finex.closed\n\n        # test xz file\n        filename = os.path.join(tempdir, \"foo.txt.xz\")\n        with lzma.open(filename, \"wt\") as fout:\n            fout.write(TEXT)\n        with utils.open_read_text(filename) as finxz:\n            in_text = finxz.read()\n            assert TEXT == in_text\n\n        assert finxz.closed\n\n        # the context should close the file when we throw an exception!\n        try:\n            with utils.open_read_text(filename) as finexxz:\n                assert not finexxz.closed\n                raise ValueError(\"unban mox opal!\")\n        except ValueError:\n            pass\n        assert finexxz.closed\n\n\ndef test_checkpoint_name():\n    \"\"\"\n    Test some expected results for the checkpoint names\n    \"\"\"\n    # use os.path.split so that the test is agnostic of file separator on Linux or Windows\n    checkpoint = utils.checkpoint_name(\"saved_models\", \"kk_oscar_forward_charlm.pt\", None)\n    assert os.path.split(checkpoint) == (\"saved_models\", \"kk_oscar_forward_charlm_checkpoint.pt\")\n\n    checkpoint = utils.checkpoint_name(\"saved_models\", \"kk_oscar_forward_charlm\", None)\n    assert os.path.split(checkpoint) == (\"saved_models\", \"kk_oscar_forward_charlm_checkpoint\")\n\n    checkpoint = utils.checkpoint_name(\"saved_models\", \"kk_oscar_forward_charlm\", \"othername.pt\")\n    assert os.path.split(checkpoint) == (\"saved_models\", \"othername.pt\")\n\ndef test_punct_simplification():\n    \"\"\"\n    Test a punctuation simplification that should make it so unexpected\n    question/exclamation marks types are processed into ? and !\n    \"\"\"\n    test = [[[\"!!!!\"],\n             [\"‼‼‼‼\"],\n             [\"????\"],\n             [\"?!?!\"],\n             [\"?？︖\"],\n             [\"?foo\"],\n             [\"bar!\"]]]\n    test = utils.simplify_punct(test)\n    expected = [[['!'], ['!'], ['?'], ['?'], ['?'], ['?foo'], ['bar!']]]\n    assert test == expected\n\n"
  },
  {
    "path": "stanza/tests/constituency/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/constituency/test_convert_arboretum.py",
    "content": "\"\"\"\nTest a couple different classes of trees to check the output of the Arboretum conversion\n\nNote that the text has been removed\n\"\"\"\n\nimport os\nimport tempfile\n\nimport pytest\n\nfrom stanza.server import tsurgeon\nfrom stanza.tests import TEST_WORKING_DIR\nfrom stanza.utils.datasets.constituency import convert_arboretum\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n\nPROJ_EXAMPLE=\"\"\"\n<s id=\"s2\" ref=\"AACBPIGY\" source=\"id=AACBPIGY\" forest=\"1/1\" text=\"A B C D E F G H.\">\n\t<graph root=\"s2_500\">\n\t\t<terminals>\n\t\t\t<t id=\"s2_1\" word=\"A\" lemma=\"A\" pos=\"prop\" morph=\"NOM\" extra=\"PROP:A compound brand\"/>\n\t\t\t<t id=\"s2_2\" word=\"B\" lemma=\"B\" pos=\"v-fin\" morph=\"PR AKT\" extra=\"mv\"/>\n\t\t\t<t id=\"s2_3\" word=\"C\" lemma=\"C\" pos=\"pron-pers\" morph=\"2S ACC\" extra=\"--\"/>\n\t\t\t<t id=\"s2_4\" word=\"D\" lemma=\"D\" pos=\"adj\" morph=\"UTR S IDF NOM\" extra=\"F:u+afhængig\"/>\n\t\t\t<t id=\"s2_5\" word=\"E\" lemma=\"E\" pos=\"prp\" morph=\"--\" extra=\"--\"/>\n\t\t\t<t id=\"s2_6\" word=\"F\" lemma=\"F\" pos=\"art\" morph=\"NEU S DEF\" extra=\"--\"/>\n\t\t\t<t id=\"s2_7\" word=\"G\" lemma=\"G\" pos=\"adj\" morph=\"nG S DEF NOM\" extra=\"--\"/>\n\t\t\t<t id=\"s2_8\" word=\"H\" lemma=\"H\" pos=\"n\" morph=\"NEU S IDF NOM\" extra=\"N:lys+net\"/>\n\t\t\t<t id=\"s2_9\" word=\".\" lemma=\"--\" pos=\"pu\" morph=\"--\" extra=\"--\"/>\n\t\t</terminals>\n\n\t\t<nonterminals>\n\t\t\t<nt id=\"s2_500\" cat=\"s\">\n\t\t\t\t<edge label=\"STA\" idref=\"s2_501\"/>\n\t\t\t</nt>\n\t\t\t<nt id=\"s2_501\" cat=\"fcl\">\n\t\t\t\t<edge label=\"S\" idref=\"s2_1\"/>\n\t\t\t\t<edge label=\"P\" idref=\"s2_2\"/>\n\t\t\t\t<edge label=\"Od\" idref=\"s2_3\"/>\n\t\t\t\t<edge label=\"Co\" idref=\"s2_502\"/>\n\t\t\t\t<edge label=\"PU\" idref=\"s2_9\"/>\n\t\t\t</nt>\n\t\t\t<nt id=\"s2_502\" cat=\"adjp\">\n\t\t\t\t<edge label=\"H\" idref=\"s2_4\"/>\n\t\t\t\t<edge label=\"DA\" idref=\"s2_503\"/>\n\t\t\t</nt>\n\t\t\t<nt id=\"s2_503\" cat=\"pp\">\n\t\t\t\t<edge label=\"H\" idref=\"s2_5\"/>\n\t\t\t\t<edge label=\"DP\" idref=\"s2_504\"/>\n\t\t\t</nt>\n\t\t\t<nt id=\"s2_504\" cat=\"np\">\n\t\t\t\t<edge label=\"DN\" idref=\"s2_6\"/>\n\t\t\t\t<edge label=\"DN\" idref=\"s2_7\"/>\n\t\t\t\t<edge label=\"H\" idref=\"s2_8\"/>\n\t\t\t</nt>\n\t\t</nonterminals>\n\t</graph>\n</s>\n\"\"\"\n\nNOT_FIX_NONPROJ_EXAMPLE=\"\"\"\n<s id=\"s322\" ref=\"EDGBITSZ\" source=\"id=EDGBITSZ\" forest=\"1/2\" text=\"A B C D E, F G H I J.\">\n        <graph root=\"s322_500\">\n                <terminals>\n                        <t id=\"s322_1\" word=\"A\" lemma=\"A\" pos=\"prop\" morph=\"NOM\" extra=\"hum fem\"/>\n                        <t id=\"s322_2\" word=\"B\" lemma=\"B\" pos=\"v-fin\" morph=\"PR AKT\" extra=\"mv\"/>\n                        <t id=\"s322_3\" word=\"C\" lemma=\"C\" pos=\"pron-dem\" morph=\"UTR S\" extra=\"dem\"/>\n                        <t id=\"s322_4\" word=\"D\" lemma=\"D\" pos=\"n\" morph=\"UTR S IDF NOM\" extra=\"--\"/>\n                        <t id=\"s322_5\" word=\"E\" lemma=\"E\" pos=\"adv\" morph=\"--\" extra=\"--\"/>\n                        <t id=\"s322_6\" word=\",\" lemma=\"--\" pos=\"pu\" morph=\"--\" extra=\"--\"/>\n                        <t id=\"s322_7\" word=\"F\" lemma=\"F\" pos=\"pron-rel\" morph=\"--\" extra=\"rel\"/>\n                        <t id=\"s322_8\" word=\"G\" lemma=\"G\" pos=\"prop\" morph=\"NOM\" extra=\"hum\"/>\n                        <t id=\"s322_9\" word=\"H\" lemma=\"H\" pos=\"v-fin\" morph=\"IMPF AKT\" extra=\"mv\"/>\n                        <t id=\"s322_10\" word=\"I\" lemma=\"I\" pos=\"prp\" morph=\"--\" extra=\"--\"/>\n                        <t id=\"s322_11\" word=\"J\" lemma=\"J\" pos=\"n\" morph=\"UTR S DEF NOM\" extra=\"F:ur+premiere\"/>\n                        <t id=\"s322_12\" word=\".\" lemma=\"--\" pos=\"pu\" morph=\"--\" extra=\"--\"/>\n                </terminals>\n\n                <nonterminals>\n                        <nt id=\"s322_500\" cat=\"s\">\n                                <edge label=\"STA\" idref=\"s322_501\"/>\n                        </nt>\n                        <nt id=\"s322_501\" cat=\"fcl\">\n                                <edge label=\"S\" idref=\"s322_1\"/>\n                                <edge label=\"P\" idref=\"s322_2\"/>\n                                <edge label=\"Od\" idref=\"s322_502\"/>\n                                <edge label=\"Vpart\" idref=\"s322_5\"/>\n                                <edge label=\"PU\" idref=\"s322_6\"/>\n                                <edge label=\"PU\" idref=\"s322_12\"/>\n                        </nt>\n                        <nt id=\"s322_502\" cat=\"np\">\n                                <edge label=\"DN\" idref=\"s322_3\"/>\n                                <edge label=\"H\" idref=\"s322_4\"/>\n                                <edge label=\"DN\" idref=\"s322_503\"/>\n                        </nt>\n                        <nt id=\"s322_503\" cat=\"fcl\">\n                                <edge label=\"Od\" idref=\"s322_7\"/>\n                                <edge label=\"S\" idref=\"s322_8\"/>\n                                <edge label=\"P\" idref=\"s322_9\"/>\n                                <edge label=\"Ao\" idref=\"s322_504\"/>\n                        </nt>\n                        <nt id=\"s322_504\" cat=\"pp\">\n                                <edge label=\"H\" idref=\"s322_10\"/>\n                                <edge label=\"DP\" idref=\"s322_11\"/>\n                        </nt>\n                </nonterminals>\n        </graph>\n</s>\n\"\"\"\n\n\nNONPROJ_EXAMPLE=\"\"\"\n<s id=\"s9\" ref=\"AATCNKQZ\" source=\"id=AATCNKQZ\" forest=\"1/1\" text=\"A B C D E F G H I.\">\n        <graph root=\"s9_500\">\n                <terminals>\n                        <t id=\"s9_1\" word=\"A\" lemma=\"A\" pos=\"adv\" morph=\"--\" extra=\"--\"/>\n                        <t id=\"s9_2\" word=\"B\" lemma=\"B\" pos=\"adv\" morph=\"--\" extra=\"--\"/>\n                        <t id=\"s9_3\" word=\"C\" lemma=\"C\" pos=\"v-fin\" morph=\"IMPF AKT\" extra=\"aux\"/>\n                        <t id=\"s9_4\" word=\"D\" lemma=\"D\" pos=\"prop\" morph=\"NOM\" extra=\"hum\"/>\n                        <t id=\"s9_5\" word=\"E\" lemma=\"E\" pos=\"adv\" morph=\"--\" extra=\"--\"/>\n                        <t id=\"s9_6\" word=\"F\" lemma=\"F\" pos=\"v-pcp2\" morph=\"PAS\" extra=\"mv\"/>\n                        <t id=\"s9_7\" word=\"G\" lemma=\"G\" pos=\"prp\" morph=\"--\" extra=\"--\"/>\n                        <t id=\"s9_8\" word=\"H\" lemma=\"H\" pos=\"num\" morph=\"--\" extra=\"card\"/>\n                        <t id=\"s9_9\" word=\"I\" lemma=\"I\" pos=\"n\" morph=\"UTR P IDF NOM\" extra=\"N:patrulje+vogn\"/>\n                        <t id=\"s9_10\" word=\".\" lemma=\"--\" pos=\"pu\" morph=\"--\" extra=\"--\"/>\n                </terminals>\n\n                <nonterminals>\n                        <nt id=\"s9_500\" cat=\"s\">\n                                <edge label=\"STA\" idref=\"s9_501\"/>\n                        </nt>\n                        <nt id=\"s9_501\" cat=\"fcl\">\n                                <edge label=\"fA\" idref=\"s9_502\"/>\n                                <edge label=\"P\" idref=\"s9_503\"/>\n                                <edge label=\"S\" idref=\"s9_4\"/>\n                                <edge label=\"fA\" idref=\"s9_5\"/>\n                                <edge label=\"fA\" idref=\"s9_504\"/>\n                                <edge label=\"PU\" idref=\"s9_10\"/>\n                        </nt>\n                        <nt id=\"s9_502\" cat=\"advp\">\n                                <edge label=\"DA\" idref=\"s9_1\"/>\n                                <edge label=\"H\" idref=\"s9_2\"/>\n                        </nt>\n                        <nt id=\"s9_503\" cat=\"vp\">\n                                <edge label=\"Vaux\" idref=\"s9_3\"/>\n                                <edge label=\"Vm\" idref=\"s9_6\"/>\n                        </nt>\n                        <nt id=\"s9_504\" cat=\"pp\">\n                                <edge label=\"H\" idref=\"s9_7\"/>\n                                <edge label=\"DP\" idref=\"s9_505\"/>\n                        </nt>\n                        <nt id=\"s9_505\" cat=\"np\">\n                                <edge label=\"DN\" idref=\"s9_8\"/>\n                                <edge label=\"H\" idref=\"s9_9\"/>\n                        </nt>\n                </nonterminals>\n        </graph>\n</s>\n\"\"\"\n\ndef test_projective_example():\n    \"\"\"\n    Test reading a basic tree, along with some further manipulations from the conversion program\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:\n        test_name = os.path.join(tempdir, \"proj.xml\")\n        with open(test_name, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(PROJ_EXAMPLE)\n        sentences = convert_arboretum.read_xml_file(test_name)\n        assert len(sentences) == 1\n\n    tree, words = convert_arboretum.process_tree(sentences[0])\n    expected_tree = \"(s (fcl (prop s2_1) (v-fin s2_2) (pron-pers s2_3) (adjp (adj s2_4) (pp (prp s2_5) (np (art s2_6) (adj s2_7) (n s2_8)))) (pu s2_9)))\"\n    assert str(tree) == expected_tree\n    assert [w.word for w in words.values()] == ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', '.']\n    assert not convert_arboretum.word_sequence_missing_words(tree)\n    with tsurgeon.Tsurgeon() as tsurgeon_processor:\n        assert tree == convert_arboretum.check_words(tree, tsurgeon_processor)\n\n    # check that the words can be replaced as expected\n    replaced_tree = convert_arboretum.replace_words(tree, words)\n    expected_tree = \"(s (fcl (prop A) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))\"\n    assert str(replaced_tree) == expected_tree\n    assert convert_arboretum.split_underscores(replaced_tree) == replaced_tree\n\n    # fake a word which should be split\n    words['s2_1'] = words['s2_1']._replace(word='foo_bar')\n    replaced_tree = convert_arboretum.replace_words(tree, words)\n    expected_tree = \"(s (fcl (prop foo_bar) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))\"\n    assert str(replaced_tree) == expected_tree\n    expected_tree = \"(s (fcl (np (prop foo) (prop bar)) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))\"\n    assert str(convert_arboretum.split_underscores(replaced_tree)) == expected_tree\n\n\ndef test_not_fix_example():\n    \"\"\"\n    Test that a non-projective tree which we don't have a heuristic for quietly fails\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:\n        test_name = os.path.join(tempdir, \"nofix.xml\")\n        with open(test_name, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(NOT_FIX_NONPROJ_EXAMPLE)\n        sentences = convert_arboretum.read_xml_file(test_name)\n        assert len(sentences) == 1\n\n    tree, words = convert_arboretum.process_tree(sentences[0])\n    assert not convert_arboretum.word_sequence_missing_words(tree)\n    with tsurgeon.Tsurgeon() as tsurgeon_processor:\n        assert convert_arboretum.check_words(tree, tsurgeon_processor) is None\n\n\ndef test_fix_proj_example():\n    \"\"\"\n    Test that a non-projective tree can be rearranged as expected\n\n    Note that there are several other classes of non-proj tree we could test as well...\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:\n        test_name = os.path.join(tempdir, \"fix.xml\")\n        with open(test_name, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(NONPROJ_EXAMPLE)\n        sentences = convert_arboretum.read_xml_file(test_name)\n        assert len(sentences) == 1\n\n    tree, words = convert_arboretum.process_tree(sentences[0])\n    assert not convert_arboretum.word_sequence_missing_words(tree)\n    # the 4 and 5 are moved inside the 3-6 node\n    expected_orig = \"(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (v-pcp2 s9_6)) (prop s9_4) (adv s9_5) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))\"\n    expected_proj = \"(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (prop s9_4) (adv s9_5) (v-pcp2 s9_6)) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))\"\n    assert str(tree) == expected_orig\n    with tsurgeon.Tsurgeon() as tsurgeon_processor:\n        assert str(convert_arboretum.check_words(tree, tsurgeon_processor)) == expected_proj\n\n"
  },
  {
    "path": "stanza/tests/constituency/test_convert_it_vit.py",
    "content": "\"\"\"\nTest a couple different classes of trees to check the output of the VIT conversion\n\nA couple representative trees are included, but hopefully not enough\nto be a problem in terms of our license.\n\nOne of the tests is currently disabled as it relies on tregex & tsurgeon features\nnot yet released\n\"\"\"\n\nimport io\nimport os\nimport tempfile\n\nimport pytest\n\nfrom stanza.server import tsurgeon\nfrom stanza.utils.conll import CoNLL\nfrom stanza.utils.datasets.constituency import convert_it_vit\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# just a sample!  don't sue us please\nCON_SAMPLE = \"\"\"\n#ID=sent_00002\tcp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]]\n\n#ID=sent_00318\tdirsp-[fc-[congf-tuttavia, f-[sn-[sq-[ind-qualche], n-problema], ir_infl-[vsupir-potrebbe, vcl-esserci], compc-[clit-ci, sp-[p-per, sn-[art-la, n-commissione, sa-[ag-esteri], f2-[sp-[part-alla, relob-cui, sn-[n-presidenza]], f-[ibar-[vc-è], compc-[sn-[n-candidato], sn-[art-l, n-esponente, spd-[pd-di, sn-[mw-Alleanza, npro-Nazionale]], sn-[mw-Mirko, nh-Tremaglia]]]]]]]]]], dirs-':', f3-[sn-[art-una, n-candidatura, sc-[q-più, sa-[ppas-subìta], sc-[ccong-che, sa-[ppas-gradita]], compt-[spda-[partda-dalla, sn-[mw-Lega, npro-Nord, punt-',', f2-[rel-che, fc-[congf-tuttavia, f-[ir_infl-[vsupir-dovrebbe, vit-rispettare], compt-[sn-[art-gli, n-accordi]]]]]]]]]], punto-.]]\n\n#ID=sent_00589\tf-[sn-[art-l, n-ottimismo, spd-[pd-di, sn-[nh-Kantor]]], ir_infl-[vsupir-potrebbe, congf-però, vcl-rivelarsi], compc-[sn-[in-ancora, art-una, nt-volta], sa-[ag-prematuro]], punto-.]\n\"\"\"\n\nUD_SAMPLE = \"\"\"\n# sent_id = VIT-2\n# text = Negli ultimi anni la dinamica dei polo di attrazione è stata sempre più caratterizzata dall'emergere di una crescente concorrenza che si è progressivamente spostata dalle singole imprese ai sistemi economici e territoriali, determinando l'esigenza di una riconsiderazione dei rapporti esistenti tra soggetti produttivi e ambiente in cui questi operano.\n1-2\tNegli\t_\t_\t_\t_\t_\t_\t_\t_\n1\tIn\tin\tADP\tE\t_\t4\tcase\t_\t_\n2\tgli\til\tDET\tRD\tDefinite=Def|Gender=Masc|Number=Plur|PronType=Art\t4\tdet\t_\t_\n3\tultimi\tultimo\tADJ\tA\tGender=Masc|Number=Plur\t4\tamod\t_\t_\n4\tanni\tanno\tNOUN\tS\tGender=Masc|Number=Plur\t16\tobl\t_\t_\n5\tla\til\tDET\tRD\tDefinite=Def|Gender=Fem|Number=Sing|PronType=Art\t6\tdet\t_\t_\n6\tdinamica\tdinamica\tNOUN\tS\tGender=Fem|Number=Sing\t16\tnsubj:pass\t_\t_\n7-8\tdei\t_\t_\t_\t_\t_\t_\t_\t_\n7\tdi\tdi\tADP\tE\t_\t9\tcase\t_\t_\n8\ti\til\tDET\tRD\tDefinite=Def|Gender=Masc|Number=Plur|PronType=Art\t9\tdet\t_\t_\n9\tpolo\tpolo\tNOUN\tS\tGender=Masc|Number=Sing\t6\tnmod\t_\t_\n10\tdi\tdi\tADP\tE\t_\t11\tcase\t_\t_\n11\tattrazione\tattrazione\tNOUN\tS\tGender=Fem|Number=Sing\t9\tnmod\t_\t_\n12\tè\tessere\tAUX\tVA\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t16\taux\t_\t_\n13\tstata\tessere\tAUX\tVA\tGender=Fem|Number=Sing|Tense=Past|VerbForm=Part\t16\taux:pass\t_\t_\n14\tsempre\tsempre\tADV\tB\t_\t15\tadvmod\t_\t_\n15\tpiù\tpiù\tADV\tB\t_\t16\tadvmod\t_\t_\n16\tcaratterizzata\tcaratterizzare\tVERB\tV\tGender=Fem|Number=Sing|Tense=Past|VerbForm=Part\t0\troot\t_\t_\n17-18\tdall'\t_\t_\t_\t_\t_\t_\t_\tSpaceAfter=No\n17\tda\tda\tADP\tE\t_\t19\tcase\t_\t_\n18\tl'\til\tDET\tRD\tDefinite=Def|Number=Sing|PronType=Art\t19\tdet\t_\t_\n19\temergere\temergere\tNOUN\tS\tGender=Masc|Number=Sing\t16\tobl\t_\t_\n20\tdi\tdi\tADP\tE\t_\t23\tcase\t_\t_\n21\tuna\tuno\tDET\tRI\tDefinite=Ind|Gender=Fem|Number=Sing|PronType=Art\t23\tdet\t_\t_\n22\tcrescente\tcrescente\tADJ\tA\tNumber=Sing\t23\tamod\t_\t_\n23\tconcorrenza\tconcorrenza\tNOUN\tS\tGender=Fem|Number=Sing\t19\tnmod\t_\t_\n24\tche\tche\tPRON\tPR\tPronType=Rel\t28\tnsubj\t_\t_\n25\tsi\tsi\tPRON\tPC\tClitic=Yes|Person=3|PronType=Prs\t28\texpl\t_\t_\n26\tè\tessere\tAUX\tVA\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t28\taux\t_\t_\n27\tprogressivamente\tprogressivamente\tADV\tB\t_\t28\tadvmod\t_\t_\n28\tspostata\tspostare\tVERB\tV\tGender=Fem|Number=Sing|Tense=Past|VerbForm=Part\t23\tacl:relcl\t_\t_\n29-30\tdalle\t_\t_\t_\t_\t_\t_\t_\t_\n29\tda\tda\tADP\tE\t_\t32\tcase\t_\t_\n30\tle\til\tDET\tRD\tDefinite=Def|Gender=Fem|Number=Plur|PronType=Art\t32\tdet\t_\t_\n31\tsingole\tsingolo\tADJ\tA\tGender=Fem|Number=Plur\t32\tamod\t_\t_\n32\timprese\timpresa\tNOUN\tS\tGender=Fem|Number=Plur\t28\tobl\t_\t_\n33-34\tai\t_\t_\t_\t_\t_\t_\t_\t_\n33\ta\ta\tADP\tE\t_\t35\tcase\t_\t_\n34\ti\til\tDET\tRD\tDefinite=Def|Gender=Masc|Number=Plur|PronType=Art\t35\tdet\t_\t_\n35\tsistemi\tsistema\tNOUN\tS\tGender=Masc|Number=Plur\t28\tobl\t_\t_\n36\teconomici\teconomico\tADJ\tA\tGender=Masc|Number=Plur\t35\tamod\t_\t_\n37\te\te\tCCONJ\tCC\t_\t38\tcc\t_\t_\n38\tterritoriali\tterritoriale\tADJ\tA\tNumber=Plur\t36\tconj\t_\tSpaceAfter=No\n39\t,\t,\tPUNCT\tFF\t_\t28\tpunct\t_\t_\n40\tdeterminando\tdeterminare\tVERB\tV\tVerbForm=Ger\t28\tadvcl\t_\t_\n41\tl'\til\tDET\tRD\tDefinite=Def|Number=Sing|PronType=Art\t42\tdet\t_\tSpaceAfter=No\n42\tesigenza\tesigenza\tNOUN\tS\tGender=Fem|Number=Sing\t40\tobj\t_\t_\n43\tdi\tdi\tADP\tE\t_\t45\tcase\t_\t_\n44\tuna\tuno\tDET\tRI\tDefinite=Ind|Gender=Fem|Number=Sing|PronType=Art\t45\tdet\t_\t_\n45\triconsiderazione\triconsiderazione\tNOUN\tS\tGender=Fem|Number=Sing\t42\tnmod\t_\t_\n46-47\tdei\t_\t_\t_\t_\t_\t_\t_\t_\n46\tdi\tdi\tADP\tE\t_\t48\tcase\t_\t_\n47\ti\til\tDET\tRD\tDefinite=Def|Gender=Masc|Number=Plur|PronType=Art\t48\tdet\t_\t_\n48\trapporti\trapporto\tNOUN\tS\tGender=Masc|Number=Plur\t45\tnmod\t_\t_\n49\tesistenti\tesistente\tVERB\tV\tNumber=Plur\t48\tacl\t_\t_\n50\ttra\ttra\tADP\tE\t_\t51\tcase\t_\t_\n51\tsoggetti\tsoggetto\tNOUN\tS\tGender=Masc|Number=Plur\t49\tobl\t_\t_\n52\tproduttivi\tproduttivo\tADJ\tA\tGender=Masc|Number=Plur\t51\tamod\t_\t_\n53\te\te\tCCONJ\tCC\t_\t54\tcc\t_\t_\n54\tambiente\tambiente\tNOUN\tS\tGender=Masc|Number=Sing\t51\tconj\t_\t_\n55\tin\tin\tADP\tE\t_\t56\tcase\t_\t_\n56\tcui\tcui\tPRON\tPR\tPronType=Rel\t58\tobl\t_\t_\n57\tquesti\tquesto\tPRON\tPD\tGender=Masc|Number=Plur|PronType=Dem\t58\tnsubj\t_\t_\n58\toperano\toperare\tVERB\tV\tMood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin\t54\tacl:relcl\t_\tSpaceAfter=No\n59\t.\t.\tPUNCT\tFS\t_\t16\tpunct\t_\t_\n\n# sent_id = VIT-318\n# text = Tuttavia qualche problema potrebbe esserci per la commissione esteri alla cui presidenza è candidato l'esponente di Alleanza Nazionale Mirko Tremaglia: una candidatura più subìta che gradita dalla Lega Nord, che tuttavia dovrebbe rispettare gli accordi.\n1\tTuttavia\ttuttavia\tCCONJ\tCC\t_\t5\tcc\t_\t_\n2\tqualche\tqualche\tDET\tDI\tNumber=Sing|PronType=Ind\t3\tdet\t_\t_\n3\tproblema\tproblema\tNOUN\tS\tGender=Masc|Number=Sing\t5\tnsubj\t_\t_\n4\tpotrebbe\tpotere\tAUX\tVA\tMood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t5\taux\t_\t_\n5-6\tesserci\t_\t_\t_\t_\t_\t_\t_\t_\n5\tesser\tessere\tVERB\tV\tVerbForm=Inf\t0\troot\t_\t_\n6\tci\tci\tPRON\tPC\tClitic=Yes|Number=Plur|Person=1|PronType=Prs\t5\texpl\t_\t_\n7\tper\tper\tADP\tE\t_\t9\tcase\t_\t_\n8\tla\til\tDET\tRD\tDefinite=Def|Gender=Fem|Number=Sing|PronType=Art\t9\tdet\t_\t_\n9\tcommissione\tcommissione\tNOUN\tS\tGender=Fem|Number=Sing\t5\tobl\t_\t_\n10\testeri\testero\tADJ\tA\tGender=Masc|Number=Plur\t9\tamod\t_\t_\n11-12\talla\t_\t_\t_\t_\t_\t_\t_\t_\n11\ta\ta\tADP\tE\t_\t14\tcase\t_\t_\n12\tla\til\tDET\tRD\tDefinite=Def|Gender=Fem|Number=Sing|PronType=Art\t14\tdet\t_\t_\n13\tcui\tcui\tDET\tDR\tPronType=Rel\t14\tdet:poss\t_\t_\n14\tpresidenza\tpresidenza\tNOUN\tS\tGender=Fem|Number=Sing\t16\tobl\t_\t_\n15\tè\tessere\tAUX\tVA\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t16\taux:pass\t_\t_\n16\tcandidato\tcandidare\tVERB\tV\tGender=Masc|Number=Sing|Tense=Past|VerbForm=Part\t9\tacl:relcl\t_\t_\n17\tl'\til\tDET\tRD\tDefinite=Def|Number=Sing|PronType=Art\t18\tdet\t_\tSpaceAfter=No\n18\tesponente\tesponente\tNOUN\tS\tNumber=Sing\t16\tnsubj:pass\t_\t_\n19\tdi\tdi\tADP\tE\t_\t20\tcase\t_\t_\n20\tAlleanza\tAlleanza\tPROPN\tSP\t_\t18\tnmod\t_\t_\n21\tNazionale\tNazionale\tPROPN\tSP\t_\t20\tflat:name\t_\t_\n22\tMirko\tMirko\tPROPN\tSP\t_\t18\tnmod\t_\t_\n23\tTremaglia\tTremaglia\tPROPN\tSP\t_\t22\tflat:name\t_\tSpaceAfter=No\n24\t:\t:\tPUNCT\tFC\t_\t22\tpunct\t_\t_\n25\tuna\tuno\tDET\tRI\tDefinite=Ind|Gender=Fem|Number=Sing|PronType=Art\t26\tdet\t_\t_\n26\tcandidatura\tcandidatura\tNOUN\tS\tGender=Fem|Number=Sing\t22\tappos\t_\t_\n27\tpiù\tpiù\tADV\tB\t_\t28\tadvmod\t_\t_\n28\tsubìta\tsubire\tVERB\tV\tGender=Fem|Number=Sing|Tense=Past|VerbForm=Part\t26\tadvcl\t_\t_\n29\tche\tche\tCCONJ\tCC\t_\t30\tcc\t_\t_\n30\tgradita\tgradito\tADJ\tA\tGender=Fem|Number=Sing\t28\tamod\t_\t_\n31-32\tdalla\t_\t_\t_\t_\t_\t_\t_\t_\n31\tda\tda\tADP\tE\t_\t33\tcase\t_\t_\n32\tla\til\tDET\tRD\tDefinite=Def|Gender=Fem|Number=Sing|PronType=Art\t33\tdet\t_\t_\n33\tLega\tLega\tPROPN\tSP\t_\t28\tobl:agent\t_\t_\n34\tNord\tNord\tPROPN\tSP\t_\t33\tflat:name\t_\tSpaceAfter=No\n35\t,\t,\tPUNCT\tFC\t_\t33\tpunct\t_\t_\n36\tche\tche\tPRON\tPR\tPronType=Rel\t39\tnsubj\t_\t_\n37\ttuttavia\ttuttavia\tCCONJ\tCC\t_\t39\tcc\t_\t_\n38\tdovrebbe\tdovere\tAUX\tVM\tMood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t39\taux\t_\t_\n39\trispettare\trispettare\tVERB\tV\tVerbForm=Inf\t33\tacl:relcl\t_\t_\n40\tgli\til\tDET\tRD\tDefinite=Def|Gender=Masc|Number=Plur|PronType=Art\t41\tdet\t_\t_\n41\taccordi\taccordio\tNOUN\tS\tGender=Masc|Number=Plur\t39\tobj\t_\tSpaceAfter=No\n42\t.\t.\tPUNCT\tFS\t_\t5\tpunct\t_\t_\n\n# sent_id = VIT-591\n# text = L'ottimismo di Kantor potrebbe però rivelarsi ancora una volta prematuro.\n1\tL'\til\tDET\tRD\tDefinite=Def|Number=Sing|PronType=Art\t2\tdet\t_\tSpaceAfter=No\n2\tottimismo\tottimismo\tNOUN\tS\tGender=Masc|Number=Sing\t7\tnsubj\t_\t_\n3\tdi\tdi\tADP\tE\t_\t4\tcase\t_\t_\n4\tKantor\tKantor\tPROPN\tSP\t_\t2\tnmod\t_\t_\n5\tpotrebbe\tpotere\tAUX\tVM\tMood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t7\taux\t_\t_\n6\tperò\tperò\tADV\tB\t_\t7\tadvmod\t_\t_\n7-8\trivelarsi\t_\t_\t_\t_\t_\t_\t_\t_\n7\trivelar\trivelare\tVERB\tV\tVerbForm=Inf\t0\troot\t_\t_\n8\tsi\tsi\tPRON\tPC\tClitic=Yes|Person=3|PronType=Prs\t7\texpl\t_\t_\n9\tancora\tancora\tADV\tB\t_\t7\tadvmod\t_\t_\n10\tuna\tuno\tDET\tRI\tDefinite=Ind|Gender=Fem|Number=Sing|PronType=Art\t11\tdet\t_\t_\n11\tvolta\tvolta\tNOUN\tS\tGender=Fem|Number=Sing\t7\tobl\t_\t_\n12\tprematuro\tprematuro\tADJ\tA\tGender=Masc|Number=Sing\t7\txcomp\t_\tSpaceAfter=No\n13\t.\t.\tPUNCT\tFS\t_\t7\tpunct\t_\t_\n\"\"\"\n\n\ndef test_process_mwts():\n    # dei appears multiple times\n    # the verb/pron esserci will be ignored\n    expected_mwts = {'Negli': ('In', 'gli'), 'dei': ('di', 'i'), \"dall'\": ('da', \"l'\"), 'dalle': ('da', 'le'), 'ai': ('a', 'i'), 'alla': ('a', 'la'), 'dalla': ('da', 'la')}\n\n    ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE)\n\n    mwts = convert_it_vit.get_mwt(ud_train_data)\n    assert expected_mwts == mwts\n\ndef test_raw_tree():\n    con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE))\n    expected_ids = [\"#ID=sent_00002\", \"#ID=sent_00318\", \"#ID=sent_00589\"]\n    expected_trees = [\"(ROOT (cp (sp (part negli) (sn (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd dei) (sn (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda dall) (sn (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda dalle) (sn (sa (ag singole)) (n imprese))) (sp (part ai) (sn (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd dei) (sn (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))\",\n                      \"(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part alla) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda dalla) (sn (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))\",\n                      \"(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))\"]\n    assert len(con_sentences) == 3\n    for sentence, expected_id, expected_tree in zip(con_sentences, expected_ids, expected_trees):\n        assert sentence[0] == expected_id\n        tree = convert_it_vit.raw_tree(sentence[1])\n        assert str(tree) == expected_tree\n\ndef test_update_mwts():\n    con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE))\n    ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE)\n    mwt_map = convert_it_vit.get_mwt(ud_train_data)\n    expected_trees=[\"(ROOT (cp (sp (part In) (sn (art gli) (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd di) (sn (art i) (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda da) (sn (art l') (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda da) (sn (art le) (sa (ag singole)) (n imprese))) (sp (part a) (sn (art i) (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd di) (sn (art i) (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))\",\n                    \"(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part a) (art la) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda da) (sn (art la) (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))\",\n                    \"(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (clit si) (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))\"]\n    with tsurgeon.Tsurgeon() as tsurgeon_processor:\n        for con_sentence, ud_sentence, expected_tree in zip(con_sentences, ud_train_data.sentences, expected_trees):\n            con_tree = convert_it_vit.raw_tree(con_sentence[1])\n            updated_tree, _ = convert_it_vit.update_mwts_and_special_cases(con_tree, ud_sentence, mwt_map, tsurgeon_processor)\n            assert str(updated_tree) == expected_tree\n\n\nCON_PERCENT_SAMPLE = \"\"\"\nID#sent_00020 f-[sn-[art-il, n-tesoro], ibar-[vt-mette], compt-[sp-[part-sul, sn-[n-mercato]], sn-[art-il, num-51%, sp-[p-a, sn-[num-2, n-lire]], sp-[p-per, sn-[n-azione]]]], punto-.]\nID#sent_00022 dirsp-[f3-[sn-[art-le, n-novità]], dirs-':', f3-[coord-[sn-[n-voto, spd-[pd-di, sn-[n-lista]]], cong-e, sn-[n-tetto, sp-[part-agli, sn-[n-acquisti]], sv3-[vppt-limitato, comppas-[sp-[part-allo, sn-[num-0/5%]]]]]], punto-.]]\nID#sent_00517 dirsp-[fc-[f-[sn-[art-l, n-aumento, sa-[ag-mensile], spd-[pd-di, sn-[nt-aprile]]], ibar-[ause-è, vppc-stato], compc-[sq-[q-dell_, sn-[num-1/3%]], sp-[p-contro, sn-[art-lo, num-0/7/0/8%, spd-[partd-degli, sn-[sa-[ag-ultimi], num-due, sn-[nt-mesi]]]]]]]]]\nID#sent_01117 fc-[f-[sn-[art-La, sa-[ag-crescente], n-ripresa, spd-[partd-dei, sn-[n-beni, spd-[pd-di, sn-[n-consumo]]]]], ibar-[vin-deriva], savv-[avv-esclusivamente], compin-[spda-[partda-dal, sn-[n-miglioramento, f2-[spd-[pd-di, sn-[relob-cui]], f-[ibar-[ausa-hanno, vppin-beneficiato], compin-[sn-[n-beni, coord-[sa-[ag-durevoli, fp-[par-'(', sn-[num-plus4/5%], par-')']], cong-e, sa-[ag-semidurevoli, fp-[par-'(', sn-[num-plus1/5%], par-')']]]]]]]]]]], punt-',', fs-[cosu-mentre, f-[sn-[art-i, n-beni, sa-[neg-non, ag-durevoli], fp-[par-'(', sn-[num-min1%], par-')']], ibar-[vt-accusano], cong-ancora, compt-[sn-[art-un, sa-[ag-evidente], n-ritardo]]]], punto-.]\n\"\"\"\n\nCON_PERCENT_LEAVES = [\n    ['il', 'tesoro', 'mette', 'sul', 'mercato', 'il', '51', '%%', 'a', '2', 'lire', 'per', 'azione', '.'],\n    ['le', 'novità', ':', 'voto', 'di', 'lista', 'e', 'tetto', 'agli', 'acquisti', 'limitato', 'allo', '0,5', '%%', '.'],\n    ['l', 'aumento', 'mensile', 'di', 'aprile', 'è', 'stato', \"dell'\", '1,3', '%%', 'contro', 'lo', '0/7,0/8', '%%', 'degli', 'ultimi', 'due', 'mesi'],\n    # the plus and min look bad, but they get cleaned up when merging with the UD version of the dataset\n    ['La', 'crescente', 'ripresa', 'dei', 'beni', 'di', 'consumo', 'deriva', 'esclusivamente', 'dal', 'miglioramento', 'di', 'cui', 'hanno', 'beneficiato', 'beni', 'durevoli', '(', 'plus4,5', '%%', ')', 'e', 'semidurevoli', '(', 'plus1,5', '%%', ')', ',', 'mentre', 'i', 'beni', 'non', 'durevoli', '(', 'min1', '%%', ')', 'accusano', 'ancora', 'un', 'evidente', 'ritardo', '.']\n]\n\ndef test_read_percent():\n    con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_PERCENT_SAMPLE))\n    assert len(con_sentences) == len(CON_PERCENT_LEAVES)\n    for (_, raw_tree), expected_leaves in zip(con_sentences, CON_PERCENT_LEAVES):\n        tree = convert_it_vit.raw_tree(raw_tree)\n        words = tree.leaf_labels()\n        if expected_leaves is None:\n            print(words)\n        else:\n            assert words == expected_leaves\n"
  },
  {
    "path": "stanza/tests/constituency/test_convert_starlang.py",
    "content": "\"\"\"\nTest a couple different classes of trees to check the output of the Starlang conversion\n\"\"\"\n\nimport os\nimport tempfile\n\nimport pytest\n\nfrom stanza.utils.datasets.constituency import convert_starlang\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nTREE=\"( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}))  (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v}))  (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE}))  )\"\n\ndef test_read_tree():\n    \"\"\"\n    Test a basic tree read\n    \"\"\"\n    tree = convert_starlang.read_tree(TREE)\n    assert \"(ROOT (S (NP (NP Bayan) (NP Haag)) (VP (NP Elianti) (VP çalar)) (. .)))\" == str(tree)\n\ndef test_missing_word():\n    \"\"\"\n    Test that an error is thrown if the word is missing\n    \"\"\"\n    tree_text = TREE.replace(\"turkish=\", \"foo=\")\n    with pytest.raises(ValueError):\n        tree = convert_starlang.read_tree(tree_text)\n\ndef test_bad_label():\n    \"\"\"\n    Test that an unexpected label results in an error\n    \"\"\"\n    tree_text = TREE.replace(\"(S\", \"(s\")\n    with pytest.raises(ValueError):\n        tree = convert_starlang.read_tree(tree_text)\n"
  },
  {
    "path": "stanza/tests/constituency/test_ensemble.py",
    "content": "\"\"\"\nAdd a simple test of the Ensemble's inference path\n\nThis just reuses one model several times - that should still check the main loop, at least\n\"\"\"\n\nimport pytest\n\nfrom stanza import Pipeline\nfrom stanza.models.constituency import text_processing\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.ensemble import Ensemble, EnsembleTrainer\nfrom stanza.models.constituency.text_processing import parse_tokenized_sentences\n\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n\n@pytest.fixture(scope=\"module\")\ndef pipeline():\n    return Pipeline(dir=TEST_MODELS_DIR, lang=\"en\", processors=\"tokenize, pos, constituency\", tokenize_pretokenized=True)\n\n@pytest.fixture(scope=\"module\")\ndef saved_ensemble(tmp_path_factory, pipeline):\n    tmp_path = tmp_path_factory.mktemp(\"ensemble\")\n\n    # test the ensemble by reusing the same parser multiple times\n    con_processor = pipeline.processors[\"constituency\"]\n    model = con_processor._model\n    args = dict(model.args)\n    foundation_cache = pipeline.foundation_cache\n\n    model_path = con_processor._config['model_path']\n    # reuse the same model 3 times just to make sure the code paths are working\n    filenames = [model_path, model_path, model_path]\n\n    ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache)\n    save_path = tmp_path / \"ensemble.pt\"\n\n    ensemble.save(save_path)\n    return ensemble, save_path, args, foundation_cache\n\ndef check_basic_predictions(trees):\n    predictions = [x.predictions for x in trees]\n    assert len(predictions) == 2\n    assert all(len(x) == 1 for x in predictions)\n    trees = [x[0].tree for x in predictions]\n    result = [\"{}\".format(tree) for tree in trees]\n    expected = [\"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))\",\n                \"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))\"]\n    assert result == expected\n\ndef test_ensemble_inference(pipeline):\n    # test the ensemble by reusing the same parser multiple times\n    con_processor = pipeline.processors[\"constituency\"]\n    model = con_processor._model\n    args = dict(model.args)\n    foundation_cache = pipeline.foundation_cache\n\n    model_path = con_processor._config['model_path']\n    # reuse the same model 3 times just to make sure the code paths are working\n    filenames = [model_path, model_path, model_path]\n\n    ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache)\n    ensemble = ensemble.model\n    sentences = [[\"This\", \"is\", \"a\", \"test\"], [\"This\", \"is\", \"another\", \"test\"]]\n    trees = parse_tokenized_sentences(args, ensemble, [pipeline], sentences)\n    check_basic_predictions(trees)\n\ndef test_ensemble_save(saved_ensemble):\n    \"\"\"\n    Depending on the saved_ensemble fixture should be enough to ensure\n    that the ensemble was correctly saved\n\n    (loading is tested separately)\n    \"\"\"\n\ndef test_ensemble_save_load(pipeline, saved_ensemble):\n    _, save_path, args, foundation_cache = saved_ensemble\n    ensemble = EnsembleTrainer.load(save_path, args, foundation_cache=foundation_cache)\n    sentences = [[\"This\", \"is\", \"a\", \"test\"], [\"This\", \"is\", \"another\", \"test\"]]\n    trees = parse_tokenized_sentences(args, ensemble.model, [pipeline], sentences)\n    check_basic_predictions(trees)\n\ndef test_parse_text(tmp_path, pipeline, saved_ensemble):\n    _, model_path, args, foundation_cache = saved_ensemble\n\n    raw_file = str(tmp_path / \"test_input.txt\")\n    with open(raw_file, \"w\") as fout:\n        fout.write(\"This is a test\\nThis is another test\\n\")\n    output_file = str(tmp_path / \"test_output.txt\")\n\n    args = dict(args)\n    args['tokenized_file'] = raw_file\n    args['predict_file'] = output_file\n\n    text_processing.load_model_parse_text(args, model_path, [pipeline])\n    trees = tree_reader.read_treebank(output_file)\n    trees = [\"{}\".format(x) for x in trees]\n    expected_trees = [\"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))\",\n                      \"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))\"]\n    assert trees == expected_trees\n\ndef test_pipeline(saved_ensemble):\n    _, model_path, _, foundation_cache = saved_ensemble\n    nlp = Pipeline(\"en\", processors=\"tokenize,pos,constituency\", constituency_model_path=str(model_path), foundation_cache=foundation_cache, download_method=None)\n    doc = nlp(\"This is a test\")\n    tree = \"{}\".format(doc.sentences[0].constituency)\n    assert tree == \"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))\"\n"
  },
  {
    "path": "stanza/tests/constituency/test_in_order_compound_oracle.py",
    "content": "import pytest\n\nfrom stanza.models.constituency import in_order_compound_oracle\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme\nfrom stanza.models.constituency.transition_sequence import build_treebank\n\nfrom stanza.tests.constituency.test_transition_sequence import reconstruct_tree\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# A sample tree from PTB with a triple unary transition (at a location other than root)\n# Here we test the incorrect closing of various brackets\nTRIPLE_UNARY_START_TREE = \"\"\"\n( (S\n    (PRN\n      (S\n        (NP-SBJ (-NONE- *) )\n        (VP (VB See) )))\n    (, ,)\n    (NP-SBJ\n      (NP (DT the) (JJ other) (NN rule) )\n      (PP (IN of)\n        (NP (NN thumb) ))\n      (PP (IN about)\n        (NP (NN ballooning) )))))\n\"\"\"\n\nTREES = [TRIPLE_UNARY_START_TREE]\nTREEBANK = \"\\n\".join(TREES)\n\nROOT_LABELS = [\"ROOT\"]\n\n@pytest.fixture(scope=\"module\")\ndef trees():\n    trees = tree_reader.read_trees(TREEBANK)\n    trees = [t.prune_none().simplify_labels() for t in trees]\n    assert len(trees) == len(TREES)\n\n    return trees\n\n@pytest.fixture(scope=\"module\")\ndef gold_sequences(trees):\n    gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)\n    return gold_sequences\n\ndef get_repairs(gold_sequence, wrong_transition, repair_fn):\n    \"\"\"\n    Use the repair function and the wrong transition to iterate over the gold sequence\n\n    Returns a list of possible repairs, one for each position in the sequence\n    Repairs are tuples, (idx, seq)\n    \"\"\"\n    repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None))\n               for idx, gold_transition in enumerate(gold_sequence)]\n    repairs = [x for x in repairs if x[1] is not None]\n    return repairs\n\ndef test_fix_shift_close():\n    trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE)\n    trees = [t.prune_none().simplify_labels() for t in trees]\n    assert len(trees) == 1\n    tree = trees[0]\n\n    gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)\n\n    # there are three places in this tree where a long bracket (more than 2 subtrees)\n    # could theoretically be closed and then reopened\n    repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_shift_close_error)\n    assert len(repairs) == 3\n\n    expected_trees = [\"(ROOT (S (S (PRN (S (VP (VB See)))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))\",\n                      \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other)) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))\",\n                      \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)))) (PP (IN about) (NP (NN ballooning))))))\"]\n\n    for repair, expected in zip(repairs, expected_trees):\n        repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)\n        assert str(repaired_tree) == expected\n\ndef test_fix_open_close():\n    trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE)\n    trees = [t.prune_none().simplify_labels() for t in trees]\n    assert len(trees) == 1\n    tree = trees[0]\n\n    gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)\n\n    repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_open_close_error)\n    print(\"------------------\")\n    for repair in repairs:\n        print(repair)\n        repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)\n        print(\"{:P}\".format(repaired_tree))\n"
  },
  {
    "path": "stanza/tests/constituency/test_in_order_oracle.py",
    "content": "import itertools\nimport pytest\n\nfrom stanza.models.constituency import parse_transitions\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.base_model import SimpleModel\nfrom stanza.models.constituency.in_order_oracle import *\nfrom stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme\nfrom stanza.models.constituency.transition_sequence import build_treebank\n\nfrom stanza.tests import *\nfrom stanza.tests.constituency.test_transition_sequence import reconstruct_tree\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# A sample tree from PTB with a single unary transition (at a location other than root)\nSINGLE_UNARY_TREE = \"\"\"\n( (S\n    (NP-SBJ-1 (DT A) (NN record) (NN date) )\n    (VP (VBZ has) (RB n't)\n      (VP (VBN been)\n        (VP (VBN set)\n          (NP (-NONE- *-1) ))))\n    (. .) ))\n\"\"\"\n\n#  [Shift, OpenConstituent(('NP-SBJ-1',)), Shift, Shift, CloseConstituent, OpenConstituent(('S',)), Shift, OpenConstituent(('VP',)), Shift, Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)), CloseConstituent, CloseConstituent, CloseConstituent, CloseConstituent, Shift, CloseConstituent, OpenConstituent(('ROOT',)), CloseConstituent]\n\n# A sample tree from PTB with a double unary transition (at a location other than root)\nDOUBLE_UNARY_TREE = \"\"\"\n( (S\n    (NP-SBJ\n      (NP (RB Not) (PDT all) (DT those) )\n      (SBAR\n        (WHNP-3 (WP who) )\n        (S\n          (NP-SBJ (-NONE- *T*-3) )\n          (VP (VBD wrote) ))))\n    (VP (VBP oppose)\n      (NP (DT the) (NNS changes) ))\n    (. .) ))\n\"\"\"\n\n# A sample tree from PTB with a triple unary transition (at a location other than root)\n# The triple unary is at the START of the next bracket, which affects how the\n# dynamic oracle repairs the transition sequence\nTRIPLE_UNARY_START_TREE = \"\"\"\n( (S\n    (PRN\n      (S\n        (NP-SBJ (-NONE- *) )\n        (VP (VB See) )))\n    (, ,)\n    (NP-SBJ\n      (NP (DT the) (JJ other) (NN rule) )\n      (PP (IN of)\n        (NP (NN thumb) ))\n      (PP (IN about)\n        (NP (NN ballooning) )))))\n\"\"\"\n\n# A sample tree from PTB with a triple unary transition (at a location other than root)\n# The triple unary is at the END of the next bracket, which affects how the\n# dynamic oracle repairs the transition sequence\nTRIPLE_UNARY_END_TREE = \"\"\"\n( (S\n    (NP (NNS optimists) )\n    (VP (VBP expect) \n      (S \n        (NP-SBJ-4 (NNP Hong) (NNP Kong) )\n        (VP (TO to) \n          (VP (VB hum) \n            (ADVP-CLR (RB along) )\n            (SBAR-MNR (RB as) \n              (S \n                (NP-SBJ (-NONE- *-4) )\n                (VP (-NONE- *?*) \n                  (ADVP-TMP (IN before) ))))))))))\n\"\"\"\n\nTREES = [SINGLE_UNARY_TREE, DOUBLE_UNARY_TREE, TRIPLE_UNARY_START_TREE, TRIPLE_UNARY_END_TREE]\nTREEBANK = \"\\n\".join(TREES)\n\nNOUN_PHRASE_TREE = \"\"\"\n( (NP\n    (NP (NNP Chicago) (POS 's))\n    (NNP Goodman)\n    (NNP Theatre)))\n\"\"\"\n\nWIDE_NP_TREE = \"\"\"\n( (S\n    (NP-SBJ (DT These) (NNS studies))\n    (VP (VBP demonstrate)\n      (SBAR (IN that)\n        (S\n          (NP-SBJ (NNS mice))\n          (VP (VBP are)\n            (NP-PRD\n              (NP (DT a)\n                (ADJP (JJ practical)\n                  (CC and)\n                  (JJ powerful))\n                (JJ experimental) (NN system))\n              (SBAR\n                (WHADVP-2 (-NONE- *0*))\n                (S\n                  (NP-SBJ (-NONE- *PRO*))\n                  (VP (TO to)\n                    (VP (VB study)\n                      (NP (DT the) (NN genetics)))))))))))))\n\"\"\"\n\nWIDE_TREES = [NOUN_PHRASE_TREE, WIDE_NP_TREE]\nWIDE_TREEBANK = \"\\n\".join(WIDE_TREES)\n\nROOT_LABELS = [\"ROOT\"]\n\ndef get_repairs(gold_sequence, wrong_transition, repair_fn):\n    \"\"\"\n    Use the repair function and the wrong transition to iterate over the gold sequence\n\n    Returns a list of possible repairs, one for each position in the sequence\n    Repairs are tuples, (idx, seq)\n    \"\"\"\n    repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None))\n               for idx, gold_transition in enumerate(gold_sequence)]\n    repairs = [x for x in repairs if x[1] is not None]\n    return repairs\n\n@pytest.fixture(scope=\"module\")\ndef unary_trees():\n    trees = tree_reader.read_trees(TREEBANK)\n    trees = [t.prune_none().simplify_labels() for t in trees]\n    assert len(trees) == len(TREES)\n\n    return trees\n\n@pytest.fixture(scope=\"module\")\ndef gold_sequences(unary_trees):\n    gold_sequences = build_treebank(unary_trees, TransitionScheme.IN_ORDER)\n    return gold_sequences\n\n@pytest.fixture(scope=\"module\")\ndef wide_trees():\n    trees = tree_reader.read_trees(WIDE_TREEBANK)\n    trees = [t.prune_none().simplify_labels() for t in trees]\n    assert len(trees) == len(WIDE_TREES)\n\n    return trees\n\ndef test_wrong_open_root(gold_sequences):\n    \"\"\"\n    Test the results of the dynamic oracle on a few trees if the ROOT is mishandled.\n    \"\"\"\n    wrong_transition = OpenConstituent(\"S\")\n    gold_transition = OpenConstituent(\"ROOT\")\n    close_transition = CloseConstituent()\n\n    for gold_sequence in gold_sequences:\n        # each of the sequences should be ended with ROOT, Close\n        assert gold_sequence[-2] == gold_transition\n\n        repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_root_error)\n        # there is only spot in the sequence with a ROOT, so there should\n        # be exactly one location which affords a S/ROOT replacement\n        assert len(repairs) == 1\n        repair = repairs[0]\n\n        # the repair should occur at the -2 position, which is where ROOT is\n        assert repair[0] == len(gold_sequence) - 2\n        # and the resulting list should have the wrong transition followed by a Close\n        # to give the model another chance to close the tree\n        expected = gold_sequence[:-2] + [wrong_transition, close_transition] + gold_sequence[-2:]\n        assert repair[1] == expected\n\ndef test_missed_unary(gold_sequences):\n    \"\"\"\n    Test the repairs of an open/open error if it is effectively a skipped unary transition\n    \"\"\"\n    wrong_transition = OpenConstituent(\"S\")\n\n    repairs = get_repairs(gold_sequences[0], wrong_transition, fix_wrong_open_unary_chain)\n    assert len(repairs) == 0\n\n    # here we are simulating picking NT-S instead of NT-VP\n    # the DOUBLE_UNARY tree has one location where this is relevant, index 11\n    repairs = get_repairs(gold_sequences[1], wrong_transition, fix_wrong_open_unary_chain)\n    assert len(repairs) == 1\n    assert repairs[0][0] == 11\n    assert repairs[0][1] == gold_sequences[1][:11] + gold_sequences[1][13:]\n\n    # the TRIPLE_UNARY_START tree has two locations where this is relevant\n    # at index 1, the pattern goes (S (VP ...))\n    # so choosing S instead of VP means you can skip the VP and only miss that one bracket\n    # at index 5, the pattern goes (S (PRN (S (VP ...))) (...))\n    # note that this is capturing a unary transition into a larger constituent\n    # skipping the PRN is satisfactory\n    repairs = get_repairs(gold_sequences[2], wrong_transition, fix_wrong_open_unary_chain)\n    assert len(repairs) == 2\n    assert repairs[0][0] == 1\n    assert repairs[0][1] == gold_sequences[2][:1] + gold_sequences[2][3:]\n    assert repairs[1][0] == 5\n    assert repairs[1][1] == gold_sequences[2][:5] + gold_sequences[2][7:]\n\n    # The TRIPLE_UNARY_END tree has 2 sections of tree for a total of 3 locations\n    # where the repair might happen\n    # Surprisingly the unary transition at the very start can only be\n    # repaired by skipping it and using the outer S transition instead\n    # The second repair overall (first repair in the second location)\n    # should have a double skip to reach the S node\n    repairs = get_repairs(gold_sequences[3], wrong_transition, fix_wrong_open_unary_chain)\n    assert len(repairs) == 3\n    assert repairs[0][0] == 1\n    assert repairs[0][1] == gold_sequences[3][:1] + gold_sequences[3][3:]\n    assert repairs[1][0] == 21\n    assert repairs[1][1] == gold_sequences[3][:21] + gold_sequences[3][25:]\n    assert repairs[2][0] == 23\n    assert repairs[2][1] == gold_sequences[3][:23] + gold_sequences[3][25:]\n\n\ndef test_open_with_stuff(unary_trees, gold_sequences):\n    wrong_transition = OpenConstituent(\"S\")\n    expected_trees = [\n        \"(ROOT (S (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))\",\n        \"(ROOT (S (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))\",\n        None,\n        \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))\"\n    ]\n\n    for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees):\n        repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_stuff_unary)\n        if expected is None:\n            assert len(repairs) == 0\n        else:\n            assert len(repairs) == 1\n            result = reconstruct_tree(tree, repairs[0][1])\n            assert str(result) == expected\n\ndef test_general_open(gold_sequences):\n    wrong_transition = OpenConstituent(\"SBARQ\")\n\n    for sequence in gold_sequences:\n        repairs = get_repairs(sequence, wrong_transition, fix_wrong_open_general)\n        assert len(repairs) == sum(isinstance(x, OpenConstituent) for x in sequence) - 1\n        for repair in repairs:\n            assert len(repair[1]) == len(sequence)\n            assert repair[1][repair[0]] == wrong_transition\n            assert repair[1][:repair[0]] == sequence[:repair[0]]\n            assert repair[1][repair[0]+1:] == sequence[repair[0]+1:]\n\ndef test_missed_unary(unary_trees, gold_sequences):\n    shift_transition = Shift()\n    close_transition = CloseConstituent()\n\n    expected_close_results = [\n        [(12, 2)],\n        [(11, 4), (13, 2)],\n        # (NP NN thumb) and (NP NN ballooning) are both candidates for this repair\n        [(18, 2), (24, 2)],\n        [(21, 6), (23, 4), (25, 2)],\n    ]\n\n    expected_shift_results = [\n        (),\n        (),\n        (),\n        # (ADVP-CLR (RB along)) is followed by a shift\n        [(16, 2)],\n    ]\n\n    for tree, sequence, expected_close, expected_shift in zip(unary_trees, gold_sequences, expected_close_results, expected_shift_results):\n        repairs = get_repairs(sequence, close_transition, fix_missed_unary)\n        assert len(repairs) == len(expected_close)\n        for repair, (expected_idx, expected_len) in zip(repairs, expected_close):\n            assert repair[0] == expected_idx\n            assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:]\n\n        repairs = get_repairs(sequence, shift_transition, fix_missed_unary)\n        assert len(repairs) == len(expected_shift)\n        for repair, (expected_idx, expected_len) in zip(repairs, expected_shift):\n            assert repair[0] == expected_idx\n            assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:]\n\ndef test_open_shift(unary_trees, gold_sequences):\n    shift_transition = Shift()\n\n    expected_repairs = [\n        [(7,  \"(ROOT (S (NP (DT A) (NN record) (NN date)) (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))\"),\n         (10, \"(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VBN been) (VP (VBN set))) (. .)))\")],\n        [(7,  \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WP who) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))\"),\n         (9,  \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))\"),\n         (19, \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose) (NP (DT the) (NNS changes)) (. .)))\"),\n         (21, \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (DT the) (NNS changes)) (. .)))\")],\n        [(14, \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))\"),\n         (16, \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))\"),\n         (22, \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about) (NP (NN ballooning)))))\")],\n        [(5,  \"(ROOT (S (NP (NNS optimists)) (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))\"),\n         (10, \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))\"),\n         (12, \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))\"),\n         (14, \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))\"),\n         (19, \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (RB as) (S (VP (ADVP (IN before))))))))))\")]\n    ]\n\n    for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs):\n        repairs = get_repairs(sequence, shift_transition, fix_open_shift)\n        assert len(repairs) == len(expected)\n        for repair, (idx, expected_tree) in zip(repairs, expected):\n            assert repair[0] == idx\n            result_tree = reconstruct_tree(tree, repair[1])\n            assert str(result_tree) == expected_tree\n\n\ndef test_open_close(unary_trees, gold_sequences):\n    close_transition = CloseConstituent()\n\n    expected_repairs = [\n        [(7,  \"(ROOT (S (S (NP (DT A) (NN record) (NN date)) (VBZ has)) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))\"),\n         (10, \"(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VP (VBZ has) (RB n't) (VBN been)) (VP (VBN set))) (. .)))\")],\n        # missed the WHNP.  The surrounding SBAR cannot be created, either\n        [(7, \"(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))\"),\n         # missed the SBAR\n         (9, \"(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who))) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))\"),\n         # missed the VP around \"oppose the changes\"\n         (19, \"(ROOT (S (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose)) (NP (DT the) (NNS changes)) (. .)))\"),\n         # missed the NP in \"the changes\", looks pretty bad tbh\n         (21, \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VP (VBP oppose) (DT the)) (NNS changes)) (. .)))\")],\n        [(14, \"(ROOT (S (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule))) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))\"),\n         (16, \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (IN of)) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))\"),\n         (22, \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about)) (NP (NN ballooning)))))\")],\n        [(5, \"(ROOT (S (S (NP (NNS optimists)) (VBP expect)) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))\"),\n         (10, \"(ROOT (S (NP (NNS optimists)) (VP (VP (VBP expect) (NP (NNP Hong) (NNP Kong))) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))\"),\n         (12, \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (S (NP (NNP Hong) (NNP Kong)) (TO to)) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))\"),\n         (14, \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (VP (TO to) (VB hum)) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))\"),\n         (19, \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (ADVP (RB along)) (RB as)) (S (VP (ADVP (IN before))))))))))\")]\n    ]\n\n    for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs):\n        repairs = get_repairs(sequence, close_transition, fix_open_close)\n\n        assert len(repairs) == len(expected)\n        for repair, (idx, expected_tree) in zip(repairs, expected):\n            assert repair[0] == idx\n            result_tree = reconstruct_tree(tree, repair[1])\n            assert str(result_tree) == expected_tree\n\ndef test_shift_close(unary_trees, gold_sequences):\n    \"\"\"\n    Test the fix for a shift -> close\n\n    These errors can occur pretty much everywhere, and the fix is quite simple,\n    so we only test a few cases.\n    \"\"\"\n\n    close_transition = CloseConstituent()\n\n    expected_tree = \"(ROOT (S (NP (NP (DT A)) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))\"\n\n    repairs = get_repairs(gold_sequences[0], close_transition, fix_shift_close)\n    assert len(repairs) == 7\n    result_tree = reconstruct_tree(unary_trees[0], repairs[0][1])\n    assert str(result_tree) == expected_tree\n\n    repairs = get_repairs(gold_sequences[1], close_transition, fix_shift_close)\n    assert len(repairs) == 8\n\n    repairs = get_repairs(gold_sequences[2], close_transition, fix_shift_close)\n    assert len(repairs) == 8\n\n    repairs = get_repairs(gold_sequences[3], close_transition, fix_shift_close)\n    assert len(repairs) == 9\n    for rep in repairs:\n        if rep[0] == 16:\n            # This one is special because it occurs as part of a unary\n            # in other words, it should go unary, shift\n            # and instead we are making it close where the unary should be\n            # ... the unary would create \"(ADVP (RB along))\"\n            expected_tree = \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))\"\n            result_tree = reconstruct_tree(unary_trees[3], rep[1])\n            assert str(result_tree) == expected_tree\n            break\n    else:\n        raise AssertionError(\"Did not find an expected repair location\")\n\ndef test_close_open_shift_nested(unary_trees, gold_sequences):\n    shift_transition = Shift()\n\n    expected_trees = [{},\n                      {4: \"(ROOT (S (NP (RB Not) (PDT all) (DT those) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))\"},\n                      {4: \"(ROOT (S (VP (VB See)) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))\",\n                       13: \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))\"},\n                      {}]\n\n    for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees):\n        repairs = get_repairs(gold_sequence, shift_transition, fix_close_open_shift_nested)\n        assert len(repairs) == len(expected)\n        if len(expected) >= 1:\n            for repair in repairs:\n                assert repair[0] in expected.keys()\n                result_tree = reconstruct_tree(tree, repair[1])\n                assert str(result_tree) == expected[repair[0]]\n\ndef check_repairs(trees, gold_sequences, expected_trees, transition, repair_fn):\n    for tree_idx, (gold_tree, gold_sequence, expected) in enumerate(zip(trees, gold_sequences, expected_trees)):\n        repairs = get_repairs(gold_sequence, transition, repair_fn)\n        if expected is not None:\n            assert len(repairs) == len(expected)\n            for repair in repairs:\n                assert repair[0] in expected\n                result_tree = reconstruct_tree(gold_tree, repair[1])\n                assert str(result_tree) == expected[repair[0]]\n        else:\n            print(\"---------------------\")\n            print(\"{:P}\".format(gold_tree))\n            print(gold_sequence)\n            #print(repairs)\n            for repair in repairs:\n                print(\"---------------------\")\n                print(gold_sequence)\n                print(repair[1])\n                result_tree = reconstruct_tree(gold_tree, repair[1])\n                print(\"{:P}\".format(gold_tree))\n                print(\"{:P}\".format(result_tree))\n                print(tree_idx)\n                print(repair[0])\n                print(result_tree)\n\ndef test_close_open_shift_unambiguous(unary_trees, gold_sequences):\n    shift_transition = Shift()\n\n    expected_trees = [{},\n                      {8: \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who) (S (VP (VBD wrote)))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))\"},\n                      {},\n                      {2: \"(ROOT (S (NP (NNS optimists) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))\",\n                       9: \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))\"}]\n    check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_unambiguous_bracket)\n\ndef test_close_open_shift_ambiguous_early(unary_trees, gold_sequences):\n    shift_transition = Shift()\n\n    expected_trees = [{4: \"(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))))) (. .)))\"},\n                      {16: \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes)))) (. .)))\"},\n                      {2: \"(ROOT (S (PRN (S (VP (VB See) (, ,)))) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))\",\n                       6: \"(ROOT (S (PRN (S (VP (VB See))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))\"},\n                      {}]\n    check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_early)\n\ndef test_close_open_shift_ambiguous_late(unary_trees, gold_sequences):\n    shift_transition = Shift()\n\n    expected_trees = [{4: \"(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))))\"},\n                      {16: \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .))))\"},\n                      {2: \"(ROOT (S (PRN (S (VP (VB See) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))))\",\n                       6: \"(ROOT (S (PRN (S (VP (VB See))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))\"},\n                      {}]\n    check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_late)\n\n\ndef test_close_shift_shift(unary_trees, wide_trees):\n    \"\"\"\n    Test that close -> shift works when there is a single block shifted after\n\n    Includes a test specifically that there is no oracle action when there are two blocks after the missed close\n    \"\"\"\n    shift_transition = Shift()\n\n    expected_trees = [{15: \"(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .))))\"},\n                      {24: \"(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes)) (. .))))\"},\n                      {20: \"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning)))))))\"},\n                      {17: \"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))\"},\n                      {},\n                      {}]\n\n    test_trees = unary_trees + wide_trees\n    gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)\n\n    check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_unambiguous)\n\n\ndef test_close_shift_shift_early(unary_trees, wide_trees):\n    \"\"\"\n    Test that close -> shift works when there are multiple blocks shifted after\n\n    Also checks that the single block case is skipped, so as to keep them separate when testing\n\n    A tree with the expected property was specifically added for this test\n    \"\"\"\n    shift_transition = Shift()\n\n    test_trees = unary_trees + wide_trees\n    gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)\n\n    expected_trees = [{},\n                      {},\n                      {},\n                      {},\n                      {},\n                      {21: \"(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental)) (NN system)) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))\"}]\n\n    check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_early)\n\ndef test_close_shift_shift_late(unary_trees, wide_trees):\n    \"\"\"\n    Test that close -> shift works when there are multiple blocks shifted after\n\n    Also checks that the single block case is skipped, so as to keep them separate when testing\n\n    A tree with the expected property was specifically added for this test\n    \"\"\"\n    shift_transition = Shift()\n\n    test_trees = unary_trees + wide_trees\n    gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)\n\n    expected_trees = [{},\n                      {},\n                      {},\n                      {},\n                      {},\n                      {21: \"(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental) (NN system))) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))\"}]\n\n    check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_late)\n"
  },
  {
    "path": "stanza/tests/constituency/test_lstm_model.py",
    "content": "import os\n\nimport pytest\nimport torch\n\nfrom stanza.models.common import pretrain\nfrom stanza.models.common.utils import set_random_seed\nfrom stanza.models.constituency import parse_transitions\nfrom stanza.tests import *\nfrom stanza.tests.constituency import test_parse_transitions\nfrom stanza.tests.constituency.test_trainer import build_trainer\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n@pytest.fixture(scope=\"module\")\ndef pretrain_file():\n    return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'\n\ndef build_model(pretrain_file, *args):\n    # By default, we turn off multistage, since that can turn off various other structures in the initial training\n    args = ['--no_multistage', '--pattn_num_layers', '4', '--pattn_d_model', '256', '--hidden_size', '128', '--use_lattn'] + list(args)\n    trainer = build_trainer(pretrain_file, *args)\n    return trainer.model\n\n@pytest.fixture(scope=\"module\")\ndef unary_model(pretrain_file):\n    return build_model(pretrain_file, \"--transition_scheme\", \"TOP_DOWN_UNARY\")\n\ndef test_initial_state(unary_model):\n    test_parse_transitions.test_initial_state(unary_model)\n\ndef test_shift(pretrain_file):\n    # TODO: might be good to include some tests specifically for shift\n    # in the context of a model with unaries\n    model = build_model(pretrain_file)\n    test_parse_transitions.test_shift(model)\n\ndef test_unary(unary_model):\n    test_parse_transitions.test_unary(unary_model)\n\ndef test_unary_requires_root(unary_model):\n    test_parse_transitions.test_unary_requires_root(unary_model)\n\ndef test_open(unary_model):\n    test_parse_transitions.test_open(unary_model)\n\ndef test_compound_open(pretrain_file):\n    model = build_model(pretrain_file, '--transition_scheme', \"TOP_DOWN_COMPOUND\")\n    test_parse_transitions.test_compound_open(model)\n\ndef test_in_order_open(pretrain_file):\n    model = build_model(pretrain_file, '--transition_scheme', \"IN_ORDER\")\n    test_parse_transitions.test_in_order_open(model)\n\ndef test_close(unary_model):\n    test_parse_transitions.test_close(unary_model)\n\ndef run_forward_checks(model, num_states=1):\n    \"\"\"\n    Run a couple small transitions and a forward pass on the given model\n\n    Results are not checked in any way.  This function allows for\n    testing that building models with various options results in a\n    functional model.\n    \"\"\"\n    states = test_parse_transitions.build_initial_state(model, num_states)\n    model(states)\n\n    shift = parse_transitions.Shift()\n    shifts = [shift for _ in range(num_states)]\n    states = model.bulk_apply(states, shifts)\n    model(states)\n\n    open_transition = parse_transitions.OpenConstituent(\"NP\")\n    open_transitions = [open_transition for _ in range(num_states)]\n    assert open_transition.is_legal(states[0], model)\n    states = model.bulk_apply(states, open_transitions)\n    assert states[0].num_opens == 1\n    model(states)\n\n    states = model.bulk_apply(states, shifts)\n    model(states)\n    states = model.bulk_apply(states, shifts)\n    model(states)\n    assert states[0].num_opens == 1\n    # now should have \"mox\", \"opal\" on the constituents\n\n    close_transition = parse_transitions.CloseConstituent()\n    close_transitions = [close_transition for _ in range(num_states)]\n    assert close_transition.is_legal(states[0], model)\n    states = model.bulk_apply(states, close_transitions)\n    assert states[0].num_opens == 0\n\n    model(states)\n\ndef test_unary_forward(unary_model):\n    \"\"\"\n    Checks that the forward pass doesn't crash when run after various operations\n\n    Doesn't check the forward pass for making reasonable answers\n    \"\"\"\n    run_forward_checks(unary_model)\n\ndef test_lstm_forward(pretrain_file):\n    model = build_model(pretrain_file)\n    run_forward_checks(model, num_states=1)\n    run_forward_checks(model, num_states=2)\n\ndef test_lstm_layers(pretrain_file):\n    model = build_model(pretrain_file, '--num_lstm_layers', '1')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--num_lstm_layers', '2')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--num_lstm_layers', '3')\n    run_forward_checks(model)\n\ndef test_multiple_output_forward(pretrain_file):\n    \"\"\"\n    Test a couple different sizes of output layers\n    \"\"\"\n    model = build_model(pretrain_file, '--num_output_layers', '1', '--num_lstm_layers', '2')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--num_output_layers', '2', '--num_lstm_layers', '2')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--num_output_layers', '3', '--num_lstm_layers', '2')\n    run_forward_checks(model)\n\ndef test_no_tag_embedding_forward(pretrain_file):\n    \"\"\"\n    Test that the model continues to work if the tag embedding is turned on or off\n    \"\"\"\n    model = build_model(pretrain_file, '--tag_embedding_dim', '20')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--tag_embedding_dim', '0')\n    run_forward_checks(model)\n\ndef test_forward_combined_dummy(pretrain_file):\n    \"\"\"\n    Tests combined dummy and open node embeddings\n    \"\"\"\n    model = build_model(pretrain_file, '--combined_dummy_embedding')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--no_combined_dummy_embedding')\n    run_forward_checks(model)\n\ndef test_nonlinearity_init(pretrain_file):\n    \"\"\"\n    Tests that different initialization methods of the nonlinearities result in valid tensors\n    \"\"\"\n    model = build_model(pretrain_file, '--nonlinearity', 'relu')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--nonlinearity', 'tanh')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--nonlinearity', 'silu')\n    run_forward_checks(model)\n\ndef test_forward_charlm(pretrain_file):\n    \"\"\"\n    Tests loading and running a charlm\n\n    Note that this doesn't test the results of the charlm itself,\n    just that the model is shaped correctly\n    \"\"\"\n    forward_charlm_path = os.path.join(TEST_MODELS_DIR, \"en\", \"forward_charlm\", \"1billion.pt\")\n    backward_charlm_path = os.path.join(TEST_MODELS_DIR, \"en\", \"backward_charlm\", \"1billion.pt\")\n    assert os.path.exists(forward_charlm_path), \"Need to download en test models (or update path to the forward charlm)\"\n    assert os.path.exists(backward_charlm_path), \"Need to download en test models (or update path to the backward charlm)\"\n\n    model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'none')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'words')\n    run_forward_checks(model)\n\ndef test_forward_bert(pretrain_file):\n    \"\"\"\n    Test on a tiny Bert, which hopefully does not take up too much disk space or memory\n    \"\"\"\n    bert_model = \"hf-internal-testing/tiny-bert\"\n\n    model = build_model(pretrain_file, '--bert_model', bert_model)\n    run_forward_checks(model)\n\n\ndef test_forward_xlnet(pretrain_file):\n    \"\"\"\n    Test on a tiny xlnet, which hopefully does not take up too much disk space or memory\n    \"\"\"\n    bert_model = \"hf-internal-testing/tiny-random-xlnet\"\n\n    model = build_model(pretrain_file, '--bert_model', bert_model)\n    run_forward_checks(model)\n\n\ndef test_forward_sentence_boundaries(pretrain_file):\n    \"\"\"\n    Test start & stop boundary vectors\n    \"\"\"\n    model = build_model(pretrain_file, '--sentence_boundary_vectors', 'everything')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--sentence_boundary_vectors', 'words')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--sentence_boundary_vectors', 'none')\n    run_forward_checks(model)\n\ndef test_forward_constituency_composition(pretrain_file):\n    \"\"\"\n    Test different constituency composition functions\n    \"\"\"\n    model = build_model(pretrain_file, '--constituency_composition', 'bilstm')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'max')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'key')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'untied_key')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'untied_max')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'bilstm_max')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm_cx')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'bigram')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'attn')\n    run_forward_checks(model, num_states=2)\n\ndef test_forward_key_position(pretrain_file):\n    \"\"\"\n    Test KEY and UNTIED_KEY either with or without reduce_position\n    \"\"\"\n    model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '0')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '32')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '0')\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '32')\n    run_forward_checks(model, num_states=2)\n\n\ndef test_forward_attn_hidden_size(pretrain_file):\n    \"\"\"\n    Test that when attn is used with hidden sizes not evenly divisible by reduce_heads, the model reconfigures the hidden_size\n    \"\"\"\n    model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129')\n    assert model.hidden_size >= 129\n    assert model.hidden_size % model.reduce_heads == 0\n    run_forward_checks(model, num_states=2)\n\n    model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129', '--reduce_heads', '10')\n    assert model.hidden_size == 130\n    assert model.reduce_heads == 10\n\ndef test_forward_partitioned_attention(pretrain_file):\n    \"\"\"\n    Test with & without partitioned attention layers\n    \"\"\"\n    model = build_model(pretrain_file, '--pattn_num_heads', '8', '--pattn_num_layers', '8')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--pattn_num_heads', '0', '--pattn_num_layers', '0')\n    run_forward_checks(model)\n\ndef test_forward_labeled_attention(pretrain_file):\n    \"\"\"\n    Test with & without labeled attention layers\n    \"\"\"\n    model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--lattn_d_proj', '0', '--lattn_d_l', '0')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_combined_input')\n    run_forward_checks(model)\n\ndef test_lattn_partitioned(pretrain_file):\n    model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_partitioned')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned')\n    run_forward_checks(model)\n\n\ndef test_lattn_projection(pretrain_file):\n    \"\"\"\n    Test with & without labeled attention layers\n    \"\"\"\n    with pytest.raises(ValueError):\n        # this is too small\n        model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '256', '--lattn_partitioned')\n        run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned', '--lattn_d_input_proj', '256')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '768')\n    run_forward_checks(model)\n\n    # check that it works if we turn off the projection,\n    # in case having it on beccomes the default\n    model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '0')\n    run_forward_checks(model)\n\ndef test_forward_timing_choices(pretrain_file):\n    \"\"\"\n    Test different timing / position encodings\n    \"\"\"\n    model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'sin')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'learned')\n    run_forward_checks(model)\n\ndef test_transition_stack(pretrain_file):\n    \"\"\"\n    Test different transition stack types: lstm & attention\n    \"\"\"\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_stack', 'attn', '--transition_heads', '1')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_stack', 'attn', '--transition_heads', '4')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_stack', 'lstm')\n    run_forward_checks(model)\n\ndef test_constituent_stack(pretrain_file):\n    \"\"\"\n    Test different constituent stack types: lstm & attention\n    \"\"\"\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--constituent_stack', 'attn', '--constituent_heads', '1')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--constituent_stack', 'attn', '--constituent_heads', '4')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--constituent_stack', 'lstm')\n    run_forward_checks(model)\n\ndef test_different_transition_sizes(pretrain_file):\n    \"\"\"\n    If the transition hidden size and embedding size are different, the model should still work\n    \"\"\"\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_embedding_dim', '10', '--transition_hidden_size', '10',\n                        '--sentence_boundary_vectors', 'everything')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_embedding_dim', '20', '--transition_hidden_size', '10',\n                        '--sentence_boundary_vectors', 'everything')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_embedding_dim', '10', '--transition_hidden_size', '20',\n                        '--sentence_boundary_vectors', 'everything')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_embedding_dim', '10', '--transition_hidden_size', '10',\n                        '--sentence_boundary_vectors', 'none')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_embedding_dim', '20', '--transition_hidden_size', '10',\n                        '--sentence_boundary_vectors', 'none')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file,\n                        '--pattn_num_layers', '0', '--lattn_d_proj', '0',\n                        '--transition_embedding_dim', '10', '--transition_hidden_size', '20',\n                        '--sentence_boundary_vectors', 'none')\n    run_forward_checks(model)\n\ndef test_relative_attention(pretrain_file):\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat')\n    run_forward_checks(model)\n\ndef test_relative_attention_cat(pretrain_file):\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat')\n    run_forward_checks(model)\n    cat_size = model.word_input_size\n\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat')\n    run_forward_checks(model)\n    no_cat_size = model.word_input_size\n    assert cat_size > no_cat_size\n\ndef test_relative_attention_directional(pretrain_file):\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_forward', '--no_rattn_cat')\n    run_forward_checks(model)\n\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_reverse', '--no_rattn_cat')\n    run_forward_checks(model)\n\ndef test_relative_attention_sinks(pretrain_file):\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_window', '2', '--rattn_sinks', '1')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_sinks', '1')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_window', '2', '--rattn_sinks', '2')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--no_rattn_cat', '--rattn_sinks', '2')\n    run_forward_checks(model)\n\ndef test_relative_attention_cat_sinks(pretrain_file):\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_window', '2', '--rattn_sinks', '1')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_sinks', '1')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_window', '2', '--rattn_sinks', '2')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_sinks', '2')\n    run_forward_checks(model)\n\ndef test_relative_attention_endpoint_sinks(pretrain_file):\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '1')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '1')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '2')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '2')\n    run_forward_checks(model)\n\ndef test_lstm_tree_forward(pretrain_file):\n    \"\"\"\n    Test the LSTM_TREE forward pass\n    \"\"\"\n    model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm')\n    run_forward_checks(model)\n\ndef test_lstm_tree_cx_forward(pretrain_file):\n    \"\"\"\n    Test the LSTM_TREE_CX forward pass\n    \"\"\"\n    model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm_cx')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm_cx')\n    run_forward_checks(model)\n    model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm_cx')\n    run_forward_checks(model)\n\ndef test_maxout(pretrain_file):\n    \"\"\"\n    Test with and without maxout layers for output\n    \"\"\"\n    model = build_model(pretrain_file, '--maxout_k', '0')\n    run_forward_checks(model)\n    # check the output size & implicitly check the type\n    # to check for a particularly silly bug\n    assert model.output_layers[-1].weight.shape[0] == len(model.transitions)\n\n    model = build_model(pretrain_file, '--maxout_k', '2')\n    run_forward_checks(model)\n    assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 2\n\n    model = build_model(pretrain_file, '--maxout_k', '3')\n    run_forward_checks(model)\n    assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 3\n\ndef check_structure_test(pretrain_file, args1, args2):\n    \"\"\"\n    Test that the \"copy\" method copies the parameters from one model to another\n\n    Also check that the copied models produce the same results\n    \"\"\"\n    set_random_seed(1000)\n    other = build_model(pretrain_file, *args1)\n    other.eval()\n\n    set_random_seed(1001)\n    model = build_model(pretrain_file, *args2)\n    model.eval()\n\n    assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)\n    assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)\n\n    model.copy_with_new_structure(other)\n\n    assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)\n    assert torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)\n    # the norms will be the same, as the non-zero values are all the same\n    assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0))\n\n    # now, check that applying one transition to an initial state\n    # results in the same values in the output states for both models\n    # as the pattn layer inputs are 0, the output values should be equal\n    shift = [parse_transitions.Shift()]\n    model_states = test_parse_transitions.build_initial_state(model, 1)\n    model_states = model.bulk_apply(model_states, shift)\n\n    other_states = test_parse_transitions.build_initial_state(other, 1)\n    other_states = other.bulk_apply(other_states, shift)\n\n    for i, j in zip(other_states[0].word_queue, model_states[0].word_queue):\n        assert torch.allclose(i.hx, j.hx, atol=1e-07)\n    for i, j in zip(other_states[0].transitions, model_states[0].transitions):\n        assert torch.allclose(i.lstm_hx, j.lstm_hx)\n        assert torch.allclose(i.lstm_cx, j.lstm_cx)\n    for i, j in zip(other_states[0].constituents, model_states[0].constituents):\n        assert (i.value is None) == (j.value is None)\n        if i.value is not None:\n            assert torch.allclose(i.value.tree_hx, j.value.tree_hx, atol=1e-07)\n        assert torch.allclose(i.lstm_hx, j.lstm_hx)\n        assert torch.allclose(i.lstm_cx, j.lstm_cx)\n\ndef test_copy_with_new_structure_same(pretrain_file):\n    \"\"\"\n    Test that copying the structure with no changes works as expected\n    \"\"\"\n    check_structure_test(pretrain_file,\n                         ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],\n                         ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'])\n\ndef test_copy_with_new_structure_untied(pretrain_file):\n    \"\"\"\n    Test that copying the structure with no changes works as expected\n    \"\"\"\n    check_structure_test(pretrain_file,\n                         ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'MAX'],\n                         ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'UNTIED_MAX'])\n\ndef test_copy_with_new_structure_pattn(pretrain_file):\n    check_structure_test(pretrain_file,\n                         ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],\n                         ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])\n\ndef test_copy_with_new_structure_both(pretrain_file):\n    check_structure_test(pretrain_file,\n                         ['--pattn_num_layers', '0', '--lattn_d_proj',  '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],\n                         ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])\n\ndef test_copy_with_new_structure_lattn(pretrain_file):\n    check_structure_test(pretrain_file,\n                         ['--pattn_num_layers', '1', '--lattn_d_proj',  '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'],\n                         ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])\n\ndef test_parse_tagged_words(pretrain_file):\n    \"\"\"\n    Small test which doesn't check results, just execution\n    \"\"\"\n    model = build_model(pretrain_file)\n\n    sentence = [(\"I\", \"PRP\"), (\"am\", \"VBZ\"), (\"Luffa\", \"NNP\")]\n\n    # we don't expect a useful tree out of a random model\n    # so we don't check the result\n    # just check that it works without crashing\n    result = model.parse_tagged_words([sentence], 10)\n    assert len(result) == 1\n    pts = [x for x in result[0].yield_preterminals()]\n\n    for word, pt in zip(sentence, pts):\n        assert pt.children[0].label == word[0]\n        assert pt.label == word[1]\n"
  },
  {
    "path": "stanza/tests/constituency/test_parse_transitions.py",
    "content": "import pytest\n\nfrom stanza.models.constituency import parse_transitions\nfrom stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT\nfrom stanza.models.constituency.parse_transitions import TransitionScheme, Shift, CloseConstituent, OpenConstituent\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n\ndef build_initial_state(model, num_states=1):\n    words = [\"Unban\", \"Mox\", \"Opal\"]\n    tags = [\"VB\", \"NNP\", \"NNP\"]\n    sentences = [list(zip(words, tags)) for _ in range(num_states)]\n\n    states = model.initial_state_from_words(sentences)\n    assert len(states) == num_states\n    assert all(state.num_transitions == 0 for state in states)\n    return states\n\ndef test_initial_state(model=None):\n    if model is None:\n        model = SimpleModel()\n    states = build_initial_state(model)\n    assert len(states) == 1\n    state = states[0]\n\n    assert state.sentence_length == 3\n    assert state.num_opens == 0\n    # each stack has a sentinel value at the end\n    assert len(state.word_queue) == 5\n    assert len(state.constituents) == 1\n    assert len(state.transitions) == 1\n    assert state.word_position == 0\n\ndef test_shift(model=None):\n    if model is None:\n        model = SimpleModel()\n    state = build_initial_state(model)[0]\n\n    open_transition = parse_transitions.OpenConstituent(\"ROOT\")\n    state = open_transition.apply(state, model)\n    open_transition = parse_transitions.OpenConstituent(\"S\")\n    state = open_transition.apply(state, model)\n    shift = parse_transitions.Shift()\n    assert shift.is_legal(state, model)\n    assert len(state.word_queue) == 5\n    assert state.word_position == 0\n\n    state = shift.apply(state, model)\n    assert len(state.word_queue) == 5\n    # 4 because of the dummy created by the opens\n    assert len(state.constituents) == 4\n    assert len(state.transitions) == 4\n    assert shift.is_legal(state, model)\n    assert state.word_position == 1\n    assert not state.empty_word_queue()\n\n    state = shift.apply(state, model)\n    assert len(state.word_queue) == 5\n    assert len(state.constituents) == 5\n    assert len(state.transitions) == 5\n    assert shift.is_legal(state, model)\n    assert state.word_position == 2\n    assert not state.empty_word_queue()\n\n    state = shift.apply(state, model)\n    assert len(state.word_queue) == 5\n    assert len(state.constituents) == 6\n    assert len(state.transitions) == 6\n    assert not shift.is_legal(state, model)\n    assert state.word_position == 3\n    assert state.empty_word_queue()\n\n    constituents = state.constituents\n    assert model.get_top_constituent(constituents).children[0].label == 'Opal'\n    constituents = constituents.pop()\n    assert model.get_top_constituent(constituents).children[0].label == 'Mox'\n    constituents = constituents.pop()\n    assert model.get_top_constituent(constituents).children[0].label == 'Unban'\n\ndef test_initial_unary(model=None):\n    # it doesn't make sense to start with a CompoundUnary\n    if model is None:\n        model = SimpleModel()\n\n    state = build_initial_state(model)[0]\n    unary = parse_transitions.CompoundUnary('ROOT', 'VP')\n    assert unary.label == ('ROOT', 'VP',)\n    assert not unary.is_legal(state, model)\n    unary = parse_transitions.CompoundUnary('VP')\n    assert unary.label == ('VP',)\n    assert not unary.is_legal(state, model)\n\n\ndef test_unary(model=None):\n    if model is None:\n        model = SimpleModel()\n    state = build_initial_state(model)[0]\n\n    shift = parse_transitions.Shift()\n    state = shift.apply(state, model)\n\n    # this is technically the wrong parse but we're being lazy\n    unary = parse_transitions.CompoundUnary('S', 'VP')\n    assert unary.is_legal(state, model)\n    state = unary.apply(state, model)\n    assert not unary.is_legal(state, model)\n\n    tree = model.get_top_constituent(state.constituents)\n    assert tree.label == 'S'\n    assert len(tree.children) == 1\n    tree = tree.children[0]\n    assert tree.label == 'VP'\n    assert len(tree.children) == 1\n    tree = tree.children[0]\n    assert tree.label == 'VB'\n    assert tree.is_preterminal()\n\ndef test_unary_requires_root(model=None):\n    if model is None:\n        model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)\n    state = build_initial_state(model)[0]\n\n    open_transition = parse_transitions.OpenConstituent(\"S\")\n    assert open_transition.is_legal(state, model)\n    state = open_transition.apply(state, model)\n\n    shift = parse_transitions.Shift()\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert not shift.is_legal(state, model)\n\n    close_transition = parse_transitions.CloseConstituent()\n    assert close_transition.is_legal(state, model)\n    state = close_transition.apply(state, model)\n    assert not open_transition.is_legal(state, model)\n    assert not close_transition.is_legal(state, model)\n\n    np_unary = parse_transitions.CompoundUnary(\"NP\")\n    assert not np_unary.is_legal(state, model)\n    root_unary = parse_transitions.CompoundUnary(\"ROOT\")\n    assert root_unary.is_legal(state, model)\n    assert not state.finished(model)\n    state = root_unary.apply(state, model)\n    assert not root_unary.is_legal(state, model)\n\n    assert state.finished(model)\n\ndef test_open(model=None):\n    if model is None:\n        model = SimpleModel()\n    state = build_initial_state(model)[0]\n\n    shift = parse_transitions.Shift()\n    state = shift.apply(state, model)\n    state = shift.apply(state, model)\n    assert state.num_opens == 0\n\n    open_transition = parse_transitions.OpenConstituent(\"VP\")\n    assert open_transition.is_legal(state, model)\n    state = open_transition.apply(state, model)\n    assert open_transition.is_legal(state, model)\n    assert state.num_opens == 1\n\n    # check that it is illegal if there are too many opens already\n    for i in range(20):\n        state = open_transition.apply(state, model)\n    assert not open_transition.is_legal(state, model)\n    assert state.num_opens == 21\n\n    # check that it is illegal if the state is out of words\n    state = build_initial_state(model)[0]\n    state = shift.apply(state, model)\n    state = shift.apply(state, model)\n    state = shift.apply(state, model)\n    assert not open_transition.is_legal(state, model)\n\ndef test_compound_open(model=None):\n    if model is None:\n        model = SimpleModel()\n    state = build_initial_state(model)[0]\n\n    open_transition = parse_transitions.OpenConstituent(\"ROOT\", \"S\")\n    assert open_transition.is_legal(state, model)\n    shift = parse_transitions.Shift()\n    close_transition = parse_transitions.CloseConstituent()\n\n    state = open_transition.apply(state, model)\n    state = shift.apply(state, model)\n    state = shift.apply(state, model)\n    state = shift.apply(state, model)\n    state = close_transition.apply(state, model)\n\n    tree = model.get_top_constituent(state.constituents)\n    assert tree.label == 'ROOT'\n    assert len(tree.children) == 1\n    tree = tree.children[0]\n    assert tree.label == 'S'\n    assert len(tree.children) == 3\n    assert tree.children[0].children[0].label == 'Unban'\n    assert tree.children[1].children[0].label == 'Mox'\n    assert tree.children[2].children[0].label == 'Opal'\n\ndef test_in_order_open(model=None):\n    if model is None:\n        model = SimpleModel(TransitionScheme.IN_ORDER)\n    state = build_initial_state(model)[0]\n\n    shift = parse_transitions.Shift()\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert not shift.is_legal(state, model)\n\n    open_vp = parse_transitions.OpenConstituent(\"VP\")\n    assert open_vp.is_legal(state, model)\n    state = open_vp.apply(state, model)\n    assert not open_vp.is_legal(state, model)\n\n    close_trans = parse_transitions.CloseConstituent()\n    assert close_trans.is_legal(state, model)\n    state = close_trans.apply(state, model)\n\n    open_s = parse_transitions.OpenConstituent(\"S\")\n    assert open_s.is_legal(state, model)\n    state = open_s.apply(state, model)\n    assert not open_vp.is_legal(state, model)\n\n    # check that root transitions won't happen in the middle of a parse\n    open_root = parse_transitions.OpenConstituent(\"ROOT\")\n    assert not open_root.is_legal(state, model)\n\n    # build (NP (NNP Mox) (NNP Opal))\n    open_np = parse_transitions.OpenConstituent(\"NP\")\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert open_np.is_legal(state, model)\n    # make sure root can't happen in places where an arbitrary open is legal\n    assert not open_root.is_legal(state, model)\n    state = open_np.apply(state, model)\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert close_trans.is_legal(state, model)\n    state = close_trans.apply(state, model)\n\n    assert close_trans.is_legal(state, model)\n    state = close_trans.apply(state, model)\n\n    assert open_root.is_legal(state, model)\n    state = open_root.apply(state, model)\n\ndef test_too_many_unaries_close():\n    \"\"\"\n    This tests rejecting Close at the start of a sequence after too many unary transitions\n\n    The model should reject doing multiple \"unaries\" - eg, Open then Close - in an IN_ORDER sequence\n    \"\"\"\n    model = SimpleModel(TransitionScheme.IN_ORDER)\n    state = build_initial_state(model)[0]\n\n    shift = parse_transitions.Shift()\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n\n    open_np = parse_transitions.OpenConstituent(\"NP\")\n    close_trans = parse_transitions.CloseConstituent()\n    for _ in range(UNARY_LIMIT):\n        assert open_np.is_legal(state, model)\n        state = open_np.apply(state, model)\n\n        assert close_trans.is_legal(state, model)\n        state = close_trans.apply(state, model)\n\n    assert open_np.is_legal(state, model)\n    state = open_np.apply(state, model)\n    assert not close_trans.is_legal(state, model)\n\ndef test_too_many_unaries_open():\n    \"\"\"\n    This tests rejecting Open in the middle of a sequence after too many unary transitions\n\n    The model should reject doing multiple \"unaries\" - eg, Open then Close - in an IN_ORDER sequence\n    \"\"\"\n    model = SimpleModel(TransitionScheme.IN_ORDER)\n    state = build_initial_state(model)[0]\n\n    shift = parse_transitions.Shift()\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n\n    open_np = parse_transitions.OpenConstituent(\"NP\")\n    close_trans = parse_transitions.CloseConstituent()\n\n    assert open_np.is_legal(state, model)\n    state = open_np.apply(state, model)\n    assert not open_np.is_legal(state, model)\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n\n    for _ in range(UNARY_LIMIT):\n        assert open_np.is_legal(state, model)\n        state = open_np.apply(state, model)\n\n        assert close_trans.is_legal(state, model)\n        state = close_trans.apply(state, model)\n\n    assert not open_np.is_legal(state, model)\n\ndef test_close(model=None):\n    if model is None:\n        model = SimpleModel()\n\n    # this one actually tests an entire subtree building\n    state = build_initial_state(model)[0]\n\n    open_transition_vp = parse_transitions.OpenConstituent(\"VP\")\n    assert open_transition_vp.is_legal(state, model)\n    state = open_transition_vp.apply(state, model)\n    assert state.num_opens == 1\n\n    shift = parse_transitions.Shift()\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n\n    open_transition_np = parse_transitions.OpenConstituent(\"NP\")\n    assert open_transition_np.is_legal(state, model)\n    state = open_transition_np.apply(state, model)\n    assert state.num_opens == 2\n\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert shift.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert not shift.is_legal(state, model)\n    assert state.num_opens == 2\n    # now should have \"mox\", \"opal\" on the constituents\n\n    close_transition = parse_transitions.CloseConstituent()\n    assert close_transition.is_legal(state, model)\n    state = close_transition.apply(state, model)\n    assert state.num_opens == 1\n    assert close_transition.is_legal(state, model)\n    state = close_transition.apply(state, model)\n    assert state.num_opens == 0\n    assert not close_transition.is_legal(state, model)\n\n    tree = model.get_top_constituent(state.constituents)\n    assert tree.label == 'VP'\n    assert len(tree.children) == 2\n    tree = tree.children[1]\n    assert tree.label == 'NP'\n    assert len(tree.children) == 2\n    assert tree.children[0].is_preterminal()\n    assert tree.children[1].is_preterminal()\n    assert tree.children[0].children[0].label == 'Mox'\n    assert tree.children[1].children[0].label == 'Opal'\n\n    # extra one for None at the start of the TreeStack\n    assert len(state.constituents) == 2\n\n    assert state.all_transitions(model) == [open_transition_vp, shift, open_transition_np, shift, shift, close_transition, close_transition]\n\ndef test_in_order_compound_finalize(model=None):\n    \"\"\"\n    Test the Finalize transition is only legal at the end of a sequence\n    \"\"\"\n    if model is None:\n        model = SimpleModel(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)\n\n    state = build_initial_state(model)[0]\n\n    finalize = parse_transitions.Finalize(\"ROOT\")\n\n    shift = parse_transitions.Shift()\n    assert shift.is_legal(state, model)\n    assert not finalize.is_legal(state, model)\n    state = shift.apply(state, model)\n\n    open_transition = parse_transitions.OpenConstituent(\"NP\")\n    assert open_transition.is_legal(state, model)\n    assert not finalize.is_legal(state, model)\n    state = open_transition.apply(state, model)\n    assert state.num_opens == 1\n\n    assert shift.is_legal(state, model)\n    assert not finalize.is_legal(state, model)\n    state = shift.apply(state, model)\n    assert shift.is_legal(state, model)\n    assert not finalize.is_legal(state, model)\n    state = shift.apply(state, model)\n\n    close_transition = parse_transitions.CloseConstituent()\n    assert close_transition.is_legal(state, model)\n    state = close_transition.apply(state, model)\n    assert state.num_opens == 0\n    assert not close_transition.is_legal(state, model)\n    assert finalize.is_legal(state, model)\n\n    state = finalize.apply(state, model)\n    assert not finalize.is_legal(state, model)\n    tree = model.get_top_constituent(state.constituents)\n    assert tree.label == 'ROOT'\n\ndef test_hashes():\n    transitions = set()\n\n    shift = parse_transitions.Shift()\n    assert shift not in transitions\n    transitions.add(shift)\n    assert shift in transitions\n    shift = parse_transitions.Shift()\n    assert shift in transitions\n\n    for i in range(5):\n        transitions.add(shift)\n    assert len(transitions) == 1\n\n    unary = parse_transitions.CompoundUnary(\"asdf\")\n    assert unary not in transitions\n    transitions.add(unary)\n    assert unary in transitions\n\n    unary = parse_transitions.CompoundUnary(\"asdf\", \"zzzz\")\n    assert unary not in transitions\n    transitions.add(unary)\n    transitions.add(unary)\n    transitions.add(unary)\n    unary = parse_transitions.CompoundUnary(\"asdf\", \"zzzz\")\n    assert unary in transitions\n\n    oc = parse_transitions.OpenConstituent(\"asdf\")\n    assert oc not in transitions\n    transitions.add(oc)\n    assert oc in transitions\n    transitions.add(oc)\n    transitions.add(oc)\n    assert len(transitions) == 4\n    assert parse_transitions.OpenConstituent(\"asdf\") in transitions\n\n    cc = parse_transitions.CloseConstituent()\n    assert cc not in transitions\n    transitions.add(cc)\n    transitions.add(cc)\n    transitions.add(cc)\n    assert cc in transitions\n    cc = parse_transitions.CloseConstituent()\n    assert cc in transitions\n    assert len(transitions) == 5\n\n\ndef test_sort():\n    expected = []\n\n    expected.append(parse_transitions.Shift())\n    expected.append(parse_transitions.CloseConstituent())\n    expected.append(parse_transitions.CompoundUnary(\"NP\"))\n    expected.append(parse_transitions.CompoundUnary(\"NP\", \"VP\"))\n    expected.append(parse_transitions.OpenConstituent(\"mox\"))\n    expected.append(parse_transitions.OpenConstituent(\"opal\"))\n    expected.append(parse_transitions.OpenConstituent(\"unban\"))\n\n    transitions = set(expected)\n    transitions = sorted(transitions)\n    assert transitions == expected\n\ndef test_check_transitions():\n    \"\"\"\n    Test that check_transitions passes or fails a couple simple, small test cases\n    \"\"\"\n    transitions = {Shift(), CloseConstituent(), OpenConstituent(\"NP\"), OpenConstituent(\"VP\")}\n\n    other = {Shift(), CloseConstituent(), OpenConstituent(\"NP\"), OpenConstituent(\"VP\")}\n    parse_transitions.check_transitions(transitions, other, \"test\")\n\n    # This will get a pass because it is a unary made out of existing unaries\n    other = {Shift(), CloseConstituent(), OpenConstituent(\"NP\", \"VP\")}\n    parse_transitions.check_transitions(transitions, other, \"test\")\n\n    # This should fail\n    with pytest.raises(RuntimeError):\n        other = {Shift(), CloseConstituent(), OpenConstituent(\"NP\", \"ZP\")}\n        parse_transitions.check_transitions(transitions, other, \"test\")\n"
  },
  {
    "path": "stanza/tests/constituency/test_parse_tree.py",
    "content": "import pytest\n\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.models.constituency import tree_reader\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef test_leaf_preterminal():\n    foo = Tree(label=\"foo\")\n    assert foo.is_leaf()\n    assert not foo.is_preterminal()\n    assert len(foo.children) == 0\n    assert str(foo) == 'foo'\n\n    bar = Tree(label=\"bar\", children=foo)\n    assert not bar.is_leaf()\n    assert bar.is_preterminal()\n    assert len(bar.children) == 1\n    assert str(bar) == \"(bar foo)\"\n\n    baz = Tree(label=\"baz\", children=[bar])\n    assert not baz.is_leaf()\n    assert not baz.is_preterminal()\n    assert len(baz.children) == 1\n    assert str(baz) == \"(baz (bar foo))\"\n\n\ndef test_yield_preterminals():\n    text = \"((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\"\n    trees = tree_reader.read_trees(text)\n\n    preterminals = list(trees[0].yield_preterminals())\n    assert len(preterminals) == 3\n    assert str(preterminals) == \"[(VB Unban), (NNP Mox), (NNP Opal)]\"\n\ndef test_depth():\n    text = \"(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\"\n    trees = tree_reader.read_trees(text)\n    assert trees[0].depth() == 0\n    assert trees[1].depth() == 4\n\ndef test_unique_labels():\n    \"\"\"\n    Test getting the unique labels from a tree\n\n    Assumes tree_reader works, which should be fine since it is tested elsewhere\n    \"\"\"\n    text=\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n\n    trees = tree_reader.read_trees(text)\n\n    labels = Tree.get_unique_constituent_labels(trees)\n    expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP']\n    assert labels == expected\n\ndef test_unique_tags():\n    \"\"\"\n    Test getting the unique tags from a tree\n    \"\"\"\n    text=\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n\n    trees = tree_reader.read_trees(text)\n\n    tags = Tree.get_unique_tags(trees)\n    expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP']\n    assert tags == expected\n\n\ndef test_unique_words():\n    \"\"\"\n    Test getting the unique words from a tree\n    \"\"\"\n    text=\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n\n    trees = tree_reader.read_trees(text)\n\n    words = Tree.get_unique_words(trees)\n    expected = ['?', 'Who', 'in', 'seat', 'sits', 'this']\n    assert words == expected\n\ndef test_rare_words():\n    \"\"\"\n    Test getting the unique words from a tree\n    \"\"\"\n    text=\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))  ((SBARQ (NP (DT this) (NN seat)) (. ?)))\"\n\n    trees = tree_reader.read_trees(text)\n\n    words = Tree.get_rare_words(trees, 0.5)\n    expected = ['Who', 'in', 'sits']\n    assert words == expected\n\ndef test_common_words():\n    \"\"\"\n    Test getting the unique words from a tree\n    \"\"\"\n    text=\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))  ((SBARQ (NP (DT this) (NN seat)) (. ?)))\"\n\n    trees = tree_reader.read_trees(text)\n\n    words = Tree.get_common_words(trees, 3)\n    expected = ['?', 'seat', 'this']\n    assert words == expected\n\ndef test_root_labels():\n    text=\"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n    assert [\"ROOT\"] == Tree.get_root_labels(trees)\n\n    text=(\"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\" +\n          \"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\" +\n          \"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\")\n    trees = tree_reader.read_trees(text)\n    assert [\"ROOT\"] == Tree.get_root_labels(trees)\n\n    text=\"(FOO) (BAR)\"\n    trees = tree_reader.read_trees(text)\n    assert [\"BAR\", \"FOO\"] == Tree.get_root_labels(trees)\n\ndef test_prune_none():\n    text=[\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (-NONE- in) (NP (DT this) (NN seat))))) (. ?)))\", # test one dead node\n          \"((SBARQ (WHNP (-NONE- Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\", # test recursive dead nodes\n          \"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (-NONE- this) (-NONE- seat))))) (. ?)))\"] # test all children dead\n    expected=[\"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (NP (DT this) (NN seat))))) (. ?)))\",\n              \"(ROOT (SBARQ (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\",\n              \"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))\"]\n\n    for t, e in zip(text, expected):\n        trees = tree_reader.read_trees(t)\n        assert len(trees) == 1\n        tree = trees[0].prune_none()\n        assert e == str(tree)\n\ndef test_simplify_labels():\n    text=\"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))\"\n    expected = \"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n    trees = [t.simplify_labels() for t in trees]\n    assert len(trees) == 1\n    assert expected == str(trees[0])\n\ndef test_remap_constituent_labels():\n    text=\"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))\"\n    expected=\"(ROOT (FOO (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))\"\n\n    label_map = { \"SBARQ\": \"FOO\" }\n    trees = tree_reader.read_trees(text)\n    trees = [t.remap_constituent_labels(label_map) for t in trees]\n    assert len(trees) == 1\n    assert expected == str(trees[0])\n\ndef test_remap_constituent_words():\n    text=\"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))\"\n    expected=\"(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))\"\n\n    word_map = { \"Who\": \"unban\", \"sits\": \"mox\", \"in\": \"opal\" }\n    trees = tree_reader.read_trees(text)\n    trees = [t.remap_words(word_map) for t in trees]\n    assert len(trees) == 1\n    assert expected == str(trees[0])\n\ndef test_replace_words():\n    text=\"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))\"\n    expected=\"(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))\"\n    new_words = [\"unban\", \"mox\", \"opal\", \"?\"]\n\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    tree = trees[0]\n    new_tree = tree.replace_words(new_words)\n    assert expected == str(new_tree)\n\n\ndef test_compound_constituents():\n    # TODO: add skinny trees like this to the various transition tests\n    text=\"((VP (VB Unban)))\"\n    trees = tree_reader.read_trees(text)\n    assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')]\n\n    text=\"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n    assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('SQ', 'VP'), ('WHNP',)]\n\n    text=\"((VP (VB Unban)))   (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n    assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('ROOT', 'VP'), ('SQ', 'VP'), ('WHNP',)]\n\ndef test_equals():\n    \"\"\"\n    Check one tree from the actual dataset for ==\n\n    when built with compound Open, this didn't work because of a silly bug\n    \"\"\"\n    text = \"(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))\"\n\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    assert tree == tree\n\n    trees2 = tree_reader.read_trees(text)\n    tree2 = trees2[0]\n\n    assert tree is not tree2\n    assert tree == tree2\n\n\n# This tree was causing the model to barf on CTB7,\n# although it turns out the problem was just the\n# depth of the unary, not the list\nCHINESE_LONG_LIST_TREE = \"\"\"\n(ROOT\n (IP\n  (NP (NNP 证券法))\n  (VP\n   (PP\n    (IN 对)\n    (NP\n     (DNP\n      (NP\n       (NP (NNP 中国))\n       (NP\n        (NN 证券)\n        (NN 市场)))\n      (DEC 的))\n     (NP (NN 运作))))\n   (, ，)\n   (PP\n    (PP\n     (IN 从)\n     (NP\n      (NP (NN 股票))\n      (NP (VV 发行) (EC 、) (VV 交易))))\n    (, ，)\n    (PP\n     (VV 到)\n     (NP\n      (NP (NN 上市) (NN 公司) (NN 收购))\n      (EC 、)\n      (NP (NN 证券) (NN 交易所))\n      (EC 、)\n      (NP (NN 证券) (NN 公司))\n      (EC 、)\n      (NP (NN 登记) (NN 结算) (NN 机构))\n      (EC 、)\n      (NP (NN 交易) (NN 服务) (NN 机构))\n      (EC 、)\n      (NP (NN 证券业) (NN 协会))\n      (EC 、)\n      (NP (NN 证券) (NN 监督) (NN 管理) (NN 机构))\n      (CC 和)\n      (NP\n       (DNP\n        (NP (CP (CP (IP (VP (VV 违法))))))\n        (DEC 的))\n       (NP (NN 法律) (NN 责任))))))\n   (ADVP (RB 都))\n   (VP\n    (VV 作)\n    (AS 了)\n    (NP\n     (ADJP (JJ 详细))\n     (NP (NN 规定)))))\n  (. 。)))\n\"\"\"\n\nWEIRD_UNARY = \"\"\"\n  (DNP\n    (NP (CP (CP (IP (VP (ASDF\n      (NP (NN 上市) (NN 公司) (NN 收购))\n      (EC 、)\n      (NP (NN 证券) (NN 交易所))\n      (EC 、)\n      (NP (NN 证券) (NN 公司))\n      (EC 、)\n      (NP (NN 登记) (NN 结算) (NN 机构))\n      (EC 、)\n      (NP (NN 交易) (NN 服务) (NN 机构))\n      (EC 、)\n      (NP (NN 证券业) (NN 协会))\n      (EC 、)\n      (NP (NN 证券) (NN 监督) (NN 管理) (NN 机构))))))))\n    (DEC 的))\n\"\"\"\n\n\ndef test_count_unaries():\n    trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)\n    assert len(trees) == 1\n    assert trees[0].count_unary_depth() == 5\n\n    trees = tree_reader.read_trees(WEIRD_UNARY)\n    assert len(trees) == 1\n    assert trees[0].count_unary_depth() == 5\n\ndef test_str_bracket_labels():\n    text = \"((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\"\n    expected = \"(_ROOT (_S (_VP (_VB Unban )_VB )_VP (_NP (_NNP Mox )_NNP (_NNP Opal )_NNP )_NP )_S )_ROOT\"\n\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    assert \"{:L}\".format(trees[0]) == expected\n\ndef test_all_leaves_are_preterminals():\n    text = \"((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    assert trees[0].all_leaves_are_preterminals()\n\n    text = \"((S (VP (VB Unban)) (NP (Mox) (NNP Opal))))\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    assert not trees[0].all_leaves_are_preterminals()\n\ndef test_latex():\n    \"\"\"\n    Test the latex format for trees\n    \"\"\"\n    expected = \"\\\\Tree [.S [.NP Jennifer ] [.VP has [.NP nice antennae ] ] ]\"\n    tree = \"(ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ nice) (NNS antennae)))))\"\n    tree = tree_reader.read_trees(tree)[0]\n    text = \"{:T}\".format(tree)\n    assert text == expected\n\ndef test_pretty_print():\n    \"\"\"\n    Pretty print a couple trees - newlines & indentation\n    \"\"\"\n    text = \"(ROOT (S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal)))) (ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric)))))))\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 2\n\n    expected = \"\"\"(ROOT\n  (S\n    (VP (VB Unban))\n    (NP (NNP Mox) (NNP Opal))))\n\"\"\"\n\n    assert \"{:P}\".format(trees[0]) == expected\n\n    expected = \"\"\"(ROOT\n  (S\n    (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission))\n    (VP\n      (VBD authorized)\n      (NP\n        (NP\n          (DT an)\n          (ADJP (CD 11.5))\n          (NN %)\n          (NN rate)\n          (NN increase))\n        (PP\n          (IN at)\n          (NP (NNP Tucson) (NNP Electric)))))))\n\"\"\"\n    assert \"{:P}\".format(trees[1]) == expected\n\n    assert text == \"{:O} {:O}\".format(*trees)\n\ndef test_reverse():\n    text = \"(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB lick) (NP (NP (NNP Jennifer) (POS 's)) (NNS antennae))))))))\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    reversed_tree = trees[0].reverse()\n    assert str(reversed_tree) == \"(ROOT (S (VP (S (VP (VP (NP (NNS antennae) (NP (POS 's) (NNP Jennifer))) (VB lick)) (TO to))) (VBP want)) (NP (PRP I))))\"\n"
  },
  {
    "path": "stanza/tests/constituency/test_positional_encoding.py",
    "content": "import pytest\n\nimport torch\n\nfrom stanza import Pipeline\nfrom stanza.models.constituency.positional_encoding import SinusoidalEncoding, AddSinusoidalEncoding\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n\ndef test_positional_encoding():\n    encoding = SinusoidalEncoding(model_dim=10, max_len=6)\n    foo = encoding(torch.tensor([5]))\n    assert foo.shape == (1, 10)\n    # TODO: check the values\n\ndef test_resize():\n    encoding = SinusoidalEncoding(model_dim=10, max_len=3)\n    foo = encoding(torch.tensor([5]))\n    assert foo.shape == (1, 10)\n\n\ndef test_arange():\n    encoding = SinusoidalEncoding(model_dim=10, max_len=2)\n    foo = encoding(torch.arange(4))\n    assert foo.shape == (4, 10)\n    assert encoding.max_len() == 4\n\ndef test_add():\n    encoding = AddSinusoidalEncoding(d_model=10, max_len=4)\n    x = torch.zeros(1, 4, 10)\n    y = encoding(x)\n\n    r = torch.randn(1, 4, 10)\n    r2 = encoding(r)\n\n    assert torch.allclose(r2 - r, y, atol=1e-07)\n\n    r = torch.randn(2, 4, 10)\n    r2 = encoding(r)\n\n    assert torch.allclose(r2[0] - r[0], y, atol=1e-07)\n    assert torch.allclose(r2[1] - r[1], y, atol=1e-07)\n"
  },
  {
    "path": "stanza/tests/constituency/test_selftrain_vi_quad.py",
    "content": "\"\"\"\nTest some of the methods in the vi_quad dataset\n\nUses a small section of the dataset as a test\n\"\"\"\n\nimport pytest\n\nfrom stanza.utils.datasets.constituency import selftrain_vi_quad\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nSAMPLE_TEXT = \"\"\"\n{\"version\": \"1.1\", \"data\": [{\"title\": \"Ph\\u1ea1m V\\u0103n \\u0110\\u1ed3ng\", \"paragraphs\": [{\"qas\": [{\"question\": \"T\\u00ean g\\u1ecdi n\\u00e0o \\u0111\\u01b0\\u1ee3c Ph\\u1ea1m V\\u0103n \\u0110\\u1ed3ng s\\u1eed d\\u1ee5ng khi l\\u00e0m Ph\\u00f3 ch\\u1ee7 nhi\\u1ec7m c\\u01a1 quan Bi\\u1ec7n s\\u1ef1 x\\u1ee9 t\\u1ea1i Qu\\u1ebf L\\u00e2m?\", \"answers\": [{\"answer_start\": 507, \"text\": \"L\\u00e2m B\\u00e1 Ki\\u1ec7t\"}], \"id\": \"uit_01__05272_0_1\"}, {\"question\": \"Ph\\u1ea1m V\\u0103n \\u0110\\u1ed3ng gi\\u1eef ch\\u1ee9c v\\u1ee5 g\\u00ec trong b\\u1ed9 m\\u00e1y Nh\\u00e0 n\\u01b0\\u1edbc C\\u1ed9ng h\\u00f2a X\\u00e3 h\\u1ed9i ch\\u1ee7 ngh\\u0129a Vi\\u1ec7t Nam?\", \"answers\": [{\"answer_start\": 60, \"text\": \"Th\\u1ee7 t\\u01b0\\u1edbng\"}], \"id\": \"uit_01__05272_0_2\"}, {\"question\": \"Giai \\u0111o\\u1ea1n n\\u0103m 1955-1976, Ph\\u1ea1m V\\u0103n \\u0110\\u1ed3ng n\\u1eafm gi\\u1eef ch\\u1ee9c v\\u1ee5 g\\u00ec?\", \"answers\": [{\"answer_start\": 245, \"text\": \"Th\\u1ee7 t\\u01b0\\u1edbng Ch\\u00ednh ph\\u1ee7 Vi\\u1ec7t Nam D\\u00e2n ch\\u1ee7 C\\u1ed9ng h\\u00f2a\"}], \"id\": \"uit_01__05272_0_3\"}], \"context\": \"Ph\\u1ea1m V\\u0103n \\u0110\\u1ed3ng (1 th\\u00e1ng 3 n\\u0103m 1906 \\u2013 29 th\\u00e1ng 4 n\\u0103m 2000) l\\u00e0 Th\\u1ee7 t\\u01b0\\u1edbng \\u0111\\u1ea7u ti\\u00ean c\\u1ee7a n\\u01b0\\u1edbc C\\u1ed9ng h\\u00f2a X\\u00e3 h\\u1ed9i ch\\u1ee7 ngh\\u0129a Vi\\u1ec7t Nam t\\u1eeb n\\u0103m 1976 (t\\u1eeb n\\u0103m 1981 g\\u1ecdi l\\u00e0 Ch\\u1ee7 t\\u1ecbch H\\u1ed9i \\u0111\\u1ed3ng B\\u1ed9 tr\\u01b0\\u1edfng) cho \\u0111\\u1ebfn khi ngh\\u1ec9 h\\u01b0u n\\u0103m 1987. Tr\\u01b0\\u1edbc \\u0111\\u00f3 \\u00f4ng t\\u1eebng gi\\u1eef ch\\u1ee9c v\\u1ee5 Th\\u1ee7 t\\u01b0\\u1edbng Ch\\u00ednh ph\\u1ee7 Vi\\u1ec7t Nam D\\u00e2n ch\\u1ee7 C\\u1ed9ng h\\u00f2a t\\u1eeb n\\u0103m 1955 \\u0111\\u1ebfn n\\u0103m 1976. \\u00d4ng l\\u00e0 v\\u1ecb Th\\u1ee7 t\\u01b0\\u1edbng Vi\\u1ec7t Nam t\\u1ea1i v\\u1ecb l\\u00e2u nh\\u1ea5t (1955\\u20131987). \\u00d4ng l\\u00e0 h\\u1ecdc tr\\u00f2, c\\u1ed9ng s\\u1ef1 c\\u1ee7a Ch\\u1ee7 t\\u1ecbch H\\u1ed3 Ch\\u00ed Minh. \\u00d4ng c\\u00f3 t\\u00ean g\\u1ecdi th\\u00e2n m\\u1eadt l\\u00e0 T\\u00f4, \\u0111\\u00e2y t\\u1eebng l\\u00e0 b\\u00ed danh c\\u1ee7a \\u00f4ng. \\u00d4ng c\\u00f2n c\\u00f3 t\\u00ean g\\u1ecdi l\\u00e0 L\\u00e2m B\\u00e1 Ki\\u1ec7t khi l\\u00e0m Ph\\u00f3 ch\\u1ee7 nhi\\u1ec7m c\\u01a1 quan Bi\\u1ec7n s\\u1ef1 x\\u1ee9 t\\u1ea1i Qu\\u1ebf L\\u00e2m (Ch\\u1ee7 nhi\\u1ec7m l\\u00e0 H\\u1ed3 H\\u1ecdc L\\u00e3m).\"}, {\"qas\": [{\"question\": \"S\\u1ef1 ki\\u1ec7n quan tr\\u1ecdng n\\u00e0o \\u0111\\u00e3 di\\u1ec5n ra v\\u00e0o ng\\u00e0y 20/7/1954?\", \"answers\": [{\"answer_start\": 364, \"text\": \"b\\u1ea3n Hi\\u1ec7p \\u0111\\u1ecbnh \\u0111\\u00ecnh ch\\u1ec9 chi\\u1ebfn s\\u1ef1 \\u1edf Vi\\u1ec7t Nam, Campuchia v\\u00e0 L\\u00e0o \\u0111\\u00e3 \\u0111\\u01b0\\u1ee3c k\\u00fd k\\u1ebft th\\u1eeba nh\\u1eadn t\\u00f4n tr\\u1ecdng \\u0111\\u1ed9c l\\u1eadp, ch\\u1ee7 quy\\u1ec1n, c\\u1ee7a n\\u01b0\\u1edbc Vi\\u1ec7t Nam, L\\u00e0o v\\u00e0 Campuchia\"}], \"id\": \"uit_01__05272_1_1\"}, {\"question\": \"Ch\\u1ee9c v\\u1ee5 m\\u00e0 Ph\\u1ea1m V\\u0103n \\u0110\\u1ed3ng \\u0111\\u1ea3m nhi\\u1ec7m t\\u1ea1i H\\u1ed9i ngh\\u1ecb Gen\\u00e8ve v\\u1ec1 \\u0110\\u00f4ng D\\u01b0\\u01a1ng?\", \"answers\": [{\"answer_start\": 33, \"text\": \"Tr\\u01b0\\u1edfng ph\\u00e1i \\u0111o\\u00e0n Ch\\u00ednh ph\\u1ee7\"}], \"id\": \"uit_01__05272_1_2\"}, {\"question\": \"H\\u1ed9i ngh\\u1ecb Gen\\u00e8ve v\\u1ec1 \\u0110\\u00f4ng D\\u01b0\\u01a1ng c\\u00f3 t\\u00ednh ch\\u1ea5t nh\\u01b0 th\\u1ebf n\\u00e0o?\", \"answers\": [{\"answer_start\": 262, \"text\": \"r\\u1ea5t c\\u0103ng th\\u1eb3ng v\\u00e0 ph\\u1ee9c t\\u1ea1p\"}], \"id\": \"uit_01__05272_1_3\"}]}]}]}\n\"\"\"\n\nEXPECTED = ['Tên gọi nào được Phạm Văn Đồng sử dụng khi làm Phó chủ nhiệm cơ quan Biện sự xứ tại Quế Lâm?', 'Phạm Văn Đồng giữ chức vụ gì trong bộ máy Nhà nước Cộng hòa Xã hội chủ nghĩa Việt Nam?', 'Giai đoạn năm 1955-1976, Phạm Văn Đồng nắm giữ chức vụ gì?', 'Sự kiện quan trọng nào đã diễn ra vào ngày 20/7/1954?', 'Chức vụ mà Phạm Văn Đồng đảm nhiệm tại Hội nghị Genève về Đông Dương?', 'Hội nghị Genève về Đông Dương có tính chất như thế nào?']\n\ndef test_read_file():\n    results = selftrain_vi_quad.parse_quad(SAMPLE_TEXT)\n    assert results == EXPECTED\n"
  },
  {
    "path": "stanza/tests/constituency/test_text_processing.py",
    "content": "\"\"\"\nRun through the various text processing methods for using the parser on text files / directories\n\nUses a simple tree where the parser should always get it right, but things could potentially go wrong\n\"\"\"\n\nimport glob\nimport os\nimport pytest\n\nfrom stanza import Pipeline\n\nfrom stanza.models.constituency import text_processing\nfrom stanza.models.constituency import tree_reader\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n@pytest.fixture(scope=\"module\")\ndef pipeline():\n    return Pipeline(dir=TEST_MODELS_DIR, lang=\"en\", processors=\"tokenize, pos, constituency\", tokenize_pretokenized=True)\n\ndef test_read_tokenized_file(tmp_path):\n    filename = str(tmp_path / \"test_input.txt\")\n    with open(filename, \"w\") as fout:\n        # test that the underscore token comes back with spaces\n        fout.write(\"This is a_small test\\nLine two\\n\")\n    text, ids = text_processing.read_tokenized_file(filename)\n    assert text == [['This', 'is', 'a small', 'test'], ['Line', 'two']]\n    assert ids == [None, None]\n\ndef test_parse_tokenized_sentences(pipeline):\n    con_processor = pipeline.processors[\"constituency\"]\n    model = con_processor._model\n    args = model.args\n\n    sentences = [[\"This\", \"is\", \"a\", \"test\"]]\n    trees = text_processing.parse_tokenized_sentences(args, model, [pipeline], sentences)\n    predictions = [x.predictions for x in trees]\n    assert len(predictions) == 1\n    scored_trees = predictions[0]\n    assert len(scored_trees) == 1\n    result = \"{}\".format(scored_trees[0].tree)\n    expected = \"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))\"\n    assert result == expected\n\ndef test_parse_text(tmp_path, pipeline):\n    con_processor = pipeline.processors[\"constituency\"]\n    model = con_processor._model\n    args = model.args\n\n    raw_file = str(tmp_path / \"test_input.txt\")\n    with open(raw_file, \"w\") as fout:\n        fout.write(\"This is a test\\nThis is another test\\n\")\n    output_file = str(tmp_path / \"test_output.txt\")\n    text_processing.parse_text(args, model, [pipeline], tokenized_file=raw_file, predict_file=output_file)\n\n    trees = tree_reader.read_treebank(output_file)\n    trees = [\"{}\".format(x) for x in trees]\n    expected_trees = [\"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))\",\n                      \"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))\"]\n    assert trees == expected_trees\n\ndef test_parse_dir(tmp_path, pipeline):\n    con_processor = pipeline.processors[\"constituency\"]\n    model = con_processor._model\n    args = model.args\n\n    raw_dir = str(tmp_path / \"input\")\n    os.makedirs(raw_dir)\n    raw_f1 = str(tmp_path / \"input\" / \"f1.txt\")\n    raw_f2 = str(tmp_path / \"input\" / \"f2.txt\")\n    output_dir = str(tmp_path / \"output\")\n\n    with open(raw_f1, \"w\") as fout:\n        fout.write(\"This is a test\")\n    with open(raw_f2, \"w\") as fout:\n        fout.write(\"This is another test\")\n\n    text_processing.parse_dir(args, model, [pipeline], raw_dir, output_dir)\n    output_files = sorted(glob.glob(os.path.join(output_dir, \"*\")))\n    expected_trees = [\"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))\",\n                      \"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))\"]\n    for output_file, expected_tree in zip(output_files, expected_trees):\n        trees = tree_reader.read_treebank(output_file)\n        assert len(trees) == 1\n        assert \"{}\".format(trees[0]) == expected_tree\n\ndef test_parse_text(tmp_path, pipeline):\n    con_processor = pipeline.processors[\"constituency\"]\n    model = con_processor._model\n    args = dict(model.args)\n\n    model_path = con_processor._config['model_path']\n\n    raw_file = str(tmp_path / \"test_input.txt\")\n    with open(raw_file, \"w\") as fout:\n        fout.write(\"This is a test\\nThis is another test\\n\")\n    output_file = str(tmp_path / \"test_output.txt\")\n\n    args['tokenized_file'] = raw_file\n    args['predict_file'] = output_file\n\n    text_processing.load_model_parse_text(args, model_path, [pipeline])\n    trees = tree_reader.read_treebank(output_file)\n    trees = [\"{}\".format(x) for x in trees]\n    expected_trees = [\"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))\",\n                      \"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))\"]\n    assert trees == expected_trees\n"
  },
  {
    "path": "stanza/tests/constituency/test_top_down_oracle.py",
    "content": "import pytest\n\nfrom stanza.models.constituency.base_model import SimpleModel\nfrom stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, TransitionScheme\nfrom stanza.models.constituency.top_down_oracle import *\nfrom stanza.models.constituency.transition_sequence import build_sequence\nfrom stanza.models.constituency.tree_reader import read_trees\n\nfrom stanza.tests.constituency.test_transition_sequence import reconstruct_tree\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nOPEN_SHIFT_EXAMPLE_TREE = \"\"\"\n( (S\n     (NP (NNP Jennifer) (NNP Sh\\'reyan))\n     (VP (VBZ has)\n         (NP (RB nice) (NNS antennae)))))\n\"\"\"\n\nOPEN_SHIFT_PROBLEM_TREE = \"\"\"\n(ROOT (S (NP (NP (NP (DT The) (`` ``) (JJ Thin) (NNP Man) ('' '') (NN series)) (PP (IN of) (NP (NNS movies)))) (, ,) (CONJP (RB as) (RB well) (IN as)) (NP (JJ many) (NNS others)) (, ,)) (VP (VBD based) (NP (PRP$ their) (JJ entire) (JJ comedic) (NN appeal)) (PP (IN on) (NP (NP (DT the) (NN star) (NNS detectives) (POS ')) (JJ witty) (NNS quips) (CC and) (NNS puns))) (SBAR (IN as) (S (NP (NP (JJ other) (NNS characters)) (PP (IN in) (NP (DT the) (NNS movies)))) (VP (VBD were) (VP (VBN murdered)))))) (. .)))\n\"\"\"\n\nROOT_LABELS = [\"ROOT\"]\n\ndef get_single_repair(gold_sequence, wrong_transition, repair_fn, idx, *args, **kwargs):\n    return repair_fn(gold_sequence[idx], wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None, *args, **kwargs)\n\ndef build_state(model, tree, num_transitions):\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    states = model.initial_state_from_gold_trees([tree], [transitions])\n    for idx, t in enumerate(transitions[:num_transitions]):\n        assert t.is_legal(states[0], model), \"Transition {} not legal at step {} in sequence {}\".format(t, idx, sequence)\n        states = model.bulk_apply(states, [t])\n    state = states[0]\n    return state\n\ndef test_fix_open_shift():\n    trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    EXPECTED_FIX_EARLY = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    EXPECTED_FIX_LATE = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n\n    assert transitions == EXPECTED_ORIG\n\n    new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2)\n    assert new_transitions == EXPECTED_FIX_EARLY\n\n    new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 8)\n    assert new_transitions == EXPECTED_FIX_LATE\n\ndef test_fix_open_shift_observed_error():\n    \"\"\"\n    Ran into an error on this tree, need to fix it\n\n    The problem is the multiple Open in a row all need to be removed when a Shift happens\n    \"\"\"\n    trees = read_trees(OPEN_SHIFT_PROBLEM_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2)\n    assert new_transitions is None\n\n    new_transitions = get_single_repair(transitions, Shift(), fix_multiple_open_shift, 2)\n\n    # Can break the expected transitions down like this:\n    # [OpenConstituent(('ROOT',)), OpenConstituent(('S',)),\n    # all gone: OpenConstituent(('NP',)), OpenConstituent(('NP',)), OpenConstituent(('NP',)),\n    # Shift, Shift, Shift, Shift, Shift, Shift,\n    # gone: CloseConstituent,\n    # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), Shift, CloseConstituent, CloseConstituent,\n    # gone: CloseConstituent,\n    # Shift, OpenConstituent(('CONJP',)), Shift, Shift, Shift, CloseConstituent, OpenConstituent(('NP',)), Shift, Shift, CloseConstituent, Shift,\n    # gone: CloseConstituent,\n    # and then the rest:\n    # OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)),\n    # Shift, Shift, Shift, Shift, CloseConstituent,\n    # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)),\n    # OpenConstituent(('NP',)), Shift, Shift, Shift, Shift,\n    # CloseConstituent, Shift, Shift, Shift, Shift, CloseConstituent,\n    # CloseConstituent, OpenConstituent(('SBAR',)), Shift,\n    # OpenConstituent(('S',)), OpenConstituent(('NP',)),\n    # OpenConstituent(('NP',)), Shift, Shift, CloseConstituent,\n    # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)),\n    # Shift, Shift, CloseConstituent, CloseConstituent,\n    # CloseConstituent, OpenConstituent(('VP',)), Shift,\n    # OpenConstituent(('VP',)), Shift, CloseConstituent,\n    # CloseConstituent, CloseConstituent, CloseConstituent,\n    # CloseConstituent, Shift, CloseConstituent, CloseConstituent]\n    expected_transitions = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), Shift(), Shift(), Shift(), Shift(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), Shift(), OpenConstituent('CONJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('SBAR'), Shift(), OpenConstituent('S'), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('VP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]\n\n    assert new_transitions == expected_transitions\n\ndef test_open_open_ambiguous_unary_fix():\n    trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    assert transitions == EXPECTED_ORIG\n    new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_unary, 2)\n    assert new_transitions == EXPECTED_FIX\n\n\ndef test_open_open_ambiguous_later_fix():\n    trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    assert transitions == EXPECTED_ORIG\n    new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_later, 2)\n    assert new_transitions == EXPECTED_FIX\n\n\nCLOSE_SHIFT_EXAMPLE_TREE = \"\"\"\n( (NP (DT a)\n   (ADJP (NN stock) (HYPH -) (VBG picking))\n   (NN tool)))\n\"\"\"\n\n# not intended to be a correct tree\nCLOSE_SHIFT_DEEP_EXAMPLE_TREE = \"\"\"\n( (NP (DT a)\n   (VP (ADJP (NN stock) (HYPH -) (VBG picking)))\n   (NN tool)))\n\"\"\"\n\n# not intended to be a correct tree\nCLOSE_SHIFT_OPEN_EXAMPLE_TREE = \"\"\"\n( (NP (DT a)\n   (ADJP (NN stock) (HYPH -) (VBG picking))\n   (NP (NN tool))))\n\"\"\"\n\nCLOSE_SHIFT_AMBIGUOUS_TREE = \"\"\"\n( (NP (DT a)\n   (ADJP (NN stock) (HYPH -) (VBG picking))\n   (NN tool)\n   (NN foo)))\n\"\"\"\n\ndef test_fix_close_shift_ambiguous_immediate():\n    \"\"\"\n    Test the result when a close/shift error occurs and we want to close the new, incorrect constituent immediately\n    \"\"\"\n    trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_later, 7)\n    expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]\n    expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    assert transitions == expected_original\n    assert new_sequence == expected_update\n\ndef test_fix_close_shift_ambiguous_later():\n    # test that the one with two shifts, which is ambiguous, gets rejected\n    trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_immediate, 7)\n    expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]\n    expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]\n    assert transitions == expected_original\n    assert new_sequence == expected_update\n\ndef test_oracle_with_optional_level():\n    tree = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)[0]\n    gold_sequence = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]\n    expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    assert transitions == gold_sequence\n\n    oracle = TopDownOracle(ROOT_LABELS, 1, \"\", \"\")\n\n    model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY, root_labels=ROOT_LABELS)\n    state = build_state(model, tree, 7)\n    fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8],\n                                         model=model,\n                                         state=state)\n    assert fix is RepairType.OTHER_CLOSE_SHIFT\n    assert new_sequence is None\n\n    oracle = TopDownOracle(ROOT_LABELS, 1, \"CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR\", \"\")\n    fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8],\n                                         model=model,\n                                         state=state)\n    assert fix is RepairType.CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR\n    assert new_sequence == expected_update\n\n\ndef test_fix_close_shift():\n    \"\"\"\n    Test a tree of the kind we expect the close/shift to be able to get right\n    \"\"\"\n    trees = read_trees(CLOSE_SHIFT_EXAMPLE_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n\n    new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7)\n\n    expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]\n    expected_update   = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    assert transitions == expected_original\n    assert new_sequence == expected_update\n\n    # test that the one with two shifts, which is ambiguous, gets rejected\n    trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7)\n    assert new_sequence is None\n\ndef test_fix_close_shift_deeper_tree():\n    \"\"\"\n    Test a tree of the kind we expect the close/shift to be able to get right\n    \"\"\"\n    trees = read_trees(CLOSE_SHIFT_DEEP_EXAMPLE_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n\n    for count_opens in [True, False]:\n        new_sequence = get_single_repair(transitions, transitions[10], fix_close_shift, 8, count_opens=count_opens)\n\n        expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]\n        expected_update   = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n        assert transitions == expected_original\n        assert new_sequence == expected_update\n\ndef test_fix_close_shift_open_tree():\n    \"\"\"\n    We would like the close/shift to get this case right as well\n    \"\"\"\n    trees = read_trees(CLOSE_SHIFT_OPEN_EXAMPLE_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n\n    new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift, 7, count_opens=False)\n    assert new_sequence is None\n\n    new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift_with_opens, 7)\n\n    expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    expected_update   = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    assert transitions == expected_original\n    assert new_sequence == expected_update\n\nCLOSE_OPEN_EXAMPLE_TREE = \"\"\"\n( (VP (VBZ eat)\n   (NP (NN spaghetti))\n   (PP (IN with) (DT a) (NN fork))))\n\"\"\"\n\nCLOSE_OPEN_DIFFERENT_LABEL_TREE = \"\"\"\n( (VP (VBZ eat)\n   (NP (NN spaghetti))\n   (NP (DT a) (NN fork))))\n\"\"\"\n\nCLOSE_OPEN_TWO_LABELS_TREE = \"\"\"\n( (VP (VBZ eat)\n   (NP (NN spaghetti))\n   (PP (IN with) (DT a) (NN fork))\n   (PP (IN in) (DT a) (NN restaurant))))\n\"\"\"\n\ndef test_fix_close_open():\n    trees = read_trees(CLOSE_OPEN_EXAMPLE_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n\n    assert isinstance(transitions[5], CloseConstituent)\n    assert transitions[6] == OpenConstituent(\"PP\")\n\n    new_transitions = get_single_repair(transitions, transitions[6], fix_close_open_correct_open, 5)\n\n    expected_original = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n    expected_update   = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]\n\n    assert transitions == expected_original\n    assert new_transitions == expected_update\n\ndef test_fix_close_open_invalid():\n    for TREE in (CLOSE_OPEN_DIFFERENT_LABEL_TREE, CLOSE_OPEN_TWO_LABELS_TREE):\n        trees = read_trees(TREE)\n        assert len(trees) == 1\n        tree = trees[0]\n\n        transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n\n        assert isinstance(transitions[5], CloseConstituent)\n        assert isinstance(transitions[6], OpenConstituent)\n\n        new_transitions = get_single_repair(transitions, OpenConstituent(\"PP\"), fix_close_open_correct_open, 5)\n        assert new_transitions is None\n\ndef test_fix_close_open_ambiguous_immediate():\n    \"\"\"\n    Test that a fix for an ambiguous close/open works as expected\n    \"\"\"\n    trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    assert isinstance(transitions[5], CloseConstituent)\n    assert isinstance(transitions[6], OpenConstituent)\n\n    reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN)\n    assert tree == reconstructed\n\n    new_transitions = get_single_repair(transitions, OpenConstituent(\"PP\"), fix_close_open_correct_open, 5, check_close=False)\n    reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)\n\n    expected = \"\"\"\n    ( (VP (VBZ eat)\n        (NP (NN spaghetti)\n          (PP (IN with) (DT a) (NN fork)))\n        (PP (IN in) (DT a) (NN restaurant))))\n    \"\"\"\n    expected = read_trees(expected)[0]\n    assert reconstructed == expected\n\ndef test_fix_close_open_ambiguous_later():\n    \"\"\"\n    Test that a fix for an ambiguous close/open works as expected\n    \"\"\"\n    trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    assert isinstance(transitions[5], CloseConstituent)\n    assert isinstance(transitions[6], OpenConstituent)\n\n    reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN)\n    assert tree == reconstructed\n\n    new_transitions = get_single_repair(transitions, OpenConstituent(\"PP\"), fix_close_open_correct_open_ambiguous_later, 5, check_close=False)\n    reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)\n\n    expected = \"\"\"\n    ( (VP (VBZ eat)\n        (NP (NN spaghetti)\n          (PP (IN with) (DT a) (NN fork))\n          (PP (IN in) (DT a) (NN restaurant)))))\n    \"\"\"\n    expected = read_trees(expected)[0]\n    assert reconstructed == expected\n\n\nSHIFT_CLOSE_EXAMPLES = [\n    (\"((S (NP (DT an) (NML (NNP Oct) (CD 19)) (NN review))))\", \"((S (NP (DT an) (NML (NNP Oct) (CD 19))) (NN review)))\", 8),\n    (\"((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))\",\n     \"((S (NP (` `) (NP (DT The)) (NN Misanthrope) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))\", 6),\n    (\"((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))\",\n     \"((S (NP (` `) (NP (DT The) (NN Misanthrope))) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre)))))\", 8),\n    (\"((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))\",\n     \"((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman)) (NNP Theatre)))))\", 13),\n]\n\ndef test_shift_close():\n    for idx, (orig_tree, expected_tree, shift_position) in enumerate(SHIFT_CLOSE_EXAMPLES):\n        trees = read_trees(orig_tree)\n        assert len(trees) == 1\n        tree = trees[0]\n\n        transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n        if shift_position is None:\n            print(transitions)\n            continue\n\n        assert isinstance(transitions[shift_position], Shift)\n        new_transitions = get_single_repair(transitions, CloseConstituent(), fix_shift_close, shift_position)\n        reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)\n        if expected_tree is None:\n            print(transitions)\n            print(new_transitions)\n\n            print(\"{:P}\".format(reconstructed))\n        else:\n            expected_tree = read_trees(expected_tree)\n            assert len(expected_tree) == 1\n            expected_tree = expected_tree[0]\n\n            assert reconstructed == expected_tree\n\ndef test_shift_open_ambiguous_unary():\n    \"\"\"\n    Test what happens if a Shift is turned into an Open in an ambiguous manner\n    \"\"\"\n    trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]\n    assert transitions == expected_original\n\n    new_sequence = get_single_repair(transitions, OpenConstituent(\"ZZ\"), fix_shift_open_ambiguous_unary, 4)\n    expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]\n    assert new_sequence == expected_updated\n\ndef test_shift_open_ambiguous_later():\n    \"\"\"\n    Test what happens if a Shift is turned into an Open in an ambiguous manner\n    \"\"\"\n    trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)\n    assert len(trees) == 1\n    tree = trees[0]\n\n    transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)\n    expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]\n    assert transitions == expected_original\n\n    new_sequence = get_single_repair(transitions, OpenConstituent(\"ZZ\"), fix_shift_open_ambiguous_later, 4)\n    expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]\n    assert new_sequence == expected_updated\n"
  },
  {
    "path": "stanza/tests/constituency/test_trainer.py",
    "content": "from collections import defaultdict\nimport logging\nimport pathlib\nimport tempfile\n\nimport pytest\nimport torch\nfrom torch import nn\nfrom torch import optim\n\nfrom stanza import Pipeline\n\nfrom stanza.models import constituency_parser\nfrom stanza.models.common import pretrain\nfrom stanza.models.common.bert_embedding import load_bert, load_tokenizer\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.models.common.utils import set_random_seed\nfrom stanza.models.constituency import lstm_model\nfrom stanza.models.constituency.parse_transitions import Transition\nfrom stanza.models.constituency import parser_training\nfrom stanza.models.constituency import trainer\nfrom stanza.models.constituency import tree_reader\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nlogger = logging.getLogger('stanza.constituency.trainer')\nlogger.setLevel(logging.WARNING)\n\nTREEBANK = \"\"\"\n( (S\n    (VP (VBG Enjoying)\n      (NP (PRP$  my) (JJ favorite) (NN Friday) (NN tradition)))\n    (. .)))\n\n( (NP\n    (VP (VBG Sitting)\n      (PP (IN in)\n        (NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station)))\n      (VP (VBG waiting)\n        (PP (IN for)\n          (NP (PRP$  my) (JJ delayed) (NNP @MBTA) (NN train)))))\n    (. .)))\n\n( (S\n    (NP (PRP I))\n    (VP\n      (ADVP (RB really))\n      (VBP hate)\n      (NP (DT the) (NNP @MBTA)))))\n\n( (S\n    (S (VP (VB Seek)))\n    (CC and)\n    (S (NP (PRP ye))\n      (VP (MD shall)\n        (VP (VB find))))\n    (. .)))\n\"\"\"\n\ndef build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK):\n    # TODO: build a fake embedding some other way?\n    train_trees = tree_reader.read_trees(treebank)\n    dev_trees = train_trees[-1:]\n    silver_trees = []\n\n    args = ['--wordvec_pretrain_file', wordvec_pretrain_file] + list(args)\n    args = constituency_parser.parse_args(args)\n\n    foundation_cache = FoundationCache()\n    # might be None, unless we're testing loading an existing model\n    model_load_name = args['load_name']\n\n    model, _, _, _ = parser_training.build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_name)\n    assert isinstance(model.model, lstm_model.LSTMModel)\n    return model\n\nclass TestTrainer:\n    @pytest.fixture(scope=\"class\")\n    def wordvec_pretrain_file(self):\n        return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'\n\n    @pytest.fixture(scope=\"class\")\n    def tiny_random_xlnet(self, tmp_path_factory):\n        \"\"\"\n        Download the tiny-random-xlnet model and make a concrete copy of it\n\n        The issue here is that the \"random\" nature of the original\n        makes it difficult or impossible to test that the values in\n        the transformer don't change during certain operations.\n        Saving a concrete instantiation of those random numbers makes\n        it so we can test there is no difference when training only a\n        subset of the layers, for example\n        \"\"\"\n        xlnet_name = 'hf-internal-testing/tiny-random-xlnet'\n        xlnet_model, xlnet_tokenizer = load_bert(xlnet_name)\n        path = str(tmp_path_factory.mktemp('tiny-random-xlnet'))\n        xlnet_model.save_pretrained(path)\n        xlnet_tokenizer.save_pretrained(path)\n        return path\n\n    @pytest.fixture(scope=\"class\")\n    def tiny_random_bart(self, tmp_path_factory):\n        \"\"\"\n        Download the tiny-random-bart model and make a concrete copy of it\n\n        Issue is the same as with tiny_random_xlnet\n        \"\"\"\n        bart_name = 'hf-internal-testing/tiny-random-bart'\n        bart_model, bart_tokenizer = load_bert(bart_name)\n        path = str(tmp_path_factory.mktemp('tiny-random-bart'))\n        bart_model.save_pretrained(path)\n        bart_tokenizer.save_pretrained(path)\n        return path\n\n    def test_initial_model(self, wordvec_pretrain_file):\n        \"\"\"\n        does nothing, just tests that the construction went okay\n        \"\"\"\n        args = ['wordvec_pretrain_file', wordvec_pretrain_file]\n        build_trainer(wordvec_pretrain_file)\n\n\n    def test_save_load_model(self, wordvec_pretrain_file):\n        \"\"\"\n        Just tests that saving and loading works without crashs.\n\n        Currently no test of the values themselves\n        (checks some fields to make sure they are regenerated correctly)\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            tr = build_trainer(wordvec_pretrain_file)\n            transitions = tr.model.transitions\n\n            # attempt saving\n            filename = os.path.join(tmpdirname, \"parser.pt\")\n            tr.save(filename)\n\n            assert os.path.exists(filename)\n\n            # load it back in\n            tr2 = tr.load(filename)\n            trans2 = tr2.model.transitions\n            assert(transitions == trans2)\n            assert all(isinstance(x, Transition) for x in trans2)\n\n    def test_relearn_structure(self, wordvec_pretrain_file):\n        \"\"\"\n        Test that starting a trainer with --relearn_structure copies the old model\n        \"\"\"\n\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            set_random_seed(1000)\n            args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']\n            tr = build_trainer(wordvec_pretrain_file, *args)\n\n            # attempt saving\n            filename = os.path.join(tmpdirname, \"parser.pt\")\n            tr.save(filename)\n\n            set_random_seed(1001)\n            args = ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--relearn_structure', '--load_name', filename]\n            tr2 = build_trainer(wordvec_pretrain_file, *args)\n\n            assert torch.allclose(tr.model.delta_embedding.weight, tr2.model.delta_embedding.weight)\n            assert torch.allclose(tr.model.output_layers[0].weight, tr2.model.output_layers[0].weight)\n            # the norms will be the same, as the non-zero values are all the same\n            assert torch.allclose(torch.linalg.norm(tr.model.word_lstm.weight_ih_l0), torch.linalg.norm(tr2.model.word_lstm.weight_ih_l0))\n\n    def write_treebanks(self, tmpdirname):\n        train_treebank_file = os.path.join(tmpdirname, \"train.mrg\")\n        with open(train_treebank_file, 'w', encoding='utf-8') as fout:\n            fout.write(TREEBANK)\n            fout.write(TREEBANK)\n\n        eval_treebank_file = os.path.join(tmpdirname, \"eval.mrg\")\n        with open(eval_treebank_file, 'w', encoding='utf-8') as fout:\n            fout.write(TREEBANK)\n\n        return train_treebank_file, eval_treebank_file\n\n    def training_args(self, wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *additional_args):\n        # let's not make the model huge...\n        args = ['--pattn_num_layers', '0', '--pattn_d_model', '128', '--lattn_d_proj', '0', '--use_lattn', '--hidden_size', '20', '--delta_embedding_dim', '10',\n                '--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname,\n                '--save_dir', tmpdirname, '--save_name', 'test.pt', '--save_each_start', '0', '--save_each_name', os.path.join(tmpdirname, 'each_%02d.pt'),\n                '--train_file', train_treebank_file, '--eval_file', eval_treebank_file,\n                '--epoch_size', '6', '--train_batch_size', '3',\n                '--shorthand', 'en_test']\n        args = args + list(additional_args)\n        args = constituency_parser.parse_args(args)\n        # just in case we change the defaults in the future\n        args['wandb'] = None\n        return args\n\n    def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False, exists_ok=False, foundation_cache=None):\n        \"\"\"\n        Runs a test of the trainer for a few iterations.\n\n        Checks some basic properties of the saved model, but doesn't\n        check for the accuracy of the results\n        \"\"\"\n        if extra_args is None:\n            extra_args = []\n        extra_args += ['--epochs', '%d' % num_epochs]\n\n        train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)\n        if use_silver:\n            extra_args += ['--silver_file', str(eval_treebank_file)]\n        args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args)\n\n        each_name = args['save_each_name']\n        if not exists_ok:\n            assert not os.path.exists(args['save_name'])\n        retag_pipeline = Pipeline(lang=\"en\", processors=\"tokenize, pos\", tokenize_pretokenized=True, dir=TEST_MODELS_DIR, foundation_cache=foundation_cache, download_method=None)\n        trained_model = parser_training.train(args, None, [retag_pipeline])\n        # check that hooks are in the model if expected\n        for p in trained_model.model.parameters():\n            if p.requires_grad:\n                if args['grad_clipping'] is not None:\n                    assert len(p._backward_hooks) == 1\n                else:\n                    assert p._backward_hooks is None\n\n        # check that the model can be loaded back\n        assert os.path.exists(args['save_name'])\n        peft_name = trained_model.model.peft_name\n        tr = trainer.Trainer.load(args['save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)\n        assert tr.optimizer is not None\n        assert tr.scheduler is not None\n        assert tr.epochs_trained >= 1\n        for p in tr.model.parameters():\n            if p.requires_grad:\n                assert p._backward_hooks is None\n\n        tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)\n        assert tr.optimizer is not None\n        assert tr.scheduler is not None\n        assert tr.epochs_trained == num_epochs\n\n        for i in range(1, num_epochs+1):\n            model_name = each_name % i\n            assert os.path.exists(model_name)\n            tr = trainer.Trainer.load(model_name, load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)\n            assert tr.epochs_trained == i\n            assert tr.batches_trained == (4 * i if use_silver else 2 * i)\n\n        return args, trained_model\n\n    def test_train(self, wordvec_pretrain_file):\n        \"\"\"\n        Test the whole thing for a few iterations on the fake data\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            self.run_train_test(wordvec_pretrain_file, tmpdirname)\n\n    def test_early_dropout(self, wordvec_pretrain_file):\n        \"\"\"\n        Test the whole thing for a few iterations on the fake data\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            args = ['--early_dropout', '3']\n            _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)\n            model = model.model\n            dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]\n            assert len(dropouts) > 0, \"Didn't find any dropouts in the model!\"\n            for name, module in dropouts:\n                assert module.p == 0.0, \"Dropout module %s was not set to 0 with early_dropout\"\n\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            # test that when turned off, early_dropout doesn't happen\n            args = ['--early_dropout', '-1']\n            _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)\n            model = model.model\n            dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]\n            assert len(dropouts) > 0, \"Didn't find any dropouts in the model!\"\n            if all(module.p == 0.0 for _, module in dropouts):\n                raise AssertionError(\"All dropouts were 0 after training even though early_dropout was set to -1\")\n\n    def test_train_silver(self, wordvec_pretrain_file):\n        \"\"\"\n        Test the whole thing for a few iterations on the fake data\n\n        This tests that it works if you give it a silver file\n        The check for the use of the silver data is that the\n        number of batches trained should go up\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True)\n\n    def test_train_checkpoint(self, wordvec_pretrain_file):\n        \"\"\"\n        Test the whole thing for a few iterations, then restart\n\n        This tests that the 5th iteration save file is not rewritten\n        and that the iterations continue to 10\n\n        TODO: could make it more robust by verifying that only 5 more\n        epochs are trained.  Perhaps a \"most recent epochs\" could be\n        saved in the trainer\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=False)\n            save_5 = args['save_each_name'] % 5\n            save_10 = args['save_each_name'] % 10\n            assert os.path.exists(save_5)\n            assert not os.path.exists(save_10)\n\n            save_5_stat = pathlib.Path(save_5).stat()\n\n            self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=10, use_silver=False, exists_ok=True)\n            assert os.path.exists(save_5)\n            assert os.path.exists(save_10)\n\n            assert pathlib.Path(save_5).stat().st_mtime == save_5_stat.st_mtime\n\n    def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None):\n            train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)\n            args = ['--multistage', '--pattn_num_layers', '1']\n            if use_lattn:\n                args += ['--lattn_d_proj', '16']\n            if extra_args:\n                args += extra_args\n            args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args)\n            each_name = os.path.join(args['save_dir'], 'each_%02d.pt')\n\n            word_input_sizes = defaultdict(list)\n            for i in range(1, 9):\n                model_name = each_name % i\n                assert os.path.exists(model_name)\n                tr = trainer.Trainer.load(model_name, load_optimizer=True)\n                assert tr.epochs_trained == i\n                word_input_sizes[tr.model.word_input_size].append(i)\n            if use_lattn:\n                # there should be three stages: no attn, pattn, pattn+lattn\n                assert len(word_input_sizes) == 3\n                word_input_keys = sorted(word_input_sizes.keys())\n                assert word_input_sizes[word_input_keys[0]] == [1, 2, 3]\n                assert word_input_sizes[word_input_keys[1]] == [4, 5]\n                assert word_input_sizes[word_input_keys[2]] == [6, 7, 8]\n            else:\n                # with no lattn, there are two stages: no attn, pattn\n                assert len(word_input_sizes) == 2\n                word_input_keys = sorted(word_input_sizes.keys())\n                assert word_input_sizes[word_input_keys[0]] == [1, 2, 3]\n                assert word_input_sizes[word_input_keys[1]] == [4, 5, 6, 7, 8]\n\n    def test_multistage_lattn(self, wordvec_pretrain_file):\n        \"\"\"\n        Test a multistage training for a few iterations on the fake data\n\n        This should start with no pattn or lattn, have pattn in the middle, then lattn at the end\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True)\n\n    def test_multistage_no_lattn(self, wordvec_pretrain_file):\n        \"\"\"\n        Test a multistage training for a few iterations on the fake data\n\n        This should start with no pattn or lattn, have pattn in the middle, then lattn at the end\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False)\n\n    def test_multistage_optimizer(self, wordvec_pretrain_file):\n        \"\"\"\n        Test that the correct optimizers are built for a multistage training process\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            extra_args = ['--optim', 'adamw']\n            self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args)\n\n            # check that the optimizers which get rebuilt when loading\n            # the models are adadelta for the first half of the\n            # multistage, then adamw\n            each_name = os.path.join(tmpdirname, 'each_%02d.pt')\n            for i in range(1, 3):\n                model_name = each_name % i\n                tr = trainer.Trainer.load(model_name, load_optimizer=True)\n                assert tr.epochs_trained == i\n                assert isinstance(tr.optimizer, optim.Adadelta)\n                # double check that this is actually a valid test\n                assert not isinstance(tr.optimizer, optim.AdamW)\n\n            for i in range(4, 8):\n                model_name = each_name % i\n                tr = trainer.Trainer.load(model_name, load_optimizer=True)\n                assert tr.epochs_trained == i\n                assert isinstance(tr.optimizer, optim.AdamW)\n\n\n    def test_grad_clip_hooks(self, wordvec_pretrain_file):\n        \"\"\"\n        Verify that grad clipping is not saved with the model, but is attached at training time\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            args = ['--grad_clipping', '25']\n            self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)\n\n    def test_analyze_trees(self, wordvec_pretrain_file):\n        test_str = \"(ROOT (S (NP (PRP I)) (VP (VBP wan) (S (VP (TO na) (VP (VB lick) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae))))))))  (ROOT (S (NP (DT This) (NN interface)) (VP (VBZ sucks))))\"\n\n        test_tree = tree_reader.read_trees(test_str)\n        assert len(test_tree) == 2\n\n        args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']\n        tr = build_trainer(wordvec_pretrain_file, *args)\n\n        results = tr.model.analyze_trees(test_tree)\n        assert len(results) == 2\n        assert len(results[0].predictions) == 1\n        assert results[0].predictions[0].tree == test_tree[0]\n        assert results[0].state is not None\n        assert isinstance(results[0].state.score, torch.Tensor)\n        assert results[0].state.score.shape == torch.Size([])\n        assert len(results[0].constituents) == 9\n        assert results[0].constituents[-1].value == test_tree[0]\n        # the way the results are built, the next-to-last entry\n        # should be the thing just below the root\n        assert results[0].constituents[-2].value == test_tree[0].children[0]\n\n        assert len(results[1].predictions) == 1\n        assert results[1].predictions[0].tree == test_tree[1]\n        assert results[1].state is not None\n        assert isinstance(results[1].state.score, torch.Tensor)\n        assert results[1].state.score.shape == torch.Size([])\n        assert len(results[1].constituents) == 4\n        assert results[1].constituents[-1].value == test_tree[1]\n        assert results[1].constituents[-2].value == test_tree[1].children[0]\n\n    def bert_weights_allclose(self, bert_model, parser_model):\n        \"\"\"\n        Return True if all bert weights are close, False otherwise\n        \"\"\"\n        for name, parameter in bert_model.named_parameters():\n            other_name = \"bert_model.\" + name\n            other_parameter = parser_model.model.get_parameter(other_name)\n            if not torch.allclose(parameter.cpu(), other_parameter.cpu()):\n                return False\n        return True\n\n    def frozen_transformer_test(self, wordvec_pretrain_file, transformer_name):\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            foundation_cache = FoundationCache()\n            args = ['--bert_model', transformer_name]\n            args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args, foundation_cache=foundation_cache)\n            bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)\n            assert self.bert_weights_allclose(bert_model, trained_model)\n\n            checkpoint = torch.load(args['save_name'], lambda storage, loc: storage, weights_only=True)\n            params = checkpoint['params']\n            # check that the bert model wasn't saved in the model\n            assert all(not x.startswith(\"bert_model.\") for x in params['model'].keys())\n            # make sure we're looking at the right thing\n            assert any(x.startswith(\"output_layers.\") for x in params['model'].keys())\n\n            # check that the cached model is used as expected when loading a bert model\n            trained_model = trainer.Trainer.load(args['save_name'], foundation_cache=foundation_cache)\n            assert trained_model.model.bert_model is bert_model\n\n    def test_bert_frozen(self, wordvec_pretrain_file):\n        \"\"\"\n        Check that the parameters of the bert model don't change when training a basic model\n        \"\"\"\n        self.frozen_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')\n\n    def test_xlnet_frozen(self, wordvec_pretrain_file, tiny_random_xlnet):\n        \"\"\"\n        Check that the parameters of an xlnet model don't change when training a basic model\n        \"\"\"\n        self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)\n\n    def test_bart_frozen(self, wordvec_pretrain_file, tiny_random_bart):\n        \"\"\"\n        Check that the parameters of an xlnet model don't change when training a basic model\n        \"\"\"\n        self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_bart)\n\n    def test_bert_finetune_one_epoch(self, wordvec_pretrain_file):\n        \"\"\"\n        Check that the parameters the bert model DO change over a single training step\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            transformer_name = 'hf-internal-testing/tiny-bert'\n            args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adadelta']\n            args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=1, extra_args=args)\n\n            # check that the weights are different\n            foundation_cache = FoundationCache()\n            bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)\n            assert not self.bert_weights_allclose(bert_model, trained_model)\n\n            # double check that a new bert is created instead of using the FoundationCache when the bert has been trained\n            model_name = args['save_name']\n            assert os.path.exists(model_name)\n            no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, \"--no_bert_finetune\", \"--no_stage1_bert_finetune\", '--bert_model', transformer_name)\n            tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache)\n            assert tr.model.bert_model is not bert_model\n            assert not self.bert_weights_allclose(bert_model, tr)\n            assert self.bert_weights_allclose(trained_model.model.bert_model, tr)\n\n            new_save_name = os.path.join(tmpdirname, \"test_resave_bert.pt\")\n            assert not os.path.exists(new_save_name)\n            tr.save(new_save_name, save_optimizer=False)\n            tr2 = trainer.Trainer.load(new_save_name, args=no_finetune_args, foundation_cache=foundation_cache)\n            # check that the resaved model included its finetuned bert weights\n            assert tr2.model.bert_model is not bert_model\n            # the finetuned bert weights should also be scheduled for saving the next time as well\n            assert not tr2.model.is_unsaved_module(\"bert_model\")\n\n    def finetune_transformer_test(self, wordvec_pretrain_file, transformer_name):\n        \"\"\"\n        Check that the parameters of the transformer DO change when using bert_finetune\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw']\n            args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)\n\n            # check that the weights are different\n            foundation_cache = FoundationCache()\n            bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)\n            assert not self.bert_weights_allclose(bert_model, trained_model)\n\n            # double check that a new bert is created instead of using the FoundationCache when the bert has been trained\n            no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, \"--no_bert_finetune\", \"--no_stage1_bert_finetune\", '--bert_model', transformer_name)\n            trained_model = trainer.Trainer.load(args['save_name'], args=no_finetune_args, foundation_cache=foundation_cache)\n            assert not trained_model.model.args['bert_finetune']\n            assert not trained_model.model.args['stage1_bert_finetune']\n            assert trained_model.model.bert_model is not bert_model\n\n    def test_bert_finetune(self, wordvec_pretrain_file):\n        \"\"\"\n        Check that the parameters of a bert model DO change when using bert_finetune\n        \"\"\"\n        self.finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')\n\n    def test_xlnet_finetune(self, wordvec_pretrain_file, tiny_random_xlnet):\n        \"\"\"\n        Check that the parameters of an xlnet model DO change when using bert_finetune\n        \"\"\"\n        self.finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)\n\n    def test_stage1_bert_finetune(self, wordvec_pretrain_file):\n        \"\"\"\n        Check that the parameters the bert model DO change when using stage1_bert_finetune, but only for the first couple steps\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            bert_model_name = 'hf-internal-testing/tiny-bert'\n            args = ['--bert_model', bert_model_name, '--stage1_bert_finetune', '--optim', 'adamw']\n            # need to use num_epochs==6 so that epochs 1 and 2 are saved to be different\n            # a test of 5 or less means that sometimes it will reload the params\n            # at step 2 to get ready for the following iterations with adamw\n            args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)\n\n            # check that the weights are different\n            foundation_cache = FoundationCache()\n            bert_model, bert_tokenizer = foundation_cache.load_bert(bert_model_name)\n            assert not self.bert_weights_allclose(bert_model, trained_model)\n\n            # double check that a new bert is created instead of using the FoundationCache when the bert has been trained\n            no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, \"--no_bert_finetune\", \"--no_stage1_bert_finetune\", '--bert_model', bert_model_name, '--optim', 'adamw')\n            num_epochs = trained_model.model.args['epochs']\n            each_name = os.path.join(tmpdirname, 'each_%02d.pt')\n            for i in range(1, num_epochs+1):\n                model_name = each_name % i\n                assert os.path.exists(model_name)\n                tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache)\n                assert tr.model.bert_model is not bert_model\n                assert not self.bert_weights_allclose(bert_model, tr)\n                if i >= num_epochs // 2:\n                    assert self.bert_weights_allclose(trained_model.model.bert_model, tr)\n\n            # verify that models 1 and 2 are saved to be different\n            model_name_1 = each_name % 1\n            model_name_2 = each_name % 2\n            tr_1 = trainer.Trainer.load(model_name_1, args=no_finetune_args, foundation_cache=foundation_cache)\n            tr_2 = trainer.Trainer.load(model_name_2, args=no_finetune_args, foundation_cache=foundation_cache)\n            assert not self.bert_weights_allclose(tr_1.model.bert_model, tr_2)\n\n\n    def one_layer_finetune_transformer_test(self, wordvec_pretrain_file, transformer_name):\n        \"\"\"\n        Check that the parameters the bert model DO change when using bert_finetune\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            args = ['--bert_model', transformer_name, '--bert_finetune', '--bert_finetune_layers', '1', '--optim', 'adamw', '--bert_finetune_layers', '1']\n            args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)\n\n            # check that the weights of the last layer are different,\n            # but the weights of the earlier layers and\n            # non-transformer-layers are the same\n            foundation_cache = FoundationCache()\n            bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)\n            assert bert_model.config.num_hidden_layers > 1\n            layer_name = \"layer.%d.\" % (bert_model.config.num_hidden_layers - 1)\n            for name, parameter in bert_model.named_parameters():\n                other_name = \"bert_model.\" + name\n                other_parameter = trained_model.model.get_parameter(other_name)\n                if layer_name in name:\n                    if 'rel_attn.seg_embed' in name or 'rel_attn.r_s_bias' in name:\n                        # not sure why this happens for xlnet, just roll with it\n                        continue\n                    assert not torch.allclose(parameter.cpu(), other_parameter.cpu())\n                else:\n                    assert torch.allclose(parameter.cpu(), other_parameter.cpu())\n\n    def test_bert_finetune_one_layer(self, wordvec_pretrain_file):\n        self.one_layer_finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')\n\n    def test_xlnet_finetune_one_layer(self, wordvec_pretrain_file, tiny_random_xlnet):\n        self.one_layer_finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)\n\n    def test_peft_finetune(self, tmp_path, wordvec_pretrain_file):\n        transformer_name = 'hf-internal-testing/tiny-bert'\n        args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw', '--use_peft']\n        args, trained_model = self.run_train_test(wordvec_pretrain_file, str(tmp_path), extra_args=args)\n\n    def test_peft_twostage_finetune(self, wordvec_pretrain_file):\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            num_epochs = 6\n            transformer_name = 'hf-internal-testing/tiny-bert'\n            args = ['--bert_model', transformer_name, '--stage1_bert_finetune', '--optim', 'adamw', '--use_peft']\n            args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=num_epochs, extra_args=args)\n            for epoch in range(num_epochs):\n                filename_prev = args['save_each_name'] % epoch\n                filename_next = args['save_each_name'] % (epoch+1)\n                trainer_prev = trainer.Trainer.load(filename_prev, args=args, load_optimizer=False)\n                trainer_next = trainer.Trainer.load(filename_next, args=args, load_optimizer=False)\n\n                lora_names = [name for name, _ in trainer_prev.model.bert_model.named_parameters() if name.find(\"lora\") >= 0]\n                if epoch < 2:\n                    assert not any(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(),\n                                                  trainer_next.model.bert_model.get_parameter(name).cpu())\n                                   for name in lora_names)\n                elif epoch > 2:\n                    assert all(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(),\n                                              trainer_next.model.bert_model.get_parameter(name).cpu())\n                               for name in lora_names)\n"
  },
  {
    "path": "stanza/tests/constituency/test_transformer_tree_stack.py",
    "content": "import pytest\n\nimport torch\n\nfrom stanza.models.constituency.transformer_tree_stack import TransformerTreeStack\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef test_initial_state():\n    \"\"\"\n    Test that the initial state has the expected shapes\n    \"\"\"\n    ts = TransformerTreeStack(3, 5, 0.0)\n    initial = ts.initial_state()\n    assert len(initial) == 1\n    assert initial.value.output.shape == torch.Size([5])\n    assert initial.value.key_stack.shape == torch.Size([1, 5])\n    assert initial.value.value_stack.shape == torch.Size([1, 5])\n\ndef test_output():\n    \"\"\"\n    Test that you can get an expected output shape from the TTS\n    \"\"\"\n    ts = TransformerTreeStack(3, 5, 0.0)\n    initial = ts.initial_state()\n    out = ts.output(initial)\n    assert out.shape == torch.Size([5])\n    assert torch.allclose(initial.value.output, out)\n\ndef test_push_state_single():\n    \"\"\"\n    Test that stacks are being updated correctly when using a single stack\n\n    Values of the attention are not verified, though\n    \"\"\"\n    ts = TransformerTreeStack(3, 5, 0.0)\n    initial = ts.initial_state()\n    rand_input = torch.randn(1, 3)\n    stacks = ts.push_states([initial], [\"A\"], rand_input)\n    stacks = ts.push_states(stacks, [\"B\"], rand_input)\n    assert len(stacks) == 1\n    assert len(stacks[0]) == 3\n    assert stacks[0].value.value == \"B\"\n    assert stacks[0].pop().value.value == \"A\"\n    assert stacks[0].pop().pop().value.value is None\n\ndef test_push_state_same_length():\n    \"\"\"\n    Test that stacks are being updated correctly when using 3 stacks of the same length\n\n    Values of the attention are not verified, though\n    \"\"\"\n    ts = TransformerTreeStack(3, 5, 0.0)\n    initial = ts.initial_state()\n    rand_input = torch.randn(3, 3)\n    stacks = ts.push_states([initial, initial, initial], [\"A\", \"A\", \"A\"], rand_input)\n    stacks = ts.push_states(stacks, [\"B\", \"B\", \"B\"], rand_input)\n    stacks = ts.push_states(stacks, [\"C\", \"C\", \"C\"], rand_input)\n    assert len(stacks) == 3\n    for s in stacks:\n        assert len(s) == 4\n        assert s.value.key_stack.shape == torch.Size([4, 5])\n        assert s.value.value_stack.shape == torch.Size([4, 5])\n        assert s.value.value == \"C\"\n        assert s.pop().value.value == \"B\"\n        assert s.pop().pop().value.value == \"A\"\n        assert s.pop().pop().pop().value.value is None\n\ndef test_push_state_different_length():\n    \"\"\"\n    Test what happens if stacks of different lengths are passed in\n    \"\"\"\n    ts = TransformerTreeStack(3, 5, 0.0)\n    initial = ts.initial_state()\n    rand_input = torch.randn(2, 3)\n    one_step = ts.push_states([initial], [\"A\"], rand_input[0:1, :])[0]\n    stacks = [one_step, initial]\n    stacks = ts.push_states(stacks, [\"B\", \"C\"], rand_input)\n    assert len(stacks) == 2\n    assert len(stacks[0]) == 3\n    assert len(stacks[1]) == 2\n    assert stacks[0].pop().value.value == 'A'\n    assert stacks[0].value.value == 'B'\n    assert stacks[1].value.value == 'C'\n\n    assert stacks[0].value.key_stack.shape == torch.Size([3, 5])\n    assert stacks[1].value.key_stack.shape == torch.Size([2, 5])\n\ndef test_mask():\n    \"\"\"\n    Test that a mask prevents the softmax from picking up unwanted values\n    \"\"\"\n    ts = TransformerTreeStack(3, 5, 0.0)\n\n    random_v = torch.tensor([[[0.1, 0.2, 0.3, 0.4, 0.5]]])\n    double_v = random_v * 2\n    value = torch.cat([random_v, double_v], axis=1)\n    random_k = torch.randn(1, 1, 5)\n    key = torch.cat([random_k, random_k], axis=1)\n    query = torch.randn(1, 5)\n\n    output = ts.attention(key, query, value)\n    # when the two keys are equal, we expect the attention to be 50/50\n    expected_output = (random_v + double_v) / 2\n    assert torch.allclose(output, expected_output)\n\n    # If the first entry is masked out, the second one should be the\n    # only one represented\n    mask = torch.zeros(1, 2, dtype=torch.bool)\n    mask[0][0] = True\n    output = ts.attention(key, query, value, mask)\n    assert torch.allclose(output, double_v)\n\n    # If the second entry is masked out, the first one should be the\n    # only one represented\n    mask = torch.zeros(1, 2, dtype=torch.bool)\n    mask[0][1] = True\n    output = ts.attention(key, query, value, mask)\n    assert torch.allclose(output, random_v)\n\ndef test_position():\n    \"\"\"\n    Test that nothing goes horribly wrong when position encodings are used\n\n    Does not actually test the results of the encodings\n    \"\"\"\n    ts = TransformerTreeStack(4, 5, 0.0, use_position=True)\n    initial = ts.initial_state()\n    assert len(initial) == 1\n    assert initial.value.output.shape == torch.Size([5])\n    assert initial.value.key_stack.shape == torch.Size([1, 5])\n    assert initial.value.value_stack.shape == torch.Size([1, 5])\n\n    rand_input = torch.randn(2, 4)\n    one_step = ts.push_states([initial], [\"A\"], rand_input[0:1, :])[0]\n    stacks = [one_step, initial]\n    stacks = ts.push_states(stacks, [\"B\", \"C\"], rand_input)\n\ndef test_length_limit():\n    \"\"\"\n    Test that the length limit drops nodes as the length limit is exceeded\n    \"\"\"\n    ts = TransformerTreeStack(4, 5, 0.0, length_limit = 2)\n    initial = ts.initial_state()\n    assert len(initial) == 1\n    assert initial.value.output.shape == torch.Size([5])\n    assert initial.value.key_stack.shape == torch.Size([1, 5])\n    assert initial.value.value_stack.shape == torch.Size([1, 5])\n\n    data = torch.tensor([[0.1, 0.2, 0.3, 0.4]])\n    stacks = ts.push_states([initial], [\"A\"], data)\n\n    stacks = ts.push_states(stacks, [\"B\"], data)\n    assert len(stacks) == 1\n    assert len(stacks[0]) == 3\n    assert stacks[0].value.key_stack.shape[0] == 3\n    assert stacks[0].value.value_stack.shape[0] == 3\n\n    stacks = ts.push_states(stacks, [\"C\"], data)\n    assert len(stacks) == 1\n    assert len(stacks[0]) == 4\n    assert stacks[0].value.key_stack.shape[0] == 3\n    assert stacks[0].value.value_stack.shape[0] == 3\n\n    stacks = ts.push_states(stacks, [\"D\"], data)\n    assert len(stacks) == 1\n    assert len(stacks[0]) == 5\n    assert stacks[0].value.key_stack.shape[0] == 3\n    assert stacks[0].value.value_stack.shape[0] == 3\n\ndef test_two_heads():\n    \"\"\"\n    Test that the length limit drops nodes as the length limit is exceeded\n    \"\"\"\n    ts = TransformerTreeStack(4, 6, 0.0, num_heads = 2)\n    initial = ts.initial_state()\n    assert len(initial) == 1\n    assert initial.value.output.shape == torch.Size([6])\n    assert initial.value.key_stack.shape == torch.Size([1, 6])\n    assert initial.value.value_stack.shape == torch.Size([1, 6])\n\n    rand_input = torch.randn(2, 4)\n    one_step = ts.push_states([initial], [\"A\"], rand_input[0:1, :])[0]\n    stacks = [one_step, initial]\n    stacks = ts.push_states(stacks, [\"B\", \"C\"], rand_input)\n    assert len(stacks) == 2\n    assert len(stacks[0]) == 3\n    assert len(stacks[1]) == 2\n    assert stacks[0].pop().value.value == 'A'\n    assert stacks[0].value.value == 'B'\n    assert stacks[1].value.value == 'C'\n\n    assert stacks[0].value.key_stack.shape == torch.Size([3, 6])\n    assert stacks[1].value.key_stack.shape == torch.Size([2, 6])\n\n"
  },
  {
    "path": "stanza/tests/constituency/test_transition_sequence.py",
    "content": "import pytest\nfrom stanza.models.constituency import parse_transitions\nfrom stanza.models.constituency import transition_sequence\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT\nfrom stanza.models.constituency.parse_transitions import *\n\nfrom stanza.tests import *\nfrom stanza.tests.constituency.test_parse_tree import CHINESE_LONG_LIST_TREE\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef reconstruct_tree(tree, sequence, transition_scheme=TransitionScheme.IN_ORDER, unary_limit=UNARY_LIMIT, reverse=False):\n    \"\"\"\n    Starting from a tree and a list of transitions, build the tree caused by the transitions\n    \"\"\"\n    model = SimpleModel(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse)\n    states = model.initial_state_from_gold_trees([tree])\n    assert(len(states)) == 1\n    assert states[0].num_transitions == 0\n\n    # TODO: could fold this into parse_sentences (similar to verify_transitions in trainer.py)\n    for idx, t in enumerate(sequence):\n        assert t.is_legal(states[0], model), \"Transition {} not legal at step {} in sequence {}\".format(t, idx, sequence)\n        states = model.bulk_apply(states, [t])\n\n    result_tree = states[0].constituents.value\n    if reverse:\n        result_tree = result_tree.reverse()\n    return result_tree\n\ndef check_reproduce_tree(transition_scheme):\n    text=\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n\n    model = SimpleModel(transition_scheme)\n    transitions = transition_sequence.build_sequence(trees[0], transition_scheme)\n    states = model.initial_state_from_gold_trees(trees)\n    assert(len(states)) == 1\n    state = states[0]\n    assert state.num_transitions == 0\n\n    for t in transitions:\n        assert t.is_legal(state, model)\n        state = t.apply(state, model)\n\n    # one item for the final tree\n    # one item for the sentinel at the end\n    assert len(state.constituents) == 2\n    # the transition sequence should put all of the words\n    # from the buffer onto the tree\n    # one spot left for the sentinel value\n    assert len(state.word_queue) == 8\n    assert state.sentence_length == 6\n    assert state.word_position == state.sentence_length\n    assert len(state.transitions) == len(transitions) + 1\n\n    result_tree = state.constituents.value\n    assert result_tree == trees[0]\n\ndef test_top_down_unary():\n    check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)\n\ndef test_top_down_no_unary():\n    check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN)\n\ndef test_in_order():\n    check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER)\n\ndef test_in_order_compound():\n    check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)\n\ndef test_in_order_unary():\n    check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_UNARY)\n\ndef test_all_transitions():\n    text=\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n    model = SimpleModel()\n    transitions = transition_sequence.build_treebank(trees)\n\n    expected = [Shift(), CloseConstituent(), CompoundUnary(\"ROOT\"), CompoundUnary(\"SQ\"), CompoundUnary(\"WHNP\"), OpenConstituent(\"NP\"), OpenConstituent(\"PP\"), OpenConstituent(\"SBARQ\"), OpenConstituent(\"VP\")]\n    assert transition_sequence.all_transitions(transitions) == expected\n\n\ndef test_all_transitions_no_unary():\n    text=\"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n    model = SimpleModel()\n    transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)\n\n    expected = [Shift(), CloseConstituent(), OpenConstituent(\"NP\"), OpenConstituent(\"PP\"), OpenConstituent(\"ROOT\"), OpenConstituent(\"SBARQ\"), OpenConstituent(\"SQ\"), OpenConstituent(\"VP\"), OpenConstituent(\"WHNP\")]\n    assert transition_sequence.all_transitions(transitions) == expected\n\ndef test_top_down_compound_unary():\n    text = \"(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))\"\n\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n\n    model = SimpleModel()\n    transitions = transition_sequence.build_sequence(trees[0], transition_scheme=TransitionScheme.TOP_DOWN_COMPOUND)\n\n    states = model.initial_state_from_gold_trees(trees)\n    assert len(states) == 1\n    state = states[0]\n\n    for t in transitions:\n        assert t.is_legal(state, model)\n        state = t.apply(state, model)\n\n    result = model.get_top_constituent(state.constituents)\n    assert trees[0] == result\n\n\ndef test_chinese_tree():\n    trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)\n\n    transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)\n    redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)\n    assert redone == trees[0]\n\n    transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER)\n    with pytest.raises(AssertionError):\n        redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER)\n\n    redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)\n    assert redone == trees[0]\n\n\ndef test_chinese_tree_reversed():\n    \"\"\"\n    test that the reversed transitions also work\n    \"\"\"\n    trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)\n\n    transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN, reverse=True)\n    redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN, reverse=True)\n    assert redone == trees[0]\n\n    with pytest.raises(AssertionError):\n        # turn off reverse - it should fail to rebuild the tree\n        redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)\n        assert redone == trees[0]\n\n    transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER, reverse=True)\n    with pytest.raises(AssertionError):\n        redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, reverse=True)\n\n    redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6, reverse=True)\n    assert redone == trees[0]\n\n    with pytest.raises(AssertionError):\n        # turn off reverse - it should fail to rebuild the tree\n        redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)\n        assert redone == trees[0]\n"
  },
  {
    "path": "stanza/tests/constituency/test_tree_reader.py",
    "content": "import pytest\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.tree_reader import MixedTreeError, UnclosedTreeError, UnlabeledTreeError\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef test_simple():\n    \"\"\"\n    Tests reading two simple trees from the same text\n    \"\"\"\n    text = \"(VB Unban) (NNP Opal)\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 2\n    assert trees[0].is_preterminal()\n    assert trees[0].label == 'VB'\n    assert trees[0].children[0].label == 'Unban'\n    assert trees[1].is_preterminal()\n    assert trees[1].label == 'NNP'\n    assert trees[1].children[0].label == 'Opal'\n\ndef test_newlines():\n    \"\"\"\n    The same test should work if there are newlines\n    \"\"\"\n    text = \"(VB Unban)\\n\\n(NNP Opal)\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 2\n\ndef test_parens():\n    \"\"\"\n    Parens should be escaped in the tree files and escaped when written\n    \"\"\"\n    text = \"(-LRB- -LRB-) (-RRB- -RRB-)\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 2\n\n    assert trees[0].label == '-LRB-'\n    assert trees[0].children[0].label == '('\n    assert \"{}\".format(trees[0]) == '(-LRB- -LRB-)'\n\n    assert trees[1].label == '-RRB-'\n    assert trees[1].children[0].label == ')'\n    assert \"{}\".format(trees[1]) == '(-RRB- -RRB-)'\n\ndef test_complicated():\n    \"\"\"\n    A more complicated tree that should successfully read\n    \"\"\"\n    text=\"( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    tree = trees[0]\n    assert not tree.is_leaf()\n    assert not tree.is_preterminal()\n    assert tree.label == 'ROOT'\n    assert len(tree.children) == 1\n    assert tree.children[0].label == 'SBARQ'\n    assert len(tree.children[0].children) == 3\n    assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.']\n    # etc etc\n\ndef test_one_word():\n    \"\"\"\n    Check that one node trees are correctly read\n\n    probably not super relevant for the parsing use case\n    \"\"\"\n    text=\"(FOO) (BAR)\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 2\n\n    assert trees[0].is_leaf()\n    assert trees[0].label == 'FOO'\n\n    assert trees[1].is_leaf()\n    assert trees[1].label == 'BAR'\n\ndef test_missing_close_parens():\n    \"\"\"\n    Test the unclosed error condition\n    \"\"\"\n    text = \"(Foo) \\n (Bar \\n zzz\"\n    try:\n        trees = tree_reader.read_trees(text)\n        raise AssertionError(\"Expected an exception\")\n    except UnclosedTreeError as e:\n        assert e.line_num == 1\n\ndef test_mixed_tree():\n    \"\"\"\n    Test the mixed error condition\n    \"\"\"\n    text = \"(Foo) \\n (Bar) \\n (Unban (Mox) Opal)\"\n    try:\n        trees = tree_reader.read_trees(text)\n        raise AssertionError(\"Expected an exception\")\n    except MixedTreeError as e:\n        assert e.line_num == 2\n\n    trees = tree_reader.read_trees(text, broken_ok=True)\n    assert len(trees) == 3\n\ndef test_unlabeled_tree():\n    \"\"\"\n    Test the unlabeled error condition\n    \"\"\"\n    text = \"(ROOT ((Foo) (Bar)))\"\n    try:\n        trees = tree_reader.read_trees(text)\n        raise AssertionError(\"Expected an exception\")\n    except UnlabeledTreeError as e:\n        assert e.line_num == 0\n\n    trees = tree_reader.read_trees(text, broken_ok=True)\n    assert len(trees) == 1\n\n    \n"
  },
  {
    "path": "stanza/tests/constituency/test_tree_stack.py",
    "content": "import pytest\n\nfrom stanza.models.constituency.tree_stack import TreeStack\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef test_simple():\n    stack = TreeStack(value=5, parent=None, length=1)\n    stack = stack.push(3)\n    stack = stack.push(1)\n\n    expected_values = [1, 3, 5]\n    for value in expected_values:\n        assert stack.value == value\n        stack = stack.pop()\n    assert stack is None\n\ndef test_iter():\n    stack = TreeStack(value=5, parent=None, length=1)\n    stack = stack.push(3)\n    stack = stack.push(1)\n\n    stack_list = list(stack)\n    assert list(stack) == [1, 3, 5]\n\ndef test_str():\n    stack = TreeStack(value=5, parent=None, length=1)\n    stack = stack.push(3)\n    stack = stack.push(1)\n\n    assert str(stack) == \"TreeStack(1, 3, 5)\"\n\ndef test_len():\n    stack = TreeStack(value=5, parent=None, length=1)\n    assert len(stack) == 1\n\n    stack = stack.push(3)\n    stack = stack.push(1)\n    assert len(stack) == 3\n\ndef test_long_len():\n    \"\"\"\n    Original stack had a bug where this took exponential time...\n    \"\"\"\n    stack = TreeStack(value=0, parent=None, length=1)\n    for i in range(1, 40):\n        stack = stack.push(i)\n    assert len(stack) == 40\n"
  },
  {
    "path": "stanza/tests/constituency/test_utils.py",
    "content": "import pytest\n\nfrom stanza import Pipeline\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency import utils\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n\n@pytest.fixture(scope=\"module\")\ndef pipeline():\n    return Pipeline(dir=TEST_MODELS_DIR, lang=\"en\", processors=\"tokenize, pos\", tokenize_pretokenized=True)\n\n\n\ndef test_xpos_retag(pipeline):\n    \"\"\"\n    Test using the English tagger that trees will be correctly retagged by read_trees using xpos\n    \"\"\"\n    text = \"((S (VP (X Find)) (NP (X Mox) (X Opal))))   ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))\"\n    expected = \"((S (VP (VB Find)) (NP (NNP Mox) (NNP Opal)))) ((S (NP (NNP Ragavan)) (VP (VBZ steals) (NP (JJ important) (NNS cards)))))\"\n\n    trees = tree_reader.read_trees(text)\n\n    new_trees = utils.retag_trees(trees, [pipeline], xpos=True)\n    assert new_trees == tree_reader.read_trees(expected)\n\n\n\ndef test_upos_retag(pipeline):\n    \"\"\"\n    Test using the English tagger that trees will be correctly retagged by read_trees using upos\n    \"\"\"\n    text = \"((S (VP (X Find)) (NP (X Mox) (X Opal))))   ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))\"\n    expected = \"((S (VP (VERB Find)) (NP (PROPN Mox) (PROPN Opal)))) ((S (NP (PROPN Ragavan)) (VP (VERB steals) (NP (ADJ important) (NOUN cards)))))\"\n\n    trees = tree_reader.read_trees(text)\n\n    new_trees = utils.retag_trees(trees, [pipeline], xpos=False)\n    assert new_trees == tree_reader.read_trees(expected)\n\n\ndef test_replace_tags():\n    \"\"\"\n    Test the underlying replace_tags method\n\n    Also tests that the method throws exceptions when it is supposed to\n    \"\"\"\n    text = \"((S (VP (X Find)) (NP (X Mox) (X Opal))))\"\n    expected = \"((S (VP (A Find)) (NP (B Mox) (C Opal))))\"\n\n    trees = tree_reader.read_trees(text)\n\n    new_tags = [\"A\", \"B\", \"C\"]\n    new_tree = trees[0].replace_tags(new_tags)\n\n    assert new_tree == tree_reader.read_trees(expected)[0]\n\n    with pytest.raises(ValueError):\n        new_tags = [\"A\", \"B\"]\n        new_tree = trees[0].replace_tags(new_tags)\n\n    with pytest.raises(ValueError):\n        new_tags = [\"A\", \"B\", \"C\", \"D\"]\n        new_tree = trees[0].replace_tags(new_tags)\n\n"
  },
  {
    "path": "stanza/tests/constituency/test_vietnamese.py",
    "content": "\"\"\"\nA few tests for Vietnamese parsing, which has some difficulties related to spaces in words\n\nTechnically some other languages can have this, too, like that one French token\n\"\"\"\n\nimport os\nimport tempfile\n\nimport pytest\n\nfrom stanza.models.common import pretrain\nfrom stanza.models.constituency import tree_reader\n\nfrom stanza.tests.constituency.test_trainer import build_trainer\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nVI_TREEBANK               = '(ROOT (S-TTL (NP (\" \") (N-H Đảo) (Np Đài Loan) (\" \") (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))'\n\nVI_TREEBANK_UNDERSCORE    = '(ROOT (S-TTL (NP (\" \") (N-H Đảo) (Np Đài_Loan) (\" \") (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .)))'\n\nVI_TREEBANK_SIMPLE        = '(ROOT (S (NP (\" \") (N Đảo) (Np Đài Loan) (\" \") (PP (E ở) (NP (N đồng bằng) (NP (N sông) (Np Cửu Long))))) (. .)))'\n\nVI_TREEBANK_PAREN         = '(ROOT (S-TTL (NP (PUNCT -LRB-) (N-H Đảo) (Np Đài Loan) (PUNCT -RRB-) (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))'\nVI_TREEBANK_VLSP          = '<s>\\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\\n</s>'\nVI_TREEBANK_VLSP_50       = '<s id=50>\\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\\n</s>'\nVI_TREEBANK_VLSP_100      = '<s id=100>\\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\\n</s>'\n\nEXPECTED_LABELED_BRACKETS = '(_ROOT (_S (_NP (_\" \" )_\" (_N Đảo )_N (_Np Đài_Loan )_Np (_\" \" )_\" (_PP (_E ở )_E (_NP (_N đồng_bằng )_N (_NP (_N sông )_N (_Np Cửu_Long )_Np )_NP )_NP )_PP )_NP (_. . )_. )_S )_ROOT'\n\n\ndef test_read_vi_tree():\n    \"\"\"\n    Test that an individual tree with spaces in the leaves is being processed as we expect\n    \"\"\"\n    text = VI_TREEBANK.split(\"\\n\")[0]\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    assert str(trees[0]) == text\n    # this is the first NP\n    #   the third node of that NP, eg (Np Đài Loan)\n    node = trees[0].children[0].children[0].children[2]\n    assert node.is_preterminal()\n    assert node.children[0].label == \"Đài Loan\"\n\nVI_EMBEDDING = \"\"\"\n4 4\nĐảo          0.11 0.21 0.31 0.41\nĐài Loan     0.12 0.22 0.32 0.42\nđồng bằng    0.13 0.23 0.33 0.43\nsông         0.14 0.24 0.34 0.44\n\"\"\".strip()\n\ndef test_vi_embedding():\n    \"\"\"\n    Test that a VI embedding's words are correctly found when processing trees\n    \"\"\"\n    text = VI_TREEBANK.split(\"\\n\")[0]\n    trees = tree_reader.read_trees(text)\n    words = set(trees[0].leaf_labels())\n\n    with tempfile.TemporaryDirectory() as tempdir:\n        emb_filename = os.path.join(tempdir, \"emb.txt\")\n        pt_filename = os.path.join(tempdir, \"emb.pt\")\n        with open(emb_filename, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(VI_EMBEDDING)\n        pt = pretrain.Pretrain(filename=pt_filename, vec_filename=emb_filename, save_to_file=True)\n        pt.load()\n\n        trainer = build_trainer(pt_filename)\n        model = trainer.model\n\n    assert model.num_words_known(words) == 4\n\n\ndef test_space_formatting():\n    \"\"\"\n    By default, spaces are left as spaces, but there is a format option to change spaces\n    \"\"\"\n    text = VI_TREEBANK.split(\"\\n\")[0]\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    assert str(trees[0]) == text\n\n    assert \"{}\".format(trees[0]) == VI_TREEBANK\n    assert \"{:_O}\".format(trees[0]) == VI_TREEBANK_UNDERSCORE\n\ndef test_vlsp_formatting():\n    text = VI_TREEBANK_PAREN.split(\"\\n\")[0]\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    assert str(trees[0]) == text\n\n    assert \"{:_V}\".format(trees[0]) == VI_TREEBANK_VLSP\n    trees[0].tree_id = 50\n    assert \"{:_Vi}\".format(trees[0]) == VI_TREEBANK_VLSP_50\n    trees[0].tree_id = 100\n    assert \"{:_Vi}\".format(trees[0]) == VI_TREEBANK_VLSP_100\n\n    empty = tree_reader.read_trees(\"(ROOT)\")[0]\n    with pytest.raises(ValueError):\n        \"{:V}\".format(empty)\n\n    branches = tree_reader.read_trees(\"(ROOT (1) (2) (3))\")[0]\n    with pytest.raises(ValueError):\n        \"{:V}\".format(branches)\n\ndef test_language_formatting():\n    \"\"\"\n    Test turning the parse tree into a 'language' for GPT\n    \"\"\"\n    text = VI_TREEBANK.split(\"\\n\")[0]\n    trees = tree_reader.read_trees(text)\n    trees = [t.prune_none().simplify_labels() for t in trees]\n    assert len(trees) == 1\n    assert str(trees[0]) == VI_TREEBANK_SIMPLE\n\n    text = \"{:L}\".format(trees[0])\n    assert text == EXPECTED_LABELED_BRACKETS\n\n"
  },
  {
    "path": "stanza/tests/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/datasets/coref/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/datasets/coref/test_hebrew_iahlt.py",
    "content": "import pytest\n\nfrom stanza import Pipeline\nfrom stanza.tests import TEST_MODELS_DIR\nfrom stanza.utils.datasets.coref.convert_hebrew_iahlt import extract_doc\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\n@pytest.fixture(scope=\"module\")\ndef tokenizer():\n    pipe = Pipeline(lang=\"he\", processors=\"tokenize\", dir=TEST_MODELS_DIR, download_method=None)\n    return pipe\n\nTEXT = \"\"\"\n\n\n\nמבולבלים​? גם אנחנו​: ל​מסעדנים ו​ה​מלצרים יש עוד סימני שאלה על ה​טיפים​\n\nה​פער בין פסיקת בית ה​דין ל​עבודה לבין פסיקה קודמת של בג\"ץ​, משאיר את ה​ענף ב​חוסר וודאות​, ו​ה -​1 ב​ינואר כבר מעבר ל​פינה . \"​מ​בחינת​י , הייתי מוסיף ל​תפריט תוספת שירות של 17​% \"​, אמר בעלים של מסעדה ב​שדרות​\n\nב​רשות ה​מיסים מסתפקים ב​מסר עמום באשר ל​כוונותי​הם לאור פסק דין ה​טיפים ש​צפוי להיכנס ל​תוקפ​ו ב​-​1 ב​ינואר . על פי פרשנות​ם ה​מקצועית , הבהירו​, יש מקום לחייב את כספי ה​טיפים ב​מע\"מ , \"​עם זאת​, ה​רשות עדין בוחנת את ה​סוגיה ו​טרם התקבלה החלטה אופרטיבית ב​עניין \"​. ו​איך אמורים ה​מסעדנים להיערך בינתיים ל​יישום ה​פסיקה ו​ל​מחזור ה​שנה ה​באה ? ב​יום חמישי יפגשו אנשי ארגון '​מסעדנים חזקים ביחד​' עם מנהל רשות ה​מיסים ערן יעקב​, ו​ידרשו תשובות ברורות​.​\n\n\"​אני עדיין לא מדבר עם ה​עובדים של​י , ו​אני גם לא יודע איך להיערך החל מ​עוד שבועיים​\"​, אמר ל​'​דבר ראשון​' ניר שוחט​, ה​בעלים של מסעדת סושי מוטו ב​שדרות ו​מוסיף כי יהיה קשה להתאים את ה​פסיקה ל​מציאות ב​שטח . \"​אף אחד לא יודע​. יש המון סתירות – עורך ה​דין אומר דבר אחד ו​רואה ה​חשבון דבר אחר​. עדיין לא הצליחו להבין את ה​חוק ל​אשור​ו \"​.​\n\n\"​מ​בחינת​י , הייתי מוסיף ל​תפריט תוספת שירות של 17​% . זה יגלם גם את ה​מע\"מ ו​ה​טיפים ו​מ​זה אני אשלם ל​מלצרים . די כבר עם ה​טיפים ה​אלה , מספיק​.​\"​\n\"\"\"\n\nCLUSTER = {'metadata': {'name': 'המסעדנים', 'entity': 'person'}, 'mentions': [[28, 35, {}], [572, 581, {}]]}\n\ndef test_extract_doc(tokenizer):\n    doc = {'text': TEXT,\n           'clusters': [CLUSTER],\n           'metadata': {\n               'doc_id': 'test'\n           }\n           }\n    extracted = extract_doc(tokenizer, [doc])\n    assert len(extracted) == 1\n    assert len(extracted[0].coref_spans) == 2\n    assert extracted[0].coref_spans[1] == [(0, 4, 4)]\n    assert extracted[0].coref_spans[6] == [(0, 3, 4)]\n"
  },
  {
    "path": "stanza/tests/datasets/ner/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/datasets/ner/test_prepare_ner_file.py",
    "content": "\"\"\"\nTest some simple conversions of NER bio files\n\"\"\"\n\nimport pytest\n\nimport json\n\nfrom stanza.models.common.doc import Document\nfrom stanza.utils.datasets.ner.prepare_ner_file import process_dataset\n\nBIO_1 = \"\"\"\nJennifer\tB-PERSON\nSh'reyan\tI-PERSON\nhas\tO\nlovely\tO\nantennae\tO\n\"\"\".strip()\n\nBIO_2 = \"\"\"\nbut\tO\nI\tO\ndon't\tO\nlike\tO\nthe\tO\nway\tO\nJennifer\tB-PERSON\ntreated\tO\nBeckett\tB-PERSON\non\tO\nthe\tO\nCerritos\tB-LOCATION\n\"\"\".strip()\n\ndef check_json_file(doc, raw_text, expected_sentences, expected_tokens):\n    raw_sentences = raw_text.strip().split(\"\\n\\n\")\n    assert len(raw_sentences) == expected_sentences\n    if isinstance(expected_tokens, int):\n        expected_tokens = [expected_tokens]\n    for raw_sentence, expected_len in zip(raw_sentences, expected_tokens):\n        assert len(raw_sentence.strip().split(\"\\n\")) == expected_len\n\n    assert len(doc.sentences) == expected_sentences\n    for sentence, expected_len in zip(doc.sentences, expected_tokens):\n        assert len(sentence.tokens) == expected_len\n    for sentence, raw_sentence in zip(doc.sentences, raw_sentences):\n        for token, line in zip(sentence.tokens, raw_sentence.strip().split(\"\\n\")):\n            word, tag = line.strip().split()\n            assert token.text == word\n            assert token.ner == tag\n\ndef write_and_convert(tmp_path, raw_text):\n    bio_file = tmp_path / \"test.bio\"\n    with open(bio_file, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(raw_text)\n\n    json_file = tmp_path / \"json.bio\"\n    process_dataset(bio_file, json_file)\n\n    with open(json_file) as fin:\n        doc = Document(json.load(fin))\n\n    return doc\n\ndef run_test(tmp_path, raw_text, expected_sentences, expected_tokens):\n    doc = write_and_convert(tmp_path, raw_text)\n    check_json_file(doc, raw_text, expected_sentences, expected_tokens)\n\ndef test_simple(tmp_path):\n    run_test(tmp_path, BIO_1, 1, 5)\n\ndef test_ner_at_end(tmp_path):\n    run_test(tmp_path, BIO_2, 1, 12)\n\ndef test_two_sentences(tmp_path):\n    raw_text = BIO_1 + \"\\n\\n\" + BIO_2\n    run_test(tmp_path, raw_text, 2, [5, 12])\n"
  },
  {
    "path": "stanza/tests/datasets/ner/test_utils.py",
    "content": "\"\"\"\nTest the utils file of the NER dataset processing\n\"\"\"\n\nimport pytest\n\nfrom stanza.utils.datasets.ner.utils import list_doc_entities\nfrom stanza.tests.datasets.ner.test_prepare_ner_file import BIO_1, BIO_2, write_and_convert\n\ndef test_list_doc_entities(tmp_path):\n    \"\"\"\n    Test the function which lists all of the entities in a doc\n    \"\"\"\n    doc = write_and_convert(tmp_path, BIO_1)\n    entities = list_doc_entities(doc)\n    expected = [(('Jennifer', \"Sh'reyan\"), 'PERSON')]\n    assert expected == entities\n\n    doc = write_and_convert(tmp_path, BIO_2)\n    entities = list_doc_entities(doc)\n    expected = [(('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]\n    assert expected == entities    \n\n    doc = write_and_convert(tmp_path, \"\\n\\n\".join([BIO_1, BIO_2]))\n    entities = list_doc_entities(doc)\n    expected = [(('Jennifer', \"Sh'reyan\"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]\n    assert expected == entities\n\n    doc = write_and_convert(tmp_path, \"\\n\\n\".join([BIO_1, BIO_1, BIO_2]))\n    entities = list_doc_entities(doc)\n    expected = [(('Jennifer', \"Sh'reyan\"), 'PERSON'), (('Jennifer', \"Sh'reyan\"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]\n    assert expected == entities\n\n\n"
  },
  {
    "path": "stanza/tests/datasets/test_common.py",
    "content": "\"\"\"\nTest conllu manipulating routines in stanza/utils/dataset/common.py\n\"\"\"\n\nimport pytest\n\n\nfrom stanza.utils.datasets.common import maybe_add_fake_dependencies\n# from stanza.tests import *\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nDEPS_EXAMPLE=\"\"\"\n# text = Sh'reyan's antennae are hella thicc\n1\tSh'reyan\tSh'reyan\tPROPN\tNNP\tNumber=Sing\t3\tnmod:poss\t3:nmod:poss\tSpaceAfter=No\n2\t's\t's\tPART\tPOS\t_\t1\tcase\t1:case\t_\n3\tantennae\tantenna\tNOUN\tNNS\tNumber=Plur\t6\tnsubj\t6:nsubj\t_\n4\tare\tbe\tVERB\tVBP\tMood=Ind|Tense=Pres|VerbForm=Fin\t6\tcop\t6:cop\t_\n5\thella\thella\tADV\tRB\t_\t6\tadvmod\t6:advmod\t_\n6\tthicc\tthicc\tADJ\tJJ\tDegree=Pos\t0\troot\t0:root\t_\n\"\"\".strip().split(\"\\n\")\n\n\nONLY_ROOT_EXAMPLE=\"\"\"\n# text = Sh'reyan's antennae are hella thicc\n1\tSh'reyan\tSh'reyan\tPROPN\tNNP\tNumber=Sing\t_\t_\t_\tSpaceAfter=No\n2\t's\t's\tPART\tPOS\t_\t_\t_\t_\t_\n3\tantennae\tantenna\tNOUN\tNNS\tNumber=Plur\t_\t_\t_\t_\n4\tare\tbe\tVERB\tVBP\tMood=Ind|Tense=Pres|VerbForm=Fin\t_\t_\t_\t_\n5\thella\thella\tADV\tRB\t_\t_\t_\t_\t_\n6\tthicc\tthicc\tADJ\tJJ\tDegree=Pos\t0\troot\t0:root\t_\n\"\"\".strip().split(\"\\n\")\n\nONLY_ROOT_EXPECTED=\"\"\"\n# text = Sh'reyan's antennae are hella thicc\n1\tSh'reyan\tSh'reyan\tPROPN\tNNP\tNumber=Sing\t6\tdep\t_\tSpaceAfter=No\n2\t's\t's\tPART\tPOS\t_\t1\tdep\t_\t_\n3\tantennae\tantenna\tNOUN\tNNS\tNumber=Plur\t1\tdep\t_\t_\n4\tare\tbe\tVERB\tVBP\tMood=Ind|Tense=Pres|VerbForm=Fin\t1\tdep\t_\t_\n5\thella\thella\tADV\tRB\t_\t1\tdep\t_\t_\n6\tthicc\tthicc\tADJ\tJJ\tDegree=Pos\t0\troot\t0:root\t_\n\"\"\".strip().split(\"\\n\")\n\nNO_DEPS_EXAMPLE=\"\"\"\n# text = Sh'reyan's antennae are hella thicc\n1\tSh'reyan\tSh'reyan\tPROPN\tNNP\tNumber=Sing\t_\t_\t_\tSpaceAfter=No\n2\t's\t's\tPART\tPOS\t_\t_\t_\t_\t_\n3\tantennae\tantenna\tNOUN\tNNS\tNumber=Plur\t_\t_\t_\t_\n4\tare\tbe\tVERB\tVBP\tMood=Ind|Tense=Pres|VerbForm=Fin\t_\t_\t_\t_\n5\thella\thella\tADV\tRB\t_\t_\t_\t_\t_\n6\tthicc\tthicc\tADJ\tJJ\tDegree=Pos\t_\t_\t_\t_\n\"\"\".strip().split(\"\\n\")\n\nNO_DEPS_EXPECTED=\"\"\"\n# text = Sh'reyan's antennae are hella thicc\n1\tSh'reyan\tSh'reyan\tPROPN\tNNP\tNumber=Sing\t0\troot\t_\tSpaceAfter=No\n2\t's\t's\tPART\tPOS\t_\t1\tdep\t_\t_\n3\tantennae\tantenna\tNOUN\tNNS\tNumber=Plur\t1\tdep\t_\t_\n4\tare\tbe\tVERB\tVBP\tMood=Ind|Tense=Pres|VerbForm=Fin\t1\tdep\t_\t_\n5\thella\thella\tADV\tRB\t_\t1\tdep\t_\t_\n6\tthicc\tthicc\tADJ\tJJ\tDegree=Pos\t1\tdep\t_\t_\n\"\"\".strip().split(\"\\n\")\n\n\ndef test_fake_deps_no_change():\n    result = maybe_add_fake_dependencies(DEPS_EXAMPLE)\n    assert result == DEPS_EXAMPLE\n\ndef test_fake_deps_all_tokens():\n    result = maybe_add_fake_dependencies(NO_DEPS_EXAMPLE)\n    assert result == NO_DEPS_EXPECTED\n\n\ndef test_fake_deps_only_root():\n    result = maybe_add_fake_dependencies(ONLY_ROOT_EXAMPLE)\n    assert result == ONLY_ROOT_EXPECTED\n"
  },
  {
    "path": "stanza/tests/datasets/test_vietnamese_renormalization.py",
    "content": "import pytest\nimport os\n\nfrom stanza.utils.datasets.vietnamese import renormalize\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_replace_all():\n    text     = \"SỌAmple tụy test file\"\n    expected = \"SOẠmple tuỵ test file\"\n\n    assert renormalize.replace_all(text) == expected\n\ndef test_replace_file(tmp_path):\n    text     = \"SỌAmple tụy test file\"\n    expected = \"SOẠmple tuỵ test file\"\n\n    orig = tmp_path / \"orig.txt\"\n    converted = tmp_path / \"converted.txt\"\n\n    with open(orig, \"w\", encoding=\"utf-8\") as fout:\n        for i in range(10):\n            fout.write(text)\n            fout.write(\"\\n\")\n\n    renormalize.convert_file(orig, converted)\n\n    assert os.path.exists(converted)\n    with open(converted, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n\n    assert len(lines) == 10\n    for i in lines:\n        assert i.strip() == expected\n        \n"
  },
  {
    "path": "stanza/tests/depparse/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/depparse/test_depparse_data.py",
    "content": "\"\"\"\nTest some pieces of the depparse dataloader\n\"\"\"\nimport pytest\nfrom stanza.models import parser\nfrom stanza.models.depparse.data import data_to_batches, DataLoader\nfrom stanza.utils.conll import CoNLL\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef make_fake_data(*lengths):\n    data = []\n    for i, length in enumerate(lengths):\n        word = chr(ord('A') + i)\n        chunk = [[word] * length]\n        data.append(chunk)\n    return data\n\ndef check_batches(batched_data, expected_sizes, expected_order):\n    for chunk, size in zip(batched_data, expected_sizes):\n        assert sum(len(x[0]) for x in chunk) == size\n    word_order = []\n    for chunk in batched_data:\n        for sentence in chunk:\n            word_order.append(sentence[0][0])\n    assert word_order == expected_order\n\ndef test_data_to_batches_eval_mode():\n    \"\"\"\n    Tests the chunking of batches in eval_mode\n\n    A few options are tested, such as whether or not to sort and the maximum sentence size\n    \"\"\"\n    data = make_fake_data(1, 2, 3)\n    batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)\n    check_batches(batched_data[0], [5, 1], ['C', 'B', 'A'])\n\n    data = make_fake_data(1, 2, 6)\n    batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)\n    check_batches(batched_data[0], [6, 3], ['C', 'B', 'A'])\n\n    data = make_fake_data(3, 2, 1)\n    batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)\n    check_batches(batched_data[0], [5, 1], ['A', 'B', 'C'])\n\n    data = make_fake_data(3, 5, 2)\n    batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)\n    check_batches(batched_data[0], [5, 5], ['B', 'A', 'C'])\n\n    data = make_fake_data(3, 5, 2)\n    batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)\n    check_batches(batched_data[0], [3, 5, 2], ['A', 'B', 'C'])\n\n    data = make_fake_data(4, 1, 1)\n    batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)\n    check_batches(batched_data[0], [4, 2], ['A', 'B', 'C'])\n\n    data = make_fake_data(1, 4, 1)\n    batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)\n    check_batches(batched_data[0], [1, 4, 1], ['A', 'B', 'C'])\n\n\nEWT_PUNCT_SAMPLE = \"\"\"\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0048\n# text = Bush asked for permission to go to Alabama to work on a Senate campaign.\n1\tBush\tBush\tPROPN\tNNP\tNumber=Sing\t2\tnsubj\t2:nsubj\t_\n2\tasked\task\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n3\tfor\tfor\tADP\tIN\t_\t4\tcase\t4:case\t_\n4\tpermission\tpermission\tNOUN\tNN\tNumber=Sing\t2\tobl\t2:obl:for\t_\n5\tto\tto\tPART\tTO\t_\t6\tmark\t6:mark\t_\n6\tgo\tgo\tVERB\tVB\tVerbForm=Inf\t4\tacl\t4:acl:to\t_\n7\tto\tto\tADP\tIN\t_\t8\tcase\t8:case\t_\n8\tAlabama\tAlabama\tPROPN\tNNP\tNumber=Sing\t6\tobl\t6:obl:to\t_\n9\tto\tto\tPART\tTO\t_\t10\tmark\t10:mark\t_\n10\twork\twork\tVERB\tVB\tVerbForm=Inf\t6\tadvcl\t6:advcl:to\t_\n11\ton\ton\tADP\tIN\t_\t14\tcase\t14:case\t_\n12\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t14\tdet\t14:det\t_\n13\tSenate\tSenate\tPROPN\tNNP\tNumber=Sing\t14\tcompound\t14:compound\t_\n14\tcampaign\tcampaign\tNOUN\tNN\tNumber=Sing\t10\tobl\t10:obl:on\tSpaceAfter=No\n15\t!!!!!\t!\tPUNCT\t.\t_\t2\tpunct\t2:punct\t_\n\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0049\n# text = His superior officers said OK.\n1\tHis\this\tPRON\tPRP$\tCase=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t3\tnmod:poss\t3:nmod:poss\t_\n2\tsuperior\tsuperior\tADJ\tJJ\tDegree=Pos\t3\tamod\t3:amod\t_\n3\tofficers\tofficer\tNOUN\tNNS\tNumber=Plur\t4\tnsubj\t4:nsubj\t_\n4\tsaid\tsay\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n5\tOK\tok\tINTJ\tUH\t_\t4\tobj\t4:obj\tSpaceAfter=No\n6\t?????\t?\tPUNCT\t.\t_\t4\tpunct\t4:punct\t_\n\"\"\"\n\n\ndef test_punct_simplification():\n    \"\"\"\n    Test a punctuation simplification that should make it so unexpected\n    question/exclamation marks types are processed into ? and !\n    \"\"\"\n    sample = CoNLL.conll2doc(input_str=EWT_PUNCT_SAMPLE)\n\n    args = parser.parse_args(args=[\"--batch_size\", \"1000\", \"--shorthand\", \"en_test\"])\n    data = DataLoader(sample, 5000, args, None)\n\n    batches = [batch for batch in data]\n    assert batches[0][-1] == [['Bush', 'asked', 'for', 'permission', 'to', 'go', 'to', 'Alabama', 'to', 'work', 'on', 'a', 'Senate', 'campaign', '!'],\n                              ['His', 'superior', 'officers', 'said', 'OK', '?']]\n\n\nif __name__ == '__main__':\n    test_data_to_batches()\n\n"
  },
  {
    "path": "stanza/tests/depparse/test_parser.py",
    "content": "\"\"\"\nRun the tagger for a couple iterations on some fake data\n\nUses a couple sentences of UD_English-EWT as training/dev data\n\"\"\"\n\nimport os\nimport pytest\nimport zipfile\n\nimport torch\n\nfrom stanza.models import parser\nfrom stanza.models.common import pretrain\nfrom stanza.models.depparse.trainer import Trainer\nfrom stanza.tests import TEST_WORKING_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nTRAIN_DATA = \"\"\"\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003\n# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.\n1\tDPA\tDPA\tPROPN\tNNP\tNumber=Sing\t0\troot\t0:root\tSpaceAfter=No\n2\t:\t:\tPUNCT\t:\t_\t1\tpunct\t1:punct\t_\n3\tIraqi\tIraqi\tADJ\tJJ\tDegree=Pos\t4\tamod\t4:amod\t_\n4\tauthorities\tauthority\tNOUN\tNNS\tNumber=Plur\t5\tnsubj\t5:nsubj\t_\n5\tannounced\tannounce\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t1\tparataxis\t1:parataxis\t_\n6\tthat\tthat\tSCONJ\tIN\t_\t9\tmark\t9:mark\t_\n7\tthey\tthey\tPRON\tPRP\tCase=Nom|Number=Plur|Person=3|PronType=Prs\t9\tnsubj\t9:nsubj\t_\n8\thad\thave\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t9\taux\t9:aux\t_\n9\tbusted\tbust\tVERB\tVBN\tTense=Past|VerbForm=Part\t5\tccomp\t5:ccomp\t_\n10\tup\tup\tADP\tRP\t_\t9\tcompound:prt\t9:compound:prt\t_\n11\t3\t3\tNUM\tCD\tNumForm=Digit|NumType=Card\t13\tnummod\t13:nummod\t_\n12\tterrorist\tterrorist\tADJ\tJJ\tDegree=Pos\t13\tamod\t13:amod\t_\n13\tcells\tcell\tNOUN\tNNS\tNumber=Plur\t9\tobj\t9:obj\t_\n14\toperating\toperate\tVERB\tVBG\tVerbForm=Ger\t13\tacl\t13:acl\t_\n15\tin\tin\tADP\tIN\t_\t16\tcase\t16:case\t_\n16\tBaghdad\tBaghdad\tPROPN\tNNP\tNumber=Sing\t14\tobl\t14:obl:in\tSpaceAfter=No\n17\t.\t.\tPUNCT\t.\t_\t1\tpunct\t1:punct\t_\n\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004\n# text = Two of them were being run by 2 officials of the Ministry of the Interior!\n1\tTwo\ttwo\tNUM\tCD\tNumForm=Word|NumType=Card\t6\tnsubj:pass\t6:nsubj:pass\t_\n2\tof\tof\tADP\tIN\t_\t3\tcase\t3:case\t_\n3\tthem\tthey\tPRON\tPRP\tCase=Acc|Number=Plur|Person=3|PronType=Prs\t1\tnmod\t1:nmod:of\t_\n4\twere\tbe\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t6\taux\t6:aux\t_\n5\tbeing\tbe\tAUX\tVBG\tVerbForm=Ger\t6\taux:pass\t6:aux:pass\t_\n6\trun\trun\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t0:root\t_\n7\tby\tby\tADP\tIN\t_\t9\tcase\t9:case\t_\n8\t2\t2\tNUM\tCD\tNumForm=Digit|NumType=Card\t9\tnummod\t9:nummod\t_\n9\tofficials\tofficial\tNOUN\tNNS\tNumber=Plur\t6\tobl\t6:obl:by\t_\n10\tof\tof\tADP\tIN\t_\t12\tcase\t12:case\t_\n11\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t12\tdet\t12:det\t_\n12\tMinistry\tMinistry\tPROPN\tNNP\tNumber=Sing\t9\tnmod\t9:nmod:of\t_\n13\tof\tof\tADP\tIN\t_\t15\tcase\t15:case\t_\n14\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t15\tdet\t15:det\t_\n15\tInterior\tInterior\tPROPN\tNNP\tNumber=Sing\t12\tnmod\t12:nmod:of\tSpaceAfter=No\n16\t!\t!\tPUNCT\t.\t_\t6\tpunct\t6:punct\t_\n\n\"\"\".lstrip()\n\n\nDEV_DATA = \"\"\"\n1\tFrom\tfrom\tADP\tIN\t_\t3\tcase\t3:case\t_\n2\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t3\tdet\t3:det\t_\n3\tAP\tAP\tPROPN\tNNP\tNumber=Sing\t4\tobl\t4:obl:from\t_\n4\tcomes\tcome\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t0:root\t_\n5\tthis\tthis\tDET\tDT\tNumber=Sing|PronType=Dem\t6\tdet\t6:det\t_\n6\tstory\tstory\tNOUN\tNN\tNumber=Sing\t4\tnsubj\t4:nsubj\t_\n7\t:\t:\tPUNCT\t:\t_\t4\tpunct\t4:punct\t_\n\n\"\"\".lstrip()\n\n\n\nclass TestParser:\n    @pytest.fixture(scope=\"class\")\n    def wordvec_pretrain_file(self):\n        return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'\n\n    def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None, zip_train_data=False):\n        \"\"\"\n        Run the training for a few iterations, load & return the model\n        \"\"\"\n        train_file = str(tmp_path / \"train.zip\") if zip_train_data else str(tmp_path / \"train.conllu\")\n        dev_file = str(tmp_path / \"dev.conllu\")\n        pred_file = str(tmp_path / \"pred.conllu\")\n\n        save_name = \"test_parser.pt\"\n        save_file = str(tmp_path / save_name)\n\n        if zip_train_data:\n            with zipfile.ZipFile(train_file, \"w\") as zout:\n                with zout.open('train.conllu', 'w') as fout:\n                    fout.write(train_text.encode())\n        else:\n            with open(train_file, \"w\", encoding=\"utf-8\") as fout:\n                fout.write(train_text)\n\n        with open(dev_file, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(dev_text)\n\n        args = [\"--wordvec_pretrain_file\", wordvec_pretrain_file,\n                \"--train_file\", train_file,\n                \"--eval_file\", dev_file,\n                \"--output_file\", pred_file,\n                \"--log_step\", \"10\",\n                \"--eval_interval\", \"20\",\n                \"--max_steps\", \"100\",\n                \"--shorthand\", \"en_test\",\n                \"--save_dir\", str(tmp_path),\n                \"--save_name\", save_name,\n                # in case we are doing a bert test\n                \"--bert_start_finetuning\", \"10\",\n                \"--bert_warmup_steps\", \"10\",\n                \"--lang\", \"en\"]\n        if not augment_nopunct:\n            args.extend([\"--augment_nopunct\", \"0.0\"])\n        if extra_args is not None:\n            args = args + extra_args\n        trainer, _ = parser.main(args)\n\n        assert os.path.exists(save_file)\n        pt = pretrain.Pretrain(wordvec_pretrain_file)\n        # test loading the saved model\n        saved_model = Trainer(pretrain=pt, model_file=save_file)\n        return trainer\n\n    def test_train(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Simple test of a few 'epochs' of tagger training\n        \"\"\"\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA)\n\n    def test_arc_embedding(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Simple test w/ and w/o arc embedding\n        \"\"\"\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--use_arc_embedding'])\n\n    def test_no_arc_embedding(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Simple test w/ and w/o arc embedding\n        \"\"\"\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--no_use_arc_embedding'])\n\n    def test_zipfile_train(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Simple test of a few 'epochs' of tagger training with a zipfile\n        \"\"\"\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, zip_train_data=True)\n\n    def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])\n\n    def test_with_bert_finetuning(self, tmp_path, wordvec_pretrain_file):\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])\n        assert 'bert_optimizer' in trainer.optimizer.keys()\n        assert 'bert_scheduler' in trainer.scheduler.keys()\n\n    def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Check that if we save, then load, then save a model with a finetuned bert, that bert isn't lost\n        \"\"\"\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])\n        assert 'bert_optimizer' in trainer.optimizer.keys()\n        assert 'bert_scheduler' in trainer.scheduler.keys()\n\n        save_name = trainer.args['save_name']\n        filename = tmp_path / save_name\n        assert os.path.exists(filename)\n        checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        assert any(x.startswith(\"bert_model\") for x in checkpoint['model'].keys())\n\n        # Test loading the saved model, saving it, and still having bert in it\n        # even if we have set bert_finetune to False for this incarnation\n        pt = pretrain.Pretrain(wordvec_pretrain_file)\n        args = {\"bert_finetune\": False}\n        saved_model = Trainer(pretrain=pt, model_file=filename, args=args)\n\n        saved_model.save(filename)\n\n        # This is the part that would fail if the force_bert_saved option did not exist\n        checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n        assert any(x.startswith(\"bert_model\") for x in checkpoint['model'].keys())\n\n    def test_with_peft(self, tmp_path, wordvec_pretrain_file):\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2', '--use_peft'])\n        assert 'bert_optimizer' in trainer.optimizer.keys()\n        assert 'bert_scheduler' in trainer.scheduler.keys()\n\n    def test_single_optimizer_checkpoint(self, tmp_path, wordvec_pretrain_file):\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam'])\n\n        save_dir = trainer.args['save_dir']\n        save_name = trainer.args['save_name']\n        checkpoint_name = trainer.args[\"checkpoint_save_name\"]\n\n        assert os.path.exists(os.path.join(save_dir, save_name))\n        assert checkpoint_name is not None\n        assert os.path.exists(checkpoint_name)\n\n        assert len(trainer.optimizer) == 1\n        for opt in trainer.optimizer.values():\n            assert isinstance(opt, torch.optim.Adam)\n\n        pt = pretrain.Pretrain(wordvec_pretrain_file)\n        checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)\n        assert checkpoint.optimizer is not None\n        assert len(checkpoint.optimizer) == 1\n        for opt in checkpoint.optimizer.values():\n            assert isinstance(opt, torch.optim.Adam)\n\n    def test_two_optimizers_checkpoint(self, tmp_path, wordvec_pretrain_file):\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam', '--second_optim', 'sgd', '--second_optim_start_step', '40'])\n\n        save_dir = trainer.args['save_dir']\n        save_name = trainer.args['save_name']\n        checkpoint_name = trainer.args[\"checkpoint_save_name\"]\n\n        assert os.path.exists(os.path.join(save_dir, save_name))\n        assert checkpoint_name is not None\n        assert os.path.exists(checkpoint_name)\n\n        assert len(trainer.optimizer) == 1\n        for opt in trainer.optimizer.values():\n            assert isinstance(opt, torch.optim.SGD)\n\n        pt = pretrain.Pretrain(wordvec_pretrain_file)\n        checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)\n        assert checkpoint.optimizer is not None\n        assert len(checkpoint.optimizer) == 1\n        for opt in trainer.optimizer.values():\n            assert isinstance(opt, torch.optim.SGD)\n\n"
  },
  {
    "path": "stanza/tests/langid/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/langid/test_langid.py",
    "content": "\"\"\"\nBasic tests of langid module\n\"\"\"\n\nimport pytest\n\nfrom stanza.models.common.doc import Document\nfrom stanza.pipeline.core import Pipeline\nfrom stanza.pipeline.langid_processor import LangIDProcessor\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n#pytestmark = pytest.mark.skip\n\n@pytest.fixture(scope=\"module\")\ndef basic_multilingual():\n    return Pipeline(dir=TEST_MODELS_DIR, lang='multilingual', processors=\"langid\")\n\n@pytest.fixture(scope=\"module\")\ndef enfr_multilingual():\n    return Pipeline(dir=TEST_MODELS_DIR, lang=\"multilingual\", processors=\"langid\", langid_lang_subset=[\"en\", \"fr\"])\n\n@pytest.fixture(scope=\"module\")\ndef en_multilingual():\n    return Pipeline(dir=TEST_MODELS_DIR, lang=\"multilingual\", processors=\"langid\", langid_lang_subset=[\"en\"])\n\n@pytest.fixture(scope=\"module\")\ndef clean_multilingual():\n    return Pipeline(dir=TEST_MODELS_DIR, lang=\"multilingual\", processors=\"langid\", langid_clean_text=True)\n\ndef test_langid(basic_multilingual):\n    \"\"\"\n    Basic test of language identification\n    \"\"\"\n    english_text = \"This is an English sentence.\"\n    french_text = \"C'est une phrase française.\"\n    docs = [english_text, french_text]\n\n    docs = [Document([], text=text) for text in docs]\n    basic_multilingual(docs)\n    predictions = [doc.lang for doc in docs]\n    assert predictions == [\"en\", \"fr\"]\n\ndef test_langid_benchmark(basic_multilingual):\n    \"\"\"\n    Run lang id model on 500 examples, confirm reasonable accuracy.\n    \"\"\"\n    examples = [\n    {\"text\": \"contingentiam in naturalibus causis.\", \"label\": \"la\"},\n    {\"text\": \"I jak opowiadał nieżyjący już pan Czesław\", \"label\": \"pl\"},\n    {\"text\": \"Sonera gilt seit längerem als Übernahmekandidat\", \"label\": \"de\"},\n    {\"text\": \"与银类似，汞也可以与空气中的硫化氢反应。\", \"label\": \"zh-hans\"},\n    {\"text\": \"contradictionem implicat.\", \"label\": \"la\"},\n    {\"text\": \"Bis zu Prozent gingen die Offerten etwa im\", \"label\": \"de\"},\n    {\"text\": \"inneren Sicherheit vorgeschlagene Ausweitung der\", \"label\": \"de\"},\n    {\"text\": \"Multimedia-PDA mit Mini-Tastatur\", \"label\": \"de\"},\n    {\"text\": \"Ponášalo sa to na rovnicu o dvoch neznámych.\", \"label\": \"sk\"},\n    {\"text\": \"이처럼 앞으로 심판의 그 날에 다시 올 메시아가 예수 그리스도이며 , 그는 모든 인류의\", \"label\": \"ko\"},\n    {\"text\": \"Die Arbeitsgruppe bedauert , dass der weit über\", \"label\": \"de\"},\n    {\"text\": \"И только раз довелось поговорить с ним не вполне\", \"label\": \"ru\"},\n    {\"text\": \"de a-l lovi cu piciorul și conștiința că era\", \"label\": \"ro\"},\n    {\"text\": \"relación coas pretensións do demandante e que, nos\", \"label\": \"gl\"},\n    {\"text\": \"med petdeset in sedemdeset\", \"label\": \"sl\"},\n    {\"text\": \"Catalunya; el Consell Comarcal del Vallès Oriental\", \"label\": \"ca\"},\n    {\"text\": \"kunnen worden.\", \"label\": \"nl\"},\n    {\"text\": \"Witkin je ve většině ohledů zcela jiný.\", \"label\": \"cs\"},\n    {\"text\": \"lernen, so zu agieren, dass sie positive oder auch\", \"label\": \"de\"},\n    {\"text\": \"olurmuş...\", \"label\": \"tr\"},\n    {\"text\": \"sarcasmo de Altman, desde as «peruas» que discutem\", \"label\": \"pt\"},\n    {\"text\": \"خلاف فوجداری مقدمہ درج کرے۔\", \"label\": \"ur\"},\n    {\"text\": \"Norddal kommune :\", \"label\": \"no\"},\n    {\"text\": \"dem Windows-.-Zeitalter , soll in diesem Jahr\", \"label\": \"de\"},\n    {\"text\": \"przeklętych ucieleśniają mit poety-cygana,\", \"label\": \"pl\"},\n    {\"text\": \"We do not believe the suspect has ties to this\", \"label\": \"en\"},\n    {\"text\": \"groziņu pīšanu.\", \"label\": \"lv\"},\n    {\"text\": \"Senior Vice-President David M. Thomas möchte\", \"label\": \"de\"},\n    {\"text\": \"neomylně vybral nějakou knihu a začetl se.\", \"label\": \"cs\"},\n    {\"text\": \"Statt dessen darf beispielsweise der Browser des\", \"label\": \"de\"},\n    {\"text\": \"outubro, alcançando R $ bilhões em .\", \"label\": \"pt\"},\n    {\"text\": \"(Porte, ), as it does other disciplines\", \"label\": \"en\"},\n    {\"text\": \"uskupení se mylně domnívaly, že podporu\", \"label\": \"cs\"},\n    {\"text\": \"Übernahme von Next Ende an dem System herum , das\", \"label\": \"de\"},\n    {\"text\": \"No podemos decir a la Hacienda que los alemanes\", \"label\": \"es\"},\n    {\"text\": \"и рѣста еи братья\", \"label\": \"orv\"},\n    {\"text\": \"الذي اتخذ قرارا بتجميد اعلان الدولة الفلسطينية\", \"label\": \"ar\"},\n    {\"text\": \"uurides Rootsi sõjaarhiivist toodud . sajandi\", \"label\": \"et\"},\n    {\"text\": \"selskapets penger til å pusse opp sin enebolig på\", \"label\": \"no\"},\n    {\"text\": \"средней полосе и севернее в Ярославской,\", \"label\": \"ru\"},\n    {\"text\": \"il-massa żejda fil-ġemgħat u superġemgħat ta'\", \"label\": \"mt\"},\n    {\"text\": \"The Global Beauties on internetilehekülg, mida\", \"label\": \"et\"},\n    {\"text\": \"이스라엘 인들은 하나님이 그 큰 팔을 펴 이집트 인들을 치는 것을 보고 하나님을 두려워하며\", \"label\": \"ko\"},\n    {\"text\": \"Snad ještě dodejme jeden ekonomický argument.\", \"label\": \"cs\"},\n    {\"text\": \"Spalio d. vykusiame pirmajame rinkimų ture\", \"label\": \"lt\"},\n    {\"text\": \"und schlechter Journalismus ein gutes Geschäft .\", \"label\": \"de\"},\n    {\"text\": \"Du sodiečiai sėdi ant potvynio apsemtų namų stogo.\", \"label\": \"lt\"},\n    {\"text\": \"цей є автентичним.\", \"label\": \"uk\"},\n    {\"text\": \"Și îndegrabă fu cu îngerul mulțime de șireaguri\", \"label\": \"ro\"},\n    {\"text\": \"sobra personal cualificado.\", \"label\": \"es\"},\n    {\"text\": \"Tako se u Njemačkoj dvije trećine liječnika služe\", \"label\": \"hr\"},\n    {\"text\": \"Dual-Athlon-Chipsatz noch in diesem Jahr\", \"label\": \"de\"},\n    {\"text\": \"यहां तक कि चीन के चीफ ऑफ जनरल स्टाफ भी भारत का\", \"label\": \"hi\"},\n    {\"text\": \"Li forestier du mont avale\", \"label\": \"fro\"},\n    {\"text\": \"Netzwerken für Privatanwender zu bewundern .\", \"label\": \"de\"},\n    {\"text\": \"만해는 승적을 가진 중이 결혼할 수 없다는 불교의 계율을 시대에 맞지 않는 것으로 보았다\", \"label\": \"ko\"},\n    {\"text\": \"balance and weight distribution but not really for\", \"label\": \"en\"},\n    {\"text\": \"og så e # tente vi opp den om morgonen å sfyrte\", \"label\": \"nn\"},\n    {\"text\": \"변화는 의심의 여지가 없는 것이지만 반면에 진화는 논쟁의 씨앗이다 .\", \"label\": \"ko\"},\n    {\"text\": \"puteare fac aceastea.\", \"label\": \"ro\"},\n    {\"text\": \"Waitt seine Führungsmannschaft nicht dem\", \"label\": \"de\"},\n    {\"text\": \"juhtimisega, tulid sealt.\", \"label\": \"et\"},\n    {\"text\": \"Veränderungen .\", \"label\": \"de\"},\n    {\"text\": \"banda en el Bayer Leverkusen de la Bundesliga de\", \"label\": \"es\"},\n    {\"text\": \"В туже зиму посла всеволодъ сн҃а своѥго ст҃ослава\", \"label\": \"orv\"},\n    {\"text\": \"пославъ приведе я мастеры ѿ грекъ\", \"label\": \"orv\"},\n    {\"text\": \"En un nou escenari difícil d'imaginar fa poques\", \"label\": \"ca\"},\n    {\"text\": \"καὶ γὰρ τινὲς αὐτοὺς εὐεργεσίαι εἶχον ἐκ Κροίσου\", \"label\": \"grc\"},\n    {\"text\": \"직접적인 관련이 있다 .\", \"label\": \"ko\"},\n    {\"text\": \"가까운 듯하면서도 멀다 .\", \"label\": \"ko\"},\n    {\"text\": \"Er bietet ein ähnliches Leistungsniveau und\", \"label\": \"de\"},\n    {\"text\": \"民都洛水牛是獨居的，並不會以群族聚居。\", \"label\": \"zh-hant\"},\n    {\"text\": \"την τρομοκρατία.\", \"label\": \"el\"},\n    {\"text\": \"hurbiltzen diren neurrian.\", \"label\": \"eu\"},\n    {\"text\": \"Ah dimenticavo, ma tutta sta caciara per fare un\", \"label\": \"it\"},\n    {\"text\": \"На первом этапе (-) прошла так называемая\", \"label\": \"ru\"},\n    {\"text\": \"of games are on the market.\", \"label\": \"en\"},\n    {\"text\": \"находится Мост дружбы, соединяющий узбекский и\", \"label\": \"ru\"},\n    {\"text\": \"lessié je voldroie que li saint fussent aporté\", \"label\": \"fro\"},\n    {\"text\": \"Дошла очередь и до Гималаев.\", \"label\": \"ru\"},\n    {\"text\": \"vzácným suknem táhly pouští, si jednou chtěl do\", \"label\": \"cs\"},\n    {\"text\": \"E no terceiro tipo sitúa a familias (%), nos que a\", \"label\": \"gl\"},\n    {\"text\": \"وجابت دوريات امريكية وعراقية شوارع المدينة، فيما\", \"label\": \"ar\"},\n    {\"text\": \"Jeg har bodd her i år .\", \"label\": \"no\"},\n    {\"text\": \"Pohrozil, že odbory zostří postoj, pokud se\", \"label\": \"cs\"},\n    {\"text\": \"tinham conseguido.\", \"label\": \"pt\"},\n    {\"text\": \"Nicht-Erkrankten einen Anfangsverdacht für einen\", \"label\": \"de\"},\n    {\"text\": \"permanece em aberto.\", \"label\": \"pt\"},\n    {\"text\": \"questi possono promettere rendimenti fino a un\", \"label\": \"it\"},\n    {\"text\": \"Tema juurutatud kahevedurisüsteemita oleksid\", \"label\": \"et\"},\n    {\"text\": \"Поведение внешне простой игрушки оказалось\", \"label\": \"ru\"},\n    {\"text\": \"Bundesländern war vom Börsenverein des Deutschen\", \"label\": \"de\"},\n    {\"text\": \"acció, 'a mesura que avanci l'estiu, amb l'augment\", \"label\": \"ca\"},\n    {\"text\": \"Dove trovare queste risorse? Jay Naidoo, ministro\", \"label\": \"it\"},\n    {\"text\": \"essas gordurinhas.\", \"label\": \"pt\"},\n    {\"text\": \"Im zweiten Schritt sollen im übernächsten Jahr\", \"label\": \"de\"},\n    {\"text\": \"allveelaeva pole enam vaja, kuna külm sõda on läbi\", \"label\": \"et\"},\n    {\"text\": \"उपद्रवी दुकानों को लूटने के साथ ही उनमें आग लगा\", \"label\": \"hi\"},\n    {\"text\": \"@user nella sfortuna sei fortunata ..\", \"label\": \"it\"},\n    {\"text\": \"математических школ в виде грозовых туч.\", \"label\": \"ru\"},\n    {\"text\": \"No cambiaremos nunca nuestra forma de jugar por un\", \"label\": \"es\"},\n    {\"text\": \"dla tej klasy ani wymogów minimalnych, z wyjątkiem\", \"label\": \"pl\"},\n    {\"text\": \"en todo el mundo, mientras que en España consiguió\", \"label\": \"es\"},\n    {\"text\": \"политики считать надежное обеспечение военной\", \"label\": \"ru\"},\n    {\"text\": \"gogoratzen du, genio alemana delakoaren\", \"label\": \"eu\"},\n    {\"text\": \"Бычий глаз.\", \"label\": \"ru\"},\n    {\"text\": \"Opeření se v pravidelných obdobích obnovuje\", \"label\": \"cs\"},\n    {\"text\": \"I no és només la seva, es tracta d'una resposta\", \"label\": \"ca\"},\n    {\"text\": \"오경을 가르쳤다 .\", \"label\": \"ko\"},\n    {\"text\": \"Nach der so genannten Start-up-Periode vergibt die\", \"label\": \"de\"},\n    {\"text\": \"Saulista huomasi jo lapsena , että hänellä on\", \"label\": \"fi\"},\n    {\"text\": \"Министерство культуры сочло нецелесообразным, и\", \"label\": \"ru\"},\n    {\"text\": \"znepřátelené tábory v Tádžikistánu předseda\", \"label\": \"cs\"},\n    {\"text\": \"καὶ ἦν ὁ λαὸς προσδοκῶν τὸν Ζαχαρίαν καὶ ἐθαύμαζον\", \"label\": \"grc\"},\n    {\"text\": \"Вечером, в продукте, этот же человек говорил о\", \"label\": \"ru\"},\n    {\"text\": \"lugar á formación de xuizos máis complexos.\", \"label\": \"gl\"},\n    {\"text\": \"cheaper, in the end?\", \"label\": \"en\"},\n    {\"text\": \"الوزارة في شأن صفقات بيع الشركات العامة التي تم\", \"label\": \"ar\"},\n    {\"text\": \"tärkeintä elämässäni .\", \"label\": \"fi\"},\n    {\"text\": \"Виконання Мінських угод було заблоковано Росією та\", \"label\": \"uk\"},\n    {\"text\": \"Aby szybko rozpoznać żołnierzy desantu, należy\", \"label\": \"pl\"},\n    {\"text\": \"Bankengeschäfte liegen vorn , sagte Strothmann .\", \"label\": \"de\"},\n    {\"text\": \"продолжение работы.\", \"label\": \"ru\"},\n    {\"text\": \"Metro AG plant Online-Offensive\", \"label\": \"de\"},\n    {\"text\": \"nu vor veni, și să vor osîndi, aceia nu pot porni\", \"label\": \"ro\"},\n    {\"text\": \"Ich denke , es geht in Wirklichkeit darum , NT bei\", \"label\": \"de\"},\n    {\"text\": \"de turism care încasează contravaloarea\", \"label\": \"ro\"},\n    {\"text\": \"Aurkaria itotzea da helburua, baloia lapurtu eta\", \"label\": \"eu\"},\n    {\"text\": \"com a centre de formació en Tecnologies de la\", \"label\": \"ca\"},\n    {\"text\": \"oportet igitur quod omne agens in agendo intendat\", \"label\": \"la\"},\n    {\"text\": \"Jerzego Andrzejewskiego, oparty na chińskich\", \"label\": \"pl\"},\n    {\"text\": \"sau một vài câu chuyện xã giao không dính dáng tới\", \"label\": \"vi\"},\n    {\"text\": \"что экономическому прорыву жесткий авторитарный\", \"label\": \"ru\"},\n    {\"text\": \"DRAM-Preisen scheinen DSPs ein\", \"label\": \"de\"},\n    {\"text\": \"Jos dajan nubbái: Mana!\", \"label\": \"sme\"},\n    {\"text\": \"toți carii ascultară de el să răsipiră.\", \"label\": \"ro\"},\n    {\"text\": \"odpowiedzialności, które w systemie własności\", \"label\": \"pl\"},\n    {\"text\": \"Dvomesečno potovanje do Mollenda v Peruju je\", \"label\": \"sl\"},\n    {\"text\": \"d'entre les agències internacionals.\", \"label\": \"ca\"},\n    {\"text\": \"Fahrzeugzugangssysteme gefertigt und an viele\", \"label\": \"de\"},\n    {\"text\": \"in an answer to the sharers' petition in Cuthbert\", \"label\": \"en\"},\n    {\"text\": \"Europa-Domain per Verordnung zu regeln .\", \"label\": \"de\"},\n    {\"text\": \"#Balotelli. Su ebay prezzi stracciati per Silvio\", \"label\": \"it\"},\n    {\"text\": \"Ne na košickém trávníku, ale už včera v letadle se\", \"label\": \"cs\"},\n    {\"text\": \"zaměstnanosti a investičních strategií.\", \"label\": \"cs\"},\n    {\"text\": \"Tatínku, udělej den\", \"label\": \"cs\"},\n    {\"text\": \"frecuencia con Mary.\", \"label\": \"es\"},\n    {\"text\": \"Свеаборге.\", \"label\": \"ru\"},\n    {\"text\": \"opatření slovenské strany o certifikaci nejvíce\", \"label\": \"cs\"},\n    {\"text\": \"En todas me decían: 'Espera que hagamos un estudio\", \"label\": \"es\"},\n    {\"text\": \"Die Demonstration sollte nach Darstellung der\", \"label\": \"de\"},\n    {\"text\": \"Ci vorrà un assoluto rigore se dietro i disavanzi\", \"label\": \"it\"},\n    {\"text\": \"Tatínku, víš, že Honzovi odešla maminka?\", \"label\": \"cs\"},\n    {\"text\": \"Die Anzahl der Rechner wuchs um % auf und die\", \"label\": \"de\"},\n    {\"text\": \"האמריקאית על אדמת סעודיה עלולה לסבך את ישראל, אין\", \"label\": \"he\"},\n    {\"text\": \"Volán Egyesülés, a Közlekedési Főfelügyelet is.\", \"label\": \"hu\"},\n    {\"text\": \"Schejbala, který stejnou hru s velkým úspěchem\", \"label\": \"cs\"},\n    {\"text\": \"depends on the data type of the field.\", \"label\": \"en\"},\n    {\"text\": \"Umsatzwarnung zu Wochenbeginn zeitweise auf ein\", \"label\": \"de\"},\n    {\"text\": \"niin heti nukun .\", \"label\": \"fi\"},\n    {\"text\": \"Mobilfunkunternehmen gegen die Anwendung der so\", \"label\": \"de\"},\n    {\"text\": \"sapessi le intenzioni del governo Monti e dell'UE\", \"label\": \"it\"},\n    {\"text\": \"Di chi è figlia Martine Aubry?\", \"label\": \"it\"},\n    {\"text\": \"avec le reste du monde.\", \"label\": \"fr\"},\n    {\"text\": \"Այդ մաքոքը ինքնին նոր չէ, աշխարհը արդեն մի քանի\", \"label\": \"hy\"},\n    {\"text\": \"și în cazul destrămării cenaclului.\", \"label\": \"ro\"},\n    {\"text\": \"befriedigen kann , und ohne die auftretenden\", \"label\": \"de\"},\n    {\"text\": \"Κύκνον τ̓ ἐξεναρεῖν καὶ ἀπὸ κλυτὰ τεύχεα δῦσαι.\", \"label\": \"grc\"},\n    {\"text\": \"færdiguddannede.\", \"label\": \"da\"},\n    {\"text\": \"Schmidt war Sohn eines Rittergutsbesitzers.\", \"label\": \"de\"},\n    {\"text\": \"и вдаша попадь ѡпрати\", \"label\": \"orv\"},\n    {\"text\": \"cine nu știe învățătură”.\", \"label\": \"ro\"},\n    {\"text\": \"détacha et cette dernière tenta de tuer le jeune\", \"label\": \"fr\"},\n    {\"text\": \"Der har saka også ei lengre forhistorie.\", \"label\": \"nn\"},\n    {\"text\": \"Pieprz roztłuc w moździerzu, dodać do pasty,\", \"label\": \"pl\"},\n    {\"text\": \"Лежа за гребнем оврага, как за бруствером, Ушаков\", \"label\": \"ru\"},\n    {\"text\": \"gesucht habe, vielen Dank nochmals!\", \"label\": \"de\"},\n    {\"text\": \"инструментальных сталей, повышения\", \"label\": \"ru\"},\n    {\"text\": \"im Halbfinale Patrick Smith und im Finale dann\", \"label\": \"de\"},\n    {\"text\": \"البنوك التريث في منح تسهيلات جديدة لمنتجي حديد\", \"label\": \"ar\"},\n    {\"text\": \"una bolsa ventral, la cual se encuentra debajo de\", \"label\": \"es\"},\n    {\"text\": \"za SETimes.\", \"label\": \"sr\"},\n    {\"text\": \"de Irak, a un piloto italiano que había violado el\", \"label\": \"es\"},\n    {\"text\": \"Er könne sich nicht erklären , wie die Zeitung auf\", \"label\": \"de\"},\n    {\"text\": \"Прохорова.\", \"label\": \"ru\"},\n    {\"text\": \"la democrazia perde sulla tecnocrazia? #\", \"label\": \"it\"},\n    {\"text\": \"entre ambas instituciones, confirmó al medio que\", \"label\": \"es\"},\n    {\"text\": \"Austlandet, vart det funne om lag førti\", \"label\": \"nn\"},\n    {\"text\": \"уровнями власти.\", \"label\": \"ru\"},\n    {\"text\": \"Dá tedy primáři úplatek, a často ne malý.\", \"label\": \"cs\"},\n    {\"text\": \"brillantes del acto, al llevar a cabo en el\", \"label\": \"es\"},\n    {\"text\": \"eee druga zadeva je majhen priročen gre kamorkoli\", \"label\": \"sl\"},\n    {\"text\": \"Das ATX-Board paßt in herkömmliche PC-ATX-Gehäuse\", \"label\": \"de\"},\n    {\"text\": \"Za vodné bylo v prvním pololetí zaplaceno v ČR\", \"label\": \"cs\"},\n    {\"text\": \"Даже на полсантиметра.\", \"label\": \"ru\"},\n    {\"text\": \"com la del primer tinent d'alcalde en funcions,\", \"label\": \"ca\"},\n    {\"text\": \"кількох оповідань в цілості — щось на зразок того\", \"label\": \"uk\"},\n    {\"text\": \"sed ad divitias congregandas, vel superfluum\", \"label\": \"la\"},\n    {\"text\": \"Norma Talmadge, spela mot Valentino i en version\", \"label\": \"sv\"},\n    {\"text\": \"Dlatego chciał się jej oświadczyć w niezwykłym\", \"label\": \"pl\"},\n    {\"text\": \"будут выступать на одинаковых снарядах.\", \"label\": \"ru\"},\n    {\"text\": \"Orang-orang terbunuh di sana.\", \"label\": \"id\"},\n    {\"text\": \"لدى رايت شقيق اسمه أوسكار, وهو يعمل كرسام للكتب\", \"label\": \"ar\"},\n    {\"text\": \"Wirklichkeit verlagerten und kaum noch\", \"label\": \"de\"},\n    {\"text\": \"как перемешивают костяшки перед игрой в домино, и\", \"label\": \"ru\"},\n    {\"text\": \"В средине дня, когда солнце светило в нашу\", \"label\": \"ru\"},\n    {\"text\": \"d'aventure aux rôles de jeune romantique avec une\", \"label\": \"fr\"},\n    {\"text\": \"My teď hledáme organizace, jež by s námi chtěly\", \"label\": \"cs\"},\n    {\"text\": \"Urteilsfähigkeit einbüßen , wenn ich eigene\", \"label\": \"de\"},\n    {\"text\": \"sua appartenenza anche a voci diverse da quella in\", \"label\": \"it\"},\n    {\"text\": \"Aufträge dieses Jahr verdoppeln werden .\", \"label\": \"de\"},\n    {\"text\": \"M.E.: Miała szanse mnie odnaleźć, gdyby naprawdę\", \"label\": \"pl\"},\n    {\"text\": \"secundum contactum virtutis, cum careat dimensiva\", \"label\": \"la\"},\n    {\"text\": \"ezinbestekoa dela esan zuen.\", \"label\": \"eu\"},\n    {\"text\": \"Anek hurbiltzeko eskatzen zion besaulkitik, eta\", \"label\": \"eu\"},\n    {\"text\": \"perfectius alio videat, quamvis uterque videat\", \"label\": \"la\"},\n    {\"text\": \"Die Strecke war anspruchsvoll und führte unter\", \"label\": \"de\"},\n    {\"text\": \"саморазоблачительным уроком, западные СМИ не\", \"label\": \"ru\"},\n    {\"text\": \"han representerer radikal islamisme .\", \"label\": \"no\"},\n    {\"text\": \"Què s'hi respira pel que fa a la reforma del\", \"label\": \"ca\"},\n    {\"text\": \"previsto para também ser desconstruido.\", \"label\": \"pt\"},\n    {\"text\": \"Ὠκεανοῦ βαθυκόλποις ἄνθεά τ̓ αἰνυμένην, ῥόδα καὶ\", \"label\": \"grc\"},\n    {\"text\": \"para jovens de a anos nos Cieps.\", \"label\": \"pt\"},\n    {\"text\": \"संघर्ष को अंजाम तक पहुंचाने का ऐलान किया है ।\", \"label\": \"hi\"},\n    {\"text\": \"objeví i u nás.\", \"label\": \"cs\"},\n    {\"text\": \"kvitteringer.\", \"label\": \"da\"},\n    {\"text\": \"This report is no exception.\", \"label\": \"en\"},\n    {\"text\": \"Разлепват доносниците до избирателните списъци\", \"label\": \"bg\"},\n    {\"text\": \"anderem ihre Bewegungsfreiheit in den USA\", \"label\": \"de\"},\n    {\"text\": \"Ñu tegoon ca kaw gor ña ay njotti bopp yu kenn\", \"label\": \"wo\"},\n    {\"text\": \"Struktur kann beispielsweise der Schwerpunkt mehr\", \"label\": \"de\"},\n    {\"text\": \"% la velocidad permitida, la sanción es muy grave.\", \"label\": \"es\"},\n    {\"text\": \"Teles-Einstieg in ADSL-Markt\", \"label\": \"de\"},\n    {\"text\": \"ettekäändeks liiga suure osamaksu.\", \"label\": \"et\"},\n    {\"text\": \"als Indiz für die geänderte Marktpolitik des\", \"label\": \"de\"},\n    {\"text\": \"quod quidem aperte consequitur ponentes\", \"label\": \"la\"},\n    {\"text\": \"de negociación para el próximo de junio.\", \"label\": \"es\"},\n    {\"text\": \"Tyto důmyslné dekorace doznaly v poslední době\", \"label\": \"cs\"},\n    {\"text\": \"največjega uspeha doslej.\", \"label\": \"sl\"},\n    {\"text\": \"Paul Allen je jedan od suosnivača Interval\", \"label\": \"hr\"},\n    {\"text\": \"Federal (Seac / DF) eo Sindicato das Empresas de\", \"label\": \"pt\"},\n    {\"text\": \"Quartal mit . Mark gegenüber dem gleichen Quartal\", \"label\": \"de\"},\n    {\"text\": \"otros clubes y del Barça B saldrán varios\", \"label\": \"es\"},\n    {\"text\": \"Jaskula (Pol.) -\", \"label\": \"cs\"},\n    {\"text\": \"umožnily říci, že je možné přejít k mnohem\", \"label\": \"cs\"},\n    {\"text\": \"اعلن الجنرال تومي فرانكس قائد القوات الامريكية\", \"label\": \"ar\"},\n    {\"text\": \"Telekom-Chef Ron Sommer und der Vorstandssprecher\", \"label\": \"de\"},\n    {\"text\": \"My, jako průmyslový a finanční holding, můžeme\", \"label\": \"cs\"},\n    {\"text\": \"voorlichting onder andere betrekking kan hebben:\", \"label\": \"nl\"},\n    {\"text\": \"Hinrichtung geistig Behinderter applaudiert oder\", \"label\": \"de\"},\n    {\"text\": \"wie beispielsweise Anzahl erzielte Klicks ,\", \"label\": \"de\"},\n    {\"text\": \"Intel-PC-SDRAM-Spezifikation in der Version . (\", \"label\": \"de\"},\n    {\"text\": \"plângere în termen de zile de la comunicarea\", \"label\": \"ro\"},\n    {\"text\": \"и Испания ще изгубят втория си комисар в ЕК.\", \"label\": \"bg\"},\n    {\"text\": \"इसके चलते इस आदिवासी जनजाति का क्षरण हो रहा है ।\", \"label\": \"hi\"},\n    {\"text\": \"aunque se mostró contrario a establecer un\", \"label\": \"es\"},\n    {\"text\": \"des letzten Jahres von auf Millionen Euro .\", \"label\": \"de\"},\n    {\"text\": \"Ankara se također poziva da u cijelosti ratificira\", \"label\": \"hr\"},\n    {\"text\": \"herunterlädt .\", \"label\": \"de\"},\n    {\"text\": \"стрессовую ситуацию для организма, каковой\", \"label\": \"ru\"},\n    {\"text\": \"Státního shromáždění (parlamentu).\", \"label\": \"cs\"},\n    {\"text\": \"diskutieren , ob und wie dieser Dienst weiterhin\", \"label\": \"de\"},\n    {\"text\": \"Verbindungen zu FPÖ-nahen Polizisten gepflegt und\", \"label\": \"de\"},\n    {\"text\": \"Pražského volebního lídra ovšem nevybírá Miloš\", \"label\": \"cs\"},\n    {\"text\": \"Nach einem Bericht der Washington Post bleibt das\", \"label\": \"de\"},\n    {\"text\": \"للوضع آنذاك، لكني في قرارة نفسي كنت سعيداً لما\", \"label\": \"ar\"},\n    {\"text\": \"не желаят запазването на статуквото.\", \"label\": \"bg\"},\n    {\"text\": \"Offenburg gewesen .\", \"label\": \"de\"},\n    {\"text\": \"ἐὰν ὑμῖν εἴπω οὐ μὴ πιστεύσητε\", \"label\": \"grc\"},\n    {\"text\": \"all'odiato compagno di squadra Prost, il quale\", \"label\": \"it\"},\n    {\"text\": \"historischen Gänselieselbrunnens.\", \"label\": \"de\"},\n    {\"text\": \"למידע מלווייני הריגול האמריקאיים העוקבים אחר\", \"label\": \"he\"},\n    {\"text\": \"οὐδὲν ἄρα διαφέρεις Ἀμάσιος τοῦ Ἠλείου, ὃν\", \"label\": \"grc\"},\n    {\"text\": \"movementos migratorios.\", \"label\": \"gl\"},\n    {\"text\": \"Handy und ein Spracherkennungsprogramm sämtliche\", \"label\": \"de\"},\n    {\"text\": \"Kümne aasta jooksul on Eestisse ohjeldamatult\", \"label\": \"et\"},\n    {\"text\": \"H.G. Bücknera.\", \"label\": \"pl\"},\n    {\"text\": \"protiv krijumčarenja, ili pak traženju ukidanja\", \"label\": \"hr\"},\n    {\"text\": \"Topware-Anteile mehrere Millionen Mark gefordert\", \"label\": \"de\"},\n    {\"text\": \"Maar de mensen die nu over Van Dijk bij FC Twente\", \"label\": \"nl\"},\n    {\"text\": \"poidan experimentar as percepcións do interesado,\", \"label\": \"gl\"},\n    {\"text\": \"Miał przecież w kieszeni nóż.\", \"label\": \"pl\"},\n    {\"text\": \"Avšak žádná z nich nepronikla za hranice přímé\", \"label\": \"cs\"},\n    {\"text\": \"esim. helpottamalla luottoja muiden\", \"label\": \"fi\"},\n    {\"text\": \"Podle předběžných výsledků zvítězila v\", \"label\": \"cs\"},\n    {\"text\": \"Nicht nur das Web-Frontend , auch die\", \"label\": \"de\"},\n    {\"text\": \"Regierungsinstitutionen oder Universitäten bei\", \"label\": \"de\"},\n    {\"text\": \"Խուլեն Լոպետեգիին, պատճառաբանելով, որ վերջինս\", \"label\": \"hy\"},\n    {\"text\": \"Афганистана, где в последние дни идут ожесточенные\", \"label\": \"ru\"},\n    {\"text\": \"лѧхове же не идоша\", \"label\": \"orv\"},\n    {\"text\": \"Mit Hilfe von IBMs Chip-Management-Systemen sollen\", \"label\": \"de\"},\n    {\"text\": \", als Manager zu Telefonica zu wechseln .\", \"label\": \"de\"},\n    {\"text\": \"którym zajmuje się człowiek, zmienia go i pozwala\", \"label\": \"pl\"},\n    {\"text\": \"činí kyperských liber, to je asi USD.\", \"label\": \"cs\"},\n    {\"text\": \"Studienplätze getauscht werden .\", \"label\": \"de\"},\n    {\"text\": \"учёных, орнитологов признают вид.\", \"label\": \"ru\"},\n    {\"text\": \"acordare a concediilor prevăzute de legislațiile\", \"label\": \"ro\"},\n    {\"text\": \"at større innsats for fornybar, berekraftig energi\", \"label\": \"nn\"},\n    {\"text\": \"Politiet veit ikkje kor mange personar som deltok\", \"label\": \"nn\"},\n    {\"text\": \"offentligheten av unge , sinte menn som har\", \"label\": \"no\"},\n    {\"text\": \"însuși în jurul lapunei, care încet DISPARE în\", \"label\": \"ro\"},\n    {\"text\": \"O motivo da decisão é evitar uma sobrecarga ainda\", \"label\": \"pt\"},\n    {\"text\": \"El Apostolado de la prensa contribuye en modo\", \"label\": \"es\"},\n    {\"text\": \"Teltow ( Kreis Teltow-Fläming ) ist Schmitt einer\", \"label\": \"de\"},\n    {\"text\": \"grozījumus un iesniegt tos Apvienoto Nāciju\", \"label\": \"lv\"},\n    {\"text\": \"Gestalt einer deutschen Nationalmannschaft als\", \"label\": \"de\"},\n    {\"text\": \"D überholt zu haben , konterte am heutigen Montag\", \"label\": \"de\"},\n    {\"text\": \"Softwarehersteller Oracle hat im dritten Quartal\", \"label\": \"de\"},\n    {\"text\": \"Během nich se ekonomické podmínky mohou radikálně\", \"label\": \"cs\"},\n    {\"text\": \"Dziki kot w górach zeskakuje z kamienia.\", \"label\": \"pl\"},\n    {\"text\": \"Ačkoliv ligový nováček prohrál, opět potvrdil, že\", \"label\": \"cs\"},\n    {\"text\": \"des Tages , Portraits internationaler Stars sowie\", \"label\": \"de\"},\n    {\"text\": \"Communicator bekannt wurde .\", \"label\": \"de\"},\n    {\"text\": \"τῷ δ’ ἄρα καὶ αὐτῷ ἡ γυνή ἐπίτεξ ἐοῦσα πᾶσαν\", \"label\": \"grc\"},\n    {\"text\": \"Triadú tenia, mentre redactava 'Dies de memòria',\", \"label\": \"ca\"},\n    {\"text\": \"دسته‌جمعی در درخشندگی ماه سیم‌گون زمزمه ستاینده و\", \"label\": \"fa\"},\n    {\"text\": \"Книгу, наполненную мелочной заботой об одежде,\", \"label\": \"ru\"},\n    {\"text\": \"putares canem leporem persequi.\", \"label\": \"la\"},\n    {\"text\": \"В дальнейшем эта яркость слегка померкла, но в\", \"label\": \"ru\"},\n    {\"text\": \"offizielles Verfahren gegen die Telekom\", \"label\": \"de\"},\n    {\"text\": \"podrían haber sido habitantes de la Península\", \"label\": \"es\"},\n    {\"text\": \"Grundlage für dieses Verfahren sind spezielle\", \"label\": \"de\"},\n    {\"text\": \"Rechtsausschuß vorgelegten Entwurf der Richtlinie\", \"label\": \"de\"},\n    {\"text\": \"Im so genannten Portalgeschäft sei das Unternehmen\", \"label\": \"de\"},\n    {\"text\": \"ⲏ ⲉⲓϣⲁⲛϥⲓ ⲛⲉⲓⲇⲱⲗⲟⲛ ⲉⲧϩⲙⲡⲉⲕⲏⲓ ⲙⲏ ⲉⲓⲛⲁϣϩⲱⲡ ⲟⲛ ⲙⲡⲣⲏ\", \"label\": \"cop\"},\n    {\"text\": \"juego podían matar a cualquier herbívoro, pero\", \"label\": \"es\"},\n    {\"text\": \"Nach Angaben von Axent nutzen Unternehmen aus der\", \"label\": \"de\"},\n    {\"text\": \"hrdiny Havlovy Zahradní slavnosti (premiéra ) se\", \"label\": \"cs\"},\n    {\"text\": \"Een zin van heb ik jou daar\", \"label\": \"nl\"},\n    {\"text\": \"hat sein Hirn an der CeBIT-Kasse vergessen .\", \"label\": \"de\"},\n    {\"text\": \"καὶ τοὺς ἐκπλαγέντας οὐκ ἔχειν ἔτι ἐλεγχομένους\", \"label\": \"grc\"},\n    {\"text\": \"nachgewiesenen langfristigen Kosten , sowie den im\", \"label\": \"de\"},\n    {\"text\": \"jučer nakon četiri dana putovanja u Helsinki.\", \"label\": \"hr\"},\n    {\"text\": \"pašto paslaugos teikėjas gali susitarti su\", \"label\": \"lt\"},\n    {\"text\": \"В результате, эти золотые кадры переходят из одной\", \"label\": \"ru\"},\n    {\"text\": \"द फाइव-ईयर एंगेजमेंट में अभिनय किया जिसमें जैसन\", \"label\": \"hi\"},\n    {\"text\": \"výpis o počtu akcií.\", \"label\": \"cs\"},\n    {\"text\": \"Enfin, elles arrivent à un pavillon chinois\", \"label\": \"fr\"},\n    {\"text\": \"Tentu saja, tren yang berhubungandengan\", \"label\": \"id\"},\n    {\"text\": \"Arbeidarpartiet og SV har sikra seg fleirtal mot\", \"label\": \"nn\"},\n    {\"text\": \"eles: 'Tudo isso está errado' , disse um\", \"label\": \"pt\"},\n    {\"text\": \"The islands are in their own time zone, minutes\", \"label\": \"en\"},\n    {\"text\": \"Auswahl debütierte er am .\", \"label\": \"de\"},\n    {\"text\": \"Bu komisyonlar, arazilerini satın almak için\", \"label\": \"tr\"},\n    {\"text\": \"Geschütze gegen Redmond aufgefahren .\", \"label\": \"de\"},\n    {\"text\": \"Time scything the hours, but at the top, over the\", \"label\": \"en\"},\n    {\"text\": \"Di musim semi , berharap mengadaptasi Tintin untuk\", \"label\": \"id\"},\n    {\"text\": \"крупнейшей геополитической катастрофой XX века.\", \"label\": \"ru\"},\n    {\"text\": \"Rajojen avaaminen ei suju ongelmitta .\", \"label\": \"fi\"},\n    {\"text\": \"непроницаемым, как для СССР.\", \"label\": \"ru\"},\n    {\"text\": \"Ma non mancano le polemiche.\", \"label\": \"it\"},\n    {\"text\": \"Internet als Ort politischer Diskussion und auch\", \"label\": \"de\"},\n    {\"text\": \"incomplets.\", \"label\": \"ca\"},\n    {\"text\": \"Su padre luchó al lado de Luis Moya, primer Jefe\", \"label\": \"es\"},\n    {\"text\": \"informazione.\", \"label\": \"it\"},\n    {\"text\": \"Primacom bietet für Telekom-Kabelnetz\", \"label\": \"de\"},\n    {\"text\": \"Oświadczenie prezydencji w imieniu Unii\", \"label\": \"pl\"},\n    {\"text\": \"foran rattet i familiens gamle Baleno hvis døra på\", \"label\": \"no\"},\n    {\"text\": \"[speaker:laughter]\", \"label\": \"sl\"},\n    {\"text\": \"Dog med langt mindre utstyr med seg.\", \"label\": \"nn\"},\n    {\"text\": \"dass es nicht schon mit der anfänglichen\", \"label\": \"de\"},\n    {\"text\": \"इस पर दोनों पक्षों में नोकझोंक शुरू हो गई ।\", \"label\": \"hi\"},\n    {\"text\": \"کے ترجمان منیش تیواری اور دگ وجئے سنگھ نے بھی یہ\", \"label\": \"ur\"},\n    {\"text\": \"dell'Assemblea Costituente che posseggono i\", \"label\": \"it\"},\n    {\"text\": \"и аште вьси съблазнѧтъ сѧ нъ не азъ\", \"label\": \"cu\"},\n    {\"text\": \"In Irvine hat auch das Logistikunternehmen Atlas\", \"label\": \"de\"},\n    {\"text\": \"законодательных норм, принимаемых существующей\", \"label\": \"ru\"},\n    {\"text\": \"Κροίσῳ προτείνων τὰς χεῖρας ἐπικατασφάξαι μιν\", \"label\": \"grc\"},\n    {\"text\": \"МИНУСЫ: ИНФЛЯЦИЯ И КРИЗИС В ЖИВОТНОВОДСТВЕ.\", \"label\": \"ru\"},\n    {\"text\": \"unterschiedlicher Meinung .\", \"label\": \"de\"},\n    {\"text\": \"Jospa joku ystävällinen sielu auttaisi kassieni\", \"label\": \"fi\"},\n    {\"text\": \"Añadió que, en el futuro se harán otros\", \"label\": \"es\"},\n    {\"text\": \"Sessiz tonlama hem Fince, hem de Kuzey Sami\", \"label\": \"tr\"},\n    {\"text\": \"nicht ihnen gehört und sie nicht alles , was sie\", \"label\": \"de\"},\n    {\"text\": \"Etelästä Kuivajärveen laskee Tammelan Liesjärvestä\", \"label\": \"fi\"},\n    {\"text\": \"ICANNs Vorsitzender Vint Cerf warb mit dem Hinweis\", \"label\": \"de\"},\n    {\"text\": \"Norsk politikk frå til kan dermed, i\", \"label\": \"nn\"},\n    {\"text\": \"Głosowało posłów.\", \"label\": \"pl\"},\n    {\"text\": \"Danny Jones -- smithjones@ev.net\", \"label\": \"en\"},\n    {\"text\": \"sebeuvědomění moderní civilizace sehrála lučavka\", \"label\": \"cs\"},\n    {\"text\": \"относительно спокойный сон: тому гарантия\", \"label\": \"ru\"},\n    {\"text\": \"A halte voiz prist li pedra a crïer\", \"label\": \"fro\"},\n    {\"text\": \"آن‌ها امیدوارند این واکسن به‌زودی در دسترس بیماران\", \"label\": \"fa\"},\n    {\"text\": \"vlastní důstojnou vousatou tváří.\", \"label\": \"cs\"},\n    {\"text\": \"ora aprire la strada a nuove cause e alimentare il\", \"label\": \"it\"},\n    {\"text\": \"Die Zahl der Vielleser nahm von auf Prozent zu ,\", \"label\": \"de\"},\n    {\"text\": \"Finanzvorstand von Hotline-Dienstleister InfoGenie\", \"label\": \"de\"},\n    {\"text\": \"entwickeln .\", \"label\": \"de\"},\n    {\"text\": \"incolumità pubblica.\", \"label\": \"it\"},\n    {\"text\": \"lehtija televisiomainonta\", \"label\": \"fi\"},\n    {\"text\": \"joistakin kohdista eri mieltä.\", \"label\": \"fi\"},\n    {\"text\": \"Hlavně anglická nezávislá scéna, Dead Can Dance,\", \"label\": \"cs\"},\n    {\"text\": \"pásmech od do bodů bodové stupnice.\", \"label\": \"cs\"},\n    {\"text\": \"Zu Beginn des Ersten Weltkrieges zählte das\", \"label\": \"de\"},\n    {\"text\": \"Així van sorgir, damunt els antics cementiris,\", \"label\": \"ca\"},\n    {\"text\": \"In manchem Gedicht der spätern Alten, wie zum\", \"label\": \"de\"},\n    {\"text\": \"gaweihaida jah insandida in þana fairƕu jus qiþiþ\", \"label\": \"got\"},\n    {\"text\": \"Beides sollte gelöscht werden!\", \"label\": \"de\"},\n    {\"text\": \"modifiqués la seva petició inicial de anys de\", \"label\": \"ca\"},\n    {\"text\": \"В день открытия симпозиума состоялась закладка\", \"label\": \"ru\"},\n    {\"text\": \"tõestatud.\", \"label\": \"et\"},\n    {\"text\": \"ἵππῳ πίπτει αὐτοῦ ταύτῃ\", \"label\": \"grc\"},\n    {\"text\": \"bisher nie enttäuscht!\", \"label\": \"de\"},\n    {\"text\": \"De bohte ollu tuollárat ja suttolaččat ja\", \"label\": \"sme\"},\n    {\"text\": \"Klarsignal från röstlängdsläsaren, tre tryck i\", \"label\": \"sv\"},\n    {\"text\": \"Tvůrcem nového termínu je Joseph Fisher.\", \"label\": \"cs\"},\n    {\"text\": \"Nie miałem czasu na reakcję twierdzi Norbert,\", \"label\": \"pl\"},\n    {\"text\": \"potentia Schöpfer.\", \"label\": \"de\"},\n    {\"text\": \"Un poquito caro, pero vale mucho la pena;\", \"label\": \"es\"},\n    {\"text\": \"οὔ τε γὰρ ἴφθιμοι Λύκιοι Δαναῶν ἐδύναντο τεῖχος\", \"label\": \"grc\"},\n    {\"text\": \"vajec, sladového výtažku a některých vitamínových\", \"label\": \"cs\"},\n    {\"text\": \"Настоящие герои, те, чьи истории потом\", \"label\": \"ru\"},\n    {\"text\": \"praesumptio:\", \"label\": \"la\"},\n    {\"text\": \"Olin justkui nende vastutusel.\", \"label\": \"et\"},\n    {\"text\": \"Jokainen keinahdus tuo lähemmäksi hetkeä jolloin\", \"label\": \"fi\"},\n    {\"text\": \"ekonomicky výhodných způsobů odvodnění těžkých,\", \"label\": \"cs\"},\n    {\"text\": \"Poprvé ve své historii dokázala v kvalifikaci pro\", \"label\": \"cs\"},\n    {\"text\": \"zpracovatelského a spotřebního průmyslu bude nutné\", \"label\": \"cs\"},\n    {\"text\": \"Windows CE zu integrieren .\", \"label\": \"de\"},\n    {\"text\": \"Armangué, a través d'un decret, ordenés l'aturada\", \"label\": \"ca\"},\n    {\"text\": \"to, co nás Evropany spojuje, než to, co nás od\", \"label\": \"cs\"},\n    {\"text\": \"ergänzt durch einen gesetzlich verankertes\", \"label\": \"de\"},\n    {\"text\": \"Насчитал, что с начала года всего три дня были\", \"label\": \"ru\"},\n    {\"text\": \"Borisovu tražeći od njega da prihvati njenu\", \"label\": \"sr\"},\n    {\"text\": \"la presenza di ben veleni diversi: . chili di\", \"label\": \"it\"},\n    {\"text\": \"καὶ τῶν ἐκλεκτῶν ἀγγέλων ἵνα ταῦτα φυλάξῃς χωρὶς\", \"label\": \"grc\"},\n    {\"text\": \"pretraživale obližnju bolnicu i stambene zgrade u\", \"label\": \"hr\"},\n    {\"text\": \"An rund Katzen habe Wolf seine Spiele getestet ,\", \"label\": \"de\"},\n    {\"text\": \"investigating since March.\", \"label\": \"en\"},\n    {\"text\": \"Tonböden (Mullböden).\", \"label\": \"de\"},\n    {\"text\": \"Stálý dopisovatel LN v SRN Bedřich Utitz\", \"label\": \"cs\"},\n    {\"text\": \"červnu předložené smlouvy.\", \"label\": \"cs\"},\n    {\"text\": \"πνεύματι ᾧ ἐλάλει\", \"label\": \"grc\"},\n    {\"text\": \".%의 신장세를 보였다.\", \"label\": \"ko\"},\n    {\"text\": \"Foae verde, foi de nuc, Prin pădure, prin colnic,\", \"label\": \"ro\"},\n    {\"text\": \"διαπέμψας ἄλλους ἄλλῃ τοὺς μὲν ἐς Δελφοὺς ἰέναι\", \"label\": \"grc\"},\n    {\"text\": \"المسلمين أو أي تيار سياسي طالما عمل ذلك التيار في\", \"label\": \"ar\"},\n    {\"text\": \"As informações são da Dow Jones.\", \"label\": \"pt\"},\n    {\"text\": \"Milliarde DM ausgestattet sein .\", \"label\": \"de\"},\n    {\"text\": \"De utgår fortfarande från att kvinnans jämlikhet\", \"label\": \"sv\"},\n    {\"text\": \"Sneeuw maakte in Davos bij de voorbereiding een\", \"label\": \"nl\"},\n    {\"text\": \"De ahí que en este mercado puedan negociarse\", \"label\": \"es\"},\n    {\"text\": \"intenzívnějšímu sbírání a studiu.\", \"label\": \"cs\"},\n    {\"text\": \"और औसकर ४.० पैकेज का प्रयोग किया गया है ।\", \"label\": \"hi\"},\n    {\"text\": \"Adipati Kuningan karena Kuningan menjadi bagian\", \"label\": \"id\"},\n    {\"text\": \"Svako je bar jednom poželeo da mašine prosto umeju\", \"label\": \"sr\"},\n    {\"text\": \"Im vergangenen Jahr haben die Regierungen einen\", \"label\": \"de\"},\n    {\"text\": \"durat motus, aliquid fit et non est;\", \"label\": \"la\"},\n    {\"text\": \"Dominować będą piosenki do tekstów Edwarda\", \"label\": \"pl\"},\n    {\"text\": \"beantwortet .\", \"label\": \"de\"},\n    {\"text\": \"О гуманитариях было кому рассказывать, а вот за\", \"label\": \"ru\"},\n    {\"text\": \"Helsingin kaupunki riitautti vuokrasopimuksen\", \"label\": \"fi\"},\n    {\"text\": \"chợt tan biến.\", \"label\": \"vi\"},\n    {\"text\": \"avtomobil ločuje od drugih.\", \"label\": \"sl\"},\n    {\"text\": \"Congress has proven itself ineffective as a body.\", \"label\": \"en\"},\n    {\"text\": \"मैक्सिको ने इस तरह का शो इस समय आयोजित करने का\", \"label\": \"hi\"},\n    {\"text\": \"No minimum order amount.\", \"label\": \"en\"},\n    {\"text\": \"Convertassa .\", \"label\": \"fi\"},\n    {\"text\": \"Как это можно сделать?\", \"label\": \"ru\"},\n    {\"text\": \"tha mi creidsinn gu robh iad ceart cho saor shuas\", \"label\": \"gd\"},\n    {\"text\": \"실제 일제는 이런 만해의 논리를 묵살하고 한반도를 침략한 다음 , 이어 만주를 침략하고\", \"label\": \"ko\"},\n    {\"text\": \"Da un semplice richiamo all'ordine fino a grandi\", \"label\": \"it\"},\n    {\"text\": \"pozoruhodný nejen po umělecké stránce, jež\", \"label\": \"cs\"},\n    {\"text\": \"La comida y el servicio aprueban.\", \"label\": \"es\"},\n    {\"text\": \"again, connected not with each other but to the\", \"label\": \"en\"},\n    {\"text\": \"Protokol výslovně stanoví, že nikdo nemůže být\", \"label\": \"cs\"},\n    {\"text\": \"ఒక విషయం అడగాలని ఉంది .\", \"label\": \"te\"},\n    {\"text\": \"Безгранично почитая дирекцию, ловя на лету каждое\", \"label\": \"ru\"},\n    {\"text\": \"rovnoběžných růstových vrstev, zůstávají krychlové\", \"label\": \"cs\"},\n    {\"text\": \"प्रवेश और पूर्व प्रधानमंत्री लाल बहादुर शास्त्री\", \"label\": \"hi\"},\n    {\"text\": \"Bronzen medaille in de Europese marathon.\", \"label\": \"nl\"},\n    {\"text\": \"- gadu vecumā viņi to nesaprot.\", \"label\": \"lv\"},\n    {\"text\": \"Realizó sus estudios primarios en la Escuela Julia\", \"label\": \"es\"},\n    {\"text\": \"cuartos de final, su clasificación para la final a\", \"label\": \"es\"},\n    {\"text\": \"Sem si pro něho přiletí americký raketoplán, na\", \"label\": \"cs\"},\n    {\"text\": \"Way to go!\", \"label\": \"en\"},\n    {\"text\": \"gehört der neuen SPD-Führung unter Parteichef\", \"label\": \"de\"},\n    {\"text\": \"Somit simuliert der Player mit einer GByte-Platte\", \"label\": \"de\"},\n    {\"text\": \"Berufung auf kommissionsnahe Kreise , die bereits\", \"label\": \"de\"},\n    {\"text\": \"Dist Clarïen\", \"label\": \"fro\"},\n    {\"text\": \"Schon nach den Gerüchten , die Telekom wolle den\", \"label\": \"de\"},\n    {\"text\": \"Software von NetObjects ist nach Angaben des\", \"label\": \"de\"},\n    {\"text\": \"si enim per legem iustitia ergo Christus gratis\", \"label\": \"la\"},\n    {\"text\": \"ducerent in ipsam magis quam in corpus christi,\", \"label\": \"la\"},\n    {\"text\": \"Neustar-Melbourne-IT-Partnerschaft NeuLevel .\", \"label\": \"de\"},\n    {\"text\": \"forderte dagegen seine drastische Verschärfung.\", \"label\": \"de\"},\n    {\"text\": \"pemmican på hundrede forskellige måder.\", \"label\": \"da\"},\n    {\"text\": \"Lehån, själv matematiklärare, visar hur den nya\", \"label\": \"sv\"},\n    {\"text\": \"I highly recommend his shop.\", \"label\": \"en\"},\n    {\"text\": \"verità, giovani fedeli prostratevi #amen\", \"label\": \"it\"},\n    {\"text\": \"उत्तर प्रदेश के अध्यक्ष पद से हटाए गए विनय कटियार\", \"label\": \"hi\"},\n    {\"text\": \"() روزی مےں کشادگی ہوتی ہے۔\", \"label\": \"ur\"},\n    {\"text\": \"Prozessorgeschäft profitieren kann , stellen\", \"label\": \"de\"},\n    {\"text\": \"školy začalo počítat pytle s moukou a zjistilo, že\", \"label\": \"cs\"},\n    {\"text\": \"प्रभावशाली पर गैर सरकारी लोगों के घरों में भी\", \"label\": \"hi\"},\n    {\"text\": \"geschichtslos , oder eine Farce , wie sich\", \"label\": \"de\"},\n    {\"text\": \"Ústrednými mocnosťami v marci však spôsobilo, že\", \"label\": \"sk\"},\n    {\"text\": \"التسليح بدون مبرر، واستمرار الأضرار الناجمة عن فرض\", \"label\": \"ar\"},\n    {\"text\": \"Například Pedagogická fakulta Univerzity Karlovy\", \"label\": \"cs\"},\n    {\"text\": \"nostris ut eriperet nos de praesenti saeculo\", \"label\": \"la\"}]\n    \n    docs = [Document([], text=example[\"text\"]) for example in examples]\n    gold_labels = [example[\"label\"] for example in examples]\n    basic_multilingual(docs)\n    accuracy = sum([(doc.lang == label) for doc,label in zip(docs,gold_labels)])/len(docs)\n    assert accuracy >= 0.98\n\n\ndef test_text_cleaning(basic_multilingual, clean_multilingual):\n    \"\"\"\n    Basic test of cleaning text\n    \"\"\"\n    docs = [\"Bonjour le monde! #thisisfrench #ilovefrance\",\n            \"Bonjour le monde! https://t.co/U0Zjp3tusD\"]\n    docs = [Document([], text=text) for text in docs]\n    \n    basic_multilingual(docs)\n    assert [doc.lang for doc in docs] == [\"it\", \"it\"]\n    \n    assert clean_multilingual.processors[\"langid\"]._clean_text\n    clean_multilingual(docs)\n    assert [doc.lang for doc in docs] == [\"fr\", \"fr\"]\n\ndef test_emoji_cleaning():\n    TEXT = [\"Sh'reyan has nice antennae :thumbs_up:\",\n            \"This is🐱 a cat\"]\n    EXPECTED = [\"Sh'reyan has nice antennae\",\n                \"This is  a cat\"]\n    for text, expected in zip(TEXT, EXPECTED):\n        assert LangIDProcessor.clean_text(text) == expected\n\ndef test_lang_subset(basic_multilingual, enfr_multilingual, en_multilingual):\n    \"\"\"\n    Basic test of restricting output to subset of languages\n    \"\"\"\n    docs = [\"Bonjour le monde! #thisisfrench #ilovefrance\",\n            \"Bonjour le monde! https://t.co/U0Zjp3tusD\"]\n    docs = [Document([], text=text) for text in docs]\n\n    basic_multilingual(docs)\n    assert [doc.lang for doc in docs] == [\"it\", \"it\"]\n\n    assert enfr_multilingual.processors[\"langid\"]._model.lang_subset == [\"en\", \"fr\"]\n    enfr_multilingual(docs)\n    assert [doc.lang for doc in docs] == [\"fr\", \"fr\"]\n\n    assert en_multilingual.processors[\"langid\"]._model.lang_subset == [\"en\"]\n    en_multilingual(docs)\n    assert [doc.lang for doc in docs] == [\"en\", \"en\"]\n\ndef test_lang_subset_unlikely_language(en_multilingual):\n    \"\"\"\n    Test that the language subset masking chooses a legal language, even if all legal languages are supa unlikely\n    \"\"\"\n    sentences = [\"你好\" * 200]\n    docs = [Document([], text=text) for text in sentences]\n    en_multilingual(docs)\n    assert [doc.lang for doc in docs] == [\"en\"]\n\n    processor = en_multilingual.processors['langid']\n    model = processor._model\n    text_tensor = processor._text_to_tensor(sentences)\n    en_idx = model.tag_to_idx['en']\n    predictions = model(text_tensor)\n    assert predictions[0, en_idx] < 0, \"If this test fails, then regardless of how unlikely it was, the model is predicting the input string is possibly English.  Update the test by picking a different combination of languages & input\"\n\n"
  },
  {
    "path": "stanza/tests/langid/test_multilingual.py",
    "content": "\"\"\"\nTests specifically for the MultilingualPipeline\n\"\"\"\n\nfrom collections import defaultdict\n\nimport pytest\n\nfrom stanza.pipeline.multilingual import MultilingualPipeline\n\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=True, **kwargs):\n    english_text = \"This is an English sentence.\"\n    english_words = [\"This\", \"is\", \"an\", \"English\", \"sentence\", \".\"]\n    english_deps_gold = \"\\n\".join((\n        \"('This', 5, 'nsubj')\",\n        \"('is', 5, 'cop')\",\n        \"('an', 5, 'det')\",\n        \"('English', 5, 'amod')\",\n        \"('sentence', 0, 'root')\",\n        \"('.', 5, 'punct')\"\n    ))\n    if not en_has_dependencies:\n        english_deps_gold = \"\"\n\n    french_text = \"C'est une phrase française.\"\n    french_words = [\"C'\", \"est\", \"une\", \"phrase\", \"française\", \".\"]\n    french_deps_gold = \"\\n\".join((\n        \"(\\\"C'\\\", 4, 'nsubj')\",\n        \"('est', 4, 'cop')\",\n        \"('une', 4, 'det')\",\n        \"('phrase', 0, 'root')\",\n        \"('française', 4, 'amod')\",\n        \"('.', 4, 'punct')\"\n    ))\n    if not fr_has_dependencies:\n        french_deps_gold = \"\"\n\n    if 'lang_configs' in kwargs:\n        nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, **kwargs)\n    else:\n        lang_configs = {\"en\": {\"processors\": \"tokenize,pos,lemma,depparse\"},\n                        \"fr\": {\"processors\": \"tokenize,pos,lemma,depparse\"}}\n        nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, lang_configs=lang_configs, **kwargs)\n    docs = [english_text, french_text]\n    docs = nlp(docs)\n\n    assert docs[0].lang == \"en\"\n    assert len(docs[0].sentences) == 1\n    assert [x.text for x in docs[0].sentences[0].words] == english_words\n    assert docs[0].sentences[0].dependencies_string() == english_deps_gold\n\n    assert len(docs[1].sentences) == 1\n    assert docs[1].lang == \"fr\"\n    assert [x.text for x in docs[1].sentences[0].words] == french_words\n    assert docs[1].sentences[0].dependencies_string() == french_deps_gold\n\n\ndef test_multilingual_pipeline():\n    \"\"\"\n    Basic test of multilingual pipeline\n    \"\"\"\n    run_multilingual_pipeline()\n\ndef test_multilingual_pipeline_small_cache():\n    \"\"\"\n    Test with the cache size 1\n    \"\"\"\n    run_multilingual_pipeline(max_cache_size=1)\n\n\ndef test_multilingual_config():\n    \"\"\"\n    Test with only tokenize for the EN pipeline\n    \"\"\"\n    lang_configs = {\n        \"en\": {\"processors\": \"tokenize\"}\n    }\n\n    run_multilingual_pipeline(en_has_dependencies=False, lang_configs=lang_configs)\n\ndef test_multilingual_processors_limited():\n    \"\"\"\n    Test loading an available subset of processors\n    \"\"\"\n    run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors=\"tokenize\")\n    run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs={\"en\": {\"processors\": \"tokenize,pos,lemma,depparse\"}}, processors=\"tokenize\")\n    # this should not fail, as it will drop the zzzzzzzzzz processor for the languages which don't have it\n    run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors=\"tokenize,zzzzzzzzzz\")\n\n\ndef test_defaultdict_config():\n    \"\"\"\n    Test that you can pass in a defaultdict for the lang_configs argument\n    \"\"\"\n    lang_configs = defaultdict(lambda: dict(processors=\"tokenize\"))\n    run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs=lang_configs)\n\n    lang_configs = defaultdict(lambda: dict(processors=\"tokenize\"))\n    lang_configs[\"en\"] = {\"processors\": \"tokenize,pos,lemma,depparse\"}\n    run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs=lang_configs)\n"
  },
  {
    "path": "stanza/tests/lemma/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/lemma/test_data.py",
    "content": "\"\"\"\nTest a couple basic data functions, such as processing a doc for its lemmas\n\"\"\"\n\nimport pytest\n\nfrom stanza.models.common.doc import Document\nfrom stanza.models.lemma.data import DataLoader\nfrom stanza.utils.conll import CoNLL\n\nTRAIN_DATA = \"\"\"\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003\n# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.\n1\tDPA\tDPA\tPROPN\tNNP\tNumber=Sing\t0\troot\t0:root\tSpaceAfter=No\n2\t:\t:\tPUNCT\t:\t_\t1\tpunct\t1:punct\t_\n3\tIraqi\tIraqi\tADJ\tJJ\tDegree=Pos\t4\tamod\t4:amod\t_\n4\tauthorities\tauthority\tNOUN\tNNS\tNumber=Plur\t5\tnsubj\t5:nsubj\t_\n5\tannounced\tannounce\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t1\tparataxis\t1:parataxis\t_\n6\tthat\tthat\tSCONJ\tIN\t_\t9\tmark\t9:mark\t_\n7\tthey\tthey\tPRON\tPRP\tCase=Nom|Number=Plur|Person=3|PronType=Prs\t9\tnsubj\t9:nsubj\t_\n8\thad\thave\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t9\taux\t9:aux\t_\n9\tbusted\tbust\tVERB\tVBN\tTense=Past|VerbForm=Part\t5\tccomp\t5:ccomp\t_\n10\tup\tup\tADP\tRP\t_\t9\tcompound:prt\t9:compound:prt\t_\n11\t3\t3\tNUM\tCD\tNumForm=Digit|NumType=Card\t13\tnummod\t13:nummod\t_\n12\tterrorist\tterrorist\tADJ\tJJ\tDegree=Pos\t13\tamod\t13:amod\t_\n13\tcells\tcell\tNOUN\tNNS\tNumber=Plur\t9\tobj\t9:obj\t_\n14\toperating\toperate\tVERB\tVBG\tVerbForm=Ger\t13\tacl\t13:acl\t_\n15\tin\tin\tADP\tIN\t_\t16\tcase\t16:case\t_\n16\tBaghdad\tBaghdad\tPROPN\tNNP\tNumber=Sing\t14\tobl\t14:obl:in\tSpaceAfter=No\n17\t.\t.\tPUNCT\t.\t_\t1\tpunct\t1:punct\t_\n\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004\n# text = Two of them were being run by 2 officials of the Ministry of the Interior!\n1\tTwo\ttwo\tNUM\tCD\tNumForm=Word|NumType=Card\t6\tnsubj:pass\t6:nsubj:pass\t_\n2\tof\tof\tADP\tIN\t_\t3\tcase\t3:case\t_\n3\tthem\tthey\tPRON\tPRP\tCase=Acc|Number=Plur|Person=3|PronType=Prs\t1\tnmod\t1:nmod:of\t_\n4\twere\tbe\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t6\taux\t6:aux\t_\n5\tbeing\tbe\tAUX\tVBG\tVerbForm=Ger\t6\taux:pass\t6:aux:pass\t_\n6\trun\trun\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t0:root\t_\n7\tby\tby\tADP\tIN\t_\t9\tcase\t9:case\t_\n8\t2\t2\tNUM\tCD\tNumForm=Digit|NumType=Card\t9\tnummod\t9:nummod\t_\n9\tofficials\tofficial\tNOUN\tNNS\tNumber=Plur\t6\tobl\t6:obl:by\t_\n10\tof\tof\tADP\tIN\t_\t12\tcase\t12:case\t_\n11\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t12\tdet\t12:det\t_\n12\tMinistry\tMinistry\tPROPN\tNNP\tNumber=Sing\t9\tnmod\t9:nmod:of\t_\n13\tof\tof\tADP\tIN\t_\t15\tcase\t15:case\t_\n14\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t15\tdet\t15:det\t_\n15\tInterior\tInterior\tPROPN\tNNP\tNumber=Sing\t12\tnmod\t12:nmod:of\tSpaceAfter=No\n16\t!\t!\tPUNCT\t.\t_\t6\tpunct\t6:punct\t_\n\n\"\"\".lstrip()\n\nGOESWITH_DATA = \"\"\"\n# sent_id = email-enronsent27_01-0041\n# newpar id = email-enronsent27_01-p0005\n# text = Ken Rice@ENRON COMMUNICATIONS\n1\tKen\tkenrice@enroncommunications\tX\tGW\tTypo=Yes\t0\troot\t0:root\t_\n2\tRice@ENRON\t_\tX\tGW\t_\t1\tgoeswith\t1:goeswith\t_\n3\tCOMMUNICATIONS\t_\tX\tADD\t_\t1\tgoeswith\t1:goeswith\t_\n\n\"\"\".lstrip()\n\nCORRECT_FORM_DATA = \"\"\"\n# sent_id = weblog-blogspot.com_healingiraq_20040409053012_ENG_20040409_053012-0019\n# text = They are targetting ambulances\n1\tThey\tthey\tPRON\tPRP\tCase=Nom|Number=Plur|Person=3|PronType=Prs\t3\tnsubj\t3:nsubj\t_\n2\tare\tbe\tAUX\tVBP\tMood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin\t3\taux\t3:aux\t_\n3\ttargetting\ttarget\tVERB\tVBG\tTense=Pres|Typo=Yes|VerbForm=Part\t0\troot\t0:root\tCorrectForm=targeting\n4\tambulances\tambulance\tNOUN\tNNS\tNumber=Plur\t3\tobj\t3:obj\tSpaceAfter=No\n\"\"\"\n\nBLANKS_DATA = \"\"\"\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0018\n# text = Guerrillas killed an engineer, Asi Ali, from Tikrit.\n1\tGuerrillas\t_\tNOUN\tNNS\tNumber=Plur\t2\tnsubj\t2:nsubj\t_\n2\tkilled\t_\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n3\tan\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t4\tdet\t4:det\t_\n4\tengineer\t_\tNOUN\tNN\tNumber=Sing\t2\tobj\t2:obj\tSpaceAfter=No\n\n\"\"\".lstrip()\n\n\ndef test_load_document():\n    train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)\n    data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)\n    assert len(data) == 33 # meticulously counted by hand\n    assert all(len(x) == 3 for x in data)\n\n    data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)\n    assert len(data) == 33\n    assert all(len(x) == 3 for x in data)\n\ndef test_load_goeswith():\n    raw_data = TRAIN_DATA + GOESWITH_DATA\n    train_doc = CoNLL.conll2doc(input_str=raw_data)\n    data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)\n    assert len(data) == 36 # will be the same as in test_load_document with three additional words\n    assert all(len(x) == 3 for x in data)\n\n    data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)\n    assert len(data) == 33 # will be the same as in test_load_document, but with the trailing 3 GOESWITH removed\n    assert all(len(x) == 3 for x in data)\n\ndef test_correct_form():\n    raw_data = TRAIN_DATA + CORRECT_FORM_DATA\n    train_doc = CoNLL.conll2doc(input_str=raw_data)\n    data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)\n    assert len(data) == 37\n    # the 'targeting' correction should not be applied if evaluation=True\n    # when evaluation=False, then the CorrectForms will be applied\n    assert not any(x[0] == 'targeting' for x in data)\n\n    data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)\n    assert len(data) == 38 # the same, but with an extra row so the model learns both 'targetting' and 'targeting'\n    assert any(x[0] == 'targeting' for x in data)\n    assert any(x[0] == 'targetting' for x in data)\n\ndef test_load_blank():\n    raw_data = TRAIN_DATA + BLANKS_DATA\n    train_doc = CoNLL.conll2doc(input_str=raw_data)\n    data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)\n    assert len(data) == 37 # will be the same as in test_load_document with FOUR additional words\n    assert all(len(x) == 3 for x in data)\n\n    data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=True, evaluation=False)\n    assert len(data) == 34 # will be the same as in test_load_document, but one extra word is added.  others were blank\n    assert all(len(x) == 3 for x in data)\n\n"
  },
  {
    "path": "stanza/tests/lemma/test_lemma_trainer.py",
    "content": "\"\"\"\nTest a couple basic functions - load & save an existing model\n\"\"\"\n\nimport pytest\n\nimport glob\nimport os\nimport tempfile\n\nimport torch\n\nfrom stanza.models import lemmatizer\nfrom stanza.models.lemma import trainer\nfrom stanza.tests import *\nfrom stanza.utils.training.common import choose_lemma_charlm, build_charlm_args\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n@pytest.fixture(scope=\"module\")\ndef english_model():\n    models_path = os.path.join(TEST_MODELS_DIR, \"en\", \"lemma\", \"*\")\n    models = glob.glob(models_path)\n    # we expect at least one English model downloaded for the tests\n    assert len(models) >= 1, \"No English lemma models downloaded during setup!  Please make sure to run the setup script.\"\n    for model_file in models:\n        if \"nocharlm\" in model_file:\n            return trainer.Trainer(model_file=model_file)\n    raise FileNotFoundError(\"Should have downloaded the nocharlm English lemmatizer during setup.  Please rerun the setup script.\")\n\ndef test_load_model(english_model):\n    \"\"\"\n    Does nothing, just tests that loading works\n    \"\"\"\n\ndef test_save_load_model(english_model):\n    \"\"\"\n    Load, save, and load again\n    \"\"\"\n    with tempfile.TemporaryDirectory() as tempdir:\n        save_file = os.path.join(tempdir, \"resaved\", \"lemma.pt\")\n        english_model.save(save_file)\n        reloaded = trainer.Trainer(model_file=save_file)\n\nTRAIN_DATA = \"\"\"\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003\n# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.\n1\tDPA\tDPA\tPROPN\tNNP\tNumber=Sing\t0\troot\t0:root\tSpaceAfter=No\n2\t:\t:\tPUNCT\t:\t_\t1\tpunct\t1:punct\t_\n3\tIraqi\tIraqi\tADJ\tJJ\tDegree=Pos\t4\tamod\t4:amod\t_\n4\tauthorities\tauthority\tNOUN\tNNS\tNumber=Plur\t5\tnsubj\t5:nsubj\t_\n5\tannounced\tannounce\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t1\tparataxis\t1:parataxis\t_\n6\tthat\tthat\tSCONJ\tIN\t_\t9\tmark\t9:mark\t_\n7\tthey\tthey\tPRON\tPRP\tCase=Nom|Number=Plur|Person=3|PronType=Prs\t9\tnsubj\t9:nsubj\t_\n8\thad\thave\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t9\taux\t9:aux\t_\n9\tbusted\tbust\tVERB\tVBN\tTense=Past|VerbForm=Part\t5\tccomp\t5:ccomp\t_\n10\tup\tup\tADP\tRP\t_\t9\tcompound:prt\t9:compound:prt\t_\n11\t3\t3\tNUM\tCD\tNumForm=Digit|NumType=Card\t13\tnummod\t13:nummod\t_\n12\tterrorist\tterrorist\tADJ\tJJ\tDegree=Pos\t13\tamod\t13:amod\t_\n13\tcells\tcell\tNOUN\tNNS\tNumber=Plur\t9\tobj\t9:obj\t_\n14\toperating\toperate\tVERB\tVBG\tVerbForm=Ger\t13\tacl\t13:acl\t_\n15\tin\tin\tADP\tIN\t_\t16\tcase\t16:case\t_\n16\tBaghdad\tBaghdad\tPROPN\tNNP\tNumber=Sing\t14\tobl\t14:obl:in\tSpaceAfter=No\n17\t.\t.\tPUNCT\t.\t_\t1\tpunct\t1:punct\t_\n\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004\n# text = Two of them were being run by 2 officials of the Ministry of the Interior!\n1\tTwo\ttwo\tNUM\tCD\tNumForm=Word|NumType=Card\t6\tnsubj:pass\t6:nsubj:pass\t_\n2\tof\tof\tADP\tIN\t_\t3\tcase\t3:case\t_\n3\tthem\tthey\tPRON\tPRP\tCase=Acc|Number=Plur|Person=3|PronType=Prs\t1\tnmod\t1:nmod:of\t_\n4\twere\tbe\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t6\taux\t6:aux\t_\n5\tbeing\tbe\tAUX\tVBG\tVerbForm=Ger\t6\taux:pass\t6:aux:pass\t_\n6\trun\trun\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t0:root\t_\n7\tby\tby\tADP\tIN\t_\t9\tcase\t9:case\t_\n8\t2\t2\tNUM\tCD\tNumForm=Digit|NumType=Card\t9\tnummod\t9:nummod\t_\n9\tofficials\tofficial\tNOUN\tNNS\tNumber=Plur\t6\tobl\t6:obl:by\t_\n10\tof\tof\tADP\tIN\t_\t12\tcase\t12:case\t_\n11\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t12\tdet\t12:det\t_\n12\tMinistry\tMinistry\tPROPN\tNNP\tNumber=Sing\t9\tnmod\t9:nmod:of\t_\n13\tof\tof\tADP\tIN\t_\t15\tcase\t15:case\t_\n14\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t15\tdet\t15:det\t_\n15\tInterior\tInterior\tPROPN\tNNP\tNumber=Sing\t12\tnmod\t12:nmod:of\tSpaceAfter=No\n16\t!\t!\tPUNCT\t.\t_\t6\tpunct\t6:punct\t_\n\n\"\"\".lstrip()\n\nDEV_DATA = \"\"\"\n1\tFrom\tfrom\tADP\tIN\t_\t3\tcase\t3:case\t_\n2\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t3\tdet\t3:det\t_\n3\tAP\tAP\tPROPN\tNNP\tNumber=Sing\t4\tobl\t4:obl:from\t_\n4\tcomes\tcome\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t0:root\t_\n5\tthis\tthis\tDET\tDT\tNumber=Sing|PronType=Dem\t6\tdet\t6:det\t_\n6\tstory\tstory\tNOUN\tNN\tNumber=Sing\t4\tnsubj\t4:nsubj\t_\n7\t:\t:\tPUNCT\t:\t_\t4\tpunct\t4:punct\t_\n\n\"\"\".lstrip()\n\nclass TestLemmatizer:\n    @pytest.fixture(scope=\"class\")\n    def charlm_args(self):\n        charlm = choose_lemma_charlm(\"en\", \"test\", \"default\")\n        charlm_args = build_charlm_args(\"en\", charlm, model_dir=TEST_MODELS_DIR)\n        return charlm_args\n\n\n    def run_training(self, tmp_path, train_text, dev_text, extra_args=None):\n        \"\"\"\n        Run the training for a few iterations, load & return the model\n        \"\"\"\n        pred_file = str(tmp_path / \"pred.conllu\")\n\n        save_name = \"test_tagger.pt\"\n        save_file = str(tmp_path / save_name)\n\n        train_file = str(tmp_path / \"train.conllu\")\n        with open(train_file, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(train_text)\n\n        dev_file = str(tmp_path / \"dev.conllu\")\n        with open(dev_file, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(dev_text)\n\n        args = [\"--train_file\", train_file,\n                \"--eval_file\", dev_file,\n                \"--output_file\", pred_file,\n                \"--num_epoch\", \"2\",\n                \"--log_step\", \"10\",\n                \"--save_dir\", str(tmp_path),\n                \"--save_name\", save_name,\n                \"--shorthand\", \"en_test\"]\n        if extra_args is not None:\n            args = args + extra_args\n        lemmatizer.main(args)\n\n        assert os.path.exists(save_file)\n        saved_model = trainer.Trainer(model_file=save_file)\n        return saved_model\n\n    def test_basic_train(self, tmp_path):\n        \"\"\"\n        Simple test of a few 'epochs' of lemmatizer training\n        \"\"\"\n        self.run_training(tmp_path, TRAIN_DATA, DEV_DATA)\n\n    def test_charlm_train(self, tmp_path, charlm_args):\n        \"\"\"\n        Simple test of a few 'epochs' of lemmatizer training\n        \"\"\"\n        saved_model = self.run_training(tmp_path, TRAIN_DATA, DEV_DATA, extra_args=charlm_args)\n\n        # check that the charlm wasn't saved in here\n        args = saved_model.args\n        save_name = os.path.join(args['save_dir'], args['save_name'])\n        checkpoint = torch.load(save_name, lambda storage, loc: storage, weights_only=True)\n        assert not any(x.startswith(\"contextual_embedding\") for x in checkpoint['model'].keys())\n"
  },
  {
    "path": "stanza/tests/lemma/test_lowercase.py",
    "content": "import pytest\n\nfrom stanza.models.lemmatizer import all_lowercase\nfrom stanza.utils.conll import CoNLL\n\nLATIN_CONLLU = \"\"\"\n# sent_id = train-s1\n# text = unde et philosophus dicit felicitatem esse operationem perfectam.\n# reference = ittb-scg-s4203\n1\tunde\tunde\tADV\tO4\tAdvType=Loc|PronType=Rel\t4\tadvmod:lmod\t_\t_\n2\tet\tet\tCCONJ\tO4\t_\t3\tadvmod:emph\t_\t_\n3\tphilosophus\tphilosophus\tNOUN\tB1|grn1|casA|gen1\tCase=Nom|Gender=Masc|InflClass=IndEurO|Number=Sing\t4\tnsubj\t_\t_\n4\tdicit\tdico\tVERB\tN3|modA|tem1|gen6\tAspect=Imp|InflClass=LatX|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act\t0\troot\t_\tTraditionalMood=Indicativus|TraditionalTense=Praesens\n5\tfelicitatem\tfelicitas\tNOUN\tC1|grn1|casD|gen2\tCase=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing\t7\tnsubj\t_\t_\n6\tesse\tsum\tAUX\tN3|modH|tem1\tAspect=Imp|Tense=Pres|VerbForm=Inf\t7\tcop\t_\t_\n7\toperationem\toperatio\tNOUN\tC1|grn1|casD|gen2|vgr1\tCase=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing\t4\tccomp\t_\t_\n8\tperfectam\tperfectus\tADJ\tA1|grn1|casD|gen2\tCase=Acc|Gender=Fem|InflClass=IndEurA|Number=Sing\t7\tamod\t_\tSpaceAfter=No\n9\t.\t.\tPUNCT\tPunc\t_\t4\tpunct\t_\t_\n\n# sent_id = train-s2\n# text = perfectio autem operationis dependet ex quatuor.\n# reference = ittb-scg-s4204\n1\tperfectio\tperfectio\tNOUN\tC1|grn1|casA|gen2\tCase=Nom|Gender=Fem|InflClass=IndEurX|Number=Sing\t4\tnsubj\t_\t_\n2\tautem\tautem\tPART\tO4\t_\t4\tdiscourse\t_\t_\n3\toperationis\toperatio\tNOUN\tC1|grn1|casB|gen2|vgr1\tCase=Gen|Gender=Fem|InflClass=IndEurX|Number=Sing\t1\tnmod\t_\t_\n4\tdependet\tdependeo\tVERB\tK3|modA|tem1|gen6\tAspect=Imp|InflClass=LatE|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act\t0\troot\t_\tTraditionalMood=Indicativus|TraditionalTense=Praesens\n5\tex\tex\tADP\tS4|vgr2\t_\t6\tcase\t_\t_\n6\tquatuor\tquattuor\tNUM\tG1|gen3|vgr1\tNumForm=Word|NumType=Card\t4\tobl:arg\t_\tSpaceAfter=No\n7\t.\t.\tPUNCT\tPunc\t_\t4\tpunct\t_\t_\n\"\"\".lstrip()\n\nENG_CONLLU = \"\"\"\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0007\n# text = You wonder if he was manipulating the market with his bombing targets.\n1\tYou\tyou\tPRON\tPRP\tCase=Nom|Person=2|PronType=Prs\t2\tnsubj\t2:nsubj\t_\n2\twonder\twonder\tVERB\tVBP\tMood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin\t0\troot\t0:root\t_\n3\tif\tif\tSCONJ\tIN\t_\t6\tmark\t6:mark\t_\n4\the\the\tPRON\tPRP\tCase=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t6\tnsubj\t6:nsubj\t_\n5\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t6\taux\t6:aux\t_\n6\tmanipulating\tmanipulate\tVERB\tVBG\tTense=Pres|VerbForm=Part\t2\tccomp\t2:ccomp\t_\n7\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t8\tdet\t8:det\t_\n8\tmarket\tmarket\tNOUN\tNN\tNumber=Sing\t6\tobj\t6:obj\t_\n9\twith\twith\tADP\tIN\t_\t12\tcase\t12:case\t_\n10\this\this\tPRON\tPRP$\tCase=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t12\tnmod:poss\t12:nmod:poss\t_\n11\tbombing\tbombing\tNOUN\tNN\tNumber=Sing\t12\tcompound\t12:compound\t_\n12\ttargets\ttarget\tNOUN\tNNS\tNumber=Plur\t6\tobl\t6:obl:with\tSpaceAfter=No\n13\t.\t.\tPUNCT\t.\t_\t2\tpunct\t2:punct\t_\n\"\"\".lstrip()\n\n\ndef test_all_lowercase():\n    doc = CoNLL.conll2doc(input_str=LATIN_CONLLU)\n    assert all_lowercase(doc)\n\ndef test_not_all_lowercase():\n    doc = CoNLL.conll2doc(input_str=ENG_CONLLU)\n    assert not all_lowercase(doc)\n"
  },
  {
    "path": "stanza/tests/lemma_classifier/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/lemma_classifier/test_data_preparation.py",
    "content": "import os\n\nimport pytest\n\nimport stanza.models.lemma_classifier.utils as utils\nimport stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nEWT_ONE_SENTENCE = \"\"\"\n# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002\n# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002\n# text = Here's a Miami Herald interview\n1-2\tHere's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tHere\there\tADV\tRB\tPronType=Dem\t0\troot\t0:root\t_\n2\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t1\tcop\t1:cop\t_\n3\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t6\tdet\t6:det\t_\n4\tMiami\tMiami\tPROPN\tNNP\tNumber=Sing\t5\tcompound\t5:compound\t_\n5\tHerald\tHerald\tPROPN\tNNP\tNumber=Sing\t6\tcompound\t6:compound\t_\n6\tinterview\tinterview\tNOUN\tNN\tNumber=Sing\t1\tnsubj\t1:nsubj\t_\n\"\"\".lstrip()\n\n\nEWT_TRAIN_SENTENCES = \"\"\"\n# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002\n# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002\n# text = Here's a Miami Herald interview\n1-2\tHere's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tHere\there\tADV\tRB\tPronType=Dem\t0\troot\t0:root\t_\n2\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t1\tcop\t1:cop\t_\n3\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t6\tdet\t6:det\t_\n4\tMiami\tMiami\tPROPN\tNNP\tNumber=Sing\t5\tcompound\t5:compound\t_\n5\tHerald\tHerald\tPROPN\tNNP\tNumber=Sing\t6\tcompound\t6:compound\t_\n6\tinterview\tinterview\tNOUN\tNN\tNumber=Sing\t1\tnsubj\t1:nsubj\t_\n\n# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0027\n# text = But Posada's nearly 80 years old\n1\tBut\tbut\tCCONJ\tCC\t_\t7\tcc\t7:cc\t_\n2-3\tPosada's\t_\t_\t_\t_\t_\t_\t_\t_\n2\tPosada\tPosada\tPROPN\tNNP\tNumber=Sing\t7\tnsubj\t7:nsubj\t_\n3\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t7\tcop\t7:cop\t_\n4\tnearly\tnearly\tADV\tRB\t_\t5\tadvmod\t5:advmod\t_\n5\t80\t80\tNUM\tCD\tNumForm=Digit|NumType=Card\t6\tnummod\t6:nummod\t_\n6\tyears\tyear\tNOUN\tNNS\tNumber=Plur\t7\tobl:npmod\t7:obl:npmod\t_\n7\told\told\tADJ\tJJ\tDegree=Pos\t0\troot\t0:root\tSpaceAfter=No\n\n# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0067\n# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0011\n# text = Now that's a post I can relate to.\n1\tNow\tnow\tADV\tRB\t_\t5\tadvmod\t5:advmod\t_\n2-3\tthat's\t_\t_\t_\t_\t_\t_\t_\t_\n2\tthat\tthat\tPRON\tDT\tNumber=Sing|PronType=Dem\t5\tnsubj\t5:nsubj\t_\n3\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t5\tcop\t5:cop\t_\n4\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t5\tdet\t5:det\t_\n5\tpost\tpost\tNOUN\tNN\tNumber=Sing\t0\troot\t0:root\t_\n6\tI\tI\tPRON\tPRP\tCase=Nom|Number=Sing|Person=1|PronType=Prs\t8\tnsubj\t8:nsubj\t_\n7\tcan\tcan\tAUX\tMD\tVerbForm=Fin\t8\taux\t8:aux\t_\n8\trelate\trelate\tVERB\tVB\tVerbForm=Inf\t5\tacl:relcl\t5:acl:relcl\t_\n9\tto\tto\tADP\tIN\t_\t8\tobl\t8:obl\tSpaceAfter=No\n10\t.\t.\tPUNCT\t.\t_\t5\tpunct\t5:punct\t_\n\n# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0073\n# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0012\n# text = hey that's a great blog\n1\they\they\tINTJ\tUH\t_\t6\tdiscourse\t6:discourse\t_\n2-3\tthat's\t_\t_\t_\t_\t_\t_\t_\t_\n2\tthat\tthat\tPRON\tDT\tNumber=Sing|PronType=Dem\t6\tnsubj\t6:nsubj\t_\n3\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t6\tcop\t6:cop\t_\n4\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t6\tdet\t6:det\t_\n5\tgreat\tgreat\tADJ\tJJ\tDegree=Pos\t6\tamod\t6:amod\t_\n6\tblog\tblog\tNOUN\tNN\tNumber=Sing\t0\troot\t0:root\tSpaceAfter=No\n\n# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0089\n# text = And It's Not Hard To Do\n1\tAnd\tand\tCCONJ\tCC\t_\t5\tcc\t5:cc\t_\n2-3\tIt's\t_\t_\t_\t_\t_\t_\t_\t_\n2\tIt\tit\tPRON\tPRP\tCase=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs\t5\texpl\t5:expl\t_\n3\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t5\tcop\t5:cop\t_\n4\tNot\tnot\tPART\tRB\t_\t5\tadvmod\t5:advmod\t_\n5\tHard\thard\tADJ\tJJ\tDegree=Pos\t0\troot\t0:root\t_\n6\tTo\tto\tPART\tTO\t_\t7\tmark\t7:mark\t_\n7\tDo\tdo\tVERB\tVB\tVerbForm=Inf\t5\tcsubj\t5:csubj\tSpaceAfter=No\n\n# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0029\n# text = Meanwhile, a decision's been reached\n1\tMeanwhile\tmeanwhile\tADV\tRB\t_\t7\tadvmod\t7:advmod\tSpaceAfter=No\n2\t,\t,\tPUNCT\t,\t_\t1\tpunct\t1:punct\t_\n3\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t4\tdet\t4:det\t_\n4-5\tdecision's\t_\t_\t_\t_\t_\t_\t_\t_\n4\tdecision\tdecision\tNOUN\tNN\tNumber=Sing\t7\tnsubj:pass\t7:nsubj:pass\t_\n5\t's\thave\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t7\taux\t7:aux\t_\n6\tbeen\tbe\tAUX\tVBN\tTense=Past|VerbForm=Part\t7\taux:pass\t7:aux:pass\t_\n7\treached\treach\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t0:root\t_\n\n# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0138\n# text = It's become a guardian of morality\n1-2\tIt's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tIt\tit\tPRON\tPRP\tCase=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs\t3\tnsubj\t3:nsubj|5:nsubj:xsubj\t_\n2\t's\thave\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\taux\t3:aux\t_\n3\tbecome\tbecome\tVERB\tVBN\tTense=Past|VerbForm=Part\t0\troot\t0:root\t_\n4\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t5\tdet\t5:det\t_\n5\tguardian\tguardian\tNOUN\tNN\tNumber=Sing\t3\txcomp\t3:xcomp\t_\n6\tof\tof\tADP\tIN\t_\t7\tcase\t7:case\t_\n7\tmorality\tmorality\tNOUN\tNN\tNumber=Sing\t5\tnmod\t5:nmod:of\t_\n\n# sent_id = email-enronsent15_01-0018\n# text = It's got its own bathroom and tv\n1-2\tIt's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tIt\tit\tPRON\tPRP\tCase=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs\t3\tnsubj\t3:nsubj|13:nsubj\t_\n2\t's\thave\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\taux\t3:aux\t_\n3\tgot\tget\tVERB\tVBN\tTense=Past|VerbForm=Part\t0\troot\t0:root\t_\n4\tits\tits\tPRON\tPRP$\tCase=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t6\tnmod:poss\t6:nmod:poss\t_\n5\town\town\tADJ\tJJ\tDegree=Pos\t6\tamod\t6:amod\t_\n6\tbathroom\tbathroom\tNOUN\tNN\tNumber=Sing\t3\tobj\t3:obj\t_\n7\tand\tand\tCCONJ\tCC\t_\t8\tcc\t8:cc\t_\n8\ttv\tTV\tNOUN\tNN\tNumber=Sing\t6\tconj\t3:obj|6:conj:and\tSpaceAfter=No\n\n# sent_id = newsgroup-groups.google.com_alt.animals.cat_01ff709c4bf2c60c_ENG_20040418_040100-0022\n# text = It's also got the website\n1-2\tIt's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tIt\tit\tPRON\tPRP\tCase=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs\t4\tnsubj\t4:nsubj\t_\n2\t's\thave\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\taux\t4:aux\t_\n3\talso\talso\tADV\tRB\t_\t4\tadvmod\t4:advmod\t_\n4\tgot\tget\tVERB\tVBN\tTense=Past|VerbForm=Part\t0\troot\t0:root\t_\n5\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t6\tdet\t6:det\t_\n6\twebsite\twebsite\tNOUN\tNN\tNumber=Sing\t4\tobj\t4:obj|12:obl\t_\n\"\"\".lstrip()\n\n\n# from the train set, actually\nEWT_DEV_SENTENCES = \"\"\"\n# sent_id = answers-20111108104724AAuBUR7_ans-0044\n# text = He's only exhibited weight loss and some muscle atrophy\n1-2\tHe's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tHe\the\tPRON\tPRP\tCase=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t4\tnsubj\t4:nsubj\t_\n2\t's\thave\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\taux\t4:aux\t_\n3\tonly\tonly\tADV\tRB\t_\t4\tadvmod\t4:advmod\t_\n4\texhibited\texhibit\tVERB\tVBN\tTense=Past|VerbForm=Part\t0\troot\t0:root\t_\n5\tweight\tweight\tNOUN\tNN\tNumber=Sing\t6\tcompound\t6:compound\t_\n6\tloss\tloss\tNOUN\tNN\tNumber=Sing\t4\tobj\t4:obj\t_\n7\tand\tand\tCCONJ\tCC\t_\t10\tcc\t10:cc\t_\n8\tsome\tsome\tDET\tDT\tPronType=Ind\t10\tdet\t10:det\t_\n9\tmuscle\tmuscle\tNOUN\tNN\tNumber=Sing\t10\tcompound\t10:compound\t_\n10\tatrophy\tatrophy\tNOUN\tNN\tNumber=Sing\t6\tconj\t4:obj|6:conj:and\tSpaceAfter=No\n\n# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0097\n# text = It's a good thing too.\n1-2\tIt's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tIt\tit\tPRON\tPRP\tCase=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs\t5\tnsubj\t5:nsubj\t_\n2\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t5\tcop\t5:cop\t_\n3\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t5\tdet\t5:det\t_\n4\tgood\tgood\tADJ\tJJ\tDegree=Pos\t5\tamod\t5:amod\t_\n5\tthing\tthing\tNOUN\tNN\tNumber=Sing\t0\troot\t0:root\t_\n6\ttoo\ttoo\tADV\tRB\t_\t5\tadvmod\t5:advmod\tSpaceAfter=No\n7\t.\t.\tPUNCT\t.\t_\t5\tpunct\t5:punct\t_\n\"\"\".lstrip()\n\n# from the train set, actually\nEWT_TEST_SENTENCES = \"\"\"\n# sent_id = reviews-162422-0015\n# text = He said he's had a long and bad day.\n1\tHe\the\tPRON\tPRP\tCase=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t2\tnsubj\t2:nsubj\t_\n2\tsaid\tsay\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n3-4\the's\t_\t_\t_\t_\t_\t_\t_\t_\n3\the\the\tPRON\tPRP\tCase=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t5\tnsubj\t5:nsubj\t_\n4\t's\thave\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t5\taux\t5:aux\t_\n5\thad\thave\tVERB\tVBN\tTense=Past|VerbForm=Part\t2\tccomp\t2:ccomp\t_\n6\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t10\tdet\t10:det\t_\n7\tlong\tlong\tADJ\tJJ\tDegree=Pos\t10\tamod\t10:amod\t_\n8\tand\tand\tCCONJ\tCC\t_\t9\tcc\t9:cc\t_\n9\tbad\tbad\tADJ\tJJ\tDegree=Pos\t7\tconj\t7:conj:and|10:amod\t_\n10\tday\tday\tNOUN\tNN\tNumber=Sing\t5\tobj\t5:obj\tSpaceAfter=No\n11\t.\t.\tPUNCT\t.\t_\t2\tpunct\t2:punct\t_\n\n# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0100\n# text = What's a few dead soldiers\n1-2\tWhat's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tWhat\twhat\tPRON\tWP\tPronType=Int\t6\tnsubj\t6:nsubj\t_\n2\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t6\tcop\t6:cop\t_\n3\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t6\tdet\t6:det\t_\n4\tfew\tfew\tADJ\tJJ\tDegree=Pos\t6\tamod\t6:amod\t_\n5\tdead\tdead\tADJ\tJJ\tDegree=Pos\t6\tamod\t6:amod\t_\n6\tsoldiers\tsoldier\tNOUN\tNNS\tNumber=Plur\t0\troot\t0:root\t_\n\"\"\"\n\ndef write_test_dataset(tmp_path, texts, datasets):\n    ud_path = tmp_path / \"ud\"\n    input_path = ud_path / \"UD_English-EWT\"\n    output_path = tmp_path / \"data\" / \"lemma_classifier\"\n\n    os.makedirs(input_path, exist_ok=True)\n\n    for text, dataset in zip(texts, datasets):\n        sample_file = input_path / (\"en_ewt-ud-%s.conllu\" % dataset)\n        with open(sample_file, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(text)\n\n    paths = {\"UDBASE\": ud_path,\n             \"LEMMA_CLASSIFIER_DATA_DIR\": output_path}\n\n    return paths\n\ndef write_english_test_dataset(tmp_path):\n    texts = (EWT_TRAIN_SENTENCES, EWT_DEV_SENTENCES, EWT_TEST_SENTENCES)\n    datasets = prepare_lemma_classifier.SECTIONS\n    return write_test_dataset(tmp_path, texts, datasets)\n\ndef convert_english_dataset(tmp_path):\n    paths = write_english_test_dataset(tmp_path)\n    converted_files = prepare_lemma_classifier.process_treebank(paths, \"en_ewt\", \"'s\", \"AUX\", \"be|have\")\n    assert len(converted_files) == 3\n\n    return converted_files\n\ndef test_convert_one_sentence(tmp_path):\n    texts = [EWT_ONE_SENTENCE]\n    datasets = [\"train\"]\n    paths = write_test_dataset(tmp_path, texts, datasets)\n\n    converted_files = prepare_lemma_classifier.process_treebank(paths, \"en_ewt\", \"'s\", \"AUX\", \"be|have\", [\"train\"])\n    assert len(converted_files) == 1\n\n    dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)\n\n    assert len(dataset) == 1\n    assert dataset.label_decoder == {'be': 0}\n    id_to_upos = {y: x for x, y in dataset.upos_to_id.items()}\n\n    for text_batches, _, upos_batches, _ in dataset:\n        assert text_batches == [['Here', \"'s\", 'a', 'Miami', 'Herald', 'interview']]\n        upos = [id_to_upos[x] for x in upos_batches[0]]\n        assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN']\n\ndef test_convert_dataset(tmp_path):\n    converted_files = convert_english_dataset(tmp_path)\n\n    dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)\n\n    assert len(dataset) == 1\n    label_decoder = dataset.label_decoder\n    assert len(label_decoder) == 2\n    assert \"be\" in label_decoder\n    assert \"have\" in label_decoder\n    for text_batches, _, _, _ in dataset:\n        assert len(text_batches) == 9\n\n    dataset = utils.Dataset(converted_files[1], get_counts=True, batch_size=10, shuffle=False)\n    assert len(dataset) == 1\n    for text_batches, _, _, _ in dataset:\n        assert len(text_batches) == 2\n\n    dataset = utils.Dataset(converted_files[2], get_counts=True, batch_size=10, shuffle=False)\n    assert len(dataset) == 1\n    for text_batches, _, _, _ in dataset:\n        assert len(text_batches) == 2\n\n"
  },
  {
    "path": "stanza/tests/lemma_classifier/test_training.py",
    "content": "import glob\nimport os\n\nimport pytest\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nfrom stanza.models.lemma_classifier import train_lstm_model\nfrom stanza.models.lemma_classifier import train_transformer_model\nfrom stanza.models.lemma_classifier.base_model import LemmaClassifier\nfrom stanza.models.lemma_classifier.evaluate_models import evaluate_model\n\nfrom stanza.tests import TEST_WORKING_DIR\nfrom stanza.tests.lemma_classifier.test_data_preparation import convert_english_dataset\n\n@pytest.fixture(scope=\"module\")\ndef pretrain_file():\n    return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'\n\ndef test_train_lstm(tmp_path, pretrain_file):\n    converted_files = convert_english_dataset(tmp_path)\n\n    save_name = str(tmp_path / 'lemma.pt')\n\n    train_file = converted_files[0]\n    eval_file = converted_files[1]\n    train_args = ['--wordvec_pretrain_file', pretrain_file,\n                  '--save_name', save_name,\n                  '--train_file', train_file,\n                  '--eval_file', eval_file]\n    trainer = train_lstm_model.main(train_args)\n\n    evaluate_model(trainer.model, eval_file)\n    # test that loading the model works\n    model = LemmaClassifier.load(save_name, None)\n\ndef test_train_transformer(tmp_path, pretrain_file):\n    converted_files = convert_english_dataset(tmp_path)\n\n    save_name = str(tmp_path / 'lemma.pt')\n\n    train_file = converted_files[0]\n    eval_file = converted_files[1]\n    train_args = ['--bert_model', 'hf-internal-testing/tiny-bert',\n                  '--save_name', save_name,\n                  '--train_file', train_file,\n                  '--eval_file', eval_file]\n    trainer = train_transformer_model.main(train_args)\n\n    evaluate_model(trainer.model, eval_file)\n\n    # test that loading the model works\n    model = LemmaClassifier.load(save_name, None)\n"
  },
  {
    "path": "stanza/tests/morphseg/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/morphseg/conftest.py",
    "content": "\"\"\"\nShared pytest fixtures and configuration\n\"\"\"\n\nimport pytest\nfrom morphseg import MorphemeSegmenter\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\n@pytest.fixture(scope=\"session\")\ndef english_segmenter():\n    \"\"\"\n    Load English segmenter once for the entire test session\n    \"\"\"\n    return MorphemeSegmenter('en')\n\n\n@pytest.fixture(scope=\"session\")\ndef all_segmenters():\n    \"\"\"\n    Load all supported language segmenters\n    \"\"\"\n    segmenters = {}\n    for lang in MorphemeSegmenter.PRETRAINED_MODEL_LANGS:\n        segmenters[lang] = MorphemeSegmenter(lang)\n    return segmenters\n\n\ndef pytest_configure(config):\n    \"\"\"\n    Custom pytest configuration\n    \"\"\"\n    config.addinivalue_line(\n        \"markers\", \"slow: marks tests as slow (deselect with '-m \\\"not slow\\\"')\"\n    )\n    config.addinivalue_line(\n        \"markers\", \"multilingual: marks tests that test multiple languages\"\n    )\n"
  },
  {
    "path": "stanza/tests/morphseg/test_integration.py",
    "content": "\"\"\"\nIntegration tests for morphseg\n\"\"\"\n\nimport pytest\nfrom morphseg import MorphemeSegmenter\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nclass TestIntegration:\n\n    def test_full_pipeline(self):\n        \"\"\"Test complete segmentation pipeline\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        text = \"According to all known laws of aviation, there is no way a bee should be able to fly.\"\n        result = segmenter.segment(text, output_string=False)\n\n        # Should segment multiple words\n        assert len(result) > 10\n\n        # Each word should have at least one morpheme\n        for word_morphemes in result:\n            assert len(word_morphemes) >= 1\n\n    def test_consistency_across_modes(self):\n        \"\"\"Test that list and string output modes are consistent\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        words = ['running', 'dogs', 'aviation']\n\n        for word in words:\n            list_result = segmenter.segment(word, output_string=False)\n            string_result = segmenter.segment(word, output_string=True, delimiter=' @@')\n\n            # String result should be reconstructable from list result\n            expected_string = ' @@'.join(list_result[0])\n            assert string_result == expected_string, \\\n                f\"List and string outputs don't match for '{word}'\"\n\n    def test_unicode_handling(self):\n        \"\"\"Test handling of unicode characters\"\"\"\n        segmenter = MorphemeSegmenter('fr')\n\n        text = \"café résumé\"\n        result = segmenter.segment(text, output_string=False)\n\n        assert isinstance(result, list)\n        assert len(result) >= 1\n\n    def test_mixed_case(self):\n        \"\"\"Test handling of mixed case input\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        # Should normalize to lowercase\n        result1 = segmenter.segment('Running', output_string=False)\n        result2 = segmenter.segment('RUNNING', output_string=False)\n        result3 = segmenter.segment('running', output_string=False)\n\n        # All should produce the same result\n        assert result1 == result2 == result3\n"
  },
  {
    "path": "stanza/tests/morphseg/test_morpheme_segmenter.py",
    "content": "\"\"\"\nTests for MorphemeSegmenter class\n\"\"\"\n\nimport pytest\nfrom morphseg import MorphemeSegmenter\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nclass TestMorphemeSegmenter:\n\n    @pytest.fixture(scope=\"class\")\n    def english_segmenter(self):\n        \"\"\"Load English model once for all tests\"\"\"\n        return MorphemeSegmenter('en')\n\n    def test_basic_segmentation(self, english_segmenter):\n        \"\"\"Test basic morpheme segmentation\"\"\"\n        result = english_segmenter.segment('running', output_string=False)\n\n        assert isinstance(result, list)\n        assert len(result) == 1\n        assert isinstance(result[0], list)\n        assert len(result[0]) >= 1\n\n    def test_multiple_words(self, english_segmenter):\n        \"\"\"Test segmentation of multiple words\"\"\"\n        result = english_segmenter.segment('running quickly', output_string=False)\n\n        assert isinstance(result, list)\n        assert len(result) == 2\n        for segmentation in result:\n            assert isinstance(segmentation, list)\n            assert len(segmentation) >= 1\n\n    def test_known_segmentations(self, english_segmenter):\n        \"\"\"Test known morpheme segmentations\"\"\"\n        test_cases = {\n            'dogs': ['dog', 's'],\n            'aviation': ['aviate', 'ion'],\n            'known': ['know', 'n'],\n        }\n\n        for word, expected in test_cases.items():\n            result = english_segmenter.segment(word, output_string=False)\n            assert result[0] == expected, f\"Expected {expected}, got {result[0]} for '{word}'\"\n\n    def test_output_string_mode(self, english_segmenter):\n        \"\"\"Test string output mode\"\"\"\n        result = english_segmenter.segment('running quickly', output_string=True)\n\n        assert isinstance(result, str)\n        assert ' @@' in result  # Default delimiter\n\n    def test_custom_delimiter(self, english_segmenter):\n        \"\"\"Test custom delimiter in output\"\"\"\n        result = english_segmenter.segment('running', output_string=True, delimiter='-')\n\n        assert isinstance(result, str)\n        assert '-' in result or result == 'running'  # May be unsegmented\n\n    def test_empty_input(self, english_segmenter):\n        \"\"\"Test handling of empty input\"\"\"\n        result = english_segmenter.segment('', output_string=False)\n        assert result == []\n\n        result = english_segmenter.segment('', output_string=True)\n        assert result == \"\"\n\n    def test_single_character(self, english_segmenter):\n        \"\"\"Test single character input\"\"\"\n        result = english_segmenter.segment('a', output_string=False)\n\n        assert isinstance(result, list)\n        assert len(result) == 1\n        assert result[0] == ['a']\n\n    def test_punctuation(self, english_segmenter):\n        \"\"\"Test handling of punctuation\"\"\"\n        result = english_segmenter.segment('Hello, world!', output_string=False)\n\n        assert isinstance(result, list)\n        # Should segment only words, not punctuation\n        assert len(result) > 0\n\n\nclass TestDeterminism:\n    \"\"\"\n    Tests to ensure predictions are deterministic\n    \"\"\"\n\n    def test_deterministic_predictions(self):\n        \"\"\"Test that same input produces same output consistently\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        test_words = ['running', 'dogs', 'quickly', 'aviation']\n\n        for word in test_words:\n            results = []\n            for _ in range(5):\n                result = segmenter.segment(word, output_string=False)\n                results.append(result)\n\n            # All results should be identical\n            for i in range(1, len(results)):\n                assert results[i] == results[0], \\\n                    f\"Non-deterministic results for '{word}': {results[0]} vs {results[i]}\"\n\n    def test_deterministic_batch(self):\n        \"\"\"Test determinism with batch processing\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        text = \"The dogs are running quickly through the fields.\"\n\n        results = []\n        for _ in range(3):\n            result = segmenter.segment(text, output_string=False)\n            results.append(result)\n\n        # All results should be identical\n        for i in range(1, len(results)):\n            assert results[i] == results[0], \\\n                f\"Non-deterministic batch results: {results[0]} vs {results[i]}\"\n\n\nclass TestMultilingual:\n\n    @pytest.mark.parametrize(\"lang\", ['cs', 'en', 'es', 'fr', 'hu', 'it', 'la', 'ru'])\n    def test_language_loading(self, lang):\n        \"\"\"Test that all supported languages can be loaded\"\"\"\n        segmenter = MorphemeSegmenter(lang)\n        assert segmenter.lang == lang\n        assert segmenter.sequence_labeller is not None\n\n    @pytest.mark.parametrize(\"lang,word\", [\n        ('en', 'running'),\n        ('es', 'corriendo'),\n        ('fr', 'rapidement'),\n        ('ru', 'бегущий'),  # Russian instead of German\n    ])\n    def test_multilingual_segmentation(self, lang, word):\n        \"\"\"Test segmentation across languages\"\"\"\n        if lang not in MorphemeSegmenter.PRETRAINED_MODEL_LANGS:\n            pytest.skip(f\"Language {lang} not supported\")\n\n        segmenter = MorphemeSegmenter(lang)\n        result = segmenter.segment(word, output_string=False)\n\n        assert isinstance(result, list)\n        assert len(result) >= 1\n\n\nclass TestErrorHandling:\n\n    def test_invalid_language(self):\n        \"\"\"Test handling of invalid language code\"\"\"\n        with pytest.warns(UserWarning):\n            segmenter = MorphemeSegmenter('invalid_lang')\n            assert segmenter.sequence_labeller is None\n\n    def test_invalid_input_type(self):\n        \"\"\"Test handling of invalid input types\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        with pytest.raises(ValueError, match=\"Input sequence must be a string\"):\n            segmenter.segment(123)\n\n        with pytest.raises(ValueError, match=\"Input sequence must be a string\"):\n            segmenter.segment(['not', 'a', 'string'])\n\n    def test_invalid_output_string_type(self):\n        \"\"\"Test handling of invalid output_string parameter\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        with pytest.raises(ValueError, match=\"output_string must be a boolean\"):\n            segmenter.segment('test', output_string='yes')\n\n    def test_invalid_delimiter_type(self):\n        \"\"\"Test handling of invalid delimiter parameter\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        with pytest.raises(ValueError, match=\"Delimiter must be a string\"):\n            segmenter.segment('test', delimiter=123)\n\n    def test_model_not_trained(self):\n        \"\"\"Test error when using untrained model\"\"\"\n        segmenter = MorphemeSegmenter('en')\n        segmenter.sequence_labeller = None\n\n        with pytest.raises(RuntimeError, match=\"Model not trained\"):\n            segmenter.segment('test')\n\n\nclass TestModelState:\n    \"\"\"\n    Tests to ensure model is in correct state\n    \"\"\"\n\n    def test_model_in_eval_mode(self):\n        \"\"\"Test that loaded model is in eval mode\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        # Check that model is in eval mode\n        assert not segmenter.sequence_labeller.model.model.training, \\\n            \"Model should be in eval mode after loading\"\n\n    def test_model_stays_in_eval_mode(self):\n        \"\"\"Test that model stays in eval mode after predictions\"\"\"\n        segmenter = MorphemeSegmenter('en')\n\n        # Make several predictions\n        for _ in range(3):\n            segmenter.segment('running', output_string=False)\n\n        # Model should still be in eval mode\n        assert not segmenter.sequence_labeller.model.model.training, \\\n            \"Model should remain in eval mode after predictions\"\n"
  },
  {
    "path": "stanza/tests/morphseg/test_stanza_integration.py",
    "content": "\"\"\"\nIntegration tests for Stanza MorphSeg Processor\nTests the morpheme segmentation processor within the Stanza pipeline\n\"\"\"\n\nimport pytest\nimport stanza\nfrom stanza.models.common.doc import Document\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nclass TestMorphSegProcessor:\n    \"\"\"Tests for the MorphSeg processor in Stanza pipeline\"\"\"\n\n    @pytest.fixture(scope=\"class\")\n    def en_pipeline(self):\n        \"\"\"Create English pipeline with morphseg processor\"\"\"\n        return stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            model_dir=TEST_MODELS_DIR,\n            download_method=None\n        )\n\n    def test_processor_loads(self, en_pipeline):\n        \"\"\"Test that morphseg processor loads successfully\"\"\"\n        assert 'morphseg' in en_pipeline.processors\n        assert en_pipeline.processors['morphseg'] is not None\n\n    def test_basic_segmentation(self, en_pipeline):\n        \"\"\"Test basic morpheme segmentation through pipeline\"\"\"\n        doc = en_pipeline(\"running\")\n\n        assert len(doc.sentences) == 1\n        assert len(doc.sentences[0].words) == 1\n\n        word = doc.sentences[0].words[0]\n        assert hasattr(word, 'morphemes')\n        assert isinstance(word.morphemes, list)\n        assert len(word.morphemes) >= 1\n\n    def test_known_segmentations(self, en_pipeline):\n        \"\"\"Test known morpheme segmentations\"\"\"\n        # Note: These are actual segmentations from the en2 model\n        # Some words may be unsegmented depending on the model\n        test_cases = {\n            'dogs': ['dog', 's'],\n            'aviation': ['aviate', 'ion'],\n            'known': ['know', 'n'],\n        }\n\n        for word_text, expected in test_cases.items():\n            doc = en_pipeline(word_text)\n            word = doc.sentences[0].words[0]\n            assert word.morphemes == expected, \\\n                f\"Expected {expected}, got {word.morphemes} for '{word_text}'\"\n\n    def test_segmentation_consistency(self, en_pipeline):\n        \"\"\"Test that segmentation is consistent and produces valid output\"\"\"\n        words = ['running', 'quickly', 'walked', 'playing']\n\n        for word_text in words:\n            doc = en_pipeline(word_text)\n            word = doc.sentences[0].words[0]\n\n            # Should have morphemes attribute\n            assert hasattr(word, 'morphemes')\n            assert isinstance(word.morphemes, list)\n            assert len(word.morphemes) >= 1\n\n            # All morphemes should be strings\n            for morpheme in word.morphemes:\n                assert isinstance(morpheme, str)\n                assert len(morpheme) > 0\n\n    def test_multiple_words(self, en_pipeline):\n        \"\"\"Test segmentation of multiple words in a sentence\"\"\"\n        doc = en_pipeline(\"The dogs are running quickly.\")\n\n        # Check that all words have morphemes attribute\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n                assert isinstance(word.morphemes, list)\n                assert len(word.morphemes) >= 1\n\n    def test_punctuation_handling(self, en_pipeline):\n        \"\"\"Test that punctuation is handled correctly\"\"\"\n        doc = en_pipeline(\"Hello, world!\")\n\n        # All tokens should have morphemes, including punctuation\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n                # Punctuation should be unsegmented\n                if word.text in [',', '!', '.']:\n                    assert word.morphemes == [word.text]\n\n    def test_long_text(self, en_pipeline):\n        \"\"\"Test processing of longer text\"\"\"\n        text = \"According to all known laws of aviation, there is no way a bee should be able to fly.\"\n        doc = en_pipeline(text)\n\n        # Should have multiple sentences or one long sentence\n        assert len(doc.sentences) >= 1\n\n        # Count words with morpheme segmentation\n        total_words = sum(len(sent.words) for sent in doc.sentences)\n        assert total_words > 10\n\n        # All words should have morphemes\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n\n    def test_empty_input(self, en_pipeline):\n        \"\"\"Test handling of empty input\"\"\"\n        doc = en_pipeline(\"\")\n        assert len(doc.sentences) == 0\n\n    def test_single_character(self, en_pipeline):\n        \"\"\"Test single character input\"\"\"\n        doc = en_pipeline(\"I\")\n\n        assert len(doc.sentences) == 1\n        word = doc.sentences[0].words[0]\n        assert word.morphemes == ['i']  # Normalized to lowercase\n\n    def test_morphemes_attribute_persistence(self, en_pipeline):\n        \"\"\"Test that morphemes attribute persists through pipeline\"\"\"\n        doc = en_pipeline(\"running quickly\")\n\n        # Store morphemes\n        morphemes_list = []\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                morphemes_list.append(word.morphemes)\n\n        # Access again to ensure persistence\n        for i, sentence in enumerate(doc.sentences):\n            for j, word in enumerate(sentence.words):\n                assert hasattr(word, 'morphemes')\n                assert word.morphemes is not None\n\n\nclass TestMultilingualMorphSeg:\n    \"\"\"Test morpheme segmentation across different languages\"\"\"\n\n    @pytest.mark.parametrize(\"lang,text,expected_word\", [\n        ('en', 'running', 'running'),\n        ('es', 'corriendo', 'corriendo'),\n        ('fr', 'rapidement', 'rapidement'),\n        ('cs', 'běžící', 'běžící'),\n        ('it', 'correndo', 'correndo'),\n    ])\n    def test_multilingual_support(self, lang, text, expected_word):\n        \"\"\"Test that different languages can be processed\"\"\"\n        try:\n            nlp = stanza.Pipeline(\n                lang=lang,\n                processors='tokenize,morphseg',\n                download_method=None\n            )\n            doc = nlp(text)\n\n            assert len(doc.sentences) >= 1\n            assert len(doc.sentences[0].words) >= 1\n\n            word = doc.sentences[0].words[0]\n            assert hasattr(word, 'morphemes')\n            assert isinstance(word.morphemes, list)\n\n        except Exception as e:\n            pytest.skip(f\"Language {lang} not available: {e}\")\n\n\nclass TestMorphSegWithOtherProcessors:\n    \"\"\"Test morphseg processor in combination with other processors\"\"\"\n\n    def test_with_mwt(self):\n        \"\"\"Test morphseg with MWT processor\"\"\"\n        nlp = stanza.Pipeline(\n            lang='en',\n            processors='tokenize,mwt,morphseg',\n            download_method=None\n        )\n\n        doc = nlp(\"The dogs are running.\")\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n\n    def test_with_pos(self):\n        \"\"\"Test morphseg with POS tagging\"\"\"\n        try:\n            nlp = stanza.Pipeline(\n                lang='en',\n                processors='tokenize,pos,morphseg',\n                download_method=None\n            )\n\n            doc = nlp(\"running quickly\")\n\n            for sentence in doc.sentences:\n                for word in sentence.words:\n                    # Should have both POS and morphemes\n                    assert hasattr(word, 'morphemes')\n                    assert hasattr(word, 'upos') or hasattr(word, 'xpos')\n\n        except Exception as e:\n            pytest.skip(f\"POS processor not available: {e}\")\n\n    def test_with_lemma(self):\n        \"\"\"Test morphseg with lemmatization\"\"\"\n        try:\n            nlp = stanza.Pipeline(\n                lang='en',\n                processors='tokenize,pos,lemma,morphseg',\n                download_method=None\n            )\n\n            doc = nlp(\"The dogs were running quickly.\")\n\n            for sentence in doc.sentences:\n                for word in sentence.words:\n                    # Should have both lemma and morphemes\n                    assert hasattr(word, 'morphemes')\n                    assert hasattr(word, 'lemma')\n\n        except Exception as e:\n            pytest.skip(f\"Lemma processor not available: {e}\")\n\n\nclass TestMorphSegDeterminism:\n    \"\"\"Test that morphseg processor produces deterministic results\"\"\"\n\n    def test_deterministic_results(self):\n        \"\"\"Test that same input produces same output\"\"\"\n        nlp = stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            download_method=None\n        )\n\n        text = \"running dogs aviation\"\n\n        results = []\n        for _ in range(3):\n            doc = nlp(text)\n            morphemes = [word.morphemes for sent in doc.sentences for word in sent.words]\n            results.append(morphemes)\n\n        # All results should be identical\n        for i in range(1, len(results)):\n            assert results[i] == results[0], \\\n                f\"Non-deterministic results: {results[0]} vs {results[i]}\"\n\n    def test_batch_determinism(self):\n        \"\"\"Test determinism with batch processing\"\"\"\n        nlp = stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            download_method=None\n        )\n\n        texts = [\n            \"The dogs are running.\",\n            \"Aviation is amazing.\",\n            \"Known facts are helpful.\"\n        ]\n\n        # Process multiple times\n        all_results = []\n        for _ in range(2):\n            batch_results = []\n            for text in texts:\n                doc = nlp(text)\n                morphemes = [word.morphemes for sent in doc.sentences for word in sent.words]\n                batch_results.append(morphemes)\n            all_results.append(batch_results)\n\n        # Results should be identical\n        assert all_results[0] == all_results[1]\n\n\nclass TestMorphSegEdgeCases:\n    \"\"\"Test edge cases and special inputs\"\"\"\n\n    @pytest.fixture(scope=\"class\")\n    def en_pipeline(self):\n        \"\"\"Create English pipeline\"\"\"\n        return stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            download_method=None\n        )\n\n    def test_numbers(self, en_pipeline):\n        \"\"\"Test handling of numbers\"\"\"\n        doc = en_pipeline(\"123 456\")\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n\n    def test_mixed_case(self, en_pipeline):\n        \"\"\"Test mixed case handling\"\"\"\n        # Should normalize to same result\n        doc1 = en_pipeline(\"Running\")\n        doc2 = en_pipeline(\"RUNNING\")\n        doc3 = en_pipeline(\"running\")\n\n        morphemes1 = doc1.sentences[0].words[0].morphemes\n        morphemes2 = doc2.sentences[0].words[0].morphemes\n        morphemes3 = doc3.sentences[0].words[0].morphemes\n\n        assert morphemes1 == morphemes2 == morphemes3\n\n    def test_unicode_characters(self, en_pipeline):\n        \"\"\"Test handling of unicode characters\"\"\"\n        doc = en_pipeline(\"café résumé\")\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n                assert isinstance(word.morphemes, list)\n\n    def test_special_characters(self, en_pipeline):\n        \"\"\"Test handling of special characters\"\"\"\n        doc = en_pipeline(\"test@example.com $100 50%\")\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n\n    def test_very_long_word(self, en_pipeline):\n        \"\"\"Test handling of very long words\"\"\"\n        long_word = \"antidisestablishmentarianism\"\n        doc = en_pipeline(long_word)\n\n        word = doc.sentences[0].words[0]\n        assert hasattr(word, 'morphemes')\n        assert len(word.morphemes) >= 1\n\n    def test_repeated_words(self, en_pipeline):\n        \"\"\"Test handling of repeated words\"\"\"\n        doc = en_pipeline(\"running running running\")\n\n        # All instances should have same segmentation\n        morphemes_list = [word.morphemes for word in doc.sentences[0].words]\n        assert morphemes_list[0] == morphemes_list[1] == morphemes_list[2]\n\n    def test_whitespace_handling(self, en_pipeline):\n        \"\"\"Test handling of various whitespace\"\"\"\n        doc = en_pipeline(\"word1    word2\\tword3\\nword4\")\n\n        # Should properly segment all words despite whitespace\n        word_count = sum(len(sent.words) for sent in doc.sentences)\n        assert word_count >= 4\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n\n\nclass TestMorphSegConfiguration:\n    \"\"\"Test different configurations of morphseg processor\"\"\"\n\n    def test_custom_model_path(self):\n        \"\"\"Test loading with custom model path configuration\"\"\"\n        # Test that the configuration accepts model_path parameter\n        # Using default behavior (no custom path)\n        nlp = stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            download_method=None\n        )\n        doc = nlp(\"testing\")\n        assert len(doc.sentences) > 0\n        assert hasattr(doc.sentences[0].words[0], 'morphemes')\n\n    def test_custom_model_path_with_file(self):\n        \"\"\"Test loading with an actual custom model file path\"\"\"\n        # This test would require a custom model file to exist\n        # Skip if no custom model is available\n        pytest.skip(\"Custom model path test requires a specific model file\")\n\n    def test_processor_requirements(self):\n        \"\"\"Test that morphseg requires tokenize\"\"\"\n        # MorphSeg requires TOKENIZE processor\n        nlp = stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            download_method=None\n        )\n\n        # Verify tokenize is present\n        assert 'tokenize' in nlp.processors or 'tokenize' in str(nlp.processors)\n\n\nclass TestMorphSegOutputFormat:\n    \"\"\"Test output format of morpheme segmentations\"\"\"\n\n    @pytest.fixture(scope=\"class\")\n    def en_pipeline(self):\n        \"\"\"Create English pipeline\"\"\"\n        return stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            download_method=None\n        )\n\n    def test_morphemes_is_list(self, en_pipeline):\n        \"\"\"Test that morphemes attribute is always a list\"\"\"\n        doc = en_pipeline(\"The dogs are running quickly.\")\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert isinstance(word.morphemes, list)\n\n    def test_morphemes_are_strings(self, en_pipeline):\n        \"\"\"Test that all morphemes are strings\"\"\"\n        doc = en_pipeline(\"The dogs are running quickly.\")\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                for morpheme in word.morphemes:\n                    assert isinstance(morpheme, str)\n\n    def test_morphemes_non_empty(self, en_pipeline):\n        \"\"\"Test that morphemes list is never empty\"\"\"\n        doc = en_pipeline(\"The dogs are running quickly.\")\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert len(word.morphemes) >= 1\n\n    def test_unsegmented_words(self, en_pipeline):\n        \"\"\"Test that unsegmented words have single morpheme\"\"\"\n        # Words like 'the', 'is', 'a' typically don't segment\n        doc = en_pipeline(\"The dog is a pet.\")\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                # Even if unsegmented, should have the word itself as morpheme\n                if len(word.morphemes) == 1:\n                    # The single morpheme should match the normalized word\n                    assert isinstance(word.morphemes[0], str)\n\n\nclass TestMorphSegRepeatedly:\n    \"\"\"Test repeated processing of multiple documents\"\"\"\n\n    def test_sequential_document_processing(self):\n        \"\"\"Test processing multiple documents one after another\"\"\"\n        nlp = stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            download_method=None\n        )\n\n        texts = [\n            \"The dogs are running.\",\n            \"Aviation is fascinating.\",\n            \"Programming requires patience.\"\n        ]\n\n        for text in texts:\n            doc = nlp(text)\n            for sentence in doc.sentences:\n                for word in sentence.words:\n                    assert hasattr(word, 'morphemes')\n                    assert isinstance(word.morphemes, list)\n\n    def test_multi_sentence_document(self):\n        \"\"\"Test processing a document with multiple sentences (internal batching)\"\"\"\n        nlp = stanza.Pipeline(\n            lang='en',\n            processors='tokenize,morphseg',\n            download_method=None\n        )\n\n        doc = nlp(\"The dogs are running. Aviation is fascinating. Programming requires patience.\")\n\n        assert len(doc.sentences) == 3\n\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                assert hasattr(word, 'morphemes')\n                assert isinstance(word.morphemes, list)\n"
  },
  {
    "path": "stanza/tests/mwt/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/mwt/test_character_classifier.py",
    "content": "import os\nimport pytest\n\nfrom stanza.models import mwt_expander\nfrom stanza.models.mwt.character_classifier import CharacterClassifier\nfrom stanza.models.mwt.data import DataLoader\nfrom stanza.models.mwt.trainer import Trainer\nfrom stanza.utils.conll import CoNLL\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nENG_TRAIN = \"\"\"\n# text = Elena's motorcycle tour\n1-2\tElena's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tElena\tElena\tPROPN\tNNP\tNumber=Sing\t4\tnmod:poss\t4:nmod:poss\t_\n2\t's\t's\tPART\tPOS\t_\t1\tcase\t1:case\t_\n3\tmotorcycle\tmotorcycle\tNOUN\tNN\tNumber=Sing\t4\tcompound\t4:compound\t_\n4\ttour\ttour\tNOUN\tNN\tNumber=Sing\t0\troot\t0:root\t_\n\n\n# text = women's reproductive health\n1-2\twomen's\t_\t_\t_\t_\t_\t_\t_\t_\n1\twomen\twoman\tNOUN\tNNS\tNumber=Plur\t4\tnmod:poss\t4:nmod:poss\t_\n2\t's\t's\tPART\tPOS\t_\t1\tcase\t1:case\t_\n3\treproductive\treproductive\tADJ\tJJ\tDegree=Pos\t4\tamod\t4:amod\t_\n4\thealth\thealth\tNOUN\tNN\tNumber=Sing\t0\troot\t0:root\tSpaceAfter=No\n\n\n# text = The Chernobyl Children's Project\n1\tThe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t3\tdet\t3:det\t_\n2\tChernobyl\tChernobyl\tPROPN\tNNP\tNumber=Sing\t3\tcompound\t3:compound\t_\n3-4\tChildren's\t_\t_\t_\t_\t_\t_\t_\t_\n3\tChildren\tChildren\tPROPN\tNNP\tNumber=Sing\t5\tnmod:poss\t5:nmod:poss\t_\n4\t's\t's\tPART\tPOS\t_\t3\tcase\t3:case\t_\n5\tProject\tProject\tPROPN\tNNP\tNumber=Sing\t0\troot\t0:root\t_\n\n\"\"\".lstrip()\n\nENG_DEV = \"\"\"\n# text = The Chernobyl Children's Project\n1\tThe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t3\tdet\t3:det\t_\n2\tChernobyl\tChernobyl\tPROPN\tNNP\tNumber=Sing\t3\tcompound\t3:compound\t_\n3-4\tChildren's\t_\t_\t_\t_\t_\t_\t_\t_\n3\tChildren\tChildren\tPROPN\tNNP\tNumber=Sing\t5\tnmod:poss\t5:nmod:poss\t_\n4\t's\t's\tPART\tPOS\t_\t3\tcase\t3:case\t_\n5\tProject\tProject\tPROPN\tNNP\tNumber=Sing\t0\troot\t0:root\t_\n\n\"\"\".lstrip()\n\ndef test_train(tmp_path):\n    test_train = str(os.path.join(tmp_path, \"en_test.train.conllu\"))\n    with open(test_train, \"w\") as fout:\n        fout.write(ENG_TRAIN)\n\n    test_dev = str(os.path.join(tmp_path, \"en_test.dev.conllu\"))\n    with open(test_dev, \"w\") as fout:\n        fout.write(ENG_DEV)\n\n    test_output = str(os.path.join(tmp_path, \"en_test.dev.pred.conllu\"))\n    model_name = \"en_test_mwt.pt\"\n\n    args = [\n        \"--data_dir\", str(tmp_path),\n        \"--train_file\", test_train,\n        \"--eval_file\", test_dev,\n        \"--gold_file\", test_dev,\n        \"--lang\", \"en\",\n        \"--shorthand\", \"en_test\",\n        \"--output_file\", test_output,\n        \"--save_dir\", str(tmp_path),\n        \"--save_name\", model_name,\n        \"--num_epoch\", \"10\",\n    ]\n\n    mwt_expander.main(args=args)\n\n    model = Trainer(model_file=os.path.join(tmp_path, model_name))\n    assert model.model is not None\n    assert isinstance(model.model, CharacterClassifier)\n\n    doc = CoNLL.conll2doc(input_str=ENG_DEV)\n    dataloader = DataLoader(doc, 10, model.args, vocab=model.vocab, evaluation=True, expand_unk_vocab=True)\n    preds = []\n    for i, batch in enumerate(dataloader.to_loader()):\n        assert i == 0 # there should only be one batch\n        preds += model.predict(batch, never_decode_unk=True, vocab=dataloader.vocab)\n    assert len(preds) == 1\n    # it is possible to make a version of the test where this happens almost every time\n    # for example, running for 100 epochs makes the model succeed 30 times in a row\n    # (never saw a failure)\n    # but the one time that failure happened, it would be really annoying\n    #assert preds[0] == \"Children 's\"\n"
  },
  {
    "path": "stanza/tests/mwt/test_english_corner_cases.py",
    "content": "\"\"\"\nTest a couple English MWT corner cases which might be more widely applicable to other MWT languages\n\n- unknown English character doesn't result in bizarre splits\n- Casing or CASING doesn't get lost in the dictionary lookup\n\nIn the English UD datasets, the MWT are composed exactly of the\nsubwords, so the MWT model should be chopping up the input text rather\nthan generating new text.\n\nFurthermore, SHE'S and She's should be split \"SHE 'S\" and \"She 's\" respectively\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef test_mwt_unknown_char():\n    pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)\n\n    mwt_trainer = pipeline.processors['mwt']._trainer\n\n    assert mwt_trainer.args['force_exact_pieces']\n\n    # find a letter 'i' which isn't in the training data\n    # the MWT model should still recognize a possessive containing this letter\n    assert \"i\" in mwt_trainer.vocab\n    for letter in \"ĩîíìī\":\n        if letter not in mwt_trainer.vocab:\n            break\n    else:\n        raise AssertionError(\"Need to update the MWT test - all of the non-standard letters 'i' are now in the MWT vocab\")\n\n    word = \"Jenn\" + letter + \"fer\"\n    possessive = word + \"'s\"\n    text = \"I wanna lick \" + possessive + \" antennae\"\n    doc = pipeline(text)\n    assert doc.sentences[0].tokens[1].text == 'wanna'\n    assert len(doc.sentences[0].tokens[1].words) == 2\n    assert \"\".join(x.text for x in doc.sentences[0].tokens[1].words) == 'wanna'\n\n    assert doc.sentences[0].tokens[3].text == possessive\n    assert len(doc.sentences[0].tokens[3].words) == 2\n    assert \"\".join(x.text for x in doc.sentences[0].tokens[3].words) == possessive\n\n\ndef test_english_mwt_casing():\n    \"\"\"\n    Test that for a word where the lowercase split is known, the correct casing is still used\n\n    Once upon a time, the logic used in the MWT expander would split\n      SHE'S -> she 's\n\n    which is a very surprising tokenization to people expecting\n    the original text in the output document\n    \"\"\"\n    pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)\n\n    mwt_trainer = pipeline.processors['mwt']._trainer\n    for i in range(1, 20):\n        # many test cases follow this pattern for some reason,\n        # so we should proactively look for a test case which hasn't\n        # made its way into the MWT dictionary\n        unknown_name = \"jennife\" + \"r\" * i + \"'s\"\n        if unknown_name not in mwt_trainer.expansion_dict and unknown_name.upper() not in mwt_trainer.expansion_dict:\n            unknown_name = unknown_name.upper()\n            break\n    else:\n        raise AssertionError(\"Need a new heuristic for the unknown word in the English MWT!\")\n\n    # this SHOULD show up in the expansion dict\n    assert \"she's\" in mwt_trainer.expansion_dict, \"Expected |she's| to be in the English MWT expansion dict... perhaps find a different test case\"\n\n    text = [x.text for x in pipeline(\"JENNIFER HAS NICE ANTENNAE\").sentences[0].words]\n    assert text == ['JENNIFER', 'HAS', 'NICE', 'ANTENNAE']\n\n    text = [x.text for x in pipeline(unknown_name + \" GOT NICE ANTENNAE\").sentences[0].words]\n    assert text == [unknown_name[:-2], \"'S\", 'GOT', 'NICE', 'ANTENNAE']\n\n    text = [x.text for x in pipeline(\"SHE'S GOT NICE ANTENNAE\").sentences[0].words]\n    assert text == ['SHE', \"'S\", 'GOT', 'NICE', 'ANTENNAE']\n\n    text = [x.text for x in pipeline(\"She's GOT NICE ANTENNAE\").sentences[0].words]\n    assert text == ['She', \"'s\", 'GOT', 'NICE', 'ANTENNAE']\n\n"
  },
  {
    "path": "stanza/tests/mwt/test_prepare_mwt.py",
    "content": "\nimport pytest\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nfrom stanza.utils.datasets.prepare_mwt_treebank import check_mwt_composition\n\nSAMPLE_GOOD_TEXT = \"\"\"\n# sent_id = weblog-typepad.com_ripples_20040407125600_ENG_20040407_125600-0057\n# text = The Chernobyl Children's Project (http://www.adiccp.org/home/default.asp) offers several ways to help the children of that region.\n1\tThe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t3\tdet\t3:det\t_\n2\tChernobyl\tChernobyl\tPROPN\tNNP\tNumber=Sing\t3\tcompound\t3:compound\t_\n3-4\tChildren's\t_\t_\t_\t_\t_\t_\t_\t_\n3\tChildren\tChildren\tPROPN\tNNP\tNumber=Sing\t5\tnmod:poss\t5:nmod:poss\t_\n4\t's\t's\tPART\tPOS\t_\t3\tcase\t3:case\t_\n5\tProject\tProject\tPROPN\tNNP\tNumber=Sing\t9\tnsubj\t9:nsubj\t_\n6\t(\t(\tPUNCT\t-LRB-\t_\t7\tpunct\t7:punct\tSpaceAfter=No\n7\thttp://www.adiccp.org/home/default.asp\thttp://www.adiccp.org/home/default.asp\tX\tADD\t_\t5\tappos\t5:appos\tSpaceAfter=No\n8\t)\t)\tPUNCT\t-RRB-\t_\t7\tpunct\t7:punct\t_\n9\toffers\toffer\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t0:root\t_\n10\tseveral\tseveral\tADJ\tJJ\tDegree=Pos\t11\tamod\t11:amod\t_\n11\tways\tway\tNOUN\tNNS\tNumber=Plur\t9\tobj\t9:obj\t_\n12\tto\tto\tPART\tTO\t_\t13\tmark\t13:mark\t_\n13\thelp\thelp\tVERB\tVB\tVerbForm=Inf\t11\tacl\t11:acl:to\t_\n14\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t15\tdet\t15:det\t_\n15\tchildren\tchild\tNOUN\tNNS\tNumber=Plur\t13\tobj\t13:obj\t_\n16\tof\tof\tADP\tIN\t_\t18\tcase\t18:case\t_\n17\tthat\tthat\tDET\tDT\tNumber=Sing|PronType=Dem\t18\tdet\t18:det\t_\n18\tregion\tregion\tNOUN\tNN\tNumber=Sing\t15\tnmod\t15:nmod:of\tSpaceAfter=No\n19\t.\t.\tPUNCT\t.\t_\t9\tpunct\t9:punct\t_\n\"\"\".lstrip()\n\nSAMPLE_BAD_TEXT = \"\"\"\n# sent_id = weblog-typepad.com_ripples_20040407125600_ENG_20040407_125600-0057\n# text = The Chernobyl Children's Project (http://www.adiccp.org/home/default.asp) offers several ways to help the children of that region.\n1\tThe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t3\tdet\t3:det\t_\n2\tChernobyl\tChernobyl\tPROPN\tNNP\tNumber=Sing\t3\tcompound\t3:compound\t_\n3-4\tChildren's\t_\t_\t_\t_\t_\t_\t_\t_\n3\tChildrez\tChildren\tPROPN\tNNP\tNumber=Sing\t5\tnmod:poss\t5:nmod:poss\t_\n4\t's\t's\tPART\tPOS\t_\t3\tcase\t3:case\t_\n5\tProject\tProject\tPROPN\tNNP\tNumber=Sing\t9\tnsubj\t9:nsubj\t_\n6\t(\t(\tPUNCT\t-LRB-\t_\t7\tpunct\t7:punct\tSpaceAfter=No\n7\thttp://www.adiccp.org/home/default.asp\thttp://www.adiccp.org/home/default.asp\tX\tADD\t_\t5\tappos\t5:appos\tSpaceAfter=No\n8\t)\t)\tPUNCT\t-RRB-\t_\t7\tpunct\t7:punct\t_\n9\toffers\toffer\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t0:root\t_\n10\tseveral\tseveral\tADJ\tJJ\tDegree=Pos\t11\tamod\t11:amod\t_\n11\tways\tway\tNOUN\tNNS\tNumber=Plur\t9\tobj\t9:obj\t_\n12\tto\tto\tPART\tTO\t_\t13\tmark\t13:mark\t_\n13\thelp\thelp\tVERB\tVB\tVerbForm=Inf\t11\tacl\t11:acl:to\t_\n14\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t15\tdet\t15:det\t_\n15\tchildren\tchild\tNOUN\tNNS\tNumber=Plur\t13\tobj\t13:obj\t_\n16\tof\tof\tADP\tIN\t_\t18\tcase\t18:case\t_\n17\tthat\tthat\tDET\tDT\tNumber=Sing|PronType=Dem\t18\tdet\t18:det\t_\n18\tregion\tregion\tNOUN\tNN\tNumber=Sing\t15\tnmod\t15:nmod:of\tSpaceAfter=No\n19\t.\t.\tPUNCT\t.\t_\t9\tpunct\t9:punct\t_\n\"\"\".lstrip()\n\ndef test_check_mwt_composition(tmp_path):\n    mwt_file = tmp_path / \"good.mwt\"\n    with open(mwt_file, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(SAMPLE_GOOD_TEXT)\n    check_mwt_composition(mwt_file)\n\n    mwt_file = tmp_path / \"bad.mwt\"\n    with open(mwt_file, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(SAMPLE_BAD_TEXT)\n    with pytest.raises(ValueError):\n        check_mwt_composition(mwt_file)\n"
  },
  {
    "path": "stanza/tests/mwt/test_utils.py",
    "content": "\"\"\"\nTest the MWT resplitting of preexisting tokens without word splits\n\"\"\"\n\nimport pytest\n\nimport stanza\nfrom stanza.models.mwt.utils import resplit_mwt\n\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n@pytest.fixture(scope=\"module\")\ndef pipeline():\n    \"\"\"\n    A reusable pipeline with the NER module\n    \"\"\"\n    return stanza.Pipeline(\"en\", dir=TEST_MODELS_DIR, processors=\"tokenize,mwt\", package=\"gum\")\n\n\ndef test_resplit_keep_tokens(pipeline):\n    \"\"\"\n    Test splitting with enforced token boundaries\n    \"\"\"\n    tokens = [[\"I\", \"can't\", \"believe\", \"it\"], [\"I can't\", \"sleep\"]]\n    doc = resplit_mwt(tokens, pipeline)\n    assert len(doc.sentences) == 2\n    assert len(doc.sentences[0].tokens) == 4\n    assert len(doc.sentences[0].tokens[1].words) == 2\n    assert doc.sentences[0].tokens[1].words[0].text == \"ca\"\n    assert doc.sentences[0].tokens[1].words[1].text == \"n't\"\n\n    assert len(doc.sentences[1].tokens) == 2\n    # updated GUM MWT splits \"I can't\" into three segments\n    # the way we want, \"I - ca - n't\"\n    # previously it would split \"I - can - 't\"\n    assert len(doc.sentences[1].tokens[0].words) == 3\n    assert doc.sentences[1].tokens[0].words[0].text == \"I\"\n    assert doc.sentences[1].tokens[0].words[1].text == \"ca\"\n    assert doc.sentences[1].tokens[0].words[2].text == \"n't\"\n\n\ndef test_resplit_no_keep_tokens(pipeline):\n    \"\"\"\n    Test splitting without enforced token boundaries\n    \"\"\"\n    tokens = [[\"I\", \"can't\", \"believe\", \"it\"], [\"I can't\", \"sleep\"]]\n    doc = resplit_mwt(tokens, pipeline, keep_tokens=False)\n    assert len(doc.sentences) == 2\n    assert len(doc.sentences[0].tokens) == 4\n    assert len(doc.sentences[0].tokens[1].words) == 2\n    assert doc.sentences[0].tokens[1].words[0].text == \"ca\"\n    assert doc.sentences[0].tokens[1].words[1].text == \"n't\"\n\n    assert len(doc.sentences[1].tokens) == 3\n    assert len(doc.sentences[1].tokens[1].words) == 2\n    assert doc.sentences[1].tokens[1].words[0].text == \"ca\"\n    assert doc.sentences[1].tokens[1].words[1].text == \"n't\"\n"
  },
  {
    "path": "stanza/tests/ner/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/ner/test_bsf_2_beios.py",
    "content": "\"\"\"\nTests the conversion code for the lang_uk NER dataset\n\"\"\"\n\nimport unittest\nfrom stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo\n\nimport pytest\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nclass TestBsf2Beios(unittest.TestCase):\n    \n    def test_empty_markup(self):\n        res = convert_bsf('', '')\n        self.assertEqual('', res)\n\n    def test_1line_markup(self):\n        data = 'тележурналіст Василь'\n        bsf_markup = 'T1\tPERS 14 20\tВасиль'\n        expected = '''тележурналіст O\nВасиль S-PERS'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup))\n\n    def test_1line_follow_markup(self):\n        data = 'тележурналіст Василь .'\n        bsf_markup = 'T1\tPERS 14 20\tВасиль'\n        expected = '''тележурналіст O\nВасиль S-PERS\n. O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup))\n\n    def test_1line_2tok_markup(self):\n        data = 'тележурналіст Василь Нагірний .'\n        bsf_markup = 'T1\tPERS 14 29\tВасиль Нагірний'\n        expected = '''тележурналіст O\nВасиль B-PERS\nНагірний E-PERS\n. O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup))\n\n    def test_1line_Long_tok_markup(self):\n        data = 'А в музеї Гуцульщини і Покуття можна '\n        bsf_markup = 'T12\tORG 4 30\tмузеї Гуцульщини і Покуття'\n        expected = '''А O\nв O\nмузеї B-ORG\nГуцульщини I-ORG\nі I-ORG\nПокуття E-ORG\nможна O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup))\n\n    def test_2line_2tok_markup(self):\n        data = '''тележурналіст Василь Нагірний .\nВ івано-франківському видавництві «Лілея НВ» вийшла друком'''\n        bsf_markup = '''T1\tPERS 14 29\tВасиль Нагірний\nT2\tORG 67 75\tЛілея НВ'''\n        expected = '''тележурналіст O\nВасиль B-PERS\nНагірний E-PERS\n. O\n\n\nВ O\nівано-франківському O\nвидавництві O\n« O\nЛілея B-ORG\nНВ E-ORG\n» O\nвийшла O\nдруком O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup))\n\n    def test_real_markup(self):\n        data = '''Через напіввоєнний стан в Україні та збільшення телефонних терористичних погроз українці купуватимуть sim-карти тільки за паспортами .\nПро це повідомив начальник управління зв'язків зі ЗМІ адміністрації Держспецзв'язку Віталій Кукса .\nВін зауважив , що днями відомство опублікує проект змін до правил надання телекомунікаційних послуг , де будуть прописані норми ідентифікації громадян .\nАбонентів , які на сьогодні вже мають sim-карту , за словами Віталія Кукси , реєструватимуть , коли ті звертатимуться в службу підтримки свого оператора мобільного зв'язку .\nОднак мобільні оператори побоюються , що таке нововведення помітно зменшить продаж стартових пакетів , адже спеціалізовані магазини є лише у містах .\nВідтак купити сімку в невеликих населених пунктах буде неможливо .\nКрім того , нова процедура ідентифікації абонентів вимагатиме від операторів мобільного зв'язку додаткових витрат .\n- Близько 90 % українських абонентів - це абоненти передоплати .\nЯкщо мова буде йти навіть про поетапну їх ідентифікацію , зробити це буде складно , довго і дорого .\nМобільним операторам доведеться йти на чималі витрати , пов'язані з укладанням і зберіганням договорів , веденням баз даних , - розповіла « Економічній правді » начальник відділу зв'язків з громадськістю « МТС-Україна » Вікторія Рубан .\n'''\n        bsf_markup = '''T1\tLOC 26 33\tУкраїні\nT2\tORG 203 218\tДержспецзв'язку\nT3\tPERS 219 232\tВіталій Кукса\nT4\tPERS 449 462\tВіталія Кукси\nT5\tORG 1201 1219\tЕкономічній правді\nT6\tORG 1267 1278\tМТС-Україна\nT7\tPERS 1281 1295\tВікторія Рубан\n'''\n        expected = '''Через O\nнапіввоєнний O\nстан O\nв O\nУкраїні S-LOC\nта O\nзбільшення O\nтелефонних O\nтерористичних O\nпогроз O\nукраїнці O\nкупуватимуть O\nsim-карти O\nтільки O\nза O\nпаспортами O\n. O\n\n\nПро O\nце O\nповідомив O\nначальник O\nуправління O\nзв'язків O\nзі O\nЗМІ O\nадміністрації O\nДержспецзв'язку S-ORG\nВіталій B-PERS\nКукса E-PERS\n. O\n\n\nВін O\nзауважив O\n, O\nщо O\nднями O\nвідомство O\nопублікує O\nпроект O\nзмін O\nдо O\nправил O\nнадання O\nтелекомунікаційних O\nпослуг O\n, O\nде O\nбудуть O\nпрописані O\nнорми O\nідентифікації O\nгромадян O\n. O\n\n\nАбонентів O\n, O\nякі O\nна O\nсьогодні O\nвже O\nмають O\nsim-карту O\n, O\nза O\nсловами O\nВіталія B-PERS\nКукси E-PERS\n, O\nреєструватимуть O\n, O\nколи O\nті O\nзвертатимуться O\nв O\nслужбу O\nпідтримки O\nсвого O\nоператора O\nмобільного O\nзв'язку O\n. O\n\n\nОднак O\nмобільні O\nоператори O\nпобоюються O\n, O\nщо O\nтаке O\nнововведення O\nпомітно O\nзменшить O\nпродаж O\nстартових O\nпакетів O\n, O\nадже O\nспеціалізовані O\nмагазини O\nє O\nлише O\nу O\nмістах O\n. O\n\n\nВідтак O\nкупити O\nсімку O\nв O\nневеликих O\nнаселених O\nпунктах O\nбуде O\nнеможливо O\n. O\n\n\nКрім O\nтого O\n, O\nнова O\nпроцедура O\nідентифікації O\nабонентів O\nвимагатиме O\nвід O\nоператорів O\nмобільного O\nзв'язку O\nдодаткових O\nвитрат O\n. O\n\n\n- O\nБлизько O\n90 O\n% O\nукраїнських O\nабонентів O\n- O\nце O\nабоненти O\nпередоплати O\n. O\n\n\nЯкщо O\nмова O\nбуде O\nйти O\nнавіть O\nпро O\nпоетапну O\nїх O\nідентифікацію O\n, O\nзробити O\nце O\nбуде O\nскладно O\n, O\nдовго O\nі O\nдорого O\n. O\n\n\nМобільним O\nоператорам O\nдоведеться O\nйти O\nна O\nчималі O\nвитрати O\n, O\nпов'язані O\nз O\nукладанням O\nі O\nзберіганням O\nдоговорів O\n, O\nведенням O\nбаз O\nданих O\n, O\n- O\nрозповіла O\n« O\nЕкономічній B-ORG\nправді E-ORG\n» O\nначальник O\nвідділу O\nзв'язків O\nз O\nгромадськістю O\n« O\nМТС-Україна S-ORG\n» O\nВікторія B-PERS\nРубан E-PERS\n. O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup))\n\n\nclass TestBsf(unittest.TestCase):\n\n    def test_empty_bsf(self):\n        self.assertEqual(parse_bsf(''), [])\n\n    def test_empty2_bsf(self):\n        self.assertEqual(parse_bsf(' \\n \\n'), [])\n\n    def test_1line_bsf(self):\n        bsf = 'T1\tPERS 103 118\tВасиль Нагірний'\n        res = parse_bsf(bsf)\n        expected = BsfInfo('T1', 'PERS', 103, 118, 'Василь Нагірний')\n        self.assertEqual(len(res), 1)\n        self.assertEqual(res, [expected])\n\n    def test_2line_bsf(self):\n        bsf = '''T9\tPERS 778 783\tКарла\nT10\tMISC 814 819\tміста'''\n        res = parse_bsf(bsf)\n        expected = [BsfInfo('T9', 'PERS', 778, 783, 'Карла'),\n                    BsfInfo('T10', 'MISC', 814, 819, 'міста')]\n        self.assertEqual(len(res), 2)\n        self.assertEqual(res, expected)\n\n    def test_multiline_bsf(self):\n        bsf = '''T3\tPERS 220 235\tАндрієм Кіщуком\nT4\tMISC 251 285\tА .\nKubler .\nСвітло і тіні маестро\nT5\tPERS 363 369\tКіблер'''\n        res = parse_bsf(bsf)\n        expected = [BsfInfo('T3', 'PERS', 220, 235, 'Андрієм Кіщуком'),\n                    BsfInfo('T4', 'MISC', 251, 285, '''А .\nKubler .\nСвітло і тіні маестро'''),\n                    BsfInfo('T5', 'PERS', 363, 369, 'Кіблер')]\n        self.assertEqual(len(res), len(expected))\n        self.assertEqual(res, expected)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "stanza/tests/ner/test_bsf_2_iob.py",
    "content": "\"\"\"\nTests the conversion code for the lang_uk NER dataset\n\"\"\"\n\nimport unittest\nfrom stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo\n\nimport pytest\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nclass TestBsf2Iob(unittest.TestCase):\n\n    def test_1line_follow_markup_iob(self):\n        data = 'тележурналіст Василь .'\n        bsf_markup = 'T1\tPERS 14 20\tВасиль'\n        expected = '''тележурналіст O\nВасиль B-PERS\n. O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))\n\n    def test_1line_2tok_markup_iob(self):\n        data = 'тележурналіст Василь Нагірний .'\n        bsf_markup = 'T1\tPERS 14 29\tВасиль Нагірний'\n        expected = '''тележурналіст O\nВасиль B-PERS\nНагірний I-PERS\n. O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))\n\n    def test_1line_Long_tok_markup_iob(self):\n        data = 'А в музеї Гуцульщини і Покуття можна '\n        bsf_markup = 'T12\tORG 4 30\tмузеї Гуцульщини і Покуття'\n        expected = '''А O\nв O\nмузеї B-ORG\nГуцульщини I-ORG\nі I-ORG\nПокуття I-ORG\nможна O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))\n\n    def test_2line_2tok_markup_iob(self):\n        data = '''тележурналіст Василь Нагірний .\nВ івано-франківському видавництві «Лілея НВ» вийшла друком'''\n        bsf_markup = '''T1\tPERS 14 29\tВасиль Нагірний\nT2\tORG 67 75\tЛілея НВ'''\n        expected = '''тележурналіст O\nВасиль B-PERS\nНагірний I-PERS\n. O\n\n\nВ O\nівано-франківському O\nвидавництві O\n« O\nЛілея B-ORG\nНВ I-ORG\n» O\nвийшла O\nдруком O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))\n\n    def test_all_multiline_iob(self):\n        data = '''його книжечка «А .\nKubler .\nСвітло і тіні маестро» .\nПричому'''\n        bsf_markup = '''T4\tMISC 15 49\tА .\nKubler .\nСвітло і тіні маестро\n'''\n        expected = '''його O\nкнижечка O\n« O\nА B-MISC\n. I-MISC\nKubler I-MISC\n. I-MISC\nСвітло I-MISC\nі I-MISC\nтіні I-MISC\nмаестро I-MISC\n» O\n. O\n\n\nПричому O'''\n        self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "stanza/tests/ner/test_combine_ner_datasets.py",
    "content": "import json\nimport os\nimport pytest\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nfrom stanza.models.common.doc import Document\nfrom stanza.tests.ner.test_ner_training import write_temp_file, EN_TRAIN_BIO, EN_DEV_BIO\nfrom stanza.utils.datasets.ner import combine_ner_datasets\n\n\ndef test_combine(tmp_path):\n    \"\"\"\n    Test that if we write two short datasets and combine them, we get back\n    one slightly longer dataset\n\n    To simplify matters, we just use the same input text with longer\n    amounts of text for each shard.\n    \"\"\"\n    SHARDS = (\"train\", \"dev\", \"test\")\n    for s_num, shard in enumerate(SHARDS):\n        t1_json = tmp_path / (\"en_t1.%s.json\" % shard)\n        # eg, 1x, 2x, 3x the test data from test_ner_training\n        write_temp_file(t1_json, \"\\n\\n\".join([EN_TRAIN_BIO] * (s_num + 1)))\n\n        t2_json = tmp_path / (\"en_t2.%s.json\" % shard)\n        write_temp_file(t2_json, \"\\n\\n\".join([EN_DEV_BIO] * (s_num + 1)))\n\n    args = [\"--output_dataset\", \"en_c\", \"en_t1\", \"en_t2\", \"--input_dir\", str(tmp_path), \"--output_dir\", str(tmp_path)]\n    combine_ner_datasets.main(args)\n\n    for s_num, shard in enumerate(SHARDS):\n        filename = tmp_path / (\"en_c.%s.json\" % shard)\n        assert os.path.exists(filename)\n\n        with open(filename, encoding=\"utf-8\") as fin:\n            doc = Document(json.load(fin))\n            assert len(doc.sentences) == (s_num + 1) * 3\n\n"
  },
  {
    "path": "stanza/tests/ner/test_convert_amt.py",
    "content": "\"\"\"\nTest some of the functions used for converting an AMT json to a Stanza json\n\"\"\"\n\n\nimport os\n\nimport pytest\n\nimport stanza\nfrom stanza.utils.datasets.ner import convert_amt\n\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nTEXT = \"Jennifer Sh'reyan has lovely antennae.\"\n\ndef fake_label(label, start_char, end_char):\n    return {'label': label,\n            'startOffset': start_char,\n            'endOffset': end_char}\n\nLABELS = [\n    fake_label('Person', 0, 8),\n    fake_label('Person', 9, 17),\n    fake_label('Person', 0, 17),\n    fake_label('Andorian', 0, 8),\n    fake_label('Appendage', 29, 37),\n    fake_label('Person', 1, 8),\n    fake_label('Person', 0, 7),\n    fake_label('Person', 0, 9),\n    fake_label('Appendage', 29, 38),\n]\n\ndef fake_labels(*indices):\n    return [LABELS[x] for x in indices]\n\ndef fake_docs(*indices):\n    return [(TEXT, fake_labels(*indices))]\n\ndef test_remove_nesting():\n    \"\"\"\n    Test a few orders on nested items to make sure the desired results are coming back\n    \"\"\"\n    # this should be unchanged\n    result = convert_amt.remove_nesting(fake_docs(0, 1))\n    assert result == fake_docs(0, 1)\n\n    # this should be returned sorted\n    result = convert_amt.remove_nesting(fake_docs(0, 4, 1))\n    assert result == fake_docs(0, 1, 4)\n\n    # this should just have one copy\n    result = convert_amt.remove_nesting(fake_docs(0, 0))\n    assert result == fake_docs(0)\n    \n    # outer one preferred\n    result = convert_amt.remove_nesting(fake_docs(0, 2))\n    assert result == fake_docs(2)\n    result = convert_amt.remove_nesting(fake_docs(1, 2))\n    assert result == fake_docs(2)\n    result = convert_amt.remove_nesting(fake_docs(5, 2))\n    assert result == fake_docs(2)\n    # order doesn't matter\n    result = convert_amt.remove_nesting(fake_docs(0, 4, 2))\n    assert result == fake_docs(2, 4)\n    result = convert_amt.remove_nesting(fake_docs(2, 4, 0))\n    assert result == fake_docs(2, 4)\n    \n    # first one preferred\n    result = convert_amt.remove_nesting(fake_docs(0, 3))\n    assert result == fake_docs(0)\n    result = convert_amt.remove_nesting(fake_docs(3, 0))\n    assert result == fake_docs(3)\n\ndef test_process_doc():\n    nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors=\"tokenize\", download_method=None)\n\n    def check_results(doc, *expected):\n        ner = [x[1] for x in doc[0]]\n        assert ner == list(expected)\n\n    # test a standard case of all the values lining up\n    doc = convert_amt.process_doc(TEXT, fake_labels(2, 4), nlp)\n    check_results(doc, \"B-Person\", \"I-Person\", \"O\", \"O\", \"B-Appendage\", \"O\")\n\n    # test a slightly wrong start index\n    doc = convert_amt.process_doc(TEXT, fake_labels(5, 1, 4), nlp)\n    check_results(doc, \"B-Person\", \"B-Person\", \"O\", \"O\", \"B-Appendage\", \"O\")\n\n    # test a slightly wrong end index\n    doc = convert_amt.process_doc(TEXT, fake_labels(6, 1, 4), nlp)\n    check_results(doc, \"B-Person\", \"B-Person\", \"O\", \"O\", \"B-Appendage\", \"O\")\n\n    # test a slightly wronger end index\n    doc = convert_amt.process_doc(TEXT, fake_labels(7, 4), nlp)\n    check_results(doc, \"B-Person\", \"O\", \"O\", \"O\", \"B-Appendage\", \"O\")\n\n    # test a period at the end of a text - should not be captured\n    doc = convert_amt.process_doc(TEXT, fake_labels(7, 8), nlp)\n    check_results(doc, \"B-Person\", \"O\", \"O\", \"O\", \"B-Appendage\", \"O\")\n\n    \n"
  },
  {
    "path": "stanza/tests/ner/test_convert_nkjp.py",
    "content": "import pytest\n\nimport io\nimport os\nimport xml.etree.ElementTree as ET\n\nfrom stanza.utils.datasets.ner.convert_nkjp import MORPH_FILE, NER_FILE, extract_entities_from_subfolder, extract_entities_from_sentence, extract_unassigned_subfolder_entities\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nEXPECTED_ENTITIES = {\n    '1-p': {\n        '1.39-s': [{'ent_id': 'named_1.39-s_n1', 'index': 0, 'orth': 'Sił Zbrojnych', 'ner_type': 'orgName', 'ner_subtype': None, 'targets': ['1.37-seg', '1.38-seg']}],\n        '1.56-s': [],\n        '1.79-s': []\n    },\n    '2-p': {\n        '2.30-s': [],\n        '2.45-s': []\n    },\n    '3-p': {\n        '3.70-s': []\n    }\n}\n\n\n@pytest.fixture(scope=\"module\")\ndef dataset(tmp_path_factory):\n    dataset_path = tmp_path_factory.mktemp(\"nkjp_dataset\")\n    sample_path = dataset_path / \"sample\"\n    os.mkdir(sample_path)\n    ann_path = sample_path / NER_FILE\n    with open(ann_path, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(SAMPLE_ANN)\n    morph_path = sample_path / MORPH_FILE\n    with open(morph_path, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(SAMPLE_MORPHO)\n    return dataset_path\n\nEXPECTED_TOKENS = [\n    {'seg_id': '1.1-seg', 'i': 0, 'orth': '2', 'text': '2', 'tag': '_', 'ner': 'O', 'ner_subtype': None},\n    {'seg_id': '1.37-seg', 'i': 36, 'orth': 'Sił', 'text': 'Sił', 'tag': '_', 'ner': 'B-orgName', 'ner_subtype': None},\n    {'seg_id': '1.38-seg', 'i': 37, 'orth': 'Zbrojnych', 'text': 'Zbrojnych', 'tag': '_', 'ner': 'I-orgName', 'ner_subtype': None},\n]\n\ndef test_extract_entities_from_subfolder(dataset):\n    entities = extract_entities_from_subfolder(\"sample\", dataset)\n    assert len(entities) == 1\n    assert len(entities['1-p']) == 1\n    assert len(entities['1-p']['1.39-s']) == 39\n    assert entities['1-p']['1.39-s']['1.1-seg'] == EXPECTED_TOKENS[0]\n    assert entities['1-p']['1.39-s']['1.37-seg'] == EXPECTED_TOKENS[1]\n    assert entities['1-p']['1.39-s']['1.38-seg'] == EXPECTED_TOKENS[2]\n\n\ndef test_extract_unassigned(dataset):\n    entities = extract_unassigned_subfolder_entities(\"sample\", dataset)\n    assert entities == EXPECTED_ENTITIES\n\nSENTENCE_SAMPLE = \"\"\"\n          <s xmlns=\"http://www.tei-c.org/ns/1.0\" xmlns:xi=\"http://www.w3.org/2001/XInclude\" xml:id=\"named_1.39-s\" corresp=\"ann_morphosyntax.xml#morph_1.39-s\">\n            <seg xml:id=\"named_1.39-s_n1\">\n              <fs type=\"named\">\n                <f name=\"type\">\n                  <symbol value=\"orgName\"/>\n                </f>\n                <f name=\"orth\">\n                  <string>Si&#322; Zbrojnych</string>\n                </f>\n                <f name=\"base\">\n                  <string>Si&#322;y Zbrojne</string>\n                </f>\n                <f name=\"certainty\">\n                  <symbol value=\"high\"/>\n                </f>\n              </fs>\n              <ptr target=\"ann_morphosyntax.xml#morph_1.37-seg\"/>\n              <ptr target=\"ann_morphosyntax.xml#morph_1.38-seg\"/>\n            </seg>\n          </s>\n\"\"\".strip()\n\n\nEMPTY_SENTENCE = \"\"\"<s xml:id=\"named_1.56-s\" corresp=\"ann_morphosyntax.xml#morph_1.56-s\"/>\"\"\"\n\ndef test_extract_entities_from_sentence():\n    rt = ET.fromstring(SENTENCE_SAMPLE)\n    entities = extract_entities_from_sentence(rt)\n    assert entities == EXPECTED_ENTITIES['1-p']['1.39-s']\n\n    rt = ET.fromstring(EMPTY_SENTENCE)\n    entities = extract_entities_from_sentence(rt)\n    assert entities == []\n\n\n\n# picked completely at random, one sample file for testing:\n# 610-1-000248/ann_named.xml\n# only the first sentence is used in the morpho file\nSAMPLE_ANN = \"\"\"\n<?xml version='1.0' encoding='UTF-8'?>\n<teiCorpus xmlns:xi=\"http://www.w3.org/2001/XInclude\" xmlns=\"http://www.tei-c.org/ns/1.0\">\n  <xi:include href=\"NKJP_1M_header.xml\"/>\n  <TEI>\n    <xi:include href=\"header.xml\"/>\n    <text xml:lang=\"pl\">\n      <body>\n        <p xml:id=\"named_1-p\" corresp=\"ann_morphosyntax.xml#morph_1-p\">\n          <s xml:id=\"named_1.39-s\" corresp=\"ann_morphosyntax.xml#morph_1.39-s\">\n            <seg xml:id=\"named_1.39-s_n1\">\n              <fs type=\"named\">\n                <f name=\"type\">\n                  <symbol value=\"orgName\"/>\n                </f>\n                <f name=\"orth\">\n                  <string>Sił Zbrojnych</string>\n                </f>\n                <f name=\"base\">\n                  <string>Siły Zbrojne</string>\n                </f>\n                <f name=\"certainty\">\n                  <symbol value=\"high\"/>\n                </f>\n              </fs>\n              <ptr target=\"ann_morphosyntax.xml#morph_1.37-seg\"/>\n              <ptr target=\"ann_morphosyntax.xml#morph_1.38-seg\"/>\n            </seg>\n          </s>\n          <s xml:id=\"named_1.56-s\" corresp=\"ann_morphosyntax.xml#morph_1.56-s\"/>\n          <s xml:id=\"named_1.79-s\" corresp=\"ann_morphosyntax.xml#morph_1.79-s\"/>\n        </p>\n        <p xml:id=\"named_2-p\" corresp=\"ann_morphosyntax.xml#morph_2-p\">\n          <s xml:id=\"named_2.30-s\" corresp=\"ann_morphosyntax.xml#morph_2.30-s\"/>\n          <s xml:id=\"named_2.45-s\" corresp=\"ann_morphosyntax.xml#morph_2.45-s\"/>\n        </p>\n        <p xml:id=\"named_3-p\" corresp=\"ann_morphosyntax.xml#morph_3-p\">\n          <s xml:id=\"named_3.70-s\" corresp=\"ann_morphosyntax.xml#morph_3.70-s\"/>\n        </p>\n      </body>\n    </text>\n  </TEI>\n</teiCorpus>\n\"\"\".lstrip()\n\n\n\nSAMPLE_MORPHO = \"\"\"\n<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!-- w indeksach elementów wciąganych mogą zdarzyć się nieciągłości (z alternatyw segmentacyjnych)  --><teiCorpus xmlns=\"http://www.tei-c.org/ns/1.0\" xmlns:nkjp=\"http://www.nkjp.pl/ns/1.0\" xmlns:xi=\"http://www.w3.org/2001/XInclude\">\n <xi:include href=\"NKJP_1M_header.xml\"/>\n <TEI>\n  <xi:include href=\"header.xml\"/>\n  <text>\n   <body>\n    <!-- morph_1-p is akapit 7626 with instances (akapit_transzy-s) 15244, 15269 in batches (transza-s) 1525, 1528 resp. -->\n    <p corresp=\"ann_segmentation.xml#segm_1-p\" xml:id=\"morph_1-p\">\n     <s corresp=\"ann_segmentation.xml#segm_1.39-s\" xml:id=\"morph_1.39-s\">\n      <seg corresp=\"ann_segmentation.xml#segm_1.1-seg\" xml:id=\"morph_1.1-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>2</string>\n        </f>\n        <!-- 2 [0,1] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.1.1-lex\">\n          <f name=\"base\">\n           <string/>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"ign\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.1.1.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.1.2-lex\">\n          <f name=\"base\">\n           <string>2</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"adj\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol nkjp:manual=\"true\" value=\"sg:nom:n:pos\" xml:id=\"morph_1.1.2.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.1.2.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>2:adj:sg:nom:n:pos</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.2-seg\" xml:id=\"morph_1.2-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>.</string>\n        </f>\n        <!-- . [1,1] -->\n        <f name=\"nps\">\n         <binary value=\"true\"/>\n        </f>\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.2.1-lex\">\n          <f name=\"base\">\n           <string>.</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interp\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.2.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.2.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>.:interp</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.3-seg\" xml:id=\"morph_1.3-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>Wezwanie</string>\n        </f>\n        <!-- Wezwanie [3,8] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.3.1-lex\">\n          <f name=\"base\">\n           <string>wezwanie</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:nom:n\" xml:id=\"morph_1.3.1.1-msd\"/>\n            <symbol value=\"sg:acc:n\" xml:id=\"morph_1.3.1.2-msd\"/>\n            <symbol value=\"sg:voc:n\" xml:id=\"morph_1.3.1.3-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.3.2-lex\">\n          <f name=\"base\">\n           <string>wezwać</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"ger\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:nom:n:perf:aff\" xml:id=\"morph_1.3.2.1-msd\"/>\n            <symbol value=\"sg:acc:n:perf:aff\" xml:id=\"morph_1.3.2.2-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.3.1.2-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>wezwanie:subst:sg:acc:n</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.4-seg\" xml:id=\"morph_1.4-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>,</string>\n        </f>\n        <!-- , [11,1] -->\n        <f name=\"nps\">\n         <binary value=\"true\"/>\n        </f>\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.4.1-lex\">\n          <f name=\"base\">\n           <string>,</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interp\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.4.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.4.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>,:interp</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.5-seg\" xml:id=\"morph_1.5-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>o</string>\n        </f>\n        <!-- o [13,1] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.5.1-lex\">\n          <f name=\"base\">\n           <string>o</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interj\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.5.1.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.5.2-lex\">\n          <f name=\"base\">\n           <string>o</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"prep\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"acc\" xml:id=\"morph_1.5.2.1-msd\"/>\n            <symbol value=\"loc\" xml:id=\"morph_1.5.2.2-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.5.3-lex\">\n          <f name=\"base\">\n           <string>ojciec</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.5.3.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.5.2.2-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>o:prep:loc</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.6-seg\" xml:id=\"morph_1.6-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>którym</string>\n        </f>\n        <!-- którym [15,6] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.6.1-lex\">\n          <f name=\"base\">\n           <string>który</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"adj\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:inst:m1:pos\" xml:id=\"morph_1.6.1.1-msd\"/>\n            <symbol value=\"sg:inst:m2:pos\" xml:id=\"morph_1.6.1.2-msd\"/>\n            <symbol value=\"sg:inst:m3:pos\" xml:id=\"morph_1.6.1.3-msd\"/>\n            <symbol value=\"sg:inst:n:pos\" xml:id=\"morph_1.6.1.4-msd\"/>\n            <symbol value=\"sg:loc:m1:pos\" xml:id=\"morph_1.6.1.5-msd\"/>\n            <symbol value=\"sg:loc:m2:pos\" xml:id=\"morph_1.6.1.6-msd\"/>\n            <symbol value=\"sg:loc:m3:pos\" xml:id=\"morph_1.6.1.7-msd\"/>\n            <symbol value=\"sg:loc:n:pos\" xml:id=\"morph_1.6.1.8-msd\"/>\n            <symbol value=\"pl:dat:m1:pos\" xml:id=\"morph_1.6.1.9-msd\"/>\n            <symbol value=\"pl:dat:m2:pos\" xml:id=\"morph_1.6.1.10-msd\"/>\n            <symbol value=\"pl:dat:m3:pos\" xml:id=\"morph_1.6.1.11-msd\"/>\n            <symbol value=\"pl:dat:f:pos\" xml:id=\"morph_1.6.1.12-msd\"/>\n            <symbol value=\"pl:dat:n:pos\" xml:id=\"morph_1.6.1.13-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.6.1.8-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>który:adj:sg:loc:n:pos</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.7-seg\" xml:id=\"morph_1.7-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>mowa</string>\n        </f>\n        <!-- mowa [22,4] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.7.1-lex\">\n          <f name=\"base\">\n           <string>mowa</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:nom:f\" xml:id=\"morph_1.7.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.7.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>mowa:subst:sg:nom:f</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.8-seg\" xml:id=\"morph_1.8-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>w</string>\n        </f>\n        <!-- w [27,1] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.8.1-lex\">\n          <f name=\"base\">\n           <string>w</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"prep\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"acc:nwok\" xml:id=\"morph_1.8.1.1-msd\"/>\n            <symbol value=\"loc:nwok\" xml:id=\"morph_1.8.1.2-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.8.2-lex\">\n          <f name=\"base\">\n           <string>wiek</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.8.2.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.8.3-lex\">\n          <f name=\"base\">\n           <string>wielki</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.8.3.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.8.4-lex\">\n          <f name=\"base\">\n           <string>wiersz</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.8.4.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.8.5-lex\">\n          <f name=\"base\">\n           <string>wieś</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.8.5.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.8.6-lex\">\n          <f name=\"base\">\n           <string>wyspa</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.8.6.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.8.1.2-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>w:prep:loc:nwok</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.9-seg\" xml:id=\"morph_1.9-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>ust</string>\n        </f>\n        <!-- ust [29,3] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.9.1-lex\">\n          <f name=\"base\">\n           <string>usta</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pl:gen:n\" xml:id=\"morph_1.9.1.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.9.2-lex\">\n          <f name=\"base\">\n           <string>ustęp</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol nkjp:manual=\"true\" value=\"pun\" xml:id=\"morph_1.9.2.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.9.2.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>ustęp:brev:pun</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.10-seg\" xml:id=\"morph_1.10-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>.</string>\n        </f>\n        <!-- . [32,1] -->\n        <f name=\"nps\">\n         <binary value=\"true\"/>\n        </f>\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.10.1-lex\">\n          <f name=\"base\">\n           <string>.</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interp\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.10.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.10.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>.:interp</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.11-seg\" xml:id=\"morph_1.11-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>1</string>\n        </f>\n        <!-- 1 [34,1] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.11.1-lex\">\n          <f name=\"base\">\n           <string/>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"ign\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.11.1.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.11.2-lex\">\n          <f name=\"base\">\n           <string>1</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"adj\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol nkjp:manual=\"true\" value=\"sg:loc:m3:pos\" xml:id=\"morph_1.11.2.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.11.2.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>1:adj:sg:loc:m3:pos</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.12-seg\" xml:id=\"morph_1.12-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>,</string>\n        </f>\n        <!-- , [35,1] -->\n        <f name=\"nps\">\n         <binary value=\"true\"/>\n        </f>\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.12.1-lex\">\n          <f name=\"base\">\n           <string>,</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interp\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.12.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.12.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>,:interp</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.13-seg\" xml:id=\"morph_1.13-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>doręcza</string>\n        </f>\n        <!-- doręcza [37,7] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.13.1-lex\">\n          <f name=\"base\">\n           <string>doręczać</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"fin\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:ter:imperf\" xml:id=\"morph_1.13.1.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.13.2-lex\">\n          <f name=\"base\">\n           <string>doręcze</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:gen:n\" xml:id=\"morph_1.13.2.1-msd\"/>\n            <symbol value=\"pl:nom:n\" xml:id=\"morph_1.13.2.2-msd\"/>\n            <symbol value=\"pl:acc:n\" xml:id=\"morph_1.13.2.3-msd\"/>\n            <symbol value=\"pl:voc:n\" xml:id=\"morph_1.13.2.4-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.13.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>doręczać:fin:sg:ter:imperf</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.14-seg\" xml:id=\"morph_1.14-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>się</string>\n        </f>\n        <!-- się [45,3] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.14.1-lex\">\n          <f name=\"base\">\n           <string>się</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"qub\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.14.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.14.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>się:qub</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.15-seg\" xml:id=\"morph_1.15-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>na</string>\n        </f>\n        <!-- na [49,2] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.15.1-lex\">\n          <f name=\"base\">\n           <string>na</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interj\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.15.1.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.15.2-lex\">\n          <f name=\"base\">\n           <string>na</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"prep\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"acc\" xml:id=\"morph_1.15.2.1-msd\"/>\n            <symbol value=\"loc\" xml:id=\"morph_1.15.2.2-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.15.2.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>na:prep:acc</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.16-seg\" xml:id=\"morph_1.16-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>czternaście</string>\n        </f>\n        <!-- czternaście [52,11] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.16.1-lex\">\n          <f name=\"base\">\n           <string>czternaście</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"num\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"pl:nom:m2:rec\" xml:id=\"morph_1.16.1.1-msd\"/>\n            <symbol value=\"pl:nom:m3:rec\" xml:id=\"morph_1.16.1.2-msd\"/>\n            <symbol value=\"pl:nom:f:rec\" xml:id=\"morph_1.16.1.3-msd\"/>\n            <symbol value=\"pl:nom:n:rec\" xml:id=\"morph_1.16.1.4-msd\"/>\n            <symbol value=\"pl:acc:m2:rec\" xml:id=\"morph_1.16.1.5-msd\"/>\n            <symbol value=\"pl:acc:m3:rec\" xml:id=\"morph_1.16.1.6-msd\"/>\n            <symbol value=\"pl:acc:f:rec\" xml:id=\"morph_1.16.1.7-msd\"/>\n            <symbol value=\"pl:acc:n:rec\" xml:id=\"morph_1.16.1.8-msd\"/>\n            <symbol value=\"pl:voc:m2:rec\" xml:id=\"morph_1.16.1.9-msd\"/>\n            <symbol value=\"pl:voc:m3:rec\" xml:id=\"morph_1.16.1.10-msd\"/>\n            <symbol value=\"pl:voc:f:rec\" xml:id=\"morph_1.16.1.11-msd\"/>\n            <symbol value=\"pl:voc:n:rec\" xml:id=\"morph_1.16.1.12-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.16.1.6-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>czternaście:num:pl:acc:m3:rec</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.17-seg\" xml:id=\"morph_1.17-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>dni</string>\n        </f>\n        <!-- dni [64,3] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.17.1-lex\">\n          <f name=\"base\">\n           <string>dni</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"pl:nom:n\" xml:id=\"morph_1.17.1.1-msd\"/>\n            <symbol value=\"pl:gen:n\" xml:id=\"morph_1.17.1.2-msd\"/>\n            <symbol value=\"pl:acc:n\" xml:id=\"morph_1.17.1.3-msd\"/>\n            <symbol value=\"pl:voc:n\" xml:id=\"morph_1.17.1.4-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.17.2-lex\">\n          <f name=\"base\">\n           <string>dzień</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"pl:nom:m3\" xml:id=\"morph_1.17.2.1-msd\"/>\n            <symbol value=\"pl:gen:m3\" xml:id=\"morph_1.17.2.2-msd\"/>\n            <symbol value=\"pl:acc:m3\" xml:id=\"morph_1.17.2.3-msd\"/>\n            <symbol value=\"pl:voc:m3\" xml:id=\"morph_1.17.2.4-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.17.2.2-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>dzień:subst:pl:gen:m3</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.18-seg\" xml:id=\"morph_1.18-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>przed</string>\n        </f>\n        <!-- przed [68,5] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.18.1-lex\">\n          <f name=\"base\">\n           <string>przed</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"prep\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"acc:nwok\" xml:id=\"morph_1.18.1.1-msd\"/>\n            <symbol value=\"inst:nwok\" xml:id=\"morph_1.18.1.2-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.18.1.2-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>przed:prep:inst:nwok</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.19-seg\" xml:id=\"morph_1.19-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>terminem</string>\n        </f>\n        <!-- terminem [74,8] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.19.1-lex\">\n          <f name=\"base\">\n           <string>termin</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:inst:m3\" xml:id=\"morph_1.19.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.19.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>termin:subst:sg:inst:m3</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.20-seg\" xml:id=\"morph_1.20-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>wykonania</string>\n        </f>\n        <!-- wykonania [83,9] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.20.1-lex\">\n          <f name=\"base\">\n           <string>wykonanie</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:gen:n\" xml:id=\"morph_1.20.1.1-msd\"/>\n            <symbol value=\"pl:nom:n\" xml:id=\"morph_1.20.1.2-msd\"/>\n            <symbol value=\"pl:acc:n\" xml:id=\"morph_1.20.1.3-msd\"/>\n            <symbol value=\"pl:voc:n\" xml:id=\"morph_1.20.1.4-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.20.2-lex\">\n          <f name=\"base\">\n           <string>wykonać</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"ger\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:gen:n:perf:aff\" xml:id=\"morph_1.20.2.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.20.2.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>wykonać:ger:sg:gen:n:perf:aff</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.21-seg\" xml:id=\"morph_1.21-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>świadczenia</string>\n        </f>\n        <!-- świadczenia [93,11] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.21.1-lex\">\n          <f name=\"base\">\n           <string>świadczenie</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:gen:n\" xml:id=\"morph_1.21.1.1-msd\"/>\n            <symbol value=\"pl:nom:n\" xml:id=\"morph_1.21.1.2-msd\"/>\n            <symbol value=\"pl:acc:n\" xml:id=\"morph_1.21.1.3-msd\"/>\n            <symbol value=\"pl:voc:n\" xml:id=\"morph_1.21.1.4-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.21.2-lex\">\n          <f name=\"base\">\n           <string>świadczyć</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"ger\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:gen:n:imperf:aff\" xml:id=\"morph_1.21.2.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.21.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>świadczenie:subst:sg:gen:n</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.22-seg\" xml:id=\"morph_1.22-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>,</string>\n        </f>\n        <!-- , [104,1] -->\n        <f name=\"nps\">\n         <binary value=\"true\"/>\n        </f>\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.22.1-lex\">\n          <f name=\"base\">\n           <string>,</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interp\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.22.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.22.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>,:interp</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.23-seg\" xml:id=\"morph_1.23-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>z</string>\n        </f>\n        <!-- z [106,1] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.23.1-lex\">\n          <f name=\"base\">\n           <string>z</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"prep\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"gen:nwok\" xml:id=\"morph_1.23.1.1-msd\"/>\n            <symbol value=\"acc:nwok\" xml:id=\"morph_1.23.1.2-msd\"/>\n            <symbol value=\"inst:nwok\" xml:id=\"morph_1.23.1.3-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.23.2-lex\">\n          <f name=\"base\">\n           <string>z</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"qub\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.23.2.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.23.3-lex\">\n          <f name=\"base\">\n           <string>zeszyt</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.23.3.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.23.1.3-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>z:prep:inst:nwok</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.24-seg\" xml:id=\"morph_1.24-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>wyjątkiem</string>\n        </f>\n        <!-- wyjątkiem [108,9] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.24.1-lex\">\n          <f name=\"base\">\n           <string>wyjątek</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:inst:m3\" xml:id=\"morph_1.24.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.24.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>wyjątek:subst:sg:inst:m3</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.25-seg\" xml:id=\"morph_1.25-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>przypadków</string>\n        </f>\n        <!-- przypadków [118,10] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.25.1-lex\">\n          <f name=\"base\">\n           <string>przypadek</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pl:gen:m3\" xml:id=\"morph_1.25.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.25.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>przypadek:subst:pl:gen:m3</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.26-seg\" xml:id=\"morph_1.26-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>,</string>\n        </f>\n        <!-- , [128,1] -->\n        <f name=\"nps\">\n         <binary value=\"true\"/>\n        </f>\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.26.1-lex\">\n          <f name=\"base\">\n           <string>,</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interp\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.26.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.26.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>,:interp</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.27-seg\" xml:id=\"morph_1.27-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>w</string>\n        </f>\n        <!-- w [130,1] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.27.1-lex\">\n          <f name=\"base\">\n           <string>w</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"prep\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"acc:nwok\" xml:id=\"morph_1.27.1.1-msd\"/>\n            <symbol value=\"loc:nwok\" xml:id=\"morph_1.27.1.2-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.27.2-lex\">\n          <f name=\"base\">\n           <string>wiek</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.27.2.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.27.3-lex\">\n          <f name=\"base\">\n           <string>wielki</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.27.3.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.27.4-lex\">\n          <f name=\"base\">\n           <string>wiersz</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.27.4.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.27.5-lex\">\n          <f name=\"base\">\n           <string>wieś</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.27.5.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.27.6-lex\">\n          <f name=\"base\">\n           <string>wyspa</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.27.6.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.27.1.2-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>w:prep:loc:nwok</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.28-seg\" xml:id=\"morph_1.28-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>których</string>\n        </f>\n        <!-- których [132,7] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.28.1-lex\">\n          <f name=\"base\">\n           <string>który</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"adj\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"pl:gen:m1:pos\" xml:id=\"morph_1.28.1.1-msd\"/>\n            <symbol value=\"pl:gen:m2:pos\" xml:id=\"morph_1.28.1.2-msd\"/>\n            <symbol value=\"pl:gen:m3:pos\" xml:id=\"morph_1.28.1.3-msd\"/>\n            <symbol value=\"pl:gen:f:pos\" xml:id=\"morph_1.28.1.4-msd\"/>\n            <symbol value=\"pl:gen:n:pos\" xml:id=\"morph_1.28.1.5-msd\"/>\n            <symbol value=\"pl:loc:m1:pos\" xml:id=\"morph_1.28.1.6-msd\"/>\n            <symbol value=\"pl:loc:m2:pos\" xml:id=\"morph_1.28.1.7-msd\"/>\n            <symbol value=\"pl:loc:m3:pos\" xml:id=\"morph_1.28.1.8-msd\"/>\n            <symbol value=\"pl:loc:f:pos\" xml:id=\"morph_1.28.1.9-msd\"/>\n            <symbol value=\"pl:loc:n:pos\" xml:id=\"morph_1.28.1.10-msd\"/>\n            <symbol value=\"pl:acc:m1:pos\" xml:id=\"morph_1.28.1.11-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.28.1.8-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>który:adj:pl:loc:m3:pos</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.29-seg\" xml:id=\"morph_1.29-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>wykonanie</string>\n        </f>\n        <!-- wykonanie [140,9] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.29.1-lex\">\n          <f name=\"base\">\n           <string>wykonanie</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:nom:n\" xml:id=\"morph_1.29.1.1-msd\"/>\n            <symbol value=\"sg:acc:n\" xml:id=\"morph_1.29.1.2-msd\"/>\n            <symbol value=\"sg:voc:n\" xml:id=\"morph_1.29.1.3-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.29.2-lex\">\n          <f name=\"base\">\n           <string>wykonać</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"ger\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:nom:n:perf:aff\" xml:id=\"morph_1.29.2.1-msd\"/>\n            <symbol value=\"sg:acc:n:perf:aff\" xml:id=\"morph_1.29.2.2-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.29.2.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>wykonać:ger:sg:nom:n:perf:aff</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.30-seg\" xml:id=\"morph_1.30-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>świadczenia</string>\n        </f>\n        <!-- świadczenia [150,11] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.30.1-lex\">\n          <f name=\"base\">\n           <string>świadczenie</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:gen:n\" xml:id=\"morph_1.30.1.1-msd\"/>\n            <symbol value=\"pl:nom:n\" xml:id=\"morph_1.30.1.2-msd\"/>\n            <symbol value=\"pl:acc:n\" xml:id=\"morph_1.30.1.3-msd\"/>\n            <symbol value=\"pl:voc:n\" xml:id=\"morph_1.30.1.4-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.30.2-lex\">\n          <f name=\"base\">\n           <string>świadczyć</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"ger\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:gen:n:imperf:aff\" xml:id=\"morph_1.30.2.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.30.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>świadczenie:subst:sg:gen:n</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.31-seg\" xml:id=\"morph_1.31-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>następuje</string>\n        </f>\n        <!-- następuje [162,9] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.31.1-lex\">\n          <f name=\"base\">\n           <string>następować</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"fin\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:ter:imperf\" xml:id=\"morph_1.31.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.31.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>następować:fin:sg:ter:imperf</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.32-seg\" xml:id=\"morph_1.32-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>w</string>\n        </f>\n        <!-- w [172,1] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.32.1-lex\">\n          <f name=\"base\">\n           <string>w</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"prep\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"acc:nwok\" xml:id=\"morph_1.32.1.1-msd\"/>\n            <symbol value=\"loc:nwok\" xml:id=\"morph_1.32.1.2-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.32.2-lex\">\n          <f name=\"base\">\n           <string>wiek</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.32.2.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.32.3-lex\">\n          <f name=\"base\">\n           <string>wielki</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.32.3.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.32.4-lex\">\n          <f name=\"base\">\n           <string>wiersz</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.32.4.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.32.5-lex\">\n          <f name=\"base\">\n           <string>wieś</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.32.5.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.32.6-lex\">\n          <f name=\"base\">\n           <string>wyspa</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"brev\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pun\" xml:id=\"morph_1.32.6.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.32.1.2-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>w:prep:loc:nwok</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.33-seg\" xml:id=\"morph_1.33-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>celu</string>\n        </f>\n        <!-- celu [174,4] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.33.1-lex\">\n          <f name=\"base\">\n           <string>Cela</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:voc:f\" xml:id=\"morph_1.33.1.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.33.2-lex\">\n          <f name=\"base\">\n           <string>cel</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:gen:m3\" xml:id=\"morph_1.33.2.1-msd\"/>\n            <symbol value=\"sg:loc:m3\" xml:id=\"morph_1.33.2.2-msd\"/>\n            <symbol value=\"sg:voc:m3\" xml:id=\"morph_1.33.2.3-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.33.2.2-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>cel:subst:sg:loc:m3</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.34-seg\" xml:id=\"morph_1.34-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>sprawdzenia</string>\n        </f>\n        <!-- sprawdzenia [179,11] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.34.1-lex\">\n          <f name=\"base\">\n           <string>sprawdzić</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"ger\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"sg:gen:n:perf:aff\" xml:id=\"morph_1.34.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.34.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>sprawdzić:ger:sg:gen:n:perf:aff</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.35-seg\" xml:id=\"morph_1.35-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>gotowości</string>\n        </f>\n        <!-- gotowości [191,9] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.35.1-lex\">\n          <f name=\"base\">\n           <string>gotowość</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:gen:f\" xml:id=\"morph_1.35.1.1-msd\"/>\n            <symbol value=\"sg:dat:f\" xml:id=\"morph_1.35.1.2-msd\"/>\n            <symbol value=\"sg:loc:f\" xml:id=\"morph_1.35.1.3-msd\"/>\n            <symbol value=\"sg:voc:f\" xml:id=\"morph_1.35.1.4-msd\"/>\n            <symbol value=\"pl:nom:f\" xml:id=\"morph_1.35.1.5-msd\"/>\n            <symbol value=\"pl:gen:f\" xml:id=\"morph_1.35.1.6-msd\"/>\n            <symbol value=\"pl:acc:f\" xml:id=\"morph_1.35.1.7-msd\"/>\n            <symbol value=\"pl:voc:f\" xml:id=\"morph_1.35.1.8-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.35.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>gotowość:subst:sg:gen:f</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.36-seg\" xml:id=\"morph_1.36-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>mobilizacyjnej</string>\n        </f>\n        <!-- mobilizacyjnej [201,14] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.36.1-lex\">\n          <f name=\"base\">\n           <string>mobilizacyjny</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"adj\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"sg:gen:f:pos\" xml:id=\"morph_1.36.1.1-msd\"/>\n            <symbol value=\"sg:dat:f:pos\" xml:id=\"morph_1.36.1.2-msd\"/>\n            <symbol value=\"sg:loc:f:pos\" xml:id=\"morph_1.36.1.3-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.36.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>mobilizacyjny:adj:sg:gen:f:pos</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.37-seg\" xml:id=\"morph_1.37-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>Sił</string>\n        </f>\n        <!-- Sił [216,3] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.37.1-lex\">\n          <f name=\"base\">\n           <string>siła</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pl:gen:f\" xml:id=\"morph_1.37.1.1-msd\"/>\n          </f>\n         </fs>\n         <fs type=\"lex\" xml:id=\"morph_1.37.2-lex\">\n          <f name=\"base\">\n           <string>siły</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"subst\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"pl:gen:n\" xml:id=\"morph_1.37.2.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.37.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>siła:subst:pl:gen:f</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.38-seg\" xml:id=\"morph_1.38-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>Zbrojnych</string>\n        </f>\n        <!-- Zbrojnych [220,9] -->\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.38.1-lex\">\n          <f name=\"base\">\n           <string>zbrojny</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"adj\"/>\n          </f>\n          <f name=\"msd\">\n           <vAlt>\n            <symbol value=\"pl:gen:m1:pos\" xml:id=\"morph_1.38.1.1-msd\"/>\n            <symbol value=\"pl:gen:m2:pos\" xml:id=\"morph_1.38.1.2-msd\"/>\n            <symbol value=\"pl:gen:m3:pos\" xml:id=\"morph_1.38.1.3-msd\"/>\n            <symbol value=\"pl:gen:f:pos\" xml:id=\"morph_1.38.1.4-msd\"/>\n            <symbol value=\"pl:gen:n:pos\" xml:id=\"morph_1.38.1.5-msd\"/>\n            <symbol value=\"pl:loc:m1:pos\" xml:id=\"morph_1.38.1.6-msd\"/>\n            <symbol value=\"pl:loc:m2:pos\" xml:id=\"morph_1.38.1.7-msd\"/>\n            <symbol value=\"pl:loc:m3:pos\" xml:id=\"morph_1.38.1.8-msd\"/>\n            <symbol value=\"pl:loc:f:pos\" xml:id=\"morph_1.38.1.9-msd\"/>\n            <symbol value=\"pl:loc:n:pos\" xml:id=\"morph_1.38.1.10-msd\"/>\n            <symbol value=\"pl:acc:m1:pos\" xml:id=\"morph_1.38.1.11-msd\"/>\n           </vAlt>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.38.1.4-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>zbrojny:adj:pl:gen:f:pos</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n      <seg corresp=\"ann_segmentation.xml#segm_1.39-seg\" xml:id=\"morph_1.39-seg\">\n       <fs type=\"morph\">\n        <f name=\"orth\">\n         <string>.</string>\n        </f>\n        <!-- . [229,1] -->\n        <f name=\"nps\">\n         <binary value=\"true\"/>\n        </f>\n        <f name=\"interps\">\n         <fs type=\"lex\" xml:id=\"morph_1.39.1-lex\">\n          <f name=\"base\">\n           <string>.</string>\n          </f>\n          <f name=\"ctag\">\n           <symbol value=\"interp\"/>\n          </f>\n          <f name=\"msd\">\n           <symbol value=\"\" xml:id=\"morph_1.39.1.1-msd\"/>\n          </f>\n         </fs>\n        </f>\n        <f name=\"disamb\">\n         <fs feats=\"#an8003\" type=\"tool_report\">\n          <f fVal=\"#morph_1.39.1.1-msd\" name=\"choice\"/>\n          <f name=\"interpretation\">\n           <string>.:interp</string>\n          </f>\n         </fs>\n        </f>\n       </fs>\n      </seg>\n     </s>\n    </p>\n   </body>\n  </text>\n </TEI>\n</teiCorpus>\n\"\"\".lstrip()\n"
  },
  {
    "path": "stanza/tests/ner/test_convert_starlang_ner.py",
    "content": "\"\"\"\nTest a couple different classes of trees to check the output of the Starlang conversion for NER\n\"\"\"\n\nimport os\nimport tempfile\n\nimport pytest\n\nfrom stanza.utils.datasets.ner import convert_starlang_ner\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nTREE=\"( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}))  (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v}))  (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE}))  )\"\n\ndef test_read_tree():\n    \"\"\"\n    Test a basic tree read\n    \"\"\"\n    sentence = convert_starlang_ner.read_tree(TREE)\n    expected = [('Bayan', 'PERSON'), ('Haag', 'PERSON'), ('Elianti', 'O'), ('çalar', 'O'), ('.', 'O')]\n    assert sentence == expected\n\n"
  },
  {
    "path": "stanza/tests/ner/test_data.py",
    "content": "import json\nimport pytest\n\nfrom stanza.models import ner_tagger\nfrom stanza.models.common.doc import Document\nfrom stanza.models.ner.data import DataLoader\nfrom stanza.tests import TEST_WORKING_DIR\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\n\nONE_SENTENCE = \"\"\"\n[\n [\n  {\n   \"text\": \"EU\",\n   \"ner\": \"B-ORG\"\n  },\n  {\n   \"text\": \"rejects\",\n   \"ner\": \"O\"\n  },\n  {\n   \"text\": \"German\",\n   \"ner\": \"B-MISC\"\n  },\n  {\n   \"text\": \"call\",\n   \"ner\": \"O\"\n  },\n  {\n   \"text\": \"to\",\n   \"ner\": \"O\"\n  },\n  {\n   \"text\": \"boycott\",\n   \"ner\": \"O\"\n  },\n  {\n   \"text\": \"Mox\",\n   \"ner\": \"B-MISC\"\n  },\n  {\n   \"text\": \"Opal\",\n   \"ner\": \"I-MISC\"\n  },\n  {\n   \"text\": \".\",\n   \"ner\": \"O\"\n  }\n ]\n]\n\"\"\"\n\n@pytest.fixture(scope=\"module\")\ndef pretrain_file():\n    return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'\n\n\n@pytest.fixture(scope=\"module\")\ndef one_sentence_json_path(tmpdir_factory):\n    filename = tmpdir_factory.mktemp('data').join(\"sentence.json\")\n    with open(filename, 'w') as fout:\n        fout.write(ONE_SENTENCE)\n    return filename\n\n\ndef test_build_vocab(pretrain_file, one_sentence_json_path, tmp_path):\n    \"\"\"\n    Test that when loading a data file, we get back \n    \"\"\"\n    args = ner_tagger.parse_args([\"--wordvec_pretrain_file\", pretrain_file])\n    pt = ner_tagger.load_pretrain(args)\n\n    with open(one_sentence_json_path) as fin:\n        train_doc = Document(json.load(fin))\n\n    train_batch = DataLoader(train_doc, args['batch_size'], args, pt, vocab=None, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])\n\n    vocab = train_batch.vocab\n    pt_words = list(vocab['word'])\n    assert pt_words == ['<PAD>', '<UNK>', '<EMPTY>', '<ROOT>', 'unban', 'mox', 'opal']\n    delta_words = list(vocab['delta'])\n    assert delta_words == ['<PAD>', '<UNK>', '<EMPTY>', '<ROOT>', 'eu', 'rejects', 'german', 'call', 'to', 'boycott', 'mox', 'opal', '.']\n    tags = list(vocab['tag'])\n    assert tags == [['<PAD>'], ['<UNK>'], [], ['<ROOT>'], ['S-ORG'], ['O'], ['S-MISC'], ['B-MISC'], ['E-MISC']]\n\n\ndef test_build_vocab_ignore_repeats(pretrain_file, one_sentence_json_path, tmp_path):\n    \"\"\"\n    Test that when loading a data file, we get back \n    \"\"\"\n    args = ner_tagger.parse_args([\"--wordvec_pretrain_file\", pretrain_file, \"--emb_finetune_known_only\"])\n    pt = ner_tagger.load_pretrain(args)\n\n    with open(one_sentence_json_path) as fin:\n        train_doc = Document(json.load(fin))\n\n    train_batch = DataLoader(train_doc, args['batch_size'], args, pt, vocab=None, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])\n\n    vocab = train_batch.vocab\n    pt_words = list(vocab['word'])\n    assert pt_words == ['<PAD>', '<UNK>', '<EMPTY>', '<ROOT>', 'unban', 'mox', 'opal']\n    delta_words = list(vocab['delta'])\n    assert delta_words == ['<PAD>', '<UNK>', '<EMPTY>', '<ROOT>', 'mox', 'opal']\n    tags = list(vocab['tag'])\n    assert tags == [['<PAD>'], ['<UNK>'], [], ['<ROOT>'], ['S-ORG'], ['O'], ['S-MISC'], ['B-MISC'], ['E-MISC']]\n"
  },
  {
    "path": "stanza/tests/ner/test_from_conllu.py",
    "content": "import pytest\n\nfrom stanza import Pipeline\nfrom stanza.utils.conll import CoNLL\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef test_from_conllu():\n    \"\"\"\n    If the doc does not have the entire text available, make sure it still safely processes the text\n\n    Test case supplied from user - see issue #1428\n    \"\"\"\n    pipe = Pipeline(\"en\", dir=TEST_MODELS_DIR, processors=\"tokenize,ner\", download_method=None)\n    doc = pipe(\"In February, I traveled to Seattle.  Dr. Pritchett gave me a new hip\")\n    ents = [x.text for x in doc.ents]\n    # the default NER model ought to find these three\n    assert ents == ['February', 'Seattle', 'Pritchett']\n\n    doc_conllu = \"{:C}\\n\\n\".format(doc)\n    doc = CoNLL.conll2doc(input_str=doc_conllu)\n    pipe = Pipeline(\"en\", dir=TEST_MODELS_DIR, processors=\"tokenize,ner\", tokenize_pretokenized=True, download_method=None)\n    pipe(doc)\n    ents = [x.text for x in doc.ents]\n    # this should still work when processed from a CoNLLu document\n    # the bug previously caused a crash because the text to construct\n    # the entities was not available, since the Document wouldn't have\n    # the entire document text available\n    assert ents == ['February', 'Seattle', 'Pritchett']\n"
  },
  {
    "path": "stanza/tests/ner/test_models_ner_scorer.py",
    "content": "\"\"\"\nSimple test of the scorer module for NER\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza.tests import *\nfrom stanza.models.ner.scorer import score_by_token, score_by_entity\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_ner_scorer():\n    pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'],\n                    ['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']]\n    gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'],\n                    ['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']]\n    \n    token_p, token_r, token_f, confusion = score_by_token(pred_sequences, gold_sequences)\n    assert pytest.approx(token_p, abs=0.00001) == 0.625\n    assert pytest.approx(token_r, abs=0.00001) == 0.5\n    assert pytest.approx(token_f, abs=0.00001) == 0.55555\n\n    entity_p, entity_r, entity_f, entity_f1 = score_by_entity(pred_sequences, gold_sequences)\n    assert pytest.approx(entity_p, abs=0.00001) == 0.4\n    assert pytest.approx(entity_r, abs=0.00001) == 0.33333\n    assert pytest.approx(entity_f, abs=0.00001) == 0.36363\n    assert entity_f1 == {'LOC': 0.0, 'MISC': 1.0, 'ORG': 0.0, 'PER': 0.5}\n"
  },
  {
    "path": "stanza/tests/ner/test_ner_tagger.py",
    "content": "\"\"\"\nBasic testing of the NER tagger.\n\"\"\"\n\nimport os\nimport pytest\nimport stanza\n\nfrom stanza.tests import *\nfrom stanza.models import ner_tagger\nfrom stanza.utils.confusion import confusion_to_macro_f1\nimport stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file\nfrom stanza.utils.training.run_ner import build_pretrain_args\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nEN_DOC = \"Chris Manning is a good man. He works in Stanford University.\"\n\nEN_DOC_GOLD = \"\"\"\n<Span text=Chris Manning;type=PERSON;start_char=0;end_char=13>\n<Span text=Stanford University;type=ORG;start_char=41;end_char=60>\n\"\"\".strip()\n\nEN_BIO = \"\"\"\nChris B-PERSON\nManning E-PERSON\nis O\na O\ngood O\nman O\n. O\n\nHe O\nworks O\nin O\nStanford B-ORG\nUniversity E-ORG\n. O\n\"\"\".strip().replace(\" \", \"\\t\")\n\nEN_EXPECTED_OUTPUT = \"\"\"\nChris B-PERSON B-PERSON\nManning E-PERSON E-PERSON\nis O O\na O O\ngood O O\nman O O\n. O O\n\nHe O O\nworks O O\nin O O\nStanford B-ORG B-ORG\nUniversity E-ORG E-ORG\n. O O\n\"\"\".strip().replace(\" \", \"\\t\")\n\n\ndef test_ner():\n    nlp = stanza.Pipeline(**{'processors': 'tokenize,ner', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'logging_level': 'error'})\n    doc = nlp(EN_DOC)\n    assert EN_DOC_GOLD == '\\n'.join([ent.pretty_print() for ent in doc.ents])\n\ndef test_evaluate(tmp_path):\n    \"\"\"\n    This simple example should have a 1.0 f1 for the ontonote model\n    \"\"\"\n    package = \"ontonotes-ww-multi_charlm\"\n    model_path = os.path.join(TEST_MODELS_DIR, \"en\", \"ner\", package + \".pt\")\n    assert os.path.exists(model_path), \"The {} model should be downloaded as part of setup.py\".format(package)\n\n    os.makedirs(tmp_path, exist_ok=True)\n\n    test_bio_filename = tmp_path / \"test.bio\"\n    test_json_filename = tmp_path / \"test.json\"\n    test_output_filename = tmp_path / \"output.bio\"\n    with open(test_bio_filename, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(EN_BIO)\n\n    prepare_ner_file.process_dataset(test_bio_filename, test_json_filename)\n\n    args = [\"--save_name\", str(model_path),\n            \"--eval_file\", str(test_json_filename),\n            \"--eval_output_file\", str(test_output_filename),\n            \"--mode\", \"predict\"]\n    args = args + build_pretrain_args(\"en\", package, model_dir=TEST_MODELS_DIR, extra_args=[])\n    args = ner_tagger.parse_args(args=args)\n    confusion = ner_tagger.evaluate(args)\n    assert confusion_to_macro_f1(confusion) == pytest.approx(1.0)\n\n    with open(test_output_filename, encoding=\"utf-8\") as fin:\n        results = fin.read().strip()\n\n    assert results == EN_EXPECTED_OUTPUT\n"
  },
  {
    "path": "stanza/tests/ner/test_ner_trainer.py",
    "content": "import pytest\n\nfrom stanza.tests import *\n\nfrom stanza.models.ner import trainer\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_fix_singleton_tags():\n    TESTS = [\n        ([\"O\"], [\"O\"]),\n        ([\"B-PER\"], [\"S-PER\"]),\n        ([\"B-PER\", \"I-PER\"], [\"B-PER\", \"E-PER\"]),\n        ([\"B-PER\", \"O\", \"B-PER\"], [\"S-PER\", \"O\", \"S-PER\"]),\n        ([\"B-PER\", \"B-PER\", \"I-PER\"], [\"S-PER\", \"B-PER\", \"E-PER\"]),\n        ([\"B-PER\", \"I-PER\", \"O\", \"B-PER\"], [\"B-PER\", \"E-PER\", \"O\", \"S-PER\"]),\n        ([\"B-PER\", \"B-PER\", \"I-PER\", \"B-PER\"], [\"S-PER\", \"B-PER\", \"E-PER\", \"S-PER\"]),\n        ([\"B-PER\", \"I-ORG\", \"O\", \"B-PER\"], [\"S-PER\", \"S-ORG\", \"O\", \"S-PER\"]),\n        ([\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"E-PER\"], [\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"E-PER\"]),\n        ([\"S-PER\", \"B-PER\", \"E-PER\"], [\"S-PER\", \"B-PER\", \"E-PER\"]),\n        ([\"E-PER\"], [\"S-PER\"]),\n        ([\"E-PER\", \"O\", \"E-PER\"], [\"S-PER\", \"O\", \"S-PER\"]),\n        ([\"B-PER\", \"E-ORG\", \"O\", \"B-PER\"], [\"S-PER\", \"S-ORG\", \"O\", \"S-PER\"]),\n        ([\"I-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"E-PER\"], [\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"E-PER\"]),\n        ([\"B-PER\", \"I-PER\", \"I-PER\", \"O\", \"B-PER\", \"E-PER\"], [\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"E-PER\"]),\n        ([\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"I-PER\", \"E-PER\"], [\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"E-PER\"]),\n        ([\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"I-PER\"], [\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"E-PER\"]),\n        ([\"I-PER\", \"I-PER\", \"I-PER\", \"O\", \"I-PER\", \"I-PER\"], [\"B-PER\", \"I-PER\", \"E-PER\", \"O\", \"B-PER\", \"E-PER\"]),\n    ]\n             \n    for unfixed, expected in TESTS:\n        assert trainer.fix_singleton_tags(unfixed) == expected, \"Error converting {} to {}\".format(unfixed, expected)\n"
  },
  {
    "path": "stanza/tests/ner/test_ner_training.py",
    "content": "import json\nimport logging\nimport os\nimport warnings\n\nimport pytest\nimport torch\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nfrom stanza.models import ner_tagger\nfrom stanza.models.ner.trainer import Trainer\nfrom stanza.tests import TEST_WORKING_DIR\nfrom stanza.utils.datasets.ner.prepare_ner_file import process_dataset\n\nlogger = logging.getLogger('stanza')\n\nEN_TRAIN_BIO = \"\"\"\nChris B-PERSON\nManning E-PERSON\nis O\na O\ngood O\nman O\n. O\n\nHe O\nworks O\nin O\nStanford B-ORG\nUniversity E-ORG\n. O\n\"\"\".lstrip().replace(\" \", \"\\t\")\n\nEN_DEV_BIO = \"\"\"\nChris B-PERSON\nManning E-PERSON\nis O\npart O\nof O\nComputer B-ORG\nScience E-ORG\n\"\"\".lstrip().replace(\" \", \"\\t\")\n\nEN_TRAIN_2TAG = \"\"\"\nChris B-PERSON B-PER\nManning E-PERSON E-PER\nis O O\na O O\ngood O O\nman O O\n. O O\n\nHe O O\nworks O O\nin O O\nStanford B-ORG B-ORG\nUniversity E-ORG B-ORG\n. O O\n\"\"\".strip().replace(\" \", \"\\t\")\n\nEN_TRAIN_2TAG_EMPTY2 = \"\"\"\nChris B-PERSON -\nManning E-PERSON -\nis O -\na O -\ngood O -\nman O -\n. O -\n\nHe O -\nworks O -\nin O -\nStanford B-ORG -\nUniversity E-ORG -\n. O -\n\"\"\".strip().replace(\" \", \"\\t\")\n\nEN_DEV_2TAG = \"\"\"\nChris B-PERSON B-PER\nManning E-PERSON E-PER\nis O O\npart O O\nof O O\nComputer B-ORG B-ORG\nScience E-ORG E-ORG\n\"\"\".strip().replace(\" \", \"\\t\")\n\n@pytest.fixture(scope=\"module\")\ndef pretrain_file():\n    return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'\n\ndef write_temp_file(filename, bio_data):\n    bio_filename = os.path.splitext(filename)[0] + \".bio\"\n    with open(bio_filename, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(bio_data)\n    process_dataset(bio_filename, filename)\n\ndef write_temp_2tag(filename, bio_data):\n    doc = []\n    sentences = bio_data.split(\"\\n\\n\")\n    for sentence in sentences:\n        doc.append([])\n        for word in sentence.split(\"\\n\"):\n            text, tags = word.split(\"\\t\", maxsplit=1)\n            doc[-1].append({\n                \"text\": text,\n                \"multi_ner\": tags.split()\n            })\n\n    with open(filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(doc, fout)\n\ndef get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args):\n    save_dir = tmp_path / \"models\"\n    args = [\"--data_dir\", str(tmp_path),\n            \"--wordvec_pretrain_file\", pretrain_file,\n            \"--train_file\", str(train_json),\n            \"--eval_file\", str(dev_json),\n            \"--shorthand\", \"en_test\",\n            \"--max_steps\", \"100\",\n            \"--eval_interval\", \"40\",\n            \"--save_dir\", str(save_dir)]\n\n    args = args + list(extra_args)\n    return args\n\ndef run_two_tag_training(pretrain_file, tmp_path, *extra_args, train_data=EN_TRAIN_2TAG):\n    train_json = tmp_path / \"en_test.train.json\"\n    write_temp_2tag(train_json, train_data)\n\n    dev_json = tmp_path / \"en_test.dev.json\"\n    write_temp_2tag(dev_json, EN_DEV_2TAG)\n\n    args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args)\n    return ner_tagger.main(args)\n\ndef test_basic_two_tag_training(pretrain_file, tmp_path):\n    trainer = run_two_tag_training(pretrain_file, tmp_path)\n    assert len(trainer.model.tag_clfs) == 2\n    assert len(trainer.model.crits) == 2\n    assert len(trainer.vocab['tag'].lens()) == 2\n\ndef test_two_tag_training_backprop(pretrain_file, tmp_path):\n    \"\"\"\n    Test that the training is backproping both tags\n\n    We can do this by using the \"finetune\" mechanism and verifying\n    that the output tensors are different\n    \"\"\"\n    trainer = run_two_tag_training(pretrain_file, tmp_path)\n\n    # first, need to save the final model before restarting\n    # (alternatively, could reload the final checkpoint)\n    trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name']))\n    new_trainer = run_two_tag_training(pretrain_file, tmp_path, \"--finetune\")\n\n    assert len(trainer.model.tag_clfs) == 2\n    assert len(new_trainer.model.tag_clfs) == 2\n    for old_clf, new_clf in zip(trainer.model.tag_clfs, new_trainer.model.tag_clfs):\n        assert not torch.allclose(old_clf.weight, new_clf.weight)\n\ndef test_two_tag_training_c2_backprop(pretrain_file, tmp_path):\n    \"\"\"\n    Test that the training is backproping only one tag if one column is blank\n\n    We can do this by using the \"finetune\" mechanism and verifying\n    that the output tensors are different in just the first column\n    \"\"\"\n    trainer = run_two_tag_training(pretrain_file, tmp_path)\n\n    # first, need to save the final model before restarting\n    # (alternatively, could reload the final checkpoint)\n    trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name']))\n    new_trainer = run_two_tag_training(pretrain_file, tmp_path, \"--finetune\", train_data=EN_TRAIN_2TAG_EMPTY2)\n\n    assert len(trainer.model.tag_clfs) == 2\n    assert len(new_trainer.model.tag_clfs) == 2\n    assert not torch.allclose(trainer.model.tag_clfs[0].weight, new_trainer.model.tag_clfs[0].weight)\n    assert torch.allclose(trainer.model.tag_clfs[1].weight, new_trainer.model.tag_clfs[1].weight)\n\ndef test_connected_two_tag_training(pretrain_file, tmp_path):\n    trainer = run_two_tag_training(pretrain_file, tmp_path, \"--connect_output_layers\")\n    assert len(trainer.model.tag_clfs) == 2\n    assert len(trainer.model.crits) == 2\n    assert len(trainer.vocab['tag'].lens()) == 2\n\n    # this checks that with the connected output layers,\n    # the second output layer has its size increased\n    # by the number of tags known to the first output layer\n    assert trainer.model.tag_clfs[1].weight.shape[1] == trainer.vocab['tag'].lens()[0] + trainer.model.tag_clfs[0].weight.shape[1]\n\ndef run_training(pretrain_file, tmp_path, *extra_args):\n    train_json = tmp_path / \"en_test.train.json\"\n    write_temp_file(train_json, EN_TRAIN_BIO)\n\n    dev_json = tmp_path / \"en_test.dev.json\"\n    write_temp_file(dev_json, EN_DEV_BIO)\n\n    args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args)\n    return ner_tagger.main(args)\n\n\ndef test_train_model_gpu(pretrain_file, tmp_path):\n    \"\"\"\n    Briefly train an NER model (no expectation of correctness) and check that it is on the GPU\n    \"\"\"\n    trainer = run_training(pretrain_file, tmp_path)\n    if not torch.cuda.is_available():\n        warnings.warn(\"Cannot check that the NER model is on the GPU, since GPU is not available\")\n        return\n\n    model = trainer.model\n    device = next(model.parameters()).device\n    assert str(device).startswith(\"cuda\")\n\n\ndef test_train_model_cpu(pretrain_file, tmp_path):\n    \"\"\"\n    Briefly train an NER model (no expectation of correctness) and check that it is on the GPU\n    \"\"\"\n    trainer = run_training(pretrain_file, tmp_path, \"--cpu\")\n\n    model = trainer.model\n    device = next(model.parameters()).device\n    assert str(device).startswith(\"cpu\")\n\ndef model_file_has_bert(filename):\n    checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)\n    return any(x.startswith(\"bert_model.\") for x in checkpoint['model'].keys())\n\ndef test_with_bert(pretrain_file, tmp_path):\n    trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert')\n    model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])\n    assert not model_file_has_bert(model_file)\n\ndef test_with_bert_finetune(pretrain_file, tmp_path):\n    trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune')\n    model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])\n    assert model_file_has_bert(model_file)\n\n    foo_save_filename = os.path.join(tmp_path, \"foo_\" + trainer.args['save_name'])\n    bar_save_filename = os.path.join(tmp_path, \"bar_\" + trainer.args['save_name'])\n    trainer.save(foo_save_filename)\n    assert model_file_has_bert(foo_save_filename)\n\n    # TODO: technically this should still work if we turn off bert finetuning when reloading\n    reloaded_trainer = Trainer(args=trainer.args, model_file=foo_save_filename)\n    reloaded_trainer.save(bar_save_filename)\n    assert model_file_has_bert(bar_save_filename)\n\ndef test_with_peft_finetune(pretrain_file, tmp_path):\n    # TODO: check that the peft tensors are moving when training?\n    trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft')\n    model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])\n    checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True)\n    assert 'bert_lora' in checkpoint\n    assert not any(x.startswith(\"bert_model.\") for x in checkpoint['model'].keys())\n\n    # test loading\n    reloaded_trainer = Trainer(args=trainer.args, model_file=model_file)\n"
  },
  {
    "path": "stanza/tests/ner/test_ner_utils.py",
    "content": "import pytest\n\nfrom stanza.tests import *\n\nfrom stanza.models.common.vocab import EMPTY\nfrom stanza.models.ner import utils\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nWORDS       = [[\"Unban\",   \"Mox\",   \"Opal\"], [\"Ragavan\",  \"is\",     \"red\"], [\"Urza\",   \"Lord\",  \"High\", \"Artificer\", \"goes\", \"infinite\", \"with\",  \"Thopter\",    \"Sword\"]]\nBIO_TAGS    = [[\"O\",     \"B-ART\",  \"I-ART\"], [\"B-MONKEY\", \"O\",  \"B-COLOR\"], [\"B-PER\", \"I-PER\", \"I-PER\", \"I-PER\",        \"O\",        \"O\",    \"O\", \"B-WEAPON\", \"B-WEAPON\"]]\nBIO_U_TAGS  = [[\"O\",     \"B_ART\",  \"I_ART\"], [\"B_MONKEY\", \"O\",  \"B_COLOR\"], [\"B_PER\", \"I_PER\", \"I_PER\", \"I_PER\",        \"O\",        \"O\",    \"O\", \"B_WEAPON\", \"B_WEAPON\"]]\nBIOES_TAGS  = [[\"O\",     \"B-ART\",  \"E-ART\"], [\"S-MONKEY\", \"O\",  \"S-COLOR\"], [\"B-PER\", \"I-PER\", \"I-PER\", \"E-PER\",        \"O\",        \"O\",    \"O\", \"S-WEAPON\", \"S-WEAPON\"]]\n# note the problem with not using BIO tags - the consecutive tags for thopter/sword get treated as one item\nBASIC_TAGS  = [[\"O\",       \"ART\",    \"ART\"], [\"MONKEY\",   \"O\",    \"COLOR\"], [  \"PER\",   \"PER\",   \"PER\",   \"PER\",        \"O\",        \"O\",    \"O\",   \"WEAPON\",   \"WEAPON\"]]\nBASIC_BIOES = [[\"O\",     \"B-ART\",  \"E-ART\"], [\"S-MONKEY\", \"O\",  \"S-COLOR\"], [\"B-PER\", \"I-PER\", \"I-PER\", \"E-PER\",        \"O\",        \"O\",    \"O\", \"B-WEAPON\", \"E-WEAPON\"]]\nALT_BIO     = [[\"O\",    \"B-MANA\", \"I-MANA\"], [\"B-CRE\",    \"O\",        \"O\"], [\"B-CRE\", \"I-CRE\", \"I-CRE\", \"I-CRE\",        \"O\",        \"O\",    \"O\",    \"B-ART\",    \"B-ART\"]]\nALT_BIOES   = [[\"O\",    \"B-MANA\", \"E-MANA\"], [\"S-CRE\",    \"O\",        \"O\"], [\"B-CRE\", \"I-CRE\", \"I-CRE\", \"E-CRE\",        \"O\",        \"O\",    \"O\",    \"S-ART\",    \"S-ART\"]]\nNONE_BIO    = [[\"O\",    \"B-MANA\", \"I-MANA\"], [None,      None,       None], [\"B-CRE\", \"I-CRE\", \"I-CRE\", \"I-CRE\",        \"O\",        \"O\",    \"O\",    \"B-ART\",    \"B-ART\"]]\nNONE_BIOES  = [[\"O\",    \"B-MANA\", \"E-MANA\"], [None,      None,       None], [\"B-CRE\", \"I-CRE\", \"I-CRE\", \"E-CRE\",        \"O\",        \"O\",    \"O\",    \"S-ART\",    \"S-ART\"]]\nEMPTY_BIO   = [[\"O\",    \"B-MANA\", \"I-MANA\"], [EMPTY,     EMPTY,     EMPTY], [\"B-CRE\", \"I-CRE\", \"I-CRE\", \"I-CRE\",        \"O\",        \"O\",    \"O\",    \"B-ART\",    \"B-ART\"]]\n\ndef test_normalize_empty_tags():\n    sentences = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, NONE_BIO)]\n    new_sentences = utils.normalize_empty_tags(sentences)\n    expected = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, EMPTY_BIO)]\n    assert new_sentences == expected\n\ndef check_reprocessed_tags(words, input_tags, expected_tags):\n    sentences = [list(zip(x, y)) for x, y in zip(words, input_tags)]\n    retagged = utils.process_tags(sentences=sentences, scheme=\"bioes\")\n    # process_tags selectively returns tuples or strings based on the input\n    # so we don't need to fiddle with the expected output format here\n    expected_retagged = [list(zip(x, y)) for x, y in zip(words, expected_tags)]\n    assert retagged == expected_retagged\n\ndef test_process_tags_bio():\n    check_reprocessed_tags(WORDS, BIO_TAGS, BIOES_TAGS)\n    # check that the alternate version is correct as well\n    # that way we can independently check the two layer version\n    check_reprocessed_tags(WORDS, ALT_BIO, ALT_BIOES)\n\ndef test_process_tags_with_none():\n    # if there is a block of tags with None in them, the Nones should be skipped over\n    check_reprocessed_tags(WORDS, NONE_BIO, NONE_BIOES)\n\ndef merge_tags(*tags):\n    merged_tags = [[tuple(x) for x in zip(*sentences)]   # combine tags such as (\"O\", \"O\"), (\"B-ART\", \"B-MANA\"), ...\n                   for sentences in zip(*tags)]          # ... for each set of sentences\n    return merged_tags\n\ndef test_combined_tags_bio():\n    bio_tags = merge_tags(BIO_TAGS, ALT_BIO)\n    expected = merge_tags(BIOES_TAGS, ALT_BIOES)\n    check_reprocessed_tags(WORDS, bio_tags, expected)\n\ndef test_combined_tags_mixed():\n    bio_tags = merge_tags(BIO_TAGS, ALT_BIOES)\n    expected = merge_tags(BIOES_TAGS, ALT_BIOES)\n    check_reprocessed_tags(WORDS, bio_tags, expected)\n\ndef test_process_tags_basic():\n    check_reprocessed_tags(WORDS, BASIC_TAGS, BASIC_BIOES)\n\ndef test_process_tags_bioes():\n    \"\"\"\n    This one should not change, naturally\n    \"\"\"\n    check_reprocessed_tags(WORDS, BIOES_TAGS, BIOES_TAGS)\n    check_reprocessed_tags(WORDS, BASIC_BIOES, BASIC_BIOES)\n\ndef run_flattened(fn, tags):\n    return fn([x for x in y for y in tags])\n\ndef test_check_bio():\n    assert     utils.is_bio_scheme([x for y in BIO_TAGS for x in y])\n    assert not utils.is_bio_scheme([x for y in BIOES_TAGS for x in y])\n    assert not utils.is_bio_scheme([x for y in BASIC_TAGS for x in y])\n    assert not utils.is_bio_scheme([x for y in BASIC_BIOES for x in y])\n\ndef test_check_basic():\n    assert not utils.is_basic_scheme([x for y in BIO_TAGS for x in y])\n    assert not utils.is_basic_scheme([x for y in BIOES_TAGS for x in y])\n    assert     utils.is_basic_scheme([x for y in BASIC_TAGS for x in y])\n    assert not utils.is_basic_scheme([x for y in BASIC_BIOES for x in y])\n\ndef test_underscores():\n    \"\"\"\n    Check that the methods work if the inputs are underscores instead of dashes\n    \"\"\"\n    assert not utils.is_basic_scheme([x for y in BIO_U_TAGS for x in y])\n    check_reprocessed_tags(WORDS, BIO_U_TAGS, BIOES_TAGS)\n\ndef test_merge_tags():\n    \"\"\"\n    Check a few versions of the tag sequence merging\n    \"\"\"\n    seq1     = [     \"O\",     \"O\",     \"O\", \"B-FOO\", \"E-FOO\",     \"O\"]\n    seq2     = [ \"S-FOO\",     \"O\", \"B-FOO\", \"E-FOO\",     \"O\",     \"O\"]\n    seq3     = [ \"B-FOO\", \"E-FOO\", \"B-FOO\", \"E-FOO\",     \"O\",     \"O\"]\n    seq_err  = [     \"O\", \"B-FOO\",     \"O\", \"B-FOO\", \"E-FOO\",     \"O\"]\n    seq_err2 = [     \"O\", \"B-FOO\",     \"O\", \"B-FOO\", \"B-FOO\",     \"O\"]\n    seq_err3 = [     \"O\", \"B-FOO\",     \"O\", \"B-FOO\", \"I-FOO\",     \"O\"]\n    seq_err4 = [     \"O\", \"B-FOO\",     \"O\", \"B-FOO\", \"I-FOO\", \"I-FOO\"]\n\n    result = utils.merge_tags(seq1, seq2)\n    expected = [ \"S-FOO\",     \"O\",     \"O\", \"B-FOO\", \"E-FOO\",     \"O\"]\n    assert result == expected\n\n    result = utils.merge_tags(seq2, seq1)\n    expected = [ \"S-FOO\",     \"O\", \"B-FOO\", \"E-FOO\",     \"O\",     \"O\"]\n    assert result == expected\n\n    result = utils.merge_tags(seq1, seq3)\n    expected = [ \"B-FOO\", \"E-FOO\",     \"O\", \"B-FOO\", \"E-FOO\",     \"O\"]\n    assert result == expected\n\n    with pytest.raises(ValueError):\n        result = utils.merge_tags(seq1, seq_err)\n\n    with pytest.raises(ValueError):\n        result = utils.merge_tags(seq1, seq_err2)\n\n    with pytest.raises(ValueError):\n        result = utils.merge_tags(seq1, seq_err3)\n\n    with pytest.raises(ValueError):\n        result = utils.merge_tags(seq1, seq_err4)\n\n"
  },
  {
    "path": "stanza/tests/ner/test_pay_amt_annotators.py",
    "content": "\"\"\"\nSimple test for tracking AMT annotator work\n\"\"\"\n\nimport os\nimport zipfile\n\nimport pytest\n\nfrom stanza.tests import TEST_WORKING_DIR\nfrom stanza.utils.ner import paying_annotators\n\nDATA_SOURCE = os.path.join(TEST_WORKING_DIR, \"in\", \"aws_annotations.zip\")\n\n@pytest.fixture(scope=\"module\")\ndef completed_amt_job_metadata(tmp_path_factory):\n    assert os.path.exists(DATA_SOURCE)\n    unzip_path = tmp_path_factory.mktemp(\"amt_test\")\n    input_path = unzip_path / \"ner\" / \"aws_labeling_copy\"\n    with zipfile.ZipFile(DATA_SOURCE, 'r') as zin:\n        zin.extractall(unzip_path)\n    return input_path\n\ndef test_amt_annotator_track(completed_amt_job_metadata):\n    workers = {\n        \"7efc17ac-3397-4472-afe5-89184ad145d0\": \"Worker1\",\n        \"afce8c28-969c-4e73-a20f-622ef122f585\": \"Worker2\",\n        \"91f6236e-63c6-4a84-8fd6-1efbab6dedab\": \"Worker3\",\n        \"6f202e93-e6b6-4e1d-8f07-0484b9a9093a\": \"Worker4\",\n        \"2b674d33-f656-44b0-8f90-d70a1ab71ec2\": \"Worker5\"\n    }  # map AMT annotator subs to relevant identifier\n\n    tracked_work = paying_annotators.track_tasks(completed_amt_job_metadata, workers)\n    assert tracked_work == {'Worker4': 20, 'Worker5': 20, 'Worker2': 3, 'Worker3': 16}\n\n\ndef test_amt_annotator_track_no_map(completed_amt_job_metadata):\n    sub_to_count = paying_annotators.track_tasks(completed_amt_job_metadata)\n    assert sub_to_count == {'6f202e93-e6b6-4e1d-8f07-0484b9a9093a': 20, '2b674d33-f656-44b0-8f90-d70a1ab71ec2': 20,\n                            'afce8c28-969c-4e73-a20f-622ef122f585': 3, '91f6236e-63c6-4a84-8fd6-1efbab6dedab': 16}\n\n\ndef main():\n    test_amt_annotator_track()\n    test_amt_annotator_track_no_map()\n\n\nif __name__ == \"__main__\":\n    main()\n    print(\"TESTS COMPLETED!\")\n"
  },
  {
    "path": "stanza/tests/ner/test_split_wikiner.py",
    "content": "\"\"\"\nRuns a few tests on the split_wikiner file\n\"\"\"\n\nimport os\nimport tempfile\n\nimport pytest\n\nfrom stanza.utils.datasets.ner import split_wikiner\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# two sentences from the Italian dataset, split into many pieces\n# to test the splitting functionality\nFBK_SAMPLE = \"\"\"\nIl\tO\nPapa\tO\nsi\tO\naggrava\tO\n\nLe\tO\ncondizioni\tO\ndi\tO\n\nPapa\tO\nGiovanni\tPER\nPaolo\tPER\nII\tPER\nsi\tO\n\nsono\tO\naggravate\tO\nin\tO\nil\tO\ncorso\tO\n\ndi\tO\nla\tO\ngiornata\tO\ndi\tO\ngiovedì\tO\n.\tO\n\nIl\tO\nportavoce\tO\nNavarro\tPER\nValls\tPER\n\nha\tO\ndichiarato\tO\nche\tO\n\nil\tO\nSanto\tO\nPadre\tO\n\nin\tO\nla\tO\ngiornata\tO\n\ndi\tO\noggi\tO\nè\tO\nstato\tO\n\ncolpito\tO\nda\tO\nuna\tO\naffezione\tO\n\naltamente\tO\nfebbrile\tO\nprovocata\tO\nda\tO\nuna\tO\n\ninfezione\tO\ndocumentata\tO\n\ndi\tO\nle\tO\nvie\tO\nurinarie\tO\n.\tO\n\nA\tO\nil\tO\nmomento\tO\n\nnon\tO\nè\tO\nprevisto\tO\nil\tO\nricovero\tO\n\na\tO\nil\tO\nPoliclinico\tLOC\nGemelli\tLOC\n,\tO\n\ncome\tO\nha\tO\nprecisato\tO\nil\tO\n\nresponsabile\tO\ndi\tO\nil\tO\ndipartimento\tO\n\ndi\tO\nemergenza\tO\nprofessor\tO\nRodolfo\tPER\nProietti\tPER\n.\tO\n\"\"\"\n\n\ndef test_read_sentences():\n    with tempfile.TemporaryDirectory() as tempdir:\n        raw_filename = os.path.join(tempdir, \"raw.tsv\")\n        with open(raw_filename, \"w\") as fout:\n            fout.write(FBK_SAMPLE)\n\n        sentences = split_wikiner.read_sentences(raw_filename, \"utf-8\")\n        assert len(sentences) == 20\n        text = [[\"\\t\".join(word) for word in sent] for sent in sentences]\n        text = [\"\\n\".join(sent) for sent in text]\n        text = \"\\n\\n\".join(text)\n        assert FBK_SAMPLE.strip() == text\n\ndef test_write_sentences():\n    with tempfile.TemporaryDirectory() as tempdir:\n        raw_filename = os.path.join(tempdir, \"raw.tsv\")\n        with open(raw_filename, \"w\") as fout:\n            fout.write(FBK_SAMPLE)\n\n        sentences = split_wikiner.read_sentences(raw_filename, \"utf-8\")\n        copy_filename = os.path.join(tempdir, \"copy.tsv\")\n        split_wikiner.write_sentences_to_file(sentences, copy_filename)\n\n        sent2 = split_wikiner.read_sentences(raw_filename, \"utf-8\")\n        assert sent2 == sentences\n\ndef run_split_wikiner(expected_train=14, expected_dev=3, expected_test=3, **kwargs):\n    \"\"\"\n    Runs a test using various parameters to check the results of the splitting process\n    \"\"\"\n    with tempfile.TemporaryDirectory() as indir:\n        raw_filename = os.path.join(indir, \"raw.tsv\")\n        with open(raw_filename, \"w\") as fout:\n            fout.write(FBK_SAMPLE)\n\n        with tempfile.TemporaryDirectory() as outdir:\n            split_wikiner.split_wikiner(outdir, raw_filename, **kwargs)\n\n            train_file = os.path.join(outdir, \"it_fbk.train.bio\")\n            dev_file = os.path.join(outdir, \"it_fbk.dev.bio\")\n            test_file = os.path.join(outdir, \"it_fbk.test.bio\")\n\n            assert os.path.exists(train_file)\n            assert os.path.exists(dev_file)\n            if kwargs[\"test_section\"]:\n                assert os.path.exists(test_file)\n            else:\n                assert not os.path.exists(test_file)\n\n            train_sent = split_wikiner.read_sentences(train_file, \"utf-8\")\n            dev_sent = split_wikiner.read_sentences(dev_file, \"utf-8\")\n            assert len(train_sent) == expected_train\n            assert len(dev_sent) == expected_dev\n            if kwargs[\"test_section\"]:\n                test_sent = split_wikiner.read_sentences(test_file, \"utf-8\")\n                assert len(test_sent) == expected_test\n            else:\n                test_sent = []\n\n            if kwargs[\"shuffle\"]:\n                orig_sents = sorted(split_wikiner.read_sentences(raw_filename, \"utf-8\"))\n                split_sents = sorted(train_sent + dev_sent + test_sent)\n            else:\n                orig_sents = split_wikiner.read_sentences(raw_filename, \"utf-8\")\n                split_sents = train_sent + dev_sent + test_sent\n            assert orig_sents == split_sents\n\ndef test_no_shuffle_split():\n    run_split_wikiner(prefix=\"it_fbk\", shuffle=False, test_section=True)\n\ndef test_shuffle_split():\n    run_split_wikiner(prefix=\"it_fbk\", shuffle=True, test_section=True)\n\ndef test_resize():\n    run_split_wikiner(expected_train=12, expected_dev=2, expected_test=6, train_fraction=0.6, dev_fraction=0.1, prefix=\"it_fbk\", shuffle=True, test_section=True)\n\ndef test_no_test_split():\n    run_split_wikiner(expected_train=17, train_fraction=0.85, prefix=\"it_fbk\", shuffle=False, test_section=False)\n\n"
  },
  {
    "path": "stanza/tests/ner/test_suc3.py",
    "content": "\"\"\"\nTests the conversion code for the SUC3 NER dataset\n\"\"\"\n\nimport os\nimport tempfile\nfrom zipfile import ZipFile\n\nimport pytest\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nimport stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob\n\nTEST_CONLL = \"\"\"\n1\tDen\tden\tPN\tPN\tUTR|SIN|DEF|SUB/OBJ\t_\t_\t_\t_\tO\t_\tac01b-030:2328\n2\tGud\tGud\tPM\tPM\tNOM\t_\t_\t_\t_\tB\tmyth\tac01b-030:2329\n3\tgiver\tgiva\tVB\tVB\tPRS|AKT\t_\t_\t_\t_\tO\t_\tac01b-030:2330\n4\tämbetet\tämbete\tNN\tNN\tNEU|SIN|DEF|NOM\t_\t_\t_\t_\tO\t_\tac01b-030:2331\n5\tfår\tfå\tVB\tVB\tPRS|AKT\t_\t_\t_\t_\tO\t_\tac01b-030:2332\n6\tockså\tockså\tAB\tAB\t\t_\t_\t_\t_\tO\t_\tac01b-030:2333\n7\tförståndet\tförstånd\tNN\tNN\tNEU|SIN|DEF|NOM\t_\t_\t_\t_\tO\t_\tac01b-030:2334\n8\t.\t.\tMAD\tMAD\t\t_\t_\t_\t_\tO\t_\tac01b-030:2335\n\n1\tHan\than\tPN\tPN\tUTR|SIN|DEF|SUB\t_\t_\t_\t_\tO\t_\taa01a-017:227\n2\tberättar\tberätta\tVB\tVB\tPRS|AKT\t_\t_\t_\t_\tO\t_\taa01a-017:228\n3\tanekdoten\tanekdot\tNN\tNN\tUTR|SIN|DEF|NOM\t_\t_\t_\t_\tO\t_\taa01a-017:229\n4\tsom\tsom\tHP\tHP\t-|-|-\t_\t_\t_\t_\tO\t_\taa01a-017:230\n5\tFN-medlaren\tFN-medlare\tNN\tNN\tUTR|SIN|DEF|NOM\t_\t_\t_\t_\tO\t_\taa01a-017:231\n6\tBrian\tBrian\tPM\tPM\tNOM\t_\t_\t_\t_\tB\tperson\taa01a-017:232\n7\tUrquhart\tUrquhart\tPM\tPM\tNOM\t_\t_\t_\t_\tI\tperson\taa01a-017:233\n8\tmyntat\tmynta\tVB\tVB\tSUP|AKT\t_\t_\t_\t_\tO\t_\taa01a-017:234\n9\t:\t:\tMAD\tMAD\t\t_\t_\t_\t_\tO\t_\taa01a-017:235\n\"\"\"\n\nEXPECTED_IOB = \"\"\"\nDen\tO\nGud\tB-myth\ngiver\tO\nämbetet\tO\nfår\tO\nockså\tO\nförståndet\tO\n.\tO\n\nHan\tO\nberättar\tO\nanekdoten\tO\nsom\tO\nFN-medlaren\tO\nBrian\tB-person\nUrquhart\tI-person\nmyntat\tO\n:\tO\n\"\"\"\n\ndef test_read_zip():\n    \"\"\"\n    Test creating a fake zip file, then converting it to an .iob file\n    \"\"\"\n    with tempfile.TemporaryDirectory() as tempdir:\n        zip_name = os.path.join(tempdir, \"test.zip\")\n        in_filename = \"conll\"\n        with ZipFile(zip_name, \"w\") as zout:\n            with zout.open(in_filename, \"w\") as fout:\n                fout.write(TEST_CONLL.encode())\n\n        out_filename = os.path.join(tempdir, \"iob\")\n        num = suc_conll_to_iob.extract_from_zip(zip_name, in_filename, out_filename)\n        assert num == 2\n\n        with open(out_filename) as fin:\n            result = fin.read()\n        assert EXPECTED_IOB.strip() == result.strip()\n\ndef test_read_raw():\n    \"\"\"\n    Test a direct text file conversion w/o the zip file\n    \"\"\"\n    with tempfile.TemporaryDirectory() as tempdir:\n        in_filename = os.path.join(tempdir, \"test.txt\")\n        with open(in_filename, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(TEST_CONLL)\n\n        out_filename = os.path.join(tempdir, \"iob\")\n        with open(in_filename, encoding=\"utf-8\") as fin, open(out_filename, \"w\", encoding=\"utf-8\") as fout:\n            num = suc_conll_to_iob.extract(fin, fout)\n        assert num == 2\n\n        with open(out_filename) as fin:\n            result = fin.read()\n        assert EXPECTED_IOB.strip() == result.strip()\n"
  },
  {
    "path": "stanza/tests/pipeline/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/pipeline/pipeline_device_tests.py",
    "content": "\"\"\"\nUtility methods to check that all processors are on the expected device\n\nRefactored since it can be used for multiple pipelines\n\"\"\"\n\nimport warnings\n\nimport torch\n\ndef check_on_gpu(pipeline):\n    \"\"\"\n    Check that the processors are all on the GPU and that basic execution works\n    \"\"\"\n    if not torch.cuda.is_available():\n        warnings.warn(\"Unable to run the test that checks the pipeline is on the GPU, as there is no GPU available!\")\n        return\n\n    for name, proc in pipeline.processors.items():\n        if proc.trainer is not None:\n            device = next(proc.trainer.model.parameters()).device\n        else:\n            device = next(proc._model.parameters()).device\n\n        assert str(device).startswith(\"cuda\"), \"Processor %s was not on the GPU\" % name\n\n    # just check that there are no cpu/cuda tensor conflicts\n    # when running on the GPU\n    pipeline(\"This is a small test\")\n\ndef check_on_cpu(pipeline):\n    \"\"\"\n    Check that the processors are all on the CPU and that basic execution works\n    \"\"\"\n    for name, proc in pipeline.processors.items():\n        if proc.trainer is not None:\n            device = next(proc.trainer.model.parameters()).device\n        else:\n            device = next(proc._model.parameters()).device\n\n        assert str(device).startswith(\"cpu\"), \"Processor %s was not on the CPU\" % name\n\n    # just check that there are no cpu/cuda tensor conflicts\n    # when running on the CPU\n    pipeline(\"This is a small test\")\n"
  },
  {
    "path": "stanza/tests/pipeline/test_arabic_pipeline.py",
    "content": "\"\"\"\nSmall test of loading the Arabic pipeline\n\nThe main goal is to check that nothing goes wrong with RtL languages,\nbut incidentally this would have caught a bug where the xpos tags\nwere split into individual pieces instead of reassembled as expected\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = pytest.mark.pipeline\n\ndef test_arabic_pos_pipeline():\n    pipe = stanza.Pipeline(**{'processors': 'tokenize,pos', 'dir': TEST_MODELS_DIR, 'download_method': None, 'lang': 'ar'})\n    text = \"ولم يتم اعتقال احد بحسب المتحدث باسم الشرطة.\"\n\n    doc = pipe(text)\n    # the first token translates to \"and not\", seems common enough\n    # that we should be able to rely on it having a stable MWT and tag\n\n    assert len(doc.sentences) == 1\n    assert doc.sentences[0].tokens[0].text == \"ولم\"\n    assert doc.sentences[0].words[0].xpos == \"C---------\"\n    assert doc.sentences[0].words[1].xpos == \"F---------\"\n"
  },
  {
    "path": "stanza/tests/pipeline/test_core.py",
    "content": "import pytest\nimport shutil\nimport tempfile\n\nimport stanza\n\nfrom stanza.tests import *\n\nfrom stanza.pipeline import core\nfrom stanza.resources.common import get_md5, load_resources_json\n\npytestmark = pytest.mark.pipeline\n\ndef test_pretagged():\n    \"\"\"\n    Test that the pipeline does or doesn't build if pos is left out and pretagged is specified\n    \"\"\"\n    nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors=\"tokenize,pos,lemma,depparse\")\n    with pytest.raises(core.PipelineRequirementsException):\n        nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors=\"tokenize,lemma,depparse\")\n    nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors=\"tokenize,lemma,depparse\", depparse_pretagged=True)\n    nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors=\"tokenize,lemma,depparse\", pretagged=True)\n    # test that the module specific flag overrides the general flag\n    nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors=\"tokenize,lemma,depparse\", depparse_pretagged=True, pretagged=False)\n\ndef test_download_missing_ner_model():\n    \"\"\"\n    Test that the pipeline will automatically download missing models\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download(\"en\", model_dir=test_dir, processors=\"tokenize\", package=\"combined\", verbose=False)\n        pipe = stanza.Pipeline(\"en\", model_dir=test_dir, processors=\"tokenize,ner\", package={\"ner\": (\"ontonotes_charlm\")})\n\n        assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']\n        en_dir = os.path.join(test_dir, 'en')\n        en_dir_listing = sorted(os.listdir(en_dir))\n        assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize']\n        assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']\n\n\ndef test_download_missing_resources():\n    \"\"\"\n    Test that the pipeline will automatically download missing models\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        pipe = stanza.Pipeline(\"en\", model_dir=test_dir, processors=\"tokenize,ner\", package={\"tokenize\": \"combined\", \"ner\": \"ontonotes_charlm\"})\n\n        assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']\n        en_dir = os.path.join(test_dir, 'en')\n        en_dir_listing = sorted(os.listdir(en_dir))\n        assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize']\n        assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']\n\n\ndef test_download_resources_overwrites():\n    \"\"\"\n    Test that the DOWNLOAD_RESOURCES method overwrites an existing resources.json\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        pipe = stanza.Pipeline(\"en\", model_dir=test_dir, processors=\"tokenize\", package={\"tokenize\": \"combined\"})\n\n        assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']\n        resources_path = os.path.join(test_dir, 'resources.json')\n        mod_time = os.path.getmtime(resources_path)\n\n        pipe = stanza.Pipeline(\"en\", model_dir=test_dir, processors=\"tokenize\", package={\"tokenize\": \"combined\"})\n        new_mod_time = os.path.getmtime(resources_path)\n        assert mod_time != new_mod_time\n\ndef test_reuse_resources_overwrites():\n    \"\"\"\n    Test that the REUSE_RESOURCES method does *not* overwrite an existing resources.json\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        pipe = stanza.Pipeline(\"en\",\n                               download_method=core.DownloadMethod.REUSE_RESOURCES,\n                               model_dir=test_dir,\n                               processors=\"tokenize\",\n                               package={\"tokenize\": \"combined\"})\n\n        assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']\n        resources_path = os.path.join(test_dir, 'resources.json')\n        mod_time = os.path.getmtime(resources_path)\n\n        pipe = stanza.Pipeline(\"en\",\n                               download_method=core.DownloadMethod.REUSE_RESOURCES,\n                               model_dir=test_dir,\n                               processors=\"tokenize\",\n                               package={\"tokenize\": \"combined\"})\n        new_mod_time = os.path.getmtime(resources_path)\n        assert mod_time == new_mod_time\n\n\ndef test_download_not_repeated():\n    \"\"\"\n    Test that a model is only downloaded once if it already matches the expected model from the resources file\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download(\"en\", model_dir=test_dir, processors=\"tokenize\", package=\"combined\")\n\n        assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']\n        en_dir = os.path.join(test_dir, 'en')\n        en_dir_listing = sorted(os.listdir(en_dir))\n        assert en_dir_listing == ['mwt', 'tokenize']\n        tokenize_path = os.path.join(en_dir, \"tokenize\", \"combined.pt\")\n        mod_time = os.path.getmtime(tokenize_path)\n\n        pipe = stanza.Pipeline(\"en\", model_dir=test_dir, processors=\"tokenize\", package={\"tokenize\": \"combined\"})\n        assert os.path.getmtime(tokenize_path) == mod_time\n\ndef test_download_none():\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download(\"it\", model_dir=test_dir, processors=\"tokenize\", package=\"combined\")\n        stanza.download(\"it\", model_dir=test_dir, processors=\"tokenize\", package=\"vit\")\n\n        it_dir = os.path.join(test_dir, 'it')\n        it_dir_listing = sorted(os.listdir(it_dir))\n        assert sorted(it_dir_listing) == ['mwt', 'tokenize']\n        combined_path = os.path.join(it_dir, \"tokenize\", \"combined.pt\")\n        vit_path = os.path.join(it_dir, \"tokenize\", \"vit.pt\")\n\n        assert os.path.exists(combined_path)\n        assert os.path.exists(vit_path)\n\n        combined_md5 = get_md5(combined_path)\n        vit_md5 = get_md5(vit_path)\n        # check that the models are different\n        # otherwise the test is not testing anything\n        assert combined_md5 != vit_md5\n\n        shutil.copyfile(vit_path, combined_path)\n        assert get_md5(combined_path) == vit_md5\n\n        pipe = stanza.Pipeline(\"it\", model_dir=test_dir, processors=\"tokenize\", package={\"tokenize\": \"combined\"}, download_method=None)\n        assert get_md5(combined_path) == vit_md5\n\n        pipe = stanza.Pipeline(\"it\", model_dir=test_dir, processors=\"tokenize\", package={\"tokenize\": \"combined\"})\n        assert get_md5(combined_path) != vit_md5\n\n\ndef check_download_method_updates(download_method):\n    \"\"\"\n    Run a single test of creating a pipeline with a given download_method, checking that the model is updated\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download(\"en\", model_dir=test_dir, processors=\"tokenize\", package=\"combined\")\n\n        assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']\n        en_dir = os.path.join(test_dir, 'en')\n        en_dir_listing = sorted(os.listdir(en_dir))\n        assert en_dir_listing == ['mwt', 'tokenize']\n        tokenize_path = os.path.join(en_dir, \"tokenize\", \"combined.pt\")\n\n        with open(tokenize_path, \"w\") as fout:\n            fout.write(\"Unban mox opal!\")\n        mod_time = os.path.getmtime(tokenize_path)\n\n        pipe = stanza.Pipeline(\"en\", model_dir=test_dir, processors=\"tokenize\", package={\"tokenize\": \"combined\"}, download_method=download_method)\n        assert os.path.getmtime(tokenize_path) != mod_time\n\ndef test_download_fixed():\n    \"\"\"\n    Test that a model is fixed if the existing model doesn't match the md5sum\n    \"\"\"\n    for download_method in (core.DownloadMethod.REUSE_RESOURCES, core.DownloadMethod.DOWNLOAD_RESOURCES):\n        check_download_method_updates(download_method)\n\ndef test_download_strings():\n    \"\"\"\n    Same as the test of the download_method, but tests that the pipeline works for string download_method\n    \"\"\"\n    for download_method in (\"reuse_resources\", \"download_resources\"):\n        check_download_method_updates(download_method)\n\ndef test_limited_pipeline():\n    \"\"\"\n    Test loading a pipeline, but then only using a couple processors\n    \"\"\"\n    pipe = stanza.Pipeline(processors=\"tokenize,pos,lemma,depparse,ner\", dir=TEST_MODELS_DIR)\n    doc = pipe(\"John Bauer works at Stanford\")\n    assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words)\n    assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)\n\n    doc = pipe(\"John Bauer works at Stanford\", processors=[\"tokenize\",\"pos\"])\n    assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words)\n    assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)\n\n    doc = pipe(\"John Bauer works at Stanford\", processors=\"tokenize\")\n    assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words)\n    assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)\n\n    doc = pipe(\"John Bauer works at Stanford\", processors=\"tokenize,ner\")\n    assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words)\n    assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)\n\n    with pytest.raises(ValueError):\n        # this should fail\n        doc = pipe(\"John Bauer works at Stanford\", processors=\"tokenize,depparse\")\n\n@pytest.fixture(scope=\"module\")\ndef unknown_language_name():\n    resources = load_resources_json(model_dir=TEST_MODELS_DIR)\n    name = \"en\"\n    while name in resources:\n        name = name + \"z\"\n    assert name != \"en\"\n    return name\n\ndef test_empty_unknown_language(unknown_language_name):\n    \"\"\"\n    Check that there is an error for trying to load an unknown language\n    \"\"\"\n    with pytest.raises(ValueError):\n        pipe = stanza.Pipeline(unknown_language_name, model_dir=TEST_MODELS_DIR, download_method=None)\n\ndef test_unknown_language_tokenizer(unknown_language_name):\n    \"\"\"\n    Test that loading tokenize works for an unknown language\n    \"\"\"\n    base_pipe = stanza.Pipeline(\"en\", dir=TEST_MODELS_DIR, processors=\"tokenize\", download_method=None)\n    # even if we one day add MWT to English, the tokenizer by itself should still work\n    tokenize_processor = base_pipe.processors[\"tokenize\"]\n\n    pipe=stanza.Pipeline(unknown_language_name,\n                         processors=\"tokenize\",\n                         allow_unknown_language=True,\n                         tokenize_model_path=tokenize_processor.config['model_path'],\n                         download_method=None)\n    doc = pipe(\"This is a test\")\n    words = [x.text for x in doc.sentences[0].words]\n    assert words == ['This', 'is', 'a', 'test']\n\n\ndef test_unknown_language_mwt(unknown_language_name):\n    \"\"\"\n    Test that loading tokenize & mwt works for an unknown language\n    \"\"\"\n    base_pipe = stanza.Pipeline(\"fr\", dir=TEST_MODELS_DIR, processors=\"tokenize,mwt\", download_method=None)\n    assert len(base_pipe.processors) == 2\n    tokenize_processor = base_pipe.processors[\"tokenize\"]\n    mwt_processor = base_pipe.processors[\"mwt\"]\n\n    pipe=stanza.Pipeline(unknown_language_name,\n                         model_dir=TEST_MODELS_DIR,\n                         processors=\"tokenize,mwt\",\n                         allow_unknown_language=True,\n                         tokenize_model_path=tokenize_processor.config['model_path'],\n                         mwt_model_path=mwt_processor.config['model_path'],\n                         download_method=None)\n"
  },
  {
    "path": "stanza/tests/pipeline/test_decorators.py",
    "content": "\"\"\"\nBasic tests of the depparse processor boolean flags\n\"\"\"\nimport pytest\n\nimport stanza\nfrom stanza.models.common.doc import Document\nfrom stanza.pipeline.core import PipelineRequirementsException\nfrom stanza.pipeline.processor import Processor, ProcessorVariant, register_processor, register_processor_variant, ProcessorRegisterException\nfrom stanza.utils.conll import CoNLL\nfrom stanza.tests import *\n\npytestmark = pytest.mark.pipeline\n\n# data for testing\nEN_DOC = \"This is a test sentence. This is another!\"\n\nEN_DOC_LOWERCASE_TOKENS = '''<Token id=1;words=[<Word id=1;text=this>]>\n<Token id=2;words=[<Word id=2;text=is>]>\n<Token id=3;words=[<Word id=3;text=a>]>\n<Token id=4;words=[<Word id=4;text=test>]>\n<Token id=5;words=[<Word id=5;text=sentence>]>\n<Token id=6;words=[<Word id=6;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=this>]>\n<Token id=2;words=[<Word id=2;text=is>]>\n<Token id=3;words=[<Word id=3;text=another>]>\n<Token id=4;words=[<Word id=4;text=!>]>'''\n\nEN_DOC_LOL_TOKENS = '''<Token id=1;words=[<Word id=1;text=LOL>]>\n<Token id=2;words=[<Word id=2;text=LOL>]>\n<Token id=3;words=[<Word id=3;text=LOL>]>\n<Token id=4;words=[<Word id=4;text=LOL>]>\n<Token id=5;words=[<Word id=5;text=LOL>]>\n<Token id=6;words=[<Word id=6;text=LOL>]>\n<Token id=7;words=[<Word id=7;text=LOL>]>\n<Token id=8;words=[<Word id=8;text=LOL>]>'''\n\nEN_DOC_COOL_LEMMAS = '''<Token id=1;words=[<Word id=1;text=This;lemma=cool;upos=PRON;xpos=DT;feats=Number=Sing|PronType=Dem>]>\n<Token id=2;words=[<Word id=2;text=is;lemma=cool;upos=AUX;xpos=VBZ;feats=Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin>]>\n<Token id=3;words=[<Word id=3;text=a;lemma=cool;upos=DET;xpos=DT;feats=Definite=Ind|PronType=Art>]>\n<Token id=4;words=[<Word id=4;text=test;lemma=cool;upos=NOUN;xpos=NN;feats=Number=Sing>]>\n<Token id=5;words=[<Word id=5;text=sentence;lemma=cool;upos=NOUN;xpos=NN;feats=Number=Sing>]>\n<Token id=6;words=[<Word id=6;text=.;lemma=cool;upos=PUNCT;xpos=.>]>\n\n<Token id=1;words=[<Word id=1;text=This;lemma=cool;upos=PRON;xpos=DT;feats=Number=Sing|PronType=Dem>]>\n<Token id=2;words=[<Word id=2;text=is;lemma=cool;upos=AUX;xpos=VBZ;feats=Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin>]>\n<Token id=3;words=[<Word id=3;text=another;lemma=cool;upos=DET;xpos=DT;feats=PronType=Ind>]>\n<Token id=4;words=[<Word id=4;text=!;lemma=cool;upos=PUNCT;xpos=.>]>'''\n\n@register_processor(\"lowercase\")\nclass LowercaseProcessor(Processor):\n    ''' Processor that lowercases all text '''\n    _requires = set(['tokenize'])\n    _provides = set(['lowercase'])\n\n    def __init__(self, config, pipeline, device):\n        pass\n\n    def _set_up_model(self, *args):\n        pass\n\n    def process(self, doc):\n        doc.text = doc.text.lower()\n        for sent in doc.sentences:\n            for tok in sent.tokens:\n                tok.text = tok.text.lower()\n\n            for word in sent.words:\n                word.text = word.text.lower()\n\n        return doc\n\ndef test_register_processor():\n    nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,lowercase', download_method=None)\n    doc = nlp(EN_DOC)\n    assert EN_DOC_LOWERCASE_TOKENS == '\\n\\n'.join(sent.tokens_string() for sent in doc.sentences)\n\ndef test_register_nonprocessor():\n    with pytest.raises(ProcessorRegisterException):\n        @register_processor(\"nonprocessor\")\n        class NonProcessor:\n            pass\n\n@register_processor_variant(\"tokenize\", \"lol\")\nclass LOLTokenizer(ProcessorVariant):\n    ''' An alternative tokenizer that splits text by space and replaces all tokens with LOL '''\n\n    def __init__(self, lang):\n        pass\n\n    def process(self, text):\n        sentence = [{'id': (i+1, ), 'text': 'LOL'} for i, tok in enumerate(text.split())]\n        return Document([sentence], text)\n\ndef test_register_processor_variant():\n    nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={\"tokenize\": \"lol\"}, package=None, download_method=None)\n    doc = nlp(EN_DOC)\n    assert EN_DOC_LOL_TOKENS == '\\n\\n'.join(sent.tokens_string() for sent in doc.sentences)\n\n@register_processor_variant(\"lemma\", \"cool\")\nclass CoolLemmatizer(ProcessorVariant):\n    ''' An alternative lemmatizer that lemmatizes every word to \"cool\". '''\n\n    OVERRIDE = True\n\n    def __init__(self, lang):\n        pass\n\n    def process(self, document):\n        for sentence in document.sentences:\n            for word in sentence.words:\n                word.lemma = \"cool\"\n\n        return document\n\ndef test_register_processor_variant_with_override():\n    nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={\"tokenize\": \"combined\", \"pos\": \"combined\", \"lemma\": \"cool\"}, package=None, download_method=None)\n    doc = nlp(EN_DOC)\n    result = '\\n\\n'.join(sent.tokens_string() for sent in doc.sentences)\n    assert EN_DOC_COOL_LEMMAS == result\n\ndef test_register_nonprocessor_variant():\n    with pytest.raises(ProcessorRegisterException):\n        @register_processor_variant(\"tokenize\", \"nonvariant\")\n        class NonVariant:\n            pass\n"
  },
  {
    "path": "stanza/tests/pipeline/test_depparse.py",
    "content": "\"\"\"\nBasic tests of the depparse processor boolean flags\n\"\"\"\nimport gc\n\nimport pytest\n\nimport stanza\nfrom stanza.pipeline.core import PipelineRequirementsException\nfrom stanza.utils.conll import CoNLL\nfrom stanza.tests import *\n\npytestmark = pytest.mark.pipeline\n\n# data for testing\nEN_DOC = \"Barack Obama was born in Hawaii.  He was elected president in 2008.  Obama attended Harvard.\"\n\nEN_DOC_CONLLU_PRETAGGED = \"\"\"\n1\tBarack\tBarack\tPROPN\tNNP\tNumber=Sing\t0\t_\t_\t_\n2\tObama\tObama\tPROPN\tNNP\tNumber=Sing\t1\t_\t_\t_\n3\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t2\t_\t_\t_\n4\tborn\tbear\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t3\t_\t_\t_\n5\tin\tin\tADP\tIN\t_\t4\t_\t_\t_\n6\tHawaii\tHawaii\tPROPN\tNNP\tNumber=Sing\t5\t_\t_\t_\n7\t.\t.\tPUNCT\t.\t_\t6\t_\t_\t_\n\n1\tHe\the\tPRON\tPRP\tCase=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t0\t_\t_\t_\n2\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t1\t_\t_\t_\n3\telected\telect\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t2\t_\t_\t_\n4\tpresident\tpresident\tPROPN\tNNP\tNumber=Sing\t3\t_\t_\t_\n5\tin\tin\tADP\tIN\t_\t4\t_\t_\t_\n6\t2008\t2008\tNUM\tCD\tNumType=Card\t5\t_\t_\t_\n7\t.\t.\tPUNCT\t.\t_\t6\t_\t_\t_\n\n1\tObama\tObama\tPROPN\tNNP\tNumber=Sing\t0\t_\t_\t_\n2\tattended\tattend\tVERB\tVBD\tMood=Ind|Tense=Past|VerbForm=Fin\t1\t_\t_\t_\n3\tHarvard\tHarvard\tPROPN\tNNP\tNumber=Sing\t2\t_\t_\t_\n4\t.\t.\tPUNCT\t.\t_\t3\t_\t_\t_\n\n\n\"\"\".lstrip()\n\nEN_DOC_DEPENDENCY_PARSES_GOLD = \"\"\"\n('Barack', 4, 'nsubj:pass')\n('Obama', 1, 'flat')\n('was', 4, 'aux:pass')\n('born', 0, 'root')\n('in', 6, 'case')\n('Hawaii', 4, 'obl')\n('.', 4, 'punct')\n\n('He', 3, 'nsubj:pass')\n('was', 3, 'aux:pass')\n('elected', 0, 'root')\n('president', 3, 'xcomp')\n('in', 6, 'case')\n('2008', 3, 'obl')\n('.', 3, 'punct')\n\n('Obama', 2, 'nsubj')\n('attended', 0, 'root')\n('Harvard', 2, 'obj')\n('.', 2, 'punct')\n\"\"\".strip()\n\n@pytest.fixture(scope=\"module\")\ndef en_depparse_pipeline():\n    nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,pos,lemma,depparse')\n    gc.collect()\n    return nlp\n\ndef test_depparse(en_depparse_pipeline):\n    doc = en_depparse_pipeline(EN_DOC)\n    assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\\n\\n'.join([sent.dependencies_string() for sent in doc.sentences])\n\n\ndef test_depparse_with_pretagged_doc():\n    nlp = stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en',\n                                  'depparse_pretagged': True})\n\n    doc = CoNLL.conll2doc(input_str=EN_DOC_CONLLU_PRETAGGED)\n    processed_doc = nlp(doc)\n\n    assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\\n\\n'.join(\n        [sent.dependencies_string() for sent in processed_doc.sentences])\n\n\ndef test_raises_requirements_exception_if_pretagged_not_passed():\n    with pytest.raises(PipelineRequirementsException):\n        stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'})\n"
  },
  {
    "path": "stanza/tests/pipeline/test_english_pipeline.py",
    "content": "\"\"\"\nBasic testing of the English pipeline\n\"\"\"\n\nimport pytest\nimport stanza\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models.common.doc import Document\n\nfrom stanza.tests import *\nfrom stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# data for testing\nEN_DOC = \"Barack Obama was born in Hawaii.  He was elected president in 2008.  Obama attended Harvard.\"\n\nEN_DOCS = [\"Barack Obama was born in Hawaii.\", \"He was elected president in 2008.\", \"Obama attended Harvard.\"]\n\nEN_DOC_TOKENS_GOLD = \"\"\"\n<Token id=1;words=[<Word id=1;text=Barack;lemma=Barack;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=nsubj:pass>]>\n<Token id=2;words=[<Word id=2;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=1;deprel=flat>]>\n<Token id=3;words=[<Word id=3;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=4;deprel=aux:pass>]>\n<Token id=4;words=[<Word id=4;text=born;lemma=bear;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>]>\n<Token id=5;words=[<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>]>\n<Token id=6;words=[<Word id=6;text=Hawaii;lemma=Hawaii;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=obl>]>\n<Token id=7;words=[<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=4;deprel=punct>]>\n\n<Token id=1;words=[<Word id=1;text=He;lemma=he;upos=PRON;xpos=PRP;feats=Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs;head=3;deprel=nsubj:pass>]>\n<Token id=2;words=[<Word id=2;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=3;deprel=aux:pass>]>\n<Token id=3;words=[<Word id=3;text=elected;lemma=elect;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>]>\n<Token id=4;words=[<Word id=4;text=president;lemma=president;upos=NOUN;xpos=NN;feats=Number=Sing;head=3;deprel=xcomp>]>\n<Token id=5;words=[<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>]>\n<Token id=6;words=[<Word id=6;text=2008;lemma=2008;upos=NUM;xpos=CD;feats=NumForm=Digit|NumType=Card;head=3;deprel=obl>]>\n<Token id=7;words=[<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=3;deprel=punct>]>\n\n<Token id=1;words=[<Word id=1;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=nsubj>]>\n<Token id=2;words=[<Word id=2;text=attended;lemma=attend;upos=VERB;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=0;deprel=root>]>\n<Token id=3;words=[<Word id=3;text=Harvard;lemma=Harvard;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=obj>]>\n<Token id=4;words=[<Word id=4;text=.;lemma=.;upos=PUNCT;xpos=.;head=2;deprel=punct>]>\n\"\"\".strip()\n\nEN_DOC_WORDS_GOLD = \"\"\"\n<Word id=1;text=Barack;lemma=Barack;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=nsubj:pass>\n<Word id=2;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=1;deprel=flat>\n<Word id=3;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=4;deprel=aux:pass>\n<Word id=4;text=born;lemma=bear;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>\n<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>\n<Word id=6;text=Hawaii;lemma=Hawaii;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=obl>\n<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=4;deprel=punct>\n\n<Word id=1;text=He;lemma=he;upos=PRON;xpos=PRP;feats=Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs;head=3;deprel=nsubj:pass>\n<Word id=2;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=3;deprel=aux:pass>\n<Word id=3;text=elected;lemma=elect;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>\n<Word id=4;text=president;lemma=president;upos=NOUN;xpos=NN;feats=Number=Sing;head=3;deprel=xcomp>\n<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>\n<Word id=6;text=2008;lemma=2008;upos=NUM;xpos=CD;feats=NumForm=Digit|NumType=Card;head=3;deprel=obl>\n<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=3;deprel=punct>\n\n<Word id=1;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=nsubj>\n<Word id=2;text=attended;lemma=attend;upos=VERB;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=0;deprel=root>\n<Word id=3;text=Harvard;lemma=Harvard;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=obj>\n<Word id=4;text=.;lemma=.;upos=PUNCT;xpos=.;head=2;deprel=punct>\n\"\"\".strip()\n\nEN_DOC_DEPENDENCY_PARSES_GOLD = \"\"\"\n('Barack', 4, 'nsubj:pass')\n('Obama', 1, 'flat')\n('was', 4, 'aux:pass')\n('born', 0, 'root')\n('in', 6, 'case')\n('Hawaii', 4, 'obl')\n('.', 4, 'punct')\n\n('He', 3, 'nsubj:pass')\n('was', 3, 'aux:pass')\n('elected', 0, 'root')\n('president', 3, 'xcomp')\n('in', 6, 'case')\n('2008', 3, 'obl')\n('.', 3, 'punct')\n\n('Obama', 2, 'nsubj')\n('attended', 0, 'root')\n('Harvard', 2, 'obj')\n('.', 2, 'punct')\n\"\"\".strip()\n\nEN_DOC_CONLLU_GOLD = \"\"\"\n# text = Barack Obama was born in Hawaii.\n# sent_id = 0\n# constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .)))\n# sentiment = 1\n1\tBarack\tBarack\tPROPN\tNNP\tNumber=Sing\t4\tnsubj:pass\t_\tstart_char=0|end_char=6|ner=B-PERSON\n2\tObama\tObama\tPROPN\tNNP\tNumber=Sing\t1\tflat\t_\tstart_char=7|end_char=12|ner=E-PERSON\n3\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t4\taux:pass\t_\tstart_char=13|end_char=16|ner=O\n4\tborn\tbear\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t_\tstart_char=17|end_char=21|ner=O\n5\tin\tin\tADP\tIN\t_\t6\tcase\t_\tstart_char=22|end_char=24|ner=O\n6\tHawaii\tHawaii\tPROPN\tNNP\tNumber=Sing\t4\tobl\t_\tSpaceAfter=No|start_char=25|end_char=31|ner=S-GPE\n7\t.\t.\tPUNCT\t.\t_\t4\tpunct\t_\tSpacesAfter=\\\\s\\\\s|start_char=31|end_char=32|ner=O\n\n# text = He was elected president in 2008.\n# sent_id = 1\n# constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .)))\n# sentiment = 1\n1\tHe\the\tPRON\tPRP\tCase=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t3\tnsubj:pass\t_\tstart_char=34|end_char=36|ner=O\n2\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t3\taux:pass\t_\tstart_char=37|end_char=40|ner=O\n3\telected\telect\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t_\tstart_char=41|end_char=48|ner=O\n4\tpresident\tpresident\tNOUN\tNN\tNumber=Sing\t3\txcomp\t_\tstart_char=49|end_char=58|ner=O\n5\tin\tin\tADP\tIN\t_\t6\tcase\t_\tstart_char=59|end_char=61|ner=O\n6\t2008\t2008\tNUM\tCD\tNumForm=Digit|NumType=Card\t3\tobl\t_\tSpaceAfter=No|start_char=62|end_char=66|ner=S-DATE\n7\t.\t.\tPUNCT\t.\t_\t3\tpunct\t_\tSpacesAfter=\\\\s\\\\s|start_char=66|end_char=67|ner=O\n\n# text = Obama attended Harvard.\n# sent_id = 2\n# constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .)))\n# sentiment = 1\n1\tObama\tObama\tPROPN\tNNP\tNumber=Sing\t2\tnsubj\t_\tstart_char=69|end_char=74|ner=S-PERSON\n2\tattended\tattend\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t_\tstart_char=75|end_char=83|ner=O\n3\tHarvard\tHarvard\tPROPN\tNNP\tNumber=Sing\t2\tobj\t_\tSpaceAfter=No|start_char=84|end_char=91|ner=S-ORG\n4\t.\t.\tPUNCT\t.\t_\t2\tpunct\t_\tSpaceAfter=No|start_char=91|end_char=92|ner=O\n\"\"\".strip()\n\nEN_DOC_CONLLU_GOLD_MULTIDOC = \"\"\"\n# text = Barack Obama was born in Hawaii.\n# sent_id = 0\n# constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .)))\n# sentiment = 1\n1\tBarack\tBarack\tPROPN\tNNP\tNumber=Sing\t4\tnsubj:pass\t_\tstart_char=0|end_char=6|ner=B-PERSON\n2\tObama\tObama\tPROPN\tNNP\tNumber=Sing\t1\tflat\t_\tstart_char=7|end_char=12|ner=E-PERSON\n3\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t4\taux:pass\t_\tstart_char=13|end_char=16|ner=O\n4\tborn\tbear\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t_\tstart_char=17|end_char=21|ner=O\n5\tin\tin\tADP\tIN\t_\t6\tcase\t_\tstart_char=22|end_char=24|ner=O\n6\tHawaii\tHawaii\tPROPN\tNNP\tNumber=Sing\t4\tobl\t_\tSpaceAfter=No|start_char=25|end_char=31|ner=S-GPE\n7\t.\t.\tPUNCT\t.\t_\t4\tpunct\t_\tSpaceAfter=No|start_char=31|end_char=32|ner=O\n\n# text = He was elected president in 2008.\n# sent_id = 1\n# constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .)))\n# sentiment = 1\n1\tHe\the\tPRON\tPRP\tCase=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t3\tnsubj:pass\t_\tstart_char=0|end_char=2|ner=O\n2\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t3\taux:pass\t_\tstart_char=3|end_char=6|ner=O\n3\telected\telect\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t_\tstart_char=7|end_char=14|ner=O\n4\tpresident\tpresident\tNOUN\tNN\tNumber=Sing\t3\txcomp\t_\tstart_char=15|end_char=24|ner=O\n5\tin\tin\tADP\tIN\t_\t6\tcase\t_\tstart_char=25|end_char=27|ner=O\n6\t2008\t2008\tNUM\tCD\tNumForm=Digit|NumType=Card\t3\tobl\t_\tSpaceAfter=No|start_char=28|end_char=32|ner=S-DATE\n7\t.\t.\tPUNCT\t.\t_\t3\tpunct\t_\tSpaceAfter=No|start_char=32|end_char=33|ner=O\n\n# text = Obama attended Harvard.\n# sent_id = 2\n# constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .)))\n# sentiment = 1\n1\tObama\tObama\tPROPN\tNNP\tNumber=Sing\t2\tnsubj\t_\tstart_char=0|end_char=5|ner=S-PERSON\n2\tattended\tattend\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t_\tstart_char=6|end_char=14|ner=O\n3\tHarvard\tHarvard\tPROPN\tNNP\tNumber=Sing\t2\tobj\t_\tSpaceAfter=No|start_char=15|end_char=22|ner=S-ORG\n4\t.\t.\tPUNCT\t.\t_\t2\tpunct\t_\tSpaceAfter=No|start_char=22|end_char=23|ner=O\n\"\"\".strip()\n\nPRETOKENIZED_TEXT = \"Jennifer has lovely blue antennae .\"\n\nPRETOKENIZED_PIECES = [PRETOKENIZED_TEXT.split()]\n\nEXPECTED_TOKENIZED_ONLY_CONLLU = \"\"\"\n# text = Jennifer has lovely blue antennae .\n# sent_id = 0\n1\tJennifer\t_\t_\t_\t_\t0\t_\t_\tstart_char=0|end_char=8\n2\thas\t_\t_\t_\t_\t1\t_\t_\tstart_char=9|end_char=12\n3\tlovely\t_\t_\t_\t_\t2\t_\t_\tstart_char=13|end_char=19\n4\tblue\t_\t_\t_\t_\t3\t_\t_\tstart_char=20|end_char=24\n5\tantennae\t_\t_\t_\t_\t4\t_\t_\tstart_char=25|end_char=33\n6\t.\t_\t_\t_\t_\t5\t_\t_\tSpaceAfter=No|start_char=34|end_char=35\n\"\"\".strip()\n\nEXPECTED_PRETOKENIZED_CONLLU = \"\"\"\n# text = Jennifer has lovely blue antennae .\n# sent_id = 0\n# constituency = (ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ lovely) (JJ blue) (NNS antennae))) (. .)))\n# sentiment = 2\n1\tJennifer\tJennifer\tPROPN\tNNP\tNumber=Sing\t2\tnsubj\t_\tstart_char=0|end_char=8|ner=S-PERSON\n2\thas\thave\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t_\tstart_char=9|end_char=12|ner=O\n3\tlovely\tlovely\tADJ\tJJ\tDegree=Pos\t5\tamod\t_\tstart_char=13|end_char=19|ner=O\n4\tblue\tblue\tADJ\tJJ\tDegree=Pos\t5\tamod\t_\tstart_char=20|end_char=24|ner=O\n5\tantennae\tantenna\tNOUN\tNNS\tNumber=Plur\t2\tobj\t_\tstart_char=25|end_char=33|ner=O\n6\t.\t.\tPUNCT\t.\t_\t2\tpunct\t_\tSpaceAfter=No|start_char=34|end_char=35|ner=O\n\"\"\".strip()\n\nclass TestEnglishPipeline:\n    @pytest.fixture(scope=\"class\")\n    def pipeline(self):\n        return stanza.Pipeline(dir=TEST_MODELS_DIR, download_method=None)\n\n    @pytest.fixture(scope=\"class\")\n    def pretokenized_pipeline(self):\n        return stanza.Pipeline(dir=TEST_MODELS_DIR, tokenize_pretokenized=True, download_method=None)\n\n    @pytest.fixture(scope=\"class\")\n    def tokenizer_pipeline(self):\n        return stanza.Pipeline(dir=TEST_MODELS_DIR, processors=\"tokenize\", download_method=None)\n\n    @pytest.fixture(scope=\"class\")\n    def processed_doc(self, pipeline):\n        \"\"\" Document created by running full English pipeline on a few sentences \"\"\"\n        return pipeline(EN_DOC)\n\n    def test_text(self, processed_doc):\n        assert processed_doc.text == EN_DOC\n\n\n    def test_conllu(self, processed_doc):\n        assert \"{:C}\".format(processed_doc) == EN_DOC_CONLLU_GOLD\n\n    def test_process_conllu(self, pipeline):\n        \"\"\"\n        Process a conllu text directly\n\n        This can use the pipeline which still uses tokenization, as\n        process_conllu skips the tokenize and mwt processors\n        \"\"\"\n        doc = pipeline.process_conllu(EN_DOC_CONLLU_GOLD)\n        result = \"{:C}\".format(doc)\n        assert result == EN_DOC_CONLLU_GOLD\n\n    def test_tokens(self, processed_doc):\n        assert \"\\n\\n\".join([sent.tokens_string() for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD\n\n\n    def test_words(self, processed_doc):\n        assert \"\\n\\n\".join([sent.words_string() for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD\n\n\n    def test_dependency_parse(self, processed_doc):\n        assert \"\\n\\n\".join([sent.dependencies_string() for sent in processed_doc.sentences]) == \\\n               EN_DOC_DEPENDENCY_PARSES_GOLD\n\n    def test_empty(self, pipeline):\n        # make sure that various models handle the degenerate empty case\n        pipeline(\"\")\n        pipeline(\"--\")\n\n    def test_bulk_process(self, pipeline):\n        \"\"\" Double check that the bulk_process method in Pipeline converts documents as expected \"\"\"\n        # it should process strings\n        processed = pipeline.bulk_process(EN_DOCS)\n        assert \"\\n\\n\".join([\"{:C}\".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC\n\n        # it should pass Documents through successfully\n        docs = [Document([], text=t) for t in EN_DOCS]\n        processed = pipeline.bulk_process(docs)\n        assert \"\\n\\n\".join([\"{:C}\".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC\n\n    def test_empty_bulk_process(self, pipeline):\n        \"\"\" Previously we had a bug where an empty document list would cause a crash \"\"\"\n        processed = pipeline.bulk_process([])\n        assert processed == []\n\n    def test_pretokenized(self, pretokenized_pipeline, tokenizer_pipeline):\n        doc = pretokenized_pipeline(PRETOKENIZED_PIECES)\n        conllu = \"{:C}\".format(doc).strip()\n        assert conllu == EXPECTED_PRETOKENIZED_CONLLU\n\n        doc = tokenizer_pipeline(PRETOKENIZED_TEXT)\n        conllu = \"{:C}\".format(doc).strip()\n        assert conllu == EXPECTED_TOKENIZED_ONLY_CONLLU\n\n        # putting a doc with tokens into the pipeline should also work\n        reparsed = pretokenized_pipeline(doc)\n        conllu = \"{:C}\".format(reparsed).strip()\n        assert conllu == EXPECTED_PRETOKENIZED_CONLLU\n\n    def test_bulk_pretokenized(self, pretokenized_pipeline, tokenizer_pipeline):\n        doc = tokenizer_pipeline(PRETOKENIZED_TEXT)\n        conllu = \"{:C}\".format(doc).strip()\n        assert conllu == EXPECTED_TOKENIZED_ONLY_CONLLU\n\n        docs = pretokenized_pipeline([doc, doc])\n        assert len(docs) == 2\n        for doc in docs:\n            conllu = \"{:C}\".format(doc).strip()\n            assert conllu == EXPECTED_PRETOKENIZED_CONLLU\n\n    def test_conll2doc_pretokenized(self, pretokenized_pipeline):\n        doc = CoNLL.conll2doc(input_str=EXPECTED_TOKENIZED_ONLY_CONLLU)\n        # this was bug from version 1.10.1 sent to us from a user\n        # the pretokenized tokenize_processor would try to whitespace tokenize a document\n        # even if the document already had sentences & words & stuff\n        # not only would that be wrong if the text wouldn't whitespace tokenize into the words\n        # (such as with punctuation and SpaceAfter=No),\n        # it wouldn't even work in the case of conll2doc, since the document.text wasn't set\n        docs = pretokenized_pipeline([doc, doc])\n        assert len(docs) == 2\n        for doc in docs:\n            conllu = \"{:C}\".format(doc).strip()\n            assert conllu == EXPECTED_PRETOKENIZED_CONLLU\n\n    def test_stream(self, pipeline):\n        \"\"\" Test the streaming interface to the Pipeline \"\"\"\n        # Test all of the documents in one batch\n        # (the default batch size is significantly more than |EN_DOCS|)\n        processed = [doc for doc in pipeline.stream(EN_DOCS)]\n        assert \"\\n\\n\".join([\"{:C}\".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC\n\n        # It should also work on an iterator rather than an iterable\n        processed = [doc for doc in pipeline.stream(iter(EN_DOCS))]\n        assert \"\\n\\n\".join([\"{:C}\".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC\n\n        # Stream one at a time\n        processed = [doc for doc in pipeline.stream(EN_DOCS, batch_size=1)]\n        processed = [\"{:C}\".format(doc) for doc in processed]\n        assert \"\\n\\n\".join(processed) == EN_DOC_CONLLU_GOLD_MULTIDOC\n\n    @pytest.fixture(scope=\"class\")\n    def processed_multidoc(self, pipeline):\n        \"\"\" Document created by running full English pipeline on a few sentences \"\"\"\n        docs = [Document([], text=t) for t in EN_DOCS]\n        return pipeline(docs)\n\n    def test_conllu_multidoc(self, processed_multidoc):\n        assert \"\\n\\n\".join([\"{:C}\".format(doc) for doc in processed_multidoc]) == EN_DOC_CONLLU_GOLD_MULTIDOC\n\n    def test_tokens_multidoc(self, processed_multidoc):\n        assert \"\\n\\n\".join([sent.tokens_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD\n\n\n    def test_words_multidoc(self, processed_multidoc):\n        assert \"\\n\\n\".join([sent.words_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD\n\n    def test_sentence_indices_multidoc(self, processed_multidoc):\n        sentences = [sent for doc in processed_multidoc for sent in doc.sentences]\n        for sent_idx, sentence in enumerate(sentences):\n            assert sent_idx == sentence.index\n\n    def test_dependency_parse_multidoc(self, processed_multidoc):\n        assert \"\\n\\n\".join([sent.dependencies_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == \\\n               EN_DOC_DEPENDENCY_PARSES_GOLD\n\n\n    @pytest.fixture(scope=\"class\")\n    def processed_multidoc_variant(self):\n        \"\"\" Document created by running full English pipeline on a few sentences \"\"\"\n        docs = [Document([], text=t) for t in EN_DOCS]\n        nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors={'tokenize': 'spacy'})\n        return nlp(docs)\n\n    def test_dependency_parse_multidoc_variant(self, processed_multidoc_variant):\n        assert \"\\n\\n\".join([sent.dependencies_string() for processed_doc in processed_multidoc_variant for sent in processed_doc.sentences]) == \\\n               EN_DOC_DEPENDENCY_PARSES_GOLD\n\n    def test_constituency_parser(self):\n        nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors=\"tokenize,pos,constituency\")\n        doc = nlp(\"This is a test\")\n        assert str(doc.sentences[0].constituency) == '(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))'\n\n    def test_on_gpu(self, pipeline):\n        \"\"\"\n        The default pipeline should have all the models on the GPU\n        \"\"\"\n        check_on_gpu(pipeline)\n\n    def test_on_cpu(self):\n        \"\"\"\n        Create a pipeline on the CPU, check that all the models on CPU\n        \"\"\"\n        pipeline = stanza.Pipeline(\"en\", dir=TEST_MODELS_DIR, use_gpu=False)\n        check_on_cpu(pipeline)\n"
  },
  {
    "path": "stanza/tests/pipeline/test_french_pipeline.py",
    "content": "\"\"\"\nBasic testing of French pipeline\n\nThe benefit of this test is to verify that the bulk processing works\nfor languages with MWT in them\n\"\"\"\n\nimport pytest\nimport stanza\nfrom stanza.models.common.doc import Document\n\nfrom stanza.tests import *\nfrom stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu\n\npytestmark = pytest.mark.pipeline\n\n\nFR_MWT_SENTENCE = \"Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de \" \\\n                  \"l'Industrie et du Numérique.\"\n\nEXPECTED_RESULT = \"\"\"\n[\n  [\n    {\n      \"id\": 1,\n      \"text\": \"Alors\",\n      \"lemma\": \"alors\",\n      \"upos\": \"ADV\",\n      \"head\": 3,\n      \"deprel\": \"advmod\",\n      \"start_char\": 0,\n      \"end_char\": 5\n    },\n    {\n      \"id\": 2,\n      \"text\": \"encore\",\n      \"lemma\": \"encore\",\n      \"upos\": \"ADV\",\n      \"head\": 3,\n      \"deprel\": \"advmod\",\n      \"start_char\": 6,\n      \"end_char\": 12\n    },\n    {\n      \"id\": 3,\n      \"text\": \"inconnu\",\n      \"lemma\": \"inconnu\",\n      \"upos\": \"ADJ\",\n      \"feats\": \"Gender=Masc|Number=Sing\",\n      \"head\": 11,\n      \"deprel\": \"advcl\",\n      \"start_char\": 13,\n      \"end_char\": 20\n    },\n    {\n      \"id\": [\n        4,\n        5\n      ],\n      \"text\": \"du\",\n      \"start_char\": 21,\n      \"end_char\": 23\n    },\n    {\n      \"id\": 4,\n      \"text\": \"de\",\n      \"lemma\": \"de\",\n      \"upos\": \"ADP\",\n      \"head\": 7,\n      \"deprel\": \"case\"\n    },\n    {\n      \"id\": 5,\n      \"text\": \"le\",\n      \"lemma\": \"le\",\n      \"upos\": \"DET\",\n      \"feats\": \"Definite=Def|Gender=Masc|Number=Sing|PronType=Art\",\n      \"head\": 7,\n      \"deprel\": \"det\"\n    },\n    {\n      \"id\": 6,\n      \"text\": \"grand\",\n      \"lemma\": \"grand\",\n      \"upos\": \"ADJ\",\n      \"feats\": \"Gender=Masc|Number=Sing\",\n      \"head\": 7,\n      \"deprel\": \"amod\",\n      \"start_char\": 24,\n      \"end_char\": 29\n    },\n    {\n      \"id\": 7,\n      \"text\": \"public\",\n      \"lemma\": \"public\",\n      \"upos\": \"NOUN\",\n      \"feats\": \"Gender=Masc|Number=Sing\",\n      \"head\": 3,\n      \"deprel\": \"obl:arg\",\n      \"start_char\": 30,\n      \"end_char\": 36,\n      \"misc\": \"SpaceAfter=No\"\n    },\n    {\n      \"id\": 8,\n      \"text\": \",\",\n      \"lemma\": \",\",\n      \"upos\": \"PUNCT\",\n      \"head\": 3,\n      \"deprel\": \"punct\",\n      \"start_char\": 36,\n      \"end_char\": 37\n    },\n    {\n      \"id\": 9,\n      \"text\": \"Emmanuel\",\n      \"lemma\": \"Emmanuel\",\n      \"upos\": \"PROPN\",\n      \"head\": 11,\n      \"deprel\": \"nsubj\",\n      \"start_char\": 38,\n      \"end_char\": 46\n    },\n    {\n      \"id\": 10,\n      \"text\": \"Macron\",\n      \"lemma\": \"Macron\",\n      \"upos\": \"PROPN\",\n      \"head\": 9,\n      \"deprel\": \"flat:name\",\n      \"start_char\": 47,\n      \"end_char\": 53\n    },\n    {\n      \"id\": 11,\n      \"text\": \"devient\",\n      \"lemma\": \"devenir\",\n      \"upos\": \"VERB\",\n      \"feats\": \"Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\",\n      \"head\": 0,\n      \"deprel\": \"root\",\n      \"start_char\": 54,\n      \"end_char\": 61\n    },\n    {\n      \"id\": 12,\n      \"text\": \"en\",\n      \"lemma\": \"en\",\n      \"upos\": \"ADP\",\n      \"head\": 13,\n      \"deprel\": \"case\",\n      \"start_char\": 62,\n      \"end_char\": 64\n    },\n    {\n      \"id\": 13,\n      \"text\": \"2014\",\n      \"lemma\": \"2014\",\n      \"upos\": \"NUM\",\n      \"feats\": \"Number=Plur\",\n      \"head\": 11,\n      \"deprel\": \"obl:mod\",\n      \"start_char\": 65,\n      \"end_char\": 69\n    },\n    {\n      \"id\": 14,\n      \"text\": \"ministre\",\n      \"lemma\": \"ministre\",\n      \"upos\": \"NOUN\",\n      \"feats\": \"Gender=Masc|Number=Sing\",\n      \"head\": 11,\n      \"deprel\": \"xcomp\",\n      \"start_char\": 70,\n      \"end_char\": 78\n    },\n    {\n      \"id\": 15,\n      \"text\": \"de\",\n      \"lemma\": \"de\",\n      \"upos\": \"ADP\",\n      \"head\": 17,\n      \"deprel\": \"case\",\n      \"start_char\": 79,\n      \"end_char\": 81\n    },\n    {\n      \"id\": 16,\n      \"text\": \"l'\",\n      \"lemma\": \"le\",\n      \"upos\": \"DET\",\n      \"feats\": \"Definite=Def|Number=Sing|PronType=Art\",\n      \"head\": 17,\n      \"deprel\": \"det\",\n      \"start_char\": 82,\n      \"end_char\": 84,\n      \"misc\": \"SpaceAfter=No\"\n    },\n    {\n      \"id\": 17,\n      \"text\": \"Économie\",\n      \"lemma\": \"économie\",\n      \"upos\": \"NOUN\",\n      \"feats\": \"Gender=Fem|Number=Sing\",\n      \"head\": 14,\n      \"deprel\": \"nmod\",\n      \"start_char\": 84,\n      \"end_char\": 92,\n      \"misc\": \"SpaceAfter=No\"\n    },\n    {\n      \"id\": 18,\n      \"text\": \",\",\n      \"lemma\": \",\",\n      \"upos\": \"PUNCT\",\n      \"head\": 21,\n      \"deprel\": \"punct\",\n      \"start_char\": 92,\n      \"end_char\": 93\n    },\n    {\n      \"id\": 19,\n      \"text\": \"de\",\n      \"lemma\": \"de\",\n      \"upos\": \"ADP\",\n      \"head\": 21,\n      \"deprel\": \"case\",\n      \"start_char\": 94,\n      \"end_char\": 96\n    },\n    {\n      \"id\": 20,\n      \"text\": \"l'\",\n      \"lemma\": \"le\",\n      \"upos\": \"DET\",\n      \"feats\": \"Definite=Def|Number=Sing|PronType=Art\",\n      \"head\": 21,\n      \"deprel\": \"det\",\n      \"start_char\": 97,\n      \"end_char\": 99,\n      \"misc\": \"SpaceAfter=No\"\n    },\n    {\n      \"id\": 21,\n      \"text\": \"Industrie\",\n      \"lemma\": \"industrie\",\n      \"upos\": \"NOUN\",\n      \"feats\": \"Gender=Fem|Number=Sing\",\n      \"head\": 17,\n      \"deprel\": \"conj\",\n      \"start_char\": 99,\n      \"end_char\": 108\n    },\n    {\n      \"id\": 22,\n      \"text\": \"et\",\n      \"lemma\": \"et\",\n      \"upos\": \"CCONJ\",\n      \"head\": 25,\n      \"deprel\": \"cc\",\n      \"start_char\": 109,\n      \"end_char\": 111\n    },\n    {\n      \"id\": [\n        23,\n        24\n      ],\n      \"text\": \"du\",\n      \"start_char\": 112,\n      \"end_char\": 114\n    },\n    {\n      \"id\": 23,\n      \"text\": \"de\",\n      \"lemma\": \"de\",\n      \"upos\": \"ADP\",\n      \"head\": 25,\n      \"deprel\": \"case\"\n    },\n    {\n      \"id\": 24,\n      \"text\": \"le\",\n      \"lemma\": \"le\",\n      \"upos\": \"DET\",\n      \"feats\": \"Definite=Def|Gender=Masc|Number=Sing|PronType=Art\",\n      \"head\": 25,\n      \"deprel\": \"det\"\n    },\n    {\n      \"id\": 25,\n      \"text\": \"Numérique\",\n      \"lemma\": \"numérique\",\n      \"upos\": \"NOUN\",\n      \"feats\": \"Gender=Masc|Number=Sing\",\n      \"head\": 17,\n      \"deprel\": \"conj\",\n      \"start_char\": 115,\n      \"end_char\": 124,\n      \"misc\": \"SpaceAfter=No\"\n    },\n    {\n      \"id\": 26,\n      \"text\": \".\",\n      \"lemma\": \".\",\n      \"upos\": \"PUNCT\",\n      \"head\": 11,\n      \"deprel\": \"punct\",\n      \"start_char\": 124,\n      \"end_char\": 125,\n      \"misc\": \"SpaceAfter=No\"\n    }\n  ]\n]\n\"\"\"\n\nclass TestFrenchPipeline:\n    @pytest.fixture(scope=\"class\")\n    def pipeline(self):\n        \"\"\" Create a pipeline with French models \"\"\"\n        pipeline = stanza.Pipeline(processors='tokenize,mwt,pos,lemma,depparse', dir=TEST_MODELS_DIR, lang='fr')\n        return pipeline\n\n    def test_single(self, pipeline):\n        doc = pipeline(FR_MWT_SENTENCE)\n        compare_ignoring_whitespace(str(doc), EXPECTED_RESULT)\n\n    def test_bulk(self, pipeline):\n        NUM_DOCS = 10\n        raw_text = [FR_MWT_SENTENCE] * NUM_DOCS\n        raw_doc = [Document([], text=doccontent) for doccontent in raw_text]\n\n        result = pipeline(raw_doc)\n\n        assert len(result) == NUM_DOCS\n        for doc in result:\n            compare_ignoring_whitespace(str(doc), EXPECTED_RESULT)\n            assert len(doc.sentences) == 1\n            assert doc.num_words == 26\n            assert doc.num_tokens == 24\n\n    def test_on_gpu(self, pipeline):\n        \"\"\"\n        The default pipeline should have all the models on the GPU\n        \"\"\"\n        check_on_gpu(pipeline)\n\n    def test_on_cpu(self):\n        \"\"\"\n        Create a pipeline on the CPU, check that all the models on CPU\n        \"\"\"\n        pipeline = stanza.Pipeline(\"fr\", dir=TEST_MODELS_DIR, use_gpu=False)\n        check_on_cpu(pipeline)\n"
  },
  {
    "path": "stanza/tests/pipeline/test_lemmatizer.py",
    "content": "\"\"\"\nBasic testing of lemmatization\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza.tests import *\nfrom stanza.models.common.doc import TEXT, UPOS, LEMMA\n\npytestmark = pytest.mark.pipeline\n\nEN_DOC = \"Joe Smith was born in California.\"\n\nEN_DOC_IDENTITY_GOLD = \"\"\"\nJoe Joe\nSmith Smith\nwas was\nborn born\nin in\nCalifornia California\n. .\n\"\"\".strip()\n\nEN_DOC_LEMMATIZER_MODEL_GOLD = \"\"\"\nJoe Joe\nSmith Smith\nwas be\nborn bear\nin in\nCalifornia California\n. .\n\"\"\".strip()\n\n\ndef test_identity_lemmatizer():\n    nlp = stanza.Pipeline(**{'processors': 'tokenize,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_use_identity': True}, download_method=None)\n    doc = nlp(EN_DOC)\n    word_lemma_pairs = []\n    for w in doc.iter_words():\n        word_lemma_pairs += [f\"{w.text} {w.lemma}\"]\n    assert EN_DOC_IDENTITY_GOLD == \"\\n\".join(word_lemma_pairs)\n\ndef test_full_lemmatizer():\n    nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, download_method=None)\n    doc = nlp(EN_DOC)\n    word_lemma_pairs = []\n    for w in doc.iter_words():\n        word_lemma_pairs += [f\"{w.text} {w.lemma}\"]\n    assert EN_DOC_LEMMATIZER_MODEL_GOLD == \"\\n\".join(word_lemma_pairs)\n\ndef find_unknown_word(lemmatizer, base):\n    for i in range(10):\n        base = base + \"z\"\n        if base not in lemmatizer.word_dict and all(x[0] != base for x in lemmatizer.composite_dict.keys()):\n            return base\n    raise RuntimeError(\"wtf?\")\n\ndef test_store_results():\n    nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, lemma_store_results=True, download_method=None)\n    lemmatizer = nlp.processors[\"lemma\"]._trainer\n\n    az = find_unknown_word(lemmatizer, \"a\")\n    bz = find_unknown_word(lemmatizer, \"b\")\n    cz = find_unknown_word(lemmatizer, \"c\")\n\n    # try sentences with the order long, short\n    doc = nlp(\"I found an \" + az + \" in my \" + bz + \".  It was a \" + cz)\n    stuff = doc.get([TEXT, UPOS, LEMMA])\n    assert len(stuff) == 12\n    assert stuff[3][0] == az\n    assert stuff[6][0] == bz\n    assert stuff[11][0] == cz\n\n    assert lemmatizer.composite_dict[(az, stuff[3][1])] == stuff[3][2]\n    assert lemmatizer.composite_dict[(bz, stuff[6][1])] == stuff[6][2]\n    assert lemmatizer.composite_dict[(cz, stuff[11][1])] == stuff[11][2]\n\n    doc2 = nlp(\"I found an \" + az + \" in my \" + bz + \".  It was a \" + cz)\n    stuff2 = doc2.get([TEXT, UPOS, LEMMA])\n\n    assert stuff == stuff2\n\n    dz = find_unknown_word(lemmatizer, \"d\")\n    ez = find_unknown_word(lemmatizer, \"e\")\n    fz = find_unknown_word(lemmatizer, \"f\")\n\n    # try sentences with the order long, short\n    doc = nlp(\"It was a \" + dz + \".  I found an \" + ez + \" in my \" + fz)\n    stuff = doc.get([TEXT, UPOS, LEMMA])\n    assert len(stuff) == 12\n    assert stuff[3][0] == dz\n    assert stuff[8][0] == ez\n    assert stuff[11][0] == fz\n\n    assert lemmatizer.composite_dict[(dz, stuff[3][1])] == stuff[3][2]\n    assert lemmatizer.composite_dict[(ez, stuff[8][1])] == stuff[8][2]\n    assert lemmatizer.composite_dict[(fz, stuff[11][1])] == stuff[11][2]\n\n    doc2 = nlp(\"It was a \" + dz + \".  I found an \" + ez + \" in my \" + fz)\n    stuff2 = doc2.get([TEXT, UPOS, LEMMA])\n\n    assert stuff == stuff2\n\n    assert az not in lemmatizer.word_dict\n\ndef test_caseless_lemmatizer():\n    \"\"\"\n    Test that setting the lemmatizer as caseless at Pipeline time lowercases the text\n    \"\"\"\n    nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None)\n    # the capital letter here should throw off the lemmatizer & it won't remove the plural\n    # although weirdly the current English model *does* lowercase the A\n    doc = nlp(\"Here is an Excerpt\")\n    assert doc.sentences[0].words[-1].lemma == 'excerpt'\n\n    nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None, lemma_caseless=True)\n    # with the model set to lowercasing, the word will be treated as if it were 'antennae'\n    doc = nlp(\"Here is an Excerpt\")\n    assert doc.sentences[0].words[-1].lemma == 'Excerpt'\n\ndef test_latin_caseless_lemmatizer():\n    \"\"\"\n    Test the Latin caseless lemmatizer\n    \"\"\"\n    nlp = stanza.Pipeline('la', package='ittb', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None)\n    lemmatizer = nlp.processors['lemma']\n    assert lemmatizer.config['caseless']\n\n    doc = nlp(\"Quod Erat Demonstrandum\")\n    expected_lemmas = \"qui sum demonstro\".split()\n    assert len(doc.sentences) == 1\n    assert len(doc.sentences[0].words) == 3\n    for word, expected in zip(doc.sentences[0].words, expected_lemmas):\n        assert word.lemma == expected\n\ndef test_contextual_lemmatizer():\n    nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, package={\"lemma\": \"default_accurate\"}, download_method=\"reuse_resources\")\n    lemmatizer = nlp.processors['lemma']._trainer\n    # the accurate model should have a 's classifier\n    assert len(lemmatizer.contextual_lemmatizers) > 0\n    # ideally the doc would have 'have' as the lemma for the second\n    # word, but maybe it's not always accurate.  actually, it works\n    # fine at the time of this test\n    doc = nlp(\"He's added a contextual lemmatizer\")\n"
  },
  {
    "path": "stanza/tests/pipeline/test_pipeline_constituency_processor.py",
    "content": "import gc\nimport pytest\nimport stanza\nfrom stanza.models.common.foundation_cache import FoundationCache\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# data for testing\nTEST_TEXT = \"This is a test.  Another sentence.  Are these sorted?\"\n\nTEST_TOKENS = [[\"This\", \"is\", \"a\", \"test\", \".\"], [\"Another\", \"sentence\", \".\"], [\"Are\", \"these\", \"sorted\", \"?\"]]\n\n@pytest.fixture(scope=\"module\")\ndef foundation_cache():\n    # the test suite sometimes winds up holding on to GPU memory for too long,\n    # resulting in an OOM error\n    # occasionally calling gc.collect() will help\n    gc.collect()\n    return FoundationCache()\n\ndef check_results(doc):\n    assert len(doc.sentences) == len(TEST_TOKENS)\n    for sentence, expected in zip(doc.sentences, TEST_TOKENS):\n        assert sentence.constituency.leaf_labels() == expected\n\ndef test_sorted_big_batch(foundation_cache):\n    pipe = stanza.Pipeline(\"en\", model_dir=TEST_MODELS_DIR, processors=\"tokenize,pos,constituency\", foundation_cache=foundation_cache, download_method=None)\n    doc = pipe(TEST_TEXT)\n    check_results(doc)\n\ndef test_comments(foundation_cache):\n    \"\"\"\n    Test that the pipeline is creating constituency comments\n    \"\"\"\n    pipe = stanza.Pipeline(\"en\", model_dir=TEST_MODELS_DIR, processors=\"tokenize,pos,constituency\", foundation_cache=foundation_cache, download_method=None)\n    doc = pipe(TEST_TEXT)\n    check_results(doc)\n    for sentence in doc.sentences:\n        assert any(x.startswith(\"# constituency = \") for x in sentence.comments)\n    doc.sentences[0].constituency = \"asdf\"\n    assert \"# constituency = asdf\" in doc.sentences[0].comments\n    for sentence in doc.sentences:\n        assert len([x for x in sentence.comments if x.startswith(\"# constituency\")]) == 1\n\ndef test_illegal_batch_size(foundation_cache):\n    stanza.Pipeline(\"en\", model_dir=TEST_MODELS_DIR, processors=\"tokenize,pos\", constituency_batch_size=\"zzz\", foundation_cache=foundation_cache, download_method=None)\n    with pytest.raises(ValueError):\n        stanza.Pipeline(\"en\", model_dir=TEST_MODELS_DIR, processors=\"tokenize,pos,constituency\", constituency_batch_size=\"zzz\", foundation_cache=foundation_cache, download_method=None)\n\ndef test_sorted_one_batch(foundation_cache):\n    pipe = stanza.Pipeline(\"en\", model_dir=TEST_MODELS_DIR, processors=\"tokenize,pos,constituency\", constituency_batch_size=1, foundation_cache=foundation_cache, download_method=None)\n    doc = pipe(TEST_TEXT)\n    check_results(doc)\n\ndef test_sorted_two_batch(foundation_cache):\n    pipe = stanza.Pipeline(\"en\", model_dir=TEST_MODELS_DIR, processors=\"tokenize,pos,constituency\", constituency_batch_size=2, foundation_cache=foundation_cache, download_method=None)\n    doc = pipe(TEST_TEXT)\n    check_results(doc)\n\ndef test_get_constituents(foundation_cache):\n    pipe = stanza.Pipeline(\"en\", model_dir=TEST_MODELS_DIR, processors=\"tokenize,pos,constituency\", foundation_cache=foundation_cache, download_method=None)\n    assert \"SBAR\" in pipe.processors[\"constituency\"].get_constituents()\n"
  },
  {
    "path": "stanza/tests/pipeline/test_pipeline_depparse_processor.py",
    "content": "\"\"\"\nBasic testing of part of speech tagging\n\"\"\"\n\nimport pytest\nimport stanza\nfrom stanza.models.common.vocab import VOCAB_PREFIX\n\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nclass TestClassifier:\n    @pytest.fixture(scope=\"class\")\n    def english_depparse(self):\n        \"\"\"\n        Get a depparse_processor for English\n        \"\"\"\n        nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'})\n        assert 'depparse' in nlp.processors\n        return nlp.processors['depparse']\n\n    def test_get_known_relations(self, english_depparse):\n        \"\"\"\n        Test getting the known relations from a processor.\n\n        Doesn't test that all the relations exist, since who knows what will change in the future\n        \"\"\"\n        relations = english_depparse.get_known_relations()\n        assert len(relations) > 5\n        assert 'case' in relations\n        for i in VOCAB_PREFIX:\n            assert i not in relations\n"
  },
  {
    "path": "stanza/tests/pipeline/test_pipeline_mwt_expander.py",
    "content": "\"\"\"\nBasic testing of multi-word-token expansion\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# mwt data for testing\nFR_MWT_SENTENCE = \"Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de \" \\\n                  \"l'Industrie et du Numérique.\"\n\n\nFR_MWT_TOKEN_TO_WORDS_GOLD = \"\"\"\ntoken: Alors    \t\twords: [<Word id=1;text=Alors>]\ntoken: encore   \t\twords: [<Word id=2;text=encore>]\ntoken: inconnu  \t\twords: [<Word id=3;text=inconnu>]\ntoken: du       \t\twords: [<Word id=4;text=de>, <Word id=5;text=le>]\ntoken: grand    \t\twords: [<Word id=6;text=grand>]\ntoken: public   \t\twords: [<Word id=7;text=public>]\ntoken: ,        \t\twords: [<Word id=8;text=,>]\ntoken: Emmanuel \t\twords: [<Word id=9;text=Emmanuel>]\ntoken: Macron   \t\twords: [<Word id=10;text=Macron>]\ntoken: devient  \t\twords: [<Word id=11;text=devient>]\ntoken: en       \t\twords: [<Word id=12;text=en>]\ntoken: 2014     \t\twords: [<Word id=13;text=2014>]\ntoken: ministre \t\twords: [<Word id=14;text=ministre>]\ntoken: de       \t\twords: [<Word id=15;text=de>]\ntoken: l'       \t\twords: [<Word id=16;text=l'>]\ntoken: Économie \t\twords: [<Word id=17;text=Économie>]\ntoken: ,        \t\twords: [<Word id=18;text=,>]\ntoken: de       \t\twords: [<Word id=19;text=de>]\ntoken: l'       \t\twords: [<Word id=20;text=l'>]\ntoken: Industrie\t\twords: [<Word id=21;text=Industrie>]\ntoken: et       \t\twords: [<Word id=22;text=et>]\ntoken: du       \t\twords: [<Word id=23;text=de>, <Word id=24;text=le>]\ntoken: Numérique\t\twords: [<Word id=25;text=Numérique>]\ntoken: .        \t\twords: [<Word id=26;text=.>]\n\"\"\".strip()\n\nFR_MWT_WORD_TO_TOKEN_GOLD = \"\"\"\nword: Alors    \t\ttoken parent:1-Alors\nword: encore   \t\ttoken parent:2-encore\nword: inconnu  \t\ttoken parent:3-inconnu\nword: de       \t\ttoken parent:4-5-du\nword: le       \t\ttoken parent:4-5-du\nword: grand    \t\ttoken parent:6-grand\nword: public   \t\ttoken parent:7-public\nword: ,        \t\ttoken parent:8-,\nword: Emmanuel \t\ttoken parent:9-Emmanuel\nword: Macron   \t\ttoken parent:10-Macron\nword: devient  \t\ttoken parent:11-devient\nword: en       \t\ttoken parent:12-en\nword: 2014     \t\ttoken parent:13-2014\nword: ministre \t\ttoken parent:14-ministre\nword: de       \t\ttoken parent:15-de\nword: l'       \t\ttoken parent:16-l'\nword: Économie \t\ttoken parent:17-Économie\nword: ,        \t\ttoken parent:18-,\nword: de       \t\ttoken parent:19-de\nword: l'       \t\ttoken parent:20-l'\nword: Industrie\t\ttoken parent:21-Industrie\nword: et       \t\ttoken parent:22-et\nword: de       \t\ttoken parent:23-24-du\nword: le       \t\ttoken parent:23-24-du\nword: Numérique\t\ttoken parent:25-Numérique\nword: .        \t\ttoken parent:26-.\n\"\"\".strip()\n\n\ndef test_mwt():\n    pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='fr', download_method=None)\n    doc = pipeline(FR_MWT_SENTENCE)\n    token_to_words = \"\\n\".join(\n        [f'token: {token.text.ljust(9)}\\t\\twords: [{\", \".join([word.pretty_print() for word in token.words])}]' for sent in doc.sentences for token in sent.tokens]\n    ).strip()\n    word_to_token = \"\\n\".join(\n        [f'word: {word.text.ljust(9)}\\t\\ttoken parent:{\"-\".join([str(x) for x in word.parent.id])}-{word.parent.text}'\n         for sent in doc.sentences for word in sent.words]).strip()\n    assert token_to_words == FR_MWT_TOKEN_TO_WORDS_GOLD\n    assert word_to_token == FR_MWT_WORD_TO_TOKEN_GOLD\n\ndef test_unknown_character():\n    \"\"\"\n    The MWT processor has a mechanism to temporarily add unknown characters to the vocab\n\n    Here we check that it is properly adding the characters from a test case a user sent us\n    \"\"\"\n    pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)\n    text = \"Björkängshallen's\"\n    mwt_processor = pipeline.processors[\"mwt\"]\n    trainer = mwt_processor.trainer\n    # verify that the test case is still valid\n    # (perhaps an updated MWT model will have all of these characters in the future)\n    assert not all(x in trainer.vocab._unit2id for x in text)\n    doc = pipeline(text)\n    batch = mwt_processor.build_batch(doc)\n    # the vocab used in this batch should have the missing characters\n    assert all(x in batch.vocab._unit2id for x in text)\n\ndef test_unknown_word():\n    \"\"\"\n    Test a word which wasn't in the MWT training data\n\n    The seq2seq model for MWT was randomly hallucinating, but with the\n    CharacterClassifier, it should be able to process unusual MWT\n    without hallucinations\n    \"\"\"\n    pipe = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)\n    doc = pipe(\"I read the newspaper's report.\")\n    assert len(doc.sentences) == 1\n    assert len(doc.sentences[0].tokens) == 6\n    assert len(doc.sentences[0].tokens[3].words) == 2\n    assert doc.sentences[0].tokens[3].words[0].text == 'newspaper'\n\n    # double check that this is something unknown to the model\n    mwt_processor = pipe.processors[\"mwt\"]\n    trainer = mwt_processor.trainer\n    expansion = trainer.dict_expansion(\"newspaper's\")\n    assert expansion is None\n"
  },
  {
    "path": "stanza/tests/pipeline/test_pipeline_ner_processor.py",
    "content": "\nimport pytest\nimport stanza\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models.common.doc import Document\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# data for testing\nEN_DOCS = [\"Barack Obama was born in Hawaii.\", \"He was elected president in 2008.\", \"Obama attended Harvard.\"]\n\nEXPECTED_ENTS = [[{\n    \"text\": \"Barack Obama\",\n    \"type\": \"PERSON\",\n    \"start_char\": 0,\n    \"end_char\": 12\n}, {\n    \"text\": \"Hawaii\",\n    \"type\": \"GPE\",\n    \"start_char\": 25,\n    \"end_char\": 31\n}],\n[{\n    \"text\": \"2008\",\n    \"type\": \"DATE\",\n    \"start_char\": 28,\n    \"end_char\": 32\n}],\n[{\n    \"text\": \"Obama\",\n    \"type\": \"PERSON\",\n    \"start_char\": 0,\n    \"end_char\": 5\n}, {\n  \"text\": \"Harvard\",\n  \"type\": \"ORG\",\n  \"start_char\": 15,\n  \"end_char\": 22\n}]]\n\n\ndef check_entities_equal(doc, expected):\n    \"\"\"\n    Checks that the entities of a doc are equal to the given list of maps\n    \"\"\"\n    assert len(doc.ents) == len(expected)\n    for doc_entity, expected_entity in zip(doc.ents, expected):\n        for k in expected_entity:\n            assert getattr(doc_entity, k) == expected_entity[k]\n\nclass TestNERProcessor:\n    @pytest.fixture(scope=\"class\")\n    def pipeline(self):\n        \"\"\"\n        A reusable pipeline with the NER module\n        \"\"\"\n        return stanza.Pipeline(dir=TEST_MODELS_DIR, processors=\"tokenize,ner\")\n\n    @pytest.fixture(scope=\"class\")\n    def processed_doc(self, pipeline):\n        \"\"\" Document created by running full English pipeline on a few sentences \"\"\"\n        return [pipeline(text) for text in EN_DOCS]\n\n\n    @pytest.fixture(scope=\"class\")\n    def processed_bulk(self, pipeline):\n        \"\"\" Document created by running full English pipeline on a few sentences \"\"\"\n        docs = [Document([], text=t) for t in EN_DOCS]\n        return pipeline(docs)\n\n    def test_bulk_ents(self, processed_bulk):\n        assert len(processed_bulk) == len(EXPECTED_ENTS)\n        for doc, expected in zip(processed_bulk, EXPECTED_ENTS):\n            check_entities_equal(doc, expected)\n\n    def test_ents(self, processed_doc):\n        assert len(processed_doc) == len(EXPECTED_ENTS)\n        for doc, expected in zip(processed_doc, EXPECTED_ENTS):\n            check_entities_equal(doc, expected)\n\nEXPECTED_MULTI_ENTS = [{\n  \"text\": \"John Bauer\",\n  \"type\": \"PERSON\",\n  \"start_char\": 0,\n  \"end_char\": 10\n}, {\n  \"text\": \"Stanford\",\n  \"type\": \"ORG\",\n  \"start_char\": 20,\n  \"end_char\": 28\n}, {\n  \"text\": \"hip arthritis\",\n  \"type\": \"DISEASE\",\n  \"start_char\": 37,\n  \"end_char\": 50\n}, {\n  \"text\": \"Chris Manning\",\n  \"type\": \"PERSON\",\n  \"start_char\": 66,\n  \"end_char\": 79\n}]\n\n\nEXPECTED_MULTI_NER = [\n    [('O', 'B-PERSON'),\n     ('O', 'E-PERSON'),\n     ('O', 'O'),\n     ('O', 'O'),\n     ('O', 'S-ORG'),\n     ('O', 'O'),\n     ('O', 'O'),\n     ('B-DISEASE', 'O'),\n     ('E-DISEASE', 'O'),\n     ('O', 'O')],\n    [('O', 'O'),\n     ('O', 'O'),\n     ('O', 'O'),\n     ('O', 'B-PERSON'),\n     ('O', 'E-PERSON'),]]\n\n\n\nclass TestMultiNERProcessor:\n    @pytest.fixture(scope=\"class\")\n    def pipeline(self):\n        \"\"\"\n        A reusable pipeline with TWO ner models\n        \"\"\"\n        return stanza.Pipeline(dir=TEST_MODELS_DIR, processors=\"tokenize,ner\", package={\"ner\": [\"ncbi_disease\", \"ontonotes_charlm\"]})\n\n    def test_multi_example(self, pipeline):\n        doc = pipeline(\"John Bauer works at Stanford and has hip arthritis.  He works for Chris Manning\")\n        check_entities_equal(doc, EXPECTED_MULTI_ENTS)\n\n    def test_multi_ner(self, pipeline):\n        \"\"\"\n        Test that multiple NER labels are correctly assigned in tuples\n        \"\"\"\n        doc = pipeline(\"John Bauer works at Stanford and has hip arthritis.  He works for Chris Manning\")\n        multi_ner = [[token.multi_ner for token in sentence.tokens] for sentence in doc.sentences]\n        assert multi_ner == EXPECTED_MULTI_NER\n\n    def test_known_tags(self, pipeline):\n        assert pipeline.processors[\"ner\"].get_known_tags() == [\"DISEASE\"]\n        assert len(pipeline.processors[\"ner\"].get_known_tags(1)) == 18\n"
  },
  {
    "path": "stanza/tests/pipeline/test_pipeline_pos_processor.py",
    "content": "\"\"\"\nBasic testing of part of speech tagging\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza.tests import *\n\npytestmark = pytest.mark.pipeline\n\nEN_DOC = \"Joe Smith was born in California.\"\n\nEN_DOC_GOLD = \"\"\"\n<Token id=1;words=[<Word id=1;text=Joe;upos=PROPN;xpos=NNP;feats=Number=Sing>]>\n<Token id=2;words=[<Word id=2;text=Smith;upos=PROPN;xpos=NNP;feats=Number=Sing>]>\n<Token id=3;words=[<Word id=3;text=was;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin>]>\n<Token id=4;words=[<Word id=4;text=born;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass>]>\n<Token id=5;words=[<Word id=5;text=in;upos=ADP;xpos=IN>]>\n<Token id=6;words=[<Word id=6;text=California;upos=PROPN;xpos=NNP;feats=Number=Sing>]>\n<Token id=7;words=[<Word id=7;text=.;upos=PUNCT;xpos=.>]>\n\"\"\".strip()\n\n@pytest.fixture(scope=\"module\")\ndef pos_pipeline():\n    return stanza.Pipeline(**{'processors': 'tokenize,pos', 'dir': TEST_MODELS_DIR, 'download_method': None, 'lang': 'en'})\n\ndef test_part_of_speech(pos_pipeline):\n    doc = pos_pipeline(EN_DOC)\n    assert EN_DOC_GOLD == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n\ndef test_get_known_xpos(pos_pipeline):\n    tags = pos_pipeline.processors['pos'].get_known_xpos()\n    # make sure we have xpos...\n    assert 'DT' in tags\n    # ... and not upos\n    assert 'DET' not in tags\n\ndef test_get_known_upos(pos_pipeline):\n    tags = pos_pipeline.processors['pos'].get_known_upos()\n    # make sure we have upos...\n    assert 'DET' in tags\n    # ... and not xpos\n    assert 'DT' not in tags\n\n\ndef test_get_known_feats(pos_pipeline):\n    feats = pos_pipeline.processors['pos'].get_known_feats()\n    # I appreciate how self-referential the Abbr feat is\n    assert 'Abbr' in feats\n    assert 'Yes' in feats['Abbr']\n"
  },
  {
    "path": "stanza/tests/pipeline/test_pipeline_sentiment_processor.py",
    "content": "import gc\n\nimport pytest\nimport stanza\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models.common.doc import Document\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\n# data for testing\nEN_DOCS = [\"Ragavan is terrible and should go away.\",  \"Today is okay.\",  \"Urza's Saga is great.\"]\n\nEN_DOC = \"  \".join(EN_DOCS)\n\nEXPECTED = [0, 1, 2]\n\nclass TestSentimentPipeline:\n    @pytest.fixture(scope=\"class\")\n    def pipeline(self):\n        \"\"\"\n        A reusable pipeline with the NER module\n        \"\"\"\n        gc.collect()\n        return stanza.Pipeline(dir=TEST_MODELS_DIR, processors=\"tokenize,sentiment\")\n\n    def test_simple(self, pipeline):\n        results = []\n        for text in EN_DOCS:\n            doc = pipeline(text)\n            assert len(doc.sentences) == 1\n            results.append(doc.sentences[0].sentiment)\n        assert EXPECTED == results\n\n    def test_multiple_sentences(self, pipeline):\n        doc = pipeline(EN_DOC)\n        assert len(doc.sentences) == 3\n        results = [sentence.sentiment for sentence in doc.sentences]\n        assert EXPECTED == results\n\n    def test_empty_text(self, pipeline):\n        \"\"\"\n        Test empty text and a text which might get reduced to empty text by removing dashes\n        \"\"\"\n        doc = pipeline(\"\")\n        assert len(doc.sentences) == 0\n\n        doc = pipeline(\"--\")\n        assert len(doc.sentences) == 1\n"
  },
  {
    "path": "stanza/tests/pipeline/test_requirements.py",
    "content": "\"\"\"\nTest the requirements functionality for processors\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza.pipeline.core import PipelineRequirementsException\nfrom stanza.pipeline.processor import ProcessorRequirementsException\nfrom stanza.tests import *\n\npytestmark = pytest.mark.pipeline\n\ndef check_exception_vals(req_exception, req_exception_vals):\n    \"\"\"\n    Check the values of a ProcessorRequirementsException against a dict of expected values.\n    :param req_exception: the ProcessorRequirementsException to evaluate\n    :param req_exception_vals: expected values for the ProcessorRequirementsException\n    :return: None\n    \"\"\"\n    assert isinstance(req_exception, ProcessorRequirementsException)\n    assert req_exception.processor_type == req_exception_vals['processor_type']\n    assert req_exception.processors_list == req_exception_vals['processors_list']\n    assert req_exception.err_processor.requires == req_exception_vals['requires']\n\n\ndef test_missing_requirements():\n    \"\"\"\n    Try to build several pipelines with bad configs and check thrown exceptions against gold exceptions.\n    :return: None\n    \"\"\"\n    # list of (bad configs, list of gold ProcessorRequirementsExceptions that should be thrown) pairs\n    bad_config_lists = [\n        # missing tokenize\n        (\n            # input config\n            {'processors': 'pos,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'},\n            # 2 expected exceptions\n            [\n                {'processor_type': 'POSProcessor', 'processors_list': ['pos', 'depparse'], 'provided_reqs': set([]),\n                 'requires': set(['tokenize'])},\n                {'processor_type': 'DepparseProcessor', 'processors_list': ['pos', 'depparse'],\n                 'provided_reqs': set([]), 'requires': set(['tokenize','pos', 'lemma'])}\n            ]\n        ),\n        # no pos when lemma_pos set to True; for english mwt should not be included in the loaded processor list\n        (\n            # input config\n            {'processors': 'tokenize,mwt,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_pos': True},\n            # 1 expected exception\n            [\n                {'processor_type': 'LemmaProcessor', 'processors_list': ['tokenize', 'mwt', 'lemma'],\n                 'provided_reqs': set(['tokenize', 'mwt']), 'requires': set(['tokenize', 'pos'])}\n            ]\n        )\n    ]\n    # try to build each bad config, catch exceptions, check against gold\n    pipeline_fails = 0\n    for bad_config, gold_exceptions in bad_config_lists:\n        try:\n            stanza.Pipeline(**bad_config)\n        except PipelineRequirementsException as e:\n            pipeline_fails += 1\n            assert isinstance(e, PipelineRequirementsException)\n            assert len(e.processor_req_fails) == len(gold_exceptions)\n            for processor_req_e, gold_exception in zip(e.processor_req_fails,gold_exceptions):\n                # compare the thrown ProcessorRequirementsExceptions against gold\n                check_exception_vals(processor_req_e, gold_exception)\n    # check pipeline building failed twice\n    assert pipeline_fails == 2\n\n\n"
  },
  {
    "path": "stanza/tests/pipeline/test_tokenizer.py",
    "content": "\"\"\"\nBasic testing of tokenization\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza.tests import *\n\npytestmark = pytest.mark.pipeline\n\nEN_DOC = \"Joe Smith lives in California. Joe's favorite food is pizza. He enjoys going to the beach.\"\nEN_DOC_WITH_EXTRA_WHITESPACE = \"Joe   Smith \\n lives in\\n California.   Joe's    favorite food \\tis pizza. \\t\\t\\tHe enjoys \\t\\tgoing to the beach.\"\nEN_DOC_GOLD_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=Joe>]>\n<Token id=2;words=[<Word id=2;text=Smith>]>\n<Token id=3;words=[<Word id=3;text=lives>]>\n<Token id=4;words=[<Word id=4;text=in>]>\n<Token id=5;words=[<Word id=5;text=California>]>\n<Token id=6;words=[<Word id=6;text=.>]>\n\n<Token id=1-2;words=[<Word id=1;text=Joe>, <Word id=2;text='s>]>\n<Token id=3;words=[<Word id=3;text=favorite>]>\n<Token id=4;words=[<Word id=4;text=food>]>\n<Token id=5;words=[<Word id=5;text=is>]>\n<Token id=6;words=[<Word id=6;text=pizza>]>\n<Token id=7;words=[<Word id=7;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=He>]>\n<Token id=2;words=[<Word id=2;text=enjoys>]>\n<Token id=3;words=[<Word id=3;text=going>]>\n<Token id=4;words=[<Word id=4;text=to>]>\n<Token id=5;words=[<Word id=5;text=the>]>\n<Token id=6;words=[<Word id=6;text=beach>]>\n<Token id=7;words=[<Word id=7;text=.>]>\n\"\"\".strip()\n\n# spaCy doesn't have MWT\nEN_DOC_SPACY_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=Joe>]>\n<Token id=2;words=[<Word id=2;text=Smith>]>\n<Token id=3;words=[<Word id=3;text=lives>]>\n<Token id=4;words=[<Word id=4;text=in>]>\n<Token id=5;words=[<Word id=5;text=California>]>\n<Token id=6;words=[<Word id=6;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=Joe>]>\n<Token id=2;words=[<Word id=2;text='s>]>\n<Token id=3;words=[<Word id=3;text=favorite>]>\n<Token id=4;words=[<Word id=4;text=food>]>\n<Token id=5;words=[<Word id=5;text=is>]>\n<Token id=6;words=[<Word id=6;text=pizza>]>\n<Token id=7;words=[<Word id=7;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=He>]>\n<Token id=2;words=[<Word id=2;text=enjoys>]>\n<Token id=3;words=[<Word id=3;text=going>]>\n<Token id=4;words=[<Word id=4;text=to>]>\n<Token id=5;words=[<Word id=5;text=the>]>\n<Token id=6;words=[<Word id=6;text=beach>]>\n<Token id=7;words=[<Word id=7;text=.>]>\n\"\"\".strip()\nEN_DOC_POSTPROCESSOR_TOKENS_LIST = [['Joe', 'Smith', 'lives', 'in', 'California', '.'], [(\"Joe's\", True), 'favorite', 'food', 'is', 'pizza', '.'], ['He', 'enjoys', 'going', 'to', 'the', 'beach', '.']]\nEN_DOC_POSTPROCESSOR_COMBINED_LIST = [['Joe', 'Smith', 'lives', 'in', 'California', '.'], ['Joe', \"'s\", 'favorite', 'food', 'is', 'pizza', '.'], ['He', 'enjoys', 'going', \"to the beach\", '.']]\n\nEN_DOC_POSTPROCESSOR_COMBINED_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=Joe>]>\n<Token id=2;words=[<Word id=2;text=Smith>]>\n<Token id=3;words=[<Word id=3;text=lives>]>\n<Token id=4;words=[<Word id=4;text=in>]>\n<Token id=5;words=[<Word id=5;text=California>]>\n<Token id=6;words=[<Word id=6;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=Joe>]>\n<Token id=2;words=[<Word id=2;text='s>]>\n<Token id=3;words=[<Word id=3;text=favorite>]>\n<Token id=4;words=[<Word id=4;text=food>]>\n<Token id=5;words=[<Word id=5;text=is>]>\n<Token id=6;words=[<Word id=6;text=pizza>]>\n<Token id=7;words=[<Word id=7;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=He>]>\n<Token id=2;words=[<Word id=2;text=enjoys>]>\n<Token id=3;words=[<Word id=3;text=going>]>\n<Token id=4;words=[<Word id=4;text=to the beach>]>\n<Token id=5;words=[<Word id=5;text=.>]>\n\"\"\"\n\n# ensure that the entry above has spaces somewhere to test that spaces work in between tokens\n\nEN_DOC_GOLD_NOSSPLIT_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=Joe>]>\n<Token id=2;words=[<Word id=2;text=Smith>]>\n<Token id=3;words=[<Word id=3;text=lives>]>\n<Token id=4;words=[<Word id=4;text=in>]>\n<Token id=5;words=[<Word id=5;text=California>]>\n<Token id=6;words=[<Word id=6;text=.>]>\n<Token id=7;words=[<Word id=7;text=Joe>]>\n<Token id=8;words=[<Word id=8;text='s>]>\n<Token id=9;words=[<Word id=9;text=favorite>]>\n<Token id=10;words=[<Word id=10;text=food>]>\n<Token id=11;words=[<Word id=11;text=is>]>\n<Token id=12;words=[<Word id=12;text=pizza>]>\n<Token id=13;words=[<Word id=13;text=.>]>\n<Token id=14;words=[<Word id=14;text=He>]>\n<Token id=15;words=[<Word id=15;text=enjoys>]>\n<Token id=16;words=[<Word id=16;text=going>]>\n<Token id=17;words=[<Word id=17;text=to>]>\n<Token id=18;words=[<Word id=18;text=the>]>\n<Token id=19;words=[<Word id=19;text=beach>]>\n<Token id=20;words=[<Word id=20;text=.>]>\n\"\"\".strip()\n\nEN_DOC_PRETOKENIZED = \\\n    \"Joe Smith lives in California .\\nJoe's favorite  food is  pizza .\\n\\nHe enjoys going to the beach.\\n\"\nEN_DOC_PRETOKENIZED_GOLD_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=Joe>]>\n<Token id=2;words=[<Word id=2;text=Smith>]>\n<Token id=3;words=[<Word id=3;text=lives>]>\n<Token id=4;words=[<Word id=4;text=in>]>\n<Token id=5;words=[<Word id=5;text=California>]>\n<Token id=6;words=[<Word id=6;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=Joe's>]>\n<Token id=2;words=[<Word id=2;text=favorite>]>\n<Token id=3;words=[<Word id=3;text=food>]>\n<Token id=4;words=[<Word id=4;text=is>]>\n<Token id=5;words=[<Word id=5;text=pizza>]>\n<Token id=6;words=[<Word id=6;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=He>]>\n<Token id=2;words=[<Word id=2;text=enjoys>]>\n<Token id=3;words=[<Word id=3;text=going>]>\n<Token id=4;words=[<Word id=4;text=to>]>\n<Token id=5;words=[<Word id=5;text=the>]>\n<Token id=6;words=[<Word id=6;text=beach.>]>\n\"\"\".strip()\n\nEN_DOC_PRETOKENIZED_LIST = [['Joe', 'Smith', 'lives', 'in', 'California', '.'], ['He', 'loves', 'pizza', '.']]\nEN_DOC_PRETOKENIZED_LIST_GOLD_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=Joe>]>\n<Token id=2;words=[<Word id=2;text=Smith>]>\n<Token id=3;words=[<Word id=3;text=lives>]>\n<Token id=4;words=[<Word id=4;text=in>]>\n<Token id=5;words=[<Word id=5;text=California>]>\n<Token id=6;words=[<Word id=6;text=.>]>\n\n<Token id=1;words=[<Word id=1;text=He>]>\n<Token id=2;words=[<Word id=2;text=loves>]>\n<Token id=3;words=[<Word id=3;text=pizza>]>\n<Token id=4;words=[<Word id=4;text=.>]>\n\"\"\".strip()\n\nEN_DOC_NO_SSPLIT = [\"This is a sentence. This is another.\", \"This is a third.\"]\nEN_DOC_NO_SSPLIT_SENTENCES = [['This', 'is', 'a', 'sentence', '.', 'This', 'is', 'another', '.'], ['This', 'is', 'a', 'third', '.']]\n\nFR_DOC = \"Le prince va manger du poulet aux les magasins aujourd'hui.\"\nFR_DOC_POSTPROCESSOR_TOKENS_LIST = [['Le', 'prince', 'va', 'manger', ('du', True), 'poulet', ('aux', True), 'les', 'magasins', \"aujourd'hui\", '.']]\nFR_DOC_POSTPROCESSOR_COMBINED_MWT_LIST = [['Le', 'prince', 'va', 'manger', ('du', True), 'poulet', ('aux', True), 'les', 'magasins', (\"aujourd'hui\", [\"aujourd'\", \"hui\"]), '.']]\nFR_DOC_PRETOKENIZED_LIST_GOLD_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=Le>]>\n<Token id=2;words=[<Word id=2;text=prince>]>\n<Token id=3;words=[<Word id=3;text=va>]>\n<Token id=4;words=[<Word id=4;text=manger>]>\n<Token id=5-6;words=[<Word id=5;text=de>, <Word id=6;text=le>]>\n<Token id=7;words=[<Word id=7;text=poulet>]>\n<Token id=8-9;words=[<Word id=8;text=à>, <Word id=9;text=les>]>\n<Token id=10;words=[<Word id=10;text=les>]>\n<Token id=11;words=[<Word id=11;text=magasins>]>\n<Token id=12-13;words=[<Word id=12;text=aujourd'>, <Word id=13;text=hui>]>\n<Token id=14;words=[<Word id=14;text=.>]>\n\"\"\"\n\nJA_DOC = \"北京は中国の首都です。 北京の人口は2152万人です。\\n\" # add some random whitespaces that need to be skipped\nJA_DOC_GOLD_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=北京>]>\n<Token id=2;words=[<Word id=2;text=は>]>\n<Token id=3;words=[<Word id=3;text=中国>]>\n<Token id=4;words=[<Word id=4;text=の>]>\n<Token id=5;words=[<Word id=5;text=首都>]>\n<Token id=6;words=[<Word id=6;text=です>]>\n<Token id=7;words=[<Word id=7;text=。>]>\n\n<Token id=1;words=[<Word id=1;text=北京>]>\n<Token id=2;words=[<Word id=2;text=の>]>\n<Token id=3;words=[<Word id=3;text=人口>]>\n<Token id=4;words=[<Word id=4;text=は>]>\n<Token id=5;words=[<Word id=5;text=2152万>]>\n<Token id=6;words=[<Word id=6;text=人>]>\n<Token id=7;words=[<Word id=7;text=です>]>\n<Token id=8;words=[<Word id=8;text=。>]>\n\"\"\".strip()\n\nJA_DOC_GOLD_NOSSPLIT_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=北京>]>\n<Token id=2;words=[<Word id=2;text=は>]>\n<Token id=3;words=[<Word id=3;text=中国>]>\n<Token id=4;words=[<Word id=4;text=の>]>\n<Token id=5;words=[<Word id=5;text=首都>]>\n<Token id=6;words=[<Word id=6;text=です>]>\n<Token id=7;words=[<Word id=7;text=。>]>\n<Token id=8;words=[<Word id=8;text=北京>]>\n<Token id=9;words=[<Word id=9;text=の>]>\n<Token id=10;words=[<Word id=10;text=人口>]>\n<Token id=11;words=[<Word id=11;text=は>]>\n<Token id=12;words=[<Word id=12;text=2152万>]>\n<Token id=13;words=[<Word id=13;text=人>]>\n<Token id=14;words=[<Word id=14;text=です>]>\n<Token id=15;words=[<Word id=15;text=。>]>\n\"\"\".strip()\n\nZH_DOC = \"北京是中国的首都。 北京有2100万人口，是一个直辖市。\\n\"\nZH_DOC1 = \"北\\n京是中\\n国的首\\n都。 北京有2100万人口，是一个直辖市。\\n\"\nZH_DOC2 = \"北\\n京是中\\n国的首\\n都。\\n\\n 北京有2100万人口，是一个直辖市。\\n\"\nZH_DOC_GOLD_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=北京>]>\n<Token id=2;words=[<Word id=2;text=是>]>\n<Token id=3;words=[<Word id=3;text=中国>]>\n<Token id=4;words=[<Word id=4;text=的>]>\n<Token id=5;words=[<Word id=5;text=首都>]>\n<Token id=6;words=[<Word id=6;text=。>]>\n\n<Token id=1;words=[<Word id=1;text=北京>]>\n<Token id=2;words=[<Word id=2;text=有>]>\n<Token id=3;words=[<Word id=3;text=2100>]>\n<Token id=4;words=[<Word id=4;text=万>]>\n<Token id=5;words=[<Word id=5;text=人口>]>\n<Token id=6;words=[<Word id=6;text=，>]>\n<Token id=7;words=[<Word id=7;text=是>]>\n<Token id=8;words=[<Word id=8;text=一个>]>\n<Token id=9;words=[<Word id=9;text=直辖市>]>\n<Token id=10;words=[<Word id=10;text=。>]>\n\"\"\".strip()\n\nZH_DOC1_GOLD_TOKENS=\"\"\"\n<Token id=1;words=[<Word id=1;text=北京>]>\n<Token id=2;words=[<Word id=2;text=是>]>\n<Token id=3;words=[<Word id=3;text=中国>]>\n<Token id=4;words=[<Word id=4;text=的>]>\n<Token id=5;words=[<Word id=5;text=首都>]>\n<Token id=6;words=[<Word id=6;text=。>]>\n\n<Token id=1;words=[<Word id=1;text=北京>]>\n<Token id=2;words=[<Word id=2;text=有>]>\n<Token id=3;words=[<Word id=3;text=2100万>]>\n<Token id=4;words=[<Word id=4;text=人口>]>\n<Token id=5;words=[<Word id=5;text=，>]>\n<Token id=6;words=[<Word id=6;text=是>]>\n<Token id=7;words=[<Word id=7;text=一>]>\n<Token id=8;words=[<Word id=8;text=个>]>\n<Token id=9;words=[<Word id=9;text=直辖>]>\n<Token id=10;words=[<Word id=10;text=市>]>\n<Token id=11;words=[<Word id=11;text=。>]>\n\"\"\".strip()\n\nZH_DOC_GOLD_NOSSPLIT_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=北京>]>\n<Token id=2;words=[<Word id=2;text=是>]>\n<Token id=3;words=[<Word id=3;text=中国>]>\n<Token id=4;words=[<Word id=4;text=的>]>\n<Token id=5;words=[<Word id=5;text=首都>]>\n<Token id=6;words=[<Word id=6;text=。>]>\n<Token id=7;words=[<Word id=7;text=北京>]>\n<Token id=8;words=[<Word id=8;text=有>]>\n<Token id=9;words=[<Word id=9;text=2100>]>\n<Token id=10;words=[<Word id=10;text=万>]>\n<Token id=11;words=[<Word id=11;text=人口>]>\n<Token id=12;words=[<Word id=12;text=，>]>\n<Token id=13;words=[<Word id=13;text=是>]>\n<Token id=14;words=[<Word id=14;text=一个>]>\n<Token id=15;words=[<Word id=15;text=直辖市>]>\n<Token id=16;words=[<Word id=16;text=。>]>\n\"\"\".strip()\n\nZH_PARENS_DOC = \"我们一起学(猫叫)\"\n\nTH_DOC = \"ข้าราชการได้รับการหมุนเวียนเป็นระยะ และเขาได้รับมอบหมายให้ประจำในระดับภูมิภาค\"\nTH_DOC_GOLD_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=ข้าราชการ>]>\n<Token id=2;words=[<Word id=2;text=ได้รับ>]>\n<Token id=3;words=[<Word id=3;text=การ>]>\n<Token id=4;words=[<Word id=4;text=หมุนเวียน>]>\n<Token id=5;words=[<Word id=5;text=เป็นระยะ>]>\n\n<Token id=1;words=[<Word id=1;text=และ>]>\n<Token id=2;words=[<Word id=2;text=เขา>]>\n<Token id=3;words=[<Word id=3;text=ได้>]>\n<Token id=4;words=[<Word id=4;text=รับมอบหมาย>]>\n<Token id=5;words=[<Word id=5;text=ให้>]>\n<Token id=6;words=[<Word id=6;text=ประจำ>]>\n<Token id=7;words=[<Word id=7;text=ใน>]>\n<Token id=8;words=[<Word id=8;text=ระดับ>]>\n<Token id=9;words=[<Word id=9;text=ภูมิภาค>]>\n\"\"\".strip()\n\nTH_DOC_GOLD_NOSSPLIT_TOKENS = \"\"\"\n<Token id=1;words=[<Word id=1;text=ข้าราชการ>]>\n<Token id=2;words=[<Word id=2;text=ได้รับ>]>\n<Token id=3;words=[<Word id=3;text=การ>]>\n<Token id=4;words=[<Word id=4;text=หมุนเวียน>]>\n<Token id=5;words=[<Word id=5;text=เป็นระยะ>]>\n<Token id=6;words=[<Word id=6;text=และ>]>\n<Token id=7;words=[<Word id=7;text=เขา>]>\n<Token id=8;words=[<Word id=8;text=ได้>]>\n<Token id=9;words=[<Word id=9;text=รับมอบหมาย>]>\n<Token id=10;words=[<Word id=10;text=ให้>]>\n<Token id=11;words=[<Word id=11;text=ประจำ>]>\n<Token id=12;words=[<Word id=12;text=ใน>]>\n<Token id=13;words=[<Word id=13;text=ระดับ>]>\n<Token id=14;words=[<Word id=14;text=ภูมิภาค>]>\n\"\"\".strip()\n\n@pytest.fixture(scope=\"module\")\ndef basic_pipeline():\n    \"\"\" Create a pipeline with a basic English tokenizer \"\"\"\n    nlp = stanza.Pipeline(processors='tokenize', dir=TEST_MODELS_DIR, lang='en', download_method=None)\n    return nlp\n\n\n@pytest.fixture(scope=\"module\")\ndef pretokenized_pipeline():\n    \"\"\" Create a pipeline with a basic English pretokenized tokenizer \"\"\"\n    nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'tokenize_pretokenized': True, 'download_method': None})\n    return nlp\n\n@pytest.fixture(scope=\"module\")\ndef zh_pipeline():\n    \"\"\" Create a pipeline with a basic Chinese tokenizer \"\"\"\n    nlp = stanza.Pipeline(lang='zh', processors='tokenize', dir=TEST_MODELS_DIR, download_method=None)\n    return nlp\n\ndef test_tokenize(basic_pipeline):\n    doc = basic_pipeline(EN_DOC)\n    assert EN_DOC_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_tokenize_ssplit_robustness(basic_pipeline):\n    doc = basic_pipeline(EN_DOC_WITH_EXTRA_WHITESPACE)\n    assert EN_DOC_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_pretokenized(pretokenized_pipeline):\n    doc = pretokenized_pipeline(EN_DOC_PRETOKENIZED)\n    assert EN_DOC_PRETOKENIZED_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n    doc = pretokenized_pipeline(EN_DOC_PRETOKENIZED_LIST)\n    assert EN_DOC_PRETOKENIZED_LIST_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_pretokenized_multidoc(pretokenized_pipeline):\n    doc = pretokenized_pipeline(EN_DOC_PRETOKENIZED)\n    assert EN_DOC_PRETOKENIZED_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n    doc = pretokenized_pipeline([stanza.Document([], text=EN_DOC_PRETOKENIZED_LIST)])[0]\n    assert EN_DOC_PRETOKENIZED_LIST_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_postprocessor():\n\n    def dummy_postprocessor(in_doc):\n        # Importantly, EN_DOC_POSTPROCESSOR_COMBINED_LIST returns a few tokens joinde\n        # with space. As some languages (such as VN) contains tokens with space in between\n        # its important to have joined space tested as one of the tokens\n        assert in_doc == EN_DOC_POSTPROCESSOR_TOKENS_LIST\n        return EN_DOC_POSTPROCESSOR_COMBINED_LIST\n\n    nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR,\n                             'lang': 'en',\n                             'download_method': None,\n                             'tokenize_postprocessor': dummy_postprocessor})\n    doc = nlp(EN_DOC)\n    assert EN_DOC_POSTPROCESSOR_COMBINED_TOKENS.strip() == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences]).strip()\n\ndef test_postprocessor_mwt():\n\n    def dummy_postprocessor(input):\n        # Importantly, EN_DOC_POSTPROCESSOR_COMBINED_LIST returns a few tokens joinde\n        # with space. As some languages (such as VN) contains tokens with space in between\n        # its important to have joined space tested as one of the tokens\n        assert input == FR_DOC_POSTPROCESSOR_TOKENS_LIST\n        return FR_DOC_POSTPROCESSOR_COMBINED_MWT_LIST\n\n    nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR,\n                             'lang': 'fr',\n                             'download_method': None,\n                             'tokenize_postprocessor': dummy_postprocessor})\n    doc = nlp(FR_DOC)\n    assert FR_DOC_PRETOKENIZED_LIST_GOLD_TOKENS.strip() == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences]).strip()\n\n\ndef test_postprocessor_typeerror():\n    with pytest.raises(ValueError):\n        nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en',\n                                 'download_method': None,\n                                 'tokenize_postprocessor': \"iamachicken\"})\n\ndef test_no_ssplit():\n    nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en',\n                             'download_method': None,\n                             'tokenize_no_ssplit': True})\n\n    doc = nlp(EN_DOC_NO_SSPLIT)\n    assert EN_DOC_NO_SSPLIT_SENTENCES == [[w.text for w in s.words] for s in doc.sentences]\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_zh_tokenizer_skip_newline(zh_pipeline):\n    doc = zh_pipeline(ZH_DOC1)\n\n    assert ZH_DOC1_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char].replace('\\n', '') == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_zh_tokenizer_skip_newline_offsets(zh_pipeline):\n    doc = zh_pipeline(ZH_DOC2)\n\n    assert ZH_DOC1_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char].replace('\\n', '') == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_zh_tokenizer_parens(zh_pipeline):\n    \"\"\"\n    The original fix for newlines in Chinese text broke () in Chinese text\n    \"\"\"\n    doc = zh_pipeline(ZH_PARENS_DOC)\n\n    # ... the results are kind of bad for this expression, so no testing of the results yet\n    #assert ZH_PARENS_DOC_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n\ndef test_spacy():\n    nlp = stanza.Pipeline(processors='tokenize', dir=TEST_MODELS_DIR, lang='en', tokenize_with_spacy=True, download_method=None)\n    doc = nlp(EN_DOC)\n\n    # make sure the loaded tokenizer is actually spacy\n    assert \"SpacyTokenizer\" == nlp.processors['tokenize']._variant.__class__.__name__\n    assert EN_DOC_SPACY_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_spacy_no_ssplit():\n    nlp = stanza.Pipeline(processors='tokenize', dir=TEST_MODELS_DIR, lang='en', tokenize_with_spacy=True, tokenize_no_ssplit=True, download_method=None)\n    doc = nlp(EN_DOC)\n\n    # make sure the loaded tokenizer is actually spacy\n    assert \"SpacyTokenizer\" == nlp.processors['tokenize']._variant.__class__.__name__\n    assert EN_DOC_GOLD_NOSSPLIT_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_sudachipy():\n    nlp = stanza.Pipeline(lang='ja', dir=TEST_MODELS_DIR, processors={'tokenize': 'sudachipy'}, package=None, download_method=None)\n    doc = nlp(JA_DOC)\n\n    assert \"SudachiPyTokenizer\" == nlp.processors['tokenize']._variant.__class__.__name__\n    assert JA_DOC_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_sudachipy_no_ssplit():\n    nlp = stanza.Pipeline(lang='ja', dir=TEST_MODELS_DIR, processors={'tokenize': 'sudachipy'}, tokenize_no_ssplit=True, package=None, download_method=None)\n    doc = nlp(JA_DOC)\n\n    assert \"SudachiPyTokenizer\" == nlp.processors['tokenize']._variant.__class__.__name__\n    assert JA_DOC_GOLD_NOSSPLIT_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_jieba():\n    nlp = stanza.Pipeline(lang='zh', dir=TEST_MODELS_DIR, processors={'tokenize': 'jieba'}, package=None, download_method=None)\n    doc = nlp(ZH_DOC)\n\n    assert \"JiebaTokenizer\" == nlp.processors['tokenize']._variant.__class__.__name__\n    assert ZH_DOC_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_jieba_no_ssplit():\n    nlp = stanza.Pipeline(lang='zh', dir=TEST_MODELS_DIR, processors={'tokenize': 'jieba'}, tokenize_no_ssplit=True, package=None, download_method=None)\n    doc = nlp(ZH_DOC)\n\n    assert \"JiebaTokenizer\" == nlp.processors['tokenize']._variant.__class__.__name__\n    assert ZH_DOC_GOLD_NOSSPLIT_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_pythainlp():\n    nlp = stanza.Pipeline(lang='th', dir=TEST_MODELS_DIR, processors={'tokenize': 'pythainlp'}, package=None, download_method=None)\n    doc = nlp(TH_DOC)\n    assert \"PyThaiNLPTokenizer\" == nlp.processors['tokenize']._variant.__class__.__name__\n    assert TH_DOC_GOLD_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\ndef test_pythainlp_no_ssplit():\n    nlp = stanza.Pipeline(lang='th', dir=TEST_MODELS_DIR, processors={'tokenize': 'pythainlp'}, tokenize_no_ssplit=True, package=None, download_method=None)\n    doc = nlp(TH_DOC)\n    assert \"PyThaiNLPTokenizer\" == nlp.processors['tokenize']._variant.__class__.__name__\n    assert TH_DOC_GOLD_NOSSPLIT_TOKENS == '\\n\\n'.join([sent.tokens_string() for sent in doc.sentences])\n    assert all([doc.text[token._start_char: token._end_char] == token.text for sent in doc.sentences for token in sent.tokens])\n\n"
  },
  {
    "path": "stanza/tests/pos/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/pos/test_data.py",
    "content": "\"\"\"\nA few tests of specific operations from the Dataset\n\"\"\"\n\nimport os\nimport pytest\n\nfrom stanza.models.common.doc import *\nfrom stanza.models import tagger\nfrom stanza.models.pos.data import Dataset, ShuffledDataset\nfrom stanza.utils.conll import CoNLL\n\nfrom stanza.tests.pos.test_tagger import TRAIN_DATA, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_FEATS\n\ndef test_basic_reading():\n    \"\"\"\n    Test that a dataset with no xpos is detected by the Dataset\n    \"\"\"\n    # empty args for building the data object\n    args = tagger.parse_args(args=[])\n\n    train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)\n\n    data = Dataset(train_doc, args, None)\n    assert data.has_upos\n    assert data.has_xpos\n    assert data.has_feats\n\ndef test_no_xpos():\n    \"\"\"\n    Test that a dataset with no xpos is detected by the Dataset\n    \"\"\"\n    # empty args for building the data object\n    args = tagger.parse_args(args=[])\n\n    train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA_NO_XPOS)\n\n    data = Dataset(train_doc, args, None)\n    assert data.has_upos\n    assert not data.has_xpos\n    assert data.has_feats\n\ndef test_no_upos():\n    \"\"\"\n    Test that a dataset with no upos is detected by the Dataset\n    \"\"\"\n    # empty args for building the data object\n    args = tagger.parse_args(args=[])\n\n    train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA_NO_UPOS)\n\n    data = Dataset(train_doc, args, None)\n    assert not data.has_upos\n    assert data.has_xpos\n    assert data.has_feats\n\ndef test_no_feats():\n    \"\"\"\n    Test that a dataset with no feats is detected by the Dataset\n    \"\"\"\n    # empty args for building the data object\n    args = tagger.parse_args(args=[])\n\n    train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA_NO_FEATS)\n\n    data = Dataset(train_doc, args, None)\n    assert data.has_upos\n    assert data.has_xpos\n    assert not data.has_feats\n\ndef test_no_augment():\n    \"\"\"\n    Test that with no punct removing augmentation, the doc always has punct at the end\n    \"\"\"\n    args = tagger.parse_args(args=[\"--shorthand\", \"en_test\", \"--augment_nopunct\", \"0.0\"])\n\n    train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)\n    data = Dataset(train_doc, args, None)\n    data = data.to_loader(batch_size=2)\n\n    for i in range(50):\n        for batch in data:\n            for text in batch.text:\n                assert text[-1] in (\".\", \"!\")\n\ndef test_augment():\n    \"\"\"\n    Test that with 100% punct removing augmentation, the doc never has punct at the end\n    \"\"\"\n    args = tagger.parse_args(args=[\"--shorthand\", \"en_test\", \"--augment_nopunct\", \"1.0\"])\n\n    train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)\n    data = Dataset(train_doc, args, None)\n    data = data.to_loader(batch_size=2)\n\n    for i in range(50):\n        for batch in data:\n            for text in batch.text:\n                assert text[-1] not in (\".\", \"!\")\n\ndef test_sometimes_augment():\n    \"\"\"\n    Test 50% punct removing augmentation\n\n    With this frequency, we should get a reasonable number of docs\n    with a punct at the end and a reasonable without.\n    \"\"\"\n    args = tagger.parse_args(args=[\"--shorthand\", \"en_test\", \"--augment_nopunct\", \"0.5\"])\n\n    train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)\n    data = Dataset(train_doc, args, None)\n    data = data.to_loader(batch_size=2)\n\n    count_with = 0\n    count_without = 0\n    for i in range(50):\n        for batch in data:\n            for text in batch.text:\n                if text[-1] in (\".\", \"!\"):\n                    count_with += 1\n                else:\n                    count_without += 1\n\n    # this should never happen\n    # literally less than 1 in 10^20th odds\n    assert count_with > 5\n    assert count_without > 5\n\n\nNO_XPOS_TEMPLATE = \"\"\"\n# text = Noxpos {indexp}\n# sent_id = {index}\n1\tNoxpos\tnoxpos\tNOUN\t_\tNumber=Sing\t0\troot\t_\tstart_char=0|end_char=8|ner=O\n2\t{indexp}\t{indexp}\tNUM\t_\tNumForm=Digit|NumType=Card\t1\tdep\t_\tstart_char=9|end_char=10|ner=S-CARDINAL\n\"\"\".strip()\n\nYES_XPOS_TEMPLATE = \"\"\"\n# text = Yesxpos {indexp}\n# sent_id = {index}\n1\tYesxpos\tyesxpos\tNOUN\tNN\tNumber=Sing\t0\troot\t_\tstart_char=0|end_char=8|ner=O\n2\t{indexp}\t{indexp}\tNUM\tCD\tNumForm=Digit|NumType=Card\t1\tdep\t_\tstart_char=9|end_char=10|ner=S-CARDINAL\n\"\"\".strip()\n\ndef test_shuffle(tmp_path):\n    args = tagger.parse_args(args=[\"--batch_size\", \"10\", \"--shorthand\", \"en_test\", \"--augment_nopunct\", \"0.0\"])\n\n    # 100 looked nice but was actually a 1/1000000 chance of the test failing\n    # so let's crank it up to 1000 and make it 1/10^58\n    no_xpos = [NO_XPOS_TEMPLATE.format(index=idx, indexp=idx+1) for idx in range(1000)]\n    no_doc = CoNLL.conll2doc(input_str=\"\\n\\n\".join(no_xpos))\n    no_data = Dataset(no_doc, args, None)\n\n    yes_xpos = [YES_XPOS_TEMPLATE.format(index=idx, indexp=idx+101) for idx in range(1000)]\n    yes_doc = CoNLL.conll2doc(input_str=\"\\n\\n\".join(yes_xpos))\n    yes_data = Dataset(yes_doc, args, None)\n\n    shuffled = ShuffledDataset([no_data, yes_data], 10)\n\n    assert sum(1 for _ in shuffled) == 200\n\n    num_with = 0\n    num_without = 0\n    for batch in shuffled:\n        if batch.xpos is not None:\n            num_with += 1\n        else:\n            num_without += 1\n        # at the halfway point of the iteration, there should be at\n        # least one in each category\n        # for example, if we had forgotten to shuffle, this assertion would fail\n        if num_with + num_without == 100:\n            assert num_with > 1\n            assert num_without > 1\n\n    assert num_with == 100\n    assert num_without == 100\n\n\nEWT_SAMPLE = \"\"\"\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0048\n# text = Bush asked for permission to go to Alabama to work on a Senate campaign.\n1\tBush\tBush\tPROPN\tNNP\tNumber=Sing\t2\tnsubj\t2:nsubj\t_\n2\tasked\task\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n3\tfor\tfor\tADP\tIN\t_\t4\tcase\t4:case\t_\n4\tpermission\tpermission\tNOUN\tNN\tNumber=Sing\t2\tobl\t2:obl:for\t_\n5\tto\tto\tPART\tTO\t_\t6\tmark\t6:mark\t_\n6\tgo\tgo\tVERB\tVB\tVerbForm=Inf\t4\tacl\t4:acl:to\t_\n7\tto\tto\tADP\tIN\t_\t8\tcase\t8:case\t_\n8\tAlabama\tAlabama\tPROPN\tNNP\tNumber=Sing\t6\tobl\t6:obl:to\t_\n9\tto\tto\tPART\tTO\t_\t10\tmark\t10:mark\t_\n10\twork\twork\tVERB\tVB\tVerbForm=Inf\t6\tadvcl\t6:advcl:to\t_\n11\ton\ton\tADP\tIN\t_\t14\tcase\t14:case\t_\n12\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t14\tdet\t14:det\t_\n13\tSenate\tSenate\tPROPN\tNNP\tNumber=Sing\t14\tcompound\t14:compound\t_\n14\tcampaign\tcampaign\tNOUN\tNN\tNumber=Sing\t10\tobl\t10:obl:on\tSpaceAfter=No\n15\t.\t.\tPUNCT\t.\t_\t2\tpunct\t2:punct\t_\n\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0049\n# text = His superior officers said OK.\n1\tHis\this\tPRON\tPRP$\tCase=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t3\tnmod:poss\t3:nmod:poss\t_\n2\tsuperior\tsuperior\tADJ\tJJ\tDegree=Pos\t3\tamod\t3:amod\t_\n3\tofficers\tofficer\tNOUN\tNNS\tNumber=Plur\t4\tnsubj\t4:nsubj\t_\n4\tsaid\tsay\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n5\tOK\tok\tINTJ\tUH\t_\t4\tobj\t4:obj\tSpaceAfter=No\n6\t.\t.\tPUNCT\t.\t_\t4\tpunct\t4:punct\t_\n\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0053\n# text = In ’72 or ’73, if you were a pilot, active or Guard, and you had an obligation and wanted to get out, no problem.\n1\tIn\tin\tADP\tIN\t_\t2\tcase\t2:case\t_\n2\t’72\t'72\tNUM\tCD\tNumForm=Digit|NumType=Card\t10\tobl\t10:obl:in\t_\n3\tor\tor\tCCONJ\tCC\t_\t4\tcc\t4:cc\t_\n4\t’73\t'73\tNUM\tCD\tNumForm=Digit|NumType=Card\t2\tconj\t2:conj:or|10:obl:in\tSpaceAfter=No\n5\t,\t,\tPUNCT\t,\t_\t2\tpunct\t2:punct\t_\n6\tif\tif\tSCONJ\tIN\t_\t10\tmark\t10:mark\t_\n7\tyou\tyou\tPRON\tPRP\tCase=Nom|Person=2|PronType=Prs\t10\tnsubj\t10:nsubj\t_\n8\twere\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin\t10\tcop\t10:cop\t_\n9\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t10\tdet\t10:det\t_\n10\tpilot\tpilot\tNOUN\tNN\tNumber=Sing\t28\tadvcl\t28:advcl:if\tSpaceAfter=No\n11\t,\t,\tPUNCT\t,\t_\t12\tpunct\t12:punct\t_\n12\tactive\tactive\tADJ\tJJ\tDegree=Pos\t10\tamod\t10:amod\t_\n13\tor\tor\tCCONJ\tCC\t_\t14\tcc\t14:cc\t_\n14\tGuard\tGuard\tPROPN\tNNP\tNumber=Sing\t12\tconj\t10:amod|12:conj:or\tSpaceAfter=No\n15\t,\t,\tPUNCT\t,\t_\t18\tpunct\t18:punct\t_\n16\tand\tand\tCCONJ\tCC\t_\t18\tcc\t18:cc\t_\n17\tyou\tyou\tPRON\tPRP\tCase=Nom|Person=2|PronType=Prs\t18\tnsubj\t18:nsubj|22:nsubj|24:nsubj:xsubj\t_\n18\thad\thave\tVERB\tVBD\tMood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin\t10\tconj\t10:conj:and|28:advcl:if\t_\n19\tan\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t20\tdet\t20:det\t_\n20\tobligation\tobligation\tNOUN\tNN\tNumber=Sing\t18\tobj\t18:obj\t_\n21\tand\tand\tCCONJ\tCC\t_\t22\tcc\t22:cc\t_\n22\twanted\twant\tVERB\tVBD\tMood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin\t18\tconj\t18:conj:and\t_\n23\tto\tto\tPART\tTO\t_\t24\tmark\t24:mark\t_\n24\tget\tget\tVERB\tVB\tVerbForm=Inf\t22\txcomp\t22:xcomp\t_\n25\tout\tout\tADV\tRB\t_\t24\tadvmod\t24:advmod\tSpaceAfter=No\n26\t,\t,\tPUNCT\t,\t_\t10\tpunct\t10:punct\t_\n27\tno\tno\tDET\tDT\tPronType=Neg\t28\tdet\t28:det\t_\n28\tproblem\tproblem\tNOUN\tNN\tNumber=Sing\t0\troot\t0:root\tSpaceAfter=No\n29\t.\t.\tPUNCT\t.\t_\t28\tpunct\t28:punct\t_\n\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0054\n# text = In fact, you were helping them solve their problem.”\n1\tIn\tin\tADP\tIN\t_\t2\tcase\t2:case\t_\n2\tfact\tfact\tNOUN\tNN\tNumber=Sing\t6\tobl\t6:obl:in\tSpaceAfter=No\n3\t,\t,\tPUNCT\t,\t_\t2\tpunct\t2:punct\t_\n4\tyou\tyou\tPRON\tPRP\tCase=Nom|Person=2|PronType=Prs\t6\tnsubj\t6:nsubj\t_\n5\twere\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin\t6\taux\t6:aux\t_\n6\thelping\thelp\tVERB\tVBG\tTense=Pres|VerbForm=Part\t0\troot\t0:root\t_\n7\tthem\tthey\tPRON\tPRP\tCase=Acc|Number=Plur|Person=3|PronType=Prs\t6\tobj\t6:obj|8:nsubj:xsubj\t_\n8\tsolve\tsolve\tVERB\tVB\tVerbForm=Inf\t6\txcomp\t6:xcomp\t_\n9\ttheir\ttheir\tPRON\tPRP$\tCase=Gen|Number=Plur|Person=3|Poss=Yes|PronType=Prs\t10\tnmod:poss\t10:nmod:poss\t_\n10\tproblem\tproblem\tNOUN\tNN\tNumber=Sing\t8\tobj\t8:obj\tSpaceAfter=No\n11\t.\t.\tPUNCT\t.\t_\t6\tpunct\t6:punct\tSpaceAfter=No\n12\t”\t\"\tPUNCT\t''\t_\t6\tpunct\t6:punct\t_\n\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0055\n# text = So Bush stopped flying.\n1\tSo\tso\tADV\tRB\t_\t3\tadvmod\t3:advmod\t_\n2\tBush\tBush\tPROPN\tNNP\tNumber=Sing\t3\tnsubj\t3:nsubj|4:nsubj:xsubj\t_\n3\tstopped\tstop\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n4\tflying\tfly\tVERB\tVBG\tVerbForm=Ger\t3\txcomp\t3:xcomp\tSpaceAfter=No\n5\t.\t.\tPUNCT\t.\t_\t3\tpunct\t3:punct\t_\n\"\"\".lstrip()\n\ndef test_length_limited_dataloader():\n    sample = CoNLL.conll2doc(input_str=EWT_SAMPLE)\n\n    args = tagger.parse_args(args=[\"--batch_size\", \"10\", \"--shorthand\", \"en_test\", \"--augment_nopunct\", \"0.0\"])\n    data = Dataset(sample, args, None)\n\n    # this should read the whole dataset\n    dl = data.to_length_limited_loader(5, 1000)\n    batches = [batch.idx for batch in dl]\n    assert batches == [(0, 1, 2, 3, 4)]\n\n    dl = data.to_length_limited_loader(4, 1000)\n    batches = [batch.idx for batch in dl]\n    assert batches == [(0, 1, 2, 3), (4,)]\n\n    dl = data.to_length_limited_loader(2, 1000)\n    batches = [batch.idx for batch in dl]\n    assert batches == [(0, 1), (2, 3), (4,)]\n\n    # the first three sentences should reach this limit\n    dl = data.to_length_limited_loader(5, 55)\n    batches = [batch.idx for batch in dl]\n    assert batches == [(0, 1, 2), (3, 4)]\n\n    # the third sentence (2) is already past this limit by itself\n    dl = data.to_length_limited_loader(5, 25)\n    batches = [batch.idx for batch in dl]\n    assert batches == [(0, 1), (2,), (3, 4)]\n\n\nEWT_PUNCT_SAMPLE = \"\"\"\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0048\n# text = Bush asked for permission to go to Alabama to work on a Senate campaign.\n1\tBush\tBush\tPROPN\tNNP\tNumber=Sing\t2\tnsubj\t2:nsubj\t_\n2\tasked\task\tVERB\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n3\tfor\tfor\tADP\tIN\t_\t4\tcase\t4:case\t_\n4\tpermission\tpermission\tNOUN\tNN\tNumber=Sing\t2\tobl\t2:obl:for\t_\n5\tto\tto\tPART\tTO\t_\t6\tmark\t6:mark\t_\n6\tgo\tgo\tVERB\tVB\tVerbForm=Inf\t4\tacl\t4:acl:to\t_\n7\tto\tto\tADP\tIN\t_\t8\tcase\t8:case\t_\n8\tAlabama\tAlabama\tPROPN\tNNP\tNumber=Sing\t6\tobl\t6:obl:to\t_\n9\tto\tto\tPART\tTO\t_\t10\tmark\t10:mark\t_\n10\twork\twork\tVERB\tVB\tVerbForm=Inf\t6\tadvcl\t6:advcl:to\t_\n11\ton\ton\tADP\tIN\t_\t14\tcase\t14:case\t_\n12\ta\ta\tDET\tDT\tDefinite=Ind|PronType=Art\t14\tdet\t14:det\t_\n13\tSenate\tSenate\tPROPN\tNNP\tNumber=Sing\t14\tcompound\t14:compound\t_\n14\tcampaign\tcampaign\tNOUN\tNN\tNumber=Sing\t10\tobl\t10:obl:on\tSpaceAfter=No\n15\t!!!!!\t!\tPUNCT\t.\t_\t2\tpunct\t2:punct\t_\n\n# sent_id = weblog-blogspot.com_alaindewitt_20040929103700_ENG_20040929_103700-0049\n# text = His superior officers said OK.\n1\tHis\this\tPRON\tPRP$\tCase=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t3\tnmod:poss\t3:nmod:poss\t_\n2\tsuperior\tsuperior\tADJ\tJJ\tDegree=Pos\t3\tamod\t3:amod\t_\n3\tofficers\tofficer\tNOUN\tNNS\tNumber=Plur\t4\tnsubj\t4:nsubj\t_\n4\tsaid\tsay\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n5\tOK\tok\tINTJ\tUH\t_\t4\tobj\t4:obj\tSpaceAfter=No\n6\t?????\t?\tPUNCT\t.\t_\t4\tpunct\t4:punct\t_\n\"\"\"\n\n\ndef test_punct_simplification():\n    \"\"\"\n    Test a punctuation simplification that should make it so unexpected\n    question/exclamation marks types are processed into ? and !\n    \"\"\"\n    sample = CoNLL.conll2doc(input_str=EWT_PUNCT_SAMPLE)\n\n    args = tagger.parse_args(args=[\"--batch_size\", \"10\", \"--shorthand\", \"en_test\", \"--augment_nopunct\", \"0.0\"])\n    data = Dataset(sample, args, None)\n\n    dl = data.to_length_limited_loader(2, 1000)\n    batches = [batch for batch in dl]\n    batch_idx = [batch.idx for batch in batches]\n    assert batch_idx == [(0, 1)]\n\n    assert batches[0].text[0][-1] == '!'\n    assert batches[0].text[1][-1] == '?'\n    assert batches[0].text[0] == ['Bush', 'asked', 'for', 'permission', 'to', 'go', 'to', 'Alabama', 'to', 'work', 'on', 'a', 'Senate', 'campaign', '!']\n    assert batches[0].text[1] == ['His', 'superior', 'officers', 'said', 'OK', '?']\n"
  },
  {
    "path": "stanza/tests/pos/test_tagger.py",
    "content": "\"\"\"\nRun the tagger for a couple iterations on some fake data\n\nUses a couple sentences of UD_English-EWT as training/dev data\n\"\"\"\n\nimport os\nimport pytest\n\nimport torch\n\nimport stanza\nfrom stanza.models import tagger\nfrom stanza.models.common import pretrain\nfrom stanza.models.pos.trainer import Trainer\nfrom stanza.tests import TEST_WORKING_DIR, TEST_MODELS_DIR\nfrom stanza.utils.training.common import choose_pos_charlm, build_charlm_args\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nTRAIN_DATA = \"\"\"\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003\n# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.\n1\tDPA\tDPA\tPROPN\tNNP\tNumber=Sing\t0\troot\t0:root\tSpaceAfter=No\n2\t:\t:\tPUNCT\t:\t_\t1\tpunct\t1:punct\t_\n3\tIraqi\tIraqi\tADJ\tJJ\tDegree=Pos\t4\tamod\t4:amod\t_\n4\tauthorities\tauthority\tNOUN\tNNS\tNumber=Plur\t5\tnsubj\t5:nsubj\t_\n5\tannounced\tannounce\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t1\tparataxis\t1:parataxis\t_\n6\tthat\tthat\tSCONJ\tIN\t_\t9\tmark\t9:mark\t_\n7\tthey\tthey\tPRON\tPRP\tCase=Nom|Number=Plur|Person=3|PronType=Prs\t9\tnsubj\t9:nsubj\t_\n8\thad\thave\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t9\taux\t9:aux\t_\n9\tbusted\tbust\tVERB\tVBN\tTense=Past|VerbForm=Part\t5\tccomp\t5:ccomp\t_\n10\tup\tup\tADP\tRP\t_\t9\tcompound:prt\t9:compound:prt\t_\n11\t3\t3\tNUM\tCD\tNumForm=Digit|NumType=Card\t13\tnummod\t13:nummod\t_\n12\tterrorist\tterrorist\tADJ\tJJ\tDegree=Pos\t13\tamod\t13:amod\t_\n13\tcells\tcell\tNOUN\tNNS\tNumber=Plur\t9\tobj\t9:obj\t_\n14\toperating\toperate\tVERB\tVBG\tVerbForm=Ger\t13\tacl\t13:acl\t_\n15\tin\tin\tADP\tIN\t_\t16\tcase\t16:case\t_\n16\tBaghdad\tBaghdad\tPROPN\tNNP\tNumber=Sing\t14\tobl\t14:obl:in\tSpaceAfter=No\n17\t.\t.\tPUNCT\t.\t_\t1\tpunct\t1:punct\t_\n\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004\n# text = Two of them were being run by 2 officials of the Ministry of the Interior!\n1\tTwo\ttwo\tNUM\tCD\tNumForm=Word|NumType=Card\t6\tnsubj:pass\t6:nsubj:pass\t_\n2\tof\tof\tADP\tIN\t_\t3\tcase\t3:case\t_\n3\tthem\tthey\tPRON\tPRP\tCase=Acc|Number=Plur|Person=3|PronType=Prs\t1\tnmod\t1:nmod:of\t_\n4\twere\tbe\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t6\taux\t6:aux\t_\n5\tbeing\tbe\tAUX\tVBG\tVerbForm=Ger\t6\taux:pass\t6:aux:pass\t_\n6\trun\trun\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t0:root\t_\n7\tby\tby\tADP\tIN\t_\t9\tcase\t9:case\t_\n8\t2\t2\tNUM\tCD\tNumForm=Digit|NumType=Card\t9\tnummod\t9:nummod\t_\n9\tofficials\tofficial\tNOUN\tNNS\tNumber=Plur\t6\tobl\t6:obl:by\t_\n10\tof\tof\tADP\tIN\t_\t12\tcase\t12:case\t_\n11\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t12\tdet\t12:det\t_\n12\tMinistry\tMinistry\tPROPN\tNNP\tNumber=Sing\t9\tnmod\t9:nmod:of\t_\n13\tof\tof\tADP\tIN\t_\t15\tcase\t15:case\t_\n14\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t15\tdet\t15:det\t_\n15\tInterior\tInterior\tPROPN\tNNP\tNumber=Sing\t12\tnmod\t12:nmod:of\tSpaceAfter=No\n16\t!\t!\tPUNCT\t.\t_\t6\tpunct\t6:punct\t_\n\n\"\"\".lstrip()\n\nTRAIN_DATA_2 = \"\"\"\n# sent_id = 11\n# text = It's all hers!\n# previous = Which person owns this?\n# comment = predeterminer modifier\n1\tIt\tit\tPRON\tPRP\tNumber=Sing|Person=3|PronType=Prs\t4\tnsubj\t_\tSpaceAfter=No\n2\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\tcop\t_\t_\n3\tall\tall\tDET\tDT\tCase=Nom\t4\tdet:predet\t_\t_\n4\thers\thers\tPRON\tPRP\tGender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t0\troot\t_\tSpaceAfter=No\n5\t!\t!\tPUNCT\t.\t_\t4\tpunct\t_\t_\n\n\"\"\".lstrip()\n\nTRAIN_DATA_NO_UPOS = \"\"\"\n# sent_id = 11\n# text = It's all hers!\n# previous = Which person owns this?\n# comment = predeterminer modifier\n1\tIt\tit\t_\tPRP\tNumber=Sing|Person=3|PronType=Prs\t4\tnsubj\t_\tSpaceAfter=No\n2\t's\tbe\t_\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\tcop\t_\t_\n3\tall\tall\t_\tDT\tCase=Nom\t4\tdet:predet\t_\t_\n4\thers\thers\t_\tPRP\tGender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t0\troot\t_\tSpaceAfter=No\n5\t!\t!\t_\t.\t_\t4\tpunct\t_\t_\n\n\"\"\".lstrip()\n\nTRAIN_DATA_NO_XPOS = \"\"\"\n# sent_id = 11\n# text = It's all hers!\n# previous = Which person owns this?\n# comment = predeterminer modifier\n1\tIt\tit\tPRON\t_\tNumber=Sing|Person=3|PronType=Prs\t4\tnsubj\t_\tSpaceAfter=No\n2\t's\tbe\tAUX\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\tcop\t_\t_\n3\tall\tall\tDET\t_\tCase=Nom\t4\tdet:predet\t_\t_\n4\thers\thers\tPRON\t_\tGender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t0\troot\t_\tSpaceAfter=No\n5\t!\t!\tPUNCT\t_\t_\t4\tpunct\t_\t_\n\n\"\"\".lstrip()\n\nTRAIN_DATA_NO_FEATS = \"\"\"\n# sent_id = 11\n# text = It's all hers!\n# previous = Which person owns this?\n# comment = predeterminer modifier\n1\tIt\tit\tPRON\tPRP\t_\t4\tnsubj\t_\tSpaceAfter=No\n2\t's\tbe\tAUX\tVBZ\t_\t4\tcop\t_\t_\n3\tall\tall\tDET\tDT\t_\t4\tdet:predet\t_\t_\n4\thers\thers\tPRON\tPRP\t_\t0\troot\t_\tSpaceAfter=No\n5\t!\t!\tPUNCT\t.\t_\t4\tpunct\t_\t_\n\n\"\"\".lstrip()\n\nDEV_DATA = \"\"\"\n1\tFrom\tfrom\tADP\tIN\t_\t3\tcase\t3:case\t_\n2\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t3\tdet\t3:det\t_\n3\tAP\tAP\tPROPN\tNNP\tNumber=Sing\t4\tobl\t4:obl:from\t_\n4\tcomes\tcome\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t0:root\t_\n5\tthis\tthis\tDET\tDT\tNumber=Sing|PronType=Dem\t6\tdet\t6:det\t_\n6\tstory\tstory\tNOUN\tNN\tNumber=Sing\t4\tnsubj\t4:nsubj\t_\n7\t:\t:\tPUNCT\t:\t_\t4\tpunct\t4:punct\t_\n\n\"\"\".lstrip()\n\nclass TestTagger:\n    @pytest.fixture(scope=\"class\")\n    def wordvec_pretrain_file(self):\n        return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'\n\n    @pytest.fixture(scope=\"class\")\n    def charlm_args(self):\n        charlm = choose_pos_charlm(\"en\", \"test\", \"default\")\n        charlm_args = build_charlm_args(\"en\", charlm, model_dir=TEST_MODELS_DIR)\n        return charlm_args\n\n    def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None):\n        \"\"\"\n        Run the training for a few iterations, load & return the model\n        \"\"\"\n        dev_file = str(tmp_path / \"dev.conllu\")\n        pred_file = str(tmp_path / \"pred.conllu\")\n\n        save_name = \"test_tagger.pt\"\n        save_file = str(tmp_path / save_name)\n\n        if isinstance(train_text, str):\n            train_text = [train_text]\n        train_files = []\n        for idx, train_blob in enumerate(train_text):\n            train_file = str(tmp_path / (\"train_%d.conllu\" % idx))\n            with open(train_file, \"w\", encoding=\"utf-8\") as fout:\n                fout.write(train_blob)\n            train_files.append(train_file)\n        train_file = \";\".join(train_files)\n\n        with open(dev_file, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(dev_text)\n\n        args = [\"--wordvec_pretrain_file\", wordvec_pretrain_file,\n                \"--train_file\", train_file,\n                \"--eval_file\", dev_file,\n                \"--output_file\", pred_file,\n                \"--log_step\", \"10\",\n                \"--eval_interval\", \"20\",\n                \"--max_steps\", \"100\",\n                \"--shorthand\", \"en_test\",\n                \"--save_dir\", str(tmp_path),\n                \"--save_name\", save_name,\n                \"--lang\", \"en\"]\n        if not augment_nopunct:\n            args.extend([\"--augment_nopunct\", \"0.0\"])\n        if extra_args is not None:\n            args = args + extra_args\n        tagger.main(args)\n\n        assert os.path.exists(save_file)\n        pt = pretrain.Pretrain(wordvec_pretrain_file)\n        saved_model = Trainer(pretrain=pt, model_file=save_file)\n        return saved_model\n\n    def test_train(self, tmp_path, wordvec_pretrain_file, augment_nopunct=True):\n        \"\"\"\n        Simple test of a few 'epochs' of tagger training\n        \"\"\"\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA)\n\n    def test_vocab_cutoff(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Test that the vocab cutoff leaves words we expect in the vocab, but not rare words\n        \"\"\"\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=[\"--word_cutoff\", \"3\"])\n        word_vocab = trainer.vocab['word']\n        assert 'of' in word_vocab\n        assert 'officials' in TRAIN_DATA\n        assert 'officials' not in word_vocab\n\n    def test_multiple_files(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Test that multiple train files works\n\n        Checks for evidence of it working by looking for words from the second file in the vocab\n        \"\"\"\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA, TRAIN_DATA_2 * 3], DEV_DATA, extra_args=[\"--word_cutoff\", \"3\"])\n        word_vocab = trainer.vocab['word']\n        assert 'of' in word_vocab\n        assert 'officials' in TRAIN_DATA\n        assert 'officials' not in word_vocab\n\n        assert '\thers\t' not in TRAIN_DATA\n        assert '\thers\t' in TRAIN_DATA_2\n        assert 'hers' in word_vocab\n\n    def test_train_zero_augment(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Train with the punct augmentation set to zero\n\n        Distinguishs cases where training works w/ or w/o augmentation\n        \"\"\"\n        extra_args = ['--augment_nopunct', '0.0']\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)\n\n    def test_train_100_augment(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Train with the punct augmentation set to 1.0\n\n        Distinguishs cases where training works w/ or w/o augmentation\n        \"\"\"\n        extra_args = ['--augment_nopunct', '1.0']\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)\n\n    def test_train_charlm(self, tmp_path, wordvec_pretrain_file, charlm_args):\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=charlm_args)\n\n    def test_train_charlm_projection(self, tmp_path, wordvec_pretrain_file, charlm_args):\n        extra_args = charlm_args + ['--charlm_transform_dim', '100']\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)\n\n    def test_missing_column(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Test that using train files with missing columns works\n\n        In this test, we create three separate files, each with a single training entry.\n        We then train on an amalgam of those three files with a batch size of 1, saving after each batch.\n        This will ensure that only one item is used for each training loop and we can inspect the models which were saved.\n\n        Since each of the three files have exactly one column missing\n        from the training data, we expect to see the output maps for\n        each column stay unchanged in one iteration and change in the\n        other two.\n        \"\"\"\n        # use SGD because some old versions of pytorch with Adam keep\n        # learning a value even if the loss is 0 in subsequent steps\n        # (perhaps it had a momentum by default?)\n        extra_args = ['--save_each', '--eval_interval', '1', '--max_steps', '3', '--batch_size', '1', '--optim', 'sgd']\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_FEATS], DEV_DATA, extra_args=extra_args)\n        save_each_name = tagger.save_each_file_name(trainer.args)\n        model_files = [save_each_name % i for i in range(4)]\n        assert all(os.path.exists(x) for x in model_files)\n        pt = pretrain.Pretrain(wordvec_pretrain_file)\n        saved_trainers = [Trainer(pretrain=pt, model_file=model_file) for model_file in model_files]\n\n        upos_unchanged = 0\n        xpos_unchanged = 0\n        ufeats_unchanged = 0\n        for t1, t2 in zip(saved_trainers[:-1], saved_trainers[1:]):\n            upos_unchanged += torch.allclose(t1.model.upos_clf.weight, t2.model.upos_clf.weight)\n            xpos_unchanged += torch.allclose(t1.model.xpos_clf.W_bilin.weight, t2.model.xpos_clf.W_bilin.weight)\n            ufeats_unchanged += all(torch.allclose(f1.W_bilin.weight, f2.W_bilin.weight) for f1, f2 in zip(t1.model.ufeats_clf, t2.model.ufeats_clf))\n        upos_norms = [torch.linalg.norm(t.model.upos_clf.weight) for t in saved_trainers]\n        assert upos_unchanged == 1, \"Unchanged: {} {} {} {}\".format(upos_unchanged, xpos_unchanged, ufeats_unchanged, upos_norms)\n        assert xpos_unchanged == 1, \"Unchanged: %d %d %d\" % (upos_unchanged, xpos_unchanged, ufeats_unchanged)\n        assert ufeats_unchanged == 1, \"Unchanged: %d %d %d\" % (upos_unchanged, xpos_unchanged, ufeats_unchanged)\n\n    def test_save_each(self, tmp_path, wordvec_pretrain_file):\n        extra_args = ['--save_each']\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)\n        save_each_name = tagger.save_each_file_name(trainer.args)\n        expected_models = sorted(set([save_each_name % i for i in range(0, trainer.args['max_steps']+1, trainer.args['eval_interval'])]))\n        assert len(expected_models) == 6\n        for model_name in expected_models:\n            assert os.path.exists(model_name)\n\n\n    def test_with_bert(self, tmp_path, wordvec_pretrain_file):\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert'])\n\n    def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])\n\n    def test_with_bert_finetune(self, tmp_path, wordvec_pretrain_file):\n        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_learning_rate', '0.01', '--bert_hidden_layers', '2'])\n\n    def test_bert_pipeline(self, tmp_path, wordvec_pretrain_file):\n        \"\"\"\n        Test training the tagger, then using it in a pipeline\n\n        The pipeline use of the tagger also tests the longer-than-maxlen workaround for the transformer\n        \"\"\"\n        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert'])\n        save_name = trainer.args['save_name']\n        save_file = str(tmp_path / save_name)\n        assert os.path.exists(save_file)\n\n        pipe = stanza.Pipeline(\"en\", processors=\"tokenize,pos\", models_dir=TEST_MODELS_DIR, pos_model_path=save_file, pos_pretrain_path=wordvec_pretrain_file)\n        trainer = pipe.processors['pos'].trainer\n        assert trainer.args['save_name'] == save_name\n\n        # these should be one chunk only\n        doc = pipe(\"foo \" * 100)\n        doc = pipe(\"foo \" * 500)\n        # this is two chunks of bert embedding\n        doc = pipe(\"foo \" * 1000)\n        # this is multiple chunks\n        doc = pipe(\"foo \" * 2000)\n"
  },
  {
    "path": "stanza/tests/pos/test_xpos_vocab_factory.py",
    "content": "\"\"\"\nTest some pieces of the depparse dataloader\n\"\"\"\nimport pytest\n\nimport logging\nimport os\nimport tempfile\n\nfrom stanza.models import tagger\nfrom stanza.models.common import pretrain\nfrom stanza.models.pos.data import Dataset\nfrom stanza.models.pos.trainer import Trainer\nfrom stanza.models.pos.vocab import WordVocab, XPOSVocab\nfrom stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory\nfrom stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory\nfrom stanza.utils.conll import CoNLL\n\nfrom stanza.tests import TEST_WORKING_DIR\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nlogger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')\n\nEN_EXAMPLE=\"\"\"\n1\tSh'reyan\tSh'reyan\tPROPN\tNNP%(tag)s\tNumber=Sing\t3\tnmod:poss\t3:nmod:poss\t_\n2\t's\t's\tPART\tPOS%(tag)s\t_\t1\tcase\t1:case\t_\n3\tantennae\tantenna\tNOUN%(tag)s\tNNS\tNumber=Plur\t6\tnsubj\t6:nsubj\t_\n4\tare\tbe\tVERB\tVBP%(tag)s\tMood=Ind|Tense=Pres|VerbForm=Fin\t6\tcop\t6:cop\t_\n5\thella\thella\tADV\tRB%(tag)s\t_\t6\tadvmod\t6:advmod\t_\n6\tthicc\tthicc\tADJ\tJJ%(tag)s\tDegree=Pos\t0\troot\t0:root\t_\n\"\"\"\n\nEMPTY_TAG = lambda x: \"\"\nDASH_TAGS = lambda x: \"-%d\" % x\n\ndef build_doc(iterations, suffix):\n    \"\"\"\n    build N copies of the english text above, with a lambda function applied for the tag suffices\n\n    for example:\n      lambda x: \"\" means the suffices are all blank (NNP, POS, NNS, etc) for each iteration\n      lambda x: \"-%d\" % x means they go (NNP-0, NNP-1, NNP-2, etc) for the first word's tag\n    \"\"\"\n    texts = [EN_EXAMPLE % {\"tag\": suffix(i)} for i in range(iterations)]\n    text = \"\\n\\n\".join(texts)\n    doc = CoNLL.conll2doc(input_str=text)\n    return doc\n\ndef build_data(iterations, suffix):\n    \"\"\"\n    Same thing, but passes the Doc through a POS Tagger DataLoader\n    \"\"\"\n    doc = build_doc(iterations, suffix)\n    data = Dataset.load_doc(doc)\n    return data\n\nclass ErrorFatalHandler(logging.Handler):\n    \"\"\"\n    This handler turns any error logs into a fatal error\n\n    Theoretically you could change the level to make other things fatal as well\n    \"\"\"\n    def __init__(self):\n        super().__init__()\n\n        self.setLevel(logging.ERROR)\n\n    def emit(self, record):\n        raise AssertionError(\"Oh no, we printed an error\")\n\nclass TestXPOSVocabFactory:\n    @classmethod\n    def setup_class(cls):\n        \"\"\"\n        Add a logger to the xpos factory logger so that it will throw an assertion instead of logging an error\n\n        We don't actually want assertions, since that would be a huge\n        pain in the event one of the models actually changes, so\n        instead we just logger.error in the factory.  Using this\n        handler is a simple way to check that the error is correctly\n        logged when something changes\n        \"\"\"\n        logger.info(\"About to start xpos_vocab_factory tests - logger.error in that module will now cause AssertionError\")\n\n        handler = ErrorFatalHandler()\n        logger.addHandler(handler)\n\n    @classmethod\n    def teardown_class(cls):\n        \"\"\"\n        Remove the handler we installed earlier\n        \"\"\"\n        handlers = [x for x in logger.handlers if isinstance(x, ErrorFatalHandler)]\n        for handler in handlers:\n            logger.removeHandler(handler)\n        logger.error(\"Done with xpos_vocab_factory tests - this should not throw an error\")\n\n    def test_basic_en_ewt(self):\n        \"\"\"\n        en_ewt is currently the basic vocab\n\n        note that this may change if the dataset is drastically relabeled in the future\n        \"\"\"\n        data = build_data(1, EMPTY_TAG)\n        vocab = xpos_vocab_factory(data, \"en_ewt\")\n        assert isinstance(vocab, WordVocab)\n\n\n    def test_basic_en_unknown(self):\n        \"\"\"\n        With only 6 tags, it should use a basic vocab for an unknown dataset\n        \"\"\"\n        data = build_data(10, EMPTY_TAG)\n        vocab = xpos_vocab_factory(data, \"en_unknown\")\n        assert isinstance(vocab, WordVocab)\n\n\n    def test_dash_en_unknown(self):\n        \"\"\"\n        With this many different tags, it should choose to reduce it to the base xpos removing the -\n        \"\"\"\n        data = build_data(10, DASH_TAGS)\n        vocab = xpos_vocab_factory(data, \"en_unknown\")\n        assert isinstance(vocab, XPOSVocab)\n        assert vocab.sep == \"-\"\n\n    def test_dash_en_ewt_wrong(self):\n        \"\"\"\n        The dataset looks like XPOS(-), which is wrong for en_ewt\n        \"\"\"\n        with pytest.raises(AssertionError):\n            data = build_data(10, DASH_TAGS)\n            vocab = xpos_vocab_factory(data, \"en_ewt\")\n            assert isinstance(vocab, XPOSVocab)\n            assert vocab.sep == \"-\"\n\n    def check_reload(self, pt, shorthand, iterations, suffix, expected_vocab):\n        \"\"\"\n        Build a Trainer (no actual training), save it, and load it back in to check the type of Vocab restored\n\n        TODO: This test may be a bit \"eager\" in that there are no other\n        tests which check building, saving, & loading a pos trainer.\n        Could add tests to test_trainer.py, for example\n        \"\"\"\n        with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:\n            args = tagger.parse_args([\"--batch_size\", \"1\", \"--shorthand\", shorthand])\n            train_doc = build_doc(iterations, suffix)\n            train_batch = Dataset(train_doc, args, pt, evaluation=False)\n            vocab = train_batch.vocab\n            assert isinstance(vocab['xpos'], expected_vocab)\n\n            trainer = Trainer(args=args, vocab=vocab, pretrain=pt, device=\"cpu\")\n\n            model_file = os.path.join(tmpdirname, \"foo.pt\")\n            trainer.save(model_file)\n\n            new_trainer = Trainer(model_file=model_file, pretrain=pt)\n            assert isinstance(new_trainer.vocab['xpos'], expected_vocab)\n\n    @pytest.fixture(scope=\"class\")\n    def pt(self):\n        pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False)\n        return pt\n\n    def test_reload_word_vocab(self, pt):\n        \"\"\"\n        Test that building a model with a known word vocab shorthand, saving it, and loading it gets back a word vocab\n        \"\"\"\n        self.check_reload(pt, \"en_ewt\", 10, EMPTY_TAG, WordVocab)\n\n    def test_reload_unknown_word_vocab(self, pt):\n        \"\"\"\n        Test that building a model with an unknown word vocab, saving it, and loading it gets back a word vocab\n        \"\"\"\n        self.check_reload(pt, \"en_unknown\", 10, EMPTY_TAG, WordVocab)\n\n    def test_reload_unknown_xpos_vocab(self, pt):\n        \"\"\"\n        Test that building a model with an unknown xpos vocab, saving it, and loading it gets back an xpos vocab\n        \"\"\"\n        self.check_reload(pt, \"en_unknown\", 10, DASH_TAGS, XPOSVocab)\n\n"
  },
  {
    "path": "stanza/tests/pytest.ini",
    "content": "[pytest]\nmarkers =\n    travis: all tests that will be run in travis CI\n    client: all tests that are related to the CoreNLP client interface\n    pipeline: all tests that are related to the Stanza neural pipeline\n    morphseg: all tests that are related to morpheme segmentation"
  },
  {
    "path": "stanza/tests/resources/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/resources/test_charlm_depparse.py",
    "content": "import pytest\n\nfrom stanza.resources.default_packages import default_charlms, depparse_charlms\nfrom stanza.resources.print_charlm_depparse import list_depparse\n\ndef test_list_depparse():\n    models = list_depparse()\n\n    # check that it's picking up the models which don't have specific charlms\n    # first, make sure the default assumption of the test is still true...\n    # if this test fails, find a different language which isn't in depparse_charlms\n    assert \"af\" not in depparse_charlms\n    assert \"af\" in default_charlms\n    assert \"af_afribooms_charlm\" in models\n    assert \"af_afribooms_nocharlm\" in models\n\n    # assert that it's picking up the models which do have specific charlms that aren't None\n    # again, first make sure the default assumptions are true\n    # if one of these next few tests fail, just update the test\n    assert \"en\" in depparse_charlms\n    assert \"en\" in default_charlms\n    assert \"ewt\" not in depparse_charlms[\"en\"]\n    assert \"craft\" in depparse_charlms[\"en\"]\n    assert \"mimic\" in depparse_charlms[\"en\"]\n    # now, check the results\n    assert \"en_ewt_charlm\" in models\n    assert \"en_ewt_nocharlm\" in models\n    assert \"en_mimic_charlm\" in models\n    # haven't yet trained w/ and w/o for the bio models\n    assert \"en_mimic_nocharlm\" not in models\n    assert \"en_craft_charlm\" not in models\n    assert \"en_craft_nocharlm\" in models\n"
  },
  {
    "path": "stanza/tests/resources/test_common.py",
    "content": "\"\"\"\nTest various resource downloading functions from resources/common.py\n\"\"\"\n\nimport os\nimport pytest\nimport tempfile\n\nimport stanza\nfrom stanza.resources import common\nfrom stanza.tests import TEST_MODELS_DIR, TEST_WORKING_DIR\n\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\ndef test_assert_file_exists():\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        filename = os.path.join(test_dir, \"test.txt\")\n        with pytest.raises(FileNotFoundError):\n            common.assert_file_exists(filename)\n\n        with open(filename, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(\"Unban mox opal!\")\n        # MD5 of the fake model file, not any real model files in the system\n        EXPECTED_MD5 = \"44dbf21b4e89cea5184615a72a825a36\"\n        common.assert_file_exists(filename)\n        common.assert_file_exists(filename, md5=EXPECTED_MD5)\n\n        with pytest.raises(ValueError):\n            common.assert_file_exists(filename, md5=\"12345\")\n\n        with pytest.raises(ValueError):\n            common.assert_file_exists(filename, md5=\"12345\", alternate_md5=\"12345\")\n\n        common.assert_file_exists(filename, md5=\"12345\", alternate_md5=EXPECTED_MD5)\n\n\ndef test_download_tokenize_mwt():\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download(\"en\", model_dir=test_dir, processors=\"tokenize\", package=\"ewt\", verbose=False)\n        pipeline = stanza.Pipeline(\"en\", model_dir=test_dir, processors=\"tokenize\", package=\"ewt\")\n        assert isinstance(pipeline, stanza.Pipeline)\n        # mwt should be added to the list\n        assert len(pipeline.loaded_processors) == 2\n\ndef test_download_non_default():\n    \"\"\"\n    Test the download path for a single file rather than the default zip\n\n    The expectation is that an NER model will also download two charlm models.\n    If that layout changes on purpose, this test will fail and will need to be updated\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download(\"en\", model_dir=test_dir, processors=\"ner\", package=\"ontonotes_charlm\", verbose=False)\n        assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']\n        en_dir = os.path.join(test_dir, 'en')\n        en_dir_listing = sorted(os.listdir(en_dir))\n        assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'ner', 'pretrain']\n        assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']\n        for i in en_dir_listing:\n            assert len(os.listdir(os.path.join(en_dir, i))) == 1\n\n\ndef test_download_two_models():\n    \"\"\"\n    Test the download path for two NER models\n\n    The package system should now allow for multiple NER models to be\n    specified, and a consequence of that is it should be possible to\n    download two models at once\n\n    The expectation is that the two different NER models both download\n    a different forward & backward charlm.  If that changes, the test\n    will fail.  Best way to update it will be two different models\n    which download two different charlms\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download(\"en\", model_dir=test_dir, processors=\"ner\", package={\"ner\": [\"ontonotes_charlm\", \"anatem\"]}, verbose=False)\n        assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']\n        en_dir = os.path.join(test_dir, 'en')\n        en_dir_listing = sorted(os.listdir(en_dir))\n        assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'ner', 'pretrain']\n        assert sorted(os.listdir(os.path.join(en_dir, 'ner'))) == ['anatem.pt', 'ontonotes_charlm.pt']\n        for i in en_dir_listing:\n            assert len(os.listdir(os.path.join(en_dir, i))) == 2\n\n\ndef test_process_pipeline_parameters():\n    \"\"\"\n    Test a few options for specifying which processors to load\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        lang, model_dir, package, processors = common.process_pipeline_parameters(\"en\", test_dir, None, \"tokenize,pos\")\n        assert processors == {\"tokenize\": \"default\", \"pos\": \"default\"}\n        assert package == None\n\n        lang, model_dir, package, processors = common.process_pipeline_parameters(\"en\", test_dir, {\"tokenize\": \"spacy\"}, \"tokenize,pos\")\n        assert processors == {\"tokenize\": \"spacy\", \"pos\": \"default\"}\n        assert package == None\n\n        lang, model_dir, package, processors = common.process_pipeline_parameters(\"en\", test_dir, {\"pos\": \"ewt\"}, \"tokenize,pos\")\n        assert processors == {\"tokenize\": \"default\", \"pos\": \"ewt\"}\n        assert package == None\n\n        lang, model_dir, package, processors = common.process_pipeline_parameters(\"en\", test_dir, \"ewt\", \"tokenize,pos\")\n        assert processors == {\"tokenize\": \"ewt\", \"pos\": \"ewt\"}\n        assert package == None\n\ndef test_language_resources():\n    resources = common.load_resources_json(TEST_MODELS_DIR)\n\n    # check that an unknown language comes back as None\n    bad_lang = 'z'\n    while bad_lang in resources and len(bad_lang) < 100:\n        bad_lang = bad_lang + 'z'\n    assert bad_lang not in resources\n    assert common.get_language_resources(resources, bad_lang) == None\n\n    # check the parameters of the test make sense\n    # there should be 'zh' which is an alias of 'zh-hans'\n    assert \"zh\" in resources\n    assert \"alias\" in resources[\"zh\"]\n    assert resources[\"zh\"][\"alias\"] == \"zh-hans\"\n\n    # check that getting the resources for either 'zh' or 'zh-hans'\n    # return the simplified Chinese resources\n    zh_resources = common.get_language_resources(resources, \"zh\")\n    assert \"tokenize\" in zh_resources\n    assert \"alias\" not in zh_resources\n    assert \"Chinese\" in zh_resources[\"lang_name\"]\n\n    zh_hans_resources = common.get_language_resources(resources, \"zh-hans\")\n    assert zh_resources == zh_hans_resources\n"
  },
  {
    "path": "stanza/tests/resources/test_default_packages.py",
    "content": "import pytest\n\nimport stanza\n\nfrom stanza.resources import default_packages\n\ndef test_default_pretrains():\n    \"\"\"\n    Test that all languages with a default treebank have a default pretrain or are specifically marked as not having a pretrain\n    \"\"\"\n    for lang in default_packages.default_treebanks.keys():\n        assert lang in default_packages.no_pretrain_languages or lang in default_packages.default_pretrains, \"Lang %s does not have a default pretrain marked!\" % lang\n\ndef test_no_pretrain_languages():\n    \"\"\"\n    Test that no languages have no_default_pretrain marked despite having a pretrain\n    \"\"\"\n    for lang in default_packages.no_pretrain_languages:\n        assert lang not in default_packages.default_pretrains, \"Lang %s is marked as no_pretrain but has a default pretrain!\" % lang\n\n\n\n\n    \n"
  },
  {
    "path": "stanza/tests/resources/test_installation.py",
    "content": "\"\"\"\nTest installation functions.\n\"\"\"\n\nimport os\nimport pytest\nimport shutil\nimport tempfile\n\nimport stanza\nfrom stanza.tests import TEST_WORKING_DIR\n\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\ndef test_install_corenlp():\n    # we do not reset the CORENLP_HOME variable since this may impact the \n    # client tests\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n\n        # the download method doesn't install over existing directories\n        shutil.rmtree(test_dir)\n        stanza.install_corenlp(dir=test_dir)\n\n        assert os.path.isdir(test_dir), \"Installation destination directory not found.\"\n        jar_files = [f for f in os.listdir(test_dir) \\\n                     if f.endswith('.jar') and f.startswith('stanford-corenlp')]\n        assert len(jar_files) > 0, \\\n            \"Cannot find stanford-corenlp jar files in the installation directory.\"\n        assert not os.path.exists(os.path.join(test_dir, 'corenlp.zip')), \\\n            \"Downloaded zip file was not removed.\"\n    \ndef test_download_corenlp_models():\n    model_name = \"arabic\"\n    version = \"4.2.2\"\n\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download_corenlp_models(model=model_name, version=version, dir=test_dir)\n\n        dest_file = os.path.join(test_dir, f\"stanford-corenlp-{version}-models-{model_name}.jar\")\n        assert os.path.isfile(dest_file), \"Downloaded model file not found.\"\n\ndef test_download_tokenize_mwt():\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        stanza.download(\"en\", model_dir=test_dir, processors=\"tokenize\", package=\"ewt\", verbose=False)\n        pipeline = stanza.Pipeline(\"en\", model_dir=test_dir, processors=\"tokenize\", package=\"ewt\")\n        assert isinstance(pipeline, stanza.Pipeline)\n        # mwt should be added to the list\n        assert len(pipeline.loaded_processors) == 2\n"
  },
  {
    "path": "stanza/tests/resources/test_prepare_resources.py",
    "content": "import pytest\n\nimport stanza\nimport stanza.resources.prepare_resources as prepare_resources\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_split_model_name():\n    # Basic test\n    lang, package, processor = prepare_resources.split_model_name('ro_nonstandard_tagger.pt')\n    assert lang == 'ro'\n    assert package == 'nonstandard'\n    assert processor == 'pos'\n\n    # Check that nertagger is found even though it also ends with tagger\n    # Check that ncbi_disease is correctly partitioned despite the extra _\n    lang, package, processor = prepare_resources.split_model_name('en_ncbi_disease_nertagger.pt')\n    assert lang == 'en'\n    assert package == 'ncbi_disease'\n    assert processor == 'ner'\n\n    # assert that processors with _ in them are also okay\n    lang, package, processor = prepare_resources.split_model_name('en_pubmed_forward_charlm.pt')\n    assert lang == 'en'\n    assert package == 'pubmed'\n    assert processor == 'forward_charlm'\n    \n    \n"
  },
  {
    "path": "stanza/tests/server/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/server/test_client.py",
    "content": "\"\"\"\nTests that call a running CoreNLPClient.\n\"\"\"\n\nfrom http.server import BaseHTTPRequestHandler, HTTPServer\nimport multiprocessing\nimport pytest\nimport requests\nimport stanza.server as corenlp\nimport stanza.server.client as client\nimport shlex\nimport subprocess\nimport time\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.tests import *\n\n# set the marker for this module\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\nTEXT = \"Chris wrote a simple sentence that he parsed with Stanford CoreNLP.\\n\"\n\nMAX_REQUEST_ATTEMPTS = 5\n\nEN_GOLD = \"\"\"\nSentence #1 (12 tokens):\nChris wrote a simple sentence that he parsed with Stanford CoreNLP.\n\nTokens:\n[Text=Chris CharacterOffsetBegin=0 CharacterOffsetEnd=5 PartOfSpeech=NNP]\n[Text=wrote CharacterOffsetBegin=6 CharacterOffsetEnd=11 PartOfSpeech=VBD]\n[Text=a CharacterOffsetBegin=12 CharacterOffsetEnd=13 PartOfSpeech=DT]\n[Text=simple CharacterOffsetBegin=14 CharacterOffsetEnd=20 PartOfSpeech=JJ]\n[Text=sentence CharacterOffsetBegin=21 CharacterOffsetEnd=29 PartOfSpeech=NN]\n[Text=that CharacterOffsetBegin=30 CharacterOffsetEnd=34 PartOfSpeech=WDT]\n[Text=he CharacterOffsetBegin=35 CharacterOffsetEnd=37 PartOfSpeech=PRP]\n[Text=parsed CharacterOffsetBegin=38 CharacterOffsetEnd=44 PartOfSpeech=VBD]\n[Text=with CharacterOffsetBegin=45 CharacterOffsetEnd=49 PartOfSpeech=IN]\n[Text=Stanford CharacterOffsetBegin=50 CharacterOffsetEnd=58 PartOfSpeech=NNP]\n[Text=CoreNLP CharacterOffsetBegin=59 CharacterOffsetEnd=66 PartOfSpeech=NNP]\n[Text=. CharacterOffsetBegin=66 CharacterOffsetEnd=67 PartOfSpeech=.]\n\"\"\".strip()\n\ndef run_webserver(port, timeout_secs):\n    class HTTPTimeoutHandler(BaseHTTPRequestHandler):\n        def do_POST(self):\n            time.sleep(timeout_secs)\n            self.send_response(200)\n            self.send_header('Content-type', 'text/plain; charset=utf-8')\n            self.end_headers()\n            self.wfile.write(\"HTTPMockServerTimeout\")\n\n    HTTPServer(('127.0.0.1', port), HTTPTimeoutHandler).serve_forever()\n\nclass HTTPMockServerTimeoutContext:\n    \"\"\" For launching an HTTP server on certain port with an specified delay at responses \"\"\"\n    def __init__(self, port, timeout_secs):\n        self.port = port\n        self.timeout_secs = timeout_secs\n\n    def __enter__(self):\n        self.p = multiprocessing.Process(target=run_webserver, args=(self.port, self.timeout_secs))\n        self.p.daemon = True\n        self.p.start()\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        self.p.terminate()\n\nclass TestCoreNLPClient:\n    @pytest.fixture(scope=\"class\")\n    def corenlp_client(self):\n        \"\"\" Client to run tests on \"\"\"\n        client = corenlp.CoreNLPClient(annotators='tokenize,ssplit,pos,lemma,ner,depparse',\n                                       server_id='stanza_main_test_server')\n        yield client\n        client.stop()\n\n\n    def test_connect(self, corenlp_client):\n        corenlp_client.ensure_alive()\n        assert corenlp_client.is_active\n        assert corenlp_client.is_alive()\n\n\n    def test_context_manager(self):\n        with corenlp.CoreNLPClient(annotators=\"tokenize,ssplit\",\n                                   endpoint=\"http://localhost:9001\") as context_client:\n            ann = context_client.annotate(TEXT)\n            assert corenlp.to_text(ann.sentence[0]) == TEXT[:-1]\n\n    def test_no_duplicate_servers(self):\n        \"\"\"We expect a second server on the same port to fail\"\"\"\n        with pytest.raises(corenlp.PermanentlyFailedException):\n            with corenlp.CoreNLPClient(annotators=\"tokenize,ssplit\") as duplicate_server:\n                raise RuntimeError(\"This should have failed\")\n\n    def test_annotate(self, corenlp_client):\n        ann = corenlp_client.annotate(TEXT)\n        assert corenlp.to_text(ann.sentence[0]) == TEXT[:-1]\n\n\n    def test_update(self, corenlp_client):\n        ann = corenlp_client.annotate(TEXT)\n        ann = corenlp_client.update(ann)\n        assert corenlp.to_text(ann.sentence[0]) == TEXT[:-1]\n\n\n    def test_tokensregex(self, corenlp_client):\n        pattern = '([ner: PERSON]+) /wrote/ /an?/ []{0,3} /sentence|article/'\n        matches = corenlp_client.tokensregex(TEXT, pattern)\n        assert len(matches[\"sentences\"]) == 1\n        assert matches[\"sentences\"][0][\"length\"] == 1\n        assert matches == {\n            \"sentences\": [{\n                \"0\": {\n                    \"text\": \"Chris wrote a simple sentence\",\n                    \"begin\": 0,\n                    \"end\": 5,\n                    \"1\": {\n                        \"text\": \"Chris\",\n                        \"begin\": 0,\n                        \"end\": 1\n                    }},\n                \"length\": 1\n            },]}\n\n\n    def test_semgrex(self, corenlp_client):\n        pattern = '{word:wrote} >nsubj {}=subject >obj {}=object'\n        matches = corenlp_client.semgrex(TEXT, pattern, to_words=True)\n        assert matches == [\n            {\n                \"text\": \"wrote\",\n                \"begin\": 1,\n                \"end\": 2,\n                \"$subject\": {\n                    \"text\": \"Chris\",\n                    \"begin\": 0,\n                    \"end\": 1\n                },\n                \"$object\": {\n                    \"text\": \"sentence\",\n                    \"begin\": 4,\n                    \"end\": 5\n                },\n                \"sentence\": 0,}]\n\n    def test_tregex(self, corenlp_client):\n        # the PP should be easy to parse\n        pattern = 'PP < NP'\n        matches = corenlp_client.tregex(TEXT, pattern)\n        print(matches)\n        assert matches == {\n            'sentences': [\n                {'0': {'sentIndex': 0, 'characterOffsetBegin': 45, 'codepointOffsetBegin': 45, 'characterOffsetEnd': 66, 'codepointOffsetEnd': 66,\n                       'match': '(PP (IN with)\\n  (NP (NNP Stanford) (NNP CoreNLP)))\\n',\n                       'spanString': 'with Stanford CoreNLP', 'namedNodes': []}}\n            ]\n        }\n\n    def test_tregex_trees(self, corenlp_client):\n        \"\"\"\n        Test the results of tregex run on trees w/o parsing\n        \"\"\"\n        trees = tree_reader.read_trees(\"(ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ blue) (NN skin)))))   (ROOT (S (NP (PRP I)) (VP (VBP like) (NP (PRP$ her) (NNS antennae)))))\")\n        pattern = \"VP < NP\"\n        matches = corenlp_client.tregex(pattern=pattern, trees=trees)\n        assert matches == {\n            'sentences': [\n                {'0': {'sentIndex': 0, 'match': '(VP (VBZ has)\\n  (NP (JJ blue) (NN skin)))\\n', 'spanString': 'has blue skin', 'namedNodes': []}},\n                {'0': {'sentIndex': 1, 'match': '(VP (VBP like)\\n  (NP (PRP$ her) (NNS antennae)))\\n', 'spanString': 'like her antennae', 'namedNodes': []}}\n            ]\n        }\n\n    @pytest.fixture\n    def external_server_9001(self):\n        corenlp_home = client.resolve_classpath(None)\n        start_cmd = f'java -Xmx5g -cp \"{corenlp_home}\" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9001 ' \\\n                    f'-timeout 60000 -server_id stanza_external_server -serverProperties {SERVER_TEST_PROPS}'\n        start_cmd = start_cmd and shlex.split(start_cmd)\n        external_server_process = subprocess.Popen(start_cmd)\n\n        yield external_server_process\n\n        assert external_server_process\n        external_server_process.terminate()\n        external_server_process.wait(5)\n\n    def test_external_server_legacy_start_server(self, external_server_9001):\n        \"\"\" Test starting up an external server and accessing with a client with start_server=False \"\"\"\n        with corenlp.CoreNLPClient(start_server=False, endpoint=\"http://localhost:9001\") as external_server_client:\n            ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text')\n        assert ann.strip() == EN_GOLD\n\n    def test_external_server_available(self, external_server_9001):\n        \"\"\" Test starting up an external available server and accessing with a client with start_server=StartServer.DONT_START \"\"\"\n        time.sleep(5) # wait and make sure the external CoreNLP server is up and running\n        with corenlp.CoreNLPClient(start_server=corenlp.StartServer.DONT_START, endpoint=\"http://localhost:9001\") as external_server_client:\n            ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text')\n        assert ann.strip() == EN_GOLD\n\n    def test_external_server_unavailable(self):\n        \"\"\" Test accessing with a client with start_server=StartServer.DONT_START to an external unavailable server \"\"\"\n        with pytest.raises(corenlp.AnnotationException):\n            with corenlp.CoreNLPClient(start_server=corenlp.StartServer.DONT_START, endpoint=\"http://localhost:9001\") as external_server_client:\n                ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text')\n\n    def test_external_server_timeout(self):\n        \"\"\" Test starting up an external server with long response time (20 seconds) and accessing with a client with start_server=StartServer.DONT_START and timeout=5000\"\"\"\n        with HTTPMockServerTimeoutContext(9001, 20):\n            time.sleep(5) # wait and make sure the external HTTPMockServer server is up and running\n            with pytest.raises(corenlp.TimeoutException):\n                with corenlp.CoreNLPClient(start_server=corenlp.StartServer.DONT_START, endpoint=\"http://localhost:9001\", timeout=5000) as external_server_client:\n                    ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text')\n\n    def test_external_server_try_start_with_external(self, external_server_9001):\n        \"\"\" Test starting up an external server and accessing with a client with start_server=StartServer.TRY_START \"\"\"\n        time.sleep(5) # wait and make sure the external CoreNLP server is up and running\n        with corenlp.CoreNLPClient(start_server=corenlp.StartServer.TRY_START,\n                                   annotators='tokenize,ssplit,pos',\n                                   endpoint=\"http://localhost:9001\") as external_server_client:\n            ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text')\n            assert external_server_client.server is None, \"If this is not None, that indicates the client started a server instead of reusing an existing one\"\n        assert ann.strip() == EN_GOLD\n\n    def test_external_server_try_start(self):\n        \"\"\" Test starting up a server with a client with start_server=StartServer.TRY_START \"\"\"\n        with corenlp.CoreNLPClient(start_server=corenlp.StartServer.TRY_START,\n                                   annotators='tokenize,ssplit,pos',\n                                   endpoint=\"http://localhost:9001\") as external_server_client:\n            ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text')\n        assert ann.strip() == EN_GOLD\n\n    def test_external_server_force_start(self, external_server_9001):\n        \"\"\" Test starting up an external server and accessing with a client with start_server=StartServer.FORCE_START \"\"\"\n        time.sleep(5) # wait and make sure the external CoreNLP server is up and running\n        with pytest.raises(corenlp.PermanentlyFailedException):\n            with corenlp.CoreNLPClient(start_server=corenlp.StartServer.FORCE_START, endpoint=\"http://localhost:9001\") as external_server_client:\n                ann = external_server_client.annotate(TEXT, annotators='tokenize,ssplit,pos', output_format='text')\n"
  },
  {
    "path": "stanza/tests/server/test_java_protobuf_requests.py",
    "content": "import tempfile\n\nimport pytest\n\nfrom stanza.models.common.utils import misc_to_space_after, space_after_to_misc\nfrom stanza.models.constituency import tree_reader\nfrom stanza.server import java_protobuf_requests\nfrom stanza.tests import *\nfrom stanza.utils.conll import CoNLL\nfrom stanza.protobuf import DependencyGraph\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef check_tree(proto_tree, py_tree, py_score):\n    tree, tree_score = java_protobuf_requests.from_tree(proto_tree)\n    assert tree_score == py_score\n    assert tree == py_tree\n\ndef test_build_tree():\n    text=\"((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\\n( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 2\n\n    for tree in trees:\n        proto_tree = java_protobuf_requests.build_tree(trees[0], 1.0)\n        check_tree(proto_tree, trees[0], 1.0)\n\n\nESTONIAN_EMPTY_DEPS = \"\"\"\n# sent_id = ewtb2_000035_15\n# text = Ja paari aasta pärast rôômalt maasikatele ...\n1\tJa\tja\tCCONJ\tJ\t_\t3\tcc\t5.1:cc\t_\n2\tpaari\tpaar\tNUM\tN\tCase=Gen|Number=Sing|NumForm=Word|NumType=Card\t3\tnummod\t3:nummod\t_\n3\taasta\taasta\tNOUN\tS\tCase=Gen|Number=Sing\t0\troot\t5.1:obl\t_\n4\tpärast\tpärast\tADP\tK\tAdpType=Post\t3\tcase\t3:case\t_\n5\trôômalt\trõõmsalt\tADV\tD\tTypo=Yes\t3\tadvmod\t5.1:advmod\tOrphan=Yes|CorrectForm=rõõmsalt\n5.1\tpanna\tpanema\tVERB\tV\tVerbForm=Inf\t_\t_\t0:root\tEmpty=5.1\n6\tmaasikatele\tmaasikas\tNOUN\tS\tCase=All|Number=Plur\t3\tobl\t5.1:obl\tOrphan=Yes\n7\t...\t...\tPUNCT\tZ\t_\t3\tpunct\t5.1:punct\t_\n\"\"\".strip()\n\n\ndef test_convert_networkx_graph():\n    doc = CoNLL.conll2doc(input_str=ESTONIAN_EMPTY_DEPS, ignore_gapping=False)\n    deps = doc.sentences[0]._enhanced_dependencies\n\n    graph = DependencyGraph()\n    java_protobuf_requests.convert_networkx_graph(graph, doc.sentences[0], 0)\n    assert len(graph.rootNode) == 1\n    assert graph.rootNode[0] == 0\n    nodes = sorted([(x.index, x.emptyIndex) for x in graph.node])\n    expected_nodes = [(1,0), (2,0), (3,0), (4,0), (5,0), (5,1), (6,0), (7,0)]\n    assert nodes == expected_nodes\n\n    edges = [(x.target, x.dep) for x in graph.edge if x.source == 5 and x.sourceEmpty == 1]\n    edges = sorted(edges)\n    expected_edges = [(1, 'cc'), (3, 'obl'), (5, 'advmod'), (6, 'obl'), (7, 'punct')]\n    assert edges == expected_edges\n\nENGLISH_NBSP_SAMPLE=\"\"\"\n# sent_id = newsgroup-groups.google.com_n3td3v_e874a1e5eb995654_ENG_20060120_052200-0011\n# text = Please note that neither the e-mail address nor name of the sender have been verified.\n1\tPlease\tplease\tINTJ\tUH\t_\t2\tdiscourse\t_\t_\n2\tnote\tnote\tVERB\tVB\tMood=Imp|VerbForm=Fin\t0\troot\t_\t_\n3\tthat\tthat\tSCONJ\tIN\t_\t15\tmark\t_\t_\n4\tneither\tneither\tCCONJ\tCC\t_\t7\tcc:preconj\t_\t_\n5\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t7\tdet\t_\t_\n6\te-mail\te-mail\tNOUN\tNN\tNumber=Sing\t7\tcompound\t_\t_\n7\taddress\taddress\tNOUN\tNN\tNumber=Sing\t15\tnsubj:pass\t_\t_\n8\tnor\tnor\tCCONJ\tCC\t_\t9\tcc\t_\t_\n9\tname\tname\tNOUN\tNN\tNumber=Sing\t7\tconj\t_\t_\n10\tof\tof\tADP\tIN\t_\t12\tcase\t_\t_\n11\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t12\tdet\t_\t_\n12\tsender\tsender\tNOUN\tNN\tNumber=Sing\t7\tnmod\t_\t_\n13\thave\thave\tAUX\tVBP\tMood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin\t15\taux\t_\tSpacesAfter=\\\\u00A0\n14\tbeen\tbe\tAUX\tVBN\tTense=Past|VerbForm=Part\t15\taux:pass\t_\t_\n15\tverified\tverify\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t2\tccomp\t_\tSpaceAfter=No\n16\t.\t.\tPUNCT\t.\t_\t2\tpunct\t_\t_\n\"\"\".strip()\n\ndef test_nbsp_doc():\n    \"\"\"\n    Test that the space conversion methods will convert to and from NBSP\n    \"\"\"\n    doc = CoNLL.conll2doc(input_str=ENGLISH_NBSP_SAMPLE)\n\n    assert doc.sentences[0].text == \"Please note that neither the e-mail address nor name of the sender have been verified.\"\n    assert doc.sentences[0].tokens[12].spaces_after == \" \"\n    assert misc_to_space_after(\"SpacesAfter=\\\\u00A0\") == ' '\n    assert space_after_to_misc(' ') == \"SpacesAfter=\\\\u00A0\"\n\n    conllu = \"{:C}\".format(doc)\n    assert conllu == ENGLISH_NBSP_SAMPLE\n"
  },
  {
    "path": "stanza/tests/server/test_morphology.py",
    "content": "\"\"\"\nTest the most basic functionality of the morphology script\n\"\"\"\n\nimport pytest\n\nfrom stanza.server.morphology import Morphology, process_text\n\nwords    = [\"Jennifer\", \"has\",  \"the\", \"prettiest\", \"antennae\"]\ntags     = [\"NNP\",      \"VBZ\",  \"DT\",  \"JJS\",       \"NNS\"]\nexpected = [\"Jennifer\", \"have\", \"the\", \"pretty\",    \"antenna\"]\n\ndef test_process_text():\n    result = process_text(words, tags)\n    lemma = [x.lemma for x in result.words]\n    print(lemma)\n    assert lemma == expected\n\ndef test_basic_morphology():\n    with Morphology() as morph:\n        result = morph.process(words, tags)\n        lemma = [x.lemma for x in result.words]\n        assert lemma == expected\n"
  },
  {
    "path": "stanza/tests/server/test_parser_eval.py",
    "content": "\"\"\"\nTest the parser eval interface\n\"\"\"\n\nimport pytest\nimport stanza\nfrom stanza.models.constituency import tree_reader\nfrom stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse\nfrom stanza.server.parser_eval import build_request, collate, EvaluateParser, ParseResult\nfrom stanza.tests.server.test_java_protobuf_requests import check_tree\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\ndef build_one_tree_treebank(fake_scores=True):\n    text = \"((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\"\n    trees = tree_reader.read_trees(text)\n    assert len(trees) == 1\n    gold = trees[0]\n    if fake_scores:\n        prediction = (gold, 1.0)\n        treebank = [ParseResult(gold, [prediction], None, None)]\n        return treebank\n    else:\n        prediction = gold\n        return collate([gold], [prediction])\n\ndef check_build(fake_scores=True):\n    treebank = build_one_tree_treebank(fake_scores)\n    request = build_request(treebank)\n\n    assert len(request.treebank) == 1\n    check_tree(request.treebank[0].gold, treebank[0][0], None)\n    assert len(request.treebank[0].predicted) == 1\n    if fake_scores:\n        check_tree(request.treebank[0].predicted[0], treebank[0][1][0][0], treebank[0][1][0][1])\n    else:\n        check_tree(request.treebank[0].predicted[0], treebank[0][1][0], None)\n\n\ndef test_build_tuple_request():\n    check_build(True)\n\ndef test_build_notuple_request():\n    check_build(False)\n\ndef test_score_one_tree_tuples():\n    treebank = build_one_tree_treebank(True)\n\n    with EvaluateParser() as ep:\n        response = ep.process(treebank)\n        assert response.f1 == pytest.approx(1.0)\n\ndef test_score_one_tree_notuples():\n    treebank = build_one_tree_treebank(False)\n\n    with EvaluateParser() as ep:\n        response = ep.process(treebank)\n        assert response.f1 == pytest.approx(1.0)\n"
  },
  {
    "path": "stanza/tests/server/test_protobuf.py",
    "content": "\"\"\"\nTests to read a stored protobuf.\nAlso serves as an example of how to parse sentences, tokens, pos, lemma,\nner, dependencies and mentions.\n\nThe test corresponds to annotations for the following sentence:\n    Chris wrote a simple sentence that he parsed with Stanford CoreNLP.\n\"\"\"\nimport os\nfrom pathlib import Path\nimport pytest\n\nfrom pytest import fixture\nfrom stanza.protobuf import Document, Sentence, Token, DependencyGraph,\\\n                             CorefChain\nfrom stanza.protobuf import parseFromDelimitedString, writeToDelimitedString, to_text\n\n# set the marker for this module\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\n# Text that was annotated\nTEXT = \"Chris wrote a simple sentence that he parsed with Stanford CoreNLP.\\n\"\n\n\n@fixture\ndef doc_pb():\n    test_dir = os.path.dirname(os.path.abspath(__file__))\n    test_dir = Path(test_dir).parent\n    test_data = os.path.join(test_dir, 'data', 'test.dat')\n    with open(test_data, 'rb') as f:\n        buf = f.read()\n    doc = Document()\n    parseFromDelimitedString(doc, buf)\n    return doc\n\n\ndef test_parse_protobuf(doc_pb):\n    assert doc_pb.ByteSize() == 4709\n\n\ndef test_write_protobuf(doc_pb):\n    stream = writeToDelimitedString(doc_pb)\n    buf = stream.getvalue()\n    stream.close()\n\n    doc_pb_ = Document()\n    parseFromDelimitedString(doc_pb_, buf)\n    assert doc_pb == doc_pb_\n\n\ndef test_document_text(doc_pb):\n    assert doc_pb.text == TEXT\n\n\ndef test_sentences(doc_pb):\n    assert len(doc_pb.sentence) == 1\n\n    sentence = doc_pb.sentence[0]\n    assert isinstance(sentence, Sentence)\n    # check sentence length\n    assert sentence.characterOffsetEnd - sentence.characterOffsetBegin == 67\n    # Note that the sentence text should actually be recovered from the tokens.\n    assert sentence.text == ''\n    assert to_text(sentence) == TEXT[:-1]\n\n\ndef test_tokens(doc_pb):\n    sentence = doc_pb.sentence[0]\n    tokens = sentence.token\n    assert len(tokens) == 12\n    assert isinstance(tokens[0], Token)\n\n    # Word\n    words = \"Chris wrote a simple sentence that he parsed with Stanford CoreNLP .\".split()\n    words_ = [t.word for t in tokens]\n    assert  words_ == words\n\n    # Lemma\n    lemmas = \"Chris write a simple sentence that he parse with Stanford CoreNLP .\".split()\n    lemmas_ = [t.lemma for t in tokens]\n    assert lemmas_ == lemmas\n\n    # POS\n    pos = \"NNP VBD DT JJ NN IN PRP VBD IN NNP NNP .\".split()\n    pos_ = [t.pos for t in tokens]\n    assert pos_ == pos\n\n    # NER\n    ner = \"PERSON O O O O O O O O ORGANIZATION O O\".split()\n    ner_ = [t.ner for t in tokens]\n    assert ner_ == ner\n\n    # character offsets\n    begin = [int(i) for i in \"0 6 12 14 21 30 35 38 45 50 59 66\".split()]\n    end =   [int(i) for i in \"5 11 13 20 29 34 37 44 49 58 66 67\".split()]\n    begin_ = [t.beginChar for t in tokens]\n    end_ = [t.endChar for t in tokens]\n    assert begin_ == begin\n    assert end_ == end\n\n\ndef test_dependency_parse(doc_pb):\n    \"\"\"\n    Extract the dependency parse from the annotation.\n    \"\"\"\n    sentence = doc_pb.sentence[0]\n\n    # You can choose from the following types of dependencies.\n    # In general, you'll want enhancedPlusPlus\n    assert sentence.basicDependencies.ByteSize() > 0\n    assert sentence.enhancedDependencies.ByteSize() > 0\n    assert sentence.enhancedPlusPlusDependencies.ByteSize() > 0\n\n    tree = sentence.enhancedPlusPlusDependencies\n    isinstance(tree, DependencyGraph)\n    # Indices are 1-indexd with 0 being the \"pseudo root\"\n    assert tree.root  # 'wrote' is the root. == [2]\n    # There are as many nodes as there are tokens.\n    assert len(tree.node) == len(sentence.token)\n\n    # Enhanced++ dependencies often contain additional edges and are\n    # not trees -- here, 'parsed' would also have an edge to\n    # 'sentence'\n    assert len(tree.edge) == 12\n\n    # This edge goes from \"wrote\" to \"Chirs\"\n    edge = tree.edge[0]\n    assert edge.source == 2\n    assert edge.target == 1\n    assert edge.dep == \"nsubj\"\n\n\ndef test_coref_chain(doc_pb):\n    \"\"\"\n    Extract the corefence chains from the annotation.\n    \"\"\"\n    # Coreference chains span sentences and are stored in the\n    # document.\n    chains = doc_pb.corefChain\n\n    # In this document there is 1 chain with Chris and he.\n    assert len(chains) == 1\n    chain = chains[0]\n    assert isinstance(chain, CorefChain)\n    assert chain.mention[0].beginIndex == 0  # 'Chris'\n    assert chain.mention[0].endIndex == 1\n    assert chain.mention[0].gender == \"MALE\"\n\n    assert chain.mention[1].beginIndex == 6  # 'he'\n    assert chain.mention[1].endIndex == 7\n    assert chain.mention[1].gender == \"MALE\"\n\n    assert chain.representative == 0  # Head of the chain is 'Chris'\n"
  },
  {
    "path": "stanza/tests/server/test_semgrex.py",
    "content": "\"\"\"\nTest the semgrex interface\n\"\"\"\n\nimport pytest\nimport stanza\nimport stanza.server.semgrex as semgrex\nfrom stanza.models.common.doc import Document\nfrom stanza.protobuf import SemgrexRequest\nfrom stanza.utils.conll import CoNLL\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\nTEST_ONE_SENTENCE = [[\n    {\n        \"id\": 1,\n        \"text\": \"Unban\",\n        \"lemma\": \"unban\",\n        \"upos\": \"VERB\",\n        \"xpos\": \"VB\",\n        \"feats\": \"Mood=Imp|VerbForm=Fin\",\n        \"head\": 0,\n        \"deprel\": \"root\",\n        \"misc\": \"start_char=0|end_char=5\"\n    },\n    {\n        \"id\": 2,\n        \"text\": \"Mox\",\n        \"lemma\": \"Mox\",\n        \"upos\": \"PROPN\",\n        \"xpos\": \"NNP\",\n        \"feats\": \"Number=Sing\",\n        \"head\": 3,\n        \"deprel\": \"compound\",\n        \"misc\": \"start_char=6|end_char=9\"\n    },\n    {\n        \"id\": 3,\n        \"text\": \"Opal\",\n        \"lemma\": \"Opal\",\n        \"upos\": \"PROPN\",\n        \"xpos\": \"NNP\",\n        \"feats\": \"Number=Sing\",\n        \"head\": 1,\n        \"deprel\": \"obj\",\n        \"misc\": \"start_char=10|end_char=14\",\n        \"ner\": \"GEM\"\n    },\n    {\n        \"id\": 4,\n        \"text\": \"!\",\n        \"lemma\": \"!\",\n        \"upos\": \"PUNCT\",\n        \"xpos\": \".\",\n        \"head\": 1,\n        \"deprel\": \"punct\",\n        \"misc\": \"start_char=14|end_char=15\"\n    }]]\n\nTEST_TWO_SENTENCES = [[\n    {\n      \"id\": 1,\n      \"text\": \"Unban\",\n      \"lemma\": \"unban\",\n      \"upos\": \"VERB\",\n      \"xpos\": \"VB\",\n      \"feats\": \"Mood=Imp|VerbForm=Fin\",\n      \"head\": 0,\n      \"deprel\": \"root\",\n      \"misc\": \"start_char=0|end_char=5\"\n    },\n    {\n      \"id\": 2,\n      \"text\": \"Mox\",\n      \"lemma\": \"Mox\",\n      \"upos\": \"PROPN\",\n      \"xpos\": \"NNP\",\n      \"feats\": \"Number=Sing\",\n      \"head\": 3,\n      \"deprel\": \"compound\",\n      \"misc\": \"start_char=6|end_char=9\"\n    },\n    {\n      \"id\": 3,\n      \"text\": \"Opal\",\n      \"lemma\": \"Opal\",\n      \"upos\": \"PROPN\",\n      \"xpos\": \"NNP\",\n      \"feats\": \"Number=Sing\",\n      \"head\": 1,\n      \"deprel\": \"obj\",\n      \"misc\": \"start_char=10|end_char=14\"\n    },\n    {\n      \"id\": 4,\n      \"text\": \"!\",\n      \"lemma\": \"!\",\n      \"upos\": \"PUNCT\",\n      \"xpos\": \".\",\n      \"head\": 1,\n      \"deprel\": \"punct\",\n      \"misc\": \"start_char=14|end_char=15\"\n    }],\n    [{\n      \"id\": 1,\n      \"text\": \"Unban\",\n      \"lemma\": \"unban\",\n      \"upos\": \"VERB\",\n      \"xpos\": \"VB\",\n      \"feats\": \"Mood=Imp|VerbForm=Fin\",\n      \"head\": 0,\n      \"deprel\": \"root\",\n      \"misc\": \"start_char=16|end_char=21\"\n    },\n    {\n      \"id\": 2,\n      \"text\": \"Mox\",\n      \"lemma\": \"Mox\",\n      \"upos\": \"PROPN\",\n      \"xpos\": \"NNP\",\n      \"feats\": \"Number=Sing\",\n      \"head\": 3,\n      \"deprel\": \"compound\",\n      \"misc\": \"start_char=22|end_char=25\"\n    },\n    {\n      \"id\": 3,\n      \"text\": \"Opal\",\n      \"lemma\": \"Opal\",\n      \"upos\": \"PROPN\",\n      \"xpos\": \"NNP\",\n      \"feats\": \"Number=Sing\",\n      \"head\": 1,\n      \"deprel\": \"obj\",\n      \"misc\": \"start_char=26|end_char=30\"\n    },\n    {\n      \"id\": 4,\n      \"text\": \"!\",\n      \"lemma\": \"!\",\n      \"upos\": \"PUNCT\",\n      \"xpos\": \".\",\n      \"head\": 1,\n      \"deprel\": \"punct\",\n      \"misc\": \"start_char=30|end_char=31\"\n    }]]\n\nONE_SENTENCE_DOC = Document(TEST_ONE_SENTENCE, \"Unban Mox Opal!\")\nTWO_SENTENCE_DOC = Document(TEST_TWO_SENTENCES, \"Unban Mox Opal! Unban Mox Opal!\")\n\n\ndef check_response(response, response_len=1, semgrex_len=1, source_index=1, target_index=3, reln='obj'):\n    assert len(response.result) == response_len\n    for sentence_idx, sentence_result in enumerate(response.result):\n        for semgrex_result in sentence_result.result:\n            for match in semgrex_result.match:\n                assert sentence_idx == match.sentenceIndex\n    assert len(response.result[0].result) == semgrex_len\n    for semgrex_result in response.result[0].result:\n        assert len(semgrex_result.match) == 1\n        assert semgrex_result.match[0].matchIndex == source_index\n        for match in semgrex_result.match:\n            assert len(match.node) == 2\n            assert match.node[0].name == 'source'\n            assert match.node[0].matchIndex == source_index\n            assert match.node[1].name == 'target'\n            assert match.node[1].matchIndex == target_index\n            assert len(match.reln) == 1\n            assert match.reln[0].name == 'zzz'\n            assert match.reln[0].reln == reln\n\ndef test_multi():\n    with semgrex.Semgrex() as sem:\n        response = sem.process(ONE_SENTENCE_DOC, \"{}=source >obj=zzz {}=target\")\n        check_response(response)\n        response = sem.process(ONE_SENTENCE_DOC, \"{}=source >obj=zzz {}=target\")\n        check_response(response)\n        response = sem.process(TWO_SENTENCE_DOC, \"{}=source >obj=zzz {}=target\")\n        check_response(response, response_len=2)\n\ndef test_single_sentence():\n    response = semgrex.process_doc(ONE_SENTENCE_DOC, \"{}=source >obj=zzz {}=target\")\n    check_response(response)\n\ndef test_two_semgrex():\n    response = semgrex.process_doc(ONE_SENTENCE_DOC, \"{}=source >obj=zzz {}=target\", \"{}=source >obj=zzz {}=target\")\n    check_response(response, semgrex_len=2)\n\ndef test_two_sentences():\n    response = semgrex.process_doc(TWO_SENTENCE_DOC, \"{}=source >obj=zzz {}=target\")\n    check_response(response, response_len=2)\n\ndef test_word_attribute():\n    response = semgrex.process_doc(ONE_SENTENCE_DOC, \"{word:Mox}=source <=zzz {word:Opal}=target\")\n    check_response(response, response_len=1, source_index=2, reln='compound')\n\ndef test_lemma_attribute():\n    response = semgrex.process_doc(ONE_SENTENCE_DOC, \"{lemma:Mox}=source <=zzz {lemma:Opal}=target\")\n    check_response(response, response_len=1, source_index=2, reln='compound')\n\ndef test_xpos_attribute():\n    response = semgrex.process_doc(ONE_SENTENCE_DOC, \"{tag:NNP}=source <=zzz {word:Opal}=target\")\n    check_response(response, response_len=1, source_index=2, reln='compound')\n    response = semgrex.process_doc(ONE_SENTENCE_DOC, \"{pos:NNP}=source <=zzz {word:Opal}=target\")\n    check_response(response, response_len=1, source_index=2, reln='compound')\n\ndef test_upos_attribute():\n    response = semgrex.process_doc(ONE_SENTENCE_DOC, \"{cpos:PROPN}=source <=zzz {word:Opal}=target\")\n    check_response(response, response_len=1, source_index=2, reln='compound')\n\ndef test_ner_attribute():\n    response = semgrex.process_doc(ONE_SENTENCE_DOC, \"{cpos:PROPN}=source <=zzz {ner:GEM}=target\")\n    check_response(response, response_len=1, source_index=2, reln='compound')\n\ndef test_hand_built_request():\n    \"\"\"\n    Essentially a test program: the result should be a response with\n    one match, two named nodes, one named relation\n    \"\"\"\n    request = SemgrexRequest()\n    request.semgrex.append(\"{}=source >obj=zzz {}=target\")\n    query = request.query.add()\n\n    for idx, word in enumerate(['Unban', 'Mox', 'Opal']):\n        token = query.token.add()\n        token.word = word\n        token.value = word\n\n        node = query.graph.node.add()\n        node.sentenceIndex = 1\n        node.index = idx+1\n\n    edge = query.graph.edge.add()\n    edge.source = 1\n    edge.target = 3\n    edge.dep = 'obj'\n\n    edge = query.graph.edge.add()\n    edge.source = 3\n    edge.target = 2\n    edge.dep = 'compound'\n\n    response = semgrex.send_semgrex_request(request)\n    check_response(response)\n\nBLANK_DEPENDENCY_SENTENCE = \"\"\"\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0007\n# text = You wonder if he was manipulating the market with his bombing targets.\n1\tYou\tyou\tPRON\tPRP\tCase=Nom|Person=2|PronType=Prs\t2\tnsubj\t_\t_\n2\twonder\twonder\tVERB\tVBP\tMood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin\t1\t_\t_\t_\n3\tif\tif\tSCONJ\tIN\t_\t6\tmark\t_\t_\n4\the\the\tPRON\tPRP\tCase=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs\t6\tnsubj\t_\t_\n5\twas\tbe\tAUX\tVBD\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t6\taux\t_\t_\n6\tmanipulating\tmanipulate\tVERB\tVBG\tTense=Pres|VerbForm=Part\t2\tccomp\t_\t_\n7\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t8\tdet\t_\t_\n8\tmarket\tmarket\tNOUN\tNN\tNumber=Sing\t6\tobj\t_\t_\n9\twith\twith\tADP\tIN\t_\t12\tcase\t_\t_\n10\this\this\tPRON\tPRP$\tCase=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t12\tnmod:poss\t_\t_\n11\tbombing\tbombing\tNOUN\tNN\tNumber=Sing\t12\tcompound\t_\t_\n12\ttargets\ttarget\tNOUN\tNNS\tNumber=Plur\t6\tobl\t_\tSpaceAfter=No\n13\t.\t.\tPUNCT\t.\t_\t2\tpunct\t_\t_\n\"\"\".lstrip()\n\n\ndef test_blank_dependency():\n    \"\"\"\n    A user / contributor sent a dependency file with blank dependency labels and twisted up roots\n    \"\"\"\n    blank_dep_doc = CoNLL.conll2doc(input_str=BLANK_DEPENDENCY_SENTENCE)\n    blank_dep_request = semgrex.build_request(blank_dep_doc, \"{}=root <_=edge {}\")\n    response = semgrex.send_semgrex_request(blank_dep_request)\n    assert len(response.result) == 1\n    assert len(response.result[0].result) == 1\n    assert len(response.result[0].result[0].match) == 1\n    # there should be a named node...\n    assert len(response.result[0].result[0].match[0].node) == 1\n    assert response.result[0].result[0].match[0].node[0].name == 'root'\n    assert response.result[0].result[0].match[0].node[0].matchIndex == 2\n\n    # ... and a named edge\n    assert len(response.result[0].result[0].match[0].edge) == 1\n    assert response.result[0].result[0].match[0].edge[0].source == 1\n    assert response.result[0].result[0].match[0].edge[0].target == 2\n    assert response.result[0].result[0].match[0].edge[0].reln == \"_\"\n\nEXPECTED_ONE_SENTENCE_MATCH = \"\"\"\n# text = Unban Mox Opal!\n# sent_id = 0\n# semgrex pattern |{cpos:PROPN}=source <=zzz {ner:GEM}=target| matched at 2:Mox  source=2:Mox target=3:Opal\n# highlight tokens = 2\n# highlight deprels = 2\n1\tUnban\tunban\tVERB\tVB\tMood=Imp|VerbForm=Fin\t0\troot\t_\tstart_char=0|end_char=5\n2\tMox\tMox\tPROPN\tNNP\tNumber=Sing\t3\tcompound\t_\tstart_char=6|end_char=9\n3\tOpal\tOpal\tPROPN\tNNP\tNumber=Sing\t1\tobj\t_\tSpaceAfter=No|start_char=10|end_char=14|ner=GEM\n4\t!\t!\tPUNCT\t.\t_\t1\tpunct\t_\tSpaceAfter=No|start_char=14|end_char=15\n\"\"\".strip()\n\ndef test_ner_annotated():\n    semgrex_pattern = \"{cpos:PROPN}=source <=zzz {ner:GEM}=target\"\n    # not using the existing ONE_SENTENCE_DOC as the Document may be mutated\n    doc = Document(TEST_ONE_SENTENCE, \"Unban Mox Opal!\")\n    response = semgrex.process_doc(doc, semgrex_pattern)\n    doc = semgrex.annotate_doc(doc, response, semgrex_pattern, True, False)\n    formatted = \"{:C}\".format(doc).strip()\n    assert formatted == EXPECTED_ONE_SENTENCE_MATCH\n\nEXPECTED_ONE_SENTENCE_NO_MATCH = \"\"\"\n# text = Unban Mox Opal!\n# sent_id = 0\n# semgrex pattern |{cpos:ZZZZ}| did not match!\n1\tUnban\tunban\tVERB\tVB\tMood=Imp|VerbForm=Fin\t0\troot\t_\tstart_char=0|end_char=5\n2\tMox\tMox\tPROPN\tNNP\tNumber=Sing\t3\tcompound\t_\tstart_char=6|end_char=9\n3\tOpal\tOpal\tPROPN\tNNP\tNumber=Sing\t1\tobj\t_\tSpaceAfter=No|start_char=10|end_char=14|ner=GEM\n4\t!\t!\tPUNCT\t.\t_\t1\tpunct\t_\tSpaceAfter=No|start_char=14|end_char=15\n\"\"\".strip()\n\ndef test_not_annotated():\n    semgrex_pattern = \"{cpos:ZZZZ}\"\n    # not using the existing ONE_SENTENCE_DOC as the Document may be mutated\n    doc = Document(TEST_ONE_SENTENCE, \"Unban Mox Opal!\")\n    response = semgrex.process_doc(doc, semgrex_pattern)\n    doc = semgrex.annotate_doc(doc, response, semgrex_pattern, False, False)\n    formatted = \"{:C}\".format(doc).strip()\n    assert formatted == EXPECTED_ONE_SENTENCE_NO_MATCH\n\n\ndef test_empty_not_annotated():\n    \"\"\"\n    If there are no responses and match_only is set, the returned doc should be empty\n    \"\"\"\n    semgrex_pattern = \"{cpos:ZZZZ}\"\n    # not using the existing ONE_SENTENCE_DOC as the Document may be mutated\n    doc = Document(TEST_ONE_SENTENCE, \"Unban Mox Opal!\")\n    response = semgrex.process_doc(doc, semgrex_pattern)\n    doc = semgrex.annotate_doc(doc, response, semgrex_pattern, True, False)\n    formatted = \"{:C}\".format(doc).strip()\n    assert formatted == \"\"\n\ndef test_only_not_annotated():\n    semgrex_pattern = \"{cpos:ZZZZ}\"\n    # not using the existing ONE_SENTENCE_DOC as the Document may be mutated\n    doc = Document(TEST_ONE_SENTENCE, \"Unban Mox Opal!\")\n    response = semgrex.process_doc(doc, semgrex_pattern)\n    doc = semgrex.annotate_doc(doc, response, semgrex_pattern, False, True)\n    formatted = \"{:C}\".format(doc).strip()\n    assert formatted == EXPECTED_ONE_SENTENCE_NO_MATCH\n\n"
  },
  {
    "path": "stanza/tests/server/test_server_misc.py",
    "content": "\"\"\"\nMisc tests for the server\n\"\"\"\n\nimport pytest\nimport re\nimport stanza.server as corenlp\nfrom stanza.tests import compare_ignoring_whitespace\n\npytestmark = pytest.mark.client\n\nEN_DOC = \"Joe Smith lives in California.\"\n\nEN_DOC_GOLD = \"\"\"\nSentence #1 (6 tokens):\nJoe Smith lives in California.\n\nTokens:\n[Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP Lemma=Joe NamedEntityTag=PERSON]\n[Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP Lemma=Smith NamedEntityTag=PERSON]\n[Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ Lemma=live NamedEntityTag=O]\n[Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN Lemma=in NamedEntityTag=O]\n[Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP Lemma=California NamedEntityTag=STATE_OR_PROVINCE]\n[Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=. Lemma=. NamedEntityTag=O]\n\nDependency Parse (enhanced plus plus dependencies):\nroot(ROOT-0, lives-3)\ncompound(Smith-2, Joe-1)\nnsubj(lives-3, Smith-2)\ncase(California-5, in-4)\nobl:in(lives-3, California-5)\npunct(lives-3, .-6)\n\nExtracted the following NER entity mentions:\nJoe Smith       PERSON  PERSON:0.9972202681743931\nCalifornia      STATE_OR_PROVINCE       LOCATION:0.9990868267559281\n\nExtracted the following KBP triples:\n1.0     Joe Smith       per:statesorprovinces_of_residence      California\n\"\"\"\n\n\nEN_DOC_POS_ONLY_GOLD = \"\"\"\nSentence #1 (6 tokens):\nJoe Smith lives in California.\n\nTokens:\n[Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP]\n[Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP]\n[Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ]\n[Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN]\n[Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP]\n[Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.]\n\"\"\"\n\ndef test_english_request():\n    \"\"\" Test case of starting server with Spanish defaults, and then requesting default English properties \"\"\"\n    with corenlp.CoreNLPClient(properties='spanish', server_id='test_spanish_english_request') as client:\n        ann = client.annotate(EN_DOC, properties='english', output_format='text')\n        compare_ignoring_whitespace(ann, EN_DOC_GOLD)\n\n    # Rerun the test with a server created in English mode to verify\n    # that the expected output is what the defaults actually give us\n    with corenlp.CoreNLPClient(properties='english', server_id='test_english_request') as client:\n        ann = client.annotate(EN_DOC, output_format='text')\n        compare_ignoring_whitespace(ann, EN_DOC_GOLD)\n\n\ndef test_default_annotators():\n    \"\"\"\n    Test case of creating a client with start_server=False and a set of annotators\n    The annotators should be used instead of the server's default annotators\n    \"\"\"\n    with corenlp.CoreNLPClient(server_id='test_default_annotators',\n                               output_format='text',\n                               annotators=['tokenize','ssplit','pos','lemma','ner','depparse']) as client:\n        with corenlp.CoreNLPClient(start_server=False,\n                                   output_format='text',\n                                   annotators=['tokenize','ssplit','pos']) as client2:\n            ann = client2.annotate(EN_DOC)\n\nexpected_codepoints = ((0, 1), (2, 4), (5, 8), (9, 15), (16, 20))\nexpected_characters = ((0, 1), (2, 4), (5, 10), (11, 17), (18, 22))\ncodepoint_doc = \"I am 𝒚̂𝒊 random text\"\n\ndef test_codepoints():\n    \"\"\" Test case of asking for codepoints from the English tokenizer \"\"\"\n    with corenlp.CoreNLPClient(annotators=['tokenize','ssplit'], # 'depparse','coref'],\n                               properties={'tokenize.codepoint': 'true'}) as client:\n        ann = client.annotate(codepoint_doc)\n        for i, (codepoints, characters) in enumerate(zip(expected_codepoints, expected_characters)):\n            token = ann.sentence[0].token[i]\n            assert token.codepointOffsetBegin == codepoints[0]\n            assert token.codepointOffsetEnd == codepoints[1]\n            assert token.beginChar == characters[0]\n            assert token.endChar == characters[1]\n\ndef test_codepoint_text():\n    \"\"\" Test case of extracting the correct sentence text using codepoints \"\"\"\n\n    text = 'Unban mox opal 🐱.  This is a second sentence.'\n\n    with corenlp.CoreNLPClient(annotators=[\"tokenize\",\"ssplit\"],\n                               properties={'tokenize.codepoint': 'true'}) as client:\n        ann = client.annotate(text)\n\n        text_start = ann.sentence[0].token[0].codepointOffsetBegin\n        text_end = ann.sentence[0].token[-1].codepointOffsetEnd\n        sentence_text = text[text_start:text_end]\n        assert sentence_text == 'Unban mox opal 🐱.'\n\n        text_start = ann.sentence[1].token[0].codepointOffsetBegin\n        text_end = ann.sentence[1].token[-1].codepointOffsetEnd\n        sentence_text = text[text_start:text_end]\n        assert sentence_text == 'This is a second sentence.'\n"
  },
  {
    "path": "stanza/tests/server/test_server_pretokenized.py",
    "content": "\"\"\"\nMisc tests for the server\n\"\"\"\n\nimport pytest\nimport re\n\nfrom stanza.server import CoreNLPClient\n\npytestmark = pytest.mark.client\n\ntokens = {}\ntags = {}\n\n# Italian examples\ntokens[\"italian\"] = [\n    \"È vero , tutti possiamo essere sostituiti .\\n Alcune chiamate partirono da il Quirinale .\"\n]\ntags[\"italian\"] = [\n    [\n        [\"AUX\", \"ADJ\", \"PUNCT\", \"PRON\", \"AUX\", \"AUX\", \"VERB\", \"PUNCT\"],\n        [\"DET\", \"NOUN\", \"VERB\", \"ADP\", \"DET\", \"PROPN\", \"PUNCT\"],\n    ],\n]\n\n\n# French examples\ntokens[\"french\"] = [\n    (\n     \"Les études durent six ans mais leur contenu diffère donc selon les Facultés .\\n\"\n     \"Il est fêté le 22 mai .\"\n    )\n]\ntags[\"french\"] = [\n    [\n        [\"DET\", \"NOUN\", \"VERB\", \"NUM\", \"NOUN\", \"CCONJ\", \"DET\", \"NOUN\", \"VERB\", \"ADV\", \"ADP\", \"DET\", \"PROPN\", \"PUNCT\"],\n        [\"PRON\", \"AUX\", \"VERB\", \"DET\", \"NUM\", \"NOUN\", \"PUNCT\"]\n    ],\n]\n\n\n# English examples\ntokens[\"english\"] = [\"This shouldn't be split .\\n I hope it's not .\"]\ntags[\"english\"] = [\n    [\n        [\"DT\", \"NN\", \"VB\", \"VBN\", \".\"],\n        [\"PRP\", \"VBP\", \"PRP$\", \"RB\", \".\"],\n    ],\n]\n\n\ndef pretokenized_test(lang):\n    \"\"\"Test submitting pretokenized French text.\"\"\"\n    with CoreNLPClient(\n        properties=lang,\n        annotators=\"pos\",\n        pretokenized=True,\n        be_quiet=True,\n    ) as client:\n        for input_text, gold_tags in zip(tokens[lang], tags[lang]):\n            ann = client.annotate(input_text)\n            for sentence_tags, sentence in zip(gold_tags, ann.sentence):\n                result_tags = [tok.pos for tok in sentence.token]\n                assert sentence_tags == result_tags\n\n\ndef test_english_pretokenized():\n    pretokenized_test(\"english\")\n\n\ndef test_italian_pretokenized():\n    pretokenized_test(\"italian\")\n\n\ndef test_french_pretokenized():\n    pretokenized_test(\"french\")\n"
  },
  {
    "path": "stanza/tests/server/test_server_request.py",
    "content": "\"\"\"\nTests for setting request properties of servers\n\"\"\"\n\nimport json\nimport pytest\nimport stanza.server as corenlp\n\nfrom stanza.protobuf import Document\nfrom stanza.tests import TEST_WORKING_DIR, compare_ignoring_whitespace\n\npytestmark = pytest.mark.client\n\nEN_DOC = \"Joe Smith lives in California.\"\n\n# results with an example properties file\nEN_DOC_GOLD = \"\"\"\nSentence #1 (6 tokens):\nJoe Smith lives in California.\n\nTokens:\n[Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP]\n[Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP]\n[Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ]\n[Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN]\n[Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP]\n[Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.]\n\"\"\"\n\nGERMAN_DOC = \"Angela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland.\"\n\nGERMAN_DOC_GOLD = \"\"\"\nSentence #1 (10 tokens):\nAngela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland.\n\nTokens:\n[Text=Angela CharacterOffsetBegin=0 CharacterOffsetEnd=6 PartOfSpeech=PROPN]\n[Text=Merkel CharacterOffsetBegin=7 CharacterOffsetEnd=13 PartOfSpeech=PROPN]\n[Text=ist CharacterOffsetBegin=14 CharacterOffsetEnd=17 PartOfSpeech=AUX]\n[Text=seit CharacterOffsetBegin=18 CharacterOffsetEnd=22 PartOfSpeech=ADP]\n[Text=2005 CharacterOffsetBegin=23 CharacterOffsetEnd=27 PartOfSpeech=NUM]\n[Text=Bundeskanzlerin CharacterOffsetBegin=28 CharacterOffsetEnd=43 PartOfSpeech=NOUN]\n[Text=der CharacterOffsetBegin=44 CharacterOffsetEnd=47 PartOfSpeech=DET]\n[Text=Bundesrepublik CharacterOffsetBegin=48 CharacterOffsetEnd=62 PartOfSpeech=PROPN]\n[Text=Deutschland CharacterOffsetBegin=63 CharacterOffsetEnd=74 PartOfSpeech=PROPN]\n[Text=. CharacterOffsetBegin=74 CharacterOffsetEnd=75 PartOfSpeech=PUNCT]\n\"\"\"\n\nFRENCH_CUSTOM_PROPS = {'annotators': 'tokenize,ssplit,mwt,pos,parse',\n                       'tokenize.language': 'fr',\n                       'pos.model': 'edu/stanford/nlp/models/pos-tagger/french-ud.tagger',\n                       'parse.model': 'edu/stanford/nlp/models/srparser/frenchSR.ser.gz',\n                       'mwt.mappingFile': 'edu/stanford/nlp/models/mwt/french/french-mwt.tsv',\n                       'mwt.pos.model': 'edu/stanford/nlp/models/mwt/french/french-mwt.tagger',\n                       'mwt.statisticalMappingFile': 'edu/stanford/nlp/models/mwt/french/french-mwt-statistical.tsv',\n                       'mwt.preserveCasing': 'false',\n                       'outputFormat': 'text'}\n\nFRENCH_EXTRA_PROPS = {'annotators': 'tokenize,ssplit,mwt,pos,depparse',\n                      'tokenize.language': 'fr',\n                      'pos.model': 'edu/stanford/nlp/models/pos-tagger/french-ud.tagger',\n                      'mwt.mappingFile': 'edu/stanford/nlp/models/mwt/french/french-mwt.tsv',\n                      'mwt.pos.model': 'edu/stanford/nlp/models/mwt/french/french-mwt.tagger',\n                      'mwt.statisticalMappingFile': 'edu/stanford/nlp/models/mwt/french/french-mwt-statistical.tsv',\n                      'mwt.preserveCasing': 'false',\n                      'depparse.model': 'edu/stanford/nlp/models/parser/nndep/UD_French.gz'}\n\nFRENCH_DOC = \"Cette enquête préliminaire fait suite aux révélations de l’hebdomadaire quelques jours plus tôt.\"\n\nFRENCH_CUSTOM_GOLD = \"\"\"\nSentence #1 (16 tokens):\nCette enquête préliminaire fait suite aux révélations de l’hebdomadaire quelques jours plus tôt.\n\nTokens:\n[Text=Cette CharacterOffsetBegin=0 CharacterOffsetEnd=5 PartOfSpeech=DET]\n[Text=enquête CharacterOffsetBegin=6 CharacterOffsetEnd=13 PartOfSpeech=NOUN]\n[Text=préliminaire CharacterOffsetBegin=14 CharacterOffsetEnd=26 PartOfSpeech=ADJ]\n[Text=fait CharacterOffsetBegin=27 CharacterOffsetEnd=31 PartOfSpeech=VERB]\n[Text=suite CharacterOffsetBegin=32 CharacterOffsetEnd=37 PartOfSpeech=NOUN]\n[Text=à CharacterOffsetBegin=38 CharacterOffsetEnd=41 PartOfSpeech=ADP]\n[Text=les CharacterOffsetBegin=38 CharacterOffsetEnd=41 PartOfSpeech=DET]\n[Text=révélations CharacterOffsetBegin=42 CharacterOffsetEnd=53 PartOfSpeech=NOUN]\n[Text=de CharacterOffsetBegin=54 CharacterOffsetEnd=56 PartOfSpeech=ADP]\n[Text=l’ CharacterOffsetBegin=57 CharacterOffsetEnd=59 PartOfSpeech=NOUN]\n[Text=hebdomadaire CharacterOffsetBegin=59 CharacterOffsetEnd=71 PartOfSpeech=ADJ]\n[Text=quelques CharacterOffsetBegin=72 CharacterOffsetEnd=80 PartOfSpeech=DET]\n[Text=jours CharacterOffsetBegin=81 CharacterOffsetEnd=86 PartOfSpeech=NOUN]\n[Text=plus CharacterOffsetBegin=87 CharacterOffsetEnd=91 PartOfSpeech=ADV]\n[Text=tôt CharacterOffsetBegin=92 CharacterOffsetEnd=95 PartOfSpeech=ADV]\n[Text=. CharacterOffsetBegin=95 CharacterOffsetEnd=96 PartOfSpeech=PUNCT]\n\nConstituency parse: \n(ROOT\n  (SENT\n    (NP (DET Cette)\n      (MWN (NOUN enquête) (ADJ préliminaire)))\n    (VN\n      (MWV (VERB fait) (NOUN suite)))\n    (PP (ADP à)\n      (NP (DET les) (NOUN révélations)\n        (PP (ADP de)\n          (NP (NOUN l’)\n            (AP (ADJ hebdomadaire))))))\n    (NP (DET quelques) (NOUN jours))\n    (AdP (ADV plus) (ADV tôt))\n    (PUNCT .)))\n\"\"\"\n\nFRENCH_EXTRA_GOLD = \"\"\"\nSentence #1 (16 tokens):\nCette enquête préliminaire fait suite aux révélations de l’hebdomadaire quelques jours plus tôt.\n\nTokens:\n[Text=Cette CharacterOffsetBegin=0 CharacterOffsetEnd=5 PartOfSpeech=DET]\n[Text=enquête CharacterOffsetBegin=6 CharacterOffsetEnd=13 PartOfSpeech=NOUN]\n[Text=préliminaire CharacterOffsetBegin=14 CharacterOffsetEnd=26 PartOfSpeech=ADJ]\n[Text=fait CharacterOffsetBegin=27 CharacterOffsetEnd=31 PartOfSpeech=VERB]\n[Text=suite CharacterOffsetBegin=32 CharacterOffsetEnd=37 PartOfSpeech=NOUN]\n[Text=à CharacterOffsetBegin=38 CharacterOffsetEnd=41 PartOfSpeech=ADP]\n[Text=les CharacterOffsetBegin=38 CharacterOffsetEnd=41 PartOfSpeech=DET]\n[Text=révélations CharacterOffsetBegin=42 CharacterOffsetEnd=53 PartOfSpeech=NOUN]\n[Text=de CharacterOffsetBegin=54 CharacterOffsetEnd=56 PartOfSpeech=ADP]\n[Text=l’ CharacterOffsetBegin=57 CharacterOffsetEnd=59 PartOfSpeech=NOUN]\n[Text=hebdomadaire CharacterOffsetBegin=59 CharacterOffsetEnd=71 PartOfSpeech=ADJ]\n[Text=quelques CharacterOffsetBegin=72 CharacterOffsetEnd=80 PartOfSpeech=DET]\n[Text=jours CharacterOffsetBegin=81 CharacterOffsetEnd=86 PartOfSpeech=NOUN]\n[Text=plus CharacterOffsetBegin=87 CharacterOffsetEnd=91 PartOfSpeech=ADV]\n[Text=tôt CharacterOffsetBegin=92 CharacterOffsetEnd=95 PartOfSpeech=ADV]\n[Text=. CharacterOffsetBegin=95 CharacterOffsetEnd=96 PartOfSpeech=PUNCT]\n\nDependency Parse (enhanced plus plus dependencies):\nroot(ROOT-0, fait-4)\ndet(enquête-2, Cette-1)\nnsubj(fait-4, enquête-2)\namod(enquête-2, préliminaire-3)\nobj(fait-4, suite-5)\ncase(révélations-8, à-6)\ndet(révélations-8, les-7)\nobl:à(fait-4, révélations-8)\ncase(l’-10, de-9)\nnmod:de(révélations-8, l’-10)\namod(révélations-8, hebdomadaire-11)\ndet(jours-13, quelques-12)\nobl(fait-4, jours-13)\nadvmod(tôt-15, plus-14)\nadvmod(jours-13, tôt-15)\npunct(fait-4, .-16)\n\"\"\"\n\nFRENCH_JSON_GOLD = json.loads(open(f'{TEST_WORKING_DIR}/out/example_french.json', encoding=\"utf-8\").read())\n\nES_DOC = 'Andrés Manuel López Obrador es el presidente de México.'\n\nES_PROPS = {'annotators': 'tokenize,ssplit,mwt,pos,depparse', 'tokenize.language': 'es',\n            'pos.model': 'edu/stanford/nlp/models/pos-tagger/spanish-ud.tagger',\n            'mwt.mappingFile': 'edu/stanford/nlp/models/mwt/spanish/spanish-mwt.tsv',\n            'depparse.model': 'edu/stanford/nlp/models/parser/nndep/UD_Spanish.gz'}\n\nES_PROPS_GOLD = \"\"\"\nSentence #1 (10 tokens):\nAndrés Manuel López Obrador es el presidente de México.\n\nTokens:\n[Text=Andrés CharacterOffsetBegin=0 CharacterOffsetEnd=6 PartOfSpeech=PROPN]\n[Text=Manuel CharacterOffsetBegin=7 CharacterOffsetEnd=13 PartOfSpeech=PROPN]\n[Text=López CharacterOffsetBegin=14 CharacterOffsetEnd=19 PartOfSpeech=PROPN]\n[Text=Obrador CharacterOffsetBegin=20 CharacterOffsetEnd=27 PartOfSpeech=PROPN]\n[Text=es CharacterOffsetBegin=28 CharacterOffsetEnd=30 PartOfSpeech=AUX]\n[Text=el CharacterOffsetBegin=31 CharacterOffsetEnd=33 PartOfSpeech=DET]\n[Text=presidente CharacterOffsetBegin=34 CharacterOffsetEnd=44 PartOfSpeech=NOUN]\n[Text=de CharacterOffsetBegin=45 CharacterOffsetEnd=47 PartOfSpeech=ADP]\n[Text=México CharacterOffsetBegin=48 CharacterOffsetEnd=54 PartOfSpeech=PROPN]\n[Text=. CharacterOffsetBegin=54 CharacterOffsetEnd=55 PartOfSpeech=PUNCT]\n\nDependency Parse (enhanced plus plus dependencies):\nroot(ROOT-0, presidente-7)\nnsubj(presidente-7, Andrés-1)\nflat(Andrés-1, Manuel-2)\nflat(Andrés-1, López-3)\nflat(Andrés-1, Obrador-4)\ncop(presidente-7, es-5)\ndet(presidente-7, el-6)\ncase(México-9, de-8)\nnmod:de(presidente-7, México-9)\npunct(presidente-7, .-10)\n\"\"\"\n\nclass TestServerRequest:\n    @pytest.fixture(scope=\"class\")\n    def corenlp_client(self):\n        \"\"\" Client to run tests on \"\"\"\n        client = corenlp.CoreNLPClient(annotators='tokenize,ssplit,pos', server_id='stanza_request_tests_server')\n        yield client\n        client.stop()\n\n\n    def test_basic(self, corenlp_client):\n        \"\"\" Basic test of making a request, test default output format is a Document \"\"\"\n        ann = corenlp_client.annotate(EN_DOC, output_format=\"text\")\n        compare_ignoring_whitespace(ann, EN_DOC_GOLD)\n        ann = corenlp_client.annotate(EN_DOC)\n        assert isinstance(ann, Document)\n\n\n    def test_python_dict(self, corenlp_client):\n        \"\"\" Test using a Python dictionary to specify all request properties \"\"\"\n        ann = corenlp_client.annotate(ES_DOC, properties=ES_PROPS, output_format=\"text\")\n        compare_ignoring_whitespace(ann, ES_PROPS_GOLD)\n        ann = corenlp_client.annotate(FRENCH_DOC, properties=FRENCH_CUSTOM_PROPS)\n        compare_ignoring_whitespace(ann, FRENCH_CUSTOM_GOLD)\n\n\n    def test_lang_setting(self, corenlp_client):\n        \"\"\" Test using a Stanford CoreNLP supported languages as a properties key \"\"\"\n        ann = corenlp_client.annotate(GERMAN_DOC, properties=\"german\", output_format=\"text\")\n        compare_ignoring_whitespace(ann, GERMAN_DOC_GOLD)\n\n\n    def test_annotators_and_output_format(self, corenlp_client):\n        \"\"\" Test setting the annotators and output_format \"\"\"\n        ann = corenlp_client.annotate(FRENCH_DOC, properties=FRENCH_EXTRA_PROPS,\n                                      annotators=\"tokenize,ssplit,mwt,pos\", output_format=\"json\")\n        assert ann == FRENCH_JSON_GOLD\n"
  },
  {
    "path": "stanza/tests/server/test_server_start.py",
    "content": "\"\"\"\nTests for starting a server in Python code\n\"\"\"\n\nimport pytest\nimport stanza.server as corenlp\nfrom stanza.server.client import AnnotationException\nimport time\n\nfrom stanza.tests import *\n\npytestmark = pytest.mark.client\n\nEN_DOC = \"Joe Smith lives in California.\"\n\n# results on EN_DOC with standard StanfordCoreNLP defaults\nEN_PRELOAD_GOLD = \"\"\"\nSentence #1 (6 tokens):\nJoe Smith lives in California.\n\nTokens:\n[Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP Lemma=Joe NamedEntityTag=PERSON]\n[Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP Lemma=Smith NamedEntityTag=PERSON]\n[Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ Lemma=live NamedEntityTag=O]\n[Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN Lemma=in NamedEntityTag=O]\n[Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP Lemma=California NamedEntityTag=STATE_OR_PROVINCE]\n[Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=. Lemma=. NamedEntityTag=O]\n\nDependency Parse (enhanced plus plus dependencies):\nroot(ROOT-0, lives-3)\ncompound(Smith-2, Joe-1)\nnsubj(lives-3, Smith-2)\ncase(California-5, in-4)\nobl:in(lives-3, California-5)\npunct(lives-3, .-6)\n\nExtracted the following NER entity mentions:\nJoe Smith       PERSON              PERSON:0.9972202681743931\nCalifornia      STATE_OR_PROVINCE   LOCATION:0.9990868267559281\n\nExtracted the following KBP triples:\n1.0 Joe Smith per:statesorprovinces_of_residence California\n\"\"\"\n\n# results with an example properties file\nEN_PROPS_FILE_GOLD = \"\"\"\nSentence #1 (6 tokens):\nJoe Smith lives in California.\n\nTokens:\n[Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP]\n[Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP]\n[Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ]\n[Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN]\n[Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP]\n[Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.]\n\"\"\"\n\nGERMAN_DOC = \"Angela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland.\"\n\n# results with standard German properties\nGERMAN_FULL_PROPS_GOLD = \"\"\"\nSentence #1 (10 tokens):\nAngela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland.\n\nTokens:\n[Text=Angela CharacterOffsetBegin=0 CharacterOffsetEnd=6 PartOfSpeech=PROPN Lemma=angela NamedEntityTag=PERSON]\n[Text=Merkel CharacterOffsetBegin=7 CharacterOffsetEnd=13 PartOfSpeech=PROPN Lemma=merkel NamedEntityTag=PERSON]\n[Text=ist CharacterOffsetBegin=14 CharacterOffsetEnd=17 PartOfSpeech=AUX Lemma=ist NamedEntityTag=O]\n[Text=seit CharacterOffsetBegin=18 CharacterOffsetEnd=22 PartOfSpeech=ADP Lemma=seit NamedEntityTag=O]\n[Text=2005 CharacterOffsetBegin=23 CharacterOffsetEnd=27 PartOfSpeech=NUM Lemma=2005 NamedEntityTag=O]\n[Text=Bundeskanzlerin CharacterOffsetBegin=28 CharacterOffsetEnd=43 PartOfSpeech=NOUN Lemma=bundeskanzlerin NamedEntityTag=O]\n[Text=der CharacterOffsetBegin=44 CharacterOffsetEnd=47 PartOfSpeech=DET Lemma=der NamedEntityTag=O]\n[Text=Bundesrepublik CharacterOffsetBegin=48 CharacterOffsetEnd=62 PartOfSpeech=PROPN Lemma=bundesrepublik NamedEntityTag=LOCATION]\n[Text=Deutschland CharacterOffsetBegin=63 CharacterOffsetEnd=74 PartOfSpeech=PROPN Lemma=deutschland NamedEntityTag=LOCATION]\n[Text=. CharacterOffsetBegin=74 CharacterOffsetEnd=75 PartOfSpeech=PUNCT Lemma=. NamedEntityTag=O]\n\nDependency Parse (enhanced plus plus dependencies):\nroot(ROOT-0, Bundeskanzlerin-6)\nnsubj(Bundeskanzlerin-6, Angela-1)\nflat(Angela-1, Merkel-2)\ncop(Bundeskanzlerin-6, ist-3)\ncase(2005-5, seit-4)\nnmod:seit(Bundeskanzlerin-6, 2005-5)\ndet(Bundesrepublik-8, der-7)\nnmod(Bundeskanzlerin-6, Bundesrepublik-8)\nappos(Bundesrepublik-8, Deutschland-9)\npunct(Bundeskanzlerin-6, .-10)\n\nExtracted the following NER entity mentions:\nAngela Merkel              PERSON   PERSON:0.9999981583351504\nBundesrepublik Deutschland LOCATION LOCATION:0.9682902289749544\n\"\"\"\n\n\nGERMAN_SMALL_PROPS = {'annotators': 'tokenize,ssplit,pos', 'tokenize.language': 'de',\n                      'pos.model': 'edu/stanford/nlp/models/pos-tagger/german-ud.tagger'}\n\n# results with custom Python dictionary set properties\nGERMAN_SMALL_PROPS_GOLD = \"\"\"\nSentence #1 (10 tokens):\nAngela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland.\n\nTokens:\n[Text=Angela CharacterOffsetBegin=0 CharacterOffsetEnd=6 PartOfSpeech=PROPN]\n[Text=Merkel CharacterOffsetBegin=7 CharacterOffsetEnd=13 PartOfSpeech=PROPN]\n[Text=ist CharacterOffsetBegin=14 CharacterOffsetEnd=17 PartOfSpeech=AUX]\n[Text=seit CharacterOffsetBegin=18 CharacterOffsetEnd=22 PartOfSpeech=ADP]\n[Text=2005 CharacterOffsetBegin=23 CharacterOffsetEnd=27 PartOfSpeech=NUM]\n[Text=Bundeskanzlerin CharacterOffsetBegin=28 CharacterOffsetEnd=43 PartOfSpeech=NOUN]\n[Text=der CharacterOffsetBegin=44 CharacterOffsetEnd=47 PartOfSpeech=DET]\n[Text=Bundesrepublik CharacterOffsetBegin=48 CharacterOffsetEnd=62 PartOfSpeech=PROPN]\n[Text=Deutschland CharacterOffsetBegin=63 CharacterOffsetEnd=74 PartOfSpeech=PROPN]\n[Text=. CharacterOffsetBegin=74 CharacterOffsetEnd=75 PartOfSpeech=PUNCT]\n\"\"\"\n\n# results with custom Python dictionary set properties and annotators=tokenize,ssplit\nGERMAN_SMALL_PROPS_W_ANNOTATORS_GOLD = \"\"\"\nSentence #1 (10 tokens):\nAngela Merkel ist seit 2005 Bundeskanzlerin der Bundesrepublik Deutschland.\n\nTokens:\n[Text=Angela CharacterOffsetBegin=0 CharacterOffsetEnd=6]\n[Text=Merkel CharacterOffsetBegin=7 CharacterOffsetEnd=13]\n[Text=ist CharacterOffsetBegin=14 CharacterOffsetEnd=17]\n[Text=seit CharacterOffsetBegin=18 CharacterOffsetEnd=22]\n[Text=2005 CharacterOffsetBegin=23 CharacterOffsetEnd=27]\n[Text=Bundeskanzlerin CharacterOffsetBegin=28 CharacterOffsetEnd=43]\n[Text=der CharacterOffsetBegin=44 CharacterOffsetEnd=47]\n[Text=Bundesrepublik CharacterOffsetBegin=48 CharacterOffsetEnd=62]\n[Text=Deutschland CharacterOffsetBegin=63 CharacterOffsetEnd=74]\n[Text=. CharacterOffsetBegin=74 CharacterOffsetEnd=75]\n\"\"\"\n\n# properties for username/password example\nUSERNAME_PASS_PROPS = {'annotators': 'tokenize,ssplit,pos'}\n\nUSERNAME_PASS_GOLD = \"\"\"\nSentence #1 (6 tokens):\nJoe Smith lives in California.\n\nTokens:\n[Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP]\n[Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP]\n[Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ]\n[Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN]\n[Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP]\n[Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.]\n\"\"\"\n\n\ndef annotate_and_time(client, text, properties={}):\n    \"\"\" Submit an annotation request and return how long it took \"\"\"\n    start = time.time()\n    ann = client.annotate(text, properties=properties, output_format=\"text\")\n    end = time.time()\n    return {'annotation': ann, 'start_time': start, 'end_time': end}\n\ndef test_preload():\n    \"\"\" Test that the default annotators load fully immediately upon server start \"\"\"\n    with corenlp.CoreNLPClient(server_id='test_server_start_preload') as client:\n        # wait for annotators to load\n        time.sleep(140)\n        results = annotate_and_time(client, EN_DOC)\n        compare_ignoring_whitespace(results['annotation'], EN_PRELOAD_GOLD)\n        assert results['end_time'] - results['start_time'] < 3\n\n\ndef test_props_file():\n    \"\"\" Test starting the server with a props file \"\"\"\n    with corenlp.CoreNLPClient(properties=SERVER_TEST_PROPS, server_id='test_server_start_props_file') as client:\n        ann = client.annotate(EN_DOC, output_format=\"text\")\n        assert ann.strip() == EN_PROPS_FILE_GOLD.strip()\n\n\ndef test_lang_start():\n    \"\"\" Test starting the server with a Stanford CoreNLP language name \"\"\"\n    with corenlp.CoreNLPClient(properties='german', server_id='test_server_start_lang_name') as client:\n        ann = client.annotate(GERMAN_DOC, output_format='text')\n        compare_ignoring_whitespace(ann, GERMAN_FULL_PROPS_GOLD)\n\n\ndef test_python_dict():\n    \"\"\" Test starting the server with a Python dictionary as default properties \"\"\"\n    with corenlp.CoreNLPClient(properties=GERMAN_SMALL_PROPS, server_id='test_server_start_python_dict') as client:\n        ann = client.annotate(GERMAN_DOC, output_format='text')\n        assert ann.strip() == GERMAN_SMALL_PROPS_GOLD.strip()\n\n\ndef test_python_dict_w_annotators():\n    \"\"\" Test starting the server with a Python dictionary as default properties, override annotators \"\"\"\n    with corenlp.CoreNLPClient(properties=GERMAN_SMALL_PROPS, annotators=\"tokenize,ssplit\",\n                               server_id='test_server_start_python_dict_w_annotators') as client:\n        ann = client.annotate(GERMAN_DOC, output_format='text')\n        assert ann.strip() == GERMAN_SMALL_PROPS_W_ANNOTATORS_GOLD.strip()\n\n\ndef test_username_password():\n    \"\"\" Test starting a server with a username and password \"\"\"\n    with corenlp.CoreNLPClient(properties=USERNAME_PASS_PROPS, username='user-1234', password='1234',\n                               server_id=\"test_server_username_pass\") as client:\n        # check with correct password\n        ann = client.annotate(EN_DOC, output_format='text', username='user-1234', password='1234')\n        assert ann.strip() == USERNAME_PASS_GOLD.strip()\n        # check with incorrect password, should throw AnnotationException\n        try:\n            ann = client.annotate(EN_DOC, output_format='text', username='user-1234', password='12345')\n            assert False\n        except AnnotationException as ae:\n            pass\n        except Exception as e:\n            assert False\n\n\n"
  },
  {
    "path": "stanza/tests/server/test_ssurgeon.py",
    "content": "import pytest\n\nfrom stanza.tests import compare_ignoring_whitespace\n\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\nfrom stanza.utils.conll import CoNLL\nimport stanza.server.ssurgeon as ssurgeon\n\nSAMPLE_DOC_INPUT = \"\"\"\n# sent_id = 271\n# text = Hers is easy to clean.\n# previous = What did the dealer like about Alex's car?\n# comment = extraction/raising via \"tough extraction\" and clausal subject\n1\tHers\thers\tPRON\tPRP\tGender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t3\tnsubj\t_\t_\n2\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\tcop\t_\t_\n3\teasy\teasy\tADJ\tJJ\tDegree=Pos\t0\troot\t_\t_\n4\tto\tto\tPART\tTO\t_\t5\tmark\t_\t_\n5\tclean\tclean\tVERB\tVB\tVerbForm=Inf\t3\tcsubj\t_\tSpaceAfter=No\n6\t.\t.\tPUNCT\t.\t_\t5\tpunct\t_\t_\n\"\"\"\n\nSAMPLE_DOC_EXPECTED = \"\"\"\n# sent_id = 271\n# text = Hers is easy to clean.\n# previous = What did the dealer like about Alex's car?\n# comment = extraction/raising via \"tough extraction\" and clausal subject\n1\tHers\thers\tPRON\tPRP\tGender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t3\tnsubj\t_\t_\n2\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\tcop\t_\t_\n3\teasy\teasy\tADJ\tJJ\tDegree=Pos\t0\troot\t_\t_\n4\tto\tto\tPART\tTO\t_\t5\tmark\t_\t_\n5\tclean\tclean\tVERB\tVB\tVerbForm=Inf\t3\tadvcl\t_\tSpaceAfter=No\n6\t.\t.\tPUNCT\t.\t_\t5\tpunct\t_\t_\n\"\"\"\n\n\ndef test_ssurgeon_same_length():\n    semgrex_pattern = \"{}=source >nsubj {} >csubj=bad {}\"\n    ssurgeon_edits = [\"relabelNamedEdge -edge bad -reln advcl\"]\n\n    doc = CoNLL.conll2doc(input_str=SAMPLE_DOC_INPUT)\n\n    ssurgeon_response = ssurgeon.process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits)\n    updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False)\n\n    result = \"{:C}\".format(updated_doc)\n    #print(result)\n    #print(SAMPLE_DOC_EXPECTED)\n    compare_ignoring_whitespace(result, SAMPLE_DOC_EXPECTED)\n\n\nADD_WORD_DOC_INPUT = \"\"\"\n# text = Jennifer has lovely antennae.\n# sent_id = 12\n# comment = if you're in to that kind of thing\n1\tJennifer\tJennifer\tPROPN\tNNP\tNumber=Sing\t2\tnsubj\t_\tstart_char=0|end_char=8|ner=S-PERSON\n2\thas\thave\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t_\tstart_char=9|end_char=12|ner=O\n3\tlovely\tlovely\tADJ\tJJ\tDegree=Pos\t4\tamod\t_\tstart_char=13|end_char=19|ner=O\n4\tantennae\tantenna\tNOUN\tNNS\tNumber=Plur\t2\tobj\t_\tstart_char=20|end_char=28|ner=O|SpaceAfter=No\n5\t.\t.\tPUNCT\t.\t_\t2\tpunct\t_\tstart_char=28|end_char=29|ner=O\n\"\"\"\n\nADD_WORD_DOC_EXPECTED = \"\"\"\n# text = Jennifer has lovely blue antennae.\n# sent_id = 12\n# comment = if you're in to that kind of thing\n1\tJennifer\tJennifer\tPROPN\tNNP\tNumber=Sing\t2\tnsubj\t_\tner=S-PERSON\n2\thas\thave\tVERB\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t_\tner=O\n3\tlovely\tlovely\tADJ\tJJ\tDegree=Pos\t5\tamod\t_\tner=O\n4\tblue\tblue\tADJ\tJJ\t_\t5\tamod\t_\tner=O\n5\tantennae\tantenna\tNOUN\tNNS\tNumber=Plur\t2\tobj\t_\tSpaceAfter=No|ner=O\n6\t.\t.\tPUNCT\t.\t_\t2\tpunct\t_\tner=O\n\"\"\"\n\n\ndef test_ssurgeon_different_length():\n    semgrex_pattern = \"{word:antennae}=antennae !> {word:blue}\"\n    ssurgeon_edits = [\"addDep -gov antennae -reln amod -word blue -lemma blue -cpos ADJ -pos JJ -ner O -position -antennae -after \\\" \\\"\"]\n\n    doc = CoNLL.conll2doc(input_str=ADD_WORD_DOC_INPUT)\n    #print()\n    #print(\"{:C}\".format(doc))\n\n    ssurgeon_response = ssurgeon.process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits)\n    updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False)\n\n    result = \"{:C}\".format(updated_doc)\n    #print(result)\n    #print(ADD_WORD_DOC_EXPECTED)\n\n    compare_ignoring_whitespace(result, ADD_WORD_DOC_EXPECTED)\n\nBECOME_MWT_DOC_INPUT = \"\"\"\n# sent_id = 25\n# text = It's not yours!\n# comment = negation \n1\tIt\tit\tPRON\tPRP\tNumber=Sing|Person=2|PronType=Prs\t4\tnsubj\t_\tSpaceAfter=No\n2\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\tcop\t_\t_\n3\tnot\tnot\tPART\tRB\tPolarity=Neg\t4\tadvmod\t_\t_\n4\tyours\tyours\tPRON\tPRP\tGender=Neut|Number=Sing|Person=2|Poss=Yes|PronType=Prs\t0\troot\t_\tSpaceAfter=No\n5\t!\t!\tPUNCT\t.\t_\t4\tpunct\t_\t_\n\"\"\"\n\nBECOME_MWT_DOC_EXPECTED = \"\"\"\n# sent_id = 25\n# text = It's not yours!\n# comment = negation\n1-2\tIt's\t_\t_\t_\t_\t_\t_\t_\t_\n1\tIt\tit\tPRON\tPRP\tNumber=Sing|Person=2|PronType=Prs\t4\tnsubj\t_\t_\n2\t's\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\tcop\t_\t_\n3\tnot\tnot\tPART\tRB\tPolarity=Neg\t4\tadvmod\t_\t_\n4\tyours\tyours\tPRON\tPRP\tGender=Neut|Number=Sing|Person=2|Poss=Yes|PronType=Prs\t0\troot\t_\tSpaceAfter=No\n5\t!\t!\tPUNCT\t.\t_\t4\tpunct\t_\t_\n\"\"\"\n\ndef test_ssurgeon_become_mwt():\n    \"\"\"\n    Test that converting a document, adding a new MWT, works as expected\n    \"\"\"\n    semgrex_pattern = \"{word:It}=it . {word:/'s/}=s\"\n    ssurgeon_edits = [\"EditNode -node it -is_mwt true  -is_first_mwt true  -mwt_text It's\",\n                      \"EditNode -node s  -is_mwt true  -is_first_mwt false -mwt_text It's\"]\n\n    doc = CoNLL.conll2doc(input_str=BECOME_MWT_DOC_INPUT)\n\n    ssurgeon_response = ssurgeon.process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits)\n    updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False)\n\n    result = \"{:C}\".format(updated_doc)\n    compare_ignoring_whitespace(result, BECOME_MWT_DOC_EXPECTED)\n\nEXISTING_MWT_DOC_INPUT = \"\"\"\n# sent_id = newsgroup-groups.google.com_GayMarriage_0ccbb50b41a5830b_ENG_20050321_181500-0005\n# text = One of “NCRC4ME’s”\n1\tOne\tone\tNUM\tCD\tNumType=Card\t0\troot\t0:root\t_\n2\tof\tof\tADP\tIN\t_\t4\tcase\t4:case\t_\n3\t“\t\"\tPUNCT\t``\t_\t4\tpunct\t4:punct\tSpaceAfter=No\n4-5\tNCRC4ME’s\t_\t_\t_\t_\t_\t_\t_\tSpaceAfter=No\n4\tNCRC4ME\tNCRC4ME\tPROPN\tNNP\tNumber=Sing\t1\tcompound\t1:compound\t_\n5\t’s\t's\tPART\tPOS\t_\t4\tcase\t4:case\t_\n6\t”\t\"\tPUNCT\t''\t_\t4\tpunct\t4:punct\t_\n\"\"\"\n\n# TODO: also, we shouldn't lose the enhanced dependencies...\nEXISTING_MWT_DOC_EXPECTED = \"\"\"\n# sent_id = newsgroup-groups.google.com_GayMarriage_0ccbb50b41a5830b_ENG_20050321_181500-0005\n# text = One of “NCRC4ME’s”\n1\tOne\tone\tNUM\tCD\tNumType=Card\t0\troot\t_\t_\n2\tof\tof\tADP\tIN\t_\t4\tcase\t_\t_\n3\t“\t\"\tPUNCT\t``\t_\t4\tpunct\t_\tSpaceAfter=No\n4-5\tNCRC4ME’s\t_\t_\t_\t_\t_\t_\t_\tSpaceAfter=No\n4\tNCRC4ME\tNCRC4ME\tPROPN\tNNP\tNumber=Sing\t1\tcompound\t_\t_\n5\t’s\t's\tPART\tPOS\t_\t4\tcase\t_\t_\n6\t”\t\"\tPUNCT\t''\t_\t4\tpunct\t_\t_\n\"\"\"\n\ndef test_ssurgeon_existing_mwt_no_change():\n    \"\"\"\n    Test that converting a document with an MWT works as expected\n\n    Note regarding this test:\n    Currently it works because ssurgeon.py doesn't look at the\n      \"changed\" flag because of a bug in EditNode in CoreNLP 4.5.3\n    If that is fixed, but the enhanced dependencies aren't fixed,\n      this test will fail because the enhanced dependencies *aren't*\n      removed.  Fixing the enhanced dependencies as well will fix\n      that, though.\n    \"\"\"\n    semgrex_pattern = \"{word:It}=it . {word:/'s/}=s\"\n    ssurgeon_edits = [\"EditNode -node it -is_mwt true  -is_first_mwt true  -mwt_text It's\",\n                      \"EditNode -node s  -is_mwt true  -is_first_mwt false -mwt_text It's\"]\n\n    doc = CoNLL.conll2doc(input_str=EXISTING_MWT_DOC_INPUT)\n\n    ssurgeon_response = ssurgeon.process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits)\n    updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False)\n\n    result = \"{:C}\".format(updated_doc)\n    compare_ignoring_whitespace(result, EXISTING_MWT_DOC_EXPECTED)\n\ndef check_empty_test(input_text, expected=None, echo=False):\n    if expected is None:\n        expected = input_text\n\n    doc = CoNLL.conll2doc(input_str=input_text)\n\n    # we don't want to edit this, just test the to/from conversion\n    ssurgeon_response = ssurgeon.process_doc(doc, [])\n    updated_doc = ssurgeon.convert_response_to_doc(doc, ssurgeon_response, add_missing_text=False)\n\n    result = \"{:C}\".format(updated_doc)\n    if echo:\n        print(\"INPUT\")\n        print(input_text)\n        print(\"EXPECTED\")\n        print(expected)\n        print(\"RESULT\")\n        print(result)\n    compare_ignoring_whitespace(result, expected)\n\nITALIAN_MWT_INPUT = \"\"\"\n# sent_id = train_78\n# text = @user dovrebbe fare pace col cervello\n# twittiro = IMPLICIT\tANALOGY\n1\t@user\t@user\tSYM\tSYM\t_\t3\tnsubj\t_\t_\n2\tdovrebbe\tdovere\tAUX\tVM\tMood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\taux\t_\t_\n3\tfare\tfare\tVERB\tV\tVerbForm=Inf\t0\troot\t_\t_\n4\tpace\tpace\tNOUN\tS\tGender=Fem|Number=Sing\t3\tobj\t_\t_\n5-6\tcol\t_\t_\t_\t_\t_\t_\t_\t_\n5\tcon\tcon\tADP\tE\t_\t7\tcase\t_\t_\n6\til\til\tDET\tRD\tDefinite=Def|Gender=Masc|Number=Sing|PronType=Art\t7\tdet\t_\t_\n7\tcervello\tcervello\tNOUN\tS\tGender=Masc|Number=Sing\t3\tobl\t_\t_\n\"\"\"\n\ndef test_ssurgeon_mwt_text():\n    \"\"\"\n    Test that an MWT which is split into pieces which don't make up\n    the original token results in a correct #text annotation\n\n    For example, in Italian, \"col\" splits into \"con il\", and we want\n    the #text to contain \"col\"\n    \"\"\"\n    check_empty_test(ITALIAN_MWT_INPUT)\n\nITALIAN_SPACES_AFTER_INPUT=\"\"\"\n# sent_id = train_1114\n# text = ““““ buona scuola ““““\n# twittiro = EXPLICIT\tOTHER\n1\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n2\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n3\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n4\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\t_\n5\tbuona\tbuono\tADJ\tA\tGender=Fem|Number=Sing\t6\tamod\t_\t_\n6\tscuola\tscuola\tNOUN\tS\tGender=Fem|Number=Sing\t0\troot\t_\t_\n7\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n8\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n9\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n10\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpacesAfter=\\\\n\n\"\"\"\n\nITALIAN_SPACES_AFTER_YES_INPUT=\"\"\"\n# sent_id = train_1114\n# text = ““““ buona scuola ““““\n# twittiro = EXPLICIT\tOTHER\n1\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n2\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n3\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n4\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=Yes\n5\tbuona\tbuono\tADJ\tA\tGender=Fem|Number=Sing\t6\tamod\t_\t_\n6\tscuola\tscuola\tNOUN\tS\tGender=Fem|Number=Sing\t0\troot\t_\t_\n7\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n8\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n9\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpaceAfter=No\n10\t“\t“\tPUNCT\tFB\t_\t6\tpunct\t_\tSpacesAfter=\\\\n\n\"\"\"\n\n\ndef test_ssurgeon_spaces_after_text():\n    \"\"\"\n    Test that SpacesAfter goes and comes back the same way\n\n    Tested using some random example from the UD_Italian-TWITTIRO dataset\n    \"\"\"\n    check_empty_test(ITALIAN_SPACES_AFTER_INPUT)\n\ndef test_ssurgeon_spaces_after_yes():\n    \"\"\"\n    Test that an unnecessary SpaceAfter=Yes is eliminated\n    \"\"\"\n    check_empty_test(ITALIAN_SPACES_AFTER_YES_INPUT, ITALIAN_SPACES_AFTER_INPUT)\n\nEMPTY_VALUES_INPUT = \"\"\"\n# text = Jennifer has lovely antennae.\n# sent_id = 12\n# comment = if you're in to that kind of thing\n1\tJennifer\t_\t_\t_\tNumber=Sing\t2\tnsubj\t_\tner=S-PERSON\n2\thas\t_\t_\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t0\troot\t_\tner=O\n3\tlovely\t_\t_\t_\tDegree=Pos\t4\tamod\t_\tner=O\n4\tantennae\t_\t_\t_\tNumber=Plur\t2\tobj\t_\tSpaceAfter=No|ner=O\n5\t.\t_\t_\t_\t_\t2\tpunct\t_\tner=O\n\"\"\"\n\ndef test_ssurgeon_blank_values():\n    \"\"\"\n    Check that various None fields such as lemma & xpos are not turned into blanks\n\n    Tests, like regulations, are often written in blood\n    \"\"\"\n    check_empty_test(EMPTY_VALUES_INPUT)\n\n# first couple sentences of UD_Cantonese-HK\n# we change the order of the misc column in word 3 to make sure the\n# pieces don't get unnecessarily reordered by ssurgeon\nCANTONESE_MISC_WORDS_INPUT = \"\"\"\n# sent_id = 1\n# text = 你喺度搵乜嘢呀？\n1\t你\t你\tPRON\t_\t_\t3\tnsubj\t_\tTranslit=nei5|Gloss=2SG|SpaceAfter=No\n2\t喺度\t喺度\tADV\t_\t_\t3\tadvmod\t_\tTranslit=hai2dou6|Gloss=PROG|SpaceAfter=No\n3\t搵\t搵\tVERB\t_\t_\t0\troot\t_\tTranslit=wan2|Gloss=find|SpaceAfter=No\n4\t乜嘢\t乜嘢\tPRON\t_\t_\t3\tobj\t_\tTranslit=mat1je5|Gloss=what|SpaceAfter=No\n5\t呀\t呀\tPART\t_\t_\t3\tdiscourse:sp\t_\tTranslit=aa3|Gloss=SFP|SpaceAfter=No\n6\t？\t？\tPUNCT\t_\t_\t3\tpunct\t_\tSpaceAfter=No\n\n# sent_id = 2\n# text = 咪執返啲嘢去阿哥個新屋度囖。\n1\t咪\t咪\tADV\t_\t_\t2\tadvmod\t_\tSpaceAfter=No\n2\t執\t執\tVERB\t_\t_\t0\troot\t_\tSpaceAfter=No\n3\t返\t返\tVERB\t_\t_\t2\tcompound:dir\t_\tSpaceAfter=No\n4\t啲\t啲\tNOUN\t_\tNounType=Clf\t5\tclf:det\t_\tSpaceAfter=No\n5\t嘢\t嘢\tNOUN\t_\t_\t3\tobj\t_\tSpaceAfter=No\n6\t去\t去\tVERB\t_\t_\t2\tconj\t_\tSpaceAfter=No\n7\t阿哥\t阿哥\tNOUN\t_\t_\t10\tnmod\t_\tSpaceAfter=No\n8\t個\t個\tNOUN\t_\tNounType=Clf\t10\tclf:det\t_\tSpaceAfter=No\n9\t新\t新\tADJ\t_\t_\t10\tamod\t_\tSpaceAfter=No\n10\t屋\t屋\tNOUN\t_\t_\t6\tobj\t_\tSpaceAfter=No\n11\t度\t度\tADP\t_\t_\t10\tcase:loc\t_\tSpaceAfter=No\n12\t囖\t囖\tPART\t_\t_\t2\tdiscourse:sp\t_\tSpaceAfter=No\n13\t。\t。\tPUNCT\t_\t_\t2\tpunct\t_\tSpaceAfter=No\n\"\"\"\n\ndef test_ssurgeon_misc_words():\n    \"\"\"\n    Check that various None fields such as lemma & xpos are not turned into blanks\n\n    Tests, like regulations, are often written in blood\n    \"\"\"\n    check_empty_test(CANTONESE_MISC_WORDS_INPUT)\n\nITALIAN_MWT_SPACE_AFTER_INPUT = \"\"\"\n# sent_id = train_78\n# text = @user dovrebbe fare pace colcervello\n# twittiro = IMPLICIT\tANALOGY\n1\t@user\t@user\tSYM\tSYM\t_\t3\tnsubj\t_\t_\n2\tdovrebbe\tdovere\tAUX\tVM\tMood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\taux\t_\t_\n3\tfare\tfare\tVERB\tV\tVerbForm=Inf\t0\troot\t_\t_\n4\tpace\tpace\tNOUN\tS\tGender=Fem|Number=Sing\t3\tobj\t_\t_\n5-6\tcol\t_\t_\t_\t_\t_\t_\t_\tSpaceAfter=No\n5\tcon\tcon\tADP\tE\t_\t7\tcase\t_\t_\n6\til\til\tDET\tRD\tDefinite=Def|Gender=Masc|Number=Sing|PronType=Art\t7\tdet\t_\t_\n7\tcervello\tcervello\tNOUN\tS\tGender=Masc|Number=Sing\t3\tobl\t_\tRandomFeature=foo\n\"\"\"\n\ndef test_ssurgeon_mwt_space_after():\n    \"\"\"\n    Check the SpaceAfter=No on an MWT (rather than a word)\n\n    the RandomFeature=foo is on account of a silly bug in the initial\n    version of passing in MWT misc features\n    \"\"\"\n    check_empty_test(ITALIAN_MWT_SPACE_AFTER_INPUT)\n\nITALIAN_MWT_MISC_INPUT = \"\"\"\n# sent_id = train_78\n# text = @user dovrebbe farepacecolcervello\n# twittiro = IMPLICIT\tANALOGY\n1\t@user\t@user\tSYM\tSYM\t_\t3\tnsubj\t_\t_\n2\tdovrebbe\tdovere\tAUX\tVM\tMood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\taux\t_\t_\n3-4\tfarepace\t_\t_\t_\t_\t_\t_\t_\tPlayers=GonnaPlay|SpaceAfter=No\n3\tfare\tfare\tVERB\tV\tVerbForm=Inf\t0\troot\t_\t_\n4\tpace\tpace\tNOUN\tS\tGender=Fem|Number=Sing\t3\tobj\t_\t_\n5-6\tcol\t_\t_\t_\t_\t_\t_\t_\tHaters=GonnaHate|SpaceAfter=No\n5\tcon\tcon\tADP\tE\t_\t7\tcase\t_\t_\n6\til\til\tDET\tRD\tDefinite=Def|Gender=Masc|Number=Sing|PronType=Art\t7\tdet\t_\t_\n7\tcervello\tcervello\tNOUN\tS\tGender=Masc|Number=Sing\t3\tobl\t_\tRandomFeature=foo\n\"\"\"\n\ndef test_ssurgeon_mwt_misc():\n    \"\"\"\n    Check the SpaceAfter=No on an MWT (rather than a word)\n\n    the RandomFeature=foo is on account of a silly bug in the initial\n    version of passing in MWT misc features\n    \"\"\"\n    check_empty_test(ITALIAN_MWT_MISC_INPUT)\n\nSINDHI_ROOT_EXAMPLE = \"\"\"\n# sent_id = 1\n# text = غلام رهڻ سان ماڻهو منافق ٿئي ٿو .\n1\tغلام\tغلام\tNOUN\tNN__اسم\tCase=Acc|Gender=Masc|Number=Sing|Person=3\t2\tcompound\t_\t_\n2\tرهڻ\tره\tVERB\tVB__فعل\tNumber=Sing\t6\tadvcl\t_\t_\n3\tسان\tسان\tADP\tIN__حرفِ_جر\tNumber=Sing\t2\tmark\t_\t_\n4\tماڻهو\tماڻهو\tNOUN\tNN__اسم\tCase=Nom|Gender=Masc|Number=Sing|Person=3\t6\tnsubj\t_\t_\n5\tمنافق\tمنافق\tADJ\tJJ__صفت\tCase=Acc|Number=Sing|Person=3\t6\txcomp\t_\t_\n6\tٿئي\tٿي\tVERB\tVB__فعل\tNumber=Sing\t_\t_\t_\t_\n7\tٿو\tٿو\tAUX\tVB__فعل\tNumber=Sing\t6\taux\t_\t_\n8\t.\t.\tPUNCT\t-__پورو_دم\t_\t6\tpunct\t_\t_\n\"\"\".lstrip()\n\nSINDHI_ROOT_EXPECTED = \"\"\"\n# sent_id = 1\n# text = غلام رهڻ سان ماڻهو منافق ٿئي ٿو .\n1\tغلام\tغلام\tNOUN\tNN__اسم\tCase=Acc|Gender=Masc|Number=Sing|Person=3\t2\tcompound\t_\t_\n2\tرهڻ\tره\tVERB\tVB__فعل\tNumber=Sing\t6\tadvcl\t_\t_\n3\tسان\tسان\tADP\tIN__حرفِ_جر\tNumber=Sing\t2\tmark\t_\t_\n4\tماڻهو\tماڻهو\tNOUN\tNN__اسم\tCase=Nom|Gender=Masc|Number=Sing|Person=3\t6\tnsubj\t_\t_\n5\tمنافق\tمنافق\tADJ\tJJ__صفت\tCase=Acc|Number=Sing|Person=3\t6\txcomp\t_\t_\n6\tٿئي\tٿي\tVERB\tVB__فعل\tNumber=Sing\t0\troot\t_\t_\n7\tٿو\tٿو\tAUX\tVB__فعل\tNumber=Sing\t6\taux\t_\t_\n8\t.\t.\tPUNCT\t-__پورو_دم\t_\t6\tpunct\t_\t_\n\"\"\".strip()\n\nSINDHI_EDIT = \"\"\"\n{}=root !< {}\nsetRoots root\n\"\"\"\n\ndef test_ssurgeon_rewrite_sindhi_roots():\n    \"\"\"\n    A user / contributor sent a dependency file with blank roots\n    \"\"\"\n    edits = ssurgeon.parse_ssurgeon_edits(SINDHI_EDIT)\n    expected_edits = [ssurgeon.SsurgeonEdit(semgrex_pattern='{}=root !< {}',\n                                            ssurgeon_edits=['setRoots root'],\n                                            ssurgeon_id='1', notes='', language='UniversalEnglish')]\n    assert edits == expected_edits\n\n    blank_dep_doc = CoNLL.conll2doc(input_str=SINDHI_ROOT_EXAMPLE)\n    # test that the conversion will work w/o crashing, such as because of a missing root edge\n    request = ssurgeon.build_request(blank_dep_doc, edits)\n\n    response = ssurgeon.process_doc(blank_dep_doc, edits)\n    updated_doc = ssurgeon.convert_response_to_doc(blank_dep_doc, response, add_missing_text=False)\n\n    result = \"{:C}\".format(updated_doc)\n    assert result == SINDHI_ROOT_EXPECTED\n"
  },
  {
    "path": "stanza/tests/server/test_tokensregex.py",
    "content": "import pytest\nfrom stanza.tests import *\n\nfrom stanza.models.common.doc import Document\nimport stanza.server.tokensregex as tokensregex\n\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\nfrom stanza.tests.server.test_semgrex import ONE_SENTENCE_DOC, TWO_SENTENCE_DOC\n\ndef test_single_sentence():\n    #expected:\n    #match {\n    #  sentence: 0\n    #  match {\n    #    text: \"Opal\"\n    #    begin: 2\n    #    end: 3\n    #  }\n    #}\n\n    response = tokensregex.process_doc(ONE_SENTENCE_DOC, \"Opal\")\n    assert len(response.match) == 1\n    assert len(response.match[0].match) == 1\n    assert response.match[0].match[0].sentence == 0\n    assert response.match[0].match[0].match.text == \"Opal\"\n    assert response.match[0].match[0].match.begin == 2\n    assert response.match[0].match[0].match.end == 3\n\n\ndef test_ner_sentence():\n    #expected:\n    #match {\n    #  sentence: 0\n    #  match {\n    #    text: \"Opal\"\n    #    begin: 2\n    #    end: 3\n    #  }\n    #}\n\n    response = tokensregex.process_doc(ONE_SENTENCE_DOC, \"[ner: GEM]\")\n    assert len(response.match) == 1\n    assert len(response.match[0].match) == 1\n    assert response.match[0].match[0].sentence == 0\n    assert response.match[0].match[0].match.text == \"Opal\"\n    assert response.match[0].match[0].match.begin == 2\n    assert response.match[0].match[0].match.end == 3\n"
  },
  {
    "path": "stanza/tests/server/test_tsurgeon.py",
    "content": "\"\"\"\nTest the semgrex interface\n\"\"\"\n\nimport pytest\nimport stanza\nfrom stanza.models.constituency import tree_reader\nfrom stanza.server.tsurgeon import process_trees, Tsurgeon\n\nfrom stanza.tests import *\n\npytestmark = [pytest.mark.travis, pytest.mark.client]\n\n\n\ndef test_simple():\n    text=\"( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n    trees = tree_reader.read_trees(text)\n\n    tregex = \"WP=wp\"\n    tsurgeon = \"relabel wp WWWPPP\"\n    result = process_trees(trees, (tregex, tsurgeon))\n    assert len(result) == 1\n    assert str(result[0]) == \"(ROOT (SBARQ (WHNP (WWWPPP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n\ndef test_context():\n    \"\"\"\n    Processing the same thing twice should work twice...\n    \"\"\"\n    with Tsurgeon() as processor:\n        text=\"( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n        trees = tree_reader.read_trees(text)\n\n        tregex = \"WP=wp\"\n        tsurgeon = \"relabel wp WWWPPP\"\n        result = processor.process(trees, (tregex, tsurgeon))\n        assert len(result) == 1\n        assert str(result[0]) == \"(ROOT (SBARQ (WHNP (WWWPPP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n\n        result = processor.process(trees, (tregex, tsurgeon))\n        assert len(result) == 1\n        assert str(result[0]) == \"(ROOT (SBARQ (WHNP (WWWPPP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))\"\n\n\ndef test_arboretum():\n    \"\"\"\n    Test a couple expressions used when processing the Arboretum treebank\n\n    That particular treebank was the original inspiration for adding the Tsurgeon interface\n    \"\"\"\n    with Tsurgeon() as processor:\n        text = \"(s (par (fcl (n s1_1) (vp (v-fin s1_2) (v-pcp2 s1_4)) (adv s1_3) (np (pron-poss s1_5) (n s1_6) (pp (prp s1_7) (n s1_8)))) (pu s1_9) (conj-c s1_10) (fcl (adv s1_11) (v-fin s1_12) (np (prop s1_13) (pp (prp s1_14) (prop s1_15))) (np (art s1_16) (adjp (adv s1_17) (adj s1_18)) (n s1_19) (pp (prp s1_20) (np (pron-poss s1_21) (adj s1_22) (n s1_23) (prop s1_24))))) (pu s1_25)))\"\n        expected = \"(s (par (fcl (n s1_1) (vp (v-fin s1_2) (adv s1_3) (v-pcp2 s1_4)) (np (pron-poss s1_5) (n s1_6) (pp (prp s1_7) (n s1_8)))) (pu s1_9) (conj-c s1_10) (fcl (adv s1_11) (v-fin s1_12) (np (prop s1_13) (pp (prp s1_14) (prop s1_15))) (np (art s1_16) (adjp (adv s1_17) (adj s1_18)) (n s1_19) (pp (prp s1_20) (np (pron-poss s1_21) (adj s1_22) (n s1_23) (prop s1_24))))) (pu s1_25)))\"\n        trees = tree_reader.read_trees(text)\n\n        tregex = \"s1_4 > (__=home > (__=parent > __=grandparent)) . (s1_3 > (__=move > =grandparent))\"\n        tsurgeon = \"move move $+ home\"\n        result = processor.process(trees, (tregex, tsurgeon))\n        assert len(result) == 1\n        assert str(result[0]) == expected\n\n\n        text = \"(s (par (fcl (n s1_1) (vp (v-fin s1_2) (v-pcp2 s1_4)) (adv s1_3) (np (pron-poss s1_5) (n s1_6) (pp (prp s1_7) (n s1_8)))) (pu s1_9) (conj-c s1_10) (fcl (adv s1_11) (v-fin s1_12) (np (prop s1_13) (pp (prp s1_14) (prop s1_15))) (np (art s1_16) (adjp (adv s1_17) (adj s1_18)) (n s1_19) (pp (prp s1_20) (np (pron-poss s1_21) (adj s1_22) (n s1_23) (prop s1_24))))) (pu s1_25)))\"\n        expected = \"(s (par (fcl (n s1_1) (vp (v-fin s1_2) (adv s1_3) (v-pcp2 s1_4)) (np (pron-poss s1_5) (n s1_6) (pp (prp s1_7) (n s1_8)))) (pu s1_9) (conj-c s1_10) (fcl (adv s1_11) (v-fin s1_12) (np (prop s1_13) (pp (prp s1_14) (prop s1_15))) (np (art s1_16) (adjp (adv s1_17) (adj s1_18)) (n s1_19) (pp (prp s1_20) (np (pron-poss s1_21) (adj s1_22) (n s1_23) (prop s1_24))))) (pu s1_25)))\"\n        trees = tree_reader.read_trees(text)\n\n        tregex = \"s1_4 > (__=home > (__=parent $+ (__=move <<, s1_3 <<- s1_3)))\"\n        tsurgeon = \"move move $+ home\"\n        result = processor.process(trees, (tregex, tsurgeon))\n        assert len(result) == 1\n        assert str(result[0]) == expected\n"
  },
  {
    "path": "stanza/tests/server/test_ud_enhancer.py",
    "content": "import pytest\nimport stanza\nfrom stanza.tests import *\n\nfrom stanza.models.common.doc import Document\nimport stanza.server.ud_enhancer as ud_enhancer\n\npytestmark = [pytest.mark.pipeline]\n\ndef check_edges(graph, source, target, num, isExtra=None):\n    edges = [edge for edge in graph.edge if edge.source == source and edge.target == target]\n    assert len(edges) == num\n    if num == 1:\n        assert edges[0].isExtra == isExtra\n\ndef test_one_sentence():\n    nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors=\"tokenize,pos,lemma,depparse\")\n    doc = nlp(\"This is the car that I bought\")\n    result = ud_enhancer.process_doc(doc, language=\"en\", pronouns_pattern=None)\n\n    assert len(result.sentence) == 1\n    sentence = result.sentence[0]\n\n    basic = sentence.basicDependencies\n    assert len(basic.node) == 7\n    assert len(basic.edge) == 6\n    check_edges(basic, 4, 7, 1, False)\n    check_edges(basic, 7, 4, 0)\n\n    enhanced = sentence.enhancedDependencies\n    assert len(enhanced.node) == 7\n    assert len(enhanced.edge) == 7\n    check_edges(enhanced, 4, 7, 1, False)\n    # this is the new edge\n    check_edges(enhanced, 7, 4, 1, True)\n"
  },
  {
    "path": "stanza/tests/setup.py",
    "content": "import glob\nimport logging\nimport os\nimport shutil\nimport stanza\nfrom stanza.resources import installation\nfrom stanza.tests import TEST_HOME_VAR, TEST_WORKING_DIR\n\nlogger = logging.getLogger('stanza')\n\ntest_dir = os.getenv(TEST_HOME_VAR, None)\nif not test_dir:\n    test_dir = TEST_WORKING_DIR\n    logger.info(\"STANZA_TEST_HOME not set.  Will assume %s\", test_dir)\n    logger.info(\"To use a different directory, export or set STANZA_TEST_HOME=...\")\n\nin_dir = os.path.join(test_dir, \"in\")\nout_dir = os.path.join(test_dir, \"out\")\nscripts_dir = os.path.join(test_dir, \"scripts\")\nmodels_dir=os.path.join(test_dir, \"models\")\ncorenlp_dir=os.path.join(test_dir, \"corenlp_dir\")\n\nos.makedirs(test_dir, exist_ok=True)\nos.makedirs(in_dir, exist_ok=True)\nos.makedirs(out_dir, exist_ok=True)\nos.makedirs(scripts_dir, exist_ok=True)\nos.makedirs(models_dir, exist_ok=True)\nos.makedirs(corenlp_dir, exist_ok=True)\n\nlogger.info(\"COPYING FILES\")\n\nshutil.copy(\"stanza/tests/data/external_server.properties\", scripts_dir)\nshutil.copy(\"stanza/tests/data/example_french.json\", out_dir)\nshutil.copy(\"stanza/tests/data/aws_annotations.zip\", in_dir)\nfor emb_file in glob.glob(\"stanza/tests/data/tiny_emb.*\"):\n    shutil.copy(emb_file, in_dir)\n\nlogger.info(\"DOWNLOADING MODELS\")\n\nstanza.download(lang='en', model_dir=models_dir, logging_level='info')\nstanza.download(lang=\"en\", model_dir=models_dir, package=None, processors={\"ner\":\"ncbi_disease\"})\nstanza.download(lang='fr', model_dir=models_dir, logging_level='info')\n# Latin ITTB has no case information for the lemmatizer\nstanza.download(lang='he', model_dir=models_dir, processors='tokenize', logging_level='info')\nstanza.download(lang='la', model_dir=models_dir, package='ittb', logging_level='info')\nstanza.download(lang='zh', model_dir=models_dir, logging_level='info')\n# useful not just for verifying RtL, but because the default Arabic has a unique style of xpos tags\nstanza.download(lang='ar', model_dir=models_dir, logging_level='info')\nstanza.download(lang='multilingual', model_dir=models_dir, logging_level='info')\n\nlogger.info(\"DOWNLOADING STANZA TOKENIZERS FOR MORPHSEG TESTS\")\n\nmorphseg_langs = ['en', 'es', 'ru', 'fr', 'it', 'cs', 'hu', 'la']\nfor lang in morphseg_langs:\n    stanza.download(lang=lang, model_dir=models_dir, processors='tokenize', logging_level='info')\n    logger.info(f\"Downloaded {lang} tokenizer for morphseg tests\")\n\nlogger.info(\"DOWNLOADING CORENLP\")\n\ninstallation.install_corenlp(dir=corenlp_dir)\ninstallation.download_corenlp_models(model=\"french\", version=\"main\", dir=corenlp_dir)\ninstallation.download_corenlp_models(model=\"german\", version=\"main\", dir=corenlp_dir)\ninstallation.download_corenlp_models(model=\"italian\", version=\"main\", dir=corenlp_dir)\ninstallation.download_corenlp_models(model=\"spanish\", version=\"main\", dir=corenlp_dir)\n\nlogger.info(\"Test setup completed.\")\n"
  },
  {
    "path": "stanza/tests/tokenization/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/tests/tokenization/test_prepare_tokenizer_treebank.py",
    "content": "import pytest\nimport stanza\nfrom stanza.tests import *\n\nfrom stanza.utils.datasets import prepare_tokenizer_treebank\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_has_space_after_no():\n    assert prepare_tokenizer_treebank.has_space_after_no(\"SpaceAfter=No\")\n    assert prepare_tokenizer_treebank.has_space_after_no(\"UnbanMoxOpal=Yes|SpaceAfter=No\")\n    assert prepare_tokenizer_treebank.has_space_after_no(\"SpaceAfter=No|UnbanMoxOpal=Yes\")\n    assert not prepare_tokenizer_treebank.has_space_after_no(\"SpaceAfter=Yes\")\n    assert not prepare_tokenizer_treebank.has_space_after_no(\"CorrectSpaceAfter=No\")\n    assert not prepare_tokenizer_treebank.has_space_after_no(\"_\")\n\n\ndef test_add_space_after_no():\n    assert prepare_tokenizer_treebank.add_space_after_no(\"_\") == \"SpaceAfter=No\"\n    assert prepare_tokenizer_treebank.add_space_after_no(\"MoxOpal=Unban\") == \"MoxOpal=Unban|SpaceAfter=No\"\n    with pytest.raises(ValueError):\n        prepare_tokenizer_treebank.add_space_after_no(\"SpaceAfter=No\")\n\ndef test_remove_space_after_no():\n    assert prepare_tokenizer_treebank.remove_space_after_no(\"SpaceAfter=No\") == \"_\"\n    assert prepare_tokenizer_treebank.remove_space_after_no(\"SpaceAfter=No|MoxOpal=Unban\") == \"MoxOpal=Unban\"\n    assert prepare_tokenizer_treebank.remove_space_after_no(\"MoxOpal=Unban|SpaceAfter=No\") == \"MoxOpal=Unban\"\n    with pytest.raises(ValueError):\n        prepare_tokenizer_treebank.remove_space_after_no(\"_\")\n\ndef read_test_doc(doc):\n    sentences = [x.strip().split(\"\\n\") for x in doc.split(\"\\n\\n\")]\n    return sentences\n\n\nSPANISH_QM_TEST_CASE = \"\"\"\n# sent_id = train-s7914\n# text = ¿Cómo explicarles entonces que el mar tiene varios dueños y que a partir de la frontera de aquella ola el pescado ya no es tuyo?.\n# orig_file_sentence 080#14\n# this sentence will have the intiial ¿ removed.  an MWT should be preserved\n1\t¿\t¿\tPUNCT\t_\tPunctSide=Ini|PunctType=Qest\t3\tpunct\t_\tSpaceAfter=No\n2\tCómo\tcómo\tPRON\t_\tPronType=Ind\t3\tobl\t_\t_\n3-4\texplicarles\t_\t_\t_\t_\t_\t_\t_\t_\n3\texplicar\texplicar\tVERB\t_\tVerbForm=Inf\t0\troot\t_\t_\n4\tles\tél\tPRON\t_\tCase=Dat|Number=Plur|Person=3|PronType=Prs\t3\tobj\t_\t_\n5\tentonces\tentonces\tADV\t_\t_\t3\tadvmod\t_\t_\n6\tque\tque\tSCONJ\t_\t_\t9\tmark\t_\t_\n7\tel\tel\tDET\t_\tDefinite=Def|Gender=Masc|Number=Sing|PronType=Art\t8\tdet\t_\t_\n8\tmar\tmar\tNOUN\t_\tNumber=Sing\t9\tnsubj\t_\t_\n9\ttiene\ttener\tVERB\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\tccomp\t_\t_\n10\tvarios\tvarios\tDET\t_\tGender=Masc|Number=Plur|PronType=Ind\t11\tdet\t_\t_\n11\tdueños\tdueño\tNOUN\t_\tGender=Masc|Number=Plur\t9\tobj\t_\t_\n12\ty\ty\tCCONJ\t_\t_\t27\tcc\t_\t_\n13\tque\tque\tSCONJ\t_\t_\t27\tmark\t_\t_\n14\ta\ta\tADP\t_\t_\t18\tcase\t_\tMWE=a_partir_de|MWEPOS=ADP\n15\tpartir\tpartir\tNOUN\t_\t_\t14\tfixed\t_\t_\n16\tde\tde\tADP\t_\t_\t14\tfixed\t_\t_\n17\tla\tel\tDET\t_\tDefinite=Def|Gender=Fem|Number=Sing|PronType=Art\t18\tdet\t_\t_\n18\tfrontera\tfrontera\tNOUN\t_\tGender=Fem|Number=Sing\t27\tobl\t_\t_\n19\tde\tde\tADP\t_\t_\t21\tcase\t_\t_\n20\taquella\taquel\tDET\t_\tGender=Fem|Number=Sing|PronType=Dem\t21\tdet\t_\t_\n21\tola\tola\tNOUN\t_\tGender=Fem|Number=Sing\t18\tnmod\t_\t_\n22\tel\tel\tDET\t_\tDefinite=Def|Gender=Masc|Number=Sing|PronType=Art\t23\tdet\t_\t_\n23\tpescado\tpescado\tNOUN\t_\tGender=Masc|Number=Sing\t27\tnsubj\t_\t_\n24\tya\tya\tADV\t_\t_\t27\tadvmod\t_\t_\n25\tno\tno\tADV\t_\tPolarity=Neg\t27\tadvmod\t_\t_\n26\tes\tser\tAUX\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t27\tcop\t_\t_\n27\ttuyo\ttuyo\tPRON\t_\tGender=Masc|Number=Sing|Number[psor]=Sing|Person=2|Poss=Yes|PronType=Ind\t9\tconj\t_\tSpaceAfter=No\n28\t?\t?\tPUNCT\t_\tPunctSide=Fin|PunctType=Qest\t3\tpunct\t_\tSpaceAfter=No\n29\t.\t.\tPUNCT\t_\tPunctType=Peri\t3\tpunct\t_\t_\n\n# sent_id = train-s8516\n# text = ¿ Pero es divertido en la vida real? - -.\n# orig_file_sentence 086#16\n# this sentence will have the ¿ removed even with no SpaceAfter=No\n1\t¿\t¿\tPUNCT\t_\tPunctSide=Ini|PunctType=Qest\t4\tpunct\t_\t_\n2\tPero\tpero\tCCONJ\t_\t_\t4\tadvmod\t_\t_\n3\tes\tser\tAUX\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t4\tcop\t_\t_\n4\tdivertido\tdivertido\tADJ\t_\tGender=Masc|Number=Sing|VerbForm=Part\t0\troot\t_\t_\n5\ten\ten\tADP\t_\t_\t7\tcase\t_\t_\n6\tla\tel\tDET\t_\tDefinite=Def|Gender=Fem|Number=Sing|PronType=Art\t7\tdet\t_\t_\n7\tvida\tvida\tNOUN\t_\tGender=Fem|Number=Sing\t4\tobl\t_\t_\n8\treal\treal\tADJ\t_\tNumber=Sing\t7\tamod\t_\tSpaceAfter=No\n9\t?\t?\tPUNCT\t_\tPunctSide=Fin|PunctType=Qest\t4\tpunct\t_\t_\n10\t-\t-\tPUNCT\t_\tPunctType=Dash\t4\tpunct\t_\t_\n11\t-\t-\tPUNCT\t_\tPunctType=Dash\t4\tpunct\t_\tSpaceAfter=No\n12\t.\t.\tPUNCT\t_\tPunctType=Peri\t4\tpunct\t_\t_\n\n# sent_id = train-s2337\n# text = Es imposible.\n# orig_file_sentence 024#37\n# Also included is a sentence which should be skipped (note that it does not show up in the expected result)\n1\tEs\tser\tAUX\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t2\tcop\t_\t_\n2\timposible\timposible\tADJ\t_\tNumber=Sing\t0\troot\t_\tSpaceAfter=No\n3\t.\t.\tPUNCT\t_\tPunctType=Peri\t2\tpunct\t_\t_\n\n# sent_id = 3LB-CAST-a1-2-s6\n# text = ¿Para qué seguir?\n# orig_file_sentence 006#22\n# The treebank now includes basic dependencies in the additional dependencies column\n1\t¿\t¿\tPUNCT\tfia\tPunctSide=Ini|PunctType=Qest\t4\tpunct\t4:punct\tSpaceAfter=No\n2\tPara\tpara\tADP\tsps00\t_\t3\tcase\t3:case\t_\n3\tqué\tqué\tPRON\tpt0cs000\tNumber=Sing|PronType=Int,Rel\t4\tobl\t4:obl\t_\n4\tseguir\tseguir\tVERB\tvmn0000\tVerbForm=Inf\t0\troot\t0:root\tSpaceAfter=No\n5\t?\t?\tPUNCT\tfit\tPunctSide=Fin|PunctType=Qest\t4\tpunct\t4:punct\t_\n\n# sent_id = CESS-CAST-P-19990901-16-s19\n# text = ¿Estará fingiendo?.\n# orig_file_sentence 097#24\n# also it includes some copy nodes\n1\t¿\t¿\tPUNCT\tfia\tPunctSide=Ini|PunctType=Qest\t3\tpunct\t3:punct\tSpaceAfter=No\n2\tEstará\testar\tAUX\tvmif3s0\tMood=Ind|Number=Sing|Person=3|Tense=Fut|VerbForm=Fin\t3\taux\t3:aux\t_\n3\tfingiendo\tfingir\tVERB\tvmg0000\tVerbForm=Ger\t0\troot\t0:root\tSpaceAfter=No\n3.1\t_\t_\tPRON\tp\t_\t_\t_\t3:nsubj\tEntity=(CESSCASTP1999090116c2-person-1-CorefType:ident,gstype:spec)\n4\t?\t?\tPUNCT\tfit\tPunctSide=Fin|PunctType=Qest\t3\tpunct\t3:punct\tSpaceAfter=No\n5\t.\t.\tPUNCT\tfp\tPunctType=Peri\t3\tpunct\t3:punct\t_\n\n# sent_id = CESS-CAST-P-20000401-126-s31\n# text = ¿Qué pensó cuando se quedó\n# orig_file_sentence 087#37\n# this one has a colon in the dependency name\n1\t¿\t¿\tPUNCT\tfia\tPunctSide=Ini|PunctType=Qest\t3\tpunct\t3:punct\tSpaceAfter=No|Entity=(CESSCASTP20000401126c27--3\n2\tQué\tqué\tPRON\tpt0cs000\tNumber=Sing|PronType=Int,Rel\t3\tobj\t3:obj\t_\n3\tpensó\tpensar\tVERB\tvmis3s0\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n3.1\t_\t_\tPRON\tp\t_\t_\t_\t3:nsubj\tEntity=(CESSCASTP20000401126c1-person-1-CorefType:ident,gstype:spec)\n4\tcuando\tcuando\tSCONJ\tcs\t_\t6\tmark\t6:mark\t_\n4.1\t_\t_\tPRON\tp\t_\t_\t_\t6:nsubj\tEntity=(CESSCASTP20000401126c1-person-1-CorefType:ident,gstype:spec)\n5\tse\tél\tPRON\tp0300000\tCase=Acc|Person=3|PrepCase=Npr|PronType=Prs|Reflex=Yes\t6\texpl:pv\t6:expl:pv\t_\n6\tquedó\tquedar\tVERB\tvmis3s0\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t3\tadvcl\t3:advcl\t_\n\"\"\"\n\nSPANISH_QM_RESULT = \"\"\"\n# sent_id = train-s7914\n# text = Cómo explicarles entonces que el mar tiene varios dueños y que a partir de la frontera de aquella ola el pescado ya no es tuyo?.\n# orig_file_sentence 080#14\n# this sentence will have the intiial ¿ removed.  an MWT should be preserved\n1\tCómo\tcómo\tPRON\t_\tPronType=Ind\t2\tobl\t_\t_\n2-3\texplicarles\t_\t_\t_\t_\t_\t_\t_\t_\n2\texplicar\texplicar\tVERB\t_\tVerbForm=Inf\t0\troot\t_\t_\n3\tles\tél\tPRON\t_\tCase=Dat|Number=Plur|Person=3|PronType=Prs\t2\tobj\t_\t_\n4\tentonces\tentonces\tADV\t_\t_\t2\tadvmod\t_\t_\n5\tque\tque\tSCONJ\t_\t_\t8\tmark\t_\t_\n6\tel\tel\tDET\t_\tDefinite=Def|Gender=Masc|Number=Sing|PronType=Art\t7\tdet\t_\t_\n7\tmar\tmar\tNOUN\t_\tNumber=Sing\t8\tnsubj\t_\t_\n8\ttiene\ttener\tVERB\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t2\tccomp\t_\t_\n9\tvarios\tvarios\tDET\t_\tGender=Masc|Number=Plur|PronType=Ind\t10\tdet\t_\t_\n10\tdueños\tdueño\tNOUN\t_\tGender=Masc|Number=Plur\t8\tobj\t_\t_\n11\ty\ty\tCCONJ\t_\t_\t26\tcc\t_\t_\n12\tque\tque\tSCONJ\t_\t_\t26\tmark\t_\t_\n13\ta\ta\tADP\t_\t_\t17\tcase\t_\tMWE=a_partir_de|MWEPOS=ADP\n14\tpartir\tpartir\tNOUN\t_\t_\t13\tfixed\t_\t_\n15\tde\tde\tADP\t_\t_\t13\tfixed\t_\t_\n16\tla\tel\tDET\t_\tDefinite=Def|Gender=Fem|Number=Sing|PronType=Art\t17\tdet\t_\t_\n17\tfrontera\tfrontera\tNOUN\t_\tGender=Fem|Number=Sing\t26\tobl\t_\t_\n18\tde\tde\tADP\t_\t_\t20\tcase\t_\t_\n19\taquella\taquel\tDET\t_\tGender=Fem|Number=Sing|PronType=Dem\t20\tdet\t_\t_\n20\tola\tola\tNOUN\t_\tGender=Fem|Number=Sing\t17\tnmod\t_\t_\n21\tel\tel\tDET\t_\tDefinite=Def|Gender=Masc|Number=Sing|PronType=Art\t22\tdet\t_\t_\n22\tpescado\tpescado\tNOUN\t_\tGender=Masc|Number=Sing\t26\tnsubj\t_\t_\n23\tya\tya\tADV\t_\t_\t26\tadvmod\t_\t_\n24\tno\tno\tADV\t_\tPolarity=Neg\t26\tadvmod\t_\t_\n25\tes\tser\tAUX\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t26\tcop\t_\t_\n26\ttuyo\ttuyo\tPRON\t_\tGender=Masc|Number=Sing|Number[psor]=Sing|Person=2|Poss=Yes|PronType=Ind\t8\tconj\t_\tSpaceAfter=No\n27\t?\t?\tPUNCT\t_\tPunctSide=Fin|PunctType=Qest\t2\tpunct\t_\tSpaceAfter=No\n28\t.\t.\tPUNCT\t_\tPunctType=Peri\t2\tpunct\t_\t_\n\n# sent_id = train-s8516\n# text = Pero es divertido en la vida real? - -.\n# orig_file_sentence 086#16\n# this sentence will have the ¿ removed even with no SpaceAfter=No\n1\tPero\tpero\tCCONJ\t_\t_\t3\tadvmod\t_\t_\n2\tes\tser\tAUX\t_\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\tcop\t_\t_\n3\tdivertido\tdivertido\tADJ\t_\tGender=Masc|Number=Sing|VerbForm=Part\t0\troot\t_\t_\n4\ten\ten\tADP\t_\t_\t6\tcase\t_\t_\n5\tla\tel\tDET\t_\tDefinite=Def|Gender=Fem|Number=Sing|PronType=Art\t6\tdet\t_\t_\n6\tvida\tvida\tNOUN\t_\tGender=Fem|Number=Sing\t3\tobl\t_\t_\n7\treal\treal\tADJ\t_\tNumber=Sing\t6\tamod\t_\tSpaceAfter=No\n8\t?\t?\tPUNCT\t_\tPunctSide=Fin|PunctType=Qest\t3\tpunct\t_\t_\n9\t-\t-\tPUNCT\t_\tPunctType=Dash\t3\tpunct\t_\t_\n10\t-\t-\tPUNCT\t_\tPunctType=Dash\t3\tpunct\t_\tSpaceAfter=No\n11\t.\t.\tPUNCT\t_\tPunctType=Peri\t3\tpunct\t_\t_\n\n# sent_id = 3LB-CAST-a1-2-s6\n# text = Para qué seguir?\n# orig_file_sentence 006#22\n# The treebank now includes basic dependencies in the additional dependencies column\n1\tPara\tpara\tADP\tsps00\t_\t2\tcase\t2:case\t_\n2\tqué\tqué\tPRON\tpt0cs000\tNumber=Sing|PronType=Int,Rel\t3\tobl\t3:obl\t_\n3\tseguir\tseguir\tVERB\tvmn0000\tVerbForm=Inf\t0\troot\t0:root\tSpaceAfter=No\n4\t?\t?\tPUNCT\tfit\tPunctSide=Fin|PunctType=Qest\t3\tpunct\t3:punct\t_\n\n# sent_id = CESS-CAST-P-19990901-16-s19\n# text = Estará fingiendo?.\n# orig_file_sentence 097#24\n# also it includes some copy nodes\n1\tEstará\testar\tAUX\tvmif3s0\tMood=Ind|Number=Sing|Person=3|Tense=Fut|VerbForm=Fin\t2\taux\t2:aux\t_\n2\tfingiendo\tfingir\tVERB\tvmg0000\tVerbForm=Ger\t0\troot\t0:root\tSpaceAfter=No\n2.1\t_\t_\tPRON\tp\t_\t_\t_\t2:nsubj\tEntity=(CESSCASTP1999090116c2-person-1-CorefType:ident,gstype:spec)\n3\t?\t?\tPUNCT\tfit\tPunctSide=Fin|PunctType=Qest\t2\tpunct\t2:punct\tSpaceAfter=No\n4\t.\t.\tPUNCT\tfp\tPunctType=Peri\t2\tpunct\t2:punct\t_\n\n# sent_id = CESS-CAST-P-20000401-126-s31\n# text = Qué pensó cuando se quedó\n# orig_file_sentence 087#37\n# this one has a colon in the dependency name\n1\tQué\tqué\tPRON\tpt0cs000\tNumber=Sing|PronType=Int,Rel\t2\tobj\t2:obj\t_\n2\tpensó\tpensar\tVERB\tvmis3s0\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n2.1\t_\t_\tPRON\tp\t_\t_\t_\t2:nsubj\tEntity=(CESSCASTP20000401126c1-person-1-CorefType:ident,gstype:spec)\n3\tcuando\tcuando\tSCONJ\tcs\t_\t5\tmark\t5:mark\t_\n3.1\t_\t_\tPRON\tp\t_\t_\t_\t5:nsubj\tEntity=(CESSCASTP20000401126c1-person-1-CorefType:ident,gstype:spec)\n4\tse\tél\tPRON\tp0300000\tCase=Acc|Person=3|PrepCase=Npr|PronType=Prs|Reflex=Yes\t5\texpl:pv\t5:expl:pv\t_\n5\tquedó\tquedar\tVERB\tvmis3s0\tMood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin\t2\tadvcl\t2:advcl\t_\n\"\"\"\n\ndef test_augment_initial_punct():\n    doc = read_test_doc(SPANISH_QM_TEST_CASE)\n    doc2 = prepare_tokenizer_treebank.augment_initial_punct(doc, ratio=1.0)\n    expected = doc + read_test_doc(SPANISH_QM_RESULT)\n    assert doc2 == expected\n\nSPANISH_SHOULD_THROW = \"\"\"\n# sent_id = 3LB-CAST-a1-2-s6\n# text = ¿Para qué seguir?\n# orig_file_sentence 006#22\n# multiple heads are not handled yet in the augmented dependencies column\n1\t¿\t¿\tPUNCT\tfia\tPunctSide=Ini|PunctType=Qest\t4\tpunct\t4:punct\tSpaceAfter=No\n2\tPara\tpara\tADP\tsps00\t_\t3\tcase\t3:case\t_\n3\tqué\tqué\tPRON\tpt0cs000\tNumber=Sing|PronType=Int,Rel\t4\tobl\t4:obl,3:foo\t_\n4\tseguir\tseguir\tVERB\tvmn0000\tVerbForm=Inf\t0\troot\t0:root\tSpaceAfter=No\n5\t?\t?\tPUNCT\tfit\tPunctSide=Fin|PunctType=Qest\t4\tpunct\t4:punct\t_\n\"\"\"\n\ndef test_augment_initial_punct_error():\n    \"\"\"\n    The augment script should protect against the single dependency assumption changing in the future\n    \"\"\"\n    doc = read_test_doc(SPANISH_SHOULD_THROW)\n    with pytest.raises(NotImplementedError):\n        doc2 = prepare_tokenizer_treebank.augment_initial_punct(doc, ratio=1.0)\n\n# first sentence should have the space added\n# second sentence should be unchanged\nARABIC_SPACE_AFTER_TEST_CASE = \"\"\"\n# newpar id = afp.20000815.0079:p6\n# sent_id = afp.20000815.0079:p6u1\n# text = وتتميز امسية الاربعاء الدولية باقامة 16 مباراة ودية.\n# orig_file_sentence AFP_ARB_20000815.0079#6\n1-2\tوتتميز\t_\t_\t_\t_\t_\t_\t_\t_\n1\tو\tوَ\tCCONJ\tC---------\t_\t0\troot\t0:root\tVform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa\n2\tتتميز\tتَمَيَّز\tVERB\tVIIA-3FS--\tAspect=Imp|Gender=Fem|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act\t1\tparataxis\t1:parataxis\tVform=تَتَمَيَّزُ|Gloss=be_distinguished,stand_out,discern,distinguish|Root=m_y_z|Translit=tatamayyazu|LTranslit=tamayyaz\n3\tامسية\tأُمسِيَّة\tNOUN\tN------S1R\tCase=Nom|Definite=Cons|Number=Sing\t2\tnsubj\t2:nsubj\tVform=أُمسِيَّةُ|Gloss=evening,soiree|Root=m_s_w|Translit=ʾumsīyatu|LTranslit=ʾumsīyat\n4\tالاربعاء\tأَربِعَاء\tNOUN\tN------S2D\tCase=Gen|Definite=Def|Number=Sing\t3\tnmod\t3:nmod:gen\tVform=اَلأَربِعَاءِ|Gloss=Wednesday|Root=r_b_`|Translit=al-ʾarbiʿāʾi|LTranslit=ʾarbiʿāʾ\n5\tالدولية\tدُوَلِيّ\tADJ\tA-----FS1D\tCase=Nom|Definite=Def|Gender=Fem|Number=Sing\t3\tamod\t3:amod\tVform=اَلدُّوَلِيَّةُ|Gloss=international,world|Root=d_w_l|Translit=ad-duwalīyatu|LTranslit=duwalīy\n6-7\tباقامة\t_\t_\t_\t_\t_\t_\t_\t_\n6\tب\tبِ\tADP\tP---------\tAdpType=Prep\t7\tcase\t7:case\tVform=بِ|Gloss=by,with|Root=bi|Translit=bi|LTranslit=bi\n7\tإقامة\tإِقَامَة\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t2\tobl\t2:obl:بِ:gen\tVform=إِقَامَةِ|Gloss=residency,setting_up|Root=q_w_m|Translit=ʾiqāmati|LTranslit=ʾiqāmat\n8\t16\t16\tNUM\tQ---------\tNumForm=Digit\t7\tnummod\t7:nummod\tVform=١٦|Translit=16\n9\tمباراة\tمُبَارَاة\tNOUN\tN------S4I\tCase=Acc|Definite=Ind|Number=Sing\t8\tnmod\t8:nmod:acc\tVform=مُبَارَاةً|Gloss=match,game,competition|Root=b_r_y|Translit=mubārātan|LTranslit=mubārāt\n10\tودية\tوُدِّيّ\tADJ\tA-----FS4I\tCase=Acc|Definite=Ind|Gender=Fem|Number=Sing\t9\tamod\t9:amod\tSpaceAfter=No|Vform=وُدِّيَّةً|Gloss=friendly,amicable|Root=w_d_d|Translit=wuddīyatan|LTranslit=wuddīy\n11\t.\t.\tPUNCT\tG---------\t_\t1\tpunct\t1:punct\tVform=.|Translit=.\n\n# newdoc id = afp.20000715.0075\n# newpar id = afp.20000715.0075:p1\n# sent_id = afp.20000715.0075:p1u1\n# text = برلين ترفض حصول شركة اميركية على رخصة تصنيع دبابة \"ليوبارد\" الالمانية\n# orig_file_sentence AFP_ARB_20000715.0075#1\n1\tبرلين\tبَرلِين\tX\tX---------\tForeign=Yes\t2\tnsubj\t2:nsubj\tVform=بَرلِين|Gloss=Berlin|Root=barlIn|Translit=barlīn|LTranslit=barlīn\n2\tترفض\tرَفَض\tVERB\tVIIA-3FS--\tAspect=Imp|Gender=Fem|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act\t0\troot\t0:root\tVform=تَرفُضُ|Gloss=reject,refuse|Root=r_f_.d|Translit=tarfuḍu|LTranslit=rafaḍ\n3\tحصول\tحُصُول\tNOUN\tN------S4R\tCase=Acc|Definite=Cons|Number=Sing\t2\tobj\t2:obj\tVform=حُصُولَ|Gloss=acquisition,obtaining,occurrence,happening|Root=.h_.s_l|Translit=ḥuṣūla|LTranslit=ḥuṣūl\n4\tشركة\tشَرِكَة\tNOUN\tN------S2I\tCase=Gen|Definite=Ind|Number=Sing\t3\tnmod\t3:nmod:gen\tVform=شَرِكَةٍ|Gloss=company,corporation|Root=^s_r_k|Translit=šarikatin|LTranslit=šarikat\n5\tاميركية\tأَمِيرِكِيّ\tADJ\tA-----FS2I\tCase=Gen|Definite=Ind|Gender=Fem|Number=Sing\t4\tamod\t4:amod\tVform=أَمِيرِكِيَّةٍ|Gloss=American|Root='amIrik|Translit=ʾamīrikīyatin|LTranslit=ʾamīrikīy\n6\tعلى\tعَلَى\tADP\tP---------\tAdpType=Prep\t7\tcase\t7:case\tVform=عَلَى|Gloss=on,above|Root=`_l_w|Translit=ʿalā|LTranslit=ʿalā\n7\tرخصة\tرُخصَة\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t3\tobl:arg\t3:obl:arg:عَلَى:gen\tVform=رُخصَةِ|Gloss=license,permit|Root=r__h_.s|Translit=ruḫṣati|LTranslit=ruḫṣat\n8\tتصنيع\tتَصنِيع\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t7\tnmod\t7:nmod:gen\tVform=تَصنِيعِ|Gloss=fabrication,industrialization,processing|Root=.s_n_`|Translit=taṣnīʿi|LTranslit=taṣnīʿ\n9\tدبابة\tدَبَّابَة\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t8\tnmod\t8:nmod:gen\tVform=دَبَّابَةِ|Gloss=tank|Root=d_b_b|Translit=dabbābati|LTranslit=dabbābat\n10\t\"\t\"\tPUNCT\tG---------\t_\t11\tpunct\t11:punct\tSpaceAfter=No|Vform=\"|Translit=\"\n11\tليوبارد\tلِيُوبَارد\tX\tX---------\tForeign=Yes\t9\tnmod\t9:nmod\tSpaceAfter=No|Vform=لِيُوبَارد|Gloss=Leopard|Root=liyUbArd|Translit=liyūbārd|LTranslit=liyūbārd\n12\t\"\t\"\tPUNCT\tG---------\t_\t11\tpunct\t11:punct\tVform=\"|Translit=\"\n13\tالالمانية\tأَلمَانِيّ\tADJ\tA-----FS2D\tCase=Gen|Definite=Def|Gender=Fem|Number=Sing\t9\tamod\t9:amod\tVform=اَلأَلمَانِيَّةِ|Gloss=German|Root='almAn|Translit=al-ʾalmānīyati|LTranslit=ʾalmānīy\n\"\"\"\n\nARABIC_SPACE_AFTER_RESULT = \"\"\"\n# newpar id = afp.20000815.0079:p6\n# sent_id = afp.20000815.0079:p6u1\n# text = وتتميز امسية الاربعاء الدولية باقامة 16 مباراة ودية .\n# orig_file_sentence AFP_ARB_20000815.0079#6\n1-2\tوتتميز\t_\t_\t_\t_\t_\t_\t_\t_\n1\tو\tوَ\tCCONJ\tC---------\t_\t0\troot\t0:root\tVform=وَ|Gloss=and|Root=wa|Translit=wa|LTranslit=wa\n2\tتتميز\tتَمَيَّز\tVERB\tVIIA-3FS--\tAspect=Imp|Gender=Fem|Mood=Ind|Number=Sing|Person=3|VerbForm=Fin|Voice=Act\t1\tparataxis\t1:parataxis\tVform=تَتَمَيَّزُ|Gloss=be_distinguished,stand_out,discern,distinguish|Root=m_y_z|Translit=tatamayyazu|LTranslit=tamayyaz\n3\tامسية\tأُمسِيَّة\tNOUN\tN------S1R\tCase=Nom|Definite=Cons|Number=Sing\t2\tnsubj\t2:nsubj\tVform=أُمسِيَّةُ|Gloss=evening,soiree|Root=m_s_w|Translit=ʾumsīyatu|LTranslit=ʾumsīyat\n4\tالاربعاء\tأَربِعَاء\tNOUN\tN------S2D\tCase=Gen|Definite=Def|Number=Sing\t3\tnmod\t3:nmod:gen\tVform=اَلأَربِعَاءِ|Gloss=Wednesday|Root=r_b_`|Translit=al-ʾarbiʿāʾi|LTranslit=ʾarbiʿāʾ\n5\tالدولية\tدُوَلِيّ\tADJ\tA-----FS1D\tCase=Nom|Definite=Def|Gender=Fem|Number=Sing\t3\tamod\t3:amod\tVform=اَلدُّوَلِيَّةُ|Gloss=international,world|Root=d_w_l|Translit=ad-duwalīyatu|LTranslit=duwalīy\n6-7\tباقامة\t_\t_\t_\t_\t_\t_\t_\t_\n6\tب\tبِ\tADP\tP---------\tAdpType=Prep\t7\tcase\t7:case\tVform=بِ|Gloss=by,with|Root=bi|Translit=bi|LTranslit=bi\n7\tإقامة\tإِقَامَة\tNOUN\tN------S2R\tCase=Gen|Definite=Cons|Number=Sing\t2\tobl\t2:obl:بِ:gen\tVform=إِقَامَةِ|Gloss=residency,setting_up|Root=q_w_m|Translit=ʾiqāmati|LTranslit=ʾiqāmat\n8\t16\t16\tNUM\tQ---------\tNumForm=Digit\t7\tnummod\t7:nummod\tVform=١٦|Translit=16\n9\tمباراة\tمُبَارَاة\tNOUN\tN------S4I\tCase=Acc|Definite=Ind|Number=Sing\t8\tnmod\t8:nmod:acc\tVform=مُبَارَاةً|Gloss=match,game,competition|Root=b_r_y|Translit=mubārātan|LTranslit=mubārāt\n10\tودية\tوُدِّيّ\tADJ\tA-----FS4I\tCase=Acc|Definite=Ind|Gender=Fem|Number=Sing\t9\tamod\t9:amod\tVform=وُدِّيَّةً|Gloss=friendly,amicable|Root=w_d_d|Translit=wuddīyatan|LTranslit=wuddīy\n11\t.\t.\tPUNCT\tG---------\t_\t1\tpunct\t1:punct\tVform=.|Translit=.\n\"\"\"\n\ndef test_augment_space_final_punct():\n    doc = read_test_doc(ARABIC_SPACE_AFTER_TEST_CASE)\n    doc2 = prepare_tokenizer_treebank.augment_arabic_padt(doc, ratio=1.0)\n    expected = doc + read_test_doc(ARABIC_SPACE_AFTER_RESULT)\n    assert doc2 == expected\n\nENGLISH_COMMA_SWAP_TEST_CASE=\"\"\"\n# sent_id = reviews-086839-0004\n# text = Approx 4 months later, the compressor went out.\n1\tApprox\tapprox\tADV\tRB\t_\t3\tadvmod\t3:advmod\t_\n2\t4\t4\tNUM\tCD\tNumType=Card\t3\tnummod\t3:nummod\t_\n3\tmonths\tmonth\tNOUN\tNNS\tNumber=Plur\t4\tobl:npmod\t4:obl:npmod\t_\n4\tlater\tlate\tADV\tRBR\tDegree=Cmp\t8\tadvmod\t8:advmod\tSpaceAfter=No\n5\t,\t,\tPUNCT\t,\t_\t8\tpunct\t8:punct\t_\n6\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t7\tdet\t7:det\t_\n7\tcompressor\tcompressor\tNOUN\tNN\tNumber=Sing\t8\tnsubj\t8:nsubj\t_\n8\twent\tgo\tVERB\tVBD\tMood=Ind|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n9\tout\tout\tADP\tRP\t_\t8\tcompound:prt\t8:compound:prt\tSpaceAfter=No\n10\t.\t.\tPUNCT\t.\t_\t8\tpunct\t8:punct\t_\n\n# sent_id = reviews-086839-0004b\n# text = Approx 4 months later , the compressor went out.\n1\tApprox\tapprox\tADV\tRB\t_\t3\tadvmod\t3:advmod\t_\n2\t4\t4\tNUM\tCD\tNumType=Card\t3\tnummod\t3:nummod\t_\n3\tmonths\tmonth\tNOUN\tNNS\tNumber=Plur\t4\tobl:npmod\t4:obl:npmod\t_\n4\tlater\tlate\tADV\tRBR\tDegree=Cmp\t8\tadvmod\t8:advmod\t_\n5\t,\t,\tPUNCT\t,\t_\t8\tpunct\t8:punct\t_\n6\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t7\tdet\t7:det\t_\n7\tcompressor\tcompressor\tNOUN\tNN\tNumber=Sing\t8\tnsubj\t8:nsubj\t_\n8\twent\tgo\tVERB\tVBD\tMood=Ind|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n9\tout\tout\tADP\tRP\t_\t8\tcompound:prt\t8:compound:prt\tSpaceAfter=No\n10\t.\t.\tPUNCT\t.\t_\t8\tpunct\t8:punct\t_\n\"\"\"\n\nENGLISH_COMMA_SWAP_RESULT=\"\"\"\n# sent_id = reviews-086839-0004\n# text = Approx 4 months later ,the compressor went out.\n1\tApprox\tapprox\tADV\tRB\t_\t3\tadvmod\t3:advmod\t_\n2\t4\t4\tNUM\tCD\tNumType=Card\t3\tnummod\t3:nummod\t_\n3\tmonths\tmonth\tNOUN\tNNS\tNumber=Plur\t4\tobl:npmod\t4:obl:npmod\t_\n4\tlater\tlate\tADV\tRBR\tDegree=Cmp\t8\tadvmod\t8:advmod\t_\n5\t,\t,\tPUNCT\t,\t_\t8\tpunct\t8:punct\tSpaceAfter=No\n6\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t7\tdet\t7:det\t_\n7\tcompressor\tcompressor\tNOUN\tNN\tNumber=Sing\t8\tnsubj\t8:nsubj\t_\n8\twent\tgo\tVERB\tVBD\tMood=Ind|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n9\tout\tout\tADP\tRP\t_\t8\tcompound:prt\t8:compound:prt\tSpaceAfter=No\n10\t.\t.\tPUNCT\t.\t_\t8\tpunct\t8:punct\t_\n\n# sent_id = reviews-086839-0004b\n# text = Approx 4 months later , the compressor went out.\n1\tApprox\tapprox\tADV\tRB\t_\t3\tadvmod\t3:advmod\t_\n2\t4\t4\tNUM\tCD\tNumType=Card\t3\tnummod\t3:nummod\t_\n3\tmonths\tmonth\tNOUN\tNNS\tNumber=Plur\t4\tobl:npmod\t4:obl:npmod\t_\n4\tlater\tlate\tADV\tRBR\tDegree=Cmp\t8\tadvmod\t8:advmod\t_\n5\t,\t,\tPUNCT\t,\t_\t8\tpunct\t8:punct\t_\n6\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t7\tdet\t7:det\t_\n7\tcompressor\tcompressor\tNOUN\tNN\tNumber=Sing\t8\tnsubj\t8:nsubj\t_\n8\twent\tgo\tVERB\tVBD\tMood=Ind|Tense=Past|VerbForm=Fin\t0\troot\t0:root\t_\n9\tout\tout\tADP\tRP\t_\t8\tcompound:prt\t8:compound:prt\tSpaceAfter=No\n10\t.\t.\tPUNCT\t.\t_\t8\tpunct\t8:punct\t_\n\"\"\"\n\ndef test_augment_space_final_punct():\n    doc = read_test_doc(ENGLISH_COMMA_SWAP_TEST_CASE)\n    doc2 = prepare_tokenizer_treebank.augment_move_comma(doc, ratio=1.0)\n    expected = read_test_doc(ENGLISH_COMMA_SWAP_RESULT)\n    assert doc2 == expected\n\nCOMMA_SEP_TEST_CASE = \"\"\"\n# text = Fuzzy people, floating people\n1\tFuzzy\tfuzzy\tADJ\tJJ\tDegree=Pos\t2\tamod\t2:amod\t_\n2\tpeople\tpeople\tNOUN\tNNS\tNumber=Plur\t0\troot\t0:root\tSpaceAfter=No\n3\t,\t,\tPUNCT\t,\t_\t2\tpunct\t2:punct\t_\n4\tfloating\tfloat\tVERB\tVBG\tVerbForm=Ger\t5\tamod\t5:amod\t_\n5\tpeople\tpeople\tNOUN\tNNS\tNumber=Plur\t2\tappos\t2:appos\t_\n\"\"\"\n\nCOMMA_SEP_TEST_EXPECTED = \"\"\"\n# text = Fuzzy people,floating people\n1\tFuzzy\tfuzzy\tADJ\tJJ\tDegree=Pos\t2\tamod\t2:amod\t_\n2\tpeople\tpeople\tNOUN\tNNS\tNumber=Plur\t0\troot\t0:root\tSpaceAfter=No\n3\t,\t,\tPUNCT\t,\t_\t2\tpunct\t2:punct\tSpaceAfter=No\n4\tfloating\tfloat\tVERB\tVBG\tVerbForm=Ger\t5\tamod\t5:amod\t_\n5\tpeople\tpeople\tNOUN\tNNS\tNumber=Plur\t2\tappos\t2:appos\t_\n\"\"\"\n\ndef test_augment_comma_separations():\n    doc = read_test_doc(COMMA_SEP_TEST_CASE)\n    doc2 = prepare_tokenizer_treebank.augment_comma_separations(doc, ratio=1.0)\n\n    assert len(doc2) == 2\n    expected = read_test_doc(COMMA_SEP_TEST_EXPECTED)\n    assert doc2[1] == expected[0]\n"
  },
  {
    "path": "stanza/tests/tokenization/test_replace_long_tokens.py",
    "content": "\"\"\"\nCheck to make sure long tokens are replaced with \"UNK\" by the tokenization processor\n\"\"\"\nimport pytest\nimport stanza\n\nfrom stanza.pipeline import tokenize_processor\n\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\ndef test_replace_long_tokens():\n    nlp = stanza.Pipeline(lang=\"en\", download_method=None, model_dir=TEST_MODELS_DIR, processors=\"tokenize\")\n\n    test_str = \"foo \" + \"x\" * 10000 + \" bar\"\n\n    res = nlp(test_str)\n\n    assert res.sentences[0].words[1].text == tokenize_processor.TOKEN_TOO_LONG_REPLACEMENT\n\ndef test_set_max_len():\n    nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR,\n                             'lang': 'en',\n                             'download_method': None,\n                             'tokenize_max_seqlen': 20})\n    doc = nlp(\"This is a doc withaverylongtokenthatshouldbereplaced\")\n    assert len(doc.sentences) == 1\n    assert len(doc.sentences[0].words) == 5\n    assert doc.sentences[0].words[-1].text == tokenize_processor.TOKEN_TOO_LONG_REPLACEMENT\n"
  },
  {
    "path": "stanza/tests/tokenization/test_spaces.py",
    "content": "\"\"\"\nTest that when tokenizing a document, the Space annotations get set the way we expect\n\"\"\"\n\nimport stanza\nfrom stanza.tests import TEST_MODELS_DIR\n\nEXPECTED_NO_MWT = \"\"\"\n# text = Jennifer has nice antennae.\n# sent_id = 0\n1\tJennifer\t_\t_\t_\t_\t0\t_\t_\tSpacesBefore=\\\\s\\\\s|start_char=2|end_char=10\n2\thas\t_\t_\t_\t_\t1\t_\t_\tstart_char=11|end_char=14\n3\tnice\t_\t_\t_\t_\t2\t_\t_\tstart_char=15|end_char=19\n4\tantennae\t_\t_\t_\t_\t3\t_\t_\tSpaceAfter=No|start_char=20|end_char=28\n5\t.\t_\t_\t_\t_\t4\t_\t_\tSpacesAfter=\\\\s\\\\s|start_char=28|end_char=29\n\n# text = Not very nice person, though.\n# sent_id = 1\n1\tNot\t_\t_\t_\t_\t0\t_\t_\tstart_char=31|end_char=34\n2\tvery\t_\t_\t_\t_\t1\t_\t_\tstart_char=35|end_char=39\n3\tnice\t_\t_\t_\t_\t2\t_\t_\tstart_char=40|end_char=44\n4\tperson\t_\t_\t_\t_\t3\t_\t_\tSpaceAfter=No|start_char=45|end_char=51\n5\t,\t_\t_\t_\t_\t4\t_\t_\tstart_char=51|end_char=52\n6\tthough\t_\t_\t_\t_\t5\t_\t_\tSpaceAfter=No|start_char=53|end_char=59\n7\t.\t_\t_\t_\t_\t6\t_\t_\tSpacesAfter=\\\\s\\\\s|start_char=59|end_char=60\n\"\"\".strip()\n\ndef test_spaces_no_mwt():\n    \"\"\"\n    Test what happens if the words in a document have SpacesBefore and/or After\n    \"\"\"\n    nlp = stanza.Pipeline(**{'processors': 'tokenize', 'download_method': None, 'dir': TEST_MODELS_DIR, 'lang': 'en'})\n    doc = nlp(\"  Jennifer has nice antennae.  Not very nice person, though.  \")\n    result = \"{:C}\".format(doc)\n    result = result.strip()\n    assert EXPECTED_NO_MWT == result\n\nEXPECTED_MWT = \"\"\"\n# text = She's not a nice person.\n# sent_id = 0\n1-2\tShe's\t_\t_\t_\t_\t_\t_\t_\tSpacesBefore=\\\\s\\\\s|start_char=2|end_char=7\n1\tShe\t_\t_\t_\t_\t0\t_\t_\tstart_char=2|end_char=5\n2\t's\t_\t_\t_\t_\t1\t_\t_\tstart_char=5|end_char=7\n3\tnot\t_\t_\t_\t_\t2\t_\t_\tstart_char=8|end_char=11\n4\ta\t_\t_\t_\t_\t3\t_\t_\tstart_char=12|end_char=13\n5\tnice\t_\t_\t_\t_\t4\t_\t_\tstart_char=14|end_char=18\n6\tperson\t_\t_\t_\t_\t5\t_\t_\tSpaceAfter=No|start_char=19|end_char=25\n7\t.\t_\t_\t_\t_\t6\t_\t_\tSpacesAfter=\\\\s\\\\s|start_char=25|end_char=26\n\n# text = However, the best antennae on the Cerritos are Jennifer's.\n# sent_id = 1\n1\tHowever\t_\t_\t_\t_\t0\t_\t_\tSpaceAfter=No|start_char=28|end_char=35\n2\t,\t_\t_\t_\t_\t1\t_\t_\tstart_char=35|end_char=36\n3\tthe\t_\t_\t_\t_\t2\t_\t_\tstart_char=37|end_char=40\n4\tbest\t_\t_\t_\t_\t3\t_\t_\tstart_char=41|end_char=45\n5\tantennae\t_\t_\t_\t_\t4\t_\t_\tstart_char=46|end_char=54\n6\ton\t_\t_\t_\t_\t5\t_\t_\tstart_char=55|end_char=57\n7\tthe\t_\t_\t_\t_\t6\t_\t_\tstart_char=58|end_char=61\n8\tCerritos\t_\t_\t_\t_\t7\t_\t_\tstart_char=62|end_char=70\n9\tare\t_\t_\t_\t_\t8\t_\t_\tstart_char=71|end_char=74\n10-11\tJennifer's\t_\t_\t_\t_\t_\t_\t_\tSpaceAfter=No|start_char=75|end_char=85\n10\tJennifer\t_\t_\t_\t_\t9\t_\t_\tstart_char=75|end_char=83\n11\t's\t_\t_\t_\t_\t10\t_\t_\tstart_char=83|end_char=85\n12\t.\t_\t_\t_\t_\t11\t_\t_\tSpacesAfter=\\\\s\\\\s|start_char=85|end_char=86\n\"\"\".strip()\n\ndef test_spaces_mwt():\n    \"\"\"\n    Similar to the above test, but now we test it with MWT\n    \"\"\"\n    nlp = stanza.Pipeline(**{'processors': 'tokenize', 'download_method': None, 'dir': TEST_MODELS_DIR, 'lang': 'en'})\n    doc = nlp(\"  She's not a nice person.  However, the best antennae on the Cerritos are Jennifer's.  \")\n    result = \"{:C}\".format(doc)\n    result = result.strip()\n    assert EXPECTED_MWT == result\n"
  },
  {
    "path": "stanza/tests/tokenization/test_tokenization_lst20.py",
    "content": "import os\nimport tempfile\n\nimport pytest\n\nimport stanza\nfrom stanza.tests import *\n\nfrom stanza.utils.datasets.common import convert_conllu_to_txt\nfrom stanza.utils.datasets.tokenization.convert_th_lst20 import read_document\nfrom stanza.utils.datasets.tokenization.process_thai_tokenization import write_section\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\nSMALL_LST_SAMPLE=\"\"\"\nสุรยุทธ์\tNN\tB_PER\tB_CLS\nยัน\tVV\tO\tI_CLS\nปฏิเสธ\tVV\tO\tI_CLS\nลงนาม\tVV\tO\tI_CLS\n_\tPU\tO\tI_CLS\nMOU\tNN\tO\tI_CLS\n_\tPU\tO\tI_CLS\nกับ\tPS\tO\tI_CLS\nอียู\tNN\tB_ORG\tI_CLS\nไม่\tNG\tO\tI_CLS\nกระทบ\tVV\tO\tI_CLS\nสัมพันธ์\tNN\tO\tE_CLS\n\n1\tNU\tB_DTM\tB_CLS\n_\tPU\tI_DTM\tI_CLS\nกันยายน\tNN\tI_DTM\tI_CLS\n_\tPU\tI_DTM\tI_CLS\n2550\tNU\tE_DTM\tI_CLS\n_\tPU\tO\tI_CLS\n12:21\tNU\tB_DTM\tI_CLS\n_\tPU\tI_DTM\tI_CLS\nน.\tCL\tE_DTM\tE_CLS\n\nผู้สื่อข่าว\tNN\tO\tB_CLS\nรายงาน\tVV\tO\tI_CLS\nเพิ่มเติม\tVV\tO\tI_CLS\nว่า\tCC\tO\tE_CLS\n_\tPU\tO\tO\nจาก\tPS\tO\tB_CLS\nการ\tFX\tO\tI_CLS\nลง\tVV\tO\tI_CLS\nพื้นที่\tNN\tO\tI_CLS\nพบ\tVV\tO\tI_CLS\nว่า\tCC\tO\tE_CLS\n\"\"\".strip()\n\nEXPECTED_CONLLU=\"\"\"\n1\tสุรยุทธ์\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No|NewPar=Yes\n2\tยัน\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tปฏิเสธ\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tลงนาม\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tMOU\t_\t_\t_\t_\t4\tdep\t4:dep\t_\n6\tกับ\t_\t_\t_\t_\t5\tdep\t5:dep\tSpaceAfter=No\n7\tอียู\t_\t_\t_\t_\t6\tdep\t6:dep\tSpaceAfter=No\n8\tไม่\t_\t_\t_\t_\t7\tdep\t7:dep\tSpaceAfter=No\n9\tกระทบ\t_\t_\t_\t_\t8\tdep\t8:dep\tSpaceAfter=No\n10\tสัมพันธ์\t_\t_\t_\t_\t9\tdep\t9:dep\tSpaceAfter=No\n\n1\t1\t_\t_\t_\t_\t0\troot\t0:root\t_\n2\tกันยายน\t_\t_\t_\t_\t1\tdep\t1:dep\t_\n3\t2550\t_\t_\t_\t_\t2\tdep\t2:dep\t_\n4\t12:21\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tน.\t_\t_\t_\t_\t4\tdep\t4:dep\tSpaceAfter=No\n\n1\tผู้สื่อข่าว\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No\n2\tรายงาน\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tเพิ่มเติม\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tว่า\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tจาก\t_\t_\t_\t_\t4\tdep\t4:dep\tSpaceAfter=No\n6\tการ\t_\t_\t_\t_\t5\tdep\t5:dep\tSpaceAfter=No\n7\tลง\t_\t_\t_\t_\t6\tdep\t6:dep\tSpaceAfter=No\n8\tพื้นที่\t_\t_\t_\t_\t7\tdep\t7:dep\tSpaceAfter=No\n9\tพบ\t_\t_\t_\t_\t8\tdep\t8:dep\tSpaceAfter=No\n10\tว่า\t_\t_\t_\t_\t9\tdep\t9:dep\tSpaceAfter=No\n\"\"\".strip()\n\n# Note: these DO NOT line up perfectly (in an emacs window, at least)\n# because Thai characters have a length greater than 1.\n# The lengths of the words are:\n#   สุรยุทธ์    8\n#      ยัน    3\n#   ปฏิเสธ    6\n#   ลงนาม    5\n#     MOU    3\n#      กับ    3\n#      อียู    4\n#      ไม่    3\n#   กระทบ    5\n#   สัมพันธ์    8\n#       1    1\n#  กันยายน    7\n#    2550    4\n#   12:21    5\n#      น.    2\n#  ผู้สื่อข่าว   11\n#  รายงาน    6\n#  เพิ่มเติม    9\n#      ว่า    3\n#     จาก    3\n#     การ    3\n#      ลง    2\n#     พื้นที่    7\n#      พบ    2\n#      ว่า    3\nEXPECTED_TXT    =   \"สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์1 กันยายน 2550 12:21 น.ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\\n\\n\"\nEXPECTED_LABELS =   \"000000010010000010000100010001000100100001000000021000000010000100000100200000000001000001000000001001000100101000000101002\\n\\n\"\n# counting spaces    1234567812312345612345_123_123123412312345123456781_1234567_1234_12345_12123456789AB123456123456789123_12312312123456712123\n\n# note that the word splits go on the final letter of the word in the\n# UD conllu datasets, so that is what we mimic here\n# for example, from EWT:\n# Al-Zaman : American forces killed Shaikh Abdullah\n# 0110000101000000001000000100000010000001000000001\n\ndef check_results(documents, expected_conllu, expected_txt, expected_labels):\n    with tempfile.TemporaryDirectory() as output_dir:\n        write_section(output_dir, \"lst20\", \"train\", documents)\n        with open(os.path.join(output_dir, \"th_lst20.train.gold.conllu\")) as fin:\n            conllu = fin.read().strip()\n        with open(os.path.join(output_dir, \"th_lst20.train.txt\")) as fin:\n            txt = fin.read()\n        with open(os.path.join(output_dir, \"th_lst20-ud-train.toklabels\")) as fin:\n            labels = fin.read()\n        assert conllu == expected_conllu\n        assert txt == expected_txt\n        assert labels == expected_labels\n\n        assert len(txt) == len(labels)\n\n\ndef test_small():\n    \"\"\"\n    A small test just to verify that the output is being produced as we want\n\n    Note that there currently are no spaces after the first sentence.\n    Apparently this is wrong, but weirdly, doing that makes the model even worse.\n    \"\"\"\n    lines = SMALL_LST_SAMPLE.strip().split(\"\\n\")\n    documents = read_document(lines, spaces_after=False, split_clauses=False)\n    check_results(documents, EXPECTED_CONLLU, EXPECTED_TXT, EXPECTED_LABELS)\n\nEXPECTED_SPACE_CONLLU=\"\"\"\n1\tสุรยุทธ์\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No|NewPar=Yes\n2\tยัน\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tปฏิเสธ\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tลงนาม\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tMOU\t_\t_\t_\t_\t4\tdep\t4:dep\t_\n6\tกับ\t_\t_\t_\t_\t5\tdep\t5:dep\tSpaceAfter=No\n7\tอียู\t_\t_\t_\t_\t6\tdep\t6:dep\tSpaceAfter=No\n8\tไม่\t_\t_\t_\t_\t7\tdep\t7:dep\tSpaceAfter=No\n9\tกระทบ\t_\t_\t_\t_\t8\tdep\t8:dep\tSpaceAfter=No\n10\tสัมพันธ์\t_\t_\t_\t_\t9\tdep\t9:dep\t_\n\n1\t1\t_\t_\t_\t_\t0\troot\t0:root\t_\n2\tกันยายน\t_\t_\t_\t_\t1\tdep\t1:dep\t_\n3\t2550\t_\t_\t_\t_\t2\tdep\t2:dep\t_\n4\t12:21\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tน.\t_\t_\t_\t_\t4\tdep\t4:dep\t_\n\n1\tผู้สื่อข่าว\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No\n2\tรายงาน\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tเพิ่มเติม\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tว่า\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tจาก\t_\t_\t_\t_\t4\tdep\t4:dep\tSpaceAfter=No\n6\tการ\t_\t_\t_\t_\t5\tdep\t5:dep\tSpaceAfter=No\n7\tลง\t_\t_\t_\t_\t6\tdep\t6:dep\tSpaceAfter=No\n8\tพื้นที่\t_\t_\t_\t_\t7\tdep\t7:dep\tSpaceAfter=No\n9\tพบ\t_\t_\t_\t_\t8\tdep\t8:dep\tSpaceAfter=No\n10\tว่า\t_\t_\t_\t_\t9\tdep\t9:dep\t_\n\"\"\".strip()\n\nEXPECTED_SPACE_TXT    =   \"สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์ 1 กันยายน 2550 12:21 น. ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\\n\\n\"\nEXPECTED_SPACE_LABELS =   \"00000001001000001000010001000100010010000100000002010000000100001000001002000000000001000001000000001001000100101000000101002\\n\\n\"\n\ndef test_space_after():\n    \"\"\"\n    This version of the test adds the space after attribute\n    \"\"\"\n    lines = SMALL_LST_SAMPLE.strip().split(\"\\n\")\n    documents = read_document(lines, spaces_after=True, split_clauses=False)\n    check_results(documents, EXPECTED_SPACE_CONLLU, EXPECTED_SPACE_TXT, EXPECTED_SPACE_LABELS)\n\n\nEXPECTED_CLAUSE_CONLLU=\"\"\"\n1\tสุรยุทธ์\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No|NewPar=Yes\n2\tยัน\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tปฏิเสธ\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tลงนาม\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tMOU\t_\t_\t_\t_\t4\tdep\t4:dep\t_\n6\tกับ\t_\t_\t_\t_\t5\tdep\t5:dep\tSpaceAfter=No\n7\tอียู\t_\t_\t_\t_\t6\tdep\t6:dep\tSpaceAfter=No\n8\tไม่\t_\t_\t_\t_\t7\tdep\t7:dep\tSpaceAfter=No\n9\tกระทบ\t_\t_\t_\t_\t8\tdep\t8:dep\tSpaceAfter=No\n10\tสัมพันธ์\t_\t_\t_\t_\t9\tdep\t9:dep\t_\n\n1\t1\t_\t_\t_\t_\t0\troot\t0:root\t_\n2\tกันยายน\t_\t_\t_\t_\t1\tdep\t1:dep\t_\n3\t2550\t_\t_\t_\t_\t2\tdep\t2:dep\t_\n4\t12:21\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tน.\t_\t_\t_\t_\t4\tdep\t4:dep\t_\n\n1\tผู้สื่อข่าว\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No\n2\tรายงาน\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tเพิ่มเติม\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tว่า\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n\n1\tจาก\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No\n2\tการ\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tลง\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tพื้นที่\t_\t_\t_\t_\t3\tdep\t3:dep\tSpaceAfter=No\n5\tพบ\t_\t_\t_\t_\t4\tdep\t4:dep\tSpaceAfter=No\n6\tว่า\t_\t_\t_\t_\t5\tdep\t5:dep\t_\n\"\"\".strip()\n\nEXPECTED_CLAUSE_TXT    =   \"สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์ 1 กันยายน 2550 12:21 น. ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\\n\\n\"\nEXPECTED_CLAUSE_LABELS =   \"00000001001000001000010001000100010010000100000002010000000100001000001002000000000001000001000000001002000100101000000101002\\n\\n\"\n\n\ndef test_split_clause():\n    \"\"\"\n    This version of the test also resplits on spaces between clauses\n    \"\"\"\n    lines = SMALL_LST_SAMPLE.strip().split(\"\\n\")\n    documents = read_document(lines, spaces_after=True, split_clauses=True)\n    check_results(documents, EXPECTED_CLAUSE_CONLLU, EXPECTED_CLAUSE_TXT, EXPECTED_CLAUSE_LABELS)\n\nif __name__ == \"__main__\":\n    lines = SMALL_LST_SAMPLE.strip().split(\"\\n\")\n    documents = read_document(lines, spaces_after=False, split_clauses=False)\n\n    write_section(\"foo\", \"lst20\", \"train\", documents)\n"
  },
  {
    "path": "stanza/tests/tokenization/test_tokenization_orchid.py",
    "content": "import os\nimport tempfile\n\nimport pytest\n\nimport xml.etree.ElementTree as ET\n\nimport stanza\nfrom stanza.tests import *\n\nfrom stanza.utils.datasets.common import convert_conllu_to_txt\nfrom stanza.utils.datasets.tokenization.convert_th_orchid import parse_xml\nfrom stanza.utils.datasets.tokenization.process_thai_tokenization import write_section\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\n\nSMALL_DOC=\"\"\"\n<corpus>\n<document TPublisher=\"ศูนย์เทคโนโลยีอิเล็กทรอนิกส์และคอมพิวเตอร์แห่งชาติ, กระทรวงวิทยาศาสตร์ เทคโนโลยีและการพลังงาน\" EPublisher=\"National Electronics and Computer Technology Center, Ministry of Science, Technology and Energy\" TInbook=\"การประชุมทางวิชาการ ครั้งที่ 1, โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์, ปีงบประมาณ 2531, เล่ม 1\" TTitle=\"การประชุมทางวิชาการ ครั้งที่ 1\" Year=\"1989\" EInbook=\"The 1st Annual Conference, Electronics and Computer Research and Development Project, Fiscal Year 1988, Book 1\" ETitle=\"[1st Annual Conference]\">\n<paragraph id=\"1\" line_num=\"12\">\n<sentence id=\"1\" line_num = \"13\" raw_txt = \"การประชุมทางวิชาการ ครั้งที่ 1\">\n<word surface=\"การ\" pos=\"FIXN\"/>\n<word surface=\"ประชุม\" pos=\"VACT\"/>\n<word surface=\"ทาง\" pos=\"NCMN\"/>\n<word surface=\"วิชาการ\" pos=\"NCMN\"/>\n<word surface=\"&lt;space&gt;\" pos=\"PUNC\"/>\n<word surface=\"ครั้ง\" pos=\"CFQC\"/>\n<word surface=\"ที่ 1\" pos=\"DONM\"/>\n</sentence>\n<sentence id=\"2\" line_num = \"23\" raw_txt = \"โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์\">\n<word surface=\"โครงการวิจัยและพัฒนา\" pos=\"NCMN\"/>\n<word surface=\"อิเล็กทรอนิกส์\" pos=\"NCMN\"/>\n<word surface=\"และ\" pos=\"JCRG\"/>\n<word surface=\"คอมพิวเตอร์\" pos=\"NCMN\"/>\n</sentence>\n</paragraph>\n<paragraph id=\"3\" line_num=\"51\">\n<sentence id=\"1\" line_num = \"52\" raw_txt = \"วันที่ 15-16 สิงหาคม 2532\">\n<word surface=\"วัน\" pos=\"NCMN\"/>\n<word surface=\"ที่ 15\" pos=\"DONM\"/>\n<word surface=\"&lt;minus&gt;\" pos=\"PUNC\"/>\n<word surface=\"16\" pos=\"DONM\"/>\n<word surface=\"&lt;space&gt;\" pos=\"PUNC\"/>\n<word surface=\"สิงหาคม\" pos=\"NCMN\"/>\n<word surface=\"&lt;space&gt;\" pos=\"PUNC\"/>\n<word surface=\"2532\" pos=\"NCNM\"/>\n</sentence>\n</paragraph>\n</document>\n</corpus>\n\"\"\"\n\n\nEXPECTED_RESULTS=\"\"\"\n1\tการ\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No|NewPar=Yes\n2\tประชุม\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tทาง\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tวิชาการ\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tครั้ง\t_\t_\t_\t_\t4\tdep\t4:dep\tSpaceAfter=No\n6\tที่ 1\t_\t_\t_\t_\t5\tdep\t5:dep\t_\n\n1\tโครงการวิจัยและพัฒนา\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No\n2\tอิเล็กทรอนิกส์\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\tและ\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\tคอมพิวเตอร์\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n\n1\tวัน\t_\t_\t_\t_\t0\troot\t0:root\tSpaceAfter=No|NewPar=Yes\n2\tที่ 15\t_\t_\t_\t_\t1\tdep\t1:dep\tSpaceAfter=No\n3\t-\t_\t_\t_\t_\t2\tdep\t2:dep\tSpaceAfter=No\n4\t16\t_\t_\t_\t_\t3\tdep\t3:dep\t_\n5\tสิงหาคม\t_\t_\t_\t_\t4\tdep\t4:dep\t_\n6\t2532\t_\t_\t_\t_\t5\tdep\t5:dep\t_\n\"\"\".strip()\n\nEXPECTED_TEXT=\"\"\"การประชุมทางวิชาการ ครั้งที่ 1 โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์\n\nวันที่ 15-16 สิงหาคม 2532\n\n\"\"\"\n\nEXPECTED_LABELS=\"\"\"0010000010010000001000001000020000000000000000000010000000000000100100000000002\n\n0010000011010000000100002\n\n\"\"\"\n\ndef check_results(documents, expected_conllu, expected_txt, expected_labels):\n    with tempfile.TemporaryDirectory() as output_dir:\n        write_section(output_dir, \"orchid\", \"train\", documents)\n        with open(os.path.join(output_dir, \"th_orchid.train.gold.conllu\")) as fin:\n            conllu = fin.read().strip()\n        with open(os.path.join(output_dir, \"th_orchid.train.txt\")) as fin:\n            txt = fin.read()\n        with open(os.path.join(output_dir, \"th_orchid-ud-train.toklabels\")) as fin:\n            labels = fin.read()\n        assert conllu == expected_conllu\n        assert txt == expected_txt\n        assert labels == expected_labels\n\n        assert len(txt) == len(labels)\n\ndef test_orchid():\n    tree = ET.ElementTree(ET.fromstring(SMALL_DOC))\n    documents = parse_xml(tree)\n    check_results(documents, EXPECTED_RESULTS, EXPECTED_TEXT, EXPECTED_LABELS)\n\n"
  },
  {
    "path": "stanza/tests/tokenization/test_tokenize_data.py",
    "content": "\"\"\"\nVery simple test of the mwt counting functionality in tokenization/data.py\n\nTODO: could add a bunch more simple tests, including tests of reading\nthe data from a temp file, for example\n\"\"\"\n\nimport pytest\nimport tempfile\nimport numpy as np\n\nimport stanza\n\nfrom stanza import Pipeline\nfrom stanza.tests import *\nfrom stanza.models.tokenization.data import DataLoader, NUMERIC_RE\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef write_tokenizer_input(test_dir, raw_text, labels):\n    \"\"\"\n    Writes raw_text and labels to randomly named files in test_dir\n\n    Note that the tempfiles are not set to automatically clean up.\n    This will not be a problem if you put them in a tempdir.\n    \"\"\"\n    with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', dir=test_dir, delete=False) as fout:\n        txt_file = fout.name\n        fout.write(raw_text)\n\n    with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', dir=test_dir, delete=False) as fout:\n        label_file = fout.name\n        fout.write(labels)\n\n    return txt_file, label_file\n\n# A single slice of the German tokenization data with no MWT in it\nNO_MWT_TEXT   = \"Sehr gute Beratung, schnelle Behebung der Probleme\"\nNO_MWT_LABELS = \"00010000100000000110000000010000000010001000000002\"\n\n# A single slice of the German tokenization data with an MWT in it\nMWT_TEXT =   \" Die Kosten sind definitiv auch im Rahmen.\"\nMWT_LABELS = \"000100000010000100000000010000100300000012\"\n\nFAKE_PROPERTIES = {\n    \"lang\":\"de\",\n    'feat_funcs': (\"space_before\",\"capitalized\"),\n    'max_seqlen': 300,\n    'use_dictionary': False,\n}\n\ndef test_has_mwt():\n    \"\"\"\n    One dataset has no mwt, the other does\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        txt_file, label_file = write_tokenizer_input(test_dir, NO_MWT_TEXT, NO_MWT_LABELS)\n        data = DataLoader(args=FAKE_PROPERTIES, input_files={'txt': txt_file, 'label': label_file})\n        assert not data.has_mwt()\n\n        txt_file, label_file = write_tokenizer_input(test_dir, MWT_TEXT, MWT_LABELS)\n        data = DataLoader(args=FAKE_PROPERTIES, input_files={'txt': txt_file, 'label': label_file})\n        assert data.has_mwt()\n\n@pytest.fixture(scope=\"module\")\ndef tokenizer():\n    pipeline = Pipeline(\"en\", dir=TEST_MODELS_DIR, download_method=None, processors=\"tokenize\")\n    tokenizer = pipeline.processors['tokenize']\n    return tokenizer\n\n@pytest.fixture(scope=\"module\")\ndef zhtok():\n    pipeline = Pipeline(\"zh-hans\", dir=TEST_MODELS_DIR, download_method=None, processors=\"tokenize\")\n    tokenizer = pipeline.processors['tokenize']\n    return tokenizer\n\nEXPECTED_TWO_NL_RAW = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0)], [('f', 0), ('o', 0), ('o', 0)]]\n# in this test, the newline after test becomes a space labeled 0\nEXPECTED_ONE_NL_RAW = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), (' ', 0), ('f', 0), ('o', 0), ('o', 0)]]\nEXPECTED_SKIP_NL_RAW = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), ('f', 0), ('o', 0), ('o', 0)]]\n\ndef test_convert_units_raw_text(tokenizer):\n    \"\"\"\n    Tests converting a couple small segments to units\n    \"\"\"\n    raw_text = \"This is a      test\\n\\nfoo\"\n    batches = DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)\n    assert batches.data == EXPECTED_TWO_NL_RAW\n\n    raw_text = \"This is a      test\\nfoo\"\n    batches = DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)\n    assert batches.data == EXPECTED_ONE_NL_RAW\n\n    skip_newline_config = dict(tokenizer.config)\n    skip_newline_config['skip_newline'] = True\n    batches = DataLoader(skip_newline_config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)\n    assert batches.data == EXPECTED_SKIP_NL_RAW\n\n\nEXPECTED_TWO_NL_FILE = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), ('.', 1)], [('f', 0), ('o', 0), ('o', 0)]]\nEXPECTED_TWO_NL_FILE_LABELS = [np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=np.int32),\n                               np.array([0, 0, 0], dtype=np.int32)]\n\n# in this test, the newline after test becomes a space labeled 0\nEXPECTED_ONE_NL_FILE = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), ('.', 1), (' ', 0), ('f', 0), ('o', 0), ('o', 0)]]\nEXPECTED_ONE_NL_FILE_LABELS = [np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int32)]\n\nEXPECTED_SKIP_NL_FILE = [[('T', 0), ('h', 0), ('i', 0), ('s', 0), (' ', 0), ('i', 0), ('s', 0), (' ', 0), ('a', 0), (' ', 0), ('t', 0), ('e', 0), ('s', 0), ('t', 0), ('.', 1), ('f', 0), ('o', 0), ('o', 0)]]\nEXPECTED_SKIP_NL_FILE_LABELS = [np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], dtype=np.int32)]\n\ndef check_labels(labels, expected_labels):\n    assert len(labels) == len(expected_labels)\n    for label, expected in zip(labels, expected_labels):\n        assert np.array_equiv(label, expected)\n\ndef test_convert_units_file(tokenizer):\n    \"\"\"\n    Tests reading some text from a file and converting that to units\n    \"\"\"\n    with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:\n        # two nl test case, read from file\n        labels   = \"00000000000000000001\\n\\n000\\n\\n\"\n        raw_text = \"This is a      test.\\n\\nfoo\\n\\n\"\n        txt_file, label_file = write_tokenizer_input(test_dir, raw_text, labels)\n\n        batches = DataLoader(tokenizer.config, input_files={'txt': txt_file, 'label': label_file}, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)\n        assert batches.data == EXPECTED_TWO_NL_FILE\n        check_labels(batches.labels(), EXPECTED_TWO_NL_FILE_LABELS)\n\n        # one nl test case, read from file\n        labels   = \"000000000000000000010000\\n\\n\"\n        raw_text = \"This is a      test.\\nfoo\\n\\n\"\n        txt_file, label_file = write_tokenizer_input(test_dir, raw_text, labels)\n\n        batches = DataLoader(tokenizer.config, input_files={'txt': txt_file, 'label': label_file}, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)\n        assert batches.data == EXPECTED_ONE_NL_FILE\n        check_labels(batches.labels(), EXPECTED_ONE_NL_FILE_LABELS)\n\n        skip_newline_config = dict(tokenizer.config)\n        skip_newline_config['skip_newline'] = True\n        labels   = \"000000000000000000010000\\n\\n\"\n        raw_text = \"This is a      test.\\nfoo\\n\\n\"\n        txt_file, label_file = write_tokenizer_input(test_dir, raw_text, labels)\n\n        batches = DataLoader(skip_newline_config, input_files={'txt': txt_file, 'label': label_file}, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)\n        assert batches.data == EXPECTED_SKIP_NL_FILE\n        check_labels(batches.labels(), EXPECTED_SKIP_NL_FILE_LABELS)\n\n\ndef test_dictionary(zhtok):\n    \"\"\"\n    Tests some features of the zh tokenizer dictionary\n\n    The expectation is that the Chinese tokenizer will be serialized with a dictionary\n    (if it ever gets serialized without, this test will warn us!)\n    \"\"\"\n    assert zhtok.trainer.lexicon is not None\n    assert zhtok.trainer.dictionary is not None\n\n    assert \"老师\" in zhtok.trainer.lexicon\n    # egg-white-stuff, eg protein\n    assert \"蛋白质\" in zhtok.trainer.lexicon\n    # egg-white\n    assert \"蛋白\" in zhtok.trainer.dictionary['prefixes']\n    # egg\n    assert \"蛋\" in zhtok.trainer.dictionary['prefixes']\n    # white-stuff\n    assert \"白质\" in zhtok.trainer.dictionary['suffixes']\n    # stuff\n    assert \"质\" in zhtok.trainer.dictionary['suffixes']\n\ndef test_dictionary_feats(zhtok):\n    \"\"\"\n    Test the results of running a sentence into the dictionary featurizer\n    \"\"\"\n    raw_text = \"我想吃蛋白质\"\n    batches = DataLoader(zhtok.config, input_text=raw_text, vocab=zhtok.vocab, evaluation=True, dictionary=zhtok.trainer.dictionary)\n    data = batches.data\n    assert len(data) == 1\n    assert len(data[0]) == 6\n\n    expected_features = [\n        # in our example, the 2-grams made by the one character words at the start\n        # don't form any prefixes or suffixes\n        [0, 0, 0, 0, 0, 0, 0, 0],\n        [0, 0, 0, 0, 0, 0, 0, 0],\n        [0, 0, 0, 0, 0, 0, 0, 0],\n        [1, 1, 0, 0, 0, 0, 0, 0],\n        [0, 0, 0, 0, 1, 0, 0, 0],\n        [0, 0, 0, 0, 0, 1, 0, 0],\n    ]\n\n    for i, expected in enumerate(expected_features):\n        dict_features = batches.extract_dict_feat(data[0], i)\n        assert dict_features == expected\n\n\ndef test_numeric_re():\n    \"\"\"\n    Test the \"is numeric\" function\n\n    This function is entirely based on an RE in data.py\n    \"\"\"\n    # the last one is Thai\n    matches = [\"57\", \"135245345\", \"12535.\", \"852358.458345\", \"435345...345345\", \"111,,,111,,,111,,,111\", \"5318008\", \"５\", \"๕\"]\n\n    # note that we might want to consider .4 a numeric token after all\n    # however, changing that means retraining all the models\n    # the really long one only works if NUMERIC_RE avoids catastrophic backtracking\n    not_matches = [\".4\", \"54353a\", \"5453 35345\", \"aaa143234\", \"a,a,a,a\", \"sh'reyan\", \"asdaf786876asdfasdf\", \"\",\n                   \"11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111a\"]\n\n    for x in matches:\n        assert NUMERIC_RE.match(x) is not None\n    for x in not_matches:\n        assert NUMERIC_RE.match(x) is None\n"
  },
  {
    "path": "stanza/tests/tokenization/test_tokenize_files.py",
    "content": "import pytest\n\nfrom stanza.models.tokenization import tokenize_files\nfrom stanza.tests import TEST_MODELS_DIR\n\npytestmark = [pytest.mark.pipeline, pytest.mark.travis]\n\nEXPECTED = \"\"\"\nThis is a test . This is a second sentence .\nI took my daughter ice skating\n\"\"\".lstrip()\n\ndef test_tokenize_files(tmp_path):\n    input_file = tmp_path / \"input.txt\"\n    with open(input_file, \"w\") as fout:\n        fout.write(\"This is a test.  This is a second sentence.\\n\\nI took my daughter ice skating\")\n\n    output_file = tmp_path / \"output.txt\"\n    tokenize_files.main([str(input_file), \"--lang\", \"en\", \"--output_file\", str(output_file), \"--model_dir\", TEST_MODELS_DIR])\n\n    with open(output_file) as fin:\n        text = fin.read()\n\n    assert EXPECTED == text\n"
  },
  {
    "path": "stanza/tests/tokenization/test_tokenize_utils.py",
    "content": "\"\"\"\nVery simple test of the sentence slicing by <PAD> tags\n\nTODO: could add a bunch more simple tests for the tokenization utils\n\"\"\"\n\nimport pytest\nimport stanza\n\nfrom stanza import Pipeline\nfrom stanza.tests import *\nfrom stanza.models.common import doc\nfrom stanza.models.tokenization import data\nfrom stanza.models.tokenization import utils\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_find_spans():\n    \"\"\"\n    Test various raw -> span manipulations\n    \"\"\"\n    raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']\n    assert utils.find_spans(raw) == [(0, 14)]\n\n    raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', '<PAD>']\n    assert utils.find_spans(raw) == [(0, 14)]\n\n    raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', '<PAD>']\n    assert utils.find_spans(raw) == [(1, 15)]\n\n    raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']\n    assert utils.find_spans(raw) == [(1, 15)]\n\n    raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', '<PAD>', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']\n    assert utils.find_spans(raw) == [(1, 6), (7, 15)]\n\ndef check_offsets(doc, expected_offsets):\n    \"\"\"\n    Compare the start_char and end_char of the tokens in the doc with the given list of list of offsets\n    \"\"\"\n    assert len(doc.sentences) == len(expected_offsets)\n    for sentence, offsets in zip(doc.sentences, expected_offsets):\n        assert len(sentence.tokens) == len(offsets)\n        for token, offset in zip(sentence.tokens, offsets):\n            assert token.start_char == offset[0]\n            assert token.end_char == offset[1]\n\ndef test_match_tokens_with_text():\n    \"\"\"\n    Test the conversion of pretokenized text to Document\n    \"\"\"\n    doc = utils.match_tokens_with_text([[\"This\", \"is\", \"a\", \"test\"]], \"Thisisatest\")\n    expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)]]\n    check_offsets(doc, expected_offsets)\n\n    doc = utils.match_tokens_with_text([[\"This\", \"is\", \"a\", \"test\"], [\"unban\", \"mox\", \"opal\", \"!\"]], \"Thisisatest  unban mox  opal!\")\n    expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)],\n                        [(13, 18), (19, 22), (24, 28), (28, 29)]]\n    check_offsets(doc, expected_offsets)\n\n    with pytest.raises(ValueError):\n        doc = utils.match_tokens_with_text([[\"This\", \"is\", \"a\", \"test\"]], \"Thisisatestttt\")\n\n    with pytest.raises(ValueError):\n        doc = utils.match_tokens_with_text([[\"This\", \"is\", \"a\", \"test\"]], \"Thisisates\")\n\n    with pytest.raises(ValueError):\n        doc = utils.match_tokens_with_text([[\"This\", \"iz\", \"a\", \"test\"]], \"Thisisatest\")\n\ndef test_long_paragraph():\n    \"\"\"\n    Test the tokenizer's capacity to break text up into smaller chunks\n    \"\"\"\n    pipeline = Pipeline(\"en\", dir=TEST_MODELS_DIR, processors=\"tokenize\")\n    tokenizer = pipeline.processors['tokenize']\n\n    raw_text = \"TIL not to ask a date to dress up as Smurfette on a first date.  \" * 100\n\n    # run a test to make sure the chunk operation is called\n    # if not, the test isn't actually testing what we need to test\n    batches = data.TokenizationDataset(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)\n    batches.advance_old_batch = None\n    with pytest.raises(TypeError):\n        _, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000,\n                                                     orig_text=raw_text,\n                                                     no_ssplit=tokenizer.config.get('no_ssplit', False))\n\n    # a new DataLoader should not be crippled as the above one was\n    batches = data.TokenizationDataset(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)\n    _, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000,\n                                                 orig_text=raw_text,\n                                                 no_ssplit=tokenizer.config.get('no_ssplit', False))\n\n    document = doc.Document(document, raw_text)\n    assert len(document.sentences) == 100\n\ndef test_postprocessor_application():\n    \"\"\"\n    Check that the postprocessor behaves correctly by applying the identity postprocessor and hoping that it does indeed return correctly.\n    \"\"\"\n\n    good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], [\"I'm\", 'a', 'chicken', '.']]\n    text = \"I am Joe. ⭆⊱⇞ Hi. I'm a chicken.\"\n\n    target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': \"I'm\", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]]\n\n    def postprocesor(_):\n        return good_tokenization\n\n    res = utils.postprocess_doc(target_doc, postprocesor, text)\n\n    assert res == target_doc\n\ndef test_reassembly_indexing():\n    \"\"\"\n    Check that the reassembly code counts the indicies correctly, and including OOV chars.\n    \"\"\"\n\n    good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], [\"I'm\", 'a', 'chicken', '.']]\n    good_mwts = [[False for _ in range(len(i))] for i in good_tokenization]\n    good_expansions = [[None for _ in range(len(i))] for i in good_tokenization]\n\n    text = \"I am Joe. ⭆⊱⇞ Hi. I'm a chicken.\"\n\n    target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': \"I'm\", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]]\n\n    res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text)\n\n    assert res == target_doc\n\ndef test_reassembly_reference_failures():\n    \"\"\"\n    Check that the reassembly code complains correctly when the user adds tokens that doesn't exist\n    \"\"\"\n\n    bad_addition_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Southern', 'California', '.']]\n    bad_addition_mwts = [[False for _ in range(len(bad_addition_tokenization[0]))]]\n    bad_addition_expansions = [[None for _ in range(len(bad_addition_tokenization[0]))]]\n\n    bad_inline_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Californiaa', '.']]\n    bad_inline_mwts = [[False for _ in range(len(bad_inline_tokenization[0]))]]\n    bad_inline_expansions = [[None for _ in range(len(bad_inline_tokenization[0]))]]\n\n    good_tokenization = [['Joe', 'Smith', 'lives', 'in', 'California', '.']]\n    good_mwts = [[False for _ in range(len(good_tokenization[0]))]]\n    good_expansions = [[None for _ in range(len(good_tokenization[0]))]]\n\n    text = \"Joe Smith lives in California.\"\n\n    with pytest.raises(ValueError):\n        utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, bad_addition_expansions, text)\n\n    with pytest.raises(ValueError):\n        utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, bad_inline_mwts, text)\n\n    utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text)\n\n\n\nTRAIN_DATA = \"\"\"\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003\n# text = DPA: Iraqi authorities announced that they'd busted up three terrorist cells operating in Baghdad.\n1\tDPA\tDPA\tPROPN\tNNP\tNumber=Sing\t0\troot\t0:root\tSpaceAfter=No\n2\t:\t:\tPUNCT\t:\t_\t1\tpunct\t1:punct\t_\n3\tIraqi\tIraqi\tADJ\tJJ\tDegree=Pos\t4\tamod\t4:amod\t_\n4\tauthorities\tauthority\tNOUN\tNNS\tNumber=Plur\t5\tnsubj\t5:nsubj\t_\n5\tannounced\tannounce\tVERB\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t1\tparataxis\t1:parataxis\t_\n6\tthat\tthat\tSCONJ\tIN\t_\t9\tmark\t9:mark\t_\n7-8\tthey'd\t_\t_\t_\t_\t_\t_\t_\t_\n7\tthey\tthey\tPRON\tPRP\tCase=Nom|Number=Plur|Person=3|PronType=Prs\t9\tnsubj\t9:nsubj\t_\n8\t'd\thave\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t9\taux\t9:aux\t_\n9\tbusted\tbust\tVERB\tVBN\tTense=Past|VerbForm=Part\t5\tccomp\t5:ccomp\t_\n10\tup\tup\tADP\tRP\t_\t9\tcompound:prt\t9:compound:prt\t_\n11\tthree\tthree\tNUM\tCD\tNumForm=Digit|NumType=Card\t13\tnummod\t13:nummod\t_\n12\tterrorist\tterrorist\tADJ\tJJ\tDegree=Pos\t13\tamod\t13:amod\t_\n13\tcells\tcell\tNOUN\tNNS\tNumber=Plur\t9\tobj\t9:obj\t_\n14\toperating\toperate\tVERB\tVBG\tVerbForm=Ger\t13\tacl\t13:acl\t_\n15\tin\tin\tADP\tIN\t_\t16\tcase\t16:case\t_\n16\tBaghdad\tBaghdad\tPROPN\tNNP\tNumber=Sing\t14\tobl\t14:obl:in\tSpaceAfter=No\n17\t.\t.\tPUNCT\t.\t_\t1\tpunct\t1:punct\t_\n\n# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004\n# text = Two of them were being run by 2 officials of the Ministry of the Interior!\n1\tTwo\ttwo\tNUM\tCD\tNumForm=Word|NumType=Card\t6\tnsubj:pass\t6:nsubj:pass\t_\n2\tof\tof\tADP\tIN\t_\t3\tcase\t3:case\t_\n3\tthem\tthey\tPRON\tPRP\tCase=Acc|Number=Plur|Person=3|PronType=Prs\t1\tnmod\t1:nmod:of\t_\n4\twere\tbe\tAUX\tVBD\tMood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin\t6\taux\t6:aux\t_\n5\tbeing\tbe\tAUX\tVBG\tVerbForm=Ger\t6\taux:pass\t6:aux:pass\t_\n6\trun\trun\tVERB\tVBN\tTense=Past|VerbForm=Part|Voice=Pass\t0\troot\t0:root\t_\n7\tby\tby\tADP\tIN\t_\t9\tcase\t9:case\t_\n8\t2\t2\tNUM\tCD\tNumForm=Digit|NumType=Card\t9\tnummod\t9:nummod\t_\n9\tofficials\tofficial\tNOUN\tNNS\tNumber=Plur\t6\tobl\t6:obl:by\t_\n10\tof\tof\tADP\tIN\t_\t12\tcase\t12:case\t_\n11\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t12\tdet\t12:det\t_\n12\tMinistry\tMinistry\tPROPN\tNNP\tNumber=Sing\t9\tnmod\t9:nmod:of\t_\n13\tof\tof\tADP\tIN\t_\t15\tcase\t15:case\t_\n14\tthe\tthe\tDET\tDT\tDefinite=Def|PronType=Art\t15\tdet\t15:det\t_\n15\tInterior\tInterior\tPROPN\tNNP\tNumber=Sing\t12\tnmod\t12:nmod:of\tSpaceAfter=No\n16\t!\t!\tPUNCT\t.\t_\t6\tpunct\t6:punct\t_\n\n\"\"\".lstrip()\n\ndef test_lexicon_from_training_data(tmp_path):\n    \"\"\"\n    Test a couple aspects of building a lexicon from training data\n\n    expected number of words eliminated for being too long\n    duplicate words counted once\n    numbers eliminated\n    \"\"\"\n    conllu_file = str(tmp_path / \"train.conllu\")\n    with open(conllu_file, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(TRAIN_DATA)\n\n    lexicon, num_dict_feat = utils.create_lexicon(\"en_test\", conllu_file)\n    lexicon = sorted(lexicon)\n    expected_lexicon = [\"'d\", 'announced', 'baghdad', 'being', 'busted', 'by', 'cells', 'dpa', 'in', 'interior', 'iraqi', 'ministry', 'of', 'officials', 'operating', 'run', 'terrorist', 'that', 'the', 'them', 'they', \"they'd\", 'three', 'two', 'up', 'were']\n    assert lexicon == expected_lexicon\n    assert num_dict_feat == max(len(x) for x in lexicon)\n\n"
  },
  {
    "path": "stanza/tests/tokenization/test_vocab.py",
    "content": "import pytest\n\nfrom stanza.models.common.vocab import UNK, PAD\nfrom stanza.models.tokenization.vocab import Vocab\n\npytestmark = [pytest.mark.travis, pytest.mark.pipeline]\n\ndef test_build():\n    \"\"\"\n    Test that building a vocab out of a text produces the expected units and ids in the vocab\n    \"\"\"\n    text = [\"this is a test\"]\n    vocab = Vocab(data=text, lang=\"en\")\n    expected = {'<PAD>', '<UNK>', 't', 's', ' ', 'i', 'h', 'a', 'e'}\n    assert expected == set(vocab._id2unit)\n    for unit in vocab._id2unit:\n        assert vocab.id2unit(vocab.unit2id(unit)) == unit\n\n\ndef test_append():\n    text = [\"this is a test\"]\n    vocab = Vocab(data=text, lang=\"en\")\n\n    assert 'z' not in vocab\n    vocab.append('z')\n    expected = {'<PAD>', '<UNK>', 't', 's', ' ', 'i', 'h', 'a', 'e', 'z'}\n    assert expected == set(vocab._id2unit)\n    for unit in vocab._id2unit:\n        assert vocab.id2unit(vocab.unit2id(unit)) == unit\n"
  },
  {
    "path": "stanza/utils/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/avg_sent_len.py",
    "content": "import sys\nimport json\n\ndef avg_sent_len(toklabels):\n    if toklabels.endswith('.json'):\n        with open(toklabels, 'r') as f:\n            l = json.load(f)\n\n        l = [''.join([str(x[1]) for x in para]) for para in l]\n    else:\n        with open(toklabels, 'r') as f:\n            l = ''.join(f.readlines())\n\n        l = l.split('\\n\\n')\n\n    sentlen = [len(x) + 1 for para in l for x in para.split('2')]\n    return sum(sentlen) / len(sentlen)\n\nif __name__ == '__main__':\n    print(avg_sent_len(sys.args[1]))\n"
  },
  {
    "path": "stanza/utils/charlm/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/charlm/conll17_to_text.py",
    "content": "\"\"\"\nTurns a directory of conllu files from the conll 2017 shared task to a text file\n\nPart of the process for building a charlm dataset\n\npython conll17_to_text.py <directory>\n\nThis is an extension of the original script:\n  https://github.com/stanfordnlp/stanza-scripts/blob/master/charlm/conll17/conll2txt.py\n\nTo build a new charlm for a new language from a conll17 dataset:\n- look for conll17 shared task data, possibly here:\n  https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-1989\n- python3 stanza/utils/charlm/conll17_to_text.py ~/extern_data/conll17/Bulgarian --output_directory extern_data/charlm_raw/bg/conll17\n- python3 stanza/utils/charlm/make_lm_data.py --langs bg extern_data/charlm_raw extern_data/charlm/\n\"\"\"\n\nimport argparse\nimport lzma\nimport sys\nimport os\n\ndef process_file(input_filename, output_directory, compress):\n    if not input_filename.endswith('.conllu') and not input_filename.endswith(\".conllu.xz\"):\n        print(\"Skipping {}\".format(input_filename))\n        return\n\n    if input_filename.endswith(\".xz\"):\n        open_fn = lambda x: lzma.open(x, mode='rt')\n        output_filename = input_filename[:-3].replace(\".conllu\", \".txt\")\n    else:\n        open_fn = lambda x: open(x)\n        output_filename = input_filename.replace('.conllu', '.txt')\n\n    if output_directory:\n        output_filename = os.path.join(output_directory, os.path.split(output_filename)[1])\n\n    if compress:\n        output_filename = output_filename + \".xz\"\n        output_fn = lambda x: lzma.open(x, mode='wt')\n    else:\n        output_fn = lambda x: open(x, mode='w')\n\n    if os.path.exists(output_filename):\n        print(\"Cowardly refusing to overwrite %s\" % output_filename)\n        return\n\n    print(\"Converting %s to %s\" % (input_filename, output_filename))\n    with open_fn(input_filename) as fin:\n        sentences = []\n        sentence = []\n        for line in fin:\n            line = line.strip()\n            if len(line) == 0: # new sentence\n                sentences.append(sentence)\n                sentence = []\n                continue\n            if line[0] == '#': # comment\n                continue\n            splitline = line.split('\\t')\n            assert(len(splitline) == 10) # correct conllu\n            id, word = splitline[0], splitline[1]\n            if '-' not in id: # not mwt token\n                sentence.append(word)\n\n    if sentence:\n        sentences.append(sentence)\n\n    print(\"  Read in {} sentences\".format(len(sentences)))\n    with output_fn(output_filename) as fout:\n        fout.write('\\n'.join([' '.join(sentence) for sentence in sentences]))\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"input_directory\", help=\"Root directory with conllu or conllu.xz files.\")\n    parser.add_argument(\"--output_directory\", default=None, help=\"Directory to output to.  Will output to input_directory if None\")\n    parser.add_argument(\"--no_xz_output\", default=True, dest=\"xz_output\", action=\"store_false\", help=\"Output compressed xz files\")\n    args = parser.parse_args()\n    return args\n\n\nif __name__ == '__main__':\n    args = parse_args()\n    directory = args.input_directory\n    filenames = sorted(os.listdir(directory))\n    print(\"Files to process in {}: {}\".format(directory, filenames))\n    print(\"Processing to .xz files: {}\".format(args.xz_output))\n\n    if args.output_directory:\n        os.makedirs(args.output_directory, exist_ok=True)\n    for filename in filenames:\n        process_file(os.path.join(directory, filename), args.output_directory, args.xz_output)\n\n"
  },
  {
    "path": "stanza/utils/charlm/dump_oscar.py",
    "content": "\"\"\"\nThis script downloads and extracts the text from an Oscar crawl on HuggingFace\n\nTo use, just run\n\ndump_oscar.py <lang>\n\nIt will download the dataset and output all of the text to the --output directory.\nFiles will be broken into pieces to avoid having one giant file.\nBy default, files will also be compressed with xz (although this can be turned off)\n\"\"\"\n\nimport argparse\nimport lzma\nimport math\nimport os\n\nfrom tqdm import tqdm\n\nfrom datasets import get_dataset_split_names\nfrom datasets import load_dataset\n\nfrom stanza.models.common.constant import lang_to_langcode\n\ndef parse_args():\n    \"\"\"\n    A few specific arguments for the dump program\n\n    Uses lang_to_langcode to process args.language, hopefully converting\n    a variety of possible formats to the short code used by HuggingFace\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"language\", help=\"Language to download\")\n    parser.add_argument(\"--output\", default=\"oscar_dump\", help=\"Path for saving files\")\n    parser.add_argument(\"--no_xz\", dest=\"xz\", default=True, action='store_false', help=\"Don't xz the files - default is to compress while writing\")\n    parser.add_argument(\"--prefix\", default=\"oscar_dump\", help=\"Prefix to use for the pieces of the dataset\")\n    parser.add_argument(\"--version\", choices=[\"2019\", \"2023\"], default=\"2023\", help=\"Which version of the Oscar dataset to download\")\n\n    args = parser.parse_args()\n    args.language = lang_to_langcode(args.language)\n    return args\n\ndef download_2023(args):\n    dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd')\n    split_names = list(dataset.keys())\n\n\ndef main():\n    args = parse_args()\n\n    # this is the 2019 version.  for 2023, you can do\n    # dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd')\n    language = args.language\n    if args.version == \"2019\":\n        dataset_name = \"unshuffled_deduplicated_%s\" % language\n        try:\n            split_names = get_dataset_split_names(\"oscar\", dataset_name)\n        except ValueError as e:\n            raise ValueError(\"Language %s not available in HuggingFace Oscar\" % language) from e\n\n        if len(split_names) > 1:\n            raise ValueError(\"Unexpected split_names: {}\".format(split_names))\n\n        dataset = load_dataset(\"oscar\", dataset_name)\n        dataset = dataset[split_names[0]]\n        size_in_bytes = dataset.info.size_in_bytes\n        process_item = lambda x: x['text']\n    elif args.version == \"2023\":\n        dataset = load_dataset(\"oscar-corpus/OSCAR-2301\", language)\n        split_names = list(dataset.keys())\n        if len(split_names) > 1:\n            raise ValueError(\"Unexpected split_names: {}\".format(split_names))\n        # it's not clear if some languages don't support size_in_bytes,\n        # or if there was an update to datasets which now allows that\n        #\n        # previously we did:\n        #  dataset = dataset[split_names[0]]['text']\n        #  size_in_bytes = sum(len(x) for x in dataset)\n        #  process_item = lambda x: x\n        dataset = dataset[split_names[0]]\n        size_in_bytes = dataset.info.size_in_bytes\n        process_item = lambda x: x['text']\n    else:\n        raise AssertionError(\"Unknown version: %s\" % args.version)\n\n    chunks = max(1.0, size_in_bytes // 1e8) # an overestimate\n    id_len = max(3, math.floor(math.log10(chunks)) + 1)\n\n    if args.xz:\n        format_str = \"%s_%%0%dd.txt.xz\" % (args.prefix, id_len)\n        fopen = lambda file_idx: lzma.open(os.path.join(args.output, format_str % file_idx), \"wt\")\n    else:\n        format_str = \"%s_%%0%dd.txt\" % (args.prefix, id_len)\n        fopen = lambda file_idx: open(os.path.join(args.output, format_str % file_idx), \"w\")\n\n    print(\"Writing dataset to %s\" % args.output)\n    print(\"Dataset length: {}\".format(size_in_bytes))\n    os.makedirs(args.output, exist_ok=True)\n\n    file_idx = 0\n    file_len = 0\n    total_len = 0\n    fout = fopen(file_idx)\n\n    for item in tqdm(dataset):\n        text = process_item(item)\n        fout.write(text)\n        fout.write(\"\\n\")\n        file_len += len(text)\n        file_len += 1\n        if file_len > 1e8:\n            file_len = 0\n            fout.close()\n            file_idx = file_idx + 1\n            fout = fopen(file_idx)\n\n    fout.close()\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/charlm/make_lm_data.py",
    "content": "\"\"\"\nCreate Stanza character LM train/dev/test data, by reading from txt files in each source corpus directory,\nshuffling, splitting and saving into multiple smaller files (50MB by default) in a target directory.\n\nThis script assumes the following source directory structures:\n    - {src_dir}/{language}/{corpus}/*.txt\nIt will read from all source .txt files and create the following target directory structures:\n    - {tgt_dir}/{language}/{corpus}\nand within each target directory, it will create the following files:\n    - train/*.txt\n    - dev.txt\n    - test.txt\nArgs:\n    - src_root: root directory of the source.\n    - tgt_root: root directory of the target.\n    - langs: a list of language codes to process; if specified, languages not in this list will be ignored.\nNote: edit the {EXCLUDED_FOLDERS} variable to exclude more folders in the source directory.\n\"\"\"\n\nimport argparse\nimport glob\nimport os\nfrom pathlib import Path\nimport shutil\nimport subprocess\nimport tempfile\n\nfrom tqdm import tqdm\n\nEXCLUDED_FOLDERS = ['raw_corpus']\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"src_root\", default=\"src\", help=\"Root directory with all source files.  Expected structure is root dir -> language dirs -> package dirs -> text files to process\")\n    parser.add_argument(\"tgt_root\", default=\"tgt\", help=\"Root directory with all target files.\")\n    parser.add_argument(\"--langs\", default=\"\", help=\"A list of language codes to process.  If not set, all languages under src_root will be processed.\")\n    parser.add_argument(\"--packages\", default=\"\", help=\"A list of packages to process.  If not set, all packages under the languages found will be processed.\")\n    parser.add_argument(\"--no_xz_output\", default=True, dest=\"xz_output\", action=\"store_false\", help=\"Output compressed xz files\")\n    parser.add_argument(\"--split_size\", default=50, type=int, help=\"How large to make each split, in MB\")\n    parser.add_argument(\"--no_make_test_file\", default=True, dest=\"make_test_file\", action=\"store_false\", help=\"Don't save a test file.  Honestly, we never even use it.  Best for low resource languages where every bit helps\")\n    args = parser.parse_args()\n\n    print(\"Processing files:\")\n    print(f\"source root: {args.src_root}\")\n    print(f\"target root: {args.tgt_root}\")\n    print(\"\")\n\n    langs = []\n    if len(args.langs) > 0:\n        langs = args.langs.split(',')\n        print(\"Only processing the following languages: \" + str(langs))\n\n    packages = []\n    if len(args.packages) > 0:\n        packages = args.packages.split(',')\n        print(\"Only processing the following packages: \" + str(packages))\n\n    src_root = Path(args.src_root)\n    tgt_root = Path(args.tgt_root)\n\n    lang_dirs = os.listdir(src_root)\n    lang_dirs = [l for l in lang_dirs if l not in EXCLUDED_FOLDERS]    # skip excluded\n    lang_dirs = [l for l in lang_dirs if os.path.isdir(src_root / l)]  # skip non-directory\n    if len(langs) > 0: # filter languages if specified\n        lang_dirs = [l for l in lang_dirs if l in langs]\n    print(f\"{len(lang_dirs)} total languages found:\")\n    print(lang_dirs)\n    print(\"\")\n\n    split_size = int(args.split_size * 1024 * 1024)\n\n    for lang in lang_dirs:\n        lang_root = src_root / lang\n        data_dirs = os.listdir(lang_root)\n        if len(packages) > 0:\n            data_dirs = [d for d in data_dirs if d in packages]\n        data_dirs = [d for d in data_dirs if os.path.isdir(lang_root / d)]\n        print(f\"{len(data_dirs)} total corpus found for language {lang}.\")\n        print(data_dirs)\n        print(\"\")\n\n        for dataset_name in data_dirs:\n            src_dir = lang_root / dataset_name\n            tgt_dir = tgt_root / lang / dataset_name\n\n            if not os.path.exists(tgt_dir):\n                os.makedirs(tgt_dir)\n            print(f\"-> Processing {lang}-{dataset_name}\")\n            prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, args.xz_output, split_size, args.make_test_file)\n\n        print(\"\")\n\ndef prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, compress, split_size, make_test_file):\n    \"\"\"\n    Combine, shuffle and split data into smaller files, following a naming convention.\n    \"\"\"\n    assert isinstance(src_dir, Path)\n    assert isinstance(tgt_dir, Path)\n    with tempfile.TemporaryDirectory(dir=tgt_dir) as tempdir:\n        tgt_tmp = os.path.join(tempdir, f\"{lang}-{dataset_name}.tmp\")\n        print(f\"--> Copying files into {tgt_tmp}...\")\n        # TODO: we can do this without the shell commands\n        input_files = glob.glob(str(src_dir) + '/*.txt') + glob.glob(str(src_dir) + '/*.txt.xz') + glob.glob(str(src_dir) + '/*.txt.gz')\n        for src_fn in tqdm(input_files):\n            if src_fn.endswith(\".txt\"):\n                cmd = f\"cat {src_fn} >> {tgt_tmp}\"\n                subprocess.run(cmd, shell=True)\n            elif src_fn.endswith(\".txt.xz\"):\n                cmd = f\"xzcat {src_fn} >> {tgt_tmp}\"\n                subprocess.run(cmd, shell=True)\n            elif src_fn.endswith(\".txt.gz\"):\n                cmd = f\"zcat {src_fn} >> {tgt_tmp}\"\n                subprocess.run(cmd, shell=True)\n            else:\n                raise AssertionError(\"should not have found %s\" % src_fn)\n        tgt_tmp_shuffled = os.path.join(tempdir, f\"{lang}-{dataset_name}.tmp.shuffled\")\n\n        print(f\"--> Shuffling files into {tgt_tmp_shuffled}...\")\n        cmd = f\"cat {tgt_tmp} | shuf > {tgt_tmp_shuffled}\"\n        result = subprocess.run(cmd, shell=True)\n        if result.returncode != 0:\n            raise RuntimeError(\"Failed to shuffle files!\")\n        size = os.path.getsize(tgt_tmp_shuffled) / 1024 / 1024 / 1024\n        print(f\"--> Shuffled file size: {size:.4f} GB\")\n        if size < 0.1:\n            raise RuntimeError(\"Not enough data found to build a charlm.  At least 100MB data expected\")\n\n        print(f\"--> Splitting into smaller files of size {split_size} ...\")\n        train_dir = tgt_dir / 'train'\n        if not os.path.exists(train_dir): # make training dir\n            os.makedirs(train_dir)\n        cmd = f\"split -C {split_size} -a 4 -d --additional-suffix .txt {tgt_tmp_shuffled} {train_dir}/{lang}-{dataset_name}-\"\n        result = subprocess.run(cmd, shell=True)\n        if result.returncode != 0:\n            raise RuntimeError(\"Failed to split files!\")\n        total = len(glob.glob(f'{train_dir}/*.txt'))\n        print(f\"--> {total} total files generated.\")\n        if total < 3:\n            raise RuntimeError(\"Something went wrong!  %d file(s) produced by shuffle and split, expected at least 3\" % total)\n\n        dev_file = f\"{tgt_dir}/dev.txt\"\n        test_file = f\"{tgt_dir}/test.txt\"\n        if make_test_file:\n            print(\"--> Creating dev and test files...\")\n            shutil.move(f\"{train_dir}/{lang}-{dataset_name}-0000.txt\", dev_file)\n            shutil.move(f\"{train_dir}/{lang}-{dataset_name}-0001.txt\", test_file)\n            txt_files = [dev_file, test_file] + glob.glob(f'{train_dir}/*.txt')\n        else:\n            print(\"--> Creating dev file...\")\n            shutil.move(f\"{train_dir}/{lang}-{dataset_name}-0000.txt\", dev_file)\n            txt_files = [dev_file] + glob.glob(f'{train_dir}/*.txt')\n\n        if compress:\n            print(\"--> Compressing files...\")\n            for txt_file in tqdm(txt_files):\n                subprocess.run(['xz', txt_file])\n\n        print(\"--> Cleaning up...\")\n    print(f\"--> All done for {lang}-{dataset_name}.\\n\")\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/charlm/oscar_to_text.py",
    "content": "\"\"\"\nTurns an Oscar 2022 jsonl file to text\n\nYOU DO NOT NEED THIS if you use the oscar extractor which reads from\nHuggingFace, dump_oscar.py\n\nto run:\npython3 -m stanza.utils.charlm.oscar_to_text <path> ...\n\neach path can be a file or a directory with multiple .jsonl files in it\n\"\"\"\n\nimport argparse\nimport glob\nimport json\nimport lzma\nimport os\nimport sys\nfrom stanza.models.common.utils import open_read_text\n\ndef extract_file(output_directory, input_filename, use_xz):\n    print(\"Extracting %s\" % input_filename)\n    if output_directory is None:\n        output_directory, output_filename = os.path.split(input_filename)\n    else:\n        _, output_filename = os.path.split(input_filename)\n\n    json_idx = output_filename.rfind(\".jsonl\")\n    if json_idx < 0:\n        output_filename = output_filename + \".txt\"\n    else:\n        output_filename = output_filename[:json_idx] + \".txt\"\n    if use_xz:\n        output_filename += \".xz\"\n        open_file = lambda x: lzma.open(x, \"wt\", encoding=\"utf-8\")\n    else:\n        open_file = lambda x: open(x, \"w\", encoding=\"utf-8\")\n\n    output_filename = os.path.join(output_directory, output_filename)\n    print(\"Writing content to %s\" % output_filename)\n    with open_read_text(input_filename) as fin:\n        with open_file(output_filename) as fout:\n            for line in fin:\n                content = json.loads(line)\n                content = content['content']\n\n                fout.write(content)\n                fout.write(\"\\n\\n\")\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--output\", default=None, help=\"Output directory for saving files.  If None, will write to the original directory\")\n    parser.add_argument(\"--no_xz\", default=True, dest=\"xz\", action=\"store_false\", help=\"Don't use xz to compress the output files\")\n    parser.add_argument(\"filenames\", nargs=\"+\", help=\"Filenames or directories to process\")\n    args = parser.parse_args()\n    return args\n\ndef main():\n    \"\"\"\n    Go through each of the given filenames or directories, convert json to .txt.xz\n    \"\"\"\n    args = parse_args()\n    if args.output is not None:\n        os.makedirs(args.output, exist_ok=True)\n    for filename in args.filenames:\n        if os.path.isfile(filename):\n            extract_file(args.output, filename, args.xz)\n        elif os.path.isdir(filename):\n            files = glob.glob(os.path.join(filename, \"*jsonl*\"))\n            files = sorted([x for x in files if os.path.isfile(x)])\n            print(\"Found %d files:\" % len(files))\n            if len(files) > 0:\n                print(\"  %s\" % \"\\n  \".join(files))\n            for json_filename in files:\n                extract_file(args.output, json_filename, args.xz)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/confusion.py",
    "content": "\nfrom collections import defaultdict, namedtuple\n\nF1Result = namedtuple(\"F1Result\", ['precision', 'recall', 'f1'])\n\ndef condense_ner_labels(confusion, gold_labels, pred_labels):\n    new_confusion = defaultdict(lambda: defaultdict(int))\n    new_gold_labels = []\n    new_pred_labels = []\n    for l1 in gold_labels:\n        if l1.find(\"-\") >= 0:\n            new_l1 = l1.split(\"-\", 1)[1]\n        else:\n            new_l1 = l1\n        if new_l1 not in new_gold_labels:\n            new_gold_labels.append(new_l1)\n        for l2 in pred_labels:\n            if l2.find(\"-\") >= 0:\n                new_l2 = l2.split(\"-\", 1)[1]\n            else:\n                new_l2 = l2\n            if new_l2 not in new_pred_labels:\n                new_pred_labels.append(new_l2)\n\n            old_value = confusion.get(l1, {}).get(l2, 0)\n            new_confusion[new_l1][new_l2] = new_confusion[new_l1][new_l2] + old_value\n    return new_confusion, new_gold_labels, new_pred_labels\n\n\ndef format_confusion(confusion, labels=None, hide_zeroes=False, hide_blank=False, transpose=False):\n    \"\"\"\n    pretty print for confusion matrixes\n    adapted from https://gist.github.com/zachguo/10296432\n\n    The matrix should look like this:\n      confusion[gold][pred]\n    \"\"\"\n    def sort_labels(labels):\n        \"\"\"\n        Sorts the labels in the list, respecting BIES if all labels are BIES, putting O at the front\n        \"\"\"\n        labels = set(labels)\n        if 'O' in labels:\n            had_O = True\n            labels.remove('O')\n        else:\n            had_O = False\n\n        if not all(isinstance(x, str) and len(x) > 2 and x[0] in ('B', 'I', 'E', 'S') and x[1] in ('-', '_') for x in labels):\n            labels = sorted(labels)\n        else:\n            # sort first by the body of the lable, then by BEIS\n            labels = sorted(labels, key=lambda x: (x[2:], x[0]))\n        if had_O:\n            labels = ['O'] + labels\n        return labels\n\n    if transpose:\n        new_confusion = defaultdict(lambda: defaultdict(int))\n        for label1 in confusion.keys():\n            for label2 in confusion[label1].keys():\n                new_confusion[label2][label1] = confusion[label1][label2]\n        confusion = new_confusion\n\n    if labels is None:\n        gold_labels = set(confusion.keys())\n        if hide_blank:\n            gold_labels = set(x for x in gold_labels if any(confusion[x][key] != 0 for key in confusion[x].keys()))\n\n        pred_labels = set()\n        for key in confusion.keys():\n            if hide_blank:\n                new_pred_labels = set(x for x in confusion[key].keys() if confusion[key][x] != 0)\n            else:\n                new_pred_labels = confusion[key].keys()\n            pred_labels = pred_labels.union(new_pred_labels)\n\n        if not hide_blank:\n            gold_labels = gold_labels.union(pred_labels)\n            pred_labels = gold_labels\n\n        gold_labels = sort_labels(gold_labels)\n        pred_labels = sort_labels(pred_labels)\n    else:\n        gold_labels = labels\n        pred_labels = labels\n\n    columnwidth = max([len(str(x)) for x in pred_labels] + [5])  # 5 is value length\n    empty_cell = \" \" * columnwidth\n\n    # If the numbers are all ints, no need to include the .0 at the end of each entry\n    all_ints = True\n    for i, label1 in enumerate(gold_labels):\n        for j, label2 in enumerate(pred_labels):\n            if not isinstance(confusion.get(label1, {}).get(label2, 0), int):\n                all_ints = False\n                break\n        if not all_ints:\n            break\n\n    if all_ints:\n        format_cell = lambda confusion_cell: \"%{0}d\".format(columnwidth) % confusion_cell\n    else:\n        format_cell = lambda confusion_cell: \"%{0}.1f\".format(columnwidth) % confusion_cell\n\n    # make sure the columnwidth can handle long numbers\n    for i, label1 in enumerate(gold_labels):\n        for j, label2 in enumerate(pred_labels):\n            cell = confusion.get(label1, {}).get(label2, 0)\n            columnwidth = max(columnwidth, len(format_cell(cell)))\n\n    # if this is an NER confusion matrix (well, if it has - in the labels)\n    # try to drop a bunch of labels to make the matrix easier to display\n    if columnwidth * len(pred_labels) > 150:\n        confusion, gold_labels, pred_labels = condense_ner_labels(confusion, gold_labels, pred_labels)\n\n    # Print header\n    if transpose:\n        corner_label = \"p\\\\t\"\n    else:\n        corner_label = \"t\\\\p\"\n    fst_empty_cell = (columnwidth-3)//2 * \" \" + corner_label + (columnwidth-3)//2 * \" \"\n    if len(fst_empty_cell) < len(empty_cell):\n        fst_empty_cell = \" \" * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell\n    header = \"    \" + fst_empty_cell + \" \"\n    for label in pred_labels:\n        header = header + \"%{0}s \".format(columnwidth) % str(label)\n    text = [header.rstrip()]\n\n    # Print rows\n    for i, label1 in enumerate(gold_labels):\n        row = \"    %{0}s \".format(columnwidth) % str(label1)\n        for j, label2 in enumerate(pred_labels):\n            confusion_cell = confusion.get(label1, {}).get(label2, 0)\n            cell = format_cell(confusion_cell)\n            if hide_zeroes:\n                cell = cell if confusion_cell else empty_cell\n            row = row + cell + \" \"\n        text.append(row.rstrip())\n    return \"\\n\".join(text)\n\n\ndef confusion_to_accuracy(confusion_matrix):\n    \"\"\"\n    Given a confusion dictionary, return correct, total\n    \"\"\"\n    correct = 0\n    total = 0\n    for l1 in confusion_matrix.keys():\n        for l2 in confusion_matrix[l1].keys():\n            if l1 == l2:\n                correct = correct + confusion_matrix[l1][l2]\n            else:\n                total = total + confusion_matrix[l1][l2]\n    return correct, (correct + total)\n\ndef confusion_to_f1(confusion_matrix):\n    results = {}\n\n    keys = set()\n    for k in confusion_matrix.keys():\n        keys.add(k)\n        for k2 in confusion_matrix.get(k).keys():\n            keys.add(k2)\n\n    sum_f1 = 0\n    for k in keys:\n        tp = 0\n        fn = 0\n        fp = 0\n        for k2 in keys:\n            if k == k2:\n                tp = confusion_matrix.get(k, {}).get(k, 0)\n            else:\n                fn = fn + confusion_matrix.get(k, {}).get(k2, 0)\n                fp = fp + confusion_matrix.get(k2, {}).get(k, 0)\n        if tp + fp == 0:\n            precision = 0.0\n        else:\n            precision = tp / (tp + fp)\n        if tp + fn == 0:\n            recall = 0.0\n        else:\n            recall = tp / (tp + fn)\n        if precision + recall == 0.0:\n            f1 = 0.0\n        else:\n            f1 = 2 * (precision * recall) / (precision + recall)\n\n        results[k] = F1Result(precision, recall, f1)\n\n    return results\n\ndef confusion_to_macro_f1(confusion_matrix):\n    \"\"\"\n    Return the macro f1 for a confusion matrix.\n    \"\"\"\n    sum_f1 = 0.0\n    results = confusion_to_f1(confusion_matrix)\n    for k in results.keys():\n        sum_f1 = sum_f1 + results[k].f1\n\n    return sum_f1 / len(results)\n\ndef confusion_to_weighted_f1(confusion_matrix, exclude=None):\n    results = confusion_to_f1(confusion_matrix)\n\n    sum_f1 = 0.0\n    total_items = 0\n    for k in results.keys():\n        if exclude is not None and k in exclude:\n            continue\n        k_items = sum(confusion_matrix.get(k, {}).values())\n        total_items += k_items\n        sum_f1 += results[k].f1 * k_items\n    return sum_f1 / total_items\n"
  },
  {
    "path": "stanza/utils/conll.py",
    "content": "\"\"\"\nUtility functions for the loading and conversion of CoNLL-format files.\n\"\"\"\nimport os\nimport io\nfrom zipfile import ZipFile\n\nfrom stanza.models.common.doc import Document\nfrom stanza.models.common.doc import ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, NER, START_CHAR, END_CHAR\nfrom stanza.models.common.doc import FIELD_TO_IDX, FIELD_NUM\nfrom stanza.models.common.doc import LINE_NUMBER\n\nclass CoNLLError(ValueError):\n    pass\n\nclass CoNLL:\n\n    @staticmethod\n    def load_conll(f, ignore_gapping=True, keep_line_numbers=False):\n        \"\"\" Load the file or string into the CoNLL-U format data.\n        Input: file or string reader, where the data is in CoNLL-U format.\n        Output: a tuple whose first element is a list of list of list for each token in each sentence in the data,\n        where the innermost list represents all fields of a token; and whose second element is a list of lists for each\n        comment in each sentence in the data.\n        \"\"\"\n        # f is open() or io.StringIO()\n        doc, sent = [], []\n        doc_comments, sent_comments = [], []\n        for line_idx, line in enumerate(f):\n            # leave whitespace such as NBSP, in case it is meaningful in the conll-u doc\n            line = line.lstrip().rstrip(' \\n\\r\\t')\n            if len(line) == 0:\n                if len(sent) > 0:\n                    doc.append(sent)\n                    sent = []\n                    doc_comments.append(sent_comments)\n                    sent_comments = []\n            else:\n                if line.startswith('#'): # read comment line\n                    sent_comments.append(line)\n                    continue\n                array = line.split('\\t')\n                if ignore_gapping and '.' in array[0]:\n                    continue\n                if len(array) != FIELD_NUM:\n                    raise CoNLLError(f\"Cannot parse CoNLL line {line_idx+1}: expecting {FIELD_NUM} fields, {len(array)} found at line {line_idx}\\n  {array}\")\n                if keep_line_numbers:\n                    if array[-1] == \"_\" or array[-1] is None:\n                        array[-1] = \"%s=%d\" % (LINE_NUMBER, line_idx)\n                    else:\n                        array[-1] = \"%s|%s=%d\" % (array[-1], LINE_NUMBER, line_idx)\n                sent += [array]\n        if len(sent) > 0:\n            doc.append(sent)\n            doc_comments.append(sent_comments)\n        return doc, doc_comments\n\n    @staticmethod\n    def convert_conll(doc_conll):\n        \"\"\" Convert the CoNLL-U format input data to a dictionary format output data.\n        Input: list of token fields loaded from the CoNLL-U format data, where the outmost list represents a list of sentences, and the inside list represents all fields of a token.\n        Output: a list of list of dictionaries for each token in each sentence in the document.\n        \"\"\"\n        doc_dict = []\n        doc_empty = []\n        for sent_idx, sent_conll in enumerate(doc_conll):\n            sent_dict = []\n            sent_empty = []\n            for token_idx, token_conll in enumerate(sent_conll):\n                try:\n                    token_dict = CoNLL.convert_conll_token(token_conll)\n                except ValueError as e:\n                    raise CoNLLError(\"Could not process sentence %d token %d:\\n%s\\n%s\" % (sent_idx, token_idx, token_conll, str(e))) from e\n                if '.' in token_dict[ID]:\n                    token_dict[ID] = tuple(int(x) for x in token_dict[ID].split(\".\", maxsplit=1))\n                    sent_empty.append(token_dict)\n                else:\n                    try:\n                        token_dict[ID] = tuple(int(x) for x in token_dict[ID].split(\"-\", maxsplit=1))\n                    except ValueError as e:\n                        raise CoNLLError(\"Could not process ID %s at sent_idx %d, token_idx %d\\nEntire token dict:\\n%s\" % (token_dict[ID], sent_idx, token_idx, token_dict)) from e\n                    sent_dict.append(token_dict)\n            doc_dict.append(sent_dict)\n            doc_empty.append(sent_empty)\n        return doc_dict, doc_empty\n\n    @staticmethod\n    def convert_dict(doc_dict):\n        \"\"\" Convert the dictionary format input data to the CoNLL-U format output data.\n\n        This is the reverse function of `convert_conll`, but does not include sentence level annotations or comments.\n\n        Can call this on a Document using `CoNLL.convert_dict(doc.to_dict())`\n\n        Input: dictionary format data, which is a list of list of dictionaries for each token in each sentence in the data.\n        Output: CoNLL-U format data as a list of list of list for each token in each sentence in the data.\n        \"\"\"\n        doc = Document(doc_dict)\n        text = \"{:c}\".format(doc)\n        sentences = text.split(\"\\n\\n\")\n        doc_conll = [[x.split(\"\\t\") for x in sentence.split(\"\\n\")] for sentence in sentences]\n        return doc_conll\n\n    @staticmethod\n    def convert_conll_token(token_conll):\n        \"\"\" Convert the CoNLL-U format input token to the dictionary format output token.\n        Input: a list of all CoNLL-U fields for the token.\n        Output: a dictionary that maps from field name to value.\n        \"\"\"\n        token_dict = {}\n        for field, field_idx in FIELD_TO_IDX.items():\n            value = token_conll[field_idx]\n            if value == '' and field is FEATS:\n                continue\n            elif value != '_':\n                if field is HEAD:\n                    token_dict[field] = int(value)\n                else:\n                    token_dict[field] = value\n        # special case if text is '_'\n        if token_conll[FIELD_TO_IDX[TEXT]] == '_':\n            token_dict[TEXT] = token_conll[FIELD_TO_IDX[TEXT]]\n            token_dict[LEMMA] = token_conll[FIELD_TO_IDX[LEMMA]]\n        return token_dict\n\n    @staticmethod\n    def conll2dict(input_file=None, input_str=None, ignore_gapping=True, zip_file=None, keep_line_numbers=False):\n        \"\"\" Load the CoNLL-U format data from file or string into lists of dictionaries.\n        \"\"\"\n        assert any([input_file, input_str]) and not all([input_file, input_str]), 'either use input file or input string'\n        if zip_file: assert input_file, 'must provide input_file if zip_file is set'\n\n        if input_str:\n            infile = io.StringIO(input_str)\n            doc_conll, doc_comments = CoNLL.load_conll(infile, ignore_gapping, keep_line_numbers)\n        elif zip_file:\n            with ZipFile(zip_file) as zin:\n                with zin.open(input_file) as fin:\n                    doc_conll, doc_comments = CoNLL.load_conll(io.TextIOWrapper(fin, encoding=\"utf-8\"), ignore_gapping, keep_line_numbers)\n        else:\n            with open(input_file, encoding='utf-8') as fin:\n                doc_conll, doc_comments = CoNLL.load_conll(fin, ignore_gapping, keep_line_numbers)\n\n        doc_dict, doc_empty = CoNLL.convert_conll(doc_conll)\n        return doc_dict, doc_comments, doc_empty\n\n    @staticmethod\n    def conll2doc(input_file=None, input_str=None, ignore_gapping=True, zip_file=None, keep_line_numbers=False):\n        doc_dict, doc_comments, doc_empty = CoNLL.conll2dict(input_file, input_str, ignore_gapping, zip_file=zip_file, keep_line_numbers=keep_line_numbers)\n        return Document(doc_dict, text=None, comments=doc_comments, empty_sentences=doc_empty)\n\n    @staticmethod\n    def conll2multi_docs(input_file=None, input_str=None, ignore_gapping=True, zip_file=None):\n        doc_dict, doc_comments, doc_empty = CoNLL.conll2dict(input_file, input_str, ignore_gapping, zip_file=zip_file)\n\n        docs = []\n        current_doc = []\n        current_comments = []\n        current_empty = []\n        current_doc_id = None\n        for doc, comments, empty in zip(doc_dict, doc_comments, doc_empty):\n            for comment in comments:\n                if comment.startswith(\"# doc_id =\") or comment.startswith(\"# newdoc id =\"):\n                    doc_id = comment.split(\"=\", maxsplit=1)[1]\n                    if len(current_doc) == 0:\n                        current_doc_id = doc_id\n                    elif doc_id != current_doc_id:\n                        new_doc = Document(current_doc, text=None, comments=current_comments, empty_sentences=current_empty)\n                        if current_doc_id != None:\n                            for i in new_doc.sentences:\n                                i.doc_id = current_doc_id.strip()\n                        docs.append(new_doc)\n                        current_doc_id = doc_id\n                    else:\n                        continue\n                    current_doc = [doc]\n                    current_comments = [comments]\n                    current_empty = [empty]\n                    break\n            else: # no comments defined a new doc_id, so just add it to the current document\n                current_doc.append(doc)\n                current_comments.append(comments)\n                current_empty.append(empty)\n        if len(current_doc) > 0:\n            new_doc = Document(current_doc, text=None, comments=current_comments, empty_sentences=current_empty)\n            if current_doc_id != None:\n                for i in new_doc.sentences:\n                    i.doc_id = current_doc_id.strip()\n            docs.append(new_doc)\n            current_doc_id = doc_id\n\n        return docs\n\n    @staticmethod\n    def dict2conll(doc_dict, filename):\n        \"\"\"\n        Convert the dictionary format input data to the CoNLL-U format output data and write to a file.\n        \"\"\"\n        doc = Document(doc_dict)\n        CoNLL.write_doc2conll(doc, filename)\n\n\n    @staticmethod\n    def write_doc2conll(doc, filename, mode='w', encoding='utf-8'):\n        \"\"\"\n        Writes the doc as a conll file to the given file.\n\n        If passed a string, that filename will be opened.  Otherwise, filename.write() will be called.\n\n        Note that the output needs an extra \\n\\n at the end to be a legal output file\n        \"\"\"\n        if hasattr(filename, \"write\"):\n            filename.write(\"{:C}\\n\\n\".format(doc))\n        else:\n            with open(filename, mode, encoding=encoding) as outfile:\n                outfile.write(\"{:C}\\n\\n\".format(doc))\n"
  },
  {
    "path": "stanza/utils/constituency/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/constituency/check_transitions.py",
    "content": "import argparse\n\nfrom stanza.models.constituency import transition_sequence\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.parse_transitions import TransitionScheme\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.models.constituency.utils import verify_transitions\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--train_file', type=str, default=\"data/constituency/en_ptb3_train.mrg\", help='Input file for data loader.')\n    parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()],\n                        help='Transition scheme to use.  {}'.format(\", \".join(x.name for x in TransitionScheme)))\n    parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed')\n    parser.add_argument('--iterations', default=30, type=int, help='How many times to iterate, such as if doing a cProfile')\n    args = parser.parse_args()\n    args = vars(args)\n\n    train_trees = tree_reader.read_treebank(args['train_file'])\n    unary_limit = max(t.count_unary_depth() for t in train_trees) + 1\n    train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, \"training\", args['transition_scheme'], args['reversed'])\n    root_labels = Tree.get_root_labels(train_trees)\n    for i in range(args['iterations']):\n        verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], \"train\", root_labels)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/constituency/grep_dev_logs.py",
    "content": "import subprocess\nimport sys\n\niteration = sys.argv[1]\nfilenames = sys.argv[2:]\n\ntotal_score = 0.0\nnum_scores = 0\n\nfor filename in filenames:\n    grep_cmd = [\"grep\", \"Dev score.* %s[)]\" % iteration, \"-A1\", filename]\n    grep_result = subprocess.run(grep_cmd, stdout=subprocess.PIPE, encoding=\"utf-8\")\n    grep_result = grep_result.stdout.strip()\n    if not grep_result:\n        max_cmd = [\"grep\", \"Dev score\", filename]\n        max_result = subprocess.run(max_cmd, stdout=subprocess.PIPE, encoding=\"utf-8\")\n        max_result = max_result.stdout.strip()\n        if not max_result:\n            print(\"{}: no result\".format(filename))\n        else:\n            max_it = max_result.split(\"\\n\")[-1]\n            max_it = int(max_it.split(\":\")[0].split(\"(\")[-1][:-1])\n            epoch_finished_string = \"Epoch %d finished\" % max_it\n            finish_cmd = [\"grep\", epoch_finished_string, filename]\n            finish_result = subprocess.run(finish_cmd, stdout=subprocess.PIPE, encoding=\"utf-8\")\n            finish_result = finish_result.stdout.strip()\n            finish_time = finish_result.split(\" INFO\")[0]\n            print(\"{}: no result.  max iteration: {}   finished at {}\".format(filename, max_it, finish_time))\n    else:\n        grep_result = grep_result.split(\"\\n\")[-1]\n        score = float(grep_result.split(\":\")[-1])\n        best_iteration = int(grep_result.split(\":\")[-2][-6:-1])\n        print(\"{}: {}  ({})\".format(filename, score, best_iteration))\n        total_score += score\n        num_scores += 1\n\nif num_scores > 0:\n    avg = total_score / num_scores\n    print(\"Avg: {}\".format(avg))\n\n"
  },
  {
    "path": "stanza/utils/constituency/grep_test_logs.py",
    "content": "import subprocess\nimport sys\n\nfilenames = sys.argv[1:]\n\ntotal_score = 0.0\nnum_scores = 0\n\nfor filename in filenames:\n    grep_cmd = [\"grep\", \"F1 score.*test.*\", filename]\n    grep_result = subprocess.run(grep_cmd, stdout=subprocess.PIPE, encoding=\"utf-8\")\n    grep_result = grep_result.stdout.strip()\n    if not grep_result:\n        print(\"{}: no result\".format(filename))\n        continue\n\n    score = float(grep_result.split()[-1])\n    print(\"{}: {}\".format(filename, score))\n    total_score += score\n    num_scores += 1\n\nif num_scores > 0:\n    avg = total_score / num_scores\n    print(\"Avg: {}\".format(avg))\n"
  },
  {
    "path": "stanza/utils/constituency/list_tensors.py",
    "content": "\"\"\"\nLists all the tensors in a constituency model.\n\nCurrently useful in combination with torchshow for displaying a series of tensors as they change.\n\"\"\"\n\nimport sys\n\nfrom stanza.models.constituency.trainer import Trainer\n\n\ntrainer = Trainer.load(sys.argv[1])\nmodel = trainer.model\n\nfor name, param in model.named_parameters():\n    print(name, param.requires_grad)\n"
  },
  {
    "path": "stanza/utils/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/common.py",
    "content": "\nimport argparse\nfrom enum import Enum\nimport glob\nimport logging\nimport os\nimport re\nimport subprocess\nimport sys\nimport unicodedata\n\nfrom stanza.models.common.short_name_to_treebank import canonical_treebank_name\nimport stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data\nimport stanza.utils.datasets.conllu_to_text as conllu_to_text\nimport stanza.utils.default_paths as default_paths\n\nlogger = logging.getLogger('stanza')\n\n# RE to see if the index of a conllu line represents an MWT\nMWT_RE = re.compile(\"^[0-9]+[-][0-9]+\")\n\n# RE to see if the index of a conllu line represents an MWT or copy node\nMWT_OR_COPY_RE = re.compile(\"^[0-9]+[-.][0-9]+\")\n\n# more restrictive than an actual int as we expect certain formats in the conllu files\nINT_RE = re.compile(\"^[0-9]+$\")\n\nclass ModelType(Enum):\n    TOKENIZER        = 1\n    MWT              = 2\n    POS              = 3\n    LEMMA            = 4\n    DEPPARSE         = 5\n\nclass UnknownDatasetError(ValueError):\n    def __init__(self, dataset, text):\n        super().__init__(text)\n        self.dataset = dataset\n\ndef convert_conllu_to_txt(tokenizer_dir, short_name, shards=(\"train\", \"dev\", \"test\")):\n    \"\"\"\n    Convert the conllu documents for this dataset to a .txt format\n\n    This follows the old conllu_to_text.pl script, except we never\n    used the ZH option anyway, so we didn't reimplement it here\n    \"\"\"\n    for dataset in shards:\n        output_conllu = f\"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu\"\n        output_txt = f\"{tokenizer_dir}/{short_name}.{dataset}.txt\"\n\n        if not os.path.exists(output_conllu):\n            raise FileNotFoundError(\"Cannot convert %s as the file cannot be found\" % output_conllu)\n        conllu_to_text.main([output_conllu, output_txt])\n\ndef strip_accents(word):\n    \"\"\"\n    Remove diacritics from words such as in the UD GRC datasets\n    \"\"\"\n    converted = ''.join(c for c in unicodedata.normalize('NFD', word)\n                        if unicodedata.category(c) not in ('Mn'))\n    if len(converted) == 0:\n        return word\n    return converted\n\ndef mwt_name(base_dir, short_name, dataset):\n    return os.path.join(base_dir, f\"{short_name}-ud-{dataset}-mwt.json\")\n\ndef tokenizer_conllu_name(base_dir, short_name, dataset):\n    return os.path.join(base_dir, f\"{short_name}.{dataset}.gold.conllu\")\n\ndef prepare_tokenizer_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset):\n    labels_filename = f\"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels\"\n    mwt_filename = mwt_name(tokenizer_dir, short_name, dataset)\n    prepare_tokenizer_data.main([input_txt,\n                                 input_conllu,\n                                 \"-o\", labels_filename,\n                                 \"-m\", mwt_filename])\n\ndef prepare_tokenizer_treebank_labels(tokenizer_dir, short_name):\n    \"\"\"\n    Given the txt and gold.conllu files, prepare mwt and label files for train/dev/test\n    \"\"\"\n    for dataset in (\"train\", \"dev\", \"test\"):\n        output_txt = f\"{tokenizer_dir}/{short_name}.{dataset}.txt\"\n        output_conllu = f\"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu\"\n        try:\n            prepare_tokenizer_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset)\n        except (KeyboardInterrupt, SystemExit):\n            raise\n        except:\n            print(\"Failed to convert %s to %s\" % (output_txt, output_conllu))\n            raise\n\ndef read_sentences_from_conllu(filename):\n    \"\"\"\n    Reads a conllu file as a list of list of strings\n\n    Finding a blank line separates the lists\n    \"\"\"\n    sents = []\n    cache = []\n    with open(filename, encoding=\"utf-8\") as infile:\n        for line in infile:\n            line = line.strip()\n            if len(line) == 0:\n                if len(cache) > 0:\n                    sents.append(cache)\n                    cache = []\n                continue\n            cache.append(line)\n        if len(cache) > 0:\n            sents.append(cache)\n    return sents\n\ndef maybe_add_fake_dependencies(lines):\n    \"\"\"\n    Possibly add fake dependencies in columns 6 and 7 (counting from 0)\n\n    The conllu scripts need the dependencies column filled out, so in\n    the case of models we build without dependency data, we need to\n    add those fake dependencies in order to use the eval script etc\n\n    lines: a list of strings with 10 tab separated columns\n      comments are allowed (they will be skipped)\n\n    returns: the same strings, but with fake dependencies added\n      if columns 6 and 7 were empty\n    \"\"\"\n    new_lines = []\n    root_idx = None\n    first_idx = None\n    for line_idx, line in enumerate(lines):\n        if line.startswith(\"#\"):\n            new_lines.append(line)\n            continue\n\n        pieces = line.split(\"\\t\")\n        if MWT_OR_COPY_RE.match(pieces[0]):\n            new_lines.append(line)\n            continue\n\n        token_idx = int(pieces[0])\n        if pieces[6] != '_':\n            if pieces[6] == '0':\n                root_idx = token_idx\n            new_lines.append(line)\n        elif token_idx == 1:\n            # note that the comments might make this not the first line\n            # we keep track of this separately so we can either make this the root,\n            # or set this to be the root later\n            first_idx = line_idx\n            new_lines.append(pieces)\n        else:\n            pieces[6] = \"1\"\n            pieces[7] = \"dep\"\n            new_lines.append(\"\\t\".join(pieces))\n    if first_idx is not None:\n        if root_idx is None:\n            new_lines[first_idx][6] = \"0\"\n            new_lines[first_idx][7] = \"root\"\n        else:\n            new_lines[first_idx][6] = str(root_idx)\n            new_lines[first_idx][7] = \"dep\"\n        new_lines[first_idx] = \"\\t\".join(new_lines[first_idx])\n    return new_lines\n\ndef write_sentences_to_file(outfile, sents):\n    for lines in sents:\n        lines = maybe_add_fake_dependencies(lines)\n        for line in lines:\n            print(line, file=outfile)\n        print(\"\", file=outfile)\n\ndef write_sentences_to_conllu(filename, sents):\n    with open(filename, 'w', encoding=\"utf-8\") as outfile:\n        write_sentences_to_file(outfile, sents)\n\ndef find_treebank_dataset_file(treebank, udbase_dir, dataset, extension, fail=False, env_var=\"UDBASE\"):\n    \"\"\"\n    For a given treebank, dataset, extension, look for the exact filename to use.\n\n    Sometimes the short name we use is different from the short name\n    used by UD.  For example, Norwegian or Chinese.  Hence the reason\n    to not hardcode it based on treebank\n\n    set fail=True to fail if the file is not found\n    \"\"\"\n    if treebank.startswith(\"UD_Korean\") and treebank.endswith(\"_seg\"):\n        treebank = treebank[:-4]\n    if treebank.startswith(\"UD_Ancient_Greek-\") and (treebank.endswith(\"-Diacritics\") or treebank.endswith(\"-diacritics\")):\n        treebank = treebank[:-11]\n    filename = os.path.join(udbase_dir, treebank, f\"*-ud-{dataset}.{extension}\")\n    files = glob.glob(filename)\n    if len(files) == 0:\n        if fail:\n            raise FileNotFoundError(\"Could not find any treebank files which matched {}\\nIf you have the data elsewhere, you can change the base directory for the search by changing the {} environment variable\".format(filename, env_var))\n        else:\n            return None\n    elif len(files) == 1:\n        return files[0]\n    else:\n        raise RuntimeError(f\"Unexpected number of files matched '{udbase_dir}/{treebank}/*-ud-{dataset}.{extension}'\")\n\ndef mostly_underscores(filename):\n    \"\"\"\n    Certain treebanks have proprietary data, so the text is hidden\n\n    For example:\n      UD_Arabic-NYUAD\n      UD_English-ESL\n      UD_English-GUMReddit\n      UD_Hindi_English-HIENCS\n      UD_Japanese-BCCWJ\n    \"\"\"\n    underscore_count = 0\n    total_count = 0\n    for line in open(filename).readlines():\n        line = line.strip()\n        if not line:\n            continue\n        if line.startswith(\"#\"):\n            continue\n        total_count = total_count + 1\n        pieces = line.split(\"\\t\")\n        if pieces[1] in (\"_\", \"-\"):\n            underscore_count = underscore_count + 1\n    return underscore_count / total_count > 0.5\n\ndef num_words_in_file(conllu_file):\n    \"\"\"\n    Count the number of non-blank lines in a conllu file\n    \"\"\"\n    count = 0\n    with open(conllu_file) as fin:\n        for line in fin:\n            line = line.strip()\n            if not line:\n                continue\n            if line.startswith(\"#\"):\n                continue\n            count = count + 1\n    return count\n\n\ndef get_test_only_ud_treebanks(udbase_dir, filtered=True):\n    \"\"\"\n    Looks in udbase_dir for all the treebanks which are *only* test sets, but might be big enough\n\n    Filters out:\n    - less than 10000 words\n    - the language already has a larger treebank we can use\n\n    The second filter takes quite some time, as there is a check that\n    goes through all the text in the treebank\n    \"\"\"\n    treebanks = sorted(glob.glob(udbase_dir + \"/UD_*\"))\n    # skip UD_English-GUMReddit as it is usually incorporated into UD_English-GUM\n    treebanks = [os.path.split(t)[1] for t in treebanks]\n    treebanks = [t for t in treebanks if t != \"UD_English-GUMReddit\"]\n\n    # only take the ones which do have test, but don't have train\n    treebanks = [t for t in treebanks if not find_treebank_dataset_file(t, udbase_dir, \"train\", \"conllu\")]\n    treebanks = [t for t in treebanks if find_treebank_dataset_file(t, udbase_dir, \"test\", \"conllu\")]\n\n    treebanks = [t for t in treebanks if not mostly_underscores(find_treebank_dataset_file(t, udbase_dir, \"test\", \"conllu\"))]\n    if any(find_treebank_dataset_file(t, udbase_dir, \"dev\", \"conllu\") for t in treebanks):\n        raise AssertionError(\"Found a treebank with dev and test, but no train.  This violates our expectations\")\n    treebanks = [t for t in treebanks if num_words_in_file(find_treebank_dataset_file(t, udbase_dir, \"test\", \"conllu\")) > 10000]\n\n    if filtered:\n        treebanks = [t for t in treebanks\n                     if len(get_ud_treebanks(udbase_dir, lang=t.split(\"-\")[0])) == 0]\n    #for t in treebanks:\n    #    print(t,\n    #          num_words_in_file(find_treebank_dataset_file(t, udbase_dir, \"test\", \"conllu\")))\n    return treebanks\n\ndef get_ud_treebanks(udbase_dir, lang=None, filtered=True):\n    \"\"\"\n    Looks in udbase_dir for all the treebanks which have both train, dev, and test\n\n    If specified, lang should be exactly UD_English or however the treebanks appear in the UD release\n    \"\"\"\n    if lang is None:\n        treebanks = sorted(glob.glob(udbase_dir + \"/UD_*\"))\n    else:\n        treebanks = sorted(glob.glob(\"%s/%s*\" % (udbase_dir, lang)))\n    # skip UD_English-GUMReddit as it is usually incorporated into UD_English-GUM\n    treebanks = [os.path.split(t)[1] for t in treebanks]\n    treebanks = [t for t in treebanks if t != \"UD_English-GUMReddit\"]\n    if filtered:\n        treebanks = [t for t in treebanks\n                     if (find_treebank_dataset_file(t, udbase_dir, \"train\", \"conllu\") and\n                         # this will be fixed using XV\n                         #find_treebank_dataset_file(t, udbase_dir, \"dev\", \"conllu\") and\n                         find_treebank_dataset_file(t, udbase_dir, \"test\", \"conllu\"))]\n        treebanks = [t for t in treebanks\n                     if not mostly_underscores(find_treebank_dataset_file(t, udbase_dir, \"train\", \"conllu\"))]\n        # eliminate partial treebanks (fixed with XV) for which we only have 1000 words or less\n        # if the train set is small and the test set is large enough, we'll flip them\n        treebanks = [t for t in treebanks\n                     if (find_treebank_dataset_file(t, udbase_dir, \"dev\", \"conllu\") or\n                         num_words_in_file(find_treebank_dataset_file(t, udbase_dir, \"train\", \"conllu\")) > 1000 or\n                         num_words_in_file(find_treebank_dataset_file(t, udbase_dir, \"test\", \"conllu\")) > 5000)]\n    return treebanks\n\ndef build_argparse():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on.  Use all_ud or ud_all for all UD treebanks')\n\n    return parser\n\n\ndef main(process_treebank, model_type, add_specific_args=None):\n    logger.info(\"Datasets program called with:\\n\" + \" \".join(sys.argv))\n\n    parser = build_argparse()\n    if add_specific_args is not None:\n        add_specific_args(parser)\n    args = parser.parse_args()\n\n    paths = default_paths.get_default_paths()\n\n    treebanks = []\n    for treebank in args.treebanks:\n        if treebank.lower() in ('ud_all', 'all_ud'):\n            ud_treebanks = get_ud_treebanks(paths[\"UDBASE\"])\n            treebanks.extend(ud_treebanks)\n        else:\n            # If this is a known UD short name, use the official name (we need it for the paths)\n            treebank = canonical_treebank_name(treebank)\n            treebanks.append(treebank)\n\n    for treebank in treebanks:\n        process_treebank(treebank, model_type, paths, args)\n"
  },
  {
    "path": "stanza/utils/datasets/conllu_to_text.py",
    "content": "\nimport argparse\nimport re\n\nTEXT_RE = re.compile(\"^#\\\\s*text\")\nNEWPAR_RE = re.compile(\"^#\\\\s*newpar\")\nNEWDOC_RE = re.compile(\"^#\\\\s*newdoc\")\n\nMWT_RE = re.compile(\"^\\\\d+-(\\\\d+)\\t\")\nWORD_RE = re.compile(\"^(\\\\d)+\\t\")\n\nWORD_NEWPAR_RE = re.compile(\"NewPar=Yes\")\nSPACEAFTER_RE = re.compile(\"SpaceAfter=No\")\n\ndef print_new_paragraph_if_needed(fout, start, newdoc, newpar, output_buffer):\n    if not start and (newdoc or newpar):\n        if output_buffer:\n            fout.write(output_buffer)\n            fout.write(\"\\n\")\n        fout.write(\"\\n\")\n        return \"\"\n    return output_buffer\n\ndef print_lines_from_buffer(fout, output_buffer, max_len):\n    while len(output_buffer) >= max_len:\n        split_idx = None\n        for idx in range(len(output_buffer)):\n            if idx > max_len and split_idx is not None:\n                break\n            if output_buffer[idx].isspace():\n                split_idx = idx\n        if split_idx is not None:\n            fout.write(output_buffer[:split_idx])\n            fout.write(\"\\n\")\n            output_buffer = output_buffer[split_idx+1:]\n        else:\n            fout.write(output_buffer)\n            fout.write(\"\\n\")\n            output_buffer = \"\"\n    return output_buffer\n\ndef convert_text(conllu_file, output_file):\n    with open(conllu_file, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n\n    with open(output_file, \"w\", encoding=\"utf-8\") as fout:\n        newpar = False\n        newdoc = False\n        start = True\n\n        in_mwt = False\n        mwt_last = None\n\n        def print_and_reset(output_buffer, incoming_buffer):\n            nonlocal start, newpar, newdoc, in_mwt\n\n            output_buffer = print_new_paragraph_if_needed(fout, start, newdoc, newpar, output_buffer)\n            output_buffer += incoming_buffer\n            output_buffer = print_lines_from_buffer(fout, output_buffer, 80)\n            start = False\n            newpar = False\n            newdoc = False\n            in_mwt = False\n            return output_buffer\n\n        output_buffer = \"\"\n        incoming_buffer = \"\"\n\n        for line in lines:\n            line = line.strip()\n\n            if not line:\n                output_buffer = print_and_reset(output_buffer, incoming_buffer)\n                incoming_buffer = \"\"\n\n            if TEXT_RE.match(line):\n                # we ignore the #text and extract the text from the tokens\n                continue\n\n            if NEWPAR_RE.match(line):\n                newpar = True\n                continue\n\n            if NEWDOC_RE.match(line):\n                newdoc = True\n                continue\n\n            match = MWT_RE.match(line)\n            if match:\n                in_mwt = True\n                mwt_last = int(match.group(1))\n                pieces = line.split(\"\\t\")\n\n                if WORD_NEWPAR_RE.search(pieces[9]):\n                    output_buffer = print_and_reset(output_buffer, incoming_buffer)\n                    incoming_buffer = \"\"\n                    fout.write(output_buffer)\n                    fout.write(\"\\n\\n\")\n                    output_buffer = \"\"\n\n                incoming_buffer += pieces[1]\n                if not SPACEAFTER_RE.search(pieces[9]):\n                    incoming_buffer += \" \"\n                continue\n\n            match = WORD_RE.match(line)\n            if match:\n                pieces = line.split(\"\\t\")\n                word_id = int(pieces[0])\n                if in_mwt and word_id <= mwt_last:\n                    continue\n                in_mwt = False\n\n                if WORD_NEWPAR_RE.search(pieces[9]):\n                    output_buffer = print_and_reset(output_buffer, incoming_buffer)\n                    incoming_buffer = \"\"\n                    fout.write(output_buffer)\n                    fout.write(\"\\n\\n\")\n                    output_buffer = \"\"\n\n                incoming_buffer += pieces[1]\n                if not SPACEAFTER_RE.search(pieces[9]):\n                    incoming_buffer += \" \"\n                continue\n        if output_buffer != \"\":\n            fout.write(output_buffer)\n            fout.write(\"\\n\")\n\ndef main(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument('conllu_file', type=str, help=\"CoNLL-U file containing tokens and sentence breaks\")\n    parser.add_argument('output_file', type=str, help=\"Plaintext file containing the raw input\")\n    args = parser.parse_args(args)\n\n    convert_text(args.conllu_file, args.output_file)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/constituency/build_silver_dataset.py",
    "content": "\"\"\"\nGiven two ensembles and a tokenized file, output the trees for which those ensembles agree and report how many of the sub-models agree on those trees.\n\nFor example:\n\npython3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_AA.txt --lang it --output_file asdf.out --e1 saved_models/constituency/it_vit_electra_100?_top_constituency.pt --e2 saved_models/constituency/it_vit_electra_100?_constituency.pt\n\nfor i in `echo f g h i j k l m n o p q r s t`; do nlprun -d a6000 \"python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tok_6M_a$i.txt --lang it --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a$i.trees --e1 saved_models/constituency/it_vit_electra_100?_top_constituency.pt --e2 saved_models/constituency/it_vit_electra_100?_constituency.pt\" -o /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a$i.out; done\n\nfor i in `echo a b c d`; do nlprun -d a6000 \"python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/english/en_wiki_2023/shuf_1M.a$i --lang en --output_file /u/nlp/data/constituency-parser/english/2024_en_ptb3_electra/forward_a$i.trees --e1 saved_models/constituency/en_ptb3_electra-large_100?_in_constituency.pt --e2 saved_models/constituency/en_ptb3_electra-large_100?_top_constituency.pt\" -o /u/nlp/data/constituency-parser/english/2024_en_ptb3_electra/forward_a$i.out; done\n\"\"\"\n\nimport argparse\nimport json\n\nimport logging\n\nfrom stanza.models.common import utils\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.models.constituency import retagging\nfrom stanza.models.constituency import text_processing\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.ensemble import Ensemble\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nlogger = logging.getLogger('stanza.constituency.trainer')\n\ndef parse_args(args=None):\n    parser = argparse.ArgumentParser(description=\"Script that uses multiple ensembles to find trees where both ensembles agree\")\n\n    input_group = parser.add_mutually_exclusive_group(required=True)\n    input_group.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')\n    input_group.add_argument('--tree_file', type=str, default=None, help='Input file of already parsed text for reparsing with parse_text.')\n    parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file')\n\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n\n    utils.add_device_args(parser)\n\n    parser.add_argument('--lang', default='en', help='Language to use')\n\n    parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')\n    parser.add_argument('--e1', type=str, nargs='+', default=None, help=\"Which model(s) to load in the first ensemble\")\n    parser.add_argument('--e2', type=str, nargs='+', default=None, help=\"Which model(s) to load in the second ensemble\")\n\n    parser.add_argument('--mode', default='predict', choices=['parse_text', 'predict'])\n\n    # another option would be to include the tree idx in each entry in an existing saved file\n    # the processing could then pick up at exactly the last known idx\n    parser.add_argument('--start_tree', type=int, default=0, help='Where to start... most useful if the previous incarnation crashed')\n    parser.add_argument('--end_tree', type=int, default=None, help='Where to end.  If unset, will process to the end of the file')\n\n    retagging.add_retag_args(parser)\n\n    args = vars(parser.parse_args())\n\n    retagging.postprocess_args(args)\n    args['num_generate'] = 0\n\n    return args\n\ndef main():\n    args = parse_args()\n    utils.log_training_args(args, logger, name=\"ensemble\")\n\n    retag_pipeline = retagging.build_retag_pipeline(args)\n    foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()\n\n    logger.info(\"Building ensemble #1 out of %s\", args['e1'])\n    e1 = Ensemble(args, filenames=args['e1'], foundation_cache=foundation_cache)\n    e1.to(args.get('device', None))\n    logger.info(\"Building ensemble #2 out of %s\", args['e2'])\n    e2 = Ensemble(args, filenames=args['e2'], foundation_cache=foundation_cache)\n    e2.to(args.get('device', None))\n\n    if args['tokenized_file']:\n        tokenized_sentences, _ = text_processing.read_tokenized_file(args['tokenized_file'])\n    elif args['tree_file']:\n        treebank = tree_reader.read_treebank(args['tree_file'])\n        tokenized_sentences = [x.leaf_labels() for x in treebank]\n        if args['lang'] == 'vi':\n            tokenized_sentences = [[x.replace(\"_\", \" \") for x in sentence] for sentence in tokenized_sentences]\n    logger.info(\"Read %d tokenized sentences\", len(tokenized_sentences))\n\n    all_models = e1.models + e2.models\n\n    chunk_size = 1000\n    with open(args['output_file'], 'w', encoding='utf-8') as fout:\n        end_tree = len(tokenized_sentences) if args['end_tree'] is None else args['end_tree']\n        for chunk_start in tqdm(range(args['start_tree'], end_tree, chunk_size)):\n            chunk = tokenized_sentences[chunk_start:chunk_start+chunk_size]\n            logger.info(\"Processing trees %d to %d\", chunk_start, chunk_start+len(chunk))\n            parsed1 = text_processing.parse_tokenized_sentences(args, e1, retag_pipeline, chunk)\n            parsed1 = [x.predictions[0].tree for x in parsed1]\n            parsed2 = text_processing.parse_tokenized_sentences(args, e2, retag_pipeline, chunk)\n            parsed2 = [x.predictions[0].tree for x in parsed2]\n            matching = [t for t, t2 in zip(parsed1, parsed2) if t == t2]\n            logger.info(\"%d trees matched\", len(matching))\n            model_counts = [0] * len(matching)\n            for model in all_models:\n                model_chunk = model.parse_sentences_no_grad(iter(matching), model.build_batch_from_trees, args['eval_batch_size'], model.predict)\n                model_chunk = [x.predictions[0].tree for x in model_chunk]\n                for idx, (t1, t2) in enumerate(zip(matching, model_chunk)):\n                    if t1 == t2:\n                        model_counts[idx] += 1\n            for count, tree in zip(model_counts, matching):\n                line = {\"tree\": \"%s\" % tree, \"count\": count}\n                fout.write(json.dumps(line))\n                fout.write(\"\\n\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/common_trees.py",
    "content": "\"\"\"\nLook through 2 files, only output the common trees\n\npretty basic - could use some more options\n\"\"\"\n\nimport sys\n\ndef main():\n    in1 = sys.argv[1]\n    with open(in1, encoding=\"utf-8\") as fin:\n        lines1 = fin.readlines()\n    in2 = sys.argv[2]\n    with open(in2, encoding=\"utf-8\") as fin:\n        lines2 = fin.readlines()\n\n    common = [l1 for l1, l2 in zip(lines1, lines2) if l1 == l2]\n    for l in common:\n        print(l.strip())\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_alt.py",
    "content": "\"\"\"\nRead files of parses and the files which define the train/dev/test splits\n\nWrite out the files after splitting them\n\nSequence of operations:\n  - read the raw lines from the input files\n  - read the recommended splits, as per the ALT description page\n  - separate the trees using the recommended split files\n  - write back the trees\n\"\"\"\n\ndef read_split_file(split_file):\n    \"\"\"\n    Read a split file for ALT\n\n    The format of the file is expected to be a list of lines such as\n    URL.1234    <url>\n    Here, we only care about the id\n\n    return: a set of the ids\n    \"\"\"\n    with open(split_file, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n    lines = [x.strip() for x in lines]\n    lines = [x.split()[0] for x in lines if x]\n    if any(not x.startswith(\"URL.\") for x in lines):\n        raise ValueError(\"Unexpected line in %s: %s\" % (split_file, x))\n    split = set(int(x.split(\".\", 1)[1]) for x in lines)\n    return split\n\ndef split_trees(all_lines, splits):\n    \"\"\"\n    Splits lines of the form\n    SNT.17873.4049\t(S ...\n    then assigns them to a list based on the file id in\n    SNT.<file>.<sent>\n    \"\"\"\n    trees = [list() for _ in splits]\n    for line in all_lines:\n        tree_id, tree_text = line.split(maxsplit=1)\n        tree_id = int(tree_id.split(\".\", 2)[1])\n        for split_idx, split in enumerate(splits):\n            if tree_id in split:\n                trees[split_idx].append(tree_text)\n                break\n        else:\n            # couldn't figure out which split to put this in\n            raise ValueError(\"Couldn't find which split this line goes in:\\n%s\" % line)\n    return trees\n\ndef read_alt_lines(input_files):\n    \"\"\"\n    Read the trees from the given file(s)\n\n    Any trees with wide spaces are eliminated.  The parse tree\n    handling doesn't handle it well and the tokenizer won't produce\n    tokens which are entirely wide spaces anyway\n\n    The tree lines are not processed into trees, though\n    \"\"\"\n    all_lines = []\n    for input_file in input_files:\n        with open(input_file, encoding=\"utf-8\") as fin:\n            all_lines.extend(fin.readlines())\n    all_lines = [x.strip() for x in all_lines]\n    all_lines = [x for x in all_lines if x]\n    original_count = len(all_lines)\n    # there is 1 tree with wide space as an entire token, and 4 with wide spaces at the end of a token\n    all_lines = [x for x in all_lines if not \"　\" in x]\n    new_count = len(all_lines)\n    if new_count < original_count:\n        print(\"Eliminated %d trees for having wide spaces in it\" % ((original_count - new_count)))\n        original_count = new_count\n    all_lines = [x for x in all_lines if not \"\\\\x\" in x]\n    new_count = len(all_lines)\n    if new_count < original_count:\n        print(\"Eliminated %d trees for not being correctly encoded\" % ((original_count - new_count)))\n        original_count = new_count\n    return all_lines\n\ndef convert_alt(input_files, split_files, output_files):\n    \"\"\"\n    Convert the ALT treebank into train/dev/test splits\n\n    input_files: paths to read trees\n    split_files: recommended splits from the ALT page\n    output_files: where to write train/dev/test\n    \"\"\"\n    all_lines = read_alt_lines(input_files)\n\n    splits = [read_split_file(split_file) for split_file in split_files]\n    trees = split_trees(all_lines, splits)\n\n    for chunk, output_file in zip(trees, output_files):\n        print(\"Writing %d trees to %s\" % (len(chunk), output_file))\n        with open(output_file, \"w\", encoding=\"utf-8\") as fout:\n            for tree in chunk:\n                # the extra ROOT is because the ALT doesn't have this at the top of its trees\n                fout.write(\"(ROOT {})\\n\".format(tree))\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_arboretum.py",
    "content": "\"\"\"\nParses a Tiger dataset to PTB\n\nAlso handles problems specific for the Arboretum treebank.\n\n- validation errors in the XML: \n  -- there is a \"&\" instead of an \"&amp;\" early on\n  -- there are tags \"<{note}>\" and \"<{parentes-udeladt}>\" which may or may not be relevant,\n     but are definitely not properly xml encoded\n- trees with stranded nodes.  5 trees have links to words in a different tree.\n  those trees are skipped\n- trees with empty nodes.  58 trees have phrase nodes with no leaves.\n  those trees are skipped\n- trees with missing words.  134 trees have words in the text which aren't in the tree\n  those trees are also skipped\n- trees with categories not in the category directory\n  for example, intj... replaced with fcl?\n  most of these are replaced with what might be a sensible replacement\n- trees with labels that don't have an obvious replacement\n  these trees are eliminated, 4 total\n- underscores in words.  those words are split into multiple words\n  the tagging is not going to be ideal, but the first step of training\n  a parser is usually to retag the words anyway, so this should be okay\n- tree 14729 is really weirdly annotated.  skipped\n- 5373 trees total have non-projective constituents.  These don't work\n  with the stanza parser...  in order to work around this, we rearrange\n  them when possible.\n    ((X Z) Y1 Y2 ...) -> (X Y1 Y2 Z)          this rearranges 3021 trees\n    ((X Z1 ...) Y1 Y2 ...) -> (X Y1 Y2 Z)     this rearranges  403 trees\n    ((X Z1 ...) (tag Y1) ...) -> (X (Y1) Z)   this rearranges 1258 trees\n\n  A couple examples of things which get rearranged\n  (limited in scope and without the words to avoid breaking our license):\n\n(vp (v-fin s4_6) (conj-c s4_8) (v-fin s4_9)) (pron-pers s4_7)\n-->\n(vp (v-fin s4_6) (pron-pers s4_7) (conj-c s4_8) (v-fin s4_9))\n\n(vp (v-fin s1_2) (v-pcp2 s1_4)) (adv s1_3)\n-->\n(vp (v-fin s1_2) (adv s1_3) (v-pcp2 s1_4))\n\n  This process leaves behind 691 trees.  In some cases, the\n  non-projective structure is at a higher level than the attachment.\n  In others, there are nested non-projectivities that are not\n  rearranged by the above pattern.  A couple examples:\n\nhere, the 3-7 nonprojectivity has the 7 in a nested structure\n(s\n (par\n  (n s206_1)\n  (pu s206_2)\n  (fcl\n   (fcl\n    (pron-pers s206_3)\n    (fcl (pron-pers s206_7) (adv s206_8) (v-fin s206_9)))\n   (vp (v-fin s206_4) (v-inf s206_6))\n   (pron-pers s206_5))\n  (pu s206_10)))\n\nhere, 11 is attached at a higher level than 12 & 13\n(s\n (fcl\n  (icl\n   (np\n    (adv s223_1)\n    (np\n     (n s223_2)\n     (pp\n      (prp s223_3)\n      (par\n       (adv s223_4)\n       (prop s223_5)\n       (pu s223_6)\n       (prop s223_7)\n       (conj-c s223_8)\n       (np (adv s223_9) (prop s223_10))))))\n   (vp (infm s223_12) (v-inf s223_13)))\n  (v-fin s223_11)\n  (pu s223_14)))\n\neven if we moved _6 between 2 and 7, we'd then have a completely flat\nstructure when moving 3..5 inside\n(s\n (fcl\n  (xx s499_1)\n  (np\n   (pp (pron-pers s499_2) (prp s499_7))\n   (n s499_6))\n  (v-fin s499_3) (adv s499_4) (adv s499_5) (pu s499_8)))\n\n\"\"\"\n\n\nfrom collections import namedtuple\nimport io\nimport xml.etree.ElementTree as ET\n\nfrom tqdm import tqdm\n\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.server import tsurgeon\n\ndef read_xml_file(input_filename):\n    \"\"\"\n    Convert an XML file into a list of trees - each <s> becomes its own object\n    \"\"\"\n    print(\"Reading {}\".format(input_filename))\n    with open(input_filename, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n\n    sentences = []\n    current_sentence = []\n    in_sentence = False\n    for line_idx, line in enumerate(lines):\n        if line.startswith(\"<s \"):\n            if len(current_sentence) > 0:\n                raise ValueError(\"Found the start of a sentence inside an existing sentence, line {}\".format(line_idx))\n            in_sentence = True\n\n        if in_sentence:\n            current_sentence.append(line)\n\n        if line.startswith(\"</s>\"):\n            assert in_sentence\n            current_sentence = [x.replace(\"<{parentes-udeladt}>\", \"\") for x in current_sentence]\n            current_sentence = [x.replace(\"<{note}>\", \"\") for x in current_sentence]\n            sentences.append(\"\".join(current_sentence))\n            current_sentence = []\n            in_sentence = False\n\n    assert len(current_sentence) == 0\n\n    xml_sentences = []\n    for sent_idx, text in enumerate(sentences):\n        sentence = io.StringIO(text)\n        try:\n            tree = ET.parse(sentence)\n            xml_sentences.append(tree)\n        except ET.ParseError as e:\n            raise ValueError(\"Failed to parse sentence {}\".format(sent_idx))\n\n    return xml_sentences\n\nWord = namedtuple('Word', ['word', 'tag'])\nNode = namedtuple('Node', ['label', 'children'])\n\nclass BrokenLinkError(ValueError):\n    def __init__(self, error):\n        super(BrokenLinkError, self).__init__(error)\n\ndef process_nodes(root_id, words, nodes, visited):\n    \"\"\"\n    Given a root_id, a map of words, and a map of nodes, construct a Tree\n\n    visited is a set of string ids and mutates over the course of the recursive call\n    \"\"\"\n    if root_id in visited:\n        raise ValueError(\"Loop in the tree!\")\n    visited.add(root_id)\n\n    if root_id in words:\n        word = words[root_id]\n        # big brain move: put the root_id here so we can use that to\n        # check the sorted order when we are done\n        word_node = Tree(label=root_id)\n        tag_node = Tree(label=word.tag, children=word_node)\n        return tag_node\n    elif root_id in nodes:\n        node = nodes[root_id]\n        children = [process_nodes(child, words, nodes, visited) for child in node.children]\n        return Tree(label=node.label, children=children)\n    else:\n        raise BrokenLinkError(\"Unknown id! {}\".format(root_id))\n\ndef check_words(tree, tsurgeon_processor):\n    \"\"\"\n    Check that the words of a sentence are in order\n\n    If they are not, this applies a tsurgeon to rearrange simple cases\n    The tsurgeon looks at the gap between words, eg _3 to _7, and looks\n    for the words between, such as _4 _5 _6.  if those words are under\n    a node at the same level as the 3-7 node and does not include any\n    other nodes (such as _8), that subtree is moved to between _3 and _7\n\n    Example:\n\n    (vp (v-fin s4_6) (conj-c s4_8) (v-fin s4_9)) (pron-pers s4_7)\n    -->\n    (vp (v-fin s4_6) (pron-pers s4_7) (conj-c s4_8) (v-fin s4_9))\n    \"\"\"\n    while True:\n        words = tree.leaf_labels()\n        indices = [int(w.split(\"_\", 1)[1]) for w in words]\n        for word_idx, word_label in enumerate(indices):\n            if word_idx != word_label - 1:\n                break\n        else:\n            # if there are no weird indices, keep the tree\n            return tree\n\n        sorted_indices = sorted(indices)\n        if indices == sorted_indices:\n            raise ValueError(\"Skipped index!  This should already be accounted for  {}\".format(tree))\n\n        if word_idx == 0:\n            return None\n\n        prefix = words[0].split(\"_\", 1)[0]\n        prev_idx = word_idx - 1\n        prev_label = indices[prev_idx]\n        missing_words = [\"%s_%d\" % (prefix, x) for x in range(prev_label + 1, word_label)]\n        missing_words = \"|\".join(missing_words)\n        #move_tregex = \"%s > (__=home > (__=parent > __=grandparent)) . (%s > (__=move > =grandparent))\" % (words[word_idx], \"|\".join(missing_words))\n        move_tregex = \"%s > (__=home > (__=parent << %s $+ (__=move <<, %s <<- %s)))\" % (words[word_idx], words[prev_idx], missing_words, missing_words)\n        move_tsurgeon = \"move move $+ home\"\n        modified = tsurgeon_processor.process(tree, move_tregex, move_tsurgeon)[0]\n        if modified == tree:\n            # this only happens if the desired fix didn't happen\n            #print(\"Failed to process:\\n  {}\\n  {} {}\".format(tree, prev_label, word_label))\n            return None\n\n        tree = modified\n\ndef replace_words(tree, words):\n    \"\"\"\n    Remap the leaf words given a map of the labels we expect in the leaves\n    \"\"\"\n    leaves = tree.leaf_labels()\n    new_words = [words[w].word for w in leaves]\n    new_tree = tree.replace_words(new_words)\n    return new_tree\n\ndef process_tree(sentence):\n    \"\"\"\n    Convert a single ET element representing a Tiger tree to a parse tree\n    \"\"\"\n    sentence = sentence.getroot()\n    sent_id = sentence.get(\"id\")\n    if sent_id is None:\n        raise ValueError(\"Tree {} does not have an id\".format(sent_id))\n    if len(sentence) > 1:\n        raise ValueError(\"Longer than expected number of items in {}\".format(sent_id))\n    graph = sentence.find(\"graph\")\n    if graph is None:\n        raise ValueError(\"Unexpected tree structure in {} : top tag is not 'graph'\".format(sent_id))\n\n    root_id = graph.get(\"root\")\n    if root_id is None:\n        raise ValueError(\"Tree has no root id in {}\".format(sent_id))\n\n    terminals = graph.find(\"terminals\")\n    if terminals is None:\n        raise ValueError(\"No terminals in tree {}\".format(sent_id))\n    # some Arboretum graphs have two sets of nonterminals,\n    # apparently intentionally, so we ignore that possible error\n    nonterminals = graph.find(\"nonterminals\")\n    if nonterminals is None:\n        raise ValueError(\"No nonterminals in tree {}\".format(sent_id))\n\n    # read the words.  the words have ids, text, and tags which we care about\n    words = {}\n    for word in terminals:\n        if word.tag == 'parentes-udeladt' or word.tag == 'note':\n            continue\n        if word.tag != \"t\":\n            raise ValueError(\"Unexpected tree structure in {} : word with tag other than t\".format(sent_id))\n        word_id = word.get(\"id\")\n        if not word_id:\n            raise ValueError(\"Word had no id in {}\".format(sent_id))\n        word_text = word.get(\"word\")\n        if not word_text:\n            raise ValueError(\"Word had no text in {}\".format(sent_id))\n        word_pos = word.get(\"pos\")\n        if not word_pos:\n            raise ValueError(\"Word had no pos in {}\".format(sent_id))\n        words[word_id] = Word(word_text, word_pos)\n\n    # read the nodes.  the nodes have ids, labels, and children\n    # some of the edges are labeled \"secedge\".  we ignore those\n    nodes = {}\n    for nt in nonterminals:\n        if nt.tag != \"nt\":\n            raise ValueError(\"Unexpected tree structure in {} : node with tag other than nt\".format(sent_id))\n        nt_id = nt.get(\"id\")\n        if not nt_id:\n            raise ValueError(\"NT has no id in {}\".format(sent_id))\n        nt_label = nt.get(\"cat\")\n        if not nt_label:\n            raise ValueError(\"NT has no label in {}\".format(sent_id))\n\n        children = []\n        for child in nt:\n            if child.tag != \"edge\" and child.tag != \"secedge\":\n                raise ValueError(\"NT has unexpected child in {} : {}\".format(sent_id, child.tag))\n            if child.tag == \"edge\":\n                child_id = child.get(\"idref\")\n                if not child_id:\n                    raise ValueError(\"Child is missing an id in {}\".format(sent_id))\n                children.append(child_id)\n        nodes[nt_id] = Node(nt_label, children)\n\n    if root_id not in nodes:\n        raise ValueError(\"Could not find root in nodes in {}\".format(sent_id))\n\n    tree = process_nodes(root_id, words, nodes, set())\n    return tree, words\n\ndef word_sequence_missing_words(tree):\n    \"\"\"\n    Check if the word sequence is missing words\n\n    Some trees skip labels, such as\n      (s (fcl (pron-pers s16817_1) (v-fin s16817_2) (prp s16817_3) (pp (prp s16817_5) (par (n s16817_6) (conj-c s16817_7) (n s16817_8))) (pu s16817_9)))\n    but in these cases, the word is present in the original text and simply not attached to the tree\n    \"\"\"\n    words = tree.leaf_labels()\n    indices = [int(w.split(\"_\")[1]) for w in words]\n    indices = sorted(indices)\n    for idx, label in enumerate(indices):\n        if label != idx + 1:\n            return True\n    return False\n\nWORD_TO_PHRASE = {\n    \"art\": \"advp\",    # \"en smule\" is the one time this happens. it is used as an advp elsewhere\n    \"adj\": \"adjp\",\n    \"adv\": \"advp\",\n    \"conj\": \"cp\",\n    \"intj\": \"fcl\",    # not sure?  seems to match \"hold kæft\" when it shows up\n    \"n\": \"np\",\n    \"num\": \"np\",      # would prefer something like QP from PTB\n    \"pron\": \"np\",     # ??\n    \"prop\": \"np\",\n    \"prp\": \"pp\",\n    \"v\": \"vp\",\n}\n\ndef split_underscores(tree):\n    assert not tree.is_leaf(), \"Should never reach a leaf in this code path\"\n\n    if tree.is_preterminal():\n        return tree\n\n    children = tree.children\n    new_children = []\n    for child in children:\n        if child.is_preterminal():\n            if '_' not in child.children[0].label:\n                new_children.append(child)\n                continue\n\n            if child.label.split(\"-\")[0] not in WORD_TO_PHRASE:\n                raise ValueError(\"SPLITTING {}\".format(child))\n            pieces = []\n            for piece in child.children[0].label.split(\"_\"):\n                # This may not be accurate, but we already retag the treebank anyway\n                if len(piece) == 0:\n                    raise ValueError(\"A word started or ended with _\")\n                pieces.append(Tree(child.label, Tree(piece)))\n            new_children.append(Tree(WORD_TO_PHRASE[child.label.split(\"-\")[0]], pieces))\n        else:\n            new_children.append(split_underscores(child))\n\n    return Tree(tree.label, new_children)\n\nREMAP_LABELS = {\n    \"adj\": \"adjp\",\n    \"adv\": \"advp\",\n    \"intj\": \"fcl\",\n    \"n\": \"np\",\n    \"num\": \"np\",     # again, a dedicated number node would be better, but there are only a few \"num\" labeled\n    \"prp\": \"pp\",\n}\n\n\ndef has_weird_constituents(tree):\n    \"\"\"\n    Eliminate a few trees with weird labels\n\n    Eliminate p?  there are only 3 and they have varying structure underneath\n    Also cl, since I have no idea how to label it and it only excludes 1 tree\n    \"\"\"\n    labels = Tree.get_unique_constituent_labels(tree)\n    if \"p\" in labels or \"cl\" in labels:\n        return True\n    return False\n\ndef convert_tiger_treebank(input_filename):\n    sentences = read_xml_file(input_filename)\n\n    unfixable = 0\n    dangling = 0\n    broken_links = 0\n    missing_words = 0\n    weird_constituents = 0\n    trees = []\n\n    with tsurgeon.Tsurgeon() as tsurgeon_processor:\n        for sent_idx, sentence in enumerate(tqdm(sentences)):\n            try:\n                tree, words = process_tree(sentence)\n\n                if not tree.all_leaves_are_preterminals():\n                    dangling += 1\n                    continue\n\n                if word_sequence_missing_words(tree):\n                    missing_words += 1\n                    continue\n\n                tree = check_words(tree, tsurgeon_processor)\n                if tree is None:\n                    unfixable += 1\n                    continue\n\n                if has_weird_constituents(tree):\n                    weird_constituents += 1\n                    continue\n\n                tree = replace_words(tree, words)\n                tree = split_underscores(tree)\n                tree = tree.remap_constituent_labels(REMAP_LABELS)\n                trees.append(tree)\n            except BrokenLinkError as e:\n                # the get(\"id\") would have failed as a different error type if missing,\n                # so we can safely use it directly like this\n                broken_links += 1\n                # print(\"Unable to process {} because of broken links: {}\".format(sentence.getroot().get(\"id\"), e))\n\n    print(\"Found {} trees with empty nodes\".format(dangling))\n    print(\"Found {} trees with unattached words\".format(missing_words))\n    print(\"Found {} trees with confusing constituent labels\".format(weird_constituents))\n    print(\"Not able to rearrange {} nodes\".format(unfixable))\n    print(\"Unable to handle {} trees because of broken links, eg names in another tree\".format(broken_links))\n    print(\"Parsed {} trees from {}\".format(len(trees), input_filename))\n    return trees\n\ndef main():\n    treebank = convert_tiger_treebank(\"extern_data/constituency/danish/W0084/arboretum.tiger/arboretum.tiger\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_cintil.py",
    "content": "import xml.etree.ElementTree as ET\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.utils.datasets.constituency import utils\n\ndef read_xml_file(input_filename):\n    \"\"\"\n    Convert the CINTIL xml file to id & test\n\n    Returns a list of tuples: (id, text)\n    \"\"\"\n    with open(input_filename, encoding=\"utf-8\") as fin:\n        dataset = ET.parse(fin)\n    dataset = dataset.getroot()\n    corpus = dataset.find(\"{http://nlx.di.fc.ul.pt}corpus\")\n    if not corpus:\n        raise ValueError(\"Unexpected dataset structure : no 'corpus'\")\n    trees = []\n    for sentence in corpus:\n        if sentence.tag != \"{http://nlx.di.fc.ul.pt}sentence\":\n            raise ValueError(\"Unexpected sentence tag: {}\".format(sentence.tag))\n        id_node = None\n        raw_node = None\n        tree_nodde = None\n        for node in sentence:\n            if node.tag == '{http://nlx.di.fc.ul.pt}id':\n                id_node = node\n            elif node.tag == '{http://nlx.di.fc.ul.pt}raw':\n                raw_node = node\n            elif node.tag == '{http://nlx.di.fc.ul.pt}tree':\n                tree_node = node\n            else:\n                raise ValueError(\"Unexpected tag in sentence {}: {}\".format(sentence, node.tag))\n        if id_node is None or raw_node is None or tree_node is None:\n            raise ValueError(\"Missing node in sentence {}\".format(sentence))\n        tree_id = \"\".join(id_node.itertext())\n        tree_text = \"\".join(tree_node.itertext())\n        trees.append((tree_id, tree_text))\n    return trees\n\ndef convert_cintil_treebank(input_filename, train_size=0.8, dev_size=0.1):\n    \"\"\"\n    dev_size is the size for splitting train & dev\n    \"\"\"\n    trees = read_xml_file(input_filename)\n\n    synthetic_trees = []\n    natural_trees = []\n    for tree_id, tree_text in trees:\n        if tree_text.find(\" _\") >= 0:\n            raise ValueError(\"Unexpected underscore\")\n        tree_text = tree_text.replace(\"_)\", \")\")\n        tree_text = tree_text.replace(\"(A (\", \"(A' (\")\n        # trees don't have ROOT, but we typically use a ROOT label at the top\n        tree_text = \"(ROOT %s)\" % tree_text\n        trees = tree_reader.read_trees(tree_text)\n        if len(trees) != 1:\n            raise ValueError(\"Unexpectedly found %d trees in %s\" % (len(trees), tree_id))\n        tree = trees[0]\n        if tree_id.startswith(\"aTSTS\"):\n            synthetic_trees.append(tree)\n        elif tree_id.find(\"TSTS\") >= 0:\n            raise ValueError(\"Unexpected TSTS\")\n        else:\n            natural_trees.append(tree)\n\n    print(\"Read %d synthetic trees\" % len(synthetic_trees))\n    print(\"Read %d natural trees\" % len(natural_trees))\n    train_trees, dev_trees, test_trees = utils.split_treebank(natural_trees, train_size, dev_size)\n    print(\"Split %d trees into %d train %d dev %d test\" % (len(natural_trees), len(train_trees), len(dev_trees), len(test_trees)))\n    train_trees = synthetic_trees + train_trees\n    print(\"Total lengths %d train %d dev %d test\" % (len(train_trees), len(dev_trees), len(test_trees)))\n    return train_trees, dev_trees, test_trees\n\n\ndef main():\n    treebank = convert_cintil_treebank(\"extern_data/constituency/portuguese/CINTIL/CINTIL-Treebank.xml\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_ctb.py",
    "content": "from enum import Enum\nimport glob\nimport os\nimport re\n\nimport xml.etree.ElementTree as ET\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.utils.datasets.constituency.utils import write_dataset\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nclass Version(Enum):\n    V51   = 1\n    V51b  = 2\n    V90   = 3\n\ndef filenum_to_shard_51(filenum):\n    if filenum >= 1 and filenum <= 815:\n        return 0\n    if filenum >= 1001 and filenum <= 1136:\n        return 0\n\n    if filenum >= 886 and filenum <= 931:\n        return 1\n    if filenum >= 1148 and filenum <= 1151:\n        return 1\n\n    if filenum >= 816 and filenum <= 885:\n        return 2\n    if filenum >= 1137 and filenum <= 1147:\n        return 2\n\n    raise ValueError(\"Unhandled filenum %d\" % filenum)\n\ndef filenum_to_shard_51_basic(filenum):\n    if filenum >= 1 and filenum <= 270:\n        return 0\n    if filenum >= 440 and filenum <= 1151:\n        return 0\n\n    if filenum >= 301 and filenum <= 325:\n        return 1\n\n    if filenum >= 271 and filenum <= 300:\n        return 2\n\n    if filenum >= 400 and filenum <= 439:\n        return None\n\n    raise ValueError(\"Unhandled filenum %d\" % filenum)\n\ndef filenum_to_shard_90(filenum):\n    if filenum >= 1 and filenum <= 40:\n        return 2\n    if filenum >= 900 and filenum <= 931:\n        return 2\n    if filenum in (1018, 1020, 1036, 1044, 1060, 1061, 1072, 1118, 1119, 1132, 1141, 1142, 1148):\n        return 2\n    if filenum >= 2165 and filenum <= 2180:\n        return 2\n    if filenum >= 2295 and filenum <= 2310:\n        return 2\n    if filenum >= 2570 and filenum <= 2602:\n        return 2\n    if filenum >= 2800 and filenum <= 2819:\n        return 2\n    if filenum >= 3110 and filenum <= 3145:\n        return 2\n\n\n    if filenum >= 41 and filenum <= 80:\n        return 1\n    if filenum >= 1120 and filenum <= 1129:\n        return 1\n    if filenum >= 2140 and filenum <= 2159:\n        return 1\n    if filenum >= 2280 and filenum <= 2294:\n        return 1\n    if filenum >= 2550 and filenum <= 2569:\n        return 1\n    if filenum >= 2775 and filenum <= 2799:\n        return 1\n    if filenum >= 3080 and filenum <= 3109:\n        return 1\n\n    if filenum >= 81 and filenum <= 900:\n        return 0\n    if filenum >= 1001 and filenum <= 1017:\n        return 0\n    if filenum in (1019, 1130, 1131):\n        return 0\n    if filenum >= 1021 and filenum <= 1035:\n        return 0\n    if filenum >= 1037 and filenum <= 1043:\n        return 0\n    if filenum >= 1045 and filenum <= 1059:\n        return 0\n    if filenum >= 1062 and filenum <= 1071:\n        return 0\n    if filenum >= 1073 and filenum <= 1117:\n        return 0\n    if filenum >= 1133 and filenum <= 1140:\n        return 0\n    if filenum >= 1143 and filenum <= 1147:\n        return 0\n    if filenum >= 1149 and filenum <= 2139:\n        return 0\n    if filenum >= 2160 and filenum <= 2164:\n        return 0\n    if filenum >= 2181 and filenum <= 2279:\n        return 0\n    if filenum >= 2311 and filenum <= 2549:\n        return 0\n    if filenum >= 2603 and filenum <= 2774:\n        return 0\n    if filenum >= 2820 and filenum <= 3079:\n        return 0\n    if filenum >= 4000 and filenum <= 7017:\n        return 0\n\n\ndef collect_trees_s(root):\n    if root.tag == 'S':\n        yield root.text, root.attrib['ID']\n\n    for child in root:\n        for tree in collect_trees_s(child):\n            yield tree\n\ndef collect_trees_text(root):\n    if root.tag == 'TEXT' and len(root.text.strip()) > 0:\n        yield root.text, None\n\n    if root.tag == 'TURN' and len(root.text.strip()) > 0:\n        yield root.text, None\n\n    for child in root:\n        for tree in collect_trees_text(child):\n            yield tree\n\n\nid_re = re.compile(\"<S ID=([0-9a-z]+)>\")\nsu_re = re.compile(\"<(su|msg) id=([0-9a-zA-Z_=]+)>\")\n\ndef convert_ctb(input_dir, output_dir, dataset_name, version):\n    input_files = glob.glob(os.path.join(input_dir, \"*\"))\n\n    # train, dev, test\n    datasets = [[], [], []]\n\n    sorted_filenames = []\n    for input_filename in input_files:\n        base_filename = os.path.split(input_filename)[1]\n        filenum = int(os.path.splitext(base_filename)[0].split(\"_\")[1])\n        sorted_filenames.append((filenum, input_filename))\n    sorted_filenames.sort()\n\n    for filenum, filename in tqdm(sorted_filenames):\n        if version in (Version.V51, Version.V51b):\n            with open(filename, errors='ignore', encoding=\"gb2312\") as fin:\n                text = fin.read()\n        elif version is Version.V90:\n            with open(filename, encoding=\"utf-8\") as fin:\n                text = fin.read()\n            if text.find(\"<TURN>\") >= 0 and text.find(\"</TURN>\") < 0:\n                text = text.replace(\"<TURN>\", \"\")\n            if filenum in (4205, 4208, 4289):\n                text = text.replace(\"<)\", \"&lt;)\").replace(\">)\", \"&gt;)\")\n            if filenum >= 4000 and filenum <= 4411:\n                if text.find(\"<segment\") >= 0:\n                    text = text.replace(\"<segment id=\", \"<S ID=\").replace(\"</segment>\", \"</S>\")\n                elif text.find(\"<seg\") < 0:\n                    text = \"<TEXT>\\n%s</TEXT>\\n\" % text\n                else:\n                    text = text.replace(\"<seg id=\", \"<S ID=\").replace(\"</seg>\", \"</S>\")\n                text = \"<foo>\\n%s</foo>\\n\" % text\n            if filenum >= 5000 and filenum <= 5558 or filenum >= 6000 and filenum <= 6700 or filenum >= 7000 and filenum <= 7017:\n                text = su_re.sub(\"\", text)\n                if filenum in (6066, 6453):\n                    text = text.replace(\"<\", \"&lt;\").replace(\">\", \"&gt;\")\n                text = \"<foo><TEXT>\\n%s</TEXT></foo>\\n\" % text\n        else:\n            raise ValueError(\"Unknown CTB version %s\" % version)\n        text = id_re.sub(r'<S ID=\"\\1\">', text)\n        text = text.replace(\"&\", \"&amp;\")\n\n        try:\n            xml_root = ET.fromstring(text)\n        except Exception as e:\n            print(text[:1000])\n            raise RuntimeError(\"Cannot xml process %s\" % filename) from e\n        trees = [x for x in collect_trees_s(xml_root)]\n        if version is Version.V90 and len(trees) == 0:\n            trees = [x for x in collect_trees_text(xml_root)]\n\n        if version in (Version.V51, Version.V51b):\n            trees = [x[0] for x in trees if filenum != 414 or x[1] != \"4366\"]\n        else:\n            trees = [x[0] for x in trees]\n\n        trees = \"\\n\".join(trees)\n        try:\n            trees = tree_reader.read_trees(trees, use_tqdm=False)\n        except ValueError as e:\n            print(text[:300])\n            raise RuntimeError(\"Could not process the tree text in %s\" % filename)\n        trees = [t.prune_none().simplify_labels() for t in trees]\n\n        assert len(trees) > 0, \"No trees in %s\" % filename\n\n        if version is Version.V51:\n            shard = filenum_to_shard_51(filenum)\n        elif version is Version.V51b:\n            shard = filenum_to_shard_51_basic(filenum)\n        else:\n            shard = filenum_to_shard_90(filenum)\n        if shard is None:\n            continue\n        datasets[shard].extend(trees)\n\n\n    write_dataset(datasets, output_dir, dataset_name)\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_icepahc.py",
    "content": "\"\"\"\nCurrently this doesn't function\n\nThe goal is simply to demonstrate how to use tsurgeon\n\"\"\"\n\nfrom stanza.models.constituency.tree_reader import read_trees, read_treebank\nfrom stanza.server import tsurgeon\n\nTREEBANK = \"\"\"\n( (IP-MAT (NP-SBJ (PRO-N Það-það))\n          (BEPI er-vera)\n          (ADVP (ADV eiginlega-eiginlega))\n          (ADJP (NEG ekki-ekki) (ADJ-N hægt-hægur))\n          (IP-INF (TO að-að) (VB lýsa-lýsa))\n          (NP-OB1 (N-D tilfinningu$-tilfinning) (D-D $nni-hinn))\n          (IP-INF (TO að-að) (VB fá-fá))\n          (IP-INF (TO að-að) (VB taka-taka))\n          (NP-OB1 (N-A þátt-þáttur))\n          (PP (P í-í)\n              (NP (D-D þessu-þessi)))\n          (, ,-,)\n          (VBPI segir-segja)\n          (NP-SBJ (NPR-N Sverrir-sverrir) (NPR-N Ingi-ingi))\n          (. .-.)))\n\"\"\"\n\n# Output of the first tsurgeon:\n#(ROOT\n#  (IP-MAT\n#    (NP-SBJ (PRO-N Það))\n#    (BEPI er)\n#    (ADVP (ADV eiginlega))\n#    (ADJP (NEG ekki) (ADJ-N hægt))\n#    (IP-INF (TO að) (VB lýsa))\n#    (NP-OB1 (N-D tilfinningu$) (D-D $nni))\n#    (IP-INF (TO að) (VB fá))\n#    (IP-INF (TO að) (VB taka))\n#    (NP-OB1 (N-A þátt))\n#    (PP\n#      (P í)\n#      (NP (D-D þessu)))\n#    (, ,)\n#    (VBPI segir)\n#    (NP-SBJ (NPR-N Sverrir) (NPR-N Ingi))\n#    (. .)))\n\n# Output of the second operation\n#(ROOT\n#  (IP-MAT\n#    (NP-SBJ (PRO-N Það))\n#    (BEPI er)\n#    (ADVP (ADV eiginlega))\n#    (ADJP (NEG ekki) (ADJ-N hægt))\n#    (IP-INF (TO að) (VB lýsa))\n#    (NP-OB1 (N-D tilfinningunni))\n#    (IP-INF (TO að) (VB fá))\n#    (IP-INF (TO að) (VB taka))\n#    (NP-OB1 (N-A þátt))\n#    (PP\n#      (P í)\n#      (NP (D-D þessu)))\n#    (, ,)\n#    (VBPI segir)\n#    (NP-SBJ (NPR-N Sverrir) (NPR-N Ingi))\n#    (. .)))\n\n\ntreebank = read_trees(TREEBANK)\n\nwith tsurgeon.Tsurgeon(classpath=\"$CLASSPATH\") as tsurgeon_processor:\n    form_tregex = \"/^(.+)-.+$/#1%form=word !< __\"\n    form_tsurgeon = \"relabel word /^.+$/%{form}/\"\n\n    noun_det_tregex = \"/^N-/ < /^([^$]+)[$]$/#1%noun=noun $+ (/^D-/ < /^[$]([^$]+)$/#1%det=det)\"\n    noun_det_relabel = \"relabel noun /^.+$/%{noun}%{det}/\"\n    noun_det_prune = \"prune det\"\n\n    for tree in treebank:\n        updated_tree = tsurgeon_processor.process(tree, (form_tregex, form_tsurgeon))[0]\n        print(\"{:P}\".format(updated_tree))\n        updated_tree = tsurgeon_processor.process(updated_tree, (noun_det_tregex, noun_det_relabel, noun_det_prune))[0]\n        print(\"{:P}\".format(updated_tree))\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_it_turin.py",
    "content": "\"\"\"\nConverts Turin's constituency dataset\n\nTurin University put out a freely available constituency dataset in 2011.\nIt is not as large as VIT or ISST, but it is free, which is nice.\n\nThe 2011 parsing task combines trees from several sources:\nhttp://www.di.unito.it/~tutreeb/evalita-parsingtask-11.html\n\nThere is another site for Turin treebanks:\nhttp://www.di.unito.it/~tutreeb/treebanks.html\n\nWeirdly, the most recent versions of the Evalita trees are not there.\nThe most relevant parts are the ParTUT downloads.  As of Sep. 2021:\n\nhttp://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/JRCAcquis_It.pen\nhttp://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/UDHR_It.pen\nhttp://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/CC_It.pen\nhttp://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/FB_It.pen\nhttp://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/WIT3_It.pen\n\nWe can't simply cat all these files together as there are a bunch of\nasterisks as comments and the files may have some duplicates.  For\nexample, the JRCAcquis piece has many duplicates.  Also, some don't\npass validation for one reason or another.\n\nOne oddity of these data files is that the MWT are denoted by doubling\nthe token.  The token is not split as would be expected, though.  We try\nto use stanza's MWT tokenizer for IT to split the tokens, with some\nrules added by hand in BIWORD_SPLITS.  Two are still unsplit, though...\n\"\"\"\n\nimport glob\nimport os\nimport re\nimport sys\n\nimport stanza\nfrom stanza.models.constituency import parse_tree\nfrom stanza.models.constituency import tree_reader\n\ndef load_without_asterisks(in_file, encoding='utf-8'):\n    with open(in_file, encoding=encoding) as fin:\n        new_lines = [x if x.find(\"********\") < 0 else \"\\n\" for x in fin.readlines()]\n    if len(new_lines) > 0 and not new_lines[-1].endswith(\"\\n\"):\n        new_lines[-1] = new_lines[-1] + \"\\n\"\n    return new_lines\n\nCONSTITUENT_SPLIT = re.compile(\"[-=#+0-9]\")\n\n# JRCA is almost entirely duplicates\n# WIT3 follows a different annotation scheme\nFILES_TO_ELIMINATE = [\"JRCAcquis_It.pen\", \"WIT3_It.pen\"]\n\n# assuming this is a typo\nREMAP_NODES = { \"Sbar\" : \"SBAR\" }\n\nREMAP_WORDS = { \"-LSB-\": \"[\", \"-RSB-\": \"]\" }\n\n# these mostly seem to be mistakes\n# maybe Vbar and ADVbar should be converted to something else?\nNODES_TO_ELIMINATE = [\"C\", \"PHRASP\", \"PRDT\", \"Vbar\", \"parte\", \"ADVbar\"]\n\nUNKNOWN_SPLITS = set()\n\n# a map of splits that the tokenizer or MWT doesn't handle well\nBIWORD_SPLITS = { \"offertogli\": (\"offerto\", \"gli\"),\n                  \"offertegli\": (\"offerte\", \"gli\"),\n                  \"formatasi\": (\"formata\", \"si\"),\n                  \"formatosi\": (\"formato\", \"si\"),\n                  \"multiplexarlo\": (\"multiplexar\", \"lo\"),\n                  \"esibirsi\": (\"esibir\", \"si\"),\n                  \"pagarne\": (\"pagar\", \"ne\"),\n                  \"recarsi\": (\"recar\", \"si\"),\n                  \"trarne\": (\"trar\", \"ne\"),\n                  \"esserci\": (\"esser\", \"ci\"),\n                  \"aprirne\": (\"aprir\", \"ne\"),\n                  \"farle\": (\"far\", \"le\"),\n                  \"disporne\": (\"dispor\", \"ne\"),\n                  \"andargli\": (\"andar\", \"gli\"),\n                  \"CONSIDERARSI\": (\"CONSIDERAR\", \"SI\"),\n                  \"conferitegli\": (\"conferite\", \"gli\"),\n                  \"formatasi\": (\"formata\", \"si\"),\n                  \"formatosi\": (\"formato\", \"si\"),\n                  \"Formatisi\": (\"Formati\", \"si\"),\n                  \"multiplexarlo\": (\"multiplexar\", \"lo\"),\n                  \"esibirsi\": (\"esibir\", \"si\"),\n                  \"pagarne\": (\"pagar\", \"ne\"),\n                  \"recarsi\": (\"recar\", \"si\"),\n                  \"trarne\": (\"trar\", \"ne\"),\n                  \"temerne\": (\"temer\", \"ne\"),\n                  \"esserci\": (\"esser\", \"ci\"),\n                  \"esservi\": (\"esser\", \"vi\"),\n                  \"restituirne\": (\"restituir\", \"ne\"),\n                  \"col\": (\"con\", \"il\"),\n                  \"cogli\": (\"con\", \"gli\"),\n                  \"dirgli\": (\"dir\", \"gli\"),\n                  \"opporgli\": (\"oppor\", \"gli\"),\n                  \"eccolo\": (\"ecco\", \"lo\"),\n                  \"Eccolo\": (\"Ecco\", \"lo\"),\n                  \"Eccole\": (\"Ecco\", \"le\"),\n                  \"farci\": (\"far\", \"ci\"),\n                  \"farli\": (\"far\", \"li\"),\n                  \"farne\": (\"far\", \"ne\"),\n                  \"farsi\": (\"far\", \"si\"),\n                  \"farvi\": (\"far\", \"vi\"),\n                  \"Connettiti\": (\"Connetti\", \"ti\"),\n                  \"APPLICARSI\": (\"APPLICAR\", \"SI\"),\n                  # This is not always two words, but if it IS two words,\n                  # it gets split like this\n                  \"assicurati\": (\"assicura\", \"ti\"),\n                  \"Fatti\": (\"Fai\", \"te\"),\n                  \"ai\": (\"a\", \"i\"),\n                  \"Ai\": (\"A\", \"i\"),\n                  \"AI\": (\"A\", \"I\"),\n                  \"al\": (\"a\", \"il\"),\n                  \"Al\": (\"A\", \"il\"),\n                  \"AL\": (\"A\", \"IL\"),\n                  \"coi\": (\"con\", \"i\"),\n                  \"colla\": (\"con\", \"la\"),\n                  \"colle\": (\"con\", \"le\"),\n                  \"dal\": (\"da\", \"il\"),\n                  \"Dal\": (\"Da\", \"il\"),\n                  \"DAL\": (\"DA\", \"IL\"),\n                  \"dei\": (\"di\", \"i\"),\n                  \"Dei\": (\"Di\", \"i\"),\n                  \"DEI\": (\"DI\", \"I\"),\n                  \"del\": (\"di\", \"il\"),\n                  \"Del\": (\"Di\", \"il\"),\n                  \"DEL\": (\"DI\", \"IL\"),\n                  \"nei\": (\"in\", \"i\"),\n                  \"NEI\": (\"IN\", \"I\"),\n                  \"nel\": (\"in\", \"il\"),\n                  \"Nel\": (\"In\", \"il\"),\n                  \"NEL\": (\"IN\", \"IL\"),\n                  \"pel\": (\"per\", \"il\"),\n                  \"sui\": (\"su\", \"i\"),\n                  \"Sui\": (\"Su\", \"i\"),\n                  \"sul\": (\"su\", \"il\"),\n                  \"Sul\": (\"Su\", \"il\"),\n                  \",\": (\",\", \",\"),\n                  \".\": (\".\", \".\"),\n                  '\"': ('\"', '\"'),\n                  '-': ('-', '-'),\n                  '-LRB-': ('-LRB-', '-LRB-'),\n                  \"garantirne\": (\"garantir\", \"ne\"),\n                  \"aprirvi\": (\"aprir\", \"vi\"),\n                  \"esimersi\": (\"esimer\", \"si\"),\n                  \"opporsi\": (\"oppor\", \"si\"),\n}\n\nCAP_BIWORD = re.compile(\"[A-Z]+_[A-Z]+\")\n\ndef split_mwe(tree, pipeline):\n    words = list(tree.leaf_labels())\n    found = False\n    for idx, word in enumerate(words[:-3]):\n        if word == words[idx+1] and word == words[idx+2] and word == words[idx+3]:\n            raise ValueError(\"Oh no, 4 consecutive words\")\n\n    for idx, word in enumerate(words[:-2]):\n        if word == words[idx+1] and word == words[idx+2]:\n            doc = pipeline(word)\n            assert len(doc.sentences) == 1\n            if len(doc.sentences[0].words) != 3:\n                raise RuntimeError(\"Word {} not tokenized into 3 parts... thought all 3 part words were handled!\".format(word))\n            words[idx] = doc.sentences[0].words[0].text\n            words[idx+1] = doc.sentences[0].words[1].text\n            words[idx+2] = doc.sentences[0].words[2].text\n            found = True\n\n    for idx, word in enumerate(words[:-1]):\n        if word == words[idx+1]:\n            if word in BIWORD_SPLITS:\n                first_word = BIWORD_SPLITS[word][0]\n                second_word = BIWORD_SPLITS[word][1]\n            elif CAP_BIWORD.match(word):\n                first_word, second_word = word.split(\"_\")\n            else:\n                doc = pipeline(word)\n                assert len(doc.sentences) == 1\n                if len(doc.sentences[0].words) == 2:\n                    first_word = doc.sentences[0].words[0].text\n                    second_word = doc.sentences[0].words[1].text\n                else:\n                    if word not in UNKNOWN_SPLITS:\n                        UNKNOWN_SPLITS.add(word)\n                        print(\"Could not figure out how to split {}\\n  {}\\n  {}\".format(word, \" \".join(words), tree))\n                    continue\n\n            words[idx] = first_word\n            words[idx+1] = second_word\n            found = True\n\n    if found:\n        tree = tree.replace_words(words)\n    return tree\n\n\ndef load_trees(filename, pipeline):\n    # some of the files are in latin-1 encoding rather than utf-8\n    try:\n        raw_text = load_without_asterisks(filename, \"utf-8\")\n    except UnicodeDecodeError:\n        raw_text = load_without_asterisks(filename, \"latin-1\")\n\n    # also, some have messed up validation (it will be logged)\n    # hence the broken_ok=True argument\n    trees = tree_reader.read_trees(\"\".join(raw_text), broken_ok=True)\n\n    filtered_trees = []\n    for tree in trees:\n        if tree.children[0].label is None:\n            print(\"Skipping a broken tree (missing label) in {}: {}\".format(filename, tree))\n            continue\n\n        try:\n            words = tuple(tree.leaf_labels())\n        except ValueError:\n            print(\"Skipping a broken tree (missing preterminal) in {}: {}\".format(filename, tree))\n            continue\n\n        if any('www.facebook' in pt.label for pt in tree.preterminals()):\n            print(\"Skipping a tree with a weird preterminal label in {}: {}\".format(filename, tree))\n            continue\n\n        tree = tree.prune_none().simplify_labels(CONSTITUENT_SPLIT)\n\n        if len(tree.children) > 1:\n            print(\"Found a tree with a non-unary root!  {}: {}\".format(filename, tree))\n            continue\n        if tree.children[0].is_preterminal():\n            print(\"Found a tree with a single preterminal node!  {}: {}\".format(filename, tree))\n            continue\n\n        # The expectation is that the retagging will handle this anyway\n        for pt in tree.preterminals():\n            if not pt.label:\n                pt.label = \"UNK\"\n                print(\"Found a tree with a blank preterminal label.  Setting it to UNK.  {}: {}\".format(filename, tree))\n\n        tree = tree.remap_constituent_labels(REMAP_NODES)\n        tree = tree.remap_words(REMAP_WORDS)\n\n        tree = split_mwe(tree, pipeline)\n        if tree is None:\n            continue\n\n        constituents = set(parse_tree.Tree.get_unique_constituent_labels(tree))\n        for weird_label in NODES_TO_ELIMINATE:\n            if weird_label in constituents:\n                break\n        else:\n            weird_label = None\n        if weird_label is not None:\n            print(\"Skipping a tree with a weird label {} in {}: {}\".format(weird_label, filename, tree))\n            continue\n\n        filtered_trees.append(tree)\n\n    return filtered_trees\n\ndef save_trees(out_file, trees):\n    print(\"Saving {} trees to {}\".format(len(trees), out_file))\n    with open(out_file, \"w\", encoding=\"utf-8\") as fout:\n        for tree in trees:\n            fout.write(str(tree))\n            fout.write(\"\\n\")\n\ndef convert_it_turin(input_path, output_path):\n    pipeline = stanza.Pipeline(\"it\", processors=\"tokenize, mwt\", tokenize_no_ssplit=True)\n\n    os.makedirs(output_path, exist_ok=True)\n\n    evalita_dir = os.path.join(input_path, \"evalita\")\n\n    evalita_test = os.path.join(evalita_dir, \"evalita11_TESTgold_CONPARSE.penn\")\n    it_test = os.path.join(output_path, \"it_turin_test.mrg\")\n    test_trees = load_trees(evalita_test, pipeline)\n    save_trees(it_test, test_trees)\n\n    known_text = set()\n    for tree in test_trees:\n        words = tuple(tree.leaf_labels())\n        assert words not in known_text\n        known_text.add(words)\n\n    evalita_train = os.path.join(output_path, \"it_turin_train.mrg\")\n    evalita_files = glob.glob(os.path.join(evalita_dir, \"*2011*penn\"))\n    turin_files = glob.glob(os.path.join(input_path, \"turin\", \"*pen\"))\n    filenames = evalita_files + turin_files\n    filtered_trees = []\n    for filename in filenames:\n        if os.path.split(filename)[1] in FILES_TO_ELIMINATE:\n            continue\n\n        trees = load_trees(filename, pipeline)\n        file_trees = []\n\n        for tree in trees:\n            words = tuple(tree.leaf_labels())\n            if words in known_text:\n                print(\"Skipping a duplicate in {}: {}\".format(filename, tree))\n                continue\n\n            known_text.add(words)\n\n            file_trees.append(tree)\n\n        filtered_trees.append((filename, file_trees))\n\n    print(\"{} contains {} usable trees\".format(evalita_test, len(test_trees)))\n    print(\"  Unique constituents in {}: {}\".format(evalita_test, parse_tree.Tree.get_unique_constituent_labels(test_trees)))\n\n    train_trees = []\n    dev_trees = []\n    for filename, file_trees in filtered_trees:\n        print(\"{} contains {} usable trees\".format(filename, len(file_trees)))\n        print(\"  Unique constituents in {}: {}\".format(filename, parse_tree.Tree.get_unique_constituent_labels(file_trees)))\n        for tree in file_trees:\n            if len(train_trees) <= len(dev_trees) * 9:\n                train_trees.append(tree)\n            else:\n                dev_trees.append(tree)\n\n    it_train = os.path.join(output_path, \"it_turin_train.mrg\")\n    save_trees(it_train, train_trees)\n\n    it_dev = os.path.join(output_path, \"it_turin_dev.mrg\")\n    save_trees(it_dev, dev_trees)\n\ndef main():\n    input_path = sys.argv[1]\n    output_path = sys.argv[2]\n\n    convert_it_turin(input_path, output_path)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_it_vit.py",
    "content": "\"\"\"Converts the proprietary VIT dataset to a format suitable for stanza\n\nThere are multiple corrections in the UD version of VIT, along with\nrecommended splits for the MWT, along with recommended splits of\nthe sentences into train/dev/test\n\nAccordingly, it is necessary to use the UD dataset as a reference\n\nHere is a sample line of the text file we use:\n\n#ID=sent_00002  cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]]\n\nHere you can already see multiple issues when parsing:\n- the first word is \"negli\", which is split into In_ADP gli_DET in the UD version\n- also the first word is capitalized in the UD version\n- comma looks like a tempting split target, but there is a ',' in this sentence\n  punt-','\n- not shown here is '-' which is different from the - used for denoting POS\n  par-'-'\n\nFortunately, -[ is always an open and ] is always a close\n\nAs of April 2022, the UD version of the dataset has some minor edits\nwhich are necessary for the proper functioning of this script.\nOtherwise, the MWT won't align correctly, some typos won't be\ncorrected, etc.  These edits are released in UD 2.10\n\nThe data itself is available from ELRA:\n\nhttp://catalog.elra.info/en-us/repository/browse/ELRA-W0040/\n\nInternally at Stanford you can contact Chris Manning or John Bauer.\n\nThe processing goes as follows:\n- read in UD and con trees\n  some of the con trees have broken brackets and are discarded\n  in other cases, abbreviations were turned into single tokens in UD\n- extract the MWT expansions of Italian contractions,\n  such as Negli -> In gli\n- attempt to align the trees from the two datasets using ngrams\n  some trees had the sentence splitting updated\n  sentences which can't be matched are discarded\n- use CoreNLP tsurgeon to update tokens in the con trees\n  based on the information in the UD dataset\n  - split contractions\n  - rearrange clitics which are occasionally non-projective\n- replace the words in the con tree with the dep tree's words\n  this takes advantage of spelling & capitalization fixes\n\nIn 2022, there was an update to the dataset from Prof. Delmonte.\nThis update is hopefully in current ELRA distributions now.\nIf not, please contact ELRA to specifically ask for the updated version.\nInternally to Stanford, feel free to ask Chris or John for the updates.\nLook for the line below \"original version with more errors\"\n\nIn August 2022, Prof. Delmonte made a slight update in a zip file\n`john.zip`.  If/when that gets updated to ELRA, we will update it\nhere.  Contact Chris or John for a copy if not updated yet, or go\nback in git history to get the older version of the code which\nworks with the 2022 ELRA update.\n\nLater, in September 2022, there is yet another update,\nNew version of VIT.zip\nUnzip the contents into a folder\n$CONSTITUENCY_BASE/italian/it_vit\nso there should be a file\n$CONSTITUENCY_BASE/italian/it_vit/VITwritten/VITconstsyntNumb\n\nThere are a few other updates needed to improve the annotations,\nbut all the nagging seemed to give Prof. Delmonte a headache,\nso at this point we include those fixes in this script instead.\nSee the first few tsurgeon operations in update_mwts_and_special_cases\n\"\"\"\n\nfrom collections import defaultdict, deque, namedtuple\nimport itertools\nimport os\nimport re\nimport sys\n\nfrom tqdm import tqdm\n\nfrom stanza.models.constituency.tree_reader import read_trees, UnclosedTreeError, ExtraCloseTreeError\nfrom stanza.server import tsurgeon\nfrom stanza.utils.conll import CoNLL\nfrom stanza.utils.datasets.constituency.utils import SHARDS, write_dataset\nimport stanza.utils.default_paths as default_paths\n\ndef read_constituency_sentences(fin):\n    \"\"\"\n    Reads the lines from the constituency treebank and splits into ID, text\n\n    No further processing is done on the trees yet\n    \"\"\"\n    sentences = []\n    for line in fin:\n        line = line.strip()\n        # WTF why doesn't strip() remove this\n        line = line.replace(u'\\ufeff', '')\n        if not line:\n            continue\n        sent_id, sent_text = line.split(maxsplit=1)\n        # we have seen a couple different versions of this sentence header\n        # although one file is always consistent with itself, at least\n        if not sent_id.startswith(\"#ID=sent\") and not sent_id.startswith(\"ID#sent\"):\n            raise ValueError(\"Unexpected start of sentence: |{}|\".format(sent_id))\n        if not sent_text:\n            raise ValueError(\"Empty text for |{}|\".format(sent_id))\n        sentences.append((sent_id, sent_text))\n    return sentences\n\ndef read_constituency_file(filename):\n    print(\"Reading raw constituencies from %s\" % filename)\n    with open(filename, encoding='utf-8') as fin:\n        return read_constituency_sentences(fin)\n\nOPEN = \"-[\"\nCLOSE = \"]\"\n\nDATE_RE = re.compile(\"^([0-9]{1,2})[_]([0-9]{2})$\")\nINTEGER_PERCENT_RE = re.compile(r\"^((?:min|plus)?[0-9]{1,3})[%]$\")\nDECIMAL_PERCENT_RE = re.compile(r\"^((?:min|plus)?[0-9]{1,3})[/_]([0-9]{1,3})[%]$\")\nRANGE_PERCENT_RE = re.compile(r\"^([0-9]{1,2}[/_][0-9]{1,2})[/]([0-9]{1,2}[/_][0-9]{1,2})[%]$\")\nDECIMAL_RE = re.compile(r\"^([0-9])[_]([0-9])$\")\n\nProcessedTree = namedtuple('ProcessedTree', ['con_id', 'dep_id', 'tree'])\n\ndef raw_tree(text):\n    \"\"\"\n    A sentence will look like this:\n       #ID=sent_00001  fc-[f3-[sn-[art-le, n-infrastrutture, sc-[ccom-come, sn-[n-fattore, spd-[pd-di,\n                       sn-[n-competitività]]]]]], f3-[spd-[pd-di, sn-[mw-Angela, nh-Airoldi]]], punto-.]\n    Non-preterminal nodes have tags, followed by the stuff under the node, -[\n    The node is closed by the ]\n    \"\"\"\n    pieces = []\n    open_pieces = text.split(OPEN)\n    for open_idx, open_piece in enumerate(open_pieces):\n        if open_idx > 0:\n            pieces[-1] = pieces[-1] + OPEN\n        open_piece = open_piece.strip()\n        if not open_piece:\n            raise ValueError(\"Unexpected empty node!\")\n        close_pieces = open_piece.split(CLOSE)\n        for close_idx, close_piece in enumerate(close_pieces):\n            if close_idx > 0:\n                pieces.append(CLOSE)\n            close_piece = close_piece.strip()\n            if not close_piece:\n                # this is okay - multiple closes at the end of a deep bracket\n                continue\n            word_pieces = close_piece.split(\", \")\n            pieces.extend([x.strip() for x in word_pieces if x.strip()])\n\n    # at this point, pieces is a list with:\n    #   tag-[     for opens\n    #   tag-word  for words\n    #   ]         for closes\n    # this structure converts pretty well to reading using the tree reader\n\n    PIECE_MAPPING = {\n        \"agn-/ter'\":               \"(agn ter)\",\n        \"cong-'&'\":                \"(cong &)\",\n        \"da_riempire-'...'\":       \"(da_riempire ...)\",\n        \"date-1992_1993\":          \"(date 1992/1993)\",\n        \"date-'31-12-95'\":         \"(date 31-12-95)\",\n        \"date-'novantaquattro-95'\":\"(date novantaquattro-95)\",\n        \"date-'novantaquattro-95\": \"(date novantaquattro-95)\",\n        \"date-'novantaquattro-novantacinque'\": \"(date novantaquattro-novantacinque)\",\n        \"dirs-':'\":                \"(dirs :)\",\n        \"dirs-'\\\"'\":               \"(dirs \\\")\",\n        \"mw-'&'\":                  \"(mw &)\",\n        \"mw-'Presunto'\":           \"(mw Presunto)\",\n        \"nh-'Alain-Gauze'\":        \"(nh Alain-Gauze)\",\n        \"np-'porto_Marghera'\":     \"(np Porto) (np Marghera)\",\n        \"np-'roma-l_aquila'\":      \"(np Roma-L'Aquila)\",\n        \"np-'L_Aquila-Villa_Vomano'\": \"(np L'Aquila) (np -) (np Villa) (np Vomano)\",\n        \"npro-'Avanti_!'\":         \"(npro Avanti_!)\",\n        \"npro-'Viacom-Paramount'\": \"(npro Viacom-Paramount)\",\n        \"npro-'Rhone-Poulenc'\":    \"(npro Rhone-Poulenc)\",\n        \"npro-'Itar-Tass'\":        \"(npro Itar-Tass)\",\n        \"par-(-)\":                 \"(par -)\",\n        \"par-','\":                 \"(par ,)\",\n        \"par-'<'\":                 \"(par <)\",\n        \"par-'>'\":                 \"(par >)\",\n        \"par-'-'\":                 \"(par -)\",\n        \"par-'\\\"'\":                \"(par \\\")\",\n        \"par-'('\":                 \"(par -LRB-)\",\n        \"par-')'\":                 \"(par -RRB-)\",\n        \"par-'&&'\":                \"(par &&)\",\n        \"punt-','\":                \"(punt ,)\",\n        \"punt-'-'\":                \"(punt -)\",\n        \"punt-';'\":                \"(punt ;)\",\n        \"punto-':'\":               \"(punto :)\",\n        \"punto-';'\":               \"(punto ;)\",\n        \"puntint-'!'\":             \"(puntint !)\",\n        \"puntint-'?'\":             \"(puntint !)\",\n        \"num-'2plus2'\":            \"(num 2+2)\",\n        \"num-/bis'\":               \"(num bis)\",\n        \"num-/ter'\":               \"(num ter)\",\n        \"num-18_00/1_00\":          \"(num 18:00/1:00)\",\n        \"num-1/500_2/000\":         \"(num 1.500-2.000)\",\n        \"num-16_1\":                \"(num 16,1)\",\n        \"num-0_1\":                 \"(num 0,1)\",\n        \"num-0_3\":                 \"(num 0,3)\",\n        \"num-2_7\":                 \"(num 2,7)\",\n        \"num-455_68\":              \"(num 455/68)\",\n        \"num-437_5\":               \"(num 437,5)\",\n        \"num-4708_82\":             \"(num 4708,82)\",\n        \"num-16EQ517_7\":           \"(num 16EQ517/7)\",\n        \"num-2=184_90\":            \"(num 2=184/90)\",\n        \"num-3EQ429_20\":           \"(num 3eq429/20)\",\n        \"num-'1990-EQU-100'\":      \"(num 1990-EQU-100)\",\n        \"num-'500-EQU-250'\":       \"(num 500-EQU-250)\",\n        \"num-0_39%minus\":          \"(num 0,39) (num %%) (num -)\",\n        \"num-1_88/76\":             \"(num 1-88/76)\",\n        \"num-'70/80'\":             \"(num 70,80)\",\n        \"num-'18/20'\":             \"(num 18:20)\",\n        \"num-295/mila'\":           \"(num 295mila)\",\n        \"num-'295/mila'\":          \"(num 295mila)\",\n        \"num-0/07%plus\":           \"(num 0,07) (num %%) (num plus)\",\n        \"num-0/69%minus\":          \"(num 0,69) (num %%) (num minus)\",\n        \"num-0_39%minus\":          \"(num 0,39) (num %%) (num minus)\",\n        \"num-9_11/16\":             \"(num 9-11,16)\",\n        \"num-2/184_90\":            \"(num 2=184/90)\",\n        \"num-3/429_20\":            \"(num 3eq429/20)\",\n        # TODO: remove the following num conversions if possible\n        # this would require editing either constituency or UD\n        \"num-1:28_124\":            \"(num 1=8/1242)\",\n        \"num-1:28_397\":            \"(num 1=8/3972)\",\n        \"num-1:28_947\":            \"(num 1=8/9472)\",\n        \"num-1:29_657\":            \"(num 1=9/6572)\",\n        \"num-1:29_867\":            \"(num 1=9/8672)\",\n        \"num-1:29_874\":            \"(num 1=9/8742)\",\n        \"num-1:30_083\":            \"(num 1=0/0833)\",\n        \"num-1:30_140\":            \"(num 1=0/1403)\",\n        \"num-1:30_354\":            \"(num 1=0/3543)\",\n        \"num-1:30_453\":            \"(num 1=0/4533)\",\n        \"num-1:30_946\":            \"(num 1=0/9463)\",\n        \"num-1:31_602\":            \"(num 1=1/6023)\",\n        \"num-1:31_842\":            \"(num 1=1/8423)\",\n        \"num-1:32_087\":            \"(num 1=2/0873)\",\n        \"num-1:32_259\":            \"(num 1=2/2593)\",\n        \"num-1:33_166\":            \"(num 1=3/1663)\",\n        \"num-1:34_154\":            \"(num 1=4/1543)\",\n        \"num-1:34_556\":            \"(num 1=4/5563)\",\n        \"num-1:35_323\":            \"(num 1=5/3233)\",\n        \"num-1:36_023\":            \"(num 1=6/0233)\",\n        \"num-1:36_076\":            \"(num 1=6/0763)\",\n        \"num-1:36_651\":            \"(num 1=6/6513)\",\n        \"n-giga_flop/s\":           \"(n giga_flop/s)\",\n        \"sect-'g-1'\":              \"(sect g-1)\",\n        \"sect-'h-1'\":              \"(sect h-1)\",\n        \"sect-'h-2'\":              \"(sect h-2)\",\n        \"sect-'h-3'\":              \"(sect h-3)\",\n        \"abbr-'a-b-c'\":            \"(abbr a-b-c)\",\n        \"abbr-d_o_a_\":             \"(abbr DOA)\",\n        \"abbr-d_l_\":               \"(abbr DL)\",\n        \"abbr-i_s_e_f_\":           \"(abbr ISEF)\",\n        \"abbr-d_p_r_\":             \"(abbr DPR)\",\n        \"abbr-D_P_R_\":             \"(abbr DPR)\",\n        \"abbr-d_m_\":               \"(abbr dm)\",\n        \"abbr-T_U_\":               \"(abbr TU)\",\n        \"abbr-F_A_M_E_\":           \"(abbr Fame)\",\n        \"dots-'...'\":              \"(dots ...)\",\n    }\n    new_pieces = [\"(ROOT \"]\n    for piece in pieces:\n        if piece.endswith(OPEN):\n            new_pieces.append(\"(\" + piece[:-2])\n        elif piece == CLOSE:\n            new_pieces.append(\")\")\n        elif piece in PIECE_MAPPING:\n            new_pieces.append(PIECE_MAPPING[piece])\n        else:\n            # maxsplit=1 because of words like 1990-EQU-100\n            tag, word = piece.split(\"-\", maxsplit=1)\n            if word.find(\"'\") >= 0 or word.find(\"(\") >= 0 or word.find(\")\") >= 0:\n                raise ValueError(\"Unhandled weird node: {}\".format(piece))\n            if word.endswith(\"_\"):\n                word = word[:-1] + \"'\"\n            date_match = DATE_RE.match(word)\n            if date_match:\n                # 10_30 special case sent_07072\n                # 16_30 special case sent_07098\n                # 21_15 special case sent_07099 and others\n                word = date_match.group(1) + \":\" + date_match.group(2)\n            integer_percent = INTEGER_PERCENT_RE.match(word)\n            if integer_percent:\n                word = integer_percent.group(1) + \"_%%\"\n            range_percent = RANGE_PERCENT_RE.match(word)\n            if range_percent:\n                word = range_percent.group(1) + \",\" + range_percent.group(2) + \"_%%\"\n            percent = DECIMAL_PERCENT_RE.match(word)\n            if percent:\n                word = percent.group(1) + \",\" + percent.group(2) + \"_%%\"\n            decimal = DECIMAL_RE.match(word)\n            if decimal:\n                word = decimal.group(1) + \",\" + decimal.group(2)\n            # there are words which are multiple words mashed together\n            # with _ for some reason\n            # also, words which end in ' are replaced with _\n            # fortunately, no words seem to have both\n            # splitting like this means the tags are likely wrong,\n            # but the conparser needs to retag anyway, so it shouldn't matter\n            word_pieces = word.split(\"_\")\n            for word_piece in word_pieces:\n                new_pieces.append(\"(%s %s)\" % (tag, word_piece))\n    new_pieces.append(\")\")\n\n    text = \" \".join(new_pieces)\n    trees = read_trees(text)\n    if len(trees) > 1:\n        raise ValueError(\"Unexpected number of trees!\")\n    return trees[0]\n\ndef extract_ngrams(sentence, process_func, ngram_len=4):\n    leaf_words = [x for x in process_func(sentence)]\n    leaf_words = [\"l'\" if x == \"l\" else x for x in leaf_words]\n    if len(leaf_words) <= ngram_len:\n        return [tuple(leaf_words)]\n    its = [leaf_words[i:i+len(leaf_words)-ngram_len+1] for i in range(ngram_len)]\n    return [words for words in itertools.zip_longest(*its)]\n\ndef build_ngrams(sentences, process_func, id_func, ngram_len=4):\n    \"\"\"\n    Turn the list of processed trees into a bunch of ngrams\n\n    The returned map is from tuple to set of ids\n\n    The idea being that this map can be used to search for trees to\n    match datasets\n    \"\"\"\n    ngram_map = defaultdict(set)\n    for sentence in tqdm(sentences, postfix=\"Extracting ngrams\"):\n        sentence_id = id_func(sentence)\n        ngrams = extract_ngrams(sentence, process_func, ngram_len)\n        for ngram in ngrams:\n            ngram_map[ngram].add(sentence_id)\n    return ngram_map\n\n# just the tokens (maybe use words?  depends on MWT in the con dataset)\nDEP_PROCESS_FUNC = lambda x: [t.text.lower() for t in x.tokens]\n# find the comment with \"sent_id\" in it, take just the id itself\nDEP_ID_FUNC = lambda x: [c for c in x.comments if c.startswith(\"# sent_id\")][0].split()[-1]\n\nCON_PROCESS_FUNC = lambda x: [y.lower() for y in x.leaf_labels()]\n\ndef match_ngrams(sentence_ngrams, ngram_map, debug=False):\n    \"\"\"\n    Check if there is a SINGLE matching sentence in the ngram_map for these ngrams\n\n    If an ngram shows up in multiple sentences, that is okay, but we ignore that info\n    If an ngram shows up in just one sentence, that is considered the match\n    If a different ngram then shows up in a different sentence, that is a problem\n    TODO: taking the intersection of all non-empty matches might be better\n    \"\"\"\n    if debug:\n        print(\"NGRAMS FOR DEBUG SENTENCE:\")\n    potential_match = None\n    unknown_ngram = 0\n    for ngram in sentence_ngrams:\n        con_matches = ngram_map[ngram]\n        if debug:\n            print(\"{} matched {}\".format(ngram, len(con_matches)))\n        if len(con_matches) == 0:\n            unknown_ngram += 1\n            continue\n        if len(con_matches) > 1:\n            continue\n        # get the one & only element from the set\n        con_match = next(iter(con_matches))\n        if debug:\n            print(\"  {}\".format(con_match))\n        if potential_match is None:\n            potential_match = con_match\n        elif potential_match != con_match:\n            return None\n    if unknown_ngram > len(sentence_ngrams) / 2:\n        return None\n    return potential_match\n\ndef match_sentences(con_tree_map, con_vit_ngrams, dep_sentences, split_name, debug_sentence=None):\n    \"\"\"\n    Match ngrams in the dependency sentences to the constituency sentences\n\n    Then, to make sure the constituency sentence wasn't split into two\n    in the UD dataset, this checks the ngrams in the reverse direction\n\n    Some examples of things which don't match:\n      VIT-4769 Insegnanti non vedenti, insegnanti non autosufficienti con protesi agli arti inferiori.\n      this is duplicated in the original dataset, so the matching algorithm can't possibly work\n\n      VIT-4796 I posti istituiti con attività di sostegno dei docenti che ottengono il trasferimento su classi di concorso;\n      the correct con match should be sent_04829 but the brackets on that tree are broken\n    \"\"\"\n    con_to_dep_matches = {}\n    dep_ngram_map = build_ngrams(dep_sentences, DEP_PROCESS_FUNC, DEP_ID_FUNC)\n    unmatched = 0\n    bad_match = 0\n    for sentence in dep_sentences:\n        sentence_ngrams = extract_ngrams(sentence, DEP_PROCESS_FUNC)\n        potential_match = match_ngrams(sentence_ngrams, con_vit_ngrams, debug_sentence is not None and DEP_ID_FUNC(sentence) == debug_sentence)\n        if potential_match is None:\n            if unmatched < 5:\n                print(\"Could not match the following sentence: {} {}\".format(DEP_ID_FUNC(sentence), sentence.text))\n            unmatched += 1\n            continue\n        if potential_match not in con_tree_map:\n            raise ValueError(\"wtf\")\n        con_ngrams = extract_ngrams(con_tree_map[potential_match], CON_PROCESS_FUNC)\n        reverse_match = match_ngrams(con_ngrams, dep_ngram_map)\n        if reverse_match is None:\n            #print(\"Matched sentence {} to sentence {} but the reverse match failed\".format(sentence.text, \" \".join(con_tree_map[potential_match].leaf_labels())))\n            bad_match += 1\n            continue\n        con_to_dep_matches[potential_match] = reverse_match\n    print(\"Failed to match %d sentences and found %d spurious matches in the %s section\" % (unmatched, bad_match, split_name))\n    return con_to_dep_matches\n\nEXCEPTIONS = [\"gliene\", \"glielo\", \"gliela\", \"eccoci\"]\n\ndef get_mwt(*dep_datasets):\n    \"\"\"\n    Get the ADP/DET MWTs from the UD dataset\n\n    This class of MWT are expanded in the UD but not the constituencies\n    \"\"\"\n    mwt_map = {}\n    for dataset in dep_datasets:\n        for sentence in dataset.sentences:\n            for token in sentence.tokens:\n                if len(token.words) == 1:\n                    continue\n                # words such as \"accorgermene\" we just skip over\n                # those are already expanded in the constituency dataset\n                # TODO: the clitics are actually expanded weirdly, maybe need to compensate for that\n                if token.words[0].upos in ('VERB', 'AUX') and all(word.upos == 'PRON' for word in token.words[1:]):\n                    continue\n                if token.text.lower() in EXCEPTIONS:\n                    continue\n                if len(token.words) != 2 or token.words[0].upos != 'ADP' or token.words[1].upos != 'DET':\n                    raise ValueError(\"Not sure how to handle this: {}\".format(token))\n                expansion = (token.words[0].text, token.words[1].text)\n                if token.text in mwt_map:\n                    if mwt_map[token.text] != expansion:\n                        raise ValueError(\"Inconsistent MWT: {} -> {} or {}\".format(token.text, expansion, mwt_map[token.text]))\n                    continue\n                #print(\"Expanding {} to {}\".format(token.text, expansion))\n                mwt_map[token.text] = expansion\n    return mwt_map\n\ndef update_mwts_and_special_cases(original_tree, dep_sentence, mwt_map, tsurgeon_processor):\n    \"\"\"\n    Replace MWT structures with their UD equivalents, along with some other minor tsurgeon based edits\n\n    original_tree: the tree as read from VIT\n    dep_sentence: the UD dependency dataset version of this sentence\n    \"\"\"\n    updated_tree = original_tree\n\n    operations = []\n\n    # first, remove titles or testo from the start of a sentence\n    con_words = updated_tree.leaf_labels()\n    if con_words[0] == \"Tit'\":\n        operations.append([\"/^Tit'$/=prune !, __\", \"prune prune\"])\n    elif con_words[0] == \"TESTO\":\n        operations.append([\"/^TESTO$/=prune !, __\", \"prune prune\"])\n    elif con_words[0] == \"testo\":\n        operations.append([\"/^testo$/ !, __ . /^:$/=prune\", \"prune prune\"])\n        operations.append([\"/^testo$/=prune !, __\", \"prune prune\"])\n\n    if len(con_words) >= 2 and con_words[-2] == '...' and con_words[-1] == '.':\n        # the most recent VIT constituency has some sentence final . after a ...\n        # the UD dataset has a more typical ... ending instead\n        # these lines used to say \"riempire\" which was rather odd\n        operations.append([\"/^[.][.][.]$/ . /^[.]$/=prune\", \"prune prune\"])\n\n    # a few constituent tags are simply errors which need to be fixed\n    if original_tree.children[0].label == 'p':\n        # 'p' shouldn't be at root\n        operations.append([\"_ROOT_ < p=p\", \"relabel p cp\"])\n    # fix one specific tree if it has an s_top in it\n    operations.append([\"s_top=stop < (in=in < più=piu)\", \"replace piu (q più)\", \"relabel in sq\", \"relabel stop sa\"])\n    # sect doesn't exist as a constituent.  replace it with sa\n    operations.append([\"sect=sect < num\", \"relabel sect sa\"])\n    # ppas as an internal node gets removed\n    operations.append([\"ppas=ppas < (__ < __)\", \"excise ppas ppas\"])\n\n    # now assemble a bunch of regex to split and otherwise manipulate\n    # the MWT in the trees\n    for token in dep_sentence.tokens:\n        if len(token.words) == 1:\n            continue\n        if token.text in mwt_map:\n            mwt_pieces = mwt_map[token.text]\n            if len(mwt_pieces) != 2:\n                raise NotImplementedError(\"Expected exactly 2 pieces of mwt for %s\" % token.text)\n            # the MWT words in the UD version will have ' when needed,\n            # but the corresponding ' is skipped in the con version of VIT,\n            # hence the replace(\"'\", \"\")\n            # however, all' has the ' included, because this is a\n            # constituent treebank, not a consistent treebank\n            search_regex = \"/^(?i:%s(?:')?)$/\" % token.text.replace(\"'\", \"\")\n            # tags which seem to be relevant:\n            # avvl|ccom|php|part|partd|partda\n            tregex = \"__ !> __ <<<%d (%s=child > (__=parent $+ sn=sn))\" % (token.id[0], search_regex)\n            tsurgeons = [\"insert (art %s) >0 sn\" % mwt_pieces[1], \"relabel child %s\" % mwt_pieces[0]]\n            operations.append([tregex] + tsurgeons)\n\n            tregex = \"__ !> __ <<<%d (%s=child > (__=parent !$+ sn !$+ (art < %s)))\" % (token.id[0], search_regex, mwt_pieces[1])\n            tsurgeons = [\"insert (art %s) $- parent\" % mwt_pieces[1], \"relabel child %s\" % mwt_pieces[0]]\n            operations.append([tregex] + tsurgeons)\n        elif len(token.words) == 2:\n            #print(\"{} not in mwt_map\".format(token.text))\n            # apparently some trees like sent_00381 and sent_05070\n            # have the clitic in a non-projective manner\n            #   [vcl-essersi, vppin-sparato, compt-[clitdat-si\n            #   intj-figurarsi, fs-[cosu-quando, f-[ibar-[clit-si\n            # and before you ask, there are also clitics which are\n            # simply not there at all, rather than always attached\n            # in a non-projective manner\n            tregex = \"__=parent < (/^(?i:%s)$/=child . (__=np !< __ . (/^clit/=clit < %s)))\" % (token.text, token.words[1].text)\n            tsurgeon = \"moveprune clit $- parent\"\n            operations.append([tregex, tsurgeon])\n\n            # there are also some trees which don't have clitics\n            # for example, trees should look like this:\n            #   [ibar-[vsup-poteva, vcl-rivelarsi], compc-[clit-si, sn-[...]]]\n            # however, at least one such example for rivelarsi instead\n            # looks like this, with no corresponding clit\n            #   [... vcl-rivelarsi], compc-[sn-[in-ancora]]\n            # note that is the actual tag, not just me being pissed off\n            # breaking down the tregex:\n            # the child is the original MWT, not split\n            # !. clit verifies that it is not split (and stops the tsurgeon once fixed)\n            # !$+ checks that the parent of the MWT is the last element under parent\n            # note that !. can leave the immediate parent to touch the clit\n            # neighbor will be the place the new clit will be sticking out\n            tregex = \"__=parent < (/^(?i:%s)$/=child !. /^clit/) !$+ __ > (__=gp $+ __=neighbor)\" % token.text\n            tsurgeon = \"insert (clit %s) >0 neighbor\" % token.words[1].text\n            operations.append([tregex, tsurgeon])\n\n            # secondary option: while most trees are like the above,\n            # with an outer bracket around the MWT and another verb,\n            # some go straight into the next phrase\n            #   sent_05076\n            #   sv5-[vcl-adeguandosi, compin-[sp-[part-alle, ...\n            tregex = \"__=parent < (/^(?i:%s)$/=child !. /^clit/) $+ __\" % token.text\n            tsurgeon = \"insert (clit %s) $- parent\" % token.words[1].text\n            operations.append([tregex, tsurgeon])\n        else:\n            pass\n    if len(operations) > 0:\n        updated_tree = tsurgeon_processor.process(updated_tree, *operations)[0]\n    return updated_tree, operations\n\ndef update_tree(original_tree, dep_sentence, con_id, dep_id, mwt_map, tsurgeon_processor):\n    \"\"\"\n    Update a tree using the mwt_map and tsurgeon to expand some MWTs\n\n    Then replace the words in the con tree with the words in the dep tree\n    \"\"\"\n    ud_words = [x.text for x in dep_sentence.words]\n\n    updated_tree, operations = update_mwts_and_special_cases(original_tree, dep_sentence, mwt_map, tsurgeon_processor)\n\n    # this checks number of words\n    try:\n        updated_tree = updated_tree.replace_words(ud_words)\n    except ValueError as e:\n        raise ValueError(\"Failed to process {} {}:\\nORIGINAL TREE\\n{}\\nUPDATED TREE\\n{}\\nUPDATED LEAVES\\n{}\\nUD TEXT\\n{}\\nTsurgeons applied:\\n{}\\n\".format(con_id, dep_id, original_tree, updated_tree, updated_tree.leaf_labels(), ud_words, \"\\n\".join(\"{}\".format(op) for op in operations))) from e\n    return updated_tree\n\n# train set:\n#  858: missing close parens in the UD conversion\n# 1169: 'che', 'poi', 'tutti', 'i', 'Paesi', 'ue', '.' -> 'per', 'tutti', 'i', 'paesi', 'Ue', '.'\n# 2375: the problem is inconsistent treatment of s_p_a_\n# 05052: the heuristic to fill in a missing \"si\" doesn't work because there's\n#   already another \"si\" immediately after\n#\n# test set:\n# 09764: weird punct at end\n# 10058: weird punct at end\nIGNORE_IDS = [\"sent_00867\", \"sent_01169\", \"sent_02375\", \"sent_05052\", \"sent_09764\", \"sent_10058\"]\n\ndef extract_updated_dataset(con_tree_map, dep_sentence_map, split_ids, mwt_map, tsurgeon_processor):\n    \"\"\"\n    Update constituency trees using the information in the dependency treebank\n    \"\"\"\n    trees = []\n    for con_id, dep_id in tqdm(split_ids.items()):\n        # skip a few trees which have non-MWT word modifications\n        if con_id in IGNORE_IDS:\n            continue\n        original_tree = con_tree_map[con_id]\n        dep_sentence = dep_sentence_map[dep_id]\n        updated_tree = update_tree(original_tree, dep_sentence, con_id, dep_id, mwt_map, tsurgeon_processor)\n\n        trees.append(ProcessedTree(con_id, dep_id, updated_tree))\n    return trees\n\ndef read_updated_trees(paths, debug_sentence=None):\n    # original version with more errors\n    #con_filename = os.path.join(con_directory, \"2011-12-20\", \"Archive\", \"VIT_newconstsynt.txt\")\n    # this is the April 2022 version\n    #con_filename = os.path.join(con_directory, \"VIT_newconstsynt.txt\")\n    # the most recent update from ELRA may look like this?\n    # it's what we got, at least\n    # con_filename = os.path.join(con_directory, \"italian\", \"VITwritten\", \"VITconstsyntNumb\")\n\n    # needs at least UD 2.11 or this will not work\n    con_directory = paths[\"CONSTITUENCY_BASE\"]\n    ud_directory = os.path.join(paths[\"UDBASE\"], \"UD_Italian-VIT\")\n\n    con_filename = os.path.join(con_directory, \"italian\", \"it_vit\", \"VITwritten\", \"VITconstsyntNumb\")\n    ud_vit_train = os.path.join(ud_directory, \"it_vit-ud-train.conllu\")\n    ud_vit_dev   = os.path.join(ud_directory, \"it_vit-ud-dev.conllu\")\n    ud_vit_test  = os.path.join(ud_directory, \"it_vit-ud-test.conllu\")\n\n    print(\"Reading UD train/dev/test from %s\" % ud_directory)\n    ud_train_data = CoNLL.conll2doc(input_file=ud_vit_train)\n    ud_dev_data   = CoNLL.conll2doc(input_file=ud_vit_dev)\n    ud_test_data  = CoNLL.conll2doc(input_file=ud_vit_test)\n\n    ud_vit_train_map = { DEP_ID_FUNC(x) : x for x in ud_train_data.sentences }\n    ud_vit_dev_map   = { DEP_ID_FUNC(x) : x for x in ud_dev_data.sentences }\n    ud_vit_test_map  = { DEP_ID_FUNC(x) : x for x in ud_test_data.sentences }\n\n    print(\"Getting ADP/DET expansions from UD data\")\n    mwt_map = get_mwt(ud_train_data, ud_dev_data, ud_test_data)\n\n    con_sentences = read_constituency_file(con_filename)\n    num_discarded = 0\n    con_tree_map = {}\n    for idx, sentence in enumerate(tqdm(con_sentences, postfix=\"Processing\")):\n        try:\n            tree = raw_tree(sentence[1])\n            if sentence[0].startswith(\"#ID=\"):\n                tree_id = sentence[0].split(\"=\")[-1]\n            else:\n                tree_id = sentence[0].split(\"#\")[-1]\n            # don't care about the raw text?\n            con_tree_map[tree_id] = tree\n        except UnclosedTreeError as e:\n            num_discarded = num_discarded + 1\n            print(\"Discarding {} because of reading error:\\n  {}: {}\\n  {}\".format(sentence[0], type(e), e, sentence[1]))\n        except ExtraCloseTreeError as e:\n            num_discarded = num_discarded + 1\n            print(\"Discarding {} because of reading error:\\n  {}: {}\\n  {}\".format(sentence[0], type(e), e, sentence[1]))\n        except ValueError as e:\n            print(\"Discarding {} because of reading error:\\n  {}: {}\\n  {}\".format(sentence[0], type(e), e, sentence[1]))\n            num_discarded = num_discarded + 1\n            #raise ValueError(\"Could not process line %d\" % idx) from e\n\n    print(\"Discarded %d trees.  Have %d trees left\" % (num_discarded, len(con_tree_map)))\n    if num_discarded > 0:\n        raise ValueError(\"Oops!  We thought all of the VIT trees were properly bracketed now\")\n    con_vit_ngrams = build_ngrams(con_tree_map.items(), lambda x: CON_PROCESS_FUNC(x[1]), lambda x: x[0])\n\n    # TODO: match more sentences.  some are probably missing because of MWT\n    train_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_train_data.sentences, \"train\", debug_sentence)\n    dev_ids   = match_sentences(con_tree_map, con_vit_ngrams, ud_dev_data.sentences,   \"dev\",   debug_sentence)\n    test_ids  = match_sentences(con_tree_map, con_vit_ngrams, ud_test_data.sentences,  \"test\",  debug_sentence)\n    print(\"Remaining total trees: %d\" % (len(train_ids) + len(dev_ids) + len(test_ids)))\n    print(\"  {} train {} dev {} test\".format(len(train_ids), len(dev_ids), len(test_ids)))\n    print(\"Updating trees with MWT and newer tokens from UD...\")\n\n    # the moveprune feature requires a new corenlp release after 4.4.0\n    with tsurgeon.Tsurgeon(classpath=\"$CLASSPATH\") as tsurgeon_processor:\n        train_trees = extract_updated_dataset(con_tree_map, ud_vit_train_map, train_ids, mwt_map, tsurgeon_processor)\n        dev_trees   = extract_updated_dataset(con_tree_map, ud_vit_dev_map,   dev_ids,   mwt_map, tsurgeon_processor)\n        test_trees  = extract_updated_dataset(con_tree_map, ud_vit_test_map,  test_ids,  mwt_map, tsurgeon_processor)\n\n    return train_trees, dev_trees, test_trees\n\ndef convert_it_vit(paths, dataset_name, debug_sentence=None):\n    \"\"\"\n    Read the trees, then write them out to the expected output_directory\n    \"\"\"\n    train_trees, dev_trees, test_trees = read_updated_trees(paths, debug_sentence)\n\n    train_trees = [x.tree for x in train_trees]\n    dev_trees   = [x.tree for x in dev_trees]\n    test_trees  = [x.tree for x in test_trees]\n\n    output_directory = paths[\"CONSTITUENCY_DATA_DIR\"]\n    write_dataset([train_trees, dev_trees, test_trees], output_directory, dataset_name)\n\ndef main():\n    paths = default_paths.get_default_paths()\n    dataset_name = \"it_vit\"\n\n    debug_sentence = sys.argv[1] if len(sys.argv) > 1 else None\n\n    convert_it_vit(paths, dataset_name, debug_sentence)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_spmrl.py",
    "content": "import os\n\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.models.constituency.tree_reader import read_treebank\nfrom stanza.utils.default_paths import get_default_paths\n\nSHARDS = (\"train\", \"dev\", \"test\")\n\ndef add_root(tree):\n    if tree.label.startswith(\"NN\"):\n        tree = Tree(\"NP\", tree)\n    if tree.label.startswith(\"NE\"):\n        tree = Tree(\"PN\", tree)\n    elif tree.label.startswith(\"XY\"):\n        tree = Tree(\"VROOT\", tree)\n    return Tree(\"ROOT\", tree)\n\ndef convert_spmrl(input_directory, output_directory, short_name):\n    for shard in SHARDS:\n        tree_filename = os.path.join(input_directory, shard, shard + \".German.gold.ptb\")\n        trees = read_treebank(tree_filename, tree_callback=add_root)\n        output_filename = os.path.join(output_directory, \"%s_%s.mrg\" % (short_name, shard))\n        with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n            for tree in trees:\n                fout.write(str(tree))\n                fout.write(\"\\n\")\n        print(\"Wrote %d trees to %s\" % (len(trees), output_filename))\n\nif __name__ == '__main__':\n    paths = get_default_paths()\n    output_directory = paths[\"CONSTITUENCY_DATA_DIR\"]\n    input_directory = \"extern_data/constituency/spmrl/SPMRL_SHARED_2014/GERMAN_SPMRL/gold/ptb\"\n    convert_spmrl(input_directory, output_directory, \"de_spmrl\")\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/convert_starlang.py",
    "content": "\nimport os\nimport re\n\nfrom tqdm import tqdm\n\nfrom stanza.models.constituency import parse_tree\nfrom stanza.models.constituency import tree_reader\n\nTURKISH_RE = re.compile(r\"[{]turkish=([^}]+)[}]\")\n\nDISALLOWED_LABELS = ('DT', 'DET', 's', 'vp', 'AFVP', 'CONJ', 'INTJ', '-XXX-')\n\ndef read_tree(text):\n    \"\"\"\n    Reads in a tree, then extracts specifically the word from the specific format used\n\n    Also converts LCB/RCB as needed\n    \"\"\"\n    trees = tree_reader.read_trees(text)\n    if len(trees) > 1:\n        raise ValueError(\"Tree file had two trees!\")\n    tree = trees[0]\n    labels = tree.leaf_labels()\n    new_labels = []\n    for label in labels:\n        match = TURKISH_RE.search(label)\n        if match is None:\n            raise ValueError(\"Could not find word in |{}|\".format(label))\n        word = match.group(1)\n        word = word.replace(\"-LCB-\", \"{\").replace(\"-RCB-\", \"}\")\n        new_labels.append(word)\n\n    tree = tree.replace_words(new_labels)\n    #tree = tree.remap_constituent_labels(LABEL_MAP)\n    con_labels = tree.get_unique_constituent_labels([tree])\n    if any(label in DISALLOWED_LABELS for label in con_labels):\n        raise ValueError(\"found an unexpected phrasal node {}\".format(label))\n    return tree\n\ndef read_files(filenames, conversion, log):\n    trees = []\n    for filename in filenames:\n        with open(filename, encoding=\"utf-8\") as fin:\n            text = fin.read()\n        try:\n            tree = conversion(text)\n            if tree is not None:\n                trees.append(tree)\n        except ValueError as e:\n            if log:\n                print(\"-----------------\\nFound an error in {}: {} Original text: {}\".format(filename, e, text))\n    return trees\n\ndef read_starlang(paths, conversion=read_tree, log=True):\n    \"\"\"\n    Read the starlang trees, converting them using the given method.\n\n    read_tree or any other conversion turns one file at a time to a sentence.\n    log is whether or not to log a ValueError - the NER division has many missing labels\n    \"\"\"\n    if isinstance(paths, str):\n        paths = (paths,)\n\n    train_files = []\n    dev_files = []\n    test_files = []\n\n    for path in paths:\n        tree_files = [os.path.join(path, x) for x in os.listdir(path)]\n        train_files.extend([x for x in tree_files if x.endswith(\".train\")])\n        dev_files.extend([x for x in tree_files if x.endswith(\".dev\")])\n        test_files.extend([x for x in tree_files if x.endswith(\".test\")])\n\n    print(\"Reading %d total files\" % (len(train_files) + len(dev_files) + len(test_files)))\n    train_treebank = read_files(tqdm(train_files), conversion=conversion, log=log)\n    dev_treebank = read_files(tqdm(dev_files), conversion=conversion, log=log)\n    test_treebank = read_files(tqdm(test_files), conversion=conversion, log=log)\n\n    return train_treebank, dev_treebank, test_treebank\n\ndef main(conversion=read_tree, log=True):\n    paths = [\"extern_data/constituency/turkish/TurkishAnnotatedTreeBank-15\",\n             \"extern_data/constituency/turkish/TurkishAnnotatedTreeBank2-15\",\n             \"extern_data/constituency/turkish/TurkishAnnotatedTreeBank2-20\"]\n    train_treebank, dev_treebank, test_treebank = read_starlang(paths, conversion=conversion, log=log)\n\n    print(\"Train: %d\" % len(train_treebank))\n    print(\"Dev: %d\" % len(dev_treebank))\n    print(\"Test: %d\" % len(test_treebank))\n\n    print(train_treebank[0])\n    return train_treebank, dev_treebank, test_treebank\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/count_common_words.py",
    "content": "import sys\n\nfrom collections import Counter\n\nfrom stanza.models.constituency import parse_tree\nfrom stanza.models.constituency import tree_reader\n\nword_counter = Counter()\ncount_words = lambda x: word_counter.update(x.leaf_labels())\n\ntree_reader.read_tree_file(sys.argv[1], tree_callback=count_words)\nprint(word_counter.most_common()[:100])\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/extract_all_silver_dataset.py",
    "content": "\"\"\"\nAfter running build_silver_dataset.py, this extracts the trees of all match levels at once\n\nFor example\n\npython stanza/utils/datasets/constituency/extract_all_silver_dataset.py --output_prefix /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_ --parsed_trees /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_wiki_a*trees\n\ncat /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_[012345678].mrg | sort | uniq | shuf > /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_sort.mrg\n\nshuf /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_sort.mrg | head -n 200000 > /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_200K.mrg\n\"\"\"\n\nimport argparse\nfrom collections import defaultdict\nimport json\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy\")\n    parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.')\n    parser.add_argument('--output_prefix', type=str, default=None, help='Prefix to use for outputting trees')\n    parser.add_argument('--output_suffix', type=str, default=\".mrg\", help='Suffix to use for outputting trees')\n    args = parser.parse_args()\n\n    return args\n\ndef main():\n    args = parse_args()\n\n    trees = defaultdict(list)\n    for filename in args.parsed_trees:\n        with open(filename, encoding='utf-8') as fin:\n            for line in fin.readlines():\n                tree = json.loads(line)\n                trees[tree['count']].append(tree['tree'])\n\n    for score, tree_list in trees.items():\n        filename = \"%s%s%s\" % (args.output_prefix, score, args.output_suffix)\n        with open(filename, 'w', encoding='utf-8') as fout:\n            for tree in tree_list:\n                fout.write(tree)\n                fout.write('\\n')\n\nif __name__ == '__main__':\n    main()\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/extract_silver_dataset.py",
    "content": "\"\"\"\nAfter running build_silver_dataset.py, this extracts the trees of a certain match level\n\nFor example\n\npython3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score 0 --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg\n\nfor i in `echo 0 1 2 3 4 5 6 7 8 9 10`; do python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score $i --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_$i.mrg; done\n\"\"\"\n\nimport argparse\nimport json\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy\")\n    parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.')\n    parser.add_argument('--keep_score', type=int, default=None, help='Which agreement level to keep.  None keeps all') \n    parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file')\n    args = parser.parse_args()\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    trees = []\n    for filename in args.parsed_trees:\n        with open(filename, encoding='utf-8') as fin:\n            for line in fin.readlines():\n                tree = json.loads(line)\n                if args.keep_score is None or tree['count'] == args.keep_score:\n                    tree = tree['tree']\n                    trees.append(tree)\n\n    if args.output_file is None:\n        for tree in trees:\n            print(tree)\n    else:\n        with open(args.output_file, 'w', encoding='utf-8') as fout:\n            for tree in trees:\n                fout.write(tree)\n                fout.write('\\n')\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/prepare_con_dataset.py",
    "content": "\"\"\"Converts raw data files from their original format (dataset dependent) into PTB trees.\n\nThe operation of this script depends heavily on the dataset in question.\nThe common result is that the data files go to data/constituency and are in PTB format.\n\nda_arboretum\n  Ekhard Bick\n    Arboretum, a Hybrid Treebank for Danish\n    https://www.researchgate.net/publication/251202293_Arboretum_a_Hybrid_Treebank_for_Danish\n  Available here for a license fee:\n    http://catalog.elra.info/en-us/repository/browse/ELRA-W0084/\n  Internal to Stanford, please contact Chris Manning and/or John Bauer\n  The file processed is the tiger xml, although there are some edits\n    needed in order to make it functional for our parser\n  The treebank comes as a tar.gz file, W0084.tar.gz\n  untar this file in $CONSTITUENCY_BASE/danish\n  then move the extracted folder to \"arboretum\"\n    $CONSTITUENCY_BASE/danish/W0084/... becomes\n    $CONSTITUENCY_BASE/danish/arboretum/...\n\nen_ptb3-revised is an updated version of PTB with NML and stuff\n  put LDC2015T13 in $CONSTITUENCY_BASE/english\n  the directory name may look like LDC2015T13_eng_news_txt_tbnk-ptb_revised\n  python3 -m stanza.utils.datasets.constituency.prepare_con_dataset en_ptb3-revised\n\n  All this needs to do is concatenate the various pieces\n\n  @article{ptb_revised,\n    title= {Penn Treebank Revised: English News Text Treebank LDC2015T13},\n    journal= {},\n    author= {Ann Bies and Justin Mott and Colin Warner},\n    year= {2015},\n    url= {https://doi.org/10.35111/xpjy-at91},\n    doi= {10.35111/xpjy-at91},\n    isbn= {1-58563-724-6},\n    dcmi= {text},\n    languages= {english},\n    language= {english},\n    ldc= {LDC2015T13},\n  }\n\nid_icon\n  ICON: Building a Large-Scale Benchmark Constituency Treebank\n    for the Indonesian Language\n    Ee Suan Lim, Wei Qi Leong, Ngan Thanh Nguyen, Dea Adhista,\n    Wei Ming Kng, William Chandra Tjhi, Ayu Purwarianti\n    https://aclanthology.org/2023.tlt-1.5.pdf\n  Available at https://github.com/aisingapore/seacorenlp-data\n  git clone the repo in $CONSTITUENCY_BASE/seacorenlp\n  so there is now a directory\n    $CONSTITUENCY_BASE/seacorenlp/seacorenlp-data\n  python3 -m stanza.utils.datasets.constituency.prepare_con_dataset id_icon\n\nit_turin\n  A combination of Evalita competition from 2011 and the ParTUT trees\n  More information is available in convert_it_turin\n\nit_vit\n  The original for the VIT UD Dataset\n  The UD version has a lot of corrections, so we try to apply those as much as possible\n  In fact, we applied some corrections of our own back to UD based on this treebank.\n    The first version which had those corrections is UD 2.10\n    Versions of UD before that won't work\n    Hopefully versions after that work\n    Set UDBASE to a path such that $UDBASE/UD_Italian-VIT is the UD version\n  The constituency labels are generally not very understandable, unfortunately\n    Some documentation is available here:\n      https://core.ac.uk/download/pdf/223148096.pdf\n      https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.423.5538&rep=rep1&type=pdf\n  Available from ELRA:\n    http://catalog.elra.info/en-us/repository/browse/ELRA-W0040/\n\nja_alt\n  Asian Language Treebank produced a treebank for Japanese:\n    Ye Kyaw Thu, Win Pa Pa, Masao Utiyama, Andrew Finch, Eiichiro Sumita\n    Introducing the Asian Language Treebank\n    http://www.lrec-conf.org/proceedings/lrec2016/pdf/435_Paper.pdf\n  Download\n    https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/Japanese-ALT-20210218.zip\n  unzip this in $CONSTITUENCY_BASE/japanese\n  this should create a directory $CONSTITUENCY_BASE/japanese/Japanese-ALT-20210218\n  In this directory, also download the following:\n    https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-train.txt\n    https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-dev.txt\n    https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-test.txt\n  In particular, there are two files with a bunch of bracketed parses,\n    Japanese-ALT-Draft.txt and Japanese-ALT-Reviewed.txt\n  The first word of each of these lines is SNT.80188.1 or something like that\n  This correlates with the three URL-... files, telling us whether the\n    sentence belongs in train/dev/test\n  python3 -m stanza.utils.datasets.constituency.prepare_con_dataset ja_alt\n\npt_cintil\n  CINTIL treebank for Portuguese, available at ELRA:\n    https://catalogue.elra.info/en-us/repository/browse/ELRA-W0055/\n  It can also be obtained from here:\n    https://hdl.handle.net/21.11129/0000-000B-D2FE-A\n  Produced at U Lisbon\n    António Branco; João Silva; Francisco Costa; Sérgio Castro\n      CINTIL TreeBank Handbook: Design options for the representation of syntactic constituency\n    Silva, João; António Branco; Sérgio Castro; Ruben Reis\n      Out-of-the-Box Robust Parsing of Portuguese\n    https://portulanclarin.net/repository/extradocs/CINTIL-Treebank.pdf\n    http://www.di.fc.ul.pt/~ahb/pubs/2011bBrancoSilvaCostaEtAl.pdf\n  If at Stanford, ask John Bauer or Chris Manning for the data\n  Otherwise, purchase it from ELRA or find it elsewhere if possible\n  Either way, unzip it in\n    $CONSTITUENCY_BASE/portuguese to the CINTIL directory\n    so for example, the final result might be\n    extern_data/constituency/portuguese/CINTIL/CINTIL-Treebank.xml\n  python3 -m stanza.utils.datasets.constituency.prepare_con_dataset pt_cintil\n\ntr_starlang\n  A dataset in three parts from the Starlang group in Turkey:\n  Neslihan Kara, Büşra Marşan, et al\n    Creating A Syntactically Felicitous Constituency Treebank For Turkish\n    https://ieeexplore.ieee.org/document/9259873\n  git clone the following three repos\n    https://github.com/olcaytaner/TurkishAnnotatedTreeBank-15\n    https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-15\n    https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-20\n  Put them in\n    $CONSTITUENCY_BASE/turkish\n  python3 -m stanza.utils.datasets.constituency.prepare_con_dataset tr_starlang\n\nvlsp09 is the 2009 constituency treebank:\n  Nguyen Phuong Thai, Vu Xuan Luong, Nguyen Thi Minh Huyen, Nguyen Van Hiep, Le Hong Phuong\n    Building a Large Syntactically-Annotated Corpus of Vietnamese\n    Proceedings of The Third Linguistic Annotation Workshop\n    In conjunction with ACL-IJCNLP 2009, Suntec City, Singapore, 2009\n  This can be obtained by contacting vlsp.resources@gmail.com\n\nvlsp22 is the 2022 constituency treebank from the VLSP bakeoff\n  there is an official test set as well\n  you may be able to obtain both of these by contacting vlsp.resources@gmail.com\n  NGUYEN Thi Minh Huyen, HA My Linh, VU Xuan Luong, PHAN Thi Hue,\n  LE Van Cuong, NGUYEN Thi Luong, NGO The Quyen\n    VLSP 2022 Challenge: Vietnamese Constituency Parsing\n    to appear in Journal of Computer Science and Cybernetics.\n\nvlsp23 is the 2023 update to the constituency treebank from the VLSP bakeoff\n  the vlsp22 code also works for the new dataset,\n    although some effort may be needed to update the tags\n  As of late 2024, the test set is available on request at vlsp.resources@gmail.com\n  Organize the directory\n    $CONSTITUENCY_BASE/vietnamese/VLSP_2023\n      $CONSTITUENCY_BASE/vietnamese/VLSP_2023/Trainingset\n      $CONSTITUENCY_BASE/vietnamese/VLSP_2023/test\n\nzh_ctb-51 is the 5.1 version of CTB\n  put LDC2005T01U01_ChineseTreebank5.1 in $CONSTITUENCY_BASE/chinese\n  python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-51\n\n  @article{xue_xia_chiou_palmer_2005,\n           title={The Penn Chinese TreeBank: Phrase structure annotation of a large corpus},\n           volume={11},\n           DOI={10.1017/S135132490400364X},\n           number={2},\n           journal={Natural Language Engineering},\n           publisher={Cambridge University Press},\n           author={XUE, NAIWEN and XIA, FEI and CHIOU, FU-DONG and PALMER, MARTA},\n           year={2005},\n           pages={207–238}}\n\nzh_ctb-51b is the same dataset, but using a smaller dev/test set\n  in our experiments, this is substantially easier\n\nzh_ctb-90 is the 9.0 version of CTB\n  put LDC2016T13 in $CONSTITUENCY_BASE/chinese\n  python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-90\n\n  the splits used are the ones from the file docs/ctb9.0-file-list.txt\n    included in the CTB 9.0 release\n\nSPMRL adds several treebanks\n  https://www.spmrl.org/\n  https://www.spmrl.org/sancl-posters2014.html\n  Currently only German is converted, the German version being a\n    version of the Tiger Treebank\n  python3 -m stanza.utils.datasets.constituency.prepare_con_dataset de_spmrl  \n\nen_mctb is a multidomain test set covering five domains other than newswire\n  https://github.com/RingoS/multi-domain-parsing-analysis\n  Challenges to Open-Domain Constituency Parsing\n\n  @inproceedings{yang-etal-2022-challenges,\n    title = \"Challenges to Open-Domain Constituency Parsing\",\n    author = \"Yang, Sen  and\n      Cui, Leyang and\n      Ning, Ruoxi and\n      Wu, Di and\n      Zhang, Yue\",\n    booktitle = \"Findings of the Association for Computational Linguistics: ACL 2022\",\n    month = may,\n    year = \"2022\",\n    address = \"Dublin, Ireland\",\n    publisher = \"Association for Computational Linguistics\",\n    url = \"https://aclanthology.org/2022.findings-acl.11\",\n    doi = \"10.18653/v1/2022.findings-acl.11\",\n    pages = \"112--127\",\n  }\n\n  This conversion replaces the top bracket from top -> ROOT and puts an extra S\n    bracket on any roots with more than one node.\n\"\"\"\n\nimport argparse\nimport os\nimport random\nimport sys\nimport tempfile\n\nfrom tqdm import tqdm\n\nfrom stanza.models.constituency import parse_tree\nimport stanza.utils.default_paths as default_paths\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.server import tsurgeon\nfrom stanza.utils.datasets.common import UnknownDatasetError\nfrom stanza.utils.datasets.constituency import utils\nfrom stanza.utils.datasets.constituency.convert_alt import convert_alt\nfrom stanza.utils.datasets.constituency.convert_arboretum import convert_tiger_treebank\nfrom stanza.utils.datasets.constituency.convert_cintil import convert_cintil_treebank\nimport stanza.utils.datasets.constituency.convert_ctb as convert_ctb\nfrom stanza.utils.datasets.constituency.convert_it_turin import convert_it_turin\nfrom stanza.utils.datasets.constituency.convert_it_vit import convert_it_vit\nfrom stanza.utils.datasets.constituency.convert_spmrl import convert_spmrl\nfrom stanza.utils.datasets.constituency.convert_starlang import read_starlang\nfrom stanza.utils.datasets.constituency.utils import SHARDS, write_dataset\nimport stanza.utils.datasets.constituency.vtb_convert as vtb_convert\nimport stanza.utils.datasets.constituency.vtb_split as vtb_split\n\ndef process_it_turin(paths, dataset_name, *args):\n    \"\"\"\n    Convert the it_turin dataset\n    \"\"\"\n    assert dataset_name == 'it_turin'\n    input_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"italian\")\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    convert_it_turin(input_dir, output_dir)\n\ndef process_it_vit(paths, dataset_name, *args):\n    # needs at least UD 2.11 or this will not work\n    # in the meantime, the git version of VIT will suffice\n    assert dataset_name == 'it_vit'\n    convert_it_vit(paths, dataset_name)\n\ndef process_vlsp09(paths, dataset_name, *args):\n    \"\"\"\n    Processes the VLSP 2009 dataset, discarding or fixing trees when needed\n    \"\"\"\n    assert dataset_name == 'vi_vlsp09'\n    vlsp_path = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"vietnamese\", \"VietTreebank_VLSP_SP73\", \"Kho ngu lieu 10000 cay cu phap\")\n    with tempfile.TemporaryDirectory() as tmp_output_path:\n        vtb_convert.convert_dir(vlsp_path, tmp_output_path)\n        vtb_split.split_files(tmp_output_path, paths[\"CONSTITUENCY_DATA_DIR\"], dataset_name)\n\ndef process_vlsp21(paths, dataset_name, *args):\n    \"\"\"\n    Processes the VLSP 2021 dataset, which is just a single file\n    \"\"\"\n    assert dataset_name == 'vi_vlsp21'\n    vlsp_file = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"vietnamese\", \"VLSP_2021\", \"VTB_VLSP21_tree.txt\")\n    if not os.path.exists(vlsp_file):\n        raise FileNotFoundError(\"Could not find the 2021 dataset in the expected location of {} - CONSTITUENCY_BASE == {}\".format(vlsp_file, paths[\"CONSTITUENCY_BASE\"]))\n    with tempfile.TemporaryDirectory() as tmp_output_path:\n        vtb_convert.convert_files([vlsp_file], tmp_output_path)\n        # This produces a 0 length test set, just as a placeholder until the actual test set is released\n        vtb_split.split_files(tmp_output_path, paths[\"CONSTITUENCY_DATA_DIR\"], dataset_name, train_size=0.9, dev_size=0.1)\n    _, _, test_file = vtb_split.create_paths(paths[\"CONSTITUENCY_DATA_DIR\"], dataset_name)\n    with open(test_file, \"w\"):\n        # create an empty test file - currently we don't have actual test data for VLSP 21\n        pass\n\ndef process_vlsp22(paths, dataset_name, *args):\n    \"\"\"\n    Processes the VLSP 2022 dataset, which is four separate files for some reason\n    \"\"\"\n    assert dataset_name == 'vi_vlsp22' or dataset_name == 'vi_vlsp23'\n\n    if dataset_name == 'vi_vlsp22':\n        default_subdir = 'VLSP_2022'\n        default_make_test_split = False\n        updated_tagset = False\n    elif dataset_name == 'vi_vlsp23':\n        default_subdir = os.path.join('VLSP_2023', 'Trainingdataset')\n        default_make_test_split = False\n        updated_tagset = True\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--subdir', default=default_subdir, type=str, help='Where to find the data - allows for using previous versions, if needed')\n    parser.add_argument('--no_convert_brackets', default=True, action='store_false', dest='convert_brackets', help=\"Don't convert the VLSP parens RKBT & LKBT to PTB parens\")\n    parser.add_argument('--n_splits', default=None, type=int, help='Split the data into this many pieces.  Relevant as there is no set training/dev split, so this allows for N models on N different dev sets')\n    parser.add_argument('--test_split', default=default_make_test_split, action='store_true', help='Split 1/10th of the data as a test split as well.  Useful for experimental results.  Less relevant since there is now an official test set')\n    parser.add_argument('--no_test_split', dest='test_split', action='store_false', help='Split 1/10th of the data as a test split as well.  Useful for experimental results.  Less relevant since there is now an official test set')\n    parser.add_argument('--seed', default=1234, type=int, help='Random seed to use when splitting')\n    args = parser.parse_args(args=list(*args))\n\n    if os.path.exists(args.subdir):\n        vlsp_dir = args.subdir\n    else:\n        vlsp_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"vietnamese\", args.subdir)\n    if not os.path.exists(vlsp_dir):\n        raise FileNotFoundError(\"Could not find the {} dataset in the expected location of {} - CONSTITUENCY_BASE == {}\".format(dataset_name, vlsp_dir, paths[\"CONSTITUENCY_BASE\"]))\n    vlsp_files = os.listdir(vlsp_dir)\n    vlsp_train_files = [os.path.join(vlsp_dir, x) for x in vlsp_files if x.startswith(\"file\") and not x.endswith(\".zip\")]\n    vlsp_train_files.sort()\n        \n    if dataset_name == 'vi_vlsp22':\n        vlsp_test_files = [os.path.join(vlsp_dir, x) for x in vlsp_files if x.startswith(\"private\") and not x.endswith(\".zip\")]\n    elif dataset_name == 'vi_vlsp23':\n        vlsp_test_dir = os.path.abspath(os.path.join(vlsp_dir, os.pardir, \"test\"))\n        vlsp_test_files = os.listdir(vlsp_test_dir)\n        vlsp_test_files = [os.path.join(vlsp_test_dir, x) for x in vlsp_test_files if x.endswith(\".csv\")]\n\n    if len(vlsp_train_files) == 0:\n        raise FileNotFoundError(\"No train files (files starting with 'file') found in {}\".format(vlsp_dir))\n    if not args.test_split and len(vlsp_test_files) == 0:\n        raise FileNotFoundError(\"No test files found in {}\".format(vlsp_dir))\n    print(\"Loading training files from {}\".format(vlsp_dir))\n    print(\"Procesing training files:\\n  {}\".format(\"\\n  \".join(vlsp_train_files)))\n    with tempfile.TemporaryDirectory() as train_output_path:\n        vtb_convert.convert_files(vlsp_train_files, train_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset)\n        # This produces a 0 length test set, just as a placeholder until the actual test set is released\n        if args.n_splits:\n            test_size = 0.1 if args.test_split else 0.0\n            dev_size = (1.0 - test_size) / args.n_splits\n            train_size = 1.0 - test_size - dev_size\n            for rotation in range(args.n_splits):\n                # there is a shuffle inside the split routine,\n                # so we need to reset the random seed each time\n                random.seed(args.seed)\n                rotation_name = \"%s-%d-%d\" % (dataset_name, rotation, args.n_splits)\n                if args.test_split:\n                    rotation_name = rotation_name + \"t\"\n                vtb_split.split_files(train_output_path, paths[\"CONSTITUENCY_DATA_DIR\"], rotation_name, train_size=train_size, dev_size=dev_size, rotation=(rotation, args.n_splits))\n        else:\n            test_size = 0.1 if args.test_split else 0.0\n            dev_size = 0.1\n            train_size = 1.0 - test_size - dev_size\n            if args.test_split:\n                dataset_name = dataset_name + \"t\"\n            vtb_split.split_files(train_output_path, paths[\"CONSTITUENCY_DATA_DIR\"], dataset_name, train_size=train_size, dev_size=dev_size)\n\n    if not args.test_split:\n        print(\"Procesing test files:\\n  {}\".format(\"\\n  \".join(vlsp_test_files)))\n        with tempfile.TemporaryDirectory() as test_output_path:\n            vtb_convert.convert_files(vlsp_test_files, test_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset)\n            if args.n_splits:\n                for rotation in range(args.n_splits):\n                    rotation_name = \"%s-%d-%d\" % (dataset_name, rotation, args.n_splits)\n                    vtb_split.split_files(test_output_path, paths[\"CONSTITUENCY_DATA_DIR\"], rotation_name, train_size=0, dev_size=0)\n            else:\n                vtb_split.split_files(test_output_path, paths[\"CONSTITUENCY_DATA_DIR\"], dataset_name, train_size=0, dev_size=0)\n    if not args.test_split and not args.n_splits and dataset_name == 'vi_vlsp23':\n        print(\"Procesing test files and keeping ids:\\n  {}\".format(\"\\n  \".join(vlsp_test_files)))\n        with tempfile.TemporaryDirectory() as test_output_path:\n            vtb_convert.convert_files(vlsp_test_files, test_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset, write_ids=True)\n            vtb_split.split_files(test_output_path, paths[\"CONSTITUENCY_DATA_DIR\"], dataset_name + \"-ids\", train_size=0, dev_size=0)\n\ndef process_arboretum(paths, dataset_name, *args):\n    \"\"\"\n    Processes the Danish dataset, Arboretum\n    \"\"\"\n    assert dataset_name == 'da_arboretum'\n\n    arboretum_file = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"danish\", \"arboretum\", \"arboretum.tiger\", \"arboretum.tiger\")\n    if not os.path.exists(arboretum_file):\n        raise FileNotFoundError(\"Unable to find input file for Arboretum.  Expected in {}\".format(arboretum_file))\n\n    treebank = convert_tiger_treebank(arboretum_file)\n    datasets = utils.split_treebank(treebank, 0.8, 0.1)\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n\n    output_filename = os.path.join(output_dir, \"%s.mrg\" % dataset_name)\n    print(\"Writing {} trees to {}\".format(len(treebank), output_filename))\n    parse_tree.Tree.write_treebank(treebank, output_filename)\n\n    write_dataset(datasets, output_dir, dataset_name)\n\n\ndef process_starlang(paths, dataset_name, *args):\n    \"\"\"\n    Convert the Turkish Starlang dataset to brackets\n    \"\"\"\n    assert dataset_name == 'tr_starlang'\n\n    PIECES = [\"TurkishAnnotatedTreeBank-15\",\n              \"TurkishAnnotatedTreeBank2-15\",\n              \"TurkishAnnotatedTreeBank2-20\"]\n\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    chunk_paths = [os.path.join(paths[\"CONSTITUENCY_BASE\"], \"turkish\", piece) for piece in PIECES]\n    datasets = read_starlang(chunk_paths)\n\n    write_dataset(datasets, output_dir, dataset_name)\n\ndef process_ja_alt(paths, dataset_name, *args):\n    \"\"\"\n    Convert and split the ALT dataset\n\n    TODO: could theoretically extend this to MY or any other similar dataset from ALT\n    \"\"\"\n    lang, source = dataset_name.split(\"_\", 1)\n    assert lang == 'ja'\n    assert source == 'alt'\n\n    PIECES = [\"Japanese-ALT-Draft.txt\", \"Japanese-ALT-Reviewed.txt\"]\n    input_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"japanese\", \"Japanese-ALT-20210218\")\n    input_files = [os.path.join(input_dir, input_file) for input_file in PIECES]\n    split_files = [os.path.join(input_dir, \"URL-%s.txt\" % shard) for shard in SHARDS]\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    output_files = [os.path.join(output_dir, \"%s_%s.mrg\" % (dataset_name, shard)) for shard in SHARDS]\n    convert_alt(input_files, split_files, output_files)\n\ndef process_pt_cintil(paths, dataset_name, *args):\n    \"\"\"\n    Convert and split the PT Cintil dataset\n    \"\"\"\n    lang, source = dataset_name.split(\"_\", 1)\n    assert lang == 'pt'\n    assert source == 'cintil'\n\n    input_file = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"portuguese\", \"CINTIL\", \"CINTIL-Treebank.xml\")\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    datasets = convert_cintil_treebank(input_file)\n\n    write_dataset(datasets, output_dir, dataset_name)\n\ndef process_id_icon(paths, dataset_name, *args):\n    lang, source = dataset_name.split(\"_\", 1)\n    assert lang == 'id'\n    assert source == 'icon'\n\n    input_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"seacorenlp\", \"seacorenlp-data\", \"id\", \"constituency\")\n    input_files = [os.path.join(input_dir, x) for x in (\"train.txt\", \"dev.txt\", \"test.txt\")]\n    datasets = []\n    for input_file in input_files:\n        trees = tree_reader.read_tree_file(input_file)\n        trees = [Tree(\"ROOT\", tree) for tree in trees]\n        datasets.append(trees)\n\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    write_dataset(datasets, output_dir, dataset_name)\n\ndef process_ctb_51(paths, dataset_name, *args):\n    lang, source = dataset_name.split(\"_\", 1)\n    assert lang == 'zh-hans'\n    assert source == 'ctb-51'\n\n    input_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"chinese\", \"LDC2005T01U01_ChineseTreebank5.1\", \"bracketed\")\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51)\n\ndef process_ctb_51b(paths, dataset_name, *args):\n    lang, source = dataset_name.split(\"_\", 1)\n    assert lang == 'zh-hans'\n    assert source == 'ctb-51b'\n\n    input_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"chinese\", \"LDC2005T01U01_ChineseTreebank5.1\", \"bracketed\")\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    if not os.path.exists(input_dir):\n        raise FileNotFoundError(\"CTB 5.1 location not found: %s\" % input_dir)\n    print(\"Loading trees from %s\" % input_dir)\n    convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51b)\n\ndef process_ctb_90(paths, dataset_name, *args):\n    lang, source = dataset_name.split(\"_\", 1)\n    assert lang == 'zh-hans'\n    assert source == 'ctb-90'\n\n    input_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"chinese\", \"LDC2016T13\", \"ctb9.0\", \"data\", \"bracketed\")\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V90)\n\n\ndef process_ptb3_revised(paths, dataset_name, *args):\n    input_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"english\", \"LDC2015T13_eng_news_txt_tbnk-ptb_revised\")\n    if not os.path.exists(input_dir):\n        backup_input_dir = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"english\", \"LDC2015T13\")\n        if not os.path.exists(backup_input_dir):\n            raise FileNotFoundError(\"Could not find ptb3-revised in either %s or %s\" % (input_dir, backup_input_dir))\n        input_dir = backup_input_dir\n\n    bracket_dir = os.path.join(input_dir, \"data\", \"penntree\")\n    output_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n\n    # compensate for a weird mislabeling in the original dataset\n    label_map = {\"ADJ-PRD\": \"ADJP-PRD\"}\n\n    train_trees = []\n    for i in tqdm(range(2, 22)):\n        new_trees = tree_reader.read_directory(os.path.join(bracket_dir, \"%02d\" % i))\n        new_trees = [t.remap_constituent_labels(label_map) for t in new_trees]\n        train_trees.extend(new_trees)\n\n    move_tregex = \"_ROOT_ <1 __=home <2 /^[.]$/=move\"\n    move_tsurgeon = \"move move >-1 home\"\n\n    print(\"Moving sentence final punctuation if necessary\")\n    with tsurgeon.Tsurgeon() as tsurgeon_processor:\n        train_trees = [tsurgeon_processor.process(tree, move_tregex, move_tsurgeon)[0] for tree in tqdm(train_trees)]\n\n    dev_trees = tree_reader.read_directory(os.path.join(bracket_dir, \"22\"))\n    dev_trees = [t.remap_constituent_labels(label_map) for t in dev_trees]\n\n    test_trees = tree_reader.read_directory(os.path.join(bracket_dir, \"23\"))\n    test_trees = [t.remap_constituent_labels(label_map) for t in test_trees]\n    print(\"Read %d train trees, %d dev trees, and %d test trees\" % (len(train_trees), len(dev_trees), len(test_trees)))\n    datasets = [train_trees, dev_trees, test_trees]\n    write_dataset(datasets, output_dir, dataset_name)\n\ndef process_en_mctb(paths, dataset_name, *args):\n    \"\"\"\n    Converts the following blocks:\n\n    dialogue.cleaned.txt  forum.cleaned.txt  law.cleaned.txt  literature.cleaned.txt  review.cleaned.txt\n    \"\"\"\n    base_path = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"english\", \"multi-domain-parsing-analysis\", \"data\", \"MCTB_en\")\n    if not os.path.exists(base_path):\n        raise FileNotFoundError(\"Please download multi-domain-parsing-analysis to %s\" % base_path)\n    def tree_callback(tree):\n        if len(tree.children) > 1:\n            tree = parse_tree.Tree(\"S\", tree.children)\n            return parse_tree.Tree(\"ROOT\", [tree])\n        return parse_tree.Tree(\"ROOT\", tree.children)\n\n    filenames = [\"dialogue.cleaned.txt\", \"forum.cleaned.txt\", \"law.cleaned.txt\", \"literature.cleaned.txt\", \"review.cleaned.txt\"]\n    for filename in filenames:\n        trees = tree_reader.read_tree_file(os.path.join(base_path, filename), tree_callback=tree_callback)\n        print(\"%d trees in %s\" % (len(trees), filename))\n        output_filename = \"%s-%s_test.mrg\" % (dataset_name, filename.split(\".\")[0])\n        output_filename = os.path.join(paths[\"CONSTITUENCY_DATA_DIR\"], output_filename)\n        print(\"Writing trees to %s\" % output_filename)\n        parse_tree.Tree.write_treebank(trees, output_filename)\n\ndef process_spmrl(paths, dataset_name, *args):\n    if dataset_name != 'de_spmrl':\n        raise ValueError(\"SPMRL dataset %s currently not supported\" % dataset_name)\n\n    output_directory = paths[\"CONSTITUENCY_DATA_DIR\"]\n    input_directory = os.path.join(paths[\"CONSTITUENCY_BASE\"], \"spmrl\", \"SPMRL_SHARED_2014\", \"GERMAN_SPMRL\", \"gold\", \"ptb\")\n\n    convert_spmrl(input_directory, output_directory, dataset_name)\n\nDATASET_MAPPING = {\n    'da_arboretum': process_arboretum,\n\n    'de_spmrl':     process_spmrl,\n\n    'en_ptb3-revised': process_ptb3_revised,\n    'en_mctb':      process_en_mctb,\n\n    'id_icon':      process_id_icon,\n\n    'it_turin':     process_it_turin,\n    'it_vit':       process_it_vit,\n\n    'ja_alt':       process_ja_alt,\n\n    'pt_cintil':    process_pt_cintil,\n\n    'tr_starlang':  process_starlang,\n\n    'vi_vlsp09':    process_vlsp09,\n    'vi_vlsp21':    process_vlsp21,\n    'vi_vlsp22':    process_vlsp22,\n    'vi_vlsp23':    process_vlsp22,  # options allow for this\n\n    'zh-hans_ctb-51':   process_ctb_51,\n    'zh-hans_ctb-51b':  process_ctb_51b,\n    'zh-hans_ctb-90':   process_ctb_90,\n}\n\ndef main(dataset_name, *args):\n    paths = default_paths.get_default_paths()\n\n    random.seed(1234)\n\n    if dataset_name in DATASET_MAPPING:\n        DATASET_MAPPING[dataset_name](paths, dataset_name, *args)\n    else:\n        raise UnknownDatasetError(dataset_name, f\"dataset {dataset_name} currently not handled by prepare_con_dataset\")\n\nif __name__ == '__main__':\n    if len(sys.argv) == 1:\n        print(\"Known datasets:\")\n        for key in DATASET_MAPPING:\n            print(\"  %s\" % key)\n    else:\n        main(sys.argv[1], sys.argv[2:])\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/reduce_dataset.py",
    "content": "\"\"\"\nCut short the training portion of a constituency dataset.\n\nOne could think this script isn't necessary, as shuf | head would work,\nbut some treebanks use multiple lines for representing trees.\nThus it is necessary to actually intelligently read the trees.\n\nRun with\n\npython3  stanza/utils/datasets/constituency/reduce_dataset.py --input zh-hans_ctb-51b --output zh-hans_ctb5k\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nfrom stanza.models.constituency import tree_reader\nimport stanza.utils.default_paths as default_paths\nfrom stanza.utils.datasets.constituency.utils import SHARDS, write_dataset\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Script that cuts a treebank down to size\")\n    parser.add_argument('--input', type=str, default=None, help='Input treebank')\n    parser.add_argument('--output', type=str, default=None, help='Output treebank')\n    parser.add_argument('--size', type=int, default=5000, help='How many trees')\n    args = parser.parse_args()\n\n    random.seed(1234)\n\n    paths = default_paths.get_default_paths()\n    output_directory = paths[\"CONSTITUENCY_DATA_DIR\"]\n\n    # data/constituency/en_ptb3_train.mrg\n    input_filenames = [os.path.join(output_directory, \"%s_%s.mrg\" % (args.input, shard)) for shard in SHARDS]\n    output_filenames = [\"%s_%s.mrg\" % (args.output, shard) for shard in SHARDS]\n    shrink_datasets = [True, False, False]\n\n    datasets = []\n    for input_filename, shrink in zip(input_filenames, shrink_datasets):\n        treebank = tree_reader.read_treebank(input_filename)\n        if shrink:\n            random.shuffle(treebank)\n            treebank = treebank[:args.size]\n        datasets.append(treebank)\n    write_dataset(datasets, output_directory, args.output)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/relabel_tags.py",
    "content": "\"\"\"\nRetag an S-expression tree with a new set of POS tags\n\nAlso includes an option to write the new trees as bracket_labels\n(essentially, skipping the treebank_to_labeled_brackets step)\n\"\"\"\n\nimport argparse\nimport logging\n\nfrom stanza import Pipeline\nfrom stanza.models.constituency import retagging\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.utils import retag_trees\n\nlogger = logging.getLogger('stanza')\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Script that retags a tree file\")\n    parser.add_argument('--lang', default='vi', type=str, help='Language')\n    parser.add_argument('--input_file', default='data/constituency/vi_vlsp21_train.mrg', help='File to retag')\n    parser.add_argument('--output_file', default='vi_vlsp21_train_retagged.mrg', help='Where to write the retagged trees')\n    retagging.add_retag_args(parser)\n\n    parser.add_argument('--bracket_labels', action='store_true', help='Write the trees as bracket labels instead of S-expressions')\n\n    args = parser.parse_args()\n    args = vars(args)\n    retagging.postprocess_args(args)\n\n    return args\n\ndef main():\n    args = parse_args()\n\n    retag_pipeline = retagging.build_retag_pipeline(args)\n\n    train_trees = tree_reader.read_treebank(args['input_file'])\n    logger.info(\"Retagging %d trees using %s\", len(train_trees), args['retag_package'])\n    train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos'])\n    tree_format = \"{:L}\" if args['bracket_labels'] else \"{}\"\n    with open(args['output_file'], \"w\") as fout:\n        for tree in train_trees:\n            fout.write(tree_format.format(tree))\n            fout.write(\"\\n\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/selftrain.py",
    "content": "\"\"\"\nCommon methods for the various self-training data collection scripts\n\"\"\"\n\nimport logging\nimport os\nimport random\nimport re\n\nimport stanza\nfrom stanza.models.common import utils\nfrom stanza.models.common.bert_embedding import TextTooLongError\nfrom stanza.utils.get_tqdm import get_tqdm\n\nlogger = logging.getLogger('stanza')\ntqdm = get_tqdm()\n\ndef common_args(parser):\n    parser.add_argument(\n        '--output_file',\n        default='data/constituency/vi_silver.mrg',\n        help='Where to write the silver trees'\n    )\n    parser.add_argument(\n        '--lang',\n        default='vi',\n        help='Which language tools to use for tokenization and POS'\n    )\n    parser.add_argument(\n        '--num_sentences',\n        type=int,\n        default=-1,\n        help='How many sentences to get per file (max)'\n    )\n    parser.add_argument(\n        '--models',\n        default='saved_models/constituency/vi_vlsp21_inorder.pt',\n        help='What models to use for parsing.  comma-separated'\n    )\n    parser.add_argument(\n        '--package',\n        default='default',\n        help='Which package to load pretrain & charlm from for the parsers'\n    )\n    parser.add_argument(\n        '--output_ptb',\n        default=False,\n        action='store_true',\n        help='Output trees in PTB brackets (default is a bracket language format)'\n    )\n\ndef add_length_args(parser):\n    parser.add_argument(\n        '--min_len',\n        default=5,\n        type=int,\n        help='Minimum length sentence to keep.  None = unlimited'\n    )\n    parser.add_argument(\n        '--no_min_len',\n        dest='min_len',\n        action='store_const',\n        const=None,\n        help='No minimum length'\n    )\n    parser.add_argument(\n        '--max_len',\n        default=100,\n        type=int,\n        help='Maximum length sentence to keep.  None = unlimited'\n    )\n    parser.add_argument(\n        '--no_max_len',\n        dest='max_len',\n        action='store_const',\n        const=None,\n        help='No maximum length'\n    )\n\ndef build_ssplit_pipe(ssplit, lang):\n    if ssplit:\n        return stanza.Pipeline(lang, processors=\"tokenize\")\n    else:\n        return stanza.Pipeline(lang, processors=\"tokenize\", tokenize_no_ssplit=True)\n\ndef build_tag_pipe(ssplit, lang, foundation_cache=None):\n    if ssplit:\n        return stanza.Pipeline(lang, processors=\"tokenize,pos\", foundation_cache=foundation_cache)\n    else:\n        return stanza.Pipeline(lang, processors=\"tokenize,pos\", tokenize_no_ssplit=True, foundation_cache=foundation_cache)\n\ndef build_parser_pipes(lang, models, package=\"default\", foundation_cache=None):\n    \"\"\"\n    Build separate pipelines for each parser model we want to use\n\n    It is highly recommended to pass in a FoundationCache to reuse bottom layers\n    \"\"\"\n    parser_pipes = []\n    for model_name in models.split(\",\"):\n        if os.path.exists(model_name):\n            # if the model name exists as a file, treat it as the path to the model\n            pipe = stanza.Pipeline(lang, processors=\"constituency\", package=package, constituency_model_path=model_name, constituency_pretagged=True, foundation_cache=foundation_cache)\n        else:\n            # otherwise, assume it is a package name?\n            pipe = stanza.Pipeline(lang, processors={\"constituency\": model_name}, constituency_pretagged=True, package=None, foundation_cache=foundation_cache)\n        parser_pipes.append(pipe)\n    return parser_pipes\n\ndef split_docs(docs, ssplit_pipe, max_len=140, max_word_len=50, chunk_size=2000):\n    \"\"\"\n    Using the ssplit pipeline, break up the documents into sentences\n\n    Filters out sentences which are too long or have words too long.\n\n    This step is necessary because some web text has unstructured\n    sentences which overwhelm the tagger, or even text with no\n    whitespace which breaks the charlm in the tokenizer or tagger\n    \"\"\"\n    raw_sentences = 0\n    filtered_sentences = 0\n    new_docs = []\n\n    logger.info(\"Splitting raw docs into sentences: %d\", len(docs))\n    for chunk_start in tqdm(range(0, len(docs), chunk_size)):\n        chunk = docs[chunk_start:chunk_start+chunk_size]\n        chunk = [stanza.Document([], text=t) for t in chunk]\n        chunk = ssplit_pipe(chunk)\n        sentences = [s for d in chunk for s in d.sentences]\n        raw_sentences += len(sentences)\n        sentences = [s for s in sentences if len(s.words) < max_len]\n        sentences = [s for s in sentences if max(len(w.text) for w in s.words) < max_word_len]\n        filtered_sentences += len(sentences)\n        new_docs.extend([s.text for s in sentences])\n\n    logger.info(\"Split sentences: %d\", raw_sentences)\n    logger.info(\"Sentences filtered for length: %d\", filtered_sentences)\n    return new_docs\n\n# from https://stackoverflow.com/questions/2718196/find-all-chinese-text-in-a-string-using-python-and-regex\nZH_RE = re.compile(u'[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]', re.UNICODE)\n# https://stackoverflow.com/questions/6787716/regular-expression-for-japanese-characters\nJA_RE = re.compile(u'[一-龠ぁ-ゔァ-ヴー々〆〤ヶ]', re.UNICODE)\nDEV_RE = re.compile(u'[\\u0900-\\u097f]', re.UNICODE)\n\ndef tokenize_docs(docs, pipe, min_len, max_len):\n    \"\"\"\n    Turn the text in docs into a list of whitespace separated sentences\n\n    docs: a list of strings\n    pipe: a Stanza pipeline for tokenizing\n    min_len, max_len: can be None to not filter by this attribute\n    \"\"\"\n    results = []\n    docs = [stanza.Document([], text=t) for t in docs]\n    if len(docs) == 0:\n        return results\n    pipe(docs)\n    is_zh = pipe.lang and pipe.lang.startswith(\"zh\")\n    is_ja = pipe.lang and pipe.lang.startswith(\"ja\")\n    is_vi = pipe.lang and pipe.lang.startswith(\"vi\")\n    for doc in docs:\n        for sentence in doc.sentences:\n            if min_len and len(sentence.words) < min_len:\n                continue\n            if max_len and len(sentence.words) > max_len:\n                continue\n            text = sentence.text\n            if (text.find(\"|\") >= 0 or text.find(\"_\") >= 0 or\n                text.find(\"<\") >= 0 or text.find(\">\") >= 0 or\n                text.find(\"[\") >= 0 or text.find(\"]\") >= 0 or\n                text.find('—') >= 0):   # an em dash, seems to be part of lists\n                continue\n            # the VI tokenizer in particular doesn't split these well\n            if any(any(w.text.find(c) >= 0 and len(w.text) > 1 for w in sentence.words)\n                   for c in '\"()'):\n                continue\n            text = [w.text.replace(\" \", \"_\") for w in sentence.words]\n            text = \" \".join(text)\n            if any(len(w.text) >= 50 for w in sentence.words):\n                # skip sentences where some of the words are unreasonably long\n                # could make this an argument\n                continue\n            if not is_zh and len(ZH_RE.findall(text)) > 250:\n                # some Chinese sentences show up in VI Wikipedia\n                # we want to eliminate ones which will choke the bert models\n                continue\n            if not is_ja and len(JA_RE.findall(text)) > 150:\n                # some Japanese sentences also show up in VI Wikipedia\n                # we want to eliminate ones which will choke the bert models\n                continue\n            if is_vi and len(DEV_RE.findall(text)) > 100:\n                # would need some list of languages that use\n                # Devanagari to eliminate sentences from all datasets.\n                # Otherwise we might accidentally throw away all the\n                # text from a language we need (although that would be obvious)\n                continue\n            results.append(text)\n    return results\n\ndef find_matching_trees(docs, num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=True, chunk_size=10, max_len=140, min_len=10, output_ptb=False):\n    \"\"\"\n    Find trees where all the parsers in parser_pipes agree\n\n    docs should be a list of strings.\n      one sentence per string or a whole block of text as long as the tag_pipe can break it into sentences\n\n    num_sentences > 0 gives an upper limit on how many sentences to extract.\n      If < 0, all possible sentences are extracted\n\n    accepted_trees is a running tally of all the trees already built,\n      so that we don't reuse the same sentence if we see it again\n    \"\"\"\n    if num_sentences < 0:\n        tqdm_total = len(docs)\n    else:\n        tqdm_total = num_sentences\n\n    if output_ptb:\n        output_format = \"{}\"\n    else:\n        output_format = \"{:L}\"\n\n    with tqdm(total=tqdm_total, leave=False) as pbar:\n        if shuffle:\n            random.shuffle(docs)\n        new_trees = set()\n        for chunk_start in range(0, len(docs), chunk_size):\n            chunk = docs[chunk_start:chunk_start+chunk_size]\n            chunk = [stanza.Document([], text=t) for t in chunk]\n\n            if num_sentences < 0:\n                pbar.update(len(chunk))\n\n            # first, retag the sentences\n            tag_pipe(chunk)\n\n            chunk = [d for d in chunk if len(d.sentences) > 0]\n            if max_len is not None:\n                # for now, we don't have a good way to deal with sentences longer than the bert maxlen\n                chunk = [d for d in chunk if max(len(s.words) for s in d.sentences) < max_len]\n            if len(chunk) == 0:\n                continue\n\n            parses = []\n            try:\n                for pipe in parser_pipes:\n                    pipe(chunk)\n                    trees = [output_format.format(sent.constituency) for doc in chunk for sent in doc.sentences if len(sent.words) >= min_len]\n                    parses.append(trees)\n            except TextTooLongError as e:\n                # easiest is to skip this chunk - could theoretically save the other sentences\n                continue\n\n            for tree in zip(*parses):\n                if len(set(tree)) != 1:\n                    continue\n                tree = tree[0]\n                if tree in accepted_trees:\n                    continue\n                if tree not in new_trees:\n                    new_trees.add(tree)\n                    if num_sentences >= 0:\n                        pbar.update(1)\n                if num_sentences >= 0 and len(new_trees) >= num_sentences:\n                    return new_trees\n\n    return new_trees\n\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/selftrain_it.py",
    "content": "\"\"\"Builds a self-training dataset from an Italian data source and two models\n\nThe idea is that the top down and the inorder parsers should make\nsomewhat different errors, so hopefully the sum of an 86 f1 parser and\nan 85.5 f1 parser will produce some half-decent silver trees which can\nbe used as self-training so that a new model can do better than either.\n\nOne dataset used is PaCCSS, which has 63000 pairs of sentences:\n\nhttp://www.italianlp.it/resources/paccss-it-parallel-corpus-of-complex-simple-sentences-for-italian/\n\nPaCCSS-IT: A Parallel Corpus of Complex-Simple Sentences for Automatic Text Simplification\n  Brunato, Dominique et al, 2016\n  https://aclanthology.org/D16-1034\n\nEven larger is the IT section of Europarl, which has 1900000 lines\n\nhttps://www.statmt.org/europarl/\n\nEuroparl: A Parallel Corpus for Statistical Machine Translation\n  Philipp Koehn\n  https://homepages.inf.ed.ac.uk/pkoehn/publications/europarl-mtsummit05.pdf\n\"\"\"\n\nimport argparse\nimport logging\nimport os\nimport random\n\nimport stanza\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.utils.datasets.constituency import selftrain\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\nlogger = logging.getLogger('stanza')\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Script that converts a public IT dataset to silver standard trees\"\n    )\n    selftrain.common_args(parser)\n    parser.add_argument(\n        '--input_dir',\n        default='extern_data/italian',\n        help='Path to the PaCCSS corpus and europarl corpus'\n    )\n\n    parser.add_argument(\n        '--no_europarl',\n        default=True,\n        action='store_false',\n        dest='europarl',\n        help='Use the europarl dataset.  Turning this off makes the script a lot faster'\n    )\n\n    parser.set_defaults(lang=\"it\")\n    parser.set_defaults(package=\"vit\")\n    parser.set_defaults(models=\"saved_models/constituency/it_best/it_vit_inorder_best.pt,saved_models/constituency/it_best/it_vit_topdown.pt\")\n    parser.set_defaults(output_file=\"data/constituency/it_silver.mrg\")\n\n    args = parser.parse_args()\n    return args\n\ndef get_paccss(input_dir):\n    \"\"\"\n    Read the paccss dataset, which is two sentences per line\n    \"\"\"\n    input_file = os.path.join(input_dir, \"PaCCSS/data-set/PACCSS-IT.txt\")\n    with open(input_file) as fin:\n        # the first line is a header line\n        lines = fin.readlines()[1:]\n    lines = [x.strip() for x in lines]\n    lines = [x.split(\"\\t\")[:2] for x in lines if x]\n    text = [y for x in lines for y in x]\n    logger.info(\"Read %d sentences from %s\", len(text), input_file)\n    return text\n\ndef get_europarl(input_dir, ssplit_pipe):\n    \"\"\"\n    Read the Europarl dataset\n\n    This dataset needs to be tokenized and split into lines\n    \"\"\"\n    input_file = os.path.join(input_dir, \"europarl/europarl-v7.it-en.it\")\n    with open(input_file) as fin:\n        # the first line is a header line\n        lines = fin.readlines()[1:]\n    lines = [x.strip() for x in lines]\n    lines = [x for x in lines if x]\n    logger.info(\"Read %d docs from %s\", len(lines), input_file)\n    lines = selftrain.split_docs(lines, ssplit_pipe)\n    return lines\n\ndef main():\n    \"\"\"\n    Combine the two datasets, parse them, and write out the results\n    \"\"\"\n    args = parse_args()\n\n    foundation_cache = FoundationCache()\n    ssplit_pipe = selftrain.build_ssplit_pipe(ssplit=True, lang=args.lang)\n    tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang, foundation_cache=foundation_cache)\n    parser_pipes = selftrain.build_parser_pipes(args.lang, args.models, package=args.package, foundation_cache=foundation_cache)\n\n    docs = get_paccss(args.input_dir)\n    if args.europarl:\n        docs.extend(get_europarl(args.input_dir, ssplit_pipe))\n\n    logger.info(\"Processing %d docs\", len(docs))\n    new_trees = selftrain.find_matching_trees(docs, args.num_sentences, set(), tag_pipe, parser_pipes, shuffle=False, chunk_size=100, output_ptb=args.output_ptb)\n    logger.info(\"Found %d unique trees which are the same between models\" % len(new_trees))\n    with open(args.output_file, \"w\") as fout:\n        for tree in sorted(new_trees):\n            fout.write(tree)\n            fout.write(\"\\n\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/selftrain_single_file.py",
    "content": "\"\"\"\nBuilds a self-training dataset from a single file.\n\nDefault is to assume one document of text per line.  If a line has\nmultiple sentences, they will be split using the stanza tokenizer.\n\"\"\"\n\nimport argparse\nimport io\nimport logging\nimport os\n\nimport numpy as np\n\nimport stanza\nfrom stanza.utils.datasets.constituency import selftrain\nfrom stanza.utils.get_tqdm import get_tqdm\n\nlogger = logging.getLogger('stanza')\ntqdm = get_tqdm()\n\ndef parse_args():\n    \"\"\"\n    Only specific argument for this script is the file to process\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Script that converts a single file of text to silver standard trees\"\n    )\n    selftrain.common_args(parser)\n    parser.add_argument(\n        '--input_file',\n        default=\"vi_part_1.aa\",\n        help='Path to the file to read'\n    )\n\n    args = parser.parse_args()\n    return args\n\n\ndef read_file(input_file):\n    \"\"\"\n    Read lines from an input file\n\n    Takes care to avoid encoding errors at the end of Oscar files.\n    The Oscar splits sometimes break a utf-8 character in half.\n    \"\"\"\n    with open(input_file, \"rb\") as fin:\n        text = fin.read()\n    text = text.decode(\"utf-8\", errors=\"replace\")\n    with io.StringIO(text) as fin:\n        lines = fin.readlines()\n    return lines\n\n\ndef main():\n    args = parse_args()\n\n    # TODO: make ssplit an argument\n    ssplit_pipe = selftrain.build_ssplit_pipe(ssplit=True, lang=args.lang)\n    tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang)\n    parser_pipes = selftrain.build_parser_pipes(args.lang, args.models)\n\n    # create a blank file.  we will append to this file so that partial results can be used\n    with open(args.output_file, \"w\") as fout:\n        pass\n\n    docs = read_file(args.input_file)\n    logger.info(\"Read %d lines from %s\", len(docs), args.input_file)\n    docs = selftrain.split_docs(docs, ssplit_pipe)\n\n    # breaking into chunks lets us output partial results and see the\n    # progress in log files\n    accepted_trees = set()\n    if len(docs) > 10000:\n        chunks = tqdm(np.array_split(docs, 100), disable=False)\n    else:\n        chunks = [docs]\n    for chunk in chunks:\n        new_trees = selftrain.find_matching_trees(chunk, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100)\n        accepted_trees.update(new_trees)\n\n        with open(args.output_file, \"a\") as fout:\n            for tree in sorted(new_trees):\n                fout.write(tree)\n                fout.write(\"\\n\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/selftrain_vi_quad.py",
    "content": "\"\"\"\nProcesses the train section of VI QuAD into trees suitable for use in the conparser lm\n\"\"\"\n\nimport argparse\nimport json\nimport logging\n\nimport stanza\nfrom stanza.utils.datasets.constituency import selftrain\n\nlogger = logging.getLogger('stanza')\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Script that converts vi quad to silver standard trees\"\n    )\n    selftrain.common_args(parser)\n    selftrain.add_length_args(parser)\n    parser.add_argument(\n        '--input_file',\n        default=\"extern_data/vietnamese/ViQuAD/train_ViQuAD.json\",\n        help='Path to the ViQuAD train file'\n    )\n    parser.add_argument(\n        '--tokenize_only',\n        default=False,\n        action='store_true',\n        help='Tokenize instead of writing trees'\n    )\n\n    args = parser.parse_args()\n    return args\n\ndef parse_quad(text):\n    \"\"\"\n    Read in a file from the VI quad dataset\n\n    The train file has a specific format:\n    the doc has a 'data' section\n    each block in the data is a separate document (138 in the train file, for example)\n    each block has a 'paragraphs' section\n    each paragrah has 'qas' and 'context'.  we care about the qas\n    each piece of qas has 'question', which is what we actually want\n    \"\"\"\n    doc = json.loads(text)\n\n    questions = []\n\n    for block in doc['data']:\n        paragraphs = block['paragraphs']\n        for paragraph in paragraphs:\n            qas = paragraph['qas']\n            for question in qas:\n                questions.append(question['question'])\n\n    return questions\n\n\ndef read_quad(train_file):\n    with open(train_file) as fin:\n        text = fin.read()\n\n    return parse_quad(text)\n\ndef main():\n    \"\"\"\n    Turn the train section of VI quad into a list of trees\n    \"\"\"\n    args = parse_args()\n\n    docs = read_quad(args.input_file)\n    logger.info(\"Read %d lines from %s\", len(docs), args.input_file)\n    if args.tokenize_only:\n        pipe = stanza.Pipeline(args.lang, processors=\"tokenize\")\n        text = selftrain.tokenize_docs(docs, pipe, args.min_len, args.max_len)\n        with open(args.output_file, \"w\", encoding=\"utf-8\") as fout:\n            for line in text:\n                fout.write(line)\n                fout.write(\"\\n\")\n    else:\n        tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang)\n        parser_pipes = selftrain.build_parser_pipes(args.lang, args.models)\n\n        # create a blank file.  we will append to this file so that partial results can be used\n        with open(args.output_file, \"w\") as fout:\n            pass\n\n        accepted_trees = set()\n        new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100)\n        new_trees = [tree for tree in new_trees if tree.find(\"(_SQ\") >= 0]\n        with open(args.output_file, \"a\") as fout:\n            for tree in sorted(new_trees):\n                fout.write(tree)\n                fout.write(\"\\n\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/selftrain_wiki.py",
    "content": "\"\"\"Builds a self-training dataset from an Italian data source and two models\n\nThe idea is that the top down and the inorder parsers should make\nsomewhat different errors, so hopefully the sum of an 86 f1 parser and\nan 85.5 f1 parser will produce some half-decent silver trees which can\nbe used as self-training so that a new model can do better than either.\n\nThe dataset used is PaCCSS, which has 63000 pairs of sentences:\n\nhttp://www.italianlp.it/resources/paccss-it-parallel-corpus-of-complex-simple-sentences-for-italian/\n\"\"\"\n\nimport argparse\nfrom collections import deque\nimport glob\nimport os\nimport random\n\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.utils.datasets.constituency import selftrain\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Script that converts part of a wikipedia dump to silver standard trees\"\n    )\n    selftrain.common_args(parser)\n    parser.add_argument(\n        '--input_dir',\n        default='extern_data/vietnamese/wikipedia/text',\n        help='Path to the wikipedia dump after processing by wikiextractor'\n    )\n    parser.add_argument(\n        '--no_shuffle',\n        dest='shuffle',\n        action='store_false',\n        help=\"Don't shuffle files when processing the directory\"\n    )\n\n    parser.set_defaults(num_sentences=10000)\n\n    args = parser.parse_args()\n    return args\n\ndef list_wikipedia_files(input_dir):\n    \"\"\"\n    Get a list of wiki files under the input_dir\n\n    Recursively traverse the directory, then sort\n    \"\"\"\n    if not os.path.isdir(input_dir) and os.path.split(input_dir)[1].startswith(\"wiki_\"):\n        return [input_dir]\n\n    wiki_files = []\n\n    recursive_files = deque()\n    recursive_files.extend(glob.glob(os.path.join(input_dir, \"*\")))\n    while len(recursive_files) > 0:\n        next_file = recursive_files.pop()\n        if os.path.isdir(next_file):\n            recursive_files.extend(glob.glob(os.path.join(next_file, \"*\")))\n        elif os.path.split(next_file)[1].startswith(\"wiki_\"):\n            wiki_files.append(next_file)\n\n    wiki_files.sort()\n    return wiki_files\n\ndef read_wiki_file(filename):\n    \"\"\"\n    Read the text from a wiki file as a list of paragraphs.\n\n    Each <doc> </doc> is its own item in the list.\n    Lines are separated by \\n\\n to give hints to the stanza tokenizer.\n    The first line after <doc> is skipped as it is usually the document title.\n    \"\"\"\n    with open(filename) as fin:\n        lines = fin.readlines()\n    docs = []\n    current_doc = []\n    line_iterator = iter(lines)\n    line = next(line_iterator, None)\n    while line is not None:\n        if line.startswith(\"<doc\"):\n            # skip the next line, as it is usually the title\n            line = next(line_iterator, None)\n        elif line.startswith(\"</doc\"):\n            if current_doc:\n                if len(current_doc) > 2:\n                    # a lot of very short documents are links to related documents\n                    # a single wikipedia can have tens of thousands of useless almost-duplicates\n                    docs.append(\"\\n\\n\".join(current_doc))\n                current_doc = []\n        else:\n            # not the start or end of a doc\n            # hopefully this is valid text\n            line = line.replace(\"()\", \" \")\n            line = line.replace(\"( )\", \" \")\n            line = line.strip()\n            if line.find(\"&lt;\") >= 0 or line.find(\"&gt;\") >= 0:\n                line = \"\"\n            if line:\n                current_doc.append(line)\n        line = next(line_iterator, None)\n\n    if current_doc:\n        docs.append(\"\\n\\n\".join(current_doc))\n    return docs\n\ndef main():\n    args = parse_args()\n\n    random.seed(1234)\n\n    wiki_files = list_wikipedia_files(args.input_dir)\n    if args.shuffle:\n        random.shuffle(wiki_files)\n\n    foundation_cache = FoundationCache()\n    tag_pipe = selftrain.build_tag_pipe(ssplit=True, lang=args.lang, foundation_cache=foundation_cache)\n    parser_pipes = selftrain.build_parser_pipes(args.lang, args.models, foundation_cache=foundation_cache)\n\n    # create a blank file.  we will append to this file so that partial results can be used\n    with open(args.output_file, \"w\") as fout:\n        pass\n\n    accepted_trees = set()\n    for filename in tqdm(wiki_files, disable=False):\n        docs = read_wiki_file(filename)\n        new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=args.shuffle)\n        accepted_trees.update(new_trees)\n\n        with open(args.output_file, \"a\") as fout:\n            for tree in sorted(new_trees):\n                fout.write(tree)\n                fout.write(\"\\n\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/silver_variance.py",
    "content": "\"\"\"\nUse the concepts in \"Dataset Cartography\" and \"Mind Your Outliers\" to find trees with the least variance over a training run\n\nhttps://arxiv.org/pdf/2009.10795.pdf\nhttps://arxiv.org/abs/2107.02331\n\nThe idea here is that high variance trees are more likely to be wrong in the first place.  Using this will filter a silver dataset to have better trees.\n\nfor example:\n\nnlprun -d a6000 -p high \"export CLASSPATH=/sailhome/horatio/CoreNLP/classes:/sailhome/horatio/CoreNLP/lib/*:$CLASSPATH; python3 stanza/utils/datasets/constituency/silver_variance.py --eval_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg saved_models/constituency/it_vit.top.each.silver0.constituency_0*0.pt --output_file filtered_silver0.mrg\" -o filter.out\n\"\"\"\n\nimport argparse\n\nimport logging\n\nimport numpy\n\nfrom stanza.models.common import utils\nfrom stanza.models.common.foundation_cache import FoundationCache\nfrom stanza.models.constituency import retagging\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.parser_training import run_dev_set\nfrom stanza.models.constituency.trainer import Trainer\nfrom stanza.models.constituency.utils import retag_trees\nfrom stanza.server.parser_eval import EvaluateParser\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nlogger = logging.getLogger('stanza.constituency.trainer')\n\ndef parse_args(args=None):\n    parser = argparse.ArgumentParser(description=\"Script to filter trees by how much variance they show over multiple checkpoints of a parser training run.\")\n\n    parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')\n    parser.add_argument('--output_file', type=str, default=None, help='Output file after sorting by variance.')\n\n    parser.add_argument('--charlm_forward_file', type=str, default=None, help=\"Exact path to use for forward charlm\")\n    parser.add_argument('--charlm_backward_file', type=str, default=None, help=\"Exact path to use for backward charlm\")\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n\n    utils.add_device_args(parser)\n\n    # TODO: use the training scripts to pick the charlm & pretrain if needed\n    parser.add_argument('--lang', default='it', help='Language to use')\n\n    parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')\n    parser.add_argument('models', type=str, nargs='+', default=None, help=\"Which model(s) to load\")\n\n    parser.add_argument('--keep', type=float, default=0.5, help=\"How many trees to keep after sorting by variance\")\n    parser.add_argument('--reverse', default=False, action='store_true', help='Actually, keep the high variance trees')\n\n    retagging.add_retag_args(parser)\n\n    args = vars(parser.parse_args())\n\n    retagging.postprocess_args(args)\n\n    return args\n\ndef main():\n    args = parse_args()\n    retag_pipeline = retagging.build_retag_pipeline(args)\n    foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()\n\n    print(\"Analyzing with the following models:\\n  \" + \"\\n  \".join(args['models']))\n\n    treebank = tree_reader.read_treebank(args['eval_file'])\n    logger.info(\"Read %d trees for analysis\", len(treebank))\n\n    f1_history = []\n    retagged_treebank = None\n\n    chunk_size = 5000\n    with EvaluateParser() as evaluator:\n        for model_filename in args['models']:\n            print(\"Starting processing with %s\" % model_filename)\n            trainer = Trainer.load(model_filename, args=args, foundation_cache=foundation_cache)\n            if retag_pipeline is not None and retagged_treebank is None:\n                retag_method = trainer.model.args['retag_method']\n                retag_xpos = trainer.model.args['retag_xpos']\n                logger.info(\"Retagging trees using the %s tags from the %s package...\", retag_method, args['retag_package'])\n                retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos)\n                logger.info(\"Retagging finished\")\n\n            current_history = []\n            for chunk_start in range(0, len(treebank), chunk_size):\n                chunk = treebank[chunk_start:chunk_start+chunk_size]\n                retagged_chunk = retagged_treebank[chunk_start:chunk_start+chunk_size] if retagged_treebank else None\n                f1, kbestF1, treeF1 = run_dev_set(trainer.model, retagged_chunk, chunk, args, evaluator)\n                current_history.extend(treeF1)\n\n            f1_history.append(current_history)\n\n    f1_history = numpy.array(f1_history)\n    f1_variance = numpy.var(f1_history, axis=0)\n    f1_sorted = sorted([(x, idx) for idx, x in enumerate(f1_variance)], reverse=args['reverse'])\n\n    num_keep = int(len(f1_sorted) * args['keep'])\n    with open(args['output_file'], \"w\", encoding=\"utf-8\") as fout:\n        for _, idx in f1_sorted[:num_keep]:\n            fout.write(str(treebank[idx]))\n            fout.write(\"\\n\")\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/split_holdout.py",
    "content": "\"\"\"\nSplit a constituency dataset randomly into 90/10 splits\n\nTODO: add a function to rotate the pieces of the split so that each\ntraining instance gets seen once\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.utils.datasets.constituency.utils import copy_dev_test\nfrom stanza.utils.default_paths import get_default_paths\n\ndef write_trees(base_path, dataset_name, trees):\n    output_path = os.path.join(base_path, \"%s_train.mrg\" % dataset_name)\n    with open(output_path, \"w\", encoding=\"utf-8\") as fout:\n        for tree in trees:\n            fout.write(\"%s\\n\" % tree)\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Split a standard dataset into 90/10 proportions of train so there is held out training data\")\n    parser.add_argument('--dataset', type=str, default=\"id_icon\", help='dataset to split')\n    parser.add_argument('--base_dataset', type=str, default=None, help='output name for base dataset')\n    parser.add_argument('--holdout_dataset', type=str, default=None, help='output name for holdout dataset')\n    parser.add_argument('--ratio', type=float, default=0.1, help='Number of trees to hold out')\n    parser.add_argument('--seed', type=int, default=1234, help='Random seed')\n    args = parser.parse_args()\n\n    if args.base_dataset is None:\n        args.base_dataset = args.dataset + \"-base\"\n        print(\"--base_dataset not set, using %s\" % args.base_dataset)\n\n    if args.holdout_dataset is None:\n        args.holdout_dataset = args.dataset + \"-holdout\"\n        print(\"--holdout_dataset not set, using %s\" % args.holdout_dataset)\n\n    base_path = get_default_paths()[\"CONSTITUENCY_DATA_DIR\"]\n    copy_dev_test(base_path, args.dataset, args.base_dataset)\n    copy_dev_test(base_path, args.dataset, args.holdout_dataset)\n\n    train_file = os.path.join(base_path, \"%s_train.mrg\" % args.dataset)\n    print(\"Reading %s\" % train_file)\n    trees = tree_reader.read_tree_file(train_file)\n\n    base_train = []\n    holdout_train = []\n\n    random.seed(args.seed)\n\n    for tree in trees:\n        if random.random() < args.ratio:\n            holdout_train.append(tree)\n        else:\n            base_train.append(tree)\n\n    write_trees(base_path, args.base_dataset, base_train)\n    write_trees(base_path, args.holdout_dataset, holdout_train)\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/split_weighted_ensemble.py",
    "content": "\"\"\"\nRead in a dataset and split the train portion into pieces\n\nOne chunk of the train will be the original dataset.\n\nOthers will be a sampling from the original dataset of the same size,\nbut sampled with replacement, with the goal being to get a random\ndistribution of trees with some reweighting of the original trees.\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.models.constituency.parse_tree import Tree\nfrom stanza.utils.datasets.constituency.utils import copy_dev_test\nfrom stanza.utils.default_paths import get_default_paths\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Split a standard dataset into 1 base section and N-1 random redraws of training data\")\n    parser.add_argument('--dataset', type=str, default=\"id_icon\", help='dataset to split')\n    parser.add_argument('--seed', type=int, default=1234, help='Random seed')\n    parser.add_argument('--num_splits', type=int, default=5, help='Number of splits')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n\n    base_path = get_default_paths()[\"CONSTITUENCY_DATA_DIR\"]\n    train_file = os.path.join(base_path, \"%s_train.mrg\" % args.dataset)\n    print(\"Reading %s\" % train_file)\n    train_trees = tree_reader.read_tree_file(train_file)\n\n    # For datasets with low numbers of certain constituents in the train set,\n    # we could easily find ourselves in a situation where all of the trees\n    # with a specific constituent have been randomly shuffled away from\n    # a random shuffle\n    # An example of this is there are 3 total trees with SQ in id_icon\n    # Therefore, we have to take a little care to guarantee at least one tree\n    # for each constituent type is in a random slice\n    # TODO: this doesn't compensate for transition schemes with compound transitions,\n    # such as in_order_compound.  could do a similar boosting with one per transition type\n    constituents = sorted(Tree.get_unique_constituent_labels(train_trees))\n    con_to_trees = {con: list() for con in constituents}\n    for tree in train_trees:\n        tree_cons = Tree.get_unique_constituent_labels(tree)\n        for con in tree_cons:\n            con_to_trees[con].append(tree)\n    for con in constituents:\n        print(\"%d trees with %s\" % (len(con_to_trees[con]), con))\n\n    for i in range(args.num_splits):\n        dataset_name = \"%s-random-%d\" % (args.dataset, i)\n\n        copy_dev_test(base_path, args.dataset, dataset_name)\n        if i == 0:\n            train_dataset = train_trees\n        else:\n            train_dataset = []\n            for con in constituents:\n                train_dataset.extend(random.choices(con_to_trees[con], k=2))\n            needed_trees = len(train_trees) - len(train_dataset)\n            if needed_trees > 0:\n                print(\"%d trees already chosen.  Adding %d more\" % (len(train_dataset), needed_trees))\n                train_dataset.extend(random.choices(train_trees, k=needed_trees))\n        output_filename = os.path.join(base_path, \"%s_train.mrg\" % dataset_name)\n        print(\"Writing {} trees to {}\".format(len(train_dataset), output_filename))\n        Tree.write_treebank(train_dataset, output_filename)\n\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/tokenize_wiki.py",
    "content": "\"\"\"\nA short script to use a Stanza tokenizer to extract tokenized sentences from Wikipedia\n\nThe first step is to convert a Wikipedia dataset using Prof. Attardi's wikiextractor:\nhttps://github.com/attardi/wikiextractor\n\nThis script then writes out sentences, one per line, whitespace separated\nSome common issues with the tokenizer are accounted for by discarding those lines.\n\nAlso, to account for languages such as VI where whitespace occurs within words,\nspaces are replaced with _  This should not cause any confusion, as any line with\na natural _ in has already been discarded.\n\nfor i in `echo A B C D E F G H I J K`; do nlprun \"python3 stanza/utils/datasets/constituency/tokenize_wiki.py --output_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_B$i.txt --lang it --max_len 120 --input_dir /u/nlp/data/Wikipedia/itwiki/B$i --tokenizer_model saved_models/tokenize/it_combined_tokenizer.pt --download_method None\" -o /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_B$i.out; done\n\"\"\"\n\nimport argparse\nimport logging\n\nimport stanza\nfrom stanza.models.common.bert_embedding import load_tokenizer, filter_data\nfrom stanza.utils.datasets.constituency import selftrain_wiki\nfrom stanza.utils.datasets.constituency.selftrain import add_length_args, tokenize_docs\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Script that converts part of a wikipedia dump to silver standard trees\"\n    )\n    parser.add_argument(\n        '--output_file',\n        default='vi_wiki_tokenized.txt',\n        help='Where to write the tokenized lines'\n    )\n    parser.add_argument(\n        '--lang',\n        default='vi',\n        help='Which language tools to use for tokenization and POS'\n    )\n\n    input_group = parser.add_mutually_exclusive_group(required=True)\n    input_group.add_argument(\n        '--input_dir',\n        default=None,\n        help='Path to the wikipedia dump after processing by wikiextractor'\n    )\n    input_group.add_argument(\n        '--input_file',\n        default=None,\n        help='Path to a single file of the wikipedia dump after processing by wikiextractor'\n    )\n    parser.add_argument(\n        '--bert_tokenizer',\n        default=None,\n        help='Which bert tokenizer (if any) to use to filter long sentences'\n    )\n    parser.add_argument(\n        '--tokenizer_model',\n        default=None,\n        help='Use this model instead of the current Stanza tokenizer for this language'\n    )\n    parser.add_argument(\n        '--download_method',\n        default=None,\n        help='Download pipeline models using this method (defaults to downloading updates from HF)'\n    )\n    add_length_args(parser)\n    args = parser.parse_args()\n    return args\n\ndef main():\n    args = parse_args()\n    if args.input_dir is not None:\n        files = selftrain_wiki.list_wikipedia_files(args.input_dir)\n    elif args.input_file is not None:\n        files = [args.input_file]\n    else:\n        raise ValueError(\"Need to specify at least one file or directory!\")\n\n    if args.bert_tokenizer:\n        tokenizer = load_tokenizer(args.bert_tokenizer)\n        print(\"Max model length: %d\" % tokenizer.model_max_length)\n    pipeline_args = {}\n    if args.tokenizer_model:\n        pipeline_args[\"tokenize_model_path\"] = args.tokenizer_model\n    if args.download_method:\n        pipeline_args[\"download_method\"] = args.download_method\n    pipe = stanza.Pipeline(args.lang, processors=\"tokenize\", **pipeline_args)\n\n    with open(args.output_file, \"w\", encoding=\"utf-8\") as fout:\n        for filename in tqdm(files):\n            docs = selftrain_wiki.read_wiki_file(filename)\n            text = tokenize_docs(docs, pipe, args.min_len, args.max_len)\n            if args.bert_tokenizer:\n                filtered = filter_data(args.bert_tokenizer, [x.split() for x in text], tokenizer, logging.DEBUG)\n                text = [\" \".join(x) for x in filtered]\n            for line in text:\n                fout.write(line)\n                fout.write(\"\\n\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/treebank_to_labeled_brackets.py",
    "content": "\"\"\"\nConverts a PTB file to a format where all the brackets have labels on the start and end bracket.\n\nSuch a file should be suitable for training an LM\n\"\"\"\n\nimport argparse\nimport logging\nimport sys\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nlogger = logging.getLogger('stanza.constituency')\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Script that converts a PTB treebank into a labeled bracketed file suitable for LM training\"\n    )\n\n    parser.add_argument(\n        'ptb_file',\n        help='Where to get the original PTB format treebank'\n    )\n    parser.add_argument(\n        'label_file',\n        help='Where to write the labeled bracketed file'\n    )\n    parser.add_argument(\n        '--separator',\n        default=\"_\",\n        help='What separator to use in place of spaces',\n    )\n    parser.add_argument(\n        '--no_separator',\n        dest='separator',\n        action='store_const',\n        const=None,\n        help=\"Don't use a separator\"\n    )\n\n    args = parser.parse_args()\n\n    treebank = tree_reader.read_treebank(args.ptb_file)\n    logger.info(\"Writing %d trees to %s\", len(treebank), args.label_file)\n\n    tree_format = \"{:%sL}\\n\" % args.separator if args.separator else \"{:L}\\n\"\n    with open(args.label_file, \"w\", encoding=\"utf-8\") as fout:\n        for tree in tqdm(treebank):\n            fout.write(tree_format.format(tree))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/utils.py",
    "content": "\"\"\"\nUtilities for the processing of constituency treebanks\n\"\"\"\n\nimport os\nimport shutil\n\nfrom stanza.models.constituency import parse_tree\n\nSHARDS = (\"train\", \"dev\", \"test\")\n\ndef copy_dev_test(base_path, input_dataset, output_dataset):\n    shutil.copy2(os.path.join(base_path, \"%s_dev.mrg\" % input_dataset),\n                 os.path.join(base_path, \"%s_dev.mrg\" % output_dataset))\n    shutil.copy2(os.path.join(base_path, \"%s_test.mrg\" % input_dataset),\n                 os.path.join(base_path, \"%s_test.mrg\" % output_dataset))\n\ndef write_dataset(datasets, output_dir, dataset_name):\n    for dataset, shard in zip(datasets, SHARDS):\n        output_filename = os.path.join(output_dir, \"%s_%s.mrg\" % (dataset_name, shard))\n        print(\"Writing {} trees to {}\".format(len(dataset), output_filename))\n        parse_tree.Tree.write_treebank(dataset, output_filename)\n\ndef split_treebank(treebank, train_size, dev_size):\n    \"\"\"\n    Split a treebank deterministically\n    \"\"\"\n    train_end = int(len(treebank) * train_size)\n    dev_end = int(len(treebank) * (train_size + dev_size))\n    return treebank[:train_end], treebank[train_end:dev_end], treebank[dev_end:]\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/vtb_convert.py",
    "content": "\"\"\"\nScript for processing the VTB files and turning their trees into the desired tree syntax\n\nThe VTB original trees are stored in the directory:\nVietTreebank_VLSP_SP73/Kho ngu lieu 10000 cay cu phap\nThe script requires two arguments:\n1. Original directory storing the original trees\n2. New directory storing the converted trees\n\"\"\"\n\nimport argparse\nimport os\n\nfrom collections import defaultdict\n\nfrom stanza.models.constituency.tree_reader import read_trees, MixedTreeError, UnlabeledTreeError\n\nREMAPPING = {\n    '(ADV-MDP': '(RP-MDP',\n    '(MPD':     '(MDP',\n    '(MP ':     '(NP ',\n    '(MP(':     '(NP(',\n    '(Np(':     '(NP(',\n    '(Np (':    '(NP (',\n    '(NLOC':    '(NP-LOC',\n    '(N-P-LOC': '(NP-LOC',\n    '(N-p-loc': '(NP-LOC',\n    '(NPDOB':   '(NP-DOB',\n    '(NPSUB':   '(NP-SUB',\n    '(NPTMP':   '(NP-TMP',\n    '(PPLOC':   '(PP-LOC',\n    '(SBA ':    '(SBAR ',\n    '(SBA-':    '(SBAR-',\n    '(SBA(':    '(SBAR(',\n    '(SBAS':    '(SBAR',\n    '(SABR':    '(SBAR',\n    '(SE-SPL':  '(S-SPL',\n    '(SBARR':   '(SBAR',\n    'PPADV':    'PP-ADV',\n    '(PR (':    '(PP (',\n    '(PPP':     '(PP',\n    'VP0ADV':   'VP-ADV',\n    '(S1':      '(S',\n    '(S2':      '(S',\n    '(S3':      '(S',\n    'BP-SUB':   'NP-SUB',\n    'APPPD':    'AP-PPD',\n    'APPRD':    'AP-PPD',\n    'Np--H':    'Np-H',\n    '(WPNP':    '(WHNP',\n    '(WHRPP':   '(WHRP',\n    # the one mistagged PV is on a prepositional phrase\n    # (the subtree there maybe needs an SBAR as well, but who's counting)\n    '(PV':      '(PP',\n    '(Mpd':     '(MDP',\n    # this only occurs on \"bao giờ\", \"when\"\n    # that seems to be WHNP when under an SBAR, but WHRP otherwise\n    '(Whadv ':  '(WHRP ',\n    # Whpr Occurs in two places: on \"sao\" in a context which is always WHRP,\n    # and on \"nào\", which Vy says is more like a preposition\n    '(Whpr (Pro-h nào))': '(WHPP (Pro-h nào))',\n    '(Whpr (Pro-h Sao))': '(WHRP (Pro-h Sao))',\n    # This is very clearly an NP: (Tp-tmp (N-h hiện nay))\n    # which is only ever in NP-TMP contexts\n    '(Tp-tmp':  '(NP-TMP',\n    # This occurs once, in the context of (Yp (SYM @))\n    # The other times (SYM @) shows up, it's always NP\n    '(Yp':      '(NP',\n}\n\ndef unify_label(tree):\n    for old, new in REMAPPING.items():\n        tree = tree.replace(old, new)\n\n    return tree\n\n\ndef count_paren_parity(tree):\n    \"\"\"\n    Checks if the tree is properly closed\n    :param tree: tree as a string\n    :return: True if closed otherwise False\n    \"\"\"\n    count = 0\n    for char in tree:\n        if char == '(':\n            count += 1\n        elif char == ')':\n            count -= 1\n    return count\n\n\ndef is_valid_line(line):\n    \"\"\"\n    Check if a line being read is a valid constituent\n\n    The idea is that some \"trees\" are just a long list of words with\n    no tree structure and need to be eliminated.\n\n    :param line: constituent being read\n    :return: True if it has open OR closing parenthesis.\n    \"\"\"\n    if line.startswith('(') or line.endswith(')'):\n        return True\n\n    return False\n\n# not clear if TP is supposed to be NP or PP - needs a native speaker to decode\nWEIRD_LABELS = sorted(set([\"WP\", \"YP\", \"SNP\", \"STC\", \"UPC\", \"(TP\", \"Xp\", \"XP\", \"WHVP\", \"WHPR\", \"NO\", \"WHADV\", \"(SC (\", \"(VOC (\", \"(Adv (\", \"(SP (\", \"ADV-MDP\", \"(SPL\", \"(ADV (\", \"(V-MWE (\"] + list(REMAPPING.keys())))\n# the 2023 dataset has TP and WHADV as actual labels\n# furthermore, trees with NO were cleaned up and one of the test trees has NORD as a word\nWEIRD_LABELS_2023 = sorted(set([\"WP\", \"YP\", \"SNP\", \"STC\", \"UPC\", \"Xp\", \"XP\", \"WHVP\", \"WHPR\", \"(SC (\", \"(VOC (\", \"(Adv (\", \"(SP (\", \"ADV-MDP\", \"(SPL\", \"(ADV (\", \"(V-MWE (\"] + list(REMAPPING.keys())))\n\ndef convert_file(orig_file, new_file, fix_errors=True, convert_brackets=False, updated_tagset=False, write_ids=False):\n    \"\"\"\n    :param orig_file: original directory storing original trees\n    :param new_file: new directory storing formatted constituency trees\n    This function writes new trees to the corresponding files in new_file\n    \"\"\"\n    if updated_tagset:\n        weird_labels = WEIRD_LABELS_2023\n    else:\n        weird_labels = WEIRD_LABELS\n    errors = defaultdict(list)\n    with open(orig_file, 'r', encoding='utf-8') as reader, open(new_file, 'w', encoding='utf-8') as writer:\n        content = reader.readlines()\n        # Tree string will only be written if the currently read\n        # tree is a valid tree. It will not be written if it\n        # does not have a '(' that signifies the presence of constituents\n        tree = \"\"\n        tree_id = None\n        reading_tree = False\n        for line_idx, line in enumerate(content):\n            line = ' '.join(line.split())\n            if line == '':\n                continue\n            elif line == '<s>' or line.startswith(\"<s id=\"):\n                tree = \"\"\n                tree += '(ROOT '\n                reading_tree = True\n                if line.startswith(\"<s id=\"):\n                    tree_id = line.split(\"=\")[1]\n                    assert tree_id.endswith(\">\")\n                    tree_id = int(tree_id[:-1])\n            elif line == '</s>' and reading_tree:\n                # one tree in 25432.prd is not valid because\n                # it is just a bunch of blank lines\n                if tree.strip() == '(ROOT':\n                    errors[\"empty\"].append(\"Empty tree in {} line {}\".format(orig_file, line_idx))\n                    continue\n                tree += ')\\n'\n                parity = count_paren_parity(tree)\n                if parity > 0:\n                    errors[\"unclosed\"].append(\"Unclosed tree from {} line {}: |{}|\".format(orig_file, line_idx, tree))\n                    continue\n                if parity < 0:\n                    errors[\"extra_parens\"].append(\"Extra parens at end of tree from {} line {} for having extra parens: {}\".format(orig_file, line_idx, tree))\n                    continue\n                if convert_brackets:\n                    tree = tree.replace(\"RBKT\", \"-RRB-\").replace(\"LBKT\", \"-LRB-\")\n                try:\n                    # test that the tree can be read in properly\n                    processed_trees = read_trees(tree)\n                    if len(processed_trees) > 1:\n                        errors[\"multiple\"].append(\"Multiple trees in one xml annotation from {} line {}\".format(orig_file, line_idx))\n                        continue\n                    if len(processed_trees) == 0:\n                        errors[\"empty\"].append(\"Empty tree in {} line {}\".format(orig_file, line_idx))\n                        continue\n                    if not processed_trees[0].all_leaves_are_preterminals():\n                        errors[\"untagged_leaf\"].append(\"Tree with non-preterminal leaves in {} line {}: {}\".format(orig_file, line_idx, tree))\n                        continue\n                    # Unify the labels\n                    if fix_errors:\n                        tree = unify_label(tree)\n\n                    # TODO: this block eliminates 3 trees from VLSP-22\n                    # maybe those trees can be salvaged?\n                    bad_label = False\n                    for weird_label in weird_labels:\n                        if tree.find(weird_label) >= 0:\n                            bad_label = True\n                            errors[weird_label].append(\"Weird label {} from {} line {}: {}\".format(weird_label, orig_file, line_idx, tree))\n                            break\n                    if bad_label:\n                        continue\n\n                    if write_ids:\n                        if tree_id is None:\n                            errors[\"missing_id\"].append(\"Missing ID from {} at line {}\".format(orig_file, line_idx))\n                            writer.write(\"<s>\")\n                        else:\n                            writer.write(\"<s id=%d>\\n\" % tree_id)\n                    writer.write(tree)\n                    if write_ids:\n                        writer.write(\"</s>\\n\")\n                    reading_tree = False\n                    tree = \"\"\n                    tree_id = None\n                except MixedTreeError:\n                    errors[\"mixed\"].append(\"Mixed leaves and constituents from {} line {}: {}\".format(orig_file, line_idx, tree))\n                except UnlabeledTreeError:\n                    errors[\"unlabeled\"].append(\"Unlabeled nodes in tree from {} line {}: {}\".format(orig_file, line_idx, tree))\n            else:  # content line\n                if is_valid_line(line) and reading_tree:\n                    tree += line\n                elif reading_tree:\n                    errors[\"invalid\"].append(\"Invalid tree error in {} line {}: |{}|, rejected because of line |{}|\".format(orig_file, line_idx, tree, line))\n                    reading_tree = False\n\n    return errors\n\ndef convert_files(file_list, new_dir, verbose=False, fix_errors=True, convert_brackets=False, updated_tagset=False, write_ids=False):\n    errors = defaultdict(list)\n    for filename in file_list:\n        base_name, _ = os.path.splitext(os.path.split(filename)[-1])\n        new_path = os.path.join(new_dir, base_name)\n        new_file_path = f'{new_path}.mrg'\n        # Convert the tree and write to new_file_path\n        new_errors = convert_file(filename, new_file_path, fix_errors, convert_brackets, updated_tagset, write_ids)\n        for e in new_errors:\n            errors[e].extend(new_errors[e])\n\n    if len(errors.keys()) == 0:\n        print(\"All errors were fixed!\")\n    else:\n        print(\"Found the following errors:\")\n        keys = sorted(errors.keys())\n        if verbose:\n            for e in keys:\n                print(\"--------- %10s -------------\" % e)\n                print(\"\\n\\n\".join(errors[e]))\n                print()\n            print()\n        for e in keys:\n            print(\"%s: %d\" % (e, len(errors[e])))\n\ndef convert_dir(orig_dir, new_dir):\n    file_list = os.listdir(orig_dir)\n    # Only convert .prd files, skip the .raw files from VLSP 2009\n    file_list = [os.path.join(orig_dir, f) for f in file_list if os.path.splitext(f)[1] != '.raw']\n    convert_files(file_list, new_dir)\n\ndef main():\n    \"\"\"\n    Converts files from the 2009 version of VLSP to .mrg files\n    \n    Process args, loop through each file in the directory and convert\n    to the desired tree format\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Script that converts a VTB Tree into the desired format\",\n    )\n    parser.add_argument(\n        'orig_dir',\n        help='The location of the original directory storing original trees '\n    )\n    parser.add_argument(\n        'new_dir',\n        help='The location of new directory storing the new formatted trees'\n    )\n    args = parser.parse_args()\n\n    org_dir = args.org_dir\n    new_dir = args.new_dir\n\n    convert_dir(org_dir, new_dir)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/constituency/vtb_split.py",
    "content": "\"\"\"\nFrom a directory of files with VTB Trees, split into train/dev/test set\nwith a split of 70/15/15\n\nThe script requires two arguments\n1. org_dir: the original directory obtainable from running vtb_convert.py\n2. split_dir: the directory where the train/dev/test splits will be stored\n\"\"\"\n\nimport os\nimport argparse\nimport random\n\n\ndef create_shuffle_list(org_dir):\n    \"\"\"\n    This function creates the random order with which we use to loop through the files\n\n    :param org_dir: original directory storing the files that store the trees\n    :return: list of file names randomly shuffled\n    \"\"\"\n    file_names = sorted(os.listdir(org_dir))\n    random.shuffle(file_names)\n\n    return file_names\n\n\ndef create_paths(split_dir, short_name):\n    \"\"\"\n    This function creates the necessary paths for the train/dev/test splits\n\n    :param split_dir: directory that stores the splits\n    :return: train path, dev path, test path\n    \"\"\"\n    if not short_name:\n        short_name = \"\"\n    elif not short_name.endswith(\"_\"):\n        short_name = short_name + \"_\"\n\n    train_path = os.path.join(split_dir, '%strain.mrg' % short_name)\n    dev_path = os.path.join(split_dir, '%sdev.mrg' % short_name)\n    test_path = os.path.join(split_dir, '%stest.mrg' % short_name)\n\n    return train_path, dev_path, test_path\n\n\ndef get_num_samples(org_dir, file_names):\n    \"\"\"\n    Function for obtaining the number of samples\n\n    :param org_dir: original directory storing the tree files\n    :param file_names: list of file names in the directory\n    :return: number of samples\n    \"\"\"\n    count = 0\n    # Loop through the files, which then loop through the trees\n    for filename in file_names:\n        # Skip files that are not .mrg\n        if not filename.endswith('.mrg'):\n            continue\n        # File is .mrg. Start processing\n        file_dir = os.path.join(org_dir, filename)\n        with open(file_dir, 'r', encoding='utf-8') as reader:\n            content = reader.readlines()\n            for line in content:\n                count += 1\n\n    return count\n\ndef split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15, rotation=None):\n    os.makedirs(split_dir, exist_ok=True)\n\n    if train_size + dev_size >= 1.0:\n        print(\"Not making a test slice with the given ratios: train {} dev {}\".format(train_size, dev_size))\n\n    # Create a random shuffle list of the file names in the original directory\n    file_names = create_shuffle_list(org_dir)\n\n    # Create train_path, dev_path, test_path\n    train_path, dev_path, test_path = create_paths(split_dir, short_name)\n\n    # Set up the number of samples for each train/dev/test set\n    # TODO: if we ever wanted to split files with <s> </s> in them,\n    # this particular code would need some updating to pay attention to the ids\n    num_samples = get_num_samples(org_dir, file_names)\n    print(\"Found {} total lines in {}\".format(num_samples, org_dir))\n\n    stop_train = int(num_samples * train_size)\n    if train_size + dev_size >= 1.0:\n        stop_dev = num_samples\n        output_limits = (stop_train, stop_dev)\n        output_names = (train_path, dev_path)\n        print(\"Splitting {} train, {} dev\".format(stop_train, stop_dev - stop_train))\n    elif train_size + dev_size > 0.0:\n        stop_dev = int(num_samples * (train_size + dev_size))\n        output_limits = (stop_train, stop_dev, num_samples)\n        output_names = (train_path, dev_path, test_path)\n        print(\"Splitting {} train, {} dev, {} test\".format(stop_train, stop_dev - stop_train, num_samples - stop_dev))\n    else:\n        stop_dev = 0\n        output_limits = (num_samples,)\n        output_names = (test_path,)\n        print(\"Copying all {} lines to test\".format(num_samples))\n\n    # Count how much stuff we've written.\n    # We will switch to the next output file when we're written enough\n    count = 0\n\n    trees = []\n    for filename in file_names:\n        if not filename.endswith('.mrg'):\n            continue\n        with open(os.path.join(org_dir, filename), encoding='utf-8') as reader:\n            new_trees = reader.readlines()\n            new_trees = [x.strip() for x in new_trees]\n            new_trees = [x for x in new_trees if x]\n            trees.extend(new_trees)\n    # rotate the train & dev sections, leave the test section the same\n    if rotation is not None and rotation[0] > 0:\n        rotation_start = len(trees) * rotation[0] // rotation[1]\n        rotation_end = stop_dev\n        # if there are no test trees, rotation_end: will be empty anyway\n        trees = trees[rotation_start:rotation_end] + trees[:rotation_start] + trees[rotation_end:]\n    tree_iter = iter(trees)\n    for write_path, count_limit in zip(output_names, output_limits):\n        with open(write_path, 'w', encoding='utf-8') as writer:\n            # Loop through the files, which then loop through the trees and write to write_path\n            while count < count_limit:\n                next_tree = next(tree_iter, None)\n                if next_tree is None:\n                    raise RuntimeError(\"Ran out of trees before reading all of the expected trees\")\n                # Write to write_dir\n                writer.write(next_tree)\n                writer.write(\"\\n\")\n                count += 1\n\ndef main():\n    \"\"\"\n    Main function for the script\n\n    Process args, loop through each tree in each file in the directory\n    and write the trees to the train/dev/test split with a split of\n    70/15/15\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Script that splits a list of files of vtb trees into train/dev/test sets\",\n    )\n    parser.add_argument(\n        'org_dir',\n        help='The location of the original directory storing correctly formatted vtb trees '\n    )\n    parser.add_argument(\n        'split_dir',\n        help='The location of new directory storing the train/dev/test set'\n    )\n\n    args = parser.parse_args()\n\n    org_dir = args.org_dir\n    split_dir = args.split_dir\n\n    random.seed(1234)\n\n    split_files(org_dir, split_dir)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/contract_mwt.py",
    "content": "import sys\n\ndef contract_mwt(infile, outfile, ignore_gapping=True):\n    \"\"\"\n    Simplify the gold tokenizer data for use as MWT processor test files\n\n    The simplifications are to remove the expanded MWTs, and in the\n    case of ignore_gapping=True, remove any copy words for the dependencies\n    \"\"\"\n    with open(outfile, 'w', encoding='utf-8') as fout:\n        with open(infile, 'r', encoding='utf-8') as fin:\n            idx = 0\n            mwt_begin = 0\n            mwt_end = -1\n            for line in fin:\n                line = line.strip()\n    \n                if line.startswith('#'):\n                    print(line, file=fout)\n                    continue\n                elif len(line) <= 0:\n                    print(line, file=fout)\n                    idx = 0\n                    mwt_begin = 0\n                    mwt_end = -1\n                    continue\n    \n                line = line.split('\\t')\n\n                # ignore gapping word\n                if ignore_gapping and '.' in line[0]:\n                    continue\n\n                idx += 1\n                if '-' in line[0]:\n                    mwt_begin, mwt_end = [int(x) for x in line[0].split('-')]\n                    print(\"{}\\t{}\\t{}\".format(idx, \"\\t\".join(line[1:-1]), \"MWT=Yes\" if line[-1] == '_' else line[-1] + \"|MWT=Yes\"), file=fout)\n                    idx -= 1\n                elif mwt_begin <= idx <= mwt_end:\n                    continue\n                else:\n                    print(\"{}\\t{}\".format(idx, \"\\t\".join(line[1:])), file=fout)\n\nif __name__ == '__main__':\n    contract_mwt(sys.argv[1], sys.argv[2])\n\n"
  },
  {
    "path": "stanza/utils/datasets/coref/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/coref/balance_languages.py",
    "content": "\"\"\"\nbalance_concat.py\ncreate a test set from a dev set which is language balanced\n\"\"\"\n\nimport json\nfrom collections import defaultdict\n\nfrom random import Random\n\n# fix random seed for reproducability\nR = Random(42)\n\nwith open(\"./corefud_concat_v1_0_langid.train.json\", 'r') as df:\n    raw = json.load(df)\n\n# calculate type of each class; then, we will select the one\n# which has the LOWEST counts as the sample rate\nlang_counts = defaultdict(int)\nfor i in raw:\n    lang_counts[i[\"lang\"]] += 1\n\nmin_lang_count = min(lang_counts.values())\n\n# sample 20% of the smallest amount for test set\n# this will look like an absurdly small number, but\n# remember this is DOCUMENTS not TOKENS or UTTERANCES\n# so its actually decent\n# also its per language\ntest_set_size = int(0.1*min_lang_count)\n\n# sampling input by language\nraw_by_language = defaultdict(list)\nfor i in raw:\n    raw_by_language[i[\"lang\"]].append(i)\nlanguages = list(set(raw_by_language.keys()))\n\ntrain_set = []\ntest_set = []\nfor i in languages:\n    length = list(range(len(raw_by_language[i])))\n    choices = R.sample(length, test_set_size)\n\n    for indx,i in enumerate(raw_by_language[i]):\n        if indx in choices:\n            test_set.append(i)\n        else:\n            train_set.append(i)\n\nwith open(\"./corefud_concat_v1_0_langid-bal.train.json\", 'w') as df:\n    json.dump(train_set, df, indent=2)\n\nwith open(\"./corefud_concat_v1_0_langid-bal.test.json\", 'w') as df:\n    json.dump(test_set, df, indent=2)\n\n\n\n# raw_by_language[\"en\"]\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/coref/convert_hebrew_iahlt.py",
    "content": "\"\"\"Convert the coref annotation of IAHLT to the Stanza coref format\n\nThis dataset is available at\n\nhttps://github.com/IAHLT/coref\n\nDownload it via git clone to $COREF_BASE/hebrew, so for example on the cluster:\n\ncd /u/nlp/data/coref/\nmkdir hebrew\ncd hebrew\ngit clone git@github.com:IAHLT/coref.git\n\nThen run\n\npython3 stanza/utils/datasets/coref/convert_hebrew_iahlt.py\n\nThe scores for models built from the dataset are pretty lousy in\ngeneral, but seem to be in line with the scores obtained by other\npeople working on this data.  For example, the authors said they had a\n52 F1, whereas if we use roberta-xlm, we get 50.\n\"\"\"\n\nimport argparse\nfrom collections import defaultdict, namedtuple\nimport json\nimport os\n\nimport stanza\n\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.get_tqdm import get_tqdm\nfrom stanza.utils.datasets.coref.utils import process_document\n\ntqdm = get_tqdm()\n\nCorefDoc = namedtuple(\"CorefDoc\", ['doc_id', 'sentences', 'coref_spans'])\n\n# TODO: binary search for speed?\ndef search_mention_start(doc, mention_start):\n    for sent_idx, sentence in enumerate(doc.sentences):\n        if mention_start < doc.sentences[sent_idx].tokens[-1].end_char:\n            break\n    else:\n        raise ValueError\n    for word_idx, word in enumerate(sentence.words):\n        if word.end_char is None:\n            print(\"Found weirdness on sentence:\\n|%s|\" % sentence.text)\n            print(word.parent)\n            return None, None\n        if mention_start < word.end_char:\n            break\n    else:\n        raise ValueError\n    return sent_idx, word_idx\n\ndef search_mention_end(doc, mention_end):\n    for sent_idx, sentence in enumerate(doc.sentences):\n        if sent_idx + 1 == len(doc.sentences) or mention_end < doc.sentences[sent_idx+1].tokens[0].start_char:\n            break\n    for word_idx, word in enumerate(sentence.words):\n        if word_idx + 1 == len(sentence.words) or mention_end < sentence.words[word_idx+1].start_char:\n            break\n    return sent_idx, word_idx\n\ndef extract_doc(tokenizer, lines):\n    # 16, 1, 5 for the train, dev, test sets\n    broken = 0\n    tok_error = 0\n    singletons = 0\n    one_words = 0\n    processed_docs = []\n    for line_idx, line in enumerate(tqdm(lines)):\n        all_clusters = defaultdict(list)\n        doc_id = line['metadata']['doc_id']\n        text = line['text']\n        clusters = line['clusters']\n        doc = tokenizer(text)\n        for cluster_idx, cluster in enumerate(clusters):\n            found_mentions = []\n            for mention_idx, mention in enumerate(cluster['mentions']):\n                mention_start = mention[0]\n                mention_end = mention[1]\n                start_sent, start_word = search_mention_start(doc, mention_start)\n                if start_sent is None or start_word is None:\n                    tok_error += 1\n                    continue\n                end_sent, end_word = search_mention_end(doc, mention_end)\n                assert end_sent >= start_sent\n                if start_sent != end_sent:\n                    broken += 1\n                    continue\n\n                assert end_word >= start_word\n                if end_word == start_word:\n                    one_words += 1\n                found_mentions.append((start_sent, start_word, end_word))\n\n                #if cluster_idx == 0 and line_idx == 0:\n                #    expanded_start = max(0, mention_start - 10)\n                #    expanded_end = min(len(text), mention_end + 10)\n                #    print(\"EXTRACTING MENTION: %d %d\" % (mention[0], mention[1]))\n                #    print(\" context: |%s|\" % text[expanded_start:expanded_end])\n                #    print(\" mention[0]:mention[1]: |%s|\" % text[mention[0]:mention[1]])\n                #    print(\" search text: |%s|\" % text[mention_start:mention_end])\n                #    extracted_words = doc.sentences[start_sent].words[start_word:end_word+1]\n                #    extracted_text = \" \".join([x.text for x in extracted_words])\n                #    print(\" extracted words: |%s|\" % extracted_text)\n                #    print(\" endpoints: %d %d\" % (mention_start, mention_end))\n                #    print(\" number of extracted words: %d\" % len(extracted_words))\n                #    print(\" first word endpoints: %d %d\" % (extracted_words[0].start_char, extracted_words[0].end_char))\n                #    print(\" last word endpoints: %d %d\" % (extracted_words[-1].start_char, extracted_words[-1].end_char))\n            if len(found_mentions) == 0:\n                continue\n            elif len(found_mentions) == 1:\n                # the number of singletons, after discarding mentions that\n                # crossed a sentence boundary according to Stanza, is\n                # 5, 0, 1\n                # so clearly the dataset does not intentionally have\n                # (many?) singletons in it\n                singletons += 1\n                continue\n            else:\n                all_clusters[cluster_idx] = found_mentions\n        # maybe we need to update the interface - there can be MWT in Hebrew\n        sentences = [[word.text for word in sent.words] for sent in doc.sentences]\n        coref_spans = defaultdict(list)\n        for cluster_idx in all_clusters:\n            for sent_idx, start_word, end_word in all_clusters[cluster_idx]:\n                coref_spans[sent_idx].append((cluster_idx, start_word, end_word))\n        processed_docs.append(CorefDoc(doc_id, sentences, coref_spans))\n    print(\"Found %d broken across two sentences, %d tok errors, %d singleton mentions, %d one_word mentions\" % (broken, tok_error, singletons, one_words))\n    return processed_docs\n\ndef read_doc(tokenizer, filename):\n    with open(filename, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n    lines = [json.loads(line) for line in lines]\n    return extract_doc(tokenizer, lines)\n\ndef write_json_file(output_filename, dataset):\n    with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(dataset, fout, indent=2, ensure_ascii=False)\n\ndef main(args=None):\n    paths = get_default_paths()\n    parser = argparse.ArgumentParser(\n        prog='Convert Hebrew IAHLT data',\n    )\n    parser.add_argument('--output_directory', default=None, type=str, help='Where to output the data (defaults to %s)' % paths['COREF_DATA_DIR'])\n    args = parser.parse_args(args=args)\n    coref_output_path = args.output_directory if args.output_directory else paths['COREF_DATA_DIR']\n    print(\"Will write IAHLT dataset to %s\" % coref_output_path)\n\n    coref_input_path = paths[\"COREF_BASE\"]\n    hebrew_base_path = os.path.join(coref_input_path, \"hebrew\", \"coref\", \"train_val_test\")\n\n    tokenizer = stanza.Pipeline(\"he\", processors=\"tokenize\", package=\"default_accurate\")\n    pipe = stanza.Pipeline(\"he\", processors=\"tokenize,pos,lemma,depparse\", package=\"default_accurate\", tokenize_pretokenized=True)\n\n    input_files = [\"coref-5-heb_train.jsonl\", \"coref-5-heb_val.jsonl\", \"coref-5-heb_test.jsonl\"]\n    output_files = [\"he_iahlt.train.json\", \"he_iahlt.dev.json\", \"he_iahlt.test.json\"]\n    for input_filename, output_filename in zip(input_files, output_files):\n        input_filename = os.path.join(hebrew_base_path, input_filename)\n        assert os.path.exists(input_filename)\n        docs = read_doc(tokenizer, input_filename)\n        dataset = [process_document(pipe, doc.doc_id, \"\", doc.sentences, doc.coref_spans, None, lang=\"he\") for doc in tqdm(docs)]\n\n        output_filename = os.path.join(coref_output_path, output_filename)\n        write_json_file(output_filename, dataset)\n\n    return output_files\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/coref/convert_hebrew_mixed.py",
    "content": "\"\"\"\nBuild a dataset mixed with IAHLT Hebrew and UD Coref\n\nWe find that the IAHLT dataset by itself, trained using Stanza 1.11\nwith xlm-roberta-large and a lora finetuning layer, gets 49.7 F1.\nThis is a bit lower than the value the IAHLT group originally had, as\nthey reported 52.  Interestingly, we find that mixing in the 1.3 UD\nCoref improves results, getting 51.7 under the same parameters\n\nThis script runs the IAHLT conversion and the UD Coref conversion,\nthen combines the files into one big training file\n\"\"\"\n\nimport json\nimport os\nimport shutil\nimport tempfile\n\nfrom stanza.utils.datasets.coref import convert_hebrew_iahlt\nfrom stanza.utils.datasets.coref import convert_udcoref\nfrom stanza.utils.default_paths import get_default_paths\n\ndef main():\n    paths = get_default_paths()\n    coref_output_path = paths['COREF_DATA_DIR']\n    with tempfile.TemporaryDirectory() as temp_dir_path:\n        hebrew_filenames = convert_hebrew_iahlt.main([\"--output_directory\", temp_dir_path])\n        udcoref_filenames = convert_udcoref.main([\"--project\", \"gerrom\", \"--output_directory\", temp_dir_path])\n\n        with open(os.path.join(temp_dir_path, hebrew_filenames[0]), encoding=\"utf-8\") as fin:\n            hebrew_train = json.load(fin)\n        udcoref_train_filename = os.path.join(temp_dir_path, udcoref_filenames[0])\n        with open(udcoref_train_filename, encoding=\"utf-8\") as fin:\n            print(\"Reading extra udcoref json data from %s\" % udcoref_train_filename)\n            udcoref_train = json.load(fin)\n        mixed_train = hebrew_train + udcoref_train\n        with open(os.path.join(coref_output_path, \"he_mixed.train.json\"), \"w\", encoding=\"utf-8\") as fout:\n            json.dump(mixed_train, fout, indent=2, ensure_ascii=False)\n\n        shutil.copyfile(os.path.join(temp_dir_path, hebrew_filenames[1]),\n                        os.path.join(coref_output_path, \"he_mixed.dev.json\"))\n        shutil.copyfile(os.path.join(temp_dir_path, hebrew_filenames[2]),\n                        os.path.join(coref_output_path, \"he_mixed.test.json\"))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/coref/convert_hindi.py",
    "content": "import argparse\nimport json\nfrom operator import itemgetter\nimport os\n\nimport stanza\n\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.get_tqdm import get_tqdm\nfrom stanza.utils.datasets.coref.utils import process_document\n\ntqdm = get_tqdm()\n\ndef flatten_spans(coref_spans):\n    \"\"\"\n    Put span IDs on each span, then flatten them into a single list sorted by first word\n    \"\"\"\n    # put span indices on the spans\n    #   [[[38, 39], [42, 43], [41, 41], [180, 180], [300, 300]], [[60, 68],\n    #   -->\n    #   [[[0, 38, 39], [0, 42, 43], [0, 41, 41], [0, 180, 180], [0, 300, 300]], [[1, 60, 68], ...\n    coref_spans = [[[span_idx, x, y] for x, y in span] for span_idx, span in enumerate(coref_spans)]\n    # flatten list\n    #   -->\n    #   [[0, 38, 39], [0, 42, 43], [0, 41, 41], [0, 180, 180], [0, 300, 300], [1, 60, 68], ...\n    coref_spans = [y for x in coref_spans for y in x]\n    # sort by the first word index\n    #   -->\n    #   [[0, 38, 39], [0, 41, 41], [0, 42, 43], [1, 60, 68], [0, 180, 180], [0, 300, 300], ...\n    coref_spans = sorted(coref_spans, key=itemgetter(1))\n    return coref_spans\n\ndef remove_nulls(coref_spans, sentences):\n    \"\"\"\n    Removes the \"\" and \"NULL\" words from the sentences\n\n    Also, reindex the spans by the number of words removed.\n    So, we might get something like\n      [[0, 2], [31, 33], [134, 136], [161, 162]]\n      ->\n      [[0, 2], [30, 32], [129, 131], [155, 156]]\n    \"\"\"\n    word_map = []\n    word_idx = 0\n    map_idx = 0\n    new_sentences = []\n    for sentence in sentences:\n        new_sentence = []\n        for word in sentence:\n            word_map.append(map_idx)\n            word_idx += 1\n            if word != '' and word != 'NULL':\n                new_sentence.append(word)\n                map_idx += 1\n        new_sentences.append(new_sentence)\n\n    new_spans = []\n    for mention in coref_spans:\n        new_mention = []\n        for span in mention:\n            span = [word_map[x] for x in span]\n            new_mention.append(span)\n        new_spans.append(new_mention)\n    return new_spans, new_sentences\n\ndef arrange_spans_by_sentence(coref_spans, sentences):\n    sentence_spans = []\n\n    current_index = 0\n    span_idx = 0\n    for sentence in sentences:\n        current_sentence_spans = []\n        end_index = current_index + len(sentence)\n        while span_idx < len(coref_spans) and coref_spans[span_idx][1] < end_index:\n            new_span = [coref_spans[span_idx][0], coref_spans[span_idx][1] - current_index, coref_spans[span_idx][2] - current_index]\n            current_sentence_spans.append(new_span)\n            span_idx += 1\n        sentence_spans.append(current_sentence_spans)\n        current_index = end_index\n    return sentence_spans\n\ndef convert_dataset_section(pipe, section, use_cconj_heads):\n    \"\"\"\n    Reprocess the original data into a format compatible with previous conversion utilities\n\n    - remove blank and NULL words\n    - rearrange the spans into spans per sentence instead of a list of indices for each span\n    - process the document using a Hindi pipeline\n    \"\"\"\n    processed_section = []\n\n    for idx, doc in enumerate(tqdm(section)):\n        doc_id = doc['doc_key']\n        part_id = \"\"\n        sentences = doc['sentences']\n        sentence_speakers = doc['speakers']\n\n        coref_spans = doc['clusters']\n        coref_spans, sentences = remove_nulls(coref_spans, sentences)\n        coref_spans = flatten_spans(coref_spans)\n        coref_spans = arrange_spans_by_sentence(coref_spans, sentences)\n\n        processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=use_cconj_heads)\n        processed_section.append(processed)\n    return processed_section\n\ndef remove_nulls_dataset_section(section):\n    processed_section = []\n    for doc in section:\n        sentences = doc['sentences']\n        coref_spans = doc['clusters']\n        coref_spans, sentences = remove_nulls(coref_spans, sentences)\n        doc['sentences'] = sentences\n        doc['clusters'] = coref_spans\n        processed_section.append(doc)\n    return processed_section\n\n\ndef read_json_file(filename):\n    with open(filename, encoding=\"utf-8\") as fin:\n        dataset = []\n        for line in fin:\n            line = line.strip()\n            if not line:\n                continue\n            dataset.append(json.loads(line))\n    return dataset\n\ndef write_json_file(output_filename, converted_section):\n    with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(converted_section, fout, indent=2)\n\ndef main():\n    parser = argparse.ArgumentParser(\n        prog='Convert Hindi Coref Data',\n    )\n    parser.add_argument('--no_use_cconj_heads', dest='use_cconj_heads', action='store_false', help=\"Don't use the conjunction-aware transformation\")\n    parser.add_argument('--remove_nulls', action='store_true', help=\"The only action is to remove the NULLs and blank tokens\")\n    args = parser.parse_args()\n\n    paths = get_default_paths()\n    coref_input_path = paths[\"COREF_BASE\"]\n    hindi_base_path = os.path.join(coref_input_path, \"hindi\", \"dataset\")\n\n    sections = (\"train\", \"dev\", \"test\")\n    if args.remove_nulls:\n        for section in sections:\n            input_filename = os.path.join(hindi_base_path, \"%s.hindi.jsonlines\" % section)\n            dataset = read_json_file(input_filename)\n            dataset = remove_nulls_dataset_section(dataset)\n            output_filename = os.path.join(hindi_base_path, \"hi_deeph.%s.nonulls.json\" % section)\n            with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n                for doc in dataset:\n                    json.dump(doc, fout, ensure_ascii=False)\n                    fout.write(\"\\n\")\n    else:\n        pipe = stanza.Pipeline(\"hi\", processors=\"tokenize,pos,lemma,depparse\", package=\"default_accurate\", tokenize_pretokenized=True)\n\n        os.makedirs(paths[\"COREF_DATA_DIR\"], exist_ok=True)\n\n        for section in sections:\n            input_filename = os.path.join(hindi_base_path, \"%s.hindi.jsonlines\" % section)\n            dataset = read_json_file(input_filename)\n\n            output_filename = os.path.join(paths[\"COREF_DATA_DIR\"], \"hi_deeph.%s.json\" % section)\n            converted_section = convert_dataset_section(pipe, dataset, use_cconj_heads=args.use_cconj_heads)\n            write_json_file(output_filename, converted_section)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/coref/convert_ontonotes.py",
    "content": "\"\"\"\nconvert_ontonotes.py\n\nThis script is used to convert the OntoNotes dataset into a format that can be used by Stanza's coreference resolution model. The script uses the datasets package to download the OntoNotes dataset and then processes the dataset using Stanza's coreference resolution pipeline. The processed dataset is then saved in a JSON file.\n\nIf you want to simply process the official OntoNotes dataset...\n1. install the `datasets` package: `pip install datasets`\n2. make folders! (or those adjusted to taste through scripts/config.sh)\n   - extern_data/coref/english/en_ontonotes\n   - data/coref\n2. run this script: python -m stanza.utils.datasets.coref.convert_ontonotes\n\nIf you happen to have singleton annotated coref chains...\n1. install the `datasets` package: `pip install datasets`\n2. make folders! (or those adjusted to taste through scripts/config.sh)\n   - extern_data/coref/english/en_ontonotes\n   - data/coref\n3. get the singletons annotated coref chains in conll format from the Splice repo\n    https://github.com/yilunzhu/splice/raw/refs/heads/main/data/ontonotes5_mentions.zip\n4. place the singleton annotated coref chains in the folder `extern_data/coref/english/en_ontonotes`\n   $ ls ./extern_data/coref/english/en_ontonotes\n        dev_sg_pred.english.v4_gold_conll\n        test_sg_pred.english.v4_gold_conll\n        train_sg.english.v4_gold_conll\n5. run this script: python -m stanza.utils.datasets.coref.convert_ontonotes --use_singletons\n\nYour results will appear in ./data/coref/, and you can be off to the races with training!\nNote that this script invokes Stanza itself to run some tagging.\n\"\"\"\n\nimport json\nimport os\n\nfrom pathlib import Path\n\nimport argparse\nimport stanza\n\nfrom stanza.models.constituency import tree_reader\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.get_tqdm import get_tqdm\nfrom stanza.utils.datasets.coref.utils import process_document\n\nfrom stanza.utils.conll import CoNLL\nfrom collections import defaultdict\n\ntqdm = get_tqdm()\n\ndef read_paragraphs(section):\n    for doc in section:\n        part_id = None\n        paragraph = []\n        for sentence in doc['sentences']:\n            if part_id is None:\n                part_id = sentence['part_id']\n            elif part_id != sentence['part_id']:\n                yield doc['document_id'], part_id, paragraph\n                paragraph = []\n                part_id = sentence['part_id']\n            paragraph.append(sentence)\n        if paragraph != []:\n            yield doc['document_id'], part_id, paragraph\n\n\ndef convert_dataset_section(pipe, section, override_singleton_chains=None):\n    processed_section = []\n    section = list(x for x in read_paragraphs(section))\n\n    # we need to do this because apparently the singleton annotations\n    # don't use the same numbering scheme as the ontonotes annotations\n    # so there will be chain id conflicts\n    max_chain_id = sorted([\n        chain_id \n        for i in section \n        for j in i[2] \n        for chain_id, _, _ in j[\"coref_spans\"]\n    ])[-1]\n    # this dictionary will map singleton chains' \"special\" ids\n    # to the OntoNotes IDs\n    sg_to_ontonotes_cluster_id_map = defaultdict(\n        lambda: len(sg_to_ontonotes_cluster_id_map)+max_chain_id+1\n    )\n\n    for idx, (doc_id, part_id, paragraph) in enumerate(tqdm(section)):\n        sentences = [x['words'] for x in paragraph]\n        truly_coref_spans = [x['coref_spans'] for x in paragraph]\n        # the problem to solve here is that the singleton chains'\n        # IDs don't match the coref chains' ids\n        # \n        # and, what the labels calls a \"singleton\" may not actually\n        # be one because the \"singleton\" seems like it includes all\n        # NPs which may or may not be a singleton\n        coref_spans = []\n        if override_singleton_chains:\n            singleton_chains = override_singleton_chains[doc_id][part_id]\n            for singleton_pred, coref_pred in zip(singleton_chains, truly_coref_spans):\n                sentence_coref_preds = []\n                # these are sentence level predictions, which we will\n                # disambiguate: if a subspan of \"singleton\" exists in the \n                # truly coref sets, we realise its not a singleton and\n                # then ignore it\n                coref_pred_locs = set([tuple(i[1:]) for i in coref_pred])\n                for id,start,end in singleton_pred:\n                    if (start,end) not in coref_pred_locs:\n                        # this is truly a singleton\n                        sentence_coref_preds.append([\n                            sg_to_ontonotes_cluster_id_map[id],\n                            start,\n                            end\n                        ])\n                sentence_coref_preds += coref_pred\n                coref_spans.append(sentence_coref_preds)\n        else:\n            coref_spans = truly_coref_spans\n\n        sentence_speakers = [x['speaker'] for x in paragraph]\n\n        processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers)\n        processed_section.append(processed)\n    return processed_section\n\ndef extract_chains_from_chunk(chunk):\n    \"\"\"give a chunk of the gold conll, extract the coref chains\n\n    remember, the indicies are front and back *inclusive*, zero indexed\n    and a span that takes one word only is annotated [id, n, n] (i.e. we\n    don't fencepost by +1)\n\n    Arguments\n    ---------\n        chunk : List[str]\n            list of strings, each string is a line in the conll file\n\n    Returns\n    -------\n        final_chains : List[Tuple[int, int, int ]]\n            list of chains, each chain is a list of [id, open_location, close_location]\n    \"\"\"\n\n    chains = [sentence.split(\"    \")[-1].strip()\n            for sentence in chunk]\n    chains = [[] if i == '-' else i.split(\"|\")\n            for i in chains]\n\n    opens = defaultdict(list)\n    closes = defaultdict(list)\n\n    for indx, elem in enumerate(chains):\n\n        # for each one, check if its an open, close, or both \n        for i in elem:\n            id = int(i.strip(\"(\").strip(\")\"))\n            if (i[0]==\"(\"):\n                opens[id].append(indx)\n            if (i[-1]==\")\"):\n                closes[id].append(indx)\n\n    # and now, we chain the ids' opens and closes together\n    # into the shape of [id, open_location, close_location]\n    opens = dict(opens)\n    closes = dict(closes)\n\n    final_chains = []\n    for key, open_indx in opens.items():\n        for o,c in zip(sorted(open_indx), sorted(closes[key])):\n            final_chains.append([key, o,c])\n\n    return final_chains\n\ndef extract_chains_from_conll(gold_coref_conll):\n    \"\"\"extract the coref chains from the gold conll file\n\n    Arguments\n    --------\n        gold_coref_conll : str\n            path to the gold conll file, with coreference chains\n    Returns\n    -------\n        final_chunks : Dict[str, List[List[List[Tuple[int, int, int]]]]]\n            dictionary of document_id to list of paragraphs into\n            list of coref chains in OntoNotes style, keyed by document ID\n    \"\"\"\n    with open(gold_coref_conll, 'r') as df:\n        gold_coref_conll = df.readlines()\n    # we want to first separate the document into sentence-level\n    # chunks; we assume that the ordering of the sentences are correct in the\n    # gold document\n    sections = []\n    section = []\n    chunk = []\n    for i in gold_coref_conll:\n        if len(i.split(\"    \")) < 10:\n            if len(chunk) > 0:\n                section.append(chunk)\n            elif i.startswith(\"#end document\"): # this is a new paragraph\n                sections.append(section)\n                section = []\n            chunk = []\n        else:\n            chunk.append(i)\n\n    # finally, we process each chunk and *index them by ID*\n    final_chunks = defaultdict(list)\n    for section in sections:\n        section_chains = []\n        for chunk in section:\n            section_chains.append(extract_chains_from_chunk(chunk))\n        final_chunks[chunk[0].split(\"    \")[0]].append(section_chains)\n    final_chunks = dict(final_chunks)\n\n    return final_chunks\n\nSECTION_NAMES = {\"train\": \"train\",\n                 \"dev\": \"validation\",\n                 \"test\": \"test\"}\nOVERRIDE_CONLL_PATHS = {\"en_ontonotes\": {\n    \"train\": \"train_sg.english.v4_gold_conll\",\n    \"validation\": \"dev_sg_pred.english.v4_gold_conll\",\n    \"test\": \"test_sg_pred.english.v4_gold_conll\"\n}}\n\ndef process_dataset(short_name, ontonotes_path, coref_output_path, use_singletons=False):\n    try:\n        from datasets import load_dataset\n    except ImportError as e:\n        raise ImportError(\"Please install the datasets package to process OntoNotes coref with Stanza\")\n\n    if short_name == 'en_ontonotes':\n        config_name = 'english_v4'\n    elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'):\n        config_name = 'chinese_v4'\n    elif short_name == 'ar_ontonotes':\n        config_name = 'arabic_v4'\n    else:\n        raise ValueError(\"Unknown short name for downloading ontonotes: %s\" % short_name)\n\n    pipe = stanza.Pipeline(\"en\", processors=\"tokenize,pos,lemma,depparse\", package=\"default_accurate\", tokenize_pretokenized=True)\n\n    # if the cache directory doesn't yet exist, we make it\n    # we store the cache in a separate subfolder to distinguish from the\n    # possible Singleton conlls that maybe in the folder\n    (Path(ontonotes_path) / \"cache\").mkdir(exist_ok=True)\n\n    dataset = load_dataset(\"conll2012_ontonotesv5\", config_name, cache_dir=str(Path(ontonotes_path) / \"cache\"), trust_remote_code=True)\n    for section, hf_name in SECTION_NAMES.items():\n    # for section, hf_name in [(\"test\", \"test\")]:\n        print(\"Processing %s\" % section)\n        if use_singletons:\n            singletons_path = (Path(ontonotes_path) / OVERRIDE_CONLL_PATHS[short_name][hf_name])\n            if not singletons_path.exists():\n                raise FileNotFoundError(\n                    \"Could not find singleton annotated coref chains \"\n                    \"in conll format\\nensure you have placed them in the folder %s\" % singletons_path\n                )\n            # if, for instance, Amir have given us some singleton annotated coref chains in conll files,\n            # we will use those instead of the ones that OntoNotes has\n            converted_section = convert_dataset_section(pipe, dataset[hf_name], extract_chains_from_conll(\n                str(singletons_path)\n            ))\n        else:\n            converted_section = convert_dataset_section(pipe, dataset[hf_name])\n        output_filename = os.path.join(coref_output_path, \"%s.%s.json\" % (short_name, section))\n        with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n            json.dump(converted_section, fout, indent=2)\n\n\ndef main():\n    parser = argparse.ArgumentParser(prog=\"convert_ontonotes.py\",\n                                     description=\"Convert OntoNotes dataset to Stanza's coreference format\")\n    parser.add_argument(\"--use_singletons\", default=False,\n                        action=\"store_true\", help=\"Use singleton annotated coref chains\")\n    args = parser.parse_args()\n\n    paths = get_default_paths()\n    coref_input_path = paths['COREF_BASE']\n    ontonotes_path = os.path.join(coref_input_path, \"english\", \"en_ontonotes\")\n    coref_output_path = paths['COREF_DATA_DIR']\n    process_dataset(\"en_ontonotes\", ontonotes_path, coref_output_path, args.use_singletons)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/coref/convert_tamil.py",
    "content": "\"\"\"\nConvert the AU-KBC coreference dataset from Prof. Sobha\n\nhttps://aclanthology.org/2020.wildre-1.4/\n\nLocated in /u/nlp/data/coref/tamil on the Stanford cluster\n\"\"\"\n\nimport argparse\nimport glob\nimport json\nfrom operator import itemgetter\nimport os\nimport random\nimport re\n\nimport stanza\n\nfrom stanza.utils.datasets.coref.utils import process_document\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nbegin_re = re.compile(r\"B-([0-9]+)\")\nin_re =  re.compile(r\"I-([0-9]+)\")\n\ndef write_json_file(output_filename, converted_section):\n    with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(converted_section, fout, indent=2)\n\ndef read_doc(filename):\n    \"\"\"\n    Returns the sentences and the coref markings from this filename\n\n    sentences: a list of list of words\n    corefs: a list of list of clusters, which were tagged B-num and I-num in the dataset\n    \"\"\"\n    with open(filename, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n\n    all_words = []\n    all_coref = []\n    current_words = []\n    current_coref = []\n    for line in lines:\n        line = line.strip()\n        if not line:\n            all_words.append(current_words)\n            all_coref.append(current_coref)\n            current_words = []\n            current_coref = []\n            continue\n        pieces = line.split(\"\\t\")\n        current_words.append(pieces[3])\n        current_coref.append(pieces[-1])\n\n    if current_words:\n        all_words.append(current_words)\n        all_coref.append(current_coref)\n\n    return all_words, all_coref\n\ndef convert_clusters(filename, corefs):\n    sentence_clusters = []\n    # current_clusters will be a list of (cluster id, start idx)\n    for sent_idx, sentence_coref in enumerate(corefs):\n        current_clusters = []\n        processed = []\n        for word_idx, word_coref in enumerate(sentence_coref):\n            new_clusters = []\n            if word_coref == '-':\n                pieces = []\n            else:\n                pieces = word_coref.split(\";\")\n            for piece in pieces:\n                if not piece.startswith(\"I-\") and not piece.startswith(\"B-\"):\n                    raise ValueError(\"Unexpected coref format %s in document %s\" % (word_coref, filename))\n                if piece.startswith(\"B-\"):\n                    new_clusters.append((int(piece[2:]), word_idx))\n                else:\n                    assert piece.startswith(\"I-\")\n                    cluster_id = int(piece[2:])\n                    # this will keep the first cluster found\n                    # the effect of this is that when two clusters overlap,\n                    # and they happen to be the same cluster id,\n                    # they will be nested instead of overlapping past each other\n                    for idx, previous_cluster in enumerate(current_clusters):\n                        if previous_cluster[0] == cluster_id:\n                            break\n                    else:\n                        raise ValueError(\"Cluster %s does not continue an existing cluster in %s\" % (piece, filename))\n                    new_clusters.append(previous_cluster)\n                    del current_clusters[idx]\n\n            for cluster, start_idx in current_clusters:\n                processed.append((cluster, start_idx, word_idx-1))\n            current_clusters = new_clusters\n        for cluster, start_idx in current_clusters:\n            processed.append((cluster, start_idx, len(sentence_coref)-1))\n        # sort by the first word index\n        processed = sorted(processed, key=itemgetter(1))\n        # TODO: cluster IDs are starting at 1, not 0.\n        # that may or may not be relevant\n        sentence_clusters.append(processed)\n    return sentence_clusters\n\ndef main():\n    parser = argparse.ArgumentParser(\n        prog='Convert Tamil Coref Data',\n    )\n    parser.add_argument('--no_use_cconj_heads', dest='use_cconj_heads', action='store_false', help=\"Don't use the conjunction-aware transformation\")\n    args = parser.parse_args()\n\n    random.seed(1234)\n\n    paths = get_default_paths()\n    coref_input_path = paths[\"COREF_BASE\"]\n    tamil_base_path = os.path.join(coref_input_path, \"tamil\", \"coref_ta_corrected\")\n    tamil_glob = os.path.join(tamil_base_path, \"*txt\")\n\n    filenames = sorted(glob.glob(tamil_glob))\n    docs = [read_doc(x) for x in filenames]\n    raw_sentences = [doc[0] for doc in docs]\n    sentence_clusters = [convert_clusters(filename, doc[1]) for filename, doc in zip(filenames, docs)]\n\n    pipe = stanza.Pipeline(\"ta\", processors=\"tokenize,pos,lemma,depparse\", package=\"default_accurate\", tokenize_pretokenized=True)\n\n    train, dev, test = [], [], []\n    for filename, sentences, coref_spans in tqdm(zip(filenames, raw_sentences, sentence_clusters), total=len(filenames)):\n        doc_id = filename\n        part_id = \" \"\n        sentence_speakers = [[\"\"] * len(sent) for sent in sentences]\n\n        processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=args.use_cconj_heads)\n        location = random.choices((train, dev, test), weights = (0.8, 0.1, 0.1))[0]\n        location.append(processed)\n\n    output_filename = os.path.join(paths[\"COREF_DATA_DIR\"], \"ta_kbc.train.json\")\n    write_json_file(output_filename, train)\n\n    output_filename = os.path.join(paths[\"COREF_DATA_DIR\"], \"ta_kbc.dev.json\")\n    write_json_file(output_filename, dev)\n\n    output_filename = os.path.join(paths[\"COREF_DATA_DIR\"], \"ta_kbc.test.json\")\n    write_json_file(output_filename, test)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/coref/convert_udcoref.py",
    "content": "from collections import defaultdict\nimport json\nimport os\nimport re\nimport glob\n\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.get_tqdm import get_tqdm\nfrom stanza.utils.datasets.coref.utils import find_cconj_head\n\nfrom stanza.utils.conll import CoNLL\n\nimport warnings\nfrom random import Random\n\nimport argparse\n\naugment_random = Random(7)\nsplit_random = Random(8)\n\ntqdm = get_tqdm()\nIS_UDCOREF_FORMAT = True\nUDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1\n\ndef process_documents(docs, augment=False):\n    # docs = sections\n    processed_section = []\n\n    for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)):\n        # drop the last token 10% of the time\n        if augment:\n            for i in doc.sentences:\n                if len(i.words) > 1:\n                    if augment_random.random() < 0.1:\n                        i.tokens = i.tokens[:-1]\n                        i.words = i.words[:-1]\n\n        # extract the entities\n        # get sentence words and lengths\n        sentences = [[j.text for j in i.all_words]\n                    for i in doc.sentences]\n        sentence_lens = [len(x.all_words) for x in doc.sentences]\n\n        cased_words = [] \n        for x in sentences:\n            if augment:\n                # modify case of the first word with 50% chance\n                if augment_random.random() < 0.5:\n                    x[0] = x[0].lower()\n\n            for y in x:\n                cased_words.append(y)\n\n        sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]\n\n        word_total = 0\n        heads = []\n        # TODO: does SD vs UD matter?\n        deprel = []\n        for sentence in doc.sentences:\n            for word in sentence.all_words:\n                deprel.append(word.deprel)\n                if not word.head or word.head == 0:\n                    heads.append(\"null\")\n                else:\n                    heads.append(word.head - 1 + word_total)\n            word_total += len(sentence.all_words)\n\n        span_clusters = defaultdict(list)\n        word_clusters = defaultdict(list)\n        head2span = []\n        is_zero = []\n        word_total = 0\n        SPANS = re.compile(r\"(\\(\\w+|[%\\w]+\\))\")\n        do_ctn = False # if we broke in the loop\n        for parsed_sentence in doc.sentences:\n            # spans regex\n            # parse the misc column, leaving on \"Entity\" entries\n            misc = [[k.split(\"=\")\n                    for k in j\n                    if k.split(\"=\")[0] == \"Entity\"]\n                    for i in parsed_sentence.all_words\n                    for j in [i.misc.split(\"|\") if i.misc else []]]\n            # and extract the Entity entry values\n            entities = [i[0][1] if len(i) > 0 else None for i in misc]\n            # extract reference information\n            refs = [SPANS.findall(i) if i else [] for i in entities]\n            # and calculate spans: the basic rule is (e... begins a reference\n            # and ) without e before ends the most recent reference\n            # every single time we get a closing element, we pop it off\n            # the refdict and insert the pair to final_refs\n            refdict = defaultdict(list)\n            final_refs = defaultdict(list)\n            last_ref = None\n            for indx, i in enumerate(refs):\n                for j in i:\n                    # this is the beginning of a reference\n                    if j[0] == \"(\":\n                        refdict[j[1+UDCOREF_ADDN:]].append(indx)\n                        last_ref = j[1+UDCOREF_ADDN:]\n                    # at the end of a reference, if we got exxxxx, that ends\n                    # a particular refereenc; otherwise, it ends the last reference\n                    elif j[-1] == \")\" and j[UDCOREF_ADDN:-1].isnumeric():\n                        if (not UDCOREF_ADDN) or j[0] == \"e\":\n                            try:\n                                final_refs[j[UDCOREF_ADDN:-1]].append((refdict[j[UDCOREF_ADDN:-1]].pop(-1), indx))\n                            except IndexError:\n                                # this is probably zero anaphora\n                                continue\n                    elif j[-1] == \")\":\n                        final_refs[last_ref].append((refdict[last_ref].pop(-1), indx))\n                        last_ref = None\n            final_refs = dict(final_refs)\n            # convert it to the right format (specifically, in (ref, start, end) tuples)\n            coref_spans = []\n            for k, v in final_refs.items():\n                for i in v:\n                    coref_spans.append([int(k), i[0], i[1]])\n            sentence_upos = [x.upos for x in parsed_sentence.all_words]\n            sentence_heads = [x.head - 1 if x.head and x.head > 0 else None for x in parsed_sentence.all_words]\n            sentence_text = [x.text for x in parsed_sentence.all_words]\n\n            # if \"_\" in sentence_text and sentence_text.index(\"_\") in [j for i in coref_spans for j in i]:\n            #     import ipdb\n            #     ipdb.set_trace()\n\n            for span in coref_spans:\n                zero = False\n                if sentence_text[span[1]] == \"_\" and span[1] == span[2]:\n                    is_zero.append([span[0], True])\n                    zero = True\n                    # oo! that's a zero coref, we should merge it forwards\n                    # i.e. we pick the next word as the head!\n                    span = [span[0], span[1]+1, span[2]+1]\n                    # crap! there's two zeros right next to each other\n                    # we are sad and confused so we give up in this case\n                    if len(sentence_text) > span[1] and sentence_text[span[1]] == \"_\":\n                        warnings.warn(\"Found two zeros next to each other in sequence; we are confused and therefore giving up.\")\n                        do_ctn = True\n                        break\n                else:\n                    is_zero.append([span[0], False])\n\n                # input is expected to be start word, end word + 1\n                # counting from 0\n                # whereas the OntoNotes coref_span is [start_word, end_word] inclusive\n                span_start = span[1] + word_total\n                span_end = span[2] + word_total + 1\n                # if its a zero coref (i.e. coref, but the head in None), we call\n                # the beginning of the span (i.e. the zero itself) the head\n\n                if zero:\n                    candidate_head = span[1]\n                else:\n                    try:\n                        candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1)\n                    except RecursionError:\n                        candidate_head = span[1]\n                    \n                if candidate_head is None:\n                    for candidate_head in range(span[1], span[2] + 1):\n                        # stanza uses 0 to mark the head, whereas OntoNotes is counting\n                        # words from 0, so we have to subtract 1 from the stanza heads\n                        #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)\n                        # treat the head of the phrase as the first word that has a head outside the phrase\n                        if (parsed_sentence.all_words[candidate_head].head is not None) and (\n                                parsed_sentence.all_words[candidate_head].head - 1 < span[1] or\n                                parsed_sentence.all_words[candidate_head].head - 1 > span[2]\n                        ):\n                            break\n                    else:\n                        # if none have a head outside the phrase (circular??)\n                        # then just take the first word\n                        candidate_head = span[1]\n                #print(\"----> %d\" % candidate_head)\n                candidate_head += word_total\n                span_clusters[span[0]].append((span_start, span_end))\n                word_clusters[span[0]].append(candidate_head)\n                head2span.append((candidate_head, span_start, span_end))\n            if do_ctn:\n                break\n            word_total += len(parsed_sentence.all_words)\n        if do_ctn:\n            continue\n        span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])\n        word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])\n        head2span = sorted(head2span)\n        is_zero = [i for _,i in sorted(is_zero)]\n\n        # remove zero tokens \"_\" from cased_words and adjust indices accordingly\n        zero_positions = [i for i, w in enumerate(cased_words) if w == \"_\"]\n        if zero_positions:\n            old_to_new = {}\n            new_idx = 0\n            for old_idx, w in enumerate(cased_words):\n                if w != \"_\":\n                    old_to_new[old_idx] = new_idx\n                    new_idx += 1\n            cased_words = [w for w in cased_words if w != \"_\"]\n            sent_id = [sent_id[i] for i in sorted(old_to_new.keys())]\n            deprel = [deprel[i] for i in sorted(old_to_new.keys())]\n            heads = [heads[i] for i in sorted(old_to_new.keys())]\n            try:\n                span_clusters = [\n                    [(old_to_new[start], old_to_new[end - 1] + 1) for start, end in cluster]\n                    for cluster in span_clusters\n                ]\n            except (KeyError, TypeError) as _: # two errors, either end-1 = -1, or start/end is None\n                warnings.warn(\"Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up.\")\n                continue\n            word_clusters = [\n                [old_to_new[h] for h in cluster]\n                for cluster in word_clusters\n            ]\n            head2span = [\n                (old_to_new[h], old_to_new[s], old_to_new[e - 1] + 1)\n                for h, s, e in head2span\n            ]\n\n        processed = {\n            \"document_id\": doc_id,\n            \"cased_words\": cased_words,\n            \"sent_id\": sent_id,\n            \"part_id\": idx,\n            # \"pos\": pos,\n            \"deprel\": deprel,\n            \"head\": heads,\n            \"span_clusters\": span_clusters,\n            \"word_clusters\": word_clusters,\n            \"head2span\": head2span,\n            \"lang\": lang,\n            \"is_zero\": is_zero\n        }\n        processed_section.append(processed)\n    return processed_section\n\ndef process_dataset(short_name, coref_output_path, split_test, train_files, dev_files):\n    section_names = ('train', 'dev')\n    section_filenames = [train_files, dev_files]\n    sections = []\n\n    test_sections = []\n\n    for section, filenames in zip(section_names, section_filenames):\n        input_file = []\n        for load in filenames:\n            lang = load.split(\"/\")[-1].split(\"_\")[0]\n            print(\"Ingesting %s from %s of lang %s\" % (section, load, lang))\n            docs = CoNLL.conll2multi_docs(load, ignore_gapping=False)\n            # sections = docs[:10]\n            print(\"  Ingested %d documents\" % len(docs))\n            if split_test and section == 'train':\n                test_section = []\n                train_section = []\n                for i in docs:\n                    # reseed for each doc so that we can attempt to keep things stable in the event\n                    # of different file orderings or some change to the number of documents\n                    split_random = Random(i.sentences[0].doc_id + i.sentences[0].text)\n                    if split_random.random() < split_test:\n                        test_section.append((i, i.sentences[0].doc_id, lang))\n                    else:\n                        train_section.append((i, i.sentences[0].doc_id, lang))\n                if len(test_section) == 0 and len(train_section) >= 2:\n                    idx = split_random.randint(0, len(train_section) - 1)\n                    test_section = [train_section[idx]]\n                    train_section = train_section[:idx] + train_section[idx+1:]\n                print(\"  Splitting %d documents from %s for test\" % (len(test_section), load))\n                input_file.extend(train_section)\n                test_sections.append(test_section)\n            else:\n                for i in docs:\n                    input_file.append((i, i.sentences[0].doc_id, lang))\n        print(\"Ingested %d total documents\" % len(input_file))\n        sections.append(input_file)\n\n    if split_test:\n        section_names = ('train', 'dev', 'test')\n        full_test_section = []\n        for filename, test_section in zip(filenames, test_sections):\n            # TODO: could write dataset-specific test sections as well\n            full_test_section.extend(test_section)\n        sections.append(full_test_section)\n\n\n    output_filenames = []\n    for section_data, section_name in zip(sections, section_names):\n        converted_section = process_documents(section_data, augment=(section_name==\"train\"))\n\n        os.makedirs(coref_output_path, exist_ok=True)\n        output_filenames.append(\"%s.%s.json\" % (short_name, section_name))\n        output_filename = os.path.join(coref_output_path, output_filenames[-1])\n        with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n            json.dump(converted_section, fout, indent=2)\n    return output_filenames\n\ndef get_dataset_by_language(coref_input_path, langs):\n    conll_path = os.path.join(coref_input_path, \"CorefUD-1.3-public\", \"data\")\n    train_filenames = []\n    dev_filenames = []\n    for lang in langs:\n        train_filenames.extend(glob.glob(os.path.join(conll_path, \"*%s*\" % lang, \"*train.conllu\")))\n        dev_filenames.extend(glob.glob(os.path.join(conll_path, \"*%s*\" % lang, \"*dev.conllu\")))\n    train_filenames = sorted(train_filenames)\n    dev_filenames = sorted(dev_filenames)\n    return train_filenames, dev_filenames\n\ndef main(args=None):\n    paths = get_default_paths()\n    parser = argparse.ArgumentParser(\n        prog='Convert UDCoref Data',\n    )\n    parser.add_argument('--split_test', default=None, type=float, help='How much of the data to randomly split from train to make a test set')\n    parser.add_argument('--output_directory', default=None, type=str, help='Where to output the data (defaults to %s)' % paths['COREF_DATA_DIR'])\n\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument('--directory', type=str, help=\"the name of the subfolder for data conversion\")\n    group.add_argument('--project', type=str, help=\"Look for and use a set of datasets for data conversion - Slavic or Hungarian\")\n    group.add_argument('--languages', type=str, help=\"Only use these specific languages from the coref directory\")\n\n    args = parser.parse_args(args=args)\n    coref_input_path = paths['COREF_BASE']\n    coref_output_path = args.output_directory if args.output_directory else paths['COREF_DATA_DIR']\n\n    if args.languages:\n        langs = args.languages.split(\",\")\n        project = \"_\".join(langs)\n        train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n    elif args.project:\n        if args.project == 'baltoslavic':\n            project = \"baltoslavic_udcoref\"\n            langs = ('Polish', 'Russian', 'Czech', 'Old_Church_Slavonic', 'Lithuanian')\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'hungarian':\n            project = \"hu_udcoref\"\n            langs = ('Hungarian',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'gerrom':\n            project = \"gerrom_udcoref\"\n            langs = ('Catalan', 'English', 'French', 'German', 'Norwegian', 'Spanish')\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'germanic':\n            project = \"germanic_udcoref\"\n            langs = ('English', 'German', 'Norwegian')\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'norwegian':\n            project = \"norwegian_udcoref\"\n            langs = ('Norwegian',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'turkish':\n            project = \"turkish_udcoref\"\n            langs = ('Turkish',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'korean':\n            project = \"korean_udcoref\"\n            langs = ('Korean',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'hindi':\n            project = \"hindi_udcoref\"\n            langs = ('Hindi',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'ancient_greek':\n            project = \"ancient_greek_udcoref\"\n            langs = ('Ancient_Greek',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'ancient_hebrew':\n            project = \"ancient_hebrew_udcoref\"\n            langs = ('Ancient_Hebrew',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n    else:\n        project = args.directory\n        conll_path = os.path.join(coref_input_path, project)\n        if not os.path.exists(conll_path) and os.path.exists(project):\n            conll_path = args.directory\n        train_filenames = sorted(glob.glob(os.path.join(conll_path, f\"*train.conllu\")))\n        dev_filenames = sorted(glob.glob(os.path.join(conll_path, f\"*dev.conllu\")))\n    return process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/coref/convert_udcoref_1.2.py",
    "content": "from collections import defaultdict\nimport json\nimport os\nimport re\nimport glob\n\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.get_tqdm import get_tqdm\nfrom stanza.utils.datasets.coref.utils import find_cconj_head\n\nfrom stanza.utils.conll import CoNLL\n\nfrom random import Random\n\nimport argparse\n\naugment_random = Random(7)\nsplit_random = Random(8)\n\ntqdm = get_tqdm()\nIS_UDCOREF_FORMAT = True\nUDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1\n\ndef process_documents(docs, augment=False):\n    processed_section = []\n\n    for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)):\n        # drop the last token 10% of the time\n        if augment:\n            for i in doc.sentences:\n                if len(i.words) > 1:\n                    if augment_random.random() < 0.1:\n                        i.tokens = i.tokens[:-1]\n                        i.words = i.words[:-1]\n\n        # extract the entities\n        # get sentence words and lengths\n        sentences = [[j.text for j in i.words]\n                    for i in doc.sentences]\n        sentence_lens = [len(x.words) for x in doc.sentences]\n\n        cased_words = [] \n        for x in sentences:\n            if augment:\n                # modify case of the first word with 50% chance\n                if augment_random.random() < 0.5:\n                    x[0] = x[0].lower()\n\n            for y in x:\n                cased_words.append(y)\n\n        sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]\n\n        word_total = 0\n        heads = []\n        # TODO: does SD vs UD matter?\n        deprel = []\n        for sentence in doc.sentences:\n            for word in sentence.words:\n                deprel.append(word.deprel)\n                if word.head == 0:\n                    heads.append(\"null\")\n                else:\n                    heads.append(word.head - 1 + word_total)\n            word_total += len(sentence.words)\n\n        span_clusters = defaultdict(list)\n        word_clusters = defaultdict(list)\n        head2span = []\n        word_total = 0\n        SPANS = re.compile(r\"(\\(\\w+|[%\\w]+\\))\")\n        for parsed_sentence in doc.sentences:\n            # spans regex\n            # parse the misc column, leaving on \"Entity\" entries\n            misc = [[k.split(\"=\")\n                    for k in j\n                    if k.split(\"=\")[0] == \"Entity\"]\n                    for i in parsed_sentence.words\n                    for j in [i.misc.split(\"|\") if i.misc else []]]\n            # and extract the Entity entry values\n            entities = [i[0][1] if len(i) > 0 else None for i in misc]\n            # extract reference information\n            refs = [SPANS.findall(i) if i else [] for i in entities]\n            # and calculate spans: the basic rule is (e... begins a reference\n            # and ) without e before ends the most recent reference\n            # every single time we get a closing element, we pop it off\n            # the refdict and insert the pair to final_refs\n            refdict = defaultdict(list)\n            final_refs = defaultdict(list)\n            last_ref = None\n            for indx, i in enumerate(refs):\n                for j in i:\n                    # this is the beginning of a reference\n                    if j[0] == \"(\":\n                        refdict[j[1+UDCOREF_ADDN:]].append(indx)\n                        last_ref = j[1+UDCOREF_ADDN:]\n                    # at the end of a reference, if we got exxxxx, that ends\n                    # a particular refereenc; otherwise, it ends the last reference\n                    elif j[-1] == \")\" and j[UDCOREF_ADDN:-1].isnumeric():\n                        if (not UDCOREF_ADDN) or j[0] == \"e\":\n                            try:\n                                final_refs[j[UDCOREF_ADDN:-1]].append((refdict[j[UDCOREF_ADDN:-1]].pop(-1), indx))\n                            except IndexError:\n                                # this is probably zero anaphora\n                                continue\n                    elif j[-1] == \")\":\n                        final_refs[last_ref].append((refdict[last_ref].pop(-1), indx))\n                        last_ref = None\n            final_refs = dict(final_refs)\n            # convert it to the right format (specifically, in (ref, start, end) tuples)\n            coref_spans = []\n            for k, v in final_refs.items():\n                for i in v:\n                    coref_spans.append([int(k), i[0], i[1]])\n            sentence_upos = [x.upos for x in parsed_sentence.words]\n            sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words]\n            for span in coref_spans:\n                # input is expected to be start word, end word + 1\n                # counting from 0\n                # whereas the OntoNotes coref_span is [start_word, end_word] inclusive\n                span_start = span[1] + word_total\n                span_end = span[2] + word_total + 1\n                candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1)\n                if candidate_head is None:\n                    for candidate_head in range(span[1], span[2] + 1):\n                        # stanza uses 0 to mark the head, whereas OntoNotes is counting\n                        # words from 0, so we have to subtract 1 from the stanza heads\n                        #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)\n                        # treat the head of the phrase as the first word that has a head outside the phrase\n                        if (parsed_sentence.words[candidate_head].head - 1 < span[1] or\n                            parsed_sentence.words[candidate_head].head - 1 > span[2]):\n                            break\n                    else:\n                        # if none have a head outside the phrase (circular??)\n                        # then just take the first word\n                        candidate_head = span[1]\n                #print(\"----> %d\" % candidate_head)\n                candidate_head += word_total\n                span_clusters[span[0]].append((span_start, span_end))\n                word_clusters[span[0]].append(candidate_head)\n                head2span.append((candidate_head, span_start, span_end))\n            word_total += len(parsed_sentence.words)\n        span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])\n        word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])\n        head2span = sorted(head2span)\n\n        processed = {\n            \"document_id\": doc_id,\n            \"cased_words\": cased_words,\n            \"sent_id\": sent_id,\n            \"part_id\": idx,\n            # \"pos\": pos,\n            \"deprel\": deprel,\n            \"head\": heads,\n            \"span_clusters\": span_clusters,\n            \"word_clusters\": word_clusters,\n            \"head2span\": head2span,\n            \"lang\": lang\n        }\n        processed_section.append(processed)\n    return processed_section\n\ndef process_dataset(short_name, coref_output_path, split_test, train_files, dev_files):\n    section_names = ('train', 'dev')\n    section_filenames = [train_files, dev_files]\n    sections = []\n\n    test_sections = []\n\n    for section, filenames in zip(section_names, section_filenames):\n        input_file = []\n        for load in filenames:\n            lang = load.split(\"/\")[-1].split(\"_\")[0]\n            print(\"Ingesting %s from %s of lang %s\" % (section, load, lang))\n            docs = CoNLL.conll2multi_docs(load)\n            print(\"  Ingested %d documents\" % len(docs))\n            if split_test and section == 'train':\n                test_section = []\n                train_section = []\n                for i in docs:\n                    # reseed for each doc so that we can attempt to keep things stable in the event\n                    # of different file orderings or some change to the number of documents\n                    split_random = Random(i.sentences[0].doc_id + i.sentences[0].text)\n                    if split_random.random() < split_test:\n                        test_section.append((i, i.sentences[0].doc_id, lang))\n                    else:\n                        train_section.append((i, i.sentences[0].doc_id, lang))\n                if len(test_section) == 0 and len(train_section) >= 2:\n                    idx = split_random.randint(0, len(train_section) - 1)\n                    test_section = [train_section[idx]]\n                    train_section = train_section[:idx] + train_section[idx+1:]\n                print(\"  Splitting %d documents from %s for test\" % (len(test_section), load))\n                input_file.extend(train_section)\n                test_sections.append(test_section)\n            else:\n                for i in docs:\n                    input_file.append((i, i.sentences[0].doc_id, lang))\n        print(\"Ingested %d total documents\" % len(input_file))\n        sections.append(input_file)\n\n    if split_test:\n        section_names = ('train', 'dev', 'test')\n        full_test_section = []\n        for filename, test_section in zip(filenames, test_sections):\n            # TODO: could write dataset-specific test sections as well\n            full_test_section.extend(test_section)\n        sections.append(full_test_section)\n\n\n    for section_data, section_name in zip(sections, section_names):\n        converted_section = process_documents(section_data, augment=(section_name==\"train\"))\n\n        os.makedirs(coref_output_path, exist_ok=True)\n        output_filename = os.path.join(coref_output_path, \"%s.%s.json\" % (short_name, section_name))\n        with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n            json.dump(converted_section, fout, indent=2)\n\ndef get_dataset_by_language(coref_input_path, langs):\n    conll_path = os.path.join(coref_input_path, \"CorefUD-1.2-public\", \"data\")\n    train_filenames = []\n    dev_filenames = []\n    for lang in langs:\n        train_filenames.extend(glob.glob(os.path.join(conll_path, \"*%s*\" % lang, \"*train.conllu\")))\n        dev_filenames.extend(glob.glob(os.path.join(conll_path, \"*%s*\" % lang, \"*dev.conllu\")))\n    train_filenames = sorted(train_filenames)\n    dev_filenames = sorted(dev_filenames)\n    return train_filenames, dev_filenames\n\ndef main():\n    paths = get_default_paths()\n    parser = argparse.ArgumentParser(\n        prog='Convert UDCoref Data',\n    )\n    parser.add_argument('--split_test', default=None, type=float, help='How much of the data to randomly split from train to make a test set')\n\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument('--directory', type=str, help=\"the name of the subfolder for data conversion\")\n    group.add_argument('--project', type=str, help=\"Look for and use a set of datasets for data conversion - Slavic or Hungarian\")\n\n    args = parser.parse_args()\n    coref_input_path = paths['COREF_BASE']\n    coref_output_path = paths['COREF_DATA_DIR']\n\n    if args.project:\n        if args.project == 'slavic':\n            project = \"slavic_udcoref\"\n            langs = ('Polish', 'Russian', 'Czech')\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'hungarian':\n            project = \"hu_udcoref\"\n            langs = ('Hungarian',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'gerrom':\n            project = \"gerrom_udcoref\"\n            langs = ('Catalan', 'English', 'French', 'German', 'Norwegian', 'Spanish')\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'germanic':\n            project = \"germanic_udcoref\"\n            langs = ('English', 'German', 'Norwegian')\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n        elif args.project == 'norwegian':\n            project = \"norwegian_udcoref\"\n            langs = ('Norwegian',)\n            train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)\n    else:\n        project = args.directory\n        conll_path = os.path.join(coref_input_path, project)\n        if not os.path.exists(conll_path) and os.path.exists(project):\n            conll_path = args.directory\n        train_filenames = sorted(glob.glob(os.path.join(conll_path, f\"*train.conllu\")))\n        dev_filenames = sorted(glob.glob(os.path.join(conll_path, f\"*dev.conllu\")))\n    process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames)\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/coref/utils.py",
    "content": "from collections import defaultdict\nfrom functools import lru_cache\n\nclass DynamicDepth():\n    \"\"\"\n    Implements a cache + dynamic programming to find the relative depth of every word in a subphrase given the head word for every word.\n    \"\"\"\n    def get_parse_depths(self, heads, start, end):\n        \"\"\"Return the relative depth for every word\n\n        Args:\n            heads (list): List where each entry is the index of that entry's head word in the dependency parse\n            start (int): starting index of the heads for the subphrase\n            end (int): ending index of the heads for the subphrase\n\n        Returns:\n            list: Relative depth in the dependency parse for every word\n        \"\"\"\n        self.heads = heads[start:end]\n        self.relative_heads = [h - start if h else -100 for h in self.heads] # -100 to deal with 'none' headwords\n\n        depths = [self._get_depth_recursive(h) for h in range(len(self.relative_heads))]\n\n        return depths\n\n    @lru_cache(maxsize=None)\n    def _get_depth_recursive(self, index):\n        \"\"\"Recursively get the depths of every index using a cache and recursion\n\n        Args:\n            index (int): Index of the word for which to calculate the relative depth\n\n        Returns:\n            int: Relative depth of the word at the index\n        \"\"\"\n        # if the head for the current index is outside the scope, this index is a relative root\n        if self.relative_heads[index] >= len(self.relative_heads) or self.relative_heads[index] < 0:\n            return 0\n        return self._get_depth_recursive(self.relative_heads[index]) + 1\n\ndef find_cconj_head(heads, upos, start, end):\n    \"\"\"\n    Finds how far each word is from the head of a span, then uses the closest CCONJ to the head as the new head\n\n    If no CCONJ is present, returns None\n    \"\"\"\n    # use head information to extract parse depth\n    dynamicDepth = DynamicDepth()\n    depth = dynamicDepth.get_parse_depths(heads, start, end)\n    depth_limit = 2\n\n    # return first 'CCONJ' token above depth limit, if exists\n    # unlike the original paper, we expect the parses to use UPOS, hence CCONJ instead of CC\n    cc_indexes = [i for i in range(end - start) if upos[i+start] == 'CCONJ' and depth[i] < depth_limit]\n    if cc_indexes:\n        return cc_indexes[0] + start\n    return None\n\ndef process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=True, lang=None):\n    \"\"\"\n    doc_id: a string naming the document\n    part_id: if the document has a particular subpart (can be blank)\n    sentences: a list of list of string representing the raw text\n\n    coref_spans: a list of lists\n    one list per sentence\n    each sentence has a list of spans, where each span is (span_index, span_start, span_end)\n    the indices are relative to 0 for that particular sentence, and if the span is exactly 1 word long, span_start == span_end\n\n    sentence_speakers: a list of list of string representing who said each word.  can all be blank if there are no known speakers\n    \"\"\"\n    sentence_lens = [len(x) for x in sentences]\n    if sentence_speakers is None:\n        sentence_speakers = [\" \" for _ in sentences]\n    if all(isinstance(x, list) for x in sentence_speakers):\n        speaker = [y for x in sentence_speakers for y in x]\n    else:\n        speaker = [y for x, sent_len in zip(sentence_speakers, sentence_lens) for y in [x] * sent_len]\n\n    cased_words = [y for x in sentences for y in x]\n    sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]\n\n    # use the trees to get the xpos tags\n    # alternatively, could translate the pos_tags field,\n    # but those have numbers, which is annoying\n    #tree_text = \"\\n\".join(x['parse_tree'] for x in paragraph)\n    #trees = tree_reader.read_trees(tree_text)\n    #pos = [x.label for tree in trees for x in tree.yield_preterminals()]\n    # actually, the downstream code doesn't use pos at all.  maybe we can skip?\n\n    doc = pipe(sentences)\n    word_total = 0\n    heads = []\n    # TODO: does SD vs UD matter?\n    deprel = []\n    for sentence in doc.sentences:\n        for word in sentence.words:\n            deprel.append(word.deprel)\n            if word.head == 0:\n                heads.append(\"null\")\n            else:\n                heads.append(word.head - 1 + word_total)\n        word_total += len(sentence.words)\n\n    span_clusters = defaultdict(list)\n    word_clusters = defaultdict(list)\n    head2span = []\n    word_total = 0\n    for sent_idx, (parsed_sentence, ontonotes_words) in enumerate(zip(doc.sentences, sentences)):\n        sentence_upos = [x.upos for x in parsed_sentence.words]\n        sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words]\n        for span in coref_spans[sent_idx]:\n            # input is expected to be start word, end word + 1\n            # counting from 0\n            # whereas the OntoNotes coref_span is [start_word, end_word] inclusive\n            span_start = span[1] + word_total\n            span_end = span[2] + word_total + 1\n            candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if use_cconj_heads else None\n            if candidate_head is None:\n                for candidate_head in range(span[1], span[2] + 1):\n                    # stanza uses 0 to mark the head, whereas OntoNotes is counting\n                    # words from 0, so we have to subtract 1 from the stanza heads\n                    #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)\n                    # treat the head of the phrase as the first word that has a head outside the phrase\n                    if (parsed_sentence.words[candidate_head].head - 1 < span[1] or\n                        parsed_sentence.words[candidate_head].head - 1 > span[2]):\n                        break\n                else:\n                    # if none have a head outside the phrase (circular??)\n                    # then just take the first word\n                    candidate_head = span[1]\n            #print(\"----> %d\" % candidate_head)\n            candidate_head += word_total\n            span_clusters[span[0]].append((span_start, span_end))\n            word_clusters[span[0]].append(candidate_head)\n            head2span.append((candidate_head, span_start, span_end))\n        word_total += len(ontonotes_words)\n    span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])\n    word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])\n    head2span = sorted(head2span)\n\n    processed = {\n        \"document_id\": doc_id,\n        \"part_id\": part_id,\n        \"cased_words\": cased_words,\n        \"sent_id\": sent_id,\n        \"speaker\": speaker,\n        #\"pos\": pos,\n        \"deprel\": deprel,\n        \"head\": heads,\n        \"span_clusters\": span_clusters,\n        \"word_clusters\": word_clusters,\n        \"head2span\": head2span,\n    }\n    if part_id is not None:\n        processed[\"part_id\"] = part_id\n    if lang is not None:\n        processed[\"lang\"] = lang\n    return processed\n"
  },
  {
    "path": "stanza/utils/datasets/corenlp_segmenter_dataset.py",
    "content": "\"\"\"\nOutput a treebank's sentences in a form that can be processed by the CoreNLP CRF Segmenter\n\nRun it as\n  python3 -m stanza.utils.datasets.corenlp_segmenter_dataset <treebank>\nsuch as\n  python3 -m stanza.utils.datasets.corenlp_segmenter_dataset UD_Chinese-GSDSimp --output_dir $CHINESE_SEGMENTER_HOME\n\"\"\"\n\nimport argparse\nimport os\nimport sys\nimport tempfile\n\nimport stanza.utils.datasets.common as common\nimport stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank\nimport stanza.utils.default_paths as default_paths\n\nfrom stanza.models.common.constant import treebank_to_short_name\n\ndef build_argparse():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('treebanks', type=str, nargs='*', default=[\"UD_Chinese-GSDSimp\"], help='Which treebanks to run on')\n    parser.add_argument('--output_dir', type=str, default='.', help='Where to put the results')\n    return parser\n\n\ndef write_segmenter_file(output_filename, dataset):\n    with open(output_filename, \"w\") as fout:\n        for sentence in dataset:\n            sentence = [x for x in sentence if not x.startswith(\"#\")]\n            sentence = [x for x in [y.strip() for y in sentence] if x]\n            # eliminate MWE, although Chinese currently doesn't have any\n            sentence = [x for x in sentence if x.split(\"\\t\")[0].find(\"-\") < 0]\n\n            text = \" \".join(x.split(\"\\t\")[1] for x in sentence)\n            fout.write(text)\n            fout.write(\"\\n\")\n\ndef process_treebank(treebank, model_type, paths, output_dir):\n    with tempfile.TemporaryDirectory() as tokenizer_dir:\n        paths = dict(paths)\n        paths[\"TOKENIZE_DATA_DIR\"] = tokenizer_dir\n\n        short_name = treebank_to_short_name(treebank)\n        \n        # first we process the tokenization data\n        args = argparse.Namespace()\n        args.augment = False\n        args.prepare_labels = False\n        prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, args)\n\n        # TODO: these names should be refactored\n        train_file = f\"{tokenizer_dir}/{short_name}.train.gold.conllu\"\n        dev_file = f\"{tokenizer_dir}/{short_name}.dev.gold.conllu\"\n        test_file = f\"{tokenizer_dir}/{short_name}.test.gold.conllu\"\n\n        train_set = common.read_sentences_from_conllu(train_file)\n        dev_set = common.read_sentences_from_conllu(dev_file)\n        test_set = common.read_sentences_from_conllu(test_file)\n\n        train_out = os.path.join(output_dir, f\"{short_name}.train.seg.txt\")\n        test_out = os.path.join(output_dir, f\"{short_name}.test.seg.txt\")\n\n        write_segmenter_file(train_out, train_set + dev_set)\n        write_segmenter_file(test_out, test_set)\n\ndef main():\n    parser = build_argparse()\n    args = parser.parse_args()\n\n    paths = default_paths.get_default_paths()\n    for treebank in args.treebanks:\n        process_treebank(treebank, common.ModelType.TOKENIZER, paths, args.output_dir)\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/depparse/check_results.py",
    "content": "\"\"\"\nA small script to report the dev/test scores from a depparse run, along with averaging multiple runs at once.\n\nUses the expected log format from the depparse.  Will not work otherwise.\n\"\"\"\n\nimport argparse\nimport re\nimport sys\n\ndev_re = re.compile(\".*INFO: step ([0-9]+).*dev_score = ([.0-9]+).*\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Grep through a list of files looking for the final results or best results up to a point\")\n    parser.add_argument(\"filenames\", nargs=\"+\", help=\"Files to check\")\n    parser.add_argument(\"--step\", default=None, type=int, help=\"If set, stop checking at this step\")\n    args = parser.parse_args()\n\n    filenames = args.filenames\n    if len(filenames) == 0:\n        return\n\n    dev_scores = []\n    test_scores = []\n\n    best_step = None\n    for filename in filenames:\n        with open(filename, encoding=\"utf-8\") as fin:\n            lines = fin.readlines()\n            dev_score = None\n            test_score = None\n            for line in lines:\n                if line.find(\"Parser score\") >= 0:\n                    score = float(line.strip().split()[-1])\n                    if \"dev\" in line:\n                        dev_score = score\n                    elif \"test\" in line:\n                        test_score = score\n                    else:\n                        raise AssertionError(\"Did the parser score layout change?  Got an unexpected score line in %s\" % filename)\n                    best_step = None\n                dev_match = dev_re.match(line)\n                if dev_match:\n                    step = int(dev_match.groups()[0])\n                    if args.step is not None and step > args.step:\n                        break\n                    score = float(dev_match.groups()[1]) * 100\n                    if dev_score is None or score > dev_score:\n                        dev_score = score\n                        best_step = step\n            if dev_score is None:\n                dev_score = \"N/A\"\n            else:\n                dev_scores.append(dev_score)\n                dev_score = \"%.2f\" % dev_score\n            if test_score is None:\n                test_score = \"N/A\"\n            else:\n                test_scores.append(test_score)\n                test_score = \"%.2f\" % test_score\n            if best_step is not None:\n                print(\"%s     %s  (%d)\" % (filename, dev_score, best_step))\n            else:\n                print(\"%s     %s  %s\" % (filename, dev_score, test_score))\n\n    if len(dev_scores) > 0:\n        dev_score = sum(dev_scores) / len(dev_scores)\n        print(\"Avg dev score:  %.2f\" % dev_score)\n    if len(test_scores) > 0:\n        test_score = sum(test_scores) / len(test_scores)\n        print(\"Avg test score: %.2f\" % test_score)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/ner/build_en_combined.py",
    "content": "\"\"\"\nBuilds a combined model out of OntoNotes, WW, and CoNLL.\n\nThis is done with three layers in the multi_ner column:\n\nFirst layer is OntoNotes only.  Other datasets have that left as blank.\n\nSecond layer is the 9 class WW dataset.  OntoNotes is reduced to 9 classes for this column.\n\nThird column is the CoNLL dataset.  OntoNotes and WW are both projected to this.\n\"\"\"\n\nimport json\nimport os\nimport shutil\n\nfrom stanza.utils import default_paths\nfrom stanza.utils.datasets.ner.simplify_en_worldwide import process_label\nfrom stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide\nfrom stanza.utils.datasets.ner.utils import combine_files\n\ndef convert_ontonotes_file(filename, short_name):\n    assert \"en_ontonotes.\" in filename\n    if not os.path.exists(filename):\n        raise FileNotFoundError(\"Cannot convert missing file %s\" % filename)\n    new_filename = filename.replace(\"en_ontonotes.\", short_name + \".ontonotes.\")\n\n    with open(filename) as fin:\n        doc = json.load(fin)\n\n    for sentence in doc:\n        is_start = False\n        for word in sentence:\n            text = word['text']\n            ner = word['ner']\n            s9 = simplify_ontonotes_to_worldwide(ner)\n            _, s4, is_start = process_label((text, s9), is_start)\n            word['multi_ner'] = (ner, s9, s4)\n\n    with open(new_filename, \"w\") as fout:\n        json.dump(doc, fout, indent=2)\n\ndef convert_worldwide_file(filename, short_name):\n    assert \"en_worldwide-9class.\" in filename\n    if not os.path.exists(filename):\n        raise FileNotFoundError(\"Cannot convert missing file %s\" % filename)\n    new_filename = filename.replace(\"en_worldwide-9class.\", short_name + \".worldwide-9class.\")\n\n    with open(filename) as fin:\n        doc = json.load(fin)\n\n    for sentence in doc:\n        is_start = False\n        for word in sentence:\n            text = word['text']\n            ner = word['ner']\n            _, s4, is_start = process_label((text, ner), is_start)\n            word['multi_ner'] = (\"-\", ner, s4)\n\n    with open(new_filename, \"w\") as fout:\n        json.dump(doc, fout, indent=2)\n\ndef convert_conll03_file(filename, short_name):\n    assert \"en_conll03.\" in filename\n    if not os.path.exists(filename):\n        raise FileNotFoundError(\"Cannot convert missing file %s\" % filename)\n    new_filename = filename.replace(\"en_conll03.\", short_name + \".conll03.\")\n\n    with open(filename) as fin:\n        doc = json.load(fin)\n\n    for sentence in doc:\n        for word in sentence:\n            ner = word['ner']\n            word['multi_ner'] = (\"-\", \"-\", ner)\n\n    with open(new_filename, \"w\") as fout:\n        json.dump(doc, fout, indent=2)\n\ndef build_combined_dataset(base_output_path, short_name):\n    convert_ontonotes_file(os.path.join(base_output_path, \"en_ontonotes.train.json\"), short_name)\n    convert_ontonotes_file(os.path.join(base_output_path, \"en_ontonotes.dev.json\"), short_name)\n    convert_ontonotes_file(os.path.join(base_output_path, \"en_ontonotes.test.json\"), short_name)\n\n    convert_worldwide_file(os.path.join(base_output_path, \"en_worldwide-9class.train.json\"), short_name)\n    convert_conll03_file(os.path.join(base_output_path, \"en_conll03.train.json\"), short_name)\n\n    combine_files(os.path.join(base_output_path, \"%s.train.json\" % short_name),\n                  os.path.join(base_output_path, \"en_combined.ontonotes.train.json\"),\n                  os.path.join(base_output_path, \"en_combined.worldwide-9class.train.json\"),\n                  os.path.join(base_output_path, \"en_combined.conll03.train.json\"))\n    shutil.copyfile(os.path.join(base_output_path, \"en_combined.ontonotes.dev.json\"),\n                    os.path.join(base_output_path, \"%s.dev.json\" % short_name))\n    shutil.copyfile(os.path.join(base_output_path, \"en_combined.ontonotes.test.json\"),\n                    os.path.join(base_output_path, \"%s.test.json\" % short_name))\n\n\ndef main():\n    paths = default_paths.get_default_paths()\n    base_output_path = paths[\"NER_DATA_DIR\"]\n\n    build_combined_dataset(base_output_path, \"en_combined\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/check_for_duplicates.py",
    "content": "\"\"\"\nA simple tool to check if there are duplicates in a set of NER files\n\nIt's surprising how many datasets have a bunch of duplicates...\n\"\"\"\n\ndef read_sentences(filename):\n    \"\"\"\n    Read the sentences (without tags) from a BIO file\n    \"\"\"\n    sentences = []\n    with open(filename) as fin:\n        lines = fin.readlines()\n    current_sentence = []\n    for line in lines:\n        line = line.strip()\n        if not line:\n            if current_sentence:\n                sentences.append(tuple(current_sentence))\n            current_sentence = []\n            continue\n        word = line.split(\"\\t\")[0]\n        current_sentence.append(word)\n    if len(current_sentence) > 0:\n        sentences.append(tuple(current_sentence))\n    return sentences\n    \ndef check_for_duplicates(output_filenames, fail=False, check_self=False, print_all=False):\n    \"\"\"\n    Checks for exact duplicates in a list of NER files\n    \"\"\"\n    sentence_map = {}\n    for output_filename in output_filenames:\n        duplicates = 0\n        sentences = read_sentences(output_filename)\n        for sentence in sentences:\n            other_file = sentence_map.get(sentence, None)\n            if other_file is not None and (check_self or other_file != output_filename):\n                if fail:\n                    raise ValueError(\"Duplicate sentence '{}', first in {}, also in {}\".format(\"\".join(sentence), sentence_map[sentence], output_filename))\n                else:\n                    if duplicates == 0 and not print_all:\n                        print(\"First duplicate:\")\n                    if duplicates == 0 or print_all:                    \n                        print(\"{}\\nFound in {} and {}\".format(sentence, other_file, output_filename))\n                    duplicates = duplicates + 1\n            sentence_map[sentence] = output_filename\n        if duplicates > 0:\n            print(\"%d duplicates found in %s\" % (duplicates, output_filename))\n"
  },
  {
    "path": "stanza/utils/datasets/ner/combine_ner_datasets.py",
    "content": "import argparse\n\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.datasets.ner.utils import combine_dataset\n\nSHARDS = (\"train\", \"dev\", \"test\")\n\ndef main(args=None):\n    ner_data_dir = get_default_paths()['NER_DATA_DIR']\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--output_dataset', type=str, help='What dataset to output')\n    parser.add_argument('input_datasets', type=str, nargs='+', help='Which datasets to input')\n\n    parser.add_argument('--input_dir', type=str, default=ner_data_dir, help='Which directory to find the datasets')\n    parser.add_argument('--output_dir', type=str, default=ner_data_dir, help='Which directory to write the dataset')\n    args = parser.parse_args(args)\n\n    input_dir = args.input_dir\n    output_dir = args.output_dir\n\n    combine_dataset(input_dir, output_dir, args.input_datasets, args.output_dataset)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/compare_entities.py",
    "content": "\"\"\"\nReport the fraction of NER entities in one file which are present in another.\n\nPurpose: show the coverage of one file on another, such as reporting\nthe number of entities in one dataset on another\n\"\"\"\n\n\nimport argparse\n\nfrom stanza.utils.datasets.ner.utils import read_json_entities\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Report the coverage of one NER file on another.\")\n    parser.add_argument('--train', type=str, nargs=\"+\", required=True, help='File to use to collect the known entities (not necessarily train).')\n    parser.add_argument('--test', type=str, nargs=\"+\", required=True, help='File for which we want to know the ratio of known entities')\n    args = parser.parse_args()\n    return args\n\ndef report_known_entities(train_file, test_file):\n    train_entities = read_json_entities(train_file)\n    test_entities = read_json_entities(test_file)\n\n    train_entities = set(x[0] for x in train_entities)\n    total_score = sum(1 for x in test_entities if x[0] in train_entities)\n    print(train_file, test_file, total_score / len(test_entities))\n\ndef main():\n    args = parse_args()\n\n    for train_idx, train_file in enumerate(args.train):\n        if train_idx > 0:\n            print()\n        for test_file in args.test:\n            report_known_entities(train_file, test_file)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/conll_to_iob.py",
    "content": "\"\"\"\nProcess a conll file into BIO\n\nIncludes the ability to process a file from a text file\nor a text file within a zip\n\nMain program extracts a piece of the zip file from the Danish DDT dataset\n\"\"\"\n\nimport io\nimport zipfile\nfrom zipfile import ZipFile\nfrom stanza.utils.conll import CoNLL\n\ndef process_conll(input_file, output_file, zip_file=None, conversion=None, attr_prefix=\"name\", allow_empty=False):\n    \"\"\"\n    Process a single file from DDT\n\n    zip_filename: path to ddt.zip\n    in_filename: which piece to read\n    out_filename: where to write the result\n\n    label: which attribute to get from the misc field\n    \"\"\"\n    if not attr_prefix.endswith(\"=\"):\n        attr_prefix = attr_prefix + \"=\"\n\n    doc = CoNLL.conll2doc(input_file=input_file, zip_file=zip_file)\n\n    with open(output_file, \"w\", encoding=\"utf-8\") as fout:\n        for sentence_idx, sentence in enumerate(doc.sentences):\n            for token_idx, token in enumerate(sentence.tokens):\n                misc = token.misc.split(\"|\")\n                for attr in misc:\n                    if attr.startswith(attr_prefix):\n                        ner = attr.split(\"=\", 1)[1]\n                        break\n                else: # name= not found\n                    if allow_empty:\n                        ner = \"O\"\n                    else:\n                        raise ValueError(\"Could not find ner tag in document {}, sentence {}, token {}\".format(input_file, sentence_idx, token_idx))\n\n                if ner != \"O\" and conversion is not None:\n                    if isinstance(conversion, dict):\n                        bio, label = ner.split(\"-\", 1)\n                        if label in conversion:\n                            label = conversion[label]\n                        ner = \"%s-%s\" % (bio, label)\n                    else:\n                        ner = conversion(ner)\n                fout.write(\"%s\\t%s\\n\" % (token.text, ner))\n            fout.write(\"\\n\")\n\ndef main():\n    process_conll(zip_file=\"extern_data/ner/da_ddt/ddt.zip\", input_file=\"ddt.train.conllu\", output_file=\"data/ner/da_ddt.train.bio\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_amt.py",
    "content": "\"\"\"\nConverts a .json file from AMT to a .bio format and then a .json file\n\nTo ignore Facility and Product, turn NORP into miscellaneous:\n\n python3 stanza/utils/datasets/ner/convert_amt.py --input_path /u/nlp/data/ner/stanza/en_amt/output.manifest --ignore Product,Facility --remap NORP=Miscellaneous\n\nTo turn all labels into the 4 class used in conll03:\n\n  python3 stanza/utils/datasets/ner/convert_amt.py --input_path /u/nlp/data/ner/stanza/en_amt/output.manifest --ignore Product,Facility --remap NORP=MISC,Miscellaneous=MISC,Location=LOC,Person=PER,Organization=ORG\n\"\"\"\n\nimport argparse\nimport copy\nimport json\nfrom operator import itemgetter\nimport sys\n\nfrom tqdm import tqdm\n\nimport stanza\nfrom stanza.utils.datasets.ner.utils import write_sentences\nimport stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file\n\ndef read_json(input_filename):\n    \"\"\"\n    Read the json file and extract the NER labels\n\n    Will not return lines which are not labeled\n\n    Return format is a list of lines\n    where each line is a tuple: (text, labels)\n    labels is a list of maps, {'label':..., 'startOffset':..., 'endOffset':...}\n    \"\"\"\n    docs = []\n    blank = 0\n    unlabeled = 0\n    broken = 0\n    with open(input_filename, encoding=\"utf-8\") as fin:\n        for line_idx, line in enumerate(fin):\n            doc = json.loads(line)\n            if sorted(doc.keys()) == ['source']:\n                unlabeled += 1\n                continue\n            if 'source' not in doc:\n                blank += 1\n                continue\n            source = doc['source']\n            entities = None\n            for k in doc.keys():\n                if k == 'source' or k.endswith('metadata'):\n                    continue\n                if 'annotations' not in doc[k]:\n                    continue\n                annotations = doc[k]['annotations']\n                if 'entities' not in annotations:\n                    continue\n                if 'entities' in annotations:\n                    if entities is not None:\n                        raise ValueError(\"Found a map with multiple annotations at line %d\" % line_idx)\n                    entities = annotations['entities']\n                # entities is now a map such as\n                # [{'label': 'Location', 'startOffset': 0, 'endOffset': 6},\n                #  {'label': 'Location', 'startOffset': 11, 'endOffset': 23},\n                #  {'label': 'NORP', 'startOffset': 66, 'endOffset': 74},\n                #  {'label': 'NORP', 'startOffset': 191, 'endOffset': 214}]\n            if entities is None:\n                unlabeled += 1\n                continue\n            is_broken = any(any(x not in entity for x in ('label', 'startOffset', 'endOffset'))\n                            for entity in entities)\n            if is_broken:\n                broken += 1\n                if broken == 1:\n                    print(\"Found an entity which was missing either label, startOffset, or endOffset\")\n                    print(entities)\n            docs.append((source, entities))\n\n    print(\"Found %d labeled lines.  %d lines were blank, %d lines were broken, and %d lines were unlabeled\" % (len(docs), blank, broken, unlabeled))\n    return docs\n\ndef remove_ignored_labels(docs, ignored):\n    if not ignored:\n        return docs\n\n    ignored = set(ignored.split(\",\"))\n    # drop all labels which match something in ignored\n    # otherwise leave everything the same\n    new_docs = [(doc[0], [x for x in doc[1] if x['label'] not in ignored])\n                for doc in docs]\n    return new_docs\n\ndef remap_labels(docs, remap):\n    if not remap:\n        return docs\n\n    remappings = {}\n    for remapping in remap.split(\",\"):\n        pieces = remapping.split(\"=\")\n        remappings[pieces[0]] = pieces[1]\n\n    print(remappings)\n\n    new_docs = []\n    for doc in docs:\n        entities = copy.deepcopy(doc[1])\n        for entity in entities:\n            entity['label'] = remappings.get(entity['label'], entity['label'])\n        new_doc = (doc[0], entities)\n        new_docs.append(new_doc)\n    return new_docs\n\ndef remove_nesting(docs):\n    \"\"\"\n    Currently the NER tool does not handle nesting, so we just throw away nested entities\n\n    In the event of entites which exactly overlap, the first one in the list wins\n    \"\"\"\n    new_docs = []\n    nested = 0\n    exact = 0\n    total = 0\n    for doc in docs:\n        source, labels = doc\n        # sort by startOffset, -endOffset\n        labels = sorted(labels, key=lambda x: (x['startOffset'], -x['endOffset']))\n        new_labels = []\n        for label in labels:\n            total += 1\n            # note that this works trivially for an empty list\n            for other in reversed(new_labels):\n                if label['startOffset'] == other['startOffset'] and label['endOffset'] == other['endOffset']:\n                    exact += 1\n                    break\n                if label['startOffset'] < other['endOffset']:\n                    #print(\"Ignoring nested entity: {} |{}| vs {} |{}|\".format(label, source[label['startOffset']:label['endOffset']], other, source[other['startOffset']:other['endOffset']]))\n                    nested += 1\n                    break\n            else: # yes, this is meant to be a for-else\n                new_labels.append(label)\n\n        new_docs.append((source, new_labels))\n    print(\"Ignored %d exact and %d nested labels out of %d entries\" % (exact, nested, total))\n    return new_docs\n\ndef process_doc(source, labels, pipe):\n    \"\"\"\n    Given a source text and a list of labels, tokenize the text, then assign labels based on the spans defined\n    \"\"\"\n    doc = pipe(source)\n    sentences = doc.sentences\n    for sentence in sentences:\n        for token in sentence.tokens:\n            token.ner = \"O\"\n\n    for label in labels:\n        ner = label['label']\n        start_offset = label['startOffset']\n        end_offset = label['endOffset']\n        for sentence in sentences:\n            if (sentence.tokens[0].start_char <= start_offset and\n                sentence.tokens[-1].end_char >= end_offset):\n                # found the sentence!\n                break\n        else: # for-else again!  deal with it\n            continue\n\n        start_token = None\n        end_token = None\n        for token_idx, token in enumerate(sentence.tokens):\n            if token.start_char <= start_offset and token.end_char > start_offset:\n                # ideally we'd have start_char == start_offset, but maybe our\n                # tokenization doesn't match the tokenization of the annotators\n                start_token = token\n                start_token.ner = \"B-\" + ner\n            elif start_token is not None:\n                if token.start_char >= end_offset and token_idx > 0:\n                    end_token = sentence.tokens[token_idx-1]\n                    break\n                if token.end_char == end_offset and token_idx > 0 and token.text in (',', '.'):\n                    end_token = sentence.tokens[token_idx-1]\n                    break\n                token.ner = \"I-\" + ner\n            if token.end_char >= end_offset and end_token is None:\n                end_token = token\n                break\n        if start_token is None or end_token is None:\n            raise AssertionError(\"This should not happen\")\n\n    return [[(token.text, token.ner) for token in sentence.tokens] for sentence in sentences]\n\n\n\ndef main(args):\n    \"\"\"\n    Read in a .json file of labeled data from AMT, write out a converted .bio file\n\n    Enforces that there is only one set of labels on a sentence\n    (TODO: add an option to skip certain sets of labels)\n    \"\"\"\n    docs = read_json(args.input_path)\n\n    if len(docs) == 0:\n        print(\"Error: no documents found in the input file!\")\n        return\n\n    docs = remove_ignored_labels(docs, args.ignore)\n    docs = remap_labels(docs, args.remap)\n    docs = remove_nesting(docs)\n\n    pipe = stanza.Pipeline(args.language, processors=\"tokenize\")\n    sentences = []\n    for doc in tqdm(docs):\n        sentences.extend(process_doc(*doc, pipe))\n    print(\"Found %d total sentences (may be more than #docs if a doc has more than one sentence)\" % len(sentences))\n    bio_filename = args.output_path\n    write_sentences(args.output_path, sentences)\n    print(\"Sentences written to %s\" % args.output_path)\n    if bio_filename.endswith(\".bio\"):\n        json_filename = bio_filename[:-4] + \".json\"\n    else:\n        json_filename = bio_filename + \".json\"\n    prepare_ner_file.process_dataset(bio_filename, json_filename)\n    \n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--language', type=str, default=\"en\", help=\"Language to process\")\n    parser.add_argument('--input_path', type=str, default=\"output.manifest\", help=\"Where to find the files\")\n    parser.add_argument('--output_path', type=str, default=\"data/ner/en_amt.test.bio\", help=\"Where to output the results\")\n    parser.add_argument('--json_output_path', type=str, default=None, help=\"Where to output .json.  Best guess will be made if there is no .json file\")\n    parser.add_argument('--ignore', type=str, default=None, help=\"Ignore these labels: comma separated list without B- or I-\")\n    parser.add_argument('--remap', type=str, default=None, help=\"Remap labels: comma separated list of X=Y\")\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_ar_aqmar.py",
    "content": "\"\"\"\nA script to randomly shuffle the input files in the AQMAR dataset and produce train/dev/test for stanza\n\nThe sentences themselves are shuffled, not the data files\n\nThis script reads the input files directly from the .zip\n\"\"\"\n\n\nfrom collections import Counter\nimport random\nimport zipfile\n\nfrom stanza.utils.datasets.ner.utils import write_dataset\n\ndef read_sentences(infile):\n    \"\"\"\n    Read sentences from an open file\n    \"\"\"\n    sents = []\n    cache = []\n    for line in infile:\n        if isinstance(line, bytes):\n            line = line.decode()\n        line = line.rstrip()\n        if len(line) == 0:\n            if len(cache) > 0:\n                sents.append(cache)\n                cache = []\n            continue\n        array = line.split()\n        assert len(array) == 2\n        w, t = array\n        cache.append([w, t])\n    if len(cache) > 0:\n        sents.append(cache)\n        cache = []\n    return sents\n\n\ndef normalize_tags(sents):\n    new_sents = []\n    # normalize tags\n    for sent in sents:\n        new_sentence = []\n        for i, pair in enumerate(sent):\n            w, t = pair\n            if t.startswith('O'):\n                new_t = 'O'\n            elif t.startswith('I-'):\n                type = t[2:]\n                if type.startswith('MIS'):\n                    new_t = 'I-MISC'\n                elif type.startswith('-'): # handle I--ORG\n                    new_t = 'I-' + type[1:]\n                else:\n                    new_t = t\n            elif t.startswith('B-'):\n                type = t[2:]\n                if type.startswith('MIS'):\n                    new_t = 'B-MISC'\n                elif type.startswith('ENGLISH') or type.startswith('SPANISH'):\n                    new_t = 'O'\n                else:\n                    new_t = t\n            else:\n                new_t = 'O'\n            # modify original tag\n            new_sentence.append((sent[i][0], new_t))\n        new_sents.append(new_sentence)\n    return new_sents\n\n\ndef convert_shuffle(base_input_path, base_output_path, short_name):\n    \"\"\"\n    Convert AQMAR to a randomly shuffled dataset\n\n    base_input_path is the zip file.  base_output_path is the output directory\n    \"\"\"\n    if not zipfile.is_zipfile(base_input_path):\n        raise FileNotFoundError(\"Expected %s to be the zipfile with AQMAR in it\" % base_input_path)\n\n    with zipfile.ZipFile(base_input_path) as zin:\n        namelist = zin.namelist()\n        annotation_files = [x for x in namelist if x.endswith(\".txt\") and not \"/\" in x]\n        annotation_files = sorted(annotation_files)\n\n        # although not necessary for good results, this does put\n        # things in the same order the shell was alphabetizing files\n        # when the original models were created for Stanza\n        assert annotation_files[2] == 'Computer.txt'\n        assert annotation_files[3] == 'Computer_Software.txt'\n        annotation_files[2], annotation_files[3] = annotation_files[3], annotation_files[2]\n\n        if len(annotation_files) != 28:\n            raise RuntimeError(\"Expected exactly 28 labeled .txt files in %s but got %d\" % (base_input_path, len(annotation_files)))\n\n        sentences = []\n        for in_filename in annotation_files:\n            with zin.open(in_filename) as infile:\n                new_sentences = read_sentences(infile)\n            print(f\"{len(new_sentences)} sentences read from {in_filename}\")\n\n            new_sentences = normalize_tags(new_sentences)\n            sentences.extend(new_sentences)\n\n    all_tags = Counter([p[1] for sent in sentences for p in sent])\n    print(\"All tags after normalization:\")\n    print(list(all_tags.keys()))\n\n    num = len(sentences)\n    train_num = int(num*0.7)\n    dev_num = int(num*0.15)\n\n    random.seed(1234)\n\n    random.shuffle(sentences)\n\n    train_sents = sentences[:train_num]\n    dev_sents = sentences[train_num:train_num+dev_num]\n    test_sents = sentences[train_num+dev_num:]\n\n    shuffled_dataset = [train_sents, dev_sents, test_sents]\n\n    write_dataset(shuffled_dataset, base_output_path, short_name)\n\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_bn_daffodil.py",
    "content": "\"\"\"\nConvert a Bengali NER dataset to our internal .json format\n\nThe dataset is here:\n\nhttps://github.com/Rifat1493/Bengali-NER/tree/master/Input\n\"\"\"\n\nimport argparse\nimport os\nimport random\nimport tempfile\n\nfrom stanza.utils.datasets.ner.utils import read_tsv, write_dataset\n\ndef redo_time_tags(sentences):\n    \"\"\"\n    Replace all TIM, TIM with B-TIM, I-TIM\n\n    A brief use of Google Translate suggests the time phrases are\n    generally one phrase, so we don't want to turn this into B-TIM, B-TIM\n    \"\"\"\n    new_sentences = []\n\n    for sentence in sentences:\n        new_sentence = []\n        prev_time = False\n        for word, tag in sentence:\n            if tag == 'TIM':\n                if prev_time:\n                    new_sentence.append((word, \"I-TIM\"))\n                else:\n                    prev_time = True\n                    new_sentence.append((word, \"B-TIM\"))\n            else:\n                prev_time = False\n                new_sentence.append((word, tag))\n        new_sentences.append(new_sentence)\n\n    return new_sentences\n\ndef strip_words(dataset):\n    return [[(x[0].strip().replace('\\ufeff', ''), x[1]) for x in sentence] for sentence in dataset]\n\ndef filter_blank_words(train_file, train_filtered_file):\n    \"\"\"\n    As of July 2022, this dataset has blank words with O labels, which is not ideal\n\n    This method removes those lines\n    \"\"\"\n    with open(train_file, encoding=\"utf-8\") as fin:\n        with open(train_filtered_file, \"w\", encoding=\"utf-8\") as fout:\n            for line in fin:\n                if line.strip() == 'O':\n                    continue\n                fout.write(line)\n\ndef filter_broken_tags(train_sentences):\n    \"\"\"\n    Eliminate any sentences where any of the tags were empty\n    \"\"\"\n    return [x for x in train_sentences if not any(y[1] is None for y in x)]\n\ndef filter_bad_words(train_sentences):\n    \"\"\"\n    Not bad words like poop, but characters that don't exist\n\n    These characters look like n and l in emacs, but they are really\n    0xF06C and 0xF06E\n    \"\"\"\n    return [[x for x in sentence if not x[0] in (\"\", \"\")] for sentence in train_sentences]\n\ndef read_datasets(in_directory):\n    \"\"\"\n    Reads & splits the train data, reads the test data\n\n    There is no validation data, so we split the training data into\n    two pieces and use the smaller piece as the dev set\n\n    Also performeed is a conversion of TIM -> B-TIM, I-TIM\n    \"\"\"\n    # make sure we always get the same shuffle & split\n    random.seed(1234)\n\n    train_file = os.path.join(in_directory, \"Input\", \"train_data.txt\")\n    with tempfile.TemporaryDirectory() as tempdir:\n        train_filtered_file = os.path.join(tempdir, \"train.txt\")\n        filter_blank_words(train_file, train_filtered_file)\n        train_sentences = read_tsv(train_filtered_file, text_column=0, annotation_column=1, keep_broken_tags=True)\n    train_sentences = filter_broken_tags(train_sentences)\n    train_sentences = filter_bad_words(train_sentences)\n    train_sentences = redo_time_tags(train_sentences)\n    train_sentences = strip_words(train_sentences)\n\n    test_file = os.path.join(in_directory, \"Input\", \"test_data.txt\")\n    test_sentences = read_tsv(test_file, text_column=0, annotation_column=1, keep_broken_tags=True)\n    test_sentences = filter_broken_tags(test_sentences)\n    test_sentences = filter_bad_words(test_sentences)\n    test_sentences = redo_time_tags(test_sentences)\n    test_sentences = strip_words(test_sentences)\n\n    random.shuffle(train_sentences)\n    split_len = len(train_sentences) * 9 // 10\n    dev_sentences = train_sentences[split_len:]\n    train_sentences = train_sentences[:split_len]\n\n    datasets = (train_sentences, dev_sentences, test_sentences)\n    return datasets\n\ndef convert_dataset(in_directory, out_directory):\n    \"\"\"\n    Reads the datasets using read_datasets, then write them back out\n    \"\"\"\n    datasets = read_datasets(in_directory)\n    write_dataset(datasets, out_directory, \"bn_daffodil\")\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input_path', type=str, default=\"/home/john/extern_data/ner/bangla/Bengali-NER\", help=\"Where to find the files\")\n    parser.add_argument('--output_path', type=str, default=\"/home/john/stanza/data/ner\", help=\"Where to output the results\")\n    args = parser.parse_args()\n\n    convert_dataset(args.input_path, args.output_path)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_bsf_to_beios.py",
    "content": "import argparse\nimport logging\nimport os\nimport glob\nfrom collections import namedtuple\nimport re\nfrom typing import Tuple\nfrom tqdm import tqdm\nfrom random import choices, shuffle\n\nBsfInfo = namedtuple('BsfInfo', 'id, tag, start_idx, end_idx, token')\n\nlog = logging.getLogger(__name__)\nlog.setLevel(logging.INFO)\n\n\ndef format_token_as_beios(token: str, tag: str) -> list:\n    t_words = token.split()\n    res = []\n    if len(t_words) == 1:\n        res.append(token + ' S-' + tag)\n    else:\n        res.append(t_words[0] + ' B-' + tag)\n        for t_word in t_words[1: -1]:\n            res.append(t_word + ' I-' + tag)\n        res.append(t_words[-1] + ' E-' + tag)\n    return res\n\n\ndef format_token_as_iob(token: str, tag: str) -> list:\n    t_words = token.split()\n    res = []\n    if len(t_words) == 1:\n        res.append(token + ' B-' + tag)\n    else:\n        res.append(t_words[0] + ' B-' + tag)\n        for t_word in t_words[1:]:\n            res.append(t_word + ' I-' + tag)\n    return res\n\n\ndef convert_bsf(data: str, bsf_markup: str, converter: str = 'beios') -> str:\n    \"\"\"\n    Convert data file with NER markup in Brat Standoff Format to BEIOS or IOB format.\n\n    :param converter: iob or beios converter to use for document\n    :param data: tokenized data to be converted. Each token separated with a space\n    :param bsf_markup: Brat Standoff Format markup\n    :return: data in BEIOS or IOB format https://en.wikipedia.org/wiki/Inside–outside–beginning_(tagging)\n    \"\"\"\n\n    def join_simple_chunk(chunk: str) -> list:\n        if len(chunk.strip()) == 0:\n            return []\n        # keep the newlines, but discard the non-newline whitespace\n        tokens = re.split(r'(\\n)|\\s', chunk.strip())\n        # the re will return None for splits which were not caught in a group\n        tokens = [x for x in tokens if x is not None]\n        return [token + ' O' if len(token.strip()) > 0 else token for token in tokens]\n\n    converters = {'beios': format_token_as_beios, 'iob': format_token_as_iob}\n    res = []\n    markup = parse_bsf(bsf_markup)\n\n    prev_idx = 0\n    m_ln: BsfInfo\n    for m_ln in markup:\n        res += join_simple_chunk(data[prev_idx:m_ln.start_idx])\n\n        convert_f = converters[converter]\n        res.extend(convert_f(m_ln.token, m_ln.tag))\n        prev_idx = m_ln.end_idx\n\n    if prev_idx < len(data) - 1:\n        res += join_simple_chunk(data[prev_idx:])\n\n    return '\\n'.join(res)\n\n\ndef parse_bsf(bsf_data: str) -> list:\n    \"\"\"\n    Convert textual bsf representation to a list of named entities.\n\n    :param bsf_data: data in the format 'T9\tPERS 778 783    токен'\n    :return: list of named tuples for each line of the data representing a single named entity token\n    \"\"\"\n    if len(bsf_data.strip()) == 0:\n        return []\n\n    ln_ptrn = re.compile(r'(T\\d+)\\s(\\w+)\\s(\\d+)\\s(\\d+)\\s(.+?)(?=T\\d+\\s\\w+\\s\\d+\\s\\d+|$)', flags=re.DOTALL)\n    result = []\n    for m in ln_ptrn.finditer(bsf_data.strip()):\n        bsf = BsfInfo(m.group(1), m.group(2), int(m.group(3)), int(m.group(4)), m.group(5).strip())\n        result.append(bsf)\n    return result\n\n\nCORPUS_NAME = 'Ukrainian-languk'\n\n\ndef convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str = 'beios',\n                          doc_delim: str = '\\n', train_test_split_file: str = None) -> None:\n    \"\"\"\n\n    :param doc_delim: delimiter to be used between documents\n    :param src_dir_path: path to directory with BSF marked files\n    :param dst_dir_path: where to save output data\n    :param converter: `beios` or `iob` output formats\n    :param train_test_split_file: path to file containing train/test lists of file names\n    :return:\n    \"\"\"\n    ann_path = os.path.join(src_dir_path, '*.tok.ann')\n    ann_files = glob.glob(ann_path)\n    ann_files.sort()\n\n    tok_path = os.path.join(src_dir_path, '*.tok.txt')\n    tok_files = glob.glob(tok_path)\n    tok_files.sort()\n\n    corpus_folder = os.path.join(dst_dir_path, CORPUS_NAME)\n    if not os.path.exists(corpus_folder):\n        os.makedirs(corpus_folder)\n\n    if len(ann_files) == 0 or len(tok_files) == 0:\n        raise FileNotFoundError(f'Token and annotation files are not found at specified path {ann_path}')\n    if len(ann_files) != len(tok_files):\n        raise RuntimeError(f'Mismatch between Annotation and Token files. Ann files: {len(ann_files)}, token files: {len(tok_files)}')\n\n    train_set = []\n    dev_set = []\n    test_set = []\n\n    data_sets = [train_set, dev_set, test_set]\n    split_weights = (8, 1, 1)\n\n    if train_test_split_file is not None:\n        train_names, dev_names, test_names = read_languk_train_test_split(train_test_split_file)\n\n    log.info(f'Found {len(tok_files)} files in data folder \"{src_dir_path}\"')\n    for (tok_fname, ann_fname) in tqdm(zip(tok_files, ann_files), total=len(tok_files), unit='file'):\n        if tok_fname[:-3] != ann_fname[:-3]:\n            tqdm.write(f'Token and Annotation file names do not match ann={ann_fname}, tok={tok_fname}')\n            continue\n\n        with open(tok_fname) as tok_file, open(ann_fname) as ann_file:\n            token_data = tok_file.read()\n            ann_data = ann_file.read()\n            out_data = convert_bsf(token_data, ann_data, converter)\n\n            if train_test_split_file is None:\n                target_dataset = choices(data_sets, split_weights)[0]\n            else:\n                target_dataset = train_set\n                fkey = os.path.basename(tok_fname)[:-4]\n                if fkey in dev_names:\n                    target_dataset = dev_set\n                elif fkey in test_names:\n                    target_dataset = test_set\n\n            target_dataset.append(out_data)\n    log.info(f'Data is split as following: train={len(train_set)}, dev={len(dev_set)}, test={len(test_set)}')\n\n    # writing data to {train/dev/test}.bio files\n    names = ['train', 'dev', 'test']\n    if doc_delim != '\\n':\n        doc_delim = '\\n' + doc_delim + '\\n'\n    for idx, name in enumerate(names):\n        fname = os.path.join(corpus_folder, name + '.bio')\n        with open(fname, 'w') as f:\n            f.write(doc_delim.join(data_sets[idx]))\n        log.info('Writing to ' + fname)\n\n    log.info('All done')\n\n\ndef read_languk_train_test_split(file_path: str, dev_split: float = 0.1) -> Tuple:\n    \"\"\"\n    Read predefined split of train and test files in data set. \n    Originally located under doc/dev-test-split.txt\n    :param file_path: path to dev-test-split.txt file (should include file name with extension)\n    :param dev_split: 0 to 1 float value defining how much to allocate to dev split\n    :return: tuple of (train, dev, test) each containing list of files to be used for respective data sets\n    \"\"\"\n    log.info(f'Trying to read train/dev/test split from file \"{file_path}\". Dev allocation = {dev_split}')\n    train_files, test_files, dev_files = [], [], []\n    container = test_files\n    with open(file_path, 'r') as f:\n        for ln in f:\n            ln = ln.strip()\n            if ln == 'DEV':\n                container = train_files\n            elif ln == 'TEST':\n                container = test_files\n            elif ln == '':\n                pass\n            else:\n                container.append(ln)\n\n    # split in file only contains train and test split. \n    # For Stanza training we need train, dev, test\n    # We will take part of train as dev set \n    # This way anyone using test set outside of this code base can be sure that there was no data set polution            \n    shuffle(train_files)\n    dev_files = train_files[: int(len(train_files) * dev_split)]\n    train_files = train_files[int(len(train_files) * dev_split):]\n\n    assert len(set(train_files).intersection(set(dev_files))) == 0\n    \n    log.info(f'Files in each set: train={len(train_files)}, dev={len(dev_files)}, test={len(test_files)}')\n    return train_files, dev_files, test_files\n\n\nif __name__ == '__main__':\n    logging.basicConfig()\n\n    parser = argparse.ArgumentParser(description='Convert lang-uk NER data set from BSF format to BEIOS format compatible with Stanza NER model training requirements.\\n'\n                                                 'Original data set should be downloaded from https://github.com/lang-uk/ner-uk\\n'\n                                                 'For example, create a directory extern_data/lang_uk, then run \"git clone git@github.com:lang-uk/ner-uk.git')\n    parser.add_argument('--src_dataset', type=str, default='extern_data/ner/lang-uk/ner-uk/data', help='Dir with lang-uk dataset \"data\" folder (https://github.com/lang-uk/ner-uk)')\n    parser.add_argument('--dst', type=str, default='data/ner', help='Where to store the converted dataset')\n    parser.add_argument('-c', type=str, default='beios', help='`beios` or `iob` formats to be used for output')\n    parser.add_argument('--doc_delim', type=str, default='\\n', help='Delimiter to be used to separate documents in the output data')\n    parser.add_argument('--split_file', type=str, help='Name of a file containing Train/Test split (files in train and test set)')\n    parser.print_help()\n    args = parser.parse_args()\n\n    convert_bsf_in_folder(args.src_dataset, args.dst, args.c, args.doc_delim, train_test_split_file=args.split_file)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_bsnlp.py",
    "content": "import argparse\nimport glob\nimport os\nimport logging\nimport random\nimport re\n\nimport stanza\n\nlogger = logging.getLogger('stanza')\n\nAVAILABLE_LANGUAGES = (\"bg\", \"cs\", \"pl\", \"ru\")\n\ndef normalize_bg_entity(text, entity, raw):\n    entity = entity.strip()\n    # sanity check that the token is in the original text\n    if text.find(entity) >= 0:\n        return entity\n\n    # some entities have quotes, but the quotes are different from those in the data file\n    # for example:\n    #   training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_458.txt\n    #     'Съвета \"Общи въпроси\"'\n    #   training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1002.txt\n    #     'Съвет \"Общи въпроси\"'\n    if sum(1 for x in entity if x == '\"') == 2:\n        quote_entity = entity.replace('\"', '“')\n        if text.find(quote_entity) >= 0:\n            logger.info(\"searching for '%s' instead of '%s' in %s\" % (quote_entity, entity, raw))\n            return quote_entity\n\n        quote_entity = entity.replace('\"', '„', 1).replace('\"', '“')\n        if text.find(quote_entity) >= 0:\n            logger.info(\"searching for '%s' instead of '%s' in %s\" % (quote_entity, entity, raw))\n            return quote_entity\n\n    if sum(1 for x in entity if x == '\"') == 1:\n        quote_entity = entity.replace('\"', '„', 1)\n        if text.find(quote_entity) >= 0:\n            logger.info(\"searching for '%s' instead of '%s' in %s\" % (quote_entity, entity, raw))\n            return quote_entity\n\n    if entity.find(\"'\") >= 0:\n        quote_entity = entity.replace(\"'\", \"’\")\n        if text.find(quote_entity) >= 0:\n            logger.info(\"searching for '%s' instead of '%s' in %s\" % (quote_entity, entity, raw))\n            return quote_entity\n\n    lower_idx = text.lower().find(entity.lower())\n    if lower_idx >= 0:\n        fixed_entity = text[lower_idx:lower_idx+len(entity)]\n        logger.info(\"lowercase match found.  Searching for '%s' instead of '%s' in %s\" % (fixed_entity, entity, raw))\n        return fixed_entity\n\n    substitution_pairs = {\n        # this exact error happens in:\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_67.txt\n        'Съвет по общи въпроси': 'Съвета по общи въпроси',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_214.txt\n        'Сумимото Мицуи файненшъл груп': 'Сумитомо Мицуи файненшъл груп',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_245.txt\n        'С и Д': 'С&Д',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_348.txt\n        'законопроекта за излизане на Великобритания за излизане от Европейския съюз': 'законопроекта за излизане на Великобритания от Европейския съюз',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_771.txt\n        'Унивеситета в Есекс': 'Университета в Есекс',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_779.txt\n        'Съвет за сигурност на ООН': 'Съвета за сигурност на ООН',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_787.txt\n        'Федерика Могерини': 'Федереика Могерини',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_938.txt\n        'Уайстейбъл': 'Уайтстейбъл',\n        'Партията за независимост на Обединеното кралство': 'Партията на независимостта на Обединеното кралство',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_972.txt\n        'Европейска банка за възстановяване и развитие': 'Европейската банка за възстановяване и развитие',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1065.txt\n        'Харолд Уилсон': 'Харолд Уилсън',\n        'Манчестърски университет': 'Манчестърския университет',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1096.txt\n        'Обединеното кралство в променящата се Европа': 'Обединеното кралство в променяща се Европа',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1175.txt\n        'The Daily Express': 'Daily Express',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1186.txt\n        'демократичната юнионистка партия': 'демократична юнионистка партия',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1192.txt\n        'Европейската агенция за безопасността на полетите': 'Европейската агенция за сигурността на полетите',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1219.txt\n        'пресцентъра на Външно министертво': 'пресцентъра на Външно министерство',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1281.txt\n        'Европейска агенциа за безопасността на полетите': 'Европейската агенция за сигурността на полетите',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1415.txt\n        'Хонк Конг': 'Хонг Конг',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1663.txt\n        'Лейбъристка партия': 'Лейбъристката партия',\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1963.txt\n        'Найджъл Фараж': 'Найджъл Фарадж',\n        'Фараж': 'Фарадж',\n\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1773.txt has an entity which is mixed Cyrillic and Ascii\n        'Tescо': 'Tesco',\n    }\n\n    if entity in substitution_pairs and text.find(substitution_pairs[entity]) >= 0:\n        fixed_entity = substitution_pairs[entity]\n        logger.info(\"searching for '%s' instead of '%s' in %s\" % (fixed_entity, entity, raw))\n        return fixed_entity\n\n    # oops, can't find it anywhere\n    # want to raise ValueError but there are just too many in the train set for BG\n    logger.error(\"Could not find '%s' in %s\" % (entity, raw))\n\ndef fix_bg_typos(text, raw_filename):\n    typo_pairs = {\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_202.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters\n        'brexit_bg.txt_file_202.txt':  ('Вlооmbеrg', 'Bloomberg'),\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_261.txt has a typo: Telegaph instead of Telegraph\n        'brexit_bg.txt_file_261.txt':  ('Telegaph', 'Telegraph'),\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_574.txt has a typo: politicalskrapbook instead of politicalscrapbook\n        'brexit_bg.txt_file_574.txt':  ('politicalskrapbook', 'politicalscrapbook'),\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_861.txt has a mix of cyrillic and ascii\n        'brexit_bg.txt_file_861.txt':  ('Съвета „Общи въпроси“', 'Съветa \"Общи въпроси\"'),\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_992.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters\n        'brexit_bg.txt_file_992.txt':  ('The Guardiаn', 'The Guardian'),\n        # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1856.txt has a typo: Southerb instead of Southern\n        'brexit_bg.txt_file_1856.txt': ('Southerb', 'Southern'),\n    }\n\n    filename = os.path.split(raw_filename)[1]\n    if filename in typo_pairs:\n        replacement = typo_pairs.get(filename)\n        text = text.replace(replacement[0], replacement[1])\n\n    return text\n\ndef get_sentences(language, pipeline, annotated, raw):\n    if language == 'bg':\n        normalize_entity = normalize_bg_entity\n        fix_typos = fix_bg_typos\n    else:\n        raise AssertionError(\"Please build a normalize_%s_entity and fix_%s_typos first\" % language)\n\n    annotated_sentences = []\n    with open(raw) as fin:\n        lines = fin.readlines()\n    if len(lines) < 5:\n        raise ValueError(\"Unexpected format in %s\" % raw)\n    text = \"\\n\".join(lines[4:])\n    text = fix_typos(text, raw)\n\n    entities = {}\n    with open(annotated) as fin:\n        # first line\n        header = fin.readline().strip()\n        if len(header.split(\"\\t\")) > 1:\n            raise ValueError(\"Unexpected missing header line in %s\" % annotated)\n        for line in fin:\n            pieces = line.strip().split(\"\\t\")\n            if len(pieces) < 3 or len(pieces) > 4:\n                raise ValueError(\"Unexpected annotation format in %s\" % annotated)\n\n            entity = normalize_entity(text, pieces[0], raw)\n            if not entity:\n                continue\n            if entity in entities:\n                if entities[entity] != pieces[2]:\n                    # would like to make this an error, but it actually happens and it's not clear how to fix\n                    # annotated/nord_stream/bg/nord_stream_bg.txt_file_119.out\n                    logger.warn(\"found multiple definitions for %s in %s\" % (pieces[0], annotated))\n                    entities[entity] = pieces[2]\n            else:\n                entities[entity] = pieces[2]\n\n    tokenized = pipeline(text)\n    # The benefit of doing these one at a time, instead of all at once,\n    # is that nested entities won't clobber previously labeled entities.\n    # For example, the file\n    #   training_pl_cs_ru_bg_rc1/annotated/bg/brexit_bg.txt_file_994.out\n    # has each of:\n    #   Северна Ирландия\n    #   Република Ирландия\n    #   Ирландия\n    # By doing the larger ones first, we can detect and skip the ones\n    # we already labeled when we reach the shorter one\n    regexes = [re.compile(re.escape(x)) for x in sorted(entities.keys(), key=len, reverse=True)]\n\n    bad_sentences = set()\n\n    for regex in regexes:\n        for match in regex.finditer(text):\n            start_char, end_char = match.span()\n            # this is inefficient, but for something only run once, it shouldn't matter\n            start_token = None\n            start_sloppy = False\n            end_token = None\n            end_sloppy = False\n            for token in tokenized.iter_tokens():\n                if token.start_char <= start_char and token.end_char > start_char:\n                    start_token = token\n                    if token.start_char != start_char:\n                        start_sloppy = True\n                if token.start_char <= end_char and token.end_char >= end_char:\n                    end_token = token\n                    if token.end_char != end_char:\n                        end_sloppy = True\n                    break\n            if start_token is None or end_token is None:\n                raise RuntimeError(\"Match %s did not align with any tokens in %s\" % (match.group(0), raw))\n            if not start_token.sent is end_token.sent:\n                bad_sentences.add(start_token.sent.id)\n                bad_sentences.add(end_token.sent.id)\n                logger.warn(\"match %s spanned sentences %d and %d in document %s\" % (match.group(0), start_token.sent.id, end_token.sent.id, raw))\n                continue\n\n            # ids start at 1, not 0, so we have to subtract 1\n            # then the end token is included, so we add back the 1\n            # TODO: verify that this is correct if the language has MWE - cs, pl, for example\n            tokens = start_token.sent.tokens[start_token.id[0]-1:end_token.id[0]]\n            if all(token.ner for token in tokens):\n                # skip matches which have already been made\n                # this has the nice side effect of not complaining if\n                # a smaller match is found after a larger match\n                # earlier set the NER on those tokens\n                continue\n\n            if start_sloppy and end_sloppy:\n                bad_sentences.add(start_token.sent.id)\n                logger.warn(\"match %s matched in the middle of a token in %s\" % (match.group(0), raw))\n                continue\n            if start_sloppy:\n                bad_sentences.add(end_token.sent.id)\n                logger.warn(\"match %s started matching in the middle of a token in %s\" % (match.group(0), raw))\n                #print(start_token)\n                #print(end_token)\n                #print(start_char, end_char)\n                continue\n            if end_sloppy:\n                bad_sentences.add(start_token.sent.id)\n                logger.warn(\"match %s ended matching in the middle of a token in %s\" % (match.group(0), raw))\n                #print(start_token)\n                #print(end_token)\n                #print(start_char, end_char)\n                continue\n            match_text = match.group(0)\n            if match_text not in entities:\n                raise RuntimeError(\"Matched %s, which is not in the entities from %s\" % (match_text, annotated))\n            ner_tag = entities[match_text]\n            tokens[0].ner = \"B-\" + ner_tag\n            for token in tokens[1:]:\n                token.ner = \"I-\" + ner_tag\n\n    for sentence in tokenized.sentences:\n        if not sentence.id in bad_sentences:\n            annotated_sentences.append(sentence)\n\n    return annotated_sentences\n\ndef write_sentences(output_filename, annotated_sentences):\n    logger.info(\"Writing %d sentences to %s\" % (len(annotated_sentences), output_filename))\n    with open(output_filename, \"w\") as fout:\n        for sentence in annotated_sentences:\n            for token in sentence.tokens:\n                ner_tag = token.ner\n                if not ner_tag:\n                    ner_tag = \"O\"\n                fout.write(\"%s\\t%s\\n\" % (token.text, ner_tag))\n            fout.write(\"\\n\")\n\n\ndef convert_bsnlp(language, base_input_path, output_filename, split_filename=None):\n    \"\"\"\n    Converts the BSNLP dataset for the given language.\n\n    If only one output_filename is provided, all of the output goes to that file.\n    If split_filename is provided as well, 15% of the output chosen randomly\n      goes there instead.  The dataset has no dev set, so this helps\n      divide the data into train/dev/test.\n    Note that the custom error fixes are only done for BG currently.\n    Please manually correct the data as appropriate before using this\n      for another language.\n    \"\"\"\n    if language not in AVAILABLE_LANGUAGES:\n        raise ValueError(\"The current BSNLP datasets only include the following languages: %s\" % \",\".join(AVAILABLE_LANGUAGES))\n    if language != \"bg\":\n        raise ValueError(\"There were quite a few data fixes needed to get the data correct for BG.  Please work on similar fixes before using the model for %s\" % language.upper())\n    pipeline = stanza.Pipeline(language, processors=\"tokenize\")\n    random.seed(1234)\n\n    annotated_path = os.path.join(base_input_path, \"annotated\", \"*\", language, \"*\")\n    annotated_files = sorted(glob.glob(annotated_path))\n    raw_path = os.path.join(base_input_path, \"raw\", \"*\", language, \"*\")\n    raw_files = sorted(glob.glob(raw_path))\n\n    # if the instructions for downloading the data from the\n    # process_ner_dataset script are followed, there will be two test\n    # directories of data and a separate training directory of data.\n    if len(annotated_files) == 0 and len(raw_files) == 0:\n        logger.info(\"Could not find files in %s\" % annotated_path)\n        annotated_path = os.path.join(base_input_path, \"annotated\", language, \"*\")\n        logger.info(\"Trying %s instead\" % annotated_path)\n        annotated_files = sorted(glob.glob(annotated_path))\n        raw_path = os.path.join(base_input_path, \"raw\", language, \"*\")\n        raw_files = sorted(glob.glob(raw_path))\n\n    if len(annotated_files) != len(raw_files):\n        raise ValueError(\"Unexpected differences in the file lists between %s and %s\" % (annotated_files, raw_files))\n\n    for i, j in zip(annotated_files, raw_files):\n        if os.path.split(i)[1][:-4] != os.path.split(j)[1][:-4]:\n            raise ValueError(\"Unexpected differences in the file lists: found %s instead of %s\" % (i, j))\n\n    annotated_sentences = []\n    if split_filename:\n        split_sentences = []\n    for annotated, raw in zip(annotated_files, raw_files):\n        new_sentences = get_sentences(language, pipeline, annotated, raw)\n        if not split_filename or random.random() < 0.85:\n            annotated_sentences.extend(new_sentences)\n        else:\n            split_sentences.extend(new_sentences)\n\n    write_sentences(output_filename, annotated_sentences)\n    if split_filename:\n        write_sentences(split_filename, split_sentences)\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--language', type=str, default=\"bg\", help=\"Language to process\")\n    parser.add_argument('--input_path', type=str, default=\"/home/john/extern_data/ner/bsnlp2019\", help=\"Where to find the files\")\n    parser.add_argument('--output_path', type=str, default=\"/home/john/stanza/data/ner/bg_bsnlp.test.csv\", help=\"Where to output the results\")\n    parser.add_argument('--dev_path', type=str, default=None, help=\"A secondary output path - 15% of the data will go here\")\n    args = parser.parse_args()\n\n    convert_bsnlp(args.language, args.input_path, args.output_path, args.dev_path)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_en_conll03.py",
    "content": "\"\"\"\nDownloads (if necessary) conll03 from Huggingface, then converts it to Stanza .json\n\nSome online sources for CoNLL 2003 require multiple pieces, but it is currently hosted on HF:\nhttps://huggingface.co/datasets/conll2003\n\"\"\"\n\nimport os\n\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.datasets.ner.utils import write_dataset\n\nTAG_TO_ID = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}\nID_TO_TAG = {y: x for x, y in TAG_TO_ID.items()}\n\ndef convert_dataset_section(section):\n    sentences = []\n    for item in section:\n        words = item['tokens']\n        tags = [ID_TO_TAG[x] for x in item['ner_tags']]\n        sentences.append(list(zip(words, tags)))\n    return sentences\n\ndef process_dataset(short_name, conll_path, ner_output_path):\n    try:\n        from datasets import load_dataset\n    except ImportError as e:\n        raise ImportError(\"Please install the datasets package to process CoNLL03 with Stanza\")\n\n    dataset = load_dataset('conll2003', cache_dir=conll_path)\n    datasets = [convert_dataset_section(x) for x in [dataset['train'], dataset['validation'], dataset['test']]]\n    write_dataset(datasets, ner_output_path, short_name)\n\ndef main():\n    paths = get_default_paths()\n    ner_input_path = paths['NERBASE']\n    conll_path = os.path.join(ner_input_path, \"english\", \"en_conll03\")\n    ner_output_path = paths['NER_DATA_DIR']\n    process_dataset(\"en_conll03\", conll_path, ner_output_path)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_fire_2013.py",
    "content": "\"\"\"\nConverts the FIRE 2013 dataset to TSV\n\nhttp://au-kbc.org/nlp/NER-FIRE2013/index.html\n\nThe dataset is in six tab separated columns.  The columns are\n\nword tag chunk ner1 ner2 ner3\n\nThis script keeps just the word and the ner1.  It is quite possible that using the tag would help\n\"\"\"\n\nimport argparse\nimport glob\nimport os\nimport random\n\ndef normalize(e1, e2, e3):\n    if e1 == 'o':\n        return \"O\"\n\n    if e2 != 'o' and e1[:2] != e2[:2]:\n        raise ValueError(\"Found a token with conflicting position tags %s,%s\" % (e1, e2))\n    if e3 != 'o' and e2 == 'o':\n        raise ValueError(\"Found a token with tertiary label but no secondary label %s,%s,%s\" % (e1, e2, e3))\n    if e3 != 'o' and (e1[:2] != e2[:2] or e1[:2] != e3[:2]):\n        raise ValueError(\"Found a token with conflicting position tags %s,%s,%s\" % (e1, e2, e3))\n\n    if e1[2:] in ('ORGANIZATION', 'FACILITIES'):\n        return e1\n    if e1[2:] == 'ENTERTAINMENT' and e2[2:] != 'SPORTS' and e2[2:] != 'CINEMA':\n        return e1\n    if e1[2:] == 'DISEASE' and e2 == 'o':\n        return e1\n    if e1[2:] == 'PLANTS' and e2[2:] != 'PARTS':\n        return e1\n    if e1[2:] == 'PERSON' and e2[2:] == 'INDIVIDUAL':\n        return e1\n    if e1[2:] == 'LOCATION' and e2[2:] == 'PLACE':\n        return e1\n    if e1[2:] in ('DATE', 'TIME', 'YEAR'):\n        string = e1[:2] + 'DATETIME'\n        return string\n\n    return \"O\"\n\ndef read_fileset(filenames):\n    # first, read the sentences from each data file\n    sentences = []\n    for filename in filenames:\n        with open(filename) as fin:\n            next_sentence = []\n            for line in fin:\n                line = line.strip()\n                if not line:\n                    # lots of single line \"sentences\" in the dataset\n                    if next_sentence:\n                        if len(next_sentence) > 1:\n                            sentences.append(next_sentence)\n                        next_sentence = []\n                else:\n                    next_sentence.append(line)\n            if next_sentence and len(next_sentence) > 1:\n                sentences.append(next_sentence)\n    return sentences\n\ndef write_fileset(output_csv_file, sentences):\n    with open(output_csv_file, \"w\") as fout:\n        for sentence in sentences:\n            for line in sentence:\n                pieces = line.split(\"\\t\")\n                if len(pieces) != 6:\n                    raise ValueError(\"Found %d pieces instead of the expected 6\" % len(pieces))\n                if pieces[3] == 'o' and (pieces[4] != 'o' or pieces[5] != 'o'):\n                    raise ValueError(\"Inner NER labeled but the top layer was O\")\n                fout.write(\"%s\\t%s\\n\" % (pieces[0], normalize(pieces[3], pieces[4], pieces[5])))\n            fout.write(\"\\n\")\n\ndef convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file):\n    random.seed(1234)\n\n    filenames = glob.glob(os.path.join(input_path, \"*\"))\n\n    # won't be numerically sorted... shouldn't matter\n    filenames = sorted(filenames)\n    random.shuffle(filenames)\n\n    sentences = read_fileset(filenames)\n    random.shuffle(sentences)\n\n    train_cutoff = int(0.8 * len(sentences))\n    dev_cutoff   = int(0.9 * len(sentences))\n\n    train_sentences = sentences[:train_cutoff]\n    dev_sentences   = sentences[train_cutoff:dev_cutoff]\n    test_sentences  = sentences[dev_cutoff:]\n\n    random.shuffle(train_sentences)\n    random.shuffle(dev_sentences)\n    random.shuffle(test_sentences)\n\n    assert len(train_sentences) > 0\n    assert len(dev_sentences) > 0\n    assert len(test_sentences) > 0\n\n    write_fileset(train_csv_file, train_sentences)\n    write_fileset(dev_csv_file,   dev_sentences)\n    write_fileset(test_csv_file,  test_sentences)\n    \nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input_path', type=str, default=\"/home/john/extern_data/ner/FIRE2013/hindi_train\",  help=\"Directory with raw files to read\")\n    parser.add_argument('--train_file', type=str, default=\"/home/john/stanza/data/ner/hi_fire2013.train.csv\", help=\"Where to put the train file\")\n    parser.add_argument('--dev_file',   type=str, default=\"/home/john/stanza/data/ner/hi_fire2013.dev.csv\",   help=\"Where to put the dev file\")\n    parser.add_argument('--test_file',  type=str, default=\"/home/john/stanza/data/ner/hi_fire2013.test.csv\",  help=\"Where to put the test file\")\n    args = parser.parse_args()\n\n    convert_fire_2013(args.input_path, args.train_file, args.dev_file, args.test_file)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_he_iahlt.py",
    "content": "from collections import defaultdict\nimport os\nimport re\n\nfrom stanza.utils.conll import CoNLL\nimport stanza.utils.default_paths as default_paths\nfrom stanza.utils.datasets.ner.utils import write_dataset\n\ndef output_entities(sentence):\n    for word in sentence.words:\n        misc = word.misc\n        if misc is None:\n            continue\n\n        pieces = misc.split(\"|\")\n        for piece in pieces:\n            if piece.startswith(\"Entity=\"):\n                entity = piece.split(\"=\", maxsplit=1)[1]\n                print(\"  \" + entity)\n                break\n\ndef extract_single_sentence(sentence):\n    current_entity = []\n    words = []\n    for word in sentence.words:\n        text = word.text\n        misc = word.misc\n        if misc is None:\n            pieces = []\n        else:\n            pieces = misc.split(\"|\")\n\n        closes = []\n        first_entity = False\n        for piece in pieces:\n            if piece.startswith(\"Entity=\"):\n                entity = piece.split(\"=\", maxsplit=1)[1]\n                entity_pieces = re.split(r\"([()])\", entity)\n                entity_pieces = [x for x in entity_pieces if x]   # remove blanks from re.split\n                entity_idx = 0\n                while entity_idx < len(entity_pieces):\n                    if entity_pieces[entity_idx] == '(':\n                        assert len(entity_pieces) > entity_idx + 1, \"Opening an unspecified entity\"\n                        if len(current_entity) == 0:\n                            first_entity = True\n                        current_entity.append(entity_pieces[entity_idx + 1])\n                        entity_idx += 2\n                    elif entity_pieces[entity_idx] == ')':\n                        assert entity_idx != 0, \"Closing an unspecified entity\"\n                        closes.append(entity_pieces[entity_idx-1])\n                        entity_idx += 1\n                    else:\n                        # the entities themselves get added or removed via the ()\n                        entity_idx += 1\n\n        if len(current_entity) == 0:\n            entity = 'O'\n        else:\n            entity = current_entity[0]\n            entity = \"B-\" + entity if first_entity else \"I-\" + entity\n        words.append((text, entity))\n\n        assert len(current_entity) >= len(closes), \"Too many closes for the current open entities\"\n        for close_entity in closes:\n            # TODO: check the close is closing the right thing\n            assert close_entity == current_entity[-1], \"Closed the wrong entity: %s vs %s\" % (close_entity, current_entity[-1])\n            current_entity = current_entity[:-1]\n    return words\n\ndef extract_sentences(doc):\n    sentences = []\n    for sentence in doc.sentences:\n        try:\n            words = extract_single_sentence(sentence)\n            sentences.append(words)\n        except AssertionError as e:\n            print(\"Skipping sentence %s  ... %s\" % (sentence.sent_id, str(e)))\n            output_entities(sentence)\n\n    return sentences\n\ndef convert_iahlt(udbase, output_dir, short_name):\n    shards = (\"train\", \"dev\", \"test\")\n    ud_datasets = [\"UD_Hebrew-IAHLTwiki\", \"UD_Hebrew-IAHLTknesset\"]\n    base_filenames = [\"he_iahltwiki-ud-%s.conllu\", \"he_iahltknesset-ud-%s.conllu\"]\n    datasets = defaultdict(list)\n\n    for ud_dataset, base_filename in zip(ud_datasets, base_filenames):\n        ud_dataset_path = os.path.join(udbase, ud_dataset)\n        for shard in shards:\n            filename = os.path.join(ud_dataset_path, base_filename % shard)\n            doc = CoNLL.conll2doc(filename)\n            sentences = extract_sentences(doc)\n            print(\"Read %d sentences from %s\" % (len(sentences), filename))\n            datasets[shard].extend(sentences)\n\n    datasets = [datasets[x] for x in shards]\n    write_dataset(datasets, output_dir, short_name)\n\ndef main():\n    paths = default_paths.get_default_paths()\n\n    udbase = paths[\"UDBASE_GIT\"]\n    output_directory = paths[\"NER_DATA_DIR\"]\n    convert_iahlt(udbase, output_directory, \"he_iahlt\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_hy_armtdp.py",
    "content": "\"\"\"\nConvert a ArmTDP-NER dataset to BIO format\n\nThe dataset is here:\n\nhttps://github.com/myavrum/ArmTDP-NER.git\n\"\"\"\n\nimport argparse\nimport os\nimport json\nimport re\nimport stanza\nimport random\nfrom tqdm import tqdm\n\nfrom stanza import DownloadMethod, Pipeline\nimport stanza.utils.default_paths as default_paths\n\ndef read_data(path: str) -> list:\n    \"\"\"\n    Reads the Armenian named entity recognition dataset\n\n    Returns a list of dictionaries.\n    Each dictionary contains information\n    about a paragraph (text, labels, etc.)\n    \"\"\"\n    with open(path, 'r') as file:\n        paragraphs = [json.loads(line) for line in file]\n    return paragraphs\n\n\ndef filter_unicode_broken_characters(text: str) -> str:\n    \"\"\"\n    Removes all unicode characters in text\n    \"\"\"\n    return re.sub(r'\\\\u[A-Za-z0-9]{4}', '', text)\n\n\ndef get_label(tok_start_char: int, tok_end_char: int, labels: list) -> list:\n    \"\"\"\n    Returns the label that corresponds to the given token\n    \"\"\"\n    for label in labels:\n        if label[0] <= tok_start_char and label[1] >= tok_end_char:\n            return label\n    return []\n\n\ndef format_sentences(paragraphs: list, nlp_hy: Pipeline) -> list:\n    \"\"\"\n    Takes a list of paragraphs and returns a list of sentences,\n    where each sentence is a list of tokens along with their respective entity tags.\n    \"\"\"\n    sentences = []\n    for paragraph in tqdm(paragraphs):\n        doc = nlp_hy(filter_unicode_broken_characters(paragraph['text']))\n        for sentence in doc.sentences:\n            sentence_ents = []\n            entity = []\n            for token in sentence.tokens:\n                label = get_label(token.start_char, token.end_char, paragraph['labels'])\n                if label:\n                    entity.append(token.text)\n                    if token.end_char == label[1]:\n                        sentence_ents.append({'tokens': entity,\n                                              'tag': label[2]})\n                        entity = []\n                else:\n                    sentence_ents.append({'tokens': [token.text],\n                                          'tag': 'O'})\n            sentences.append(sentence_ents)\n    return sentences\n\n\ndef convert_to_bioes(sentences: list) -> list:\n    \"\"\"\n    Returns a list of strings where each string represents a sentence in BIOES format\n    \"\"\"\n    beios_sents = []\n    for sentence in tqdm(sentences):\n        sentence_toc = ''\n        for ent in sentence:\n            if ent['tag'] == 'O':\n                sentence_toc += ent['tokens'][0] + '\\tO' + '\\n'\n            else:\n                if len(ent['tokens']) == 1:\n                    sentence_toc += ent['tokens'][0] + '\\tS-' + ent['tag'] + '\\n'\n                else:\n                    sentence_toc += ent['tokens'][0] + '\\tB-' + ent['tag'] + '\\n'\n                    for token in ent['tokens'][1:-1]:\n                        sentence_toc += token + '\\tI-' + ent['tag'] + '\\n'\n                    sentence_toc += ent['tokens'][-1] + '\\tE-' + ent['tag'] + '\\n'\n        beios_sents.append(sentence_toc)\n    return beios_sents\n\n\ndef write_sentences_to_file(sents, filename):\n    print(f\"Writing {len(sents)} sentences to {filename}\")\n    with open(filename, 'w') as outfile:\n        for sent in sents:\n            outfile.write(sent + '\\n\\n')\n\n\ndef train_test_dev_split(sents, base_output_path, short_name, train_fraction=0.7, dev_fraction=0.15):\n    \"\"\"\n    Splits a list of sentences into training, dev, and test sets,\n    and writes each set to a separate file with write_sentences_to_file\n    \"\"\"\n    num = len(sents)\n    train_num = int(num * train_fraction)\n    dev_num = int(num * dev_fraction)\n    if train_fraction + dev_fraction > 1.0:\n        raise ValueError(\n            \"Train and dev fractions added up to more than 1: {} {} {}\".format(train_fraction, dev_fraction))\n\n    random.shuffle(sents)\n    train_sents = sents[:train_num]\n    dev_sents = sents[train_num:train_num + dev_num]\n    test_sents = sents[train_num + dev_num:]\n    batches = [train_sents, dev_sents, test_sents]\n    filenames = [f'{short_name}.train.tsv', f'{short_name}.dev.tsv', f'{short_name}.test.tsv']\n    for batch, filename in zip(batches, filenames):\n        write_sentences_to_file(batch, os.path.join(base_output_path, filename))\n\n\ndef convert_dataset(base_input_path, base_output_path, short_name, download_method=DownloadMethod.DOWNLOAD_RESOURCES):\n    nlp_hy = stanza.Pipeline(lang='hy', processors='tokenize', download_method=download_method)\n    paragraphs = read_data(os.path.join(base_input_path, 'ArmNER-HY.json1'))\n    tagged_sentences = format_sentences(paragraphs, nlp_hy)\n    beios_sentences = convert_to_bioes(tagged_sentences)\n    train_test_dev_split(beios_sentences, base_output_path, short_name)\n\n\nif __name__ == '__main__':\n    paths = default_paths.get_default_paths()\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input_path', type=str, default=os.path.join(paths[\"NERBASE\"], \"armenian\", \"ArmTDP-NER\"), help=\"Path to input file\")\n    parser.add_argument('--output_path', type=str, default=paths[\"NER_DATA_DIR\"], help=\"Path to the output directory\")\n    parser.add_argument('--short_name', type=str, default=\"hy_armtdp\", help=\"Name to identify the dataset and the model\")\n    parser.add_argument('--download_method', type=str, default=DownloadMethod.DOWNLOAD_RESOURCES, help=\"Download method for initializing the Pipeline.  Default downloads the Armenian pipeline, --download_method NONE does not.  Options: %s\" % DownloadMethod._member_names_)\n    args = parser.parse_args()\n\n    convert_dataset(args.input_path, args.output_path, args.short_name, args.download_method)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_ijc.py",
    "content": "import argparse\nimport random\nimport sys\n\n\"\"\"\nConverts IJC data to a TSV format.\n\nSo far, tested on Hindi.  Not checked on any of the other languages.\n\"\"\"\n\ndef convert_tag(tag):\n    \"\"\"\n    Project the classes IJC used to 4 classes with more human-readable names\n\n    The trained result is a pile, as I inadvertently taught my\n    daughter to call horrible things, but leaving them with the\n    original classes is also a pile\n    \"\"\"\n    if not tag:\n        return \"O\"\n    if tag == \"NEP\":\n        return \"PER\"\n    if tag == \"NEO\":\n        return \"ORG\"\n    if tag == \"NEL\":\n        return \"LOC\"\n    return \"MISC\"\n\ndef read_single_file(input_file, bio_format=True):\n    \"\"\"\n    Reads an IJC NER file and returns a list of list of lines\n    \"\"\"\n    sentences = []\n    lineno = 0\n    with open(input_file) as fin:\n        current_sentence = []\n        in_ner = False\n        in_sentence = False\n        printed_first = False\n        nesting = 0\n        for line in fin:\n            lineno = lineno + 1\n            line = line.strip()\n            if not line:\n                continue\n            if line.startswith(\"<Story\") or line.startswith(\"</Story>\"):\n                assert not current_sentence, \"File %s had an unexpected <Story> tag\" % input_file\n                continue\n\n            if line.startswith(\"<Sentence\"):\n                assert not current_sentence, \"File %s has a nested sentence\" % input_file\n                continue\n\n            if line.startswith(\"</Sentence>\"):\n                # Would like to assert that empty sentences don't exist, but alas, they do\n                # assert current_sentence, \"File %s has an empty sentence at %d\" % (input_file, lineno)\n                # AssertionError: File .../hi_ijc/training-hindi/193.naval.utf8 has an empty sentence at 74\n                if current_sentence:\n                    sentences.append(current_sentence)\n                current_sentence = []\n                continue\n\n            if line == \"))\":\n                assert in_sentence, \"File %s closed a sentence when there was no open sentence at %d\" % (input_file, lineno)\n                nesting = nesting - 1\n                if nesting < 0:\n                    in_sentence = False\n                    nesting = 0\n                elif nesting == 0:\n                    in_ner = False\n                continue\n\n            pieces = line.split(\"\\t\")\n            if pieces[0] == '0':\n                assert pieces[1] == '((', \"File %s has an unexpected first line at %d\" % (input_file, lineno)\n                in_sentence = True\n                continue\n\n            if pieces[1] == '((':\n                nesting = nesting + 1\n                if nesting == 1:\n                    if len(pieces) < 4:\n                        tag = None\n                    else:\n                        assert pieces[3][0] == '<' and pieces[3][-1] == '>', \"File %s has an unexpected tag format at %d: %s\" % (input_file, lineno, pieces[3])\n                        ne, tag = pieces[3][1:-1].split('=', 1)\n                        assert pieces[3] == \"<%s=%s>\" % (ne, tag), \"File %s has an unexpected tag format at %d: %s\" % (input_file, lineno, pieces[3])\n                        in_ner = True\n                        printed_first = False\n                        tag = convert_tag(tag)\n            elif in_ner and tag:\n                if bio_format:\n                    if printed_first:\n                        current_sentence.append((pieces[1], \"I-\" + tag))\n                    else:\n                        current_sentence.append((pieces[1], \"B-\" + tag))\n                        printed_first = True\n                else:\n                    current_sentence.append((pieces[1], tag))\n            else:\n                current_sentence.append((pieces[1], \"O\"))\n    assert not current_sentence, \"File %s is unclosed!\" % input_file\n    return sentences\n\ndef read_ijc_files(input_files, bio_format=True):\n    sentences = []\n    for input_file in input_files:\n        sentences.extend(read_single_file(input_file, bio_format))\n    return sentences\n\ndef convert_ijc(input_files, csv_file, bio_format=True):\n    sentences = read_ijc_files(input_files, bio_format)\n    with open(csv_file, \"w\") as fout:\n        for sentence in sentences:\n            for word in sentence:\n                fout.write(\"%s\\t%s\\n\" % word)\n            fout.write(\"\\n\")\n\ndef convert_split_ijc(input_files, train_csv, dev_csv):\n    \"\"\"\n    Randomly splits the given list of input files into a train/dev with 85/15 split\n\n    The original datasets only have train & test\n    \"\"\"\n    random.seed(1234)\n    train_files = []\n    dev_files = []\n    for filename in input_files:\n        if random.random() < 0.85:\n            train_files.append(filename)\n        else:\n            dev_files.append(filename)\n\n    if len(train_files) == 0 or len(dev_files) == 0:\n        raise RuntimeError(\"Not enough files to split into train & dev\")\n\n    convert_ijc(train_files, train_csv)\n    convert_ijc(dev_files, dev_csv)\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--output_path', type=str, default=\"/home/john/stanza/data/ner/hi_ijc.test.csv\", help=\"Where to output the results\")\n    parser.add_argument('input_files', metavar='N', nargs='+', help='input files to process')\n    args = parser.parse_args()\n\n    convert_ijc(args.input_files, args.output_path, False)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_kk_kazNERD.py",
    "content": "\"\"\"\nConvert a Kazakh NER dataset to our internal .json format\nThe dataset is here:\n\nhttps://github.com/IS2AI/KazNERD/tree/main/KazNERD\n\"\"\"\n\nimport argparse\nimport os\nimport shutil\n# import random\n\nfrom stanza.utils.datasets.ner.utils import convert_bio_to_json, SHARDS\n\ndef convert_dataset(in_directory, out_directory, short_name):\n    \"\"\"\n    Reads in train, validation, and test data and converts them to .json file\n    \"\"\"\n    filenames = (\"IOB2_train.txt\", \"IOB2_valid.txt\", \"IOB2_test.txt\")\n    for shard, filename in zip(SHARDS, filenames):\n        input_filename = os.path.join(in_directory, filename)\n        output_filename = os.path.join(out_directory, \"%s.%s.bio\" % (short_name, shard))\n        shutil.copy(input_filename, output_filename)\n    convert_bio_to_json(out_directory, out_directory, short_name, \"bio\")\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input_path', type=str, default=\"/nlp/scr/aaydin/kazNERD/NER\", help=\"Where to find the files\")\n    parser.add_argument('--output_path', type=str, default=\"/nlp/scr/aaydin/kazNERD/data/ner\", help=\"Where to output the results\")\n    args = parser.parse_args()\n    # in_path = '/nlp/scr/aaydin/kazNERD/NER'\n    # out_path = '/nlp/scr/aaydin/kazNERD/NER/output'\n    # convert_dataset(in_path, out_path)\n    convert_dataset(args.input_path, args.output_path, \"kk_kazNERD\")\n\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_lst20.py",
    "content": "\"\"\"\nConverts the Thai LST20 dataset to a format usable by Stanza's NER model\n\nThe dataset in the original format has a few tag errors which we\nautomatically fix (or at worst cover up)\n\"\"\"\n\nimport os\n\nfrom stanza.utils.datasets.ner.utils import convert_bio_to_json\n\ndef convert_lst20(paths, short_name, include_space_char=True):\n    assert short_name == \"th_lst20\"\n    SHARDS = (\"train\", \"eval\", \"test\")\n    BASE_OUTPUT_PATH = paths[\"NER_DATA_DIR\"]\n\n    input_split = [(os.path.join(paths[\"NERBASE\"], \"thai\", \"LST20_Corpus\", x), x) for x in SHARDS]\n\n    if not include_space_char:\n        short_name = short_name + \"_no_ws\"\n\n    for input_folder, split_type in input_split:\n        text_list = [text for text in os.listdir(input_folder) if text[0] == 'T']\n\n        if split_type == \"eval\":\n            split_type = \"dev\"\n\n        output_path = os.path.join(BASE_OUTPUT_PATH, \"%s.%s.bio\" % (short_name, split_type))\n        print(output_path)\n\n        with open(output_path, 'w', encoding='utf-8') as fout:\n            for text in text_list:\n                lst = []\n                with open(os.path.join(input_folder, text), 'r', encoding='utf-8') as fin:\n                    lines = fin.readlines()\n\n                for line_idx, line in enumerate(lines):\n                    x = line.strip().split('\\t')\n                    if len(x) > 1:\n                        if x[0] == '_' and not include_space_char:\n                            continue\n                        else:\n                            word, tag = x[0], x[2]\n\n                            if tag == \"MEA_BI\":\n                                tag = \"B_MEA\"\n                            if tag == \"OBRN_B\":\n                                tag = \"B_BRN\"\n                            if tag == \"ORG_I\":\n                                tag = \"I_ORG\"\n                            if tag == \"PER_I\":\n                                tag = \"I_PER\"\n                            if tag == \"LOC_I\":\n                                tag = \"I_LOC\"\n                            if tag == \"B\" and line_idx + 1 < len(lines):\n                                x_next = lines[line_idx+1].strip().split('\\t')\n                                if len(x_next) > 1:\n                                    tag_next = x_next[2]\n                                    if \"I_\" in tag_next or \"E_\" in tag_next:\n                                        tag = tag + tag_next[1:]\n                                    else:\n                                        tag = \"O\"\n                                else:\n                                    tag = \"O\"\n                            if \"_\" in tag:\n                                tag = tag.replace(\"_\", \"-\")\n                            if \"ABB\" in tag or tag == \"DDEM\" or tag == \"I\" or tag == \"__\":\n                                tag = \"O\"\n\n                            fout.write('{}\\t{}'.format(word, tag))\n                            fout.write('\\n')\n                    else:\n                        fout.write('\\n')\n    convert_bio_to_json(BASE_OUTPUT_PATH, BASE_OUTPUT_PATH, short_name)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_mr_l3cube.py",
    "content": "\"\"\"\nReads one piece of the MR L3Cube dataset\n\nThe dataset is structured as a long list of words already in IOB format\nThe sentences have an ID which changes when a new sentence starts\nThe tags are labeled BNEM instead of B-NEM, so we update that.\n(Could theoretically remap the tags to names more typical of other datasets as well)\n\"\"\"\n\ndef convert(input_file):\n    \"\"\"\n    Converts one file of the dataset\n\n    Return: a list of list of pairs, (text, tag)\n    \"\"\"\n    with open(input_file, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n\n    sentences = []\n    current_sentence = []\n    prev_sent_id = None\n    for idx, line in enumerate(lines):\n        # first line of each of the segments is the header\n        if idx == 0:\n            continue\n\n        line = line.strip()\n        if not line:\n            continue\n        pieces = line.split(\"\\t\")\n        if len(pieces) != 3:\n            raise ValueError(\"Unexpected number of pieces at line %d of %s\" % (idx, input_file))\n\n        text, ner, sent_id = pieces\n        if ner != 'O':\n            # ner symbols are written as BNEM, BNED, etc in this dataset\n            ner = ner[0] + \"-\" + ner[1:]\n\n        if not prev_sent_id:\n            prev_sent_id = sent_id\n        if sent_id != prev_sent_id:\n            prev_sent_id = sent_id\n            if len(current_sentence) == 0:\n                raise ValueError(\"This should not happen!\")\n            sentences.append(current_sentence)\n            current_sentence = []\n\n        current_sentence.append((text, ner))\n\n    if current_sentence:\n        sentences.append(current_sentence)\n\n    print(\"Read %d sentences in %d lines from %s\" % (len(sentences), len(lines), input_file))\n    return sentences\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_my_ucsy.py",
    "content": "\"\"\"\nProcesses the three pieces of the NER dataset we received from UCSY.\n\nRequires the Myanmar tokenizer to exist, since the text is not already tokenized.\n\nThere are three files sent to us from UCSY, one each for train, dev, test\nThis script expects them to be in the ner directory with the names\n    $NERBASE/my_ucsy/Myanmar_NER_train.txt\n    $NERBASE/my_ucsy/Myanmar_NER_dev.txt\n    $NERBASE/my_ucsy/Myanmar_NER_test.txt\n\nThe files are in the following format:\n  unsegmentedtext@LABEL|unsegmentedtext@LABEL|...\nwith one sentence per line\n\nSolution:\n  - break the text up into fragments by splitting on |\n  - extract the labels\n  - segment each block of text using the MY tokenizer\n\nWe could take two approaches to breaking up the blocks.  One would be\nto combine all chunks, then segment an entire sentence at once.  This\nwould require some logic to re-chunk the resulting pieces.  Instead,\nwe resegment each individual chunk by itself.  This loses the\ninformation from the neighboring chunks, but guarantees there are no\nscrewups where segmentation crosses segment boundaries and is simpler\nto code.\n\nOf course, experimenting with the alternate approach might be better.\n\nThere is one stray label of SB in the training data, so we throw out\nthat entire sentence.\n\"\"\"\n\n\nimport os\n\nfrom tqdm import tqdm\nimport stanza\nfrom stanza.utils.datasets.ner.check_for_duplicates import check_for_duplicates\n\nSPLITS = (\"train\", \"dev\", \"test\")\n\ndef convert_file(input_filename, output_filename, pipe):\n    with open(input_filename) as fin:\n        lines = fin.readlines()\n\n    all_labels = set()\n\n    with open(output_filename, \"w\") as fout:\n        for line in tqdm(lines):\n            pieces = line.split(\"|\")\n            texts = []\n            labels = []\n            skip_sentence = False\n            for piece in pieces:\n                piece = piece.strip()\n                if not piece:\n                    continue\n                text, label = piece.rsplit(\"@\", maxsplit=1)\n                text = text.strip()\n                if not text:\n                    continue\n                if label == 'SB':\n                    skip_sentence = True\n                    break\n\n                texts.append(text)\n                labels.append(label)\n\n            if skip_sentence:\n                continue\n\n            text = \"\\n\\n\".join(texts)\n            doc = pipe(text)\n            assert len(doc.sentences) == len(texts)\n            for sentence, label in zip(doc.sentences, labels):\n                all_labels.add(label)\n                for word_idx, word in enumerate(sentence.words):\n                    if label == \"O\":\n                        output_label = \"O\"\n                    elif word_idx == 0:\n                        output_label = \"B-\" + label\n                    else:\n                        output_label = \"I-\" + label\n\n                    fout.write(\"%s\\t%s\\n\" % (word.text, output_label))\n            fout.write(\"\\n\\n\")\n\n    print(\"Finished processing {}  Labels found: {}\".format(input_filename, sorted(all_labels)))\n\ndef convert_my_ucsy(base_input_path, base_output_path):\n    os.makedirs(base_output_path, exist_ok=True)\n    pipe = stanza.Pipeline(\"my\", processors=\"tokenize\", tokenize_no_ssplit=True)\n    output_filenames = [os.path.join(base_output_path, \"my_ucsy.%s.bio\" % split) for split in SPLITS]\n\n    for split, output_filename in zip(SPLITS, output_filenames):\n        input_filename = os.path.join(base_input_path, \"Myanmar_NER_%s.txt\" % split)\n        if not os.path.exists(input_filename):\n            raise FileNotFoundError(\"Necessary file for my_ucsy does not exist: %s\" % input_filename)\n\n        convert_file(input_filename, output_filename, pipe)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_nkjp.py",
    "content": "import argparse\nimport json\nimport os\nimport random\nimport tarfile\nimport tempfile\nfrom tqdm import tqdm\n# could import lxml here, but that would involve adding lxml as a\n# dependency to the stanza package\n# another alternative would be to try & catch ImportError\ntry:\n    from lxml import etree\nexcept ImportError:\n    import xml.etree.ElementTree as etree\n\n\nNAMESPACE = \"http://www.tei-c.org/ns/1.0\"\nMORPH_FILE = \"ann_morphosyntax.xml\"\nNER_FILE = \"ann_named.xml\"\nSEGMENTATION_FILE = \"ann_segmentation.xml\"\n\ndef parse_xml(path):\n    if not os.path.exists(path):\n        return None\n    et = etree.parse(path)\n    rt = et.getroot()\n    return rt\n\n\ndef get_node_id(node):\n    # get the id from the xml node\n    return node.get('{http://www.w3.org/XML/1998/namespace}id')\n\n\ndef extract_entities_from_subfolder(subfolder, nkjp_dir):\n    # read the ner annotation from a subfolder, assign it to paragraphs\n    subfolder_entities = extract_unassigned_subfolder_entities(subfolder, nkjp_dir)\n    par_id_to_segs = assign_entities(subfolder, subfolder_entities, nkjp_dir)\n    return par_id_to_segs\n\n\ndef extract_unassigned_subfolder_entities(subfolder, nkjp_dir):\n    \"\"\"\n    Build and return a map from par_id to extracted entities\n    \"\"\"\n    ner_path = os.path.join(nkjp_dir, subfolder, NER_FILE)\n    rt = parse_xml(ner_path)\n    if rt is None:\n        return None\n    subfolder_entities = {}\n    ner_pars = rt.findall(\"{%s}TEI/{%s}text/{%s}body/{%s}p\" % (NAMESPACE, NAMESPACE, NAMESPACE, NAMESPACE))\n    for par in ner_pars:\n        par_entities = {}\n        _, par_id = get_node_id(par).split(\"_\")\n        ner_sents = par.findall(\"{%s}s\" % NAMESPACE)\n        for ner_sent in ner_sents:\n            corresp = ner_sent.get(\"corresp\")\n            _, ner_sent_id  = corresp.split(\"#morph_\")\n            par_entities[ner_sent_id] = extract_entities_from_sentence(ner_sent)\n        subfolder_entities[par_id] = par_entities\n    return subfolder_entities\n\ndef extract_entities_from_sentence(ner_sent):\n    # extracts all the entity dicts from the sentence\n    # we assume that an entity cannot span across sentences\n    segs = ner_sent.findall(\"./{%s}seg\" % NAMESPACE)\n    sent_entities = {}\n    for i, seg in enumerate(segs):\n        ent_id = get_node_id(seg)\n        targets = [ptr.get(\"target\") for ptr in seg.findall(\"./{%s}ptr\" % NAMESPACE)]\n        orth = seg.findall(\"./{%s}fs/{%s}f[@name='orth']/{%s}string\" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].text\n        ner_type = seg.findall(\"./{%s}fs/{%s}f[@name='type']/{%s}symbol\" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].get(\"value\")\n        ner_subtype_node = seg.findall(\"./{%s}fs/{%s}f[@name='subtype']/{%s}symbol\" % (NAMESPACE, NAMESPACE, NAMESPACE))\n        if ner_subtype_node:\n            ner_subtype = ner_subtype_node[0].get(\"value\")\n        else:\n            ner_subtype = None\n        entity = {\"ent_id\": ent_id,\n                  \"index\": i,\n                  \"orth\": orth,\n                  \"ner_type\": ner_type,\n                  \"ner_subtype\": ner_subtype,\n                  \"targets\": targets}\n        sent_entities[ent_id] = entity\n    cleared_entities = clear_entities(sent_entities)\n    return cleared_entities\n\n\ndef clear_entities(entities):\n    # eliminates entities which extend beyond our scope\n    resolve_entities(entities)\n    entities_list = sorted(list(entities.values()), key=lambda ent: ent[\"index\"])\n    entities = eliminate_overlapping_entities(entities_list)\n    for entity in entities:\n        targets = entity[\"targets\"]\n        entity[\"targets\"] = [t.split(\"morph_\")[1] for t in targets]\n    return entities\n\n\ndef resolve_entities(entities):\n    # assign morphological level targets to entities\n    resolved_targets = {entity_id: resolve_entity(entity, entities) for entity_id, entity in entities.items()}\n    for entity_id in entities:\n        entities[entity_id][\"targets\"] = resolved_targets[entity_id]\n\n\ndef resolve_entity(entity, entities):\n    # translate targets defined in terms of entities, into morphological units\n    # works recurrently\n    targets = entity[\"targets\"]\n    resolved = []\n    for target in targets:\n        if target.startswith(\"named_\"):\n            target_entity = entities[target]\n            resolved.extend(resolve_entity(target_entity, entities))\n        else:\n            resolved.append(target)\n    return resolved\n\n\ndef eliminate_overlapping_entities(entities_list):\n    # we eliminate entities which are at least partially contained in one ocurring prior to them\n    # this amounts to removing overlap\n    subsumed = set([])\n    for sub_i, sub in enumerate(entities_list):\n        for over in entities_list[:sub_i]:\n            if any([target in over[\"targets\"] for target in sub[\"targets\"]]):\n                subsumed.add(sub[\"ent_id\"])\n    return [entity for entity in entities_list if entity[\"ent_id\"] not in subsumed]\n\n\ndef assign_entities(subfolder, subfolder_entities, nkjp_dir):\n    # recovers all the segments from a subfolder, and annotates it with NER\n    morph_path = os.path.join(nkjp_dir, subfolder, MORPH_FILE)\n    rt = parse_xml(morph_path)\n    morph_pars = rt.findall(\"{%s}TEI/{%s}text/{%s}body/{%s}p\" % (NAMESPACE, NAMESPACE, NAMESPACE, NAMESPACE))\n    par_id_to_segs = {}\n    for par in morph_pars:\n        _, par_id = get_node_id(par).split(\"_\")\n        morph_sents = par.findall(\"{%s}s\" % NAMESPACE)\n        sent_id_to_segs = {}\n        for morph_sent in morph_sents:\n            _, sent_id = get_node_id(morph_sent).split(\"_\")\n            segs = morph_sent.findall(\"{%s}seg\" % NAMESPACE)\n            sent_segs = {}\n            for i, seg in enumerate(segs):\n                _, seg_id = get_node_id(seg).split(\"morph_\")\n                orth = seg.findall(\"{%s}fs/{%s}f[@name='orth']/{%s}string\" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].text\n                token = {\"seg_id\": seg_id,\n                          \"i\": i,\n                          \"orth\": orth,\n                          \"text\": orth,\n                          \"tag\": \"_\",\n                          \"ner\": \"O\", # This will be overwritten\n                          \"ner_subtype\": None,\n                          }\n                sent_segs[seg_id] = token\n            sent_id_to_segs[sent_id] = sent_segs\n        par_id_to_segs[par_id] = sent_id_to_segs\n\n    if subfolder_entities is None:\n        return None\n\n    for par_key in subfolder_entities:\n        par_ents = subfolder_entities[par_key]\n        for sent_key in par_ents:\n            sent_entities = par_ents[sent_key]\n            for entity in sent_entities:\n                targets = entity[\"targets\"]\n                iob = \"B\"\n                ner_label = entity[\"ner_type\"]\n                matching_tokens = sorted([par_id_to_segs[par_key][sent_key][target] for target in targets], key=lambda x:x[\"i\"])\n                for token in matching_tokens:\n                    full_label = f\"{iob}-{ner_label}\"\n                    token[\"ner\"] = full_label\n                    token[\"ner_subtype\"] = entity[\"ner_subtype\"]\n                    iob = \"I\"\n    return par_id_to_segs\n\n\ndef load_xml_nkjp(nkjp_dir):\n    subfolder_to_annotations = {}\n    subfolders = sorted(os.listdir(nkjp_dir))\n    for subfolder in tqdm([name for name in subfolders if os.path.isdir(os.path.join(nkjp_dir, name))]):\n        out = extract_entities_from_subfolder(subfolder, nkjp_dir)\n        if out:\n            subfolder_to_annotations[subfolder] = out\n        else:\n            print(subfolder, \"has no ann_named.xml file\")\n\n    return subfolder_to_annotations\n\n\ndef split_dataset(dataset, shuffle=True, train_fraction=0.9, dev_fraction=0.05, test_section=True):\n    random.seed(987654321)\n    if shuffle:\n        random.shuffle(dataset)\n\n    if not test_section:\n        dev_fraction = 1 - train_fraction\n\n    train_size = int(train_fraction * len(dataset))\n    dev_size = int(dev_fraction * len(dataset))\n    train = dataset[:train_size]\n    dev = dataset[train_size: train_size + dev_size]\n    test = dataset[train_size + dev_size:]\n\n    return {\n        'train': train,\n        'dev': dev,\n        'test': test\n    }\n\n\ndef convert_nkjp(nkjp_path, output_dir):\n    \"\"\"Converts NKJP NER data into IOB json format.\n\n    nkjp_dir is the path to directory where NKJP files are located.\n    \"\"\"\n    # Load XML NKJP\n    print(\"Reading data from %s\" % nkjp_path)\n    if os.path.isfile(nkjp_path) and (nkjp_path.endswith(\".tar.gz\") or nkjp_path.endswith(\".tgz\")):\n        with tempfile.TemporaryDirectory() as nkjp_dir:\n            print(\"Temporarily extracting %s to %s\" % (nkjp_path, nkjp_dir))\n            with tarfile.open(nkjp_path, \"r:gz\") as tar_in:\n                tar_in.extractall(nkjp_dir)\n\n            subfolder_to_entities = load_xml_nkjp(nkjp_dir)\n    elif os.path.isdir(nkjp_path):\n        subfolder_to_entities = load_xml_nkjp(nkjp_path)\n    else:\n        raise FileNotFoundError(\"Cannot find either unpacked dataset or gzipped file\")\n    converted = []\n    for subfolder_name, pars in subfolder_to_entities.items():\n        for par_id, par in pars.items():\n            paragraph_identifier = f\"{subfolder_name}|{par_id}\"\n            par_tokens = []\n            for _, sent in par.items():\n                tokens = sent.values()\n                srt = sorted(tokens, key=lambda tok:tok[\"i\"])\n                for token in srt:\n                    _ = token.pop(\"i\")\n                    _ = token.pop(\"seg_id\")\n                    par_tokens.append(token)\n            par_tokens[0][\"paragraph_id\"] = paragraph_identifier\n            converted.append(par_tokens)\n\n    split = split_dataset(converted)\n\n    for split_name, split in split.items():\n        if split:\n            with open(os.path.join(output_dir, f\"pl_nkjp.{split_name}.json\"), \"w\", encoding=\"utf-8\") as f:\n                json.dump(split, f, ensure_ascii=False, indent=2)\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input_path', type=str, default=\"/u/nlp/data/ner/stanza/polish/NKJP-PodkorpusMilionowy-1.2.tar.gz\", help=\"Where to find the files\")\n    parser.add_argument('--output_path', type=str, default=\"data/ner\", help=\"Where to output the results\")\n    args = parser.parse_args()\n\n    convert_nkjp(args.input_path, args.output_path)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_nner22.py",
    "content": "\"\"\"\nConverts the Thai NNER22 dataset to a format usable by Stanza's NER model\n\nThe dataset is already written in json format, so we will convert into a compatible json format.\n\nThe dataset in the original format has nested NER format which we will only extract the first layer\nof NER tag and write it in the format accepted by current Stanza model\n\"\"\"\n\nimport os\nimport logging\nimport json\n\ndef convert_nner22(paths, short_name, include_space_char=True):\n    assert short_name == \"th_nner22\"\n    SHARDS = (\"train\", \"dev\", \"test\")\n    BASE_INPUT_PATH = os.path.join(paths[\"NERBASE\"], \"thai\", \"Thai-NNER\", \"data\", \"scb-nner-th-2022\", \"postproc\")\n\n    if not include_space_char:\n        short_name = short_name + \"_no_ws\"\n\n    for shard in SHARDS:\n        input_path = os.path.join(BASE_INPUT_PATH, \"%s.json\" % (shard))\n        output_path = os.path.join(paths[\"NER_DATA_DIR\"], \"%s.%s.json\" % (short_name, shard))\n\n        logging.info(\"Output path for %s split at %s\" % (shard, output_path))\n\n        data = json.load(open(input_path))\n\n        documents = []\n\n        for i in range(len(data)):\n            token, entities = data[i][\"tokens\"], data[i][\"entities\"]\n\n            token_length, sofar = len(token), 0\n            document, ner_dict = [], {}\n\n            for entity in entities:\n                start, stop = entity[\"span\"]\n\n                if stop > sofar:\n                    ner = entity[\"entity_type\"].upper()\n                    sofar = stop\n\n                    for j in range(start, stop):\n                        if j == start:\n                            ner_tag = \"B-\" + ner\n                        elif j == stop - 1:\n                            ner_tag = \"E-\" + ner\n                        else:\n                            ner_tag = \"I-\" + ner\n\n                        ner_dict[j] = (ner_tag, token[j])\n\n            for k in range(token_length):\n                dict_add = {}\n\n                if k not in ner_dict:\n                    dict_add[\"ner\"], dict_add[\"text\"] = \"O\", token[k]\n                else:\n                    dict_add[\"ner\"], dict_add[\"text\"] = ner_dict[k]\n\n                document.append(dict_add)\n\n            documents.append(document)\n\n        with open(output_path, \"w\") as outfile:\n            json.dump(documents, outfile, indent=1)\n\n        logging.info(\"%s.%s.json file successfully created\" % (short_name, shard))\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_nytk.py",
    "content": "\nimport glob\nimport os\n\ndef convert_nytk(base_input_path, base_output_path, short_name):\n    for shard in ('train', 'dev', 'test'):\n        if shard == 'dev':\n            base_input_subdir = os.path.join(base_input_path, \"data/train-devel-test/devel\")\n        else:\n            base_input_subdir = os.path.join(base_input_path, \"data/train-devel-test\", shard)\n\n        shard_lines = []\n        base_input_glob = base_input_subdir + \"/*/no-morph/*\"\n        subpaths = glob.glob(base_input_glob)\n        print(\"Reading %d input files from %s\" % (len(subpaths), base_input_glob))\n        for input_filename in subpaths:\n            if len(shard_lines) > 0:\n                shard_lines.append(\"\")\n            with open(input_filename) as fin:\n                lines = fin.readlines()\n                if lines[0].strip() != '# global.columns = FORM LEMMA UPOS XPOS FEATS CONLL:NER':\n                    raise ValueError(\"Unexpected format in %s\" % input_filename)\n                lines = [x.strip().split(\"\\t\") for x in lines[1:]]\n                lines = [\"%s\\t%s\" % (x[0], x[5]) if len(x) > 1 else \"\" for x in lines]\n                shard_lines.extend(lines)\n\n        bio_filename = os.path.join(base_output_path, '%s.%s.bio' % (short_name, shard))\n        with open(bio_filename, \"w\") as fout:\n            print(\"Writing %d lines to %s\" % (len(shard_lines), bio_filename))\n            for line in shard_lines:\n                fout.write(line)\n                fout.write(\"\\n\")\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_ontonotes.py",
    "content": "\"\"\"\nDownloads (if necessary) conll03 from Huggingface, then converts it to Stanza .json\n\nSome online sources for CoNLL 2003 require multiple pieces, but it is currently hosted on HF:\nhttps://huggingface.co/datasets/conll2003\n\"\"\"\n\nimport os\n\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.datasets.ner.utils import write_dataset\n\nID_TO_TAG = [\"O\", \"B-PERSON\", \"I-PERSON\", \"B-NORP\", \"I-NORP\", \"B-FAC\", \"I-FAC\", \"B-ORG\", \"I-ORG\", \"B-GPE\", \"I-GPE\", \"B-LOC\", \"I-LOC\", \"B-PRODUCT\", \"I-PRODUCT\", \"B-DATE\", \"I-DATE\", \"B-TIME\", \"I-TIME\", \"B-PERCENT\", \"I-PERCENT\", \"B-MONEY\", \"I-MONEY\", \"B-QUANTITY\", \"I-QUANTITY\", \"B-ORDINAL\", \"I-ORDINAL\", \"B-CARDINAL\", \"I-CARDINAL\", \"B-EVENT\", \"I-EVENT\", \"B-WORK_OF_ART\", \"I-WORK_OF_ART\", \"B-LAW\", \"I-LAW\", \"B-LANGUAGE\", \"I-LANGUAGE\",]\n\ndef convert_dataset_section(config_name, section):\n    sentences = []\n    for doc in section:\n        # the nt_ sentences (New Testament) in the HF version of OntoNotes\n        # have blank named_entities, even though there was no original .name file\n        # that corresponded with these annotations\n        if config_name.startswith(\"english\") and doc['document_id'].startswith(\"pt/nt\"):\n            continue\n        for sentence in doc['sentences']:\n            words = sentence['words']\n            tags = [ID_TO_TAG[x] for x in sentence['named_entities']]\n            sentences.append(list(zip(words, tags)))\n    return sentences\n\ndef process_dataset(short_name, conll_path, ner_output_path):\n    try:\n        from datasets import load_dataset\n    except ImportError as e:\n        raise ImportError(\"Please install the datasets package to process CoNLL03 with Stanza\")\n\n    if short_name == 'en_ontonotes':\n        # there is an english_v12, but it is filled with junk annotations\n        # for example, near the end:\n        #   And John_O, I realize\n        config_name = 'english_v4'\n    elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'):\n        config_name = 'chinese_v4'\n    elif short_name == 'ar_ontonotes':\n        config_name = 'arabic_v4'\n    else:\n        raise ValueError(\"Unknown short name for downloading ontonotes: %s\" % short_name)\n    dataset = load_dataset(\"conll2012_ontonotesv5\", config_name, cache_dir=conll_path)\n    datasets = [convert_dataset_section(config_name, x) for x in [dataset['train'], dataset['validation'], dataset['test']]]\n    write_dataset(datasets, ner_output_path, short_name)\n\ndef main():\n    paths = get_default_paths()\n    ner_input_path = paths['NERBASE']\n    conll_path = os.path.join(ner_input_path, \"english\", \"en_ontonotes\")\n    ner_output_path = paths['NER_DATA_DIR']\n    process_dataset(\"en_ontonotes\", conll_path, ner_output_path)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_rgai.py",
    "content": "\"\"\"\nThis script converts the Hungarian files available at u-szeged\n  https://rgai.inf.u-szeged.hu/node/130\n\"\"\"\n\nimport os\nimport tempfile\n\n# we reuse this to split the data randomly\nfrom stanza.utils.datasets.ner.split_wikiner import split_wikiner\n\ndef read_rgai_file(filename, separator):\n    with open(filename, encoding=\"latin-1\") as fin:\n        lines = fin.readlines()\n        lines = [x.strip() for x in lines]\n\n        for idx, line in enumerate(lines):\n            if not line:\n                continue\n            pieces = lines[idx].split(separator)\n            if len(pieces) != 2:\n                raise ValueError(\"Line %d is in an unexpected format!  Expected exactly two pieces when split on %s\" % (idx, separator))\n            # some of the data has '0' (the digit) instead of 'O' (the letter)\n            if pieces[-1] == '0':\n                pieces[-1] = \"O\"\n                lines[idx] = \"\\t\".join(pieces)\n    print(\"Read %d lines from %s\" % (len(lines), filename))\n    return lines\n\ndef get_rgai_data(base_input_path, use_business, use_criminal):\n    assert use_business or use_criminal, \"Must specify one or more sections of the dataset to use\"\n\n    dataset_lines = []\n    if use_business:\n        business_file = os.path.join(base_input_path, \"hun_ner_corpus.txt\")\n\n        lines = read_rgai_file(business_file, \"\\t\")\n        dataset_lines.extend(lines)\n\n    if use_criminal:\n        # There are two different annotation schemes, Context and\n        # NoContext.  NoContext seems to fit better with the\n        # business_file's annotation scheme, since the scores are much\n        # higher when NoContext and hun_ner are combined\n        criminal_file = os.path.join(base_input_path, \"HVGJavNENoContext\")\n\n        lines = read_rgai_file(criminal_file, \" \")\n        dataset_lines.extend(lines)\n\n    return dataset_lines\n\ndef convert_rgai(base_input_path, base_output_path, short_name, use_business, use_criminal):\n    all_data_file = tempfile.NamedTemporaryFile(delete=False)\n    try:\n        raw_data = get_rgai_data(base_input_path, use_business, use_criminal)\n        for line in raw_data:\n            all_data_file.write(line.encode())\n            all_data_file.write(\"\\n\".encode())\n        all_data_file.close()\n        split_wikiner(base_output_path, all_data_file.name, prefix=short_name)\n    finally:\n        os.unlink(all_data_file.name)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_sindhi_siner.py",
    "content": "\"\"\"\nConverts the raw data from SiNER to .json for the Stanza NER system\n\nhttps://aclanthology.org/2020.lrec-1.361.pdf\n\"\"\"\n\nfrom stanza.utils.datasets.ner.utils import write_dataset\n\ndef fix_sentence(sentence):\n    \"\"\"\n    Fix some of the mistags in the dataset\n\n    This covers 11 sentences: 1 P-PERSON, 2 with line breaks in the middle of the tag, and 8 with no B- or I-\n    \"\"\"\n    new_sentence = []\n    for word_idx, word in enumerate(sentence):\n        if word[1] == 'P-PERSON':\n            new_sentence.append((word[0], 'B-PERSON'))\n        elif word[1] == 'B-OT\"':\n            new_sentence.append((word[0], 'B-OTHERS'))\n        elif word[1] == 'B-T\"':\n            new_sentence.append((word[0], 'B-TITLE'))\n        elif word[1] in ('GPE', 'LOC', 'OTHERS'):\n            if len(new_sentence) > 0 and new_sentence[-1][1][:2] in ('B-', 'I-') and new_sentence[-1][1][2:] == word[1]:\n                # one example... no idea if it should be a break or\n                # not, but the last word translates to \"Corporation\",\n                # so probably not: ميٽرو پوليٽن ڪارپوريشن\n                new_sentence.append((word[0], 'I-' + word[1]))\n            else:\n                new_sentence.append((word[0], 'B-' + word[1]))\n        else:\n            new_sentence.append(word)\n    return new_sentence\n\ndef convert_sindhi_siner(in_filename, out_directory, short_name, train_frac=0.8, dev_frac=0.1):\n    \"\"\"\n    Read lines from the dataset, crudely separate sentences based on . or !, and write the dataset\n    \"\"\"\n    with open(in_filename, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n\n    lines = [x.strip().split(\"\\t\") for x in lines]\n    lines = [(x[0].strip(), x[1].strip()) for x in lines if len(x) == 2]\n    print(\"Read %d words from %s\" % (len(lines), in_filename))\n    sentences = []\n    prev_idx = 0\n    for sent_idx, line in enumerate(lines):\n        # maybe also handle line[0] == '،', \"Arabic comma\"?\n        if line[0] in ('.', '!'):\n            sentences.append(lines[prev_idx:sent_idx+1])\n            prev_idx=sent_idx+1\n\n    # in case the file doesn't end with punctuation, grab the last few lines\n    if prev_idx < len(lines):\n        sentences.append(lines[prev_idx:])\n\n    print(\"Found %d sentences before splitting\" % len(sentences))\n    sentences = [fix_sentence(x) for x in sentences]\n    assert not any('\"' in x[1] or x[1].startswith(\"P-\") or x[1] in (\"GPE\", \"LOC\", \"OTHERS\") for sentence in sentences for x in sentence)\n\n    train_len = int(len(sentences) * train_frac)\n    dev_len = int(len(sentences) * (train_frac+dev_frac))\n    train_sentences = sentences[:train_len]\n    dev_sentences = sentences[train_len:dev_len]\n    test_sentences = sentences[dev_len:]\n\n    datasets = (train_sentences, dev_sentences, test_sentences)\n    write_dataset(datasets, out_directory, short_name, suffix=\"bio\")\n\n"
  },
  {
    "path": "stanza/utils/datasets/ner/convert_starlang_ner.py",
    "content": "\"\"\"\nConvert the starlang trees to a NER dataset\n\nHas to hide quite a few trees with missing NER labels\n\"\"\"\n\nimport re\n\nfrom stanza.models.constituency import tree_reader\nimport stanza.utils.datasets.constituency.convert_starlang as convert_starlang\n\nTURKISH_WORD_RE = re.compile(r\"[{]turkish=([^}]+)[}]\")\nTURKISH_LABEL_RE = re.compile(r\"[{]namedEntity=([^}]+)[}]\")\n\n\n\ndef read_tree(text):\n    \"\"\"\n    Reads in a tree, then extracts the word and the NER\n\n    One problem is that it is unknown if there are cases of two separate items occurring consecutively\n\n    Note that this is quite similar to the convert_starlang script for constituency.  \n    \"\"\"\n    trees = tree_reader.read_trees(text)\n    if len(trees) > 1:\n        raise ValueError(\"Tree file had two trees!\")\n    tree = trees[0]\n    words = []\n    for label in tree.leaf_labels():\n        match = TURKISH_WORD_RE.search(label)\n        if match is None:\n            raise ValueError(\"Could not find word in |{}|\".format(label))\n        word = match.group(1)\n        word = word.replace(\"-LCB-\", \"{\").replace(\"-RCB-\", \"}\")\n\n        match = TURKISH_LABEL_RE.search(label)\n        if match is None:\n            raise ValueError(\"Could not find ner in |{}|\".format(label))\n        tag = match.group(1)\n        if tag == 'NONE' or tag == \"null\":\n            tag = 'O'\n        words.append((word, tag))\n\n    return words\n\ndef read_starlang(paths):\n    return convert_starlang.read_starlang(paths, conversion=read_tree, log=False)\n\ndef main():\n    train, dev, test = convert_starlang.main(conversion=read_tree, log=False)\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/ner/count_entities.py",
    "content": "\nimport argparse\nfrom collections import defaultdict\nimport json\n\nfrom stanza.models.common.doc import Document\nfrom stanza.utils.datasets.ner.utils import list_doc_entities\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Report the coverage of one NER file on another.\")\n    parser.add_argument('filename', type=str, nargs='+', help='File(s) to count')\n    args = parser.parse_args()\n    return args\n\n\ndef count_entities(*filenames):\n    entity_collection = defaultdict(list)\n\n    for filename in filenames:\n        with open(filename) as fin:\n            doc = Document(json.load(fin))\n            num_tokens = sum(1 for sentence in doc.sentences for token in sentence.tokens)\n            print(\"Number of tokens in %s: %d\" % (filename, num_tokens))\n            entities = list_doc_entities(doc)\n\n        for ent in entities:\n            entity_collection[ent[1]].append(ent[0])\n\n    keys = sorted(entity_collection.keys())\n    for k in keys:\n        print(k, len(entity_collection[k]))\n\ndef main():\n    args = parse_args()\n\n    count_entities(*args.filename)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/json_to_bio.py",
    "content": "\"\"\"\nIf you want to convert .json back to .bio for some reason, this will do it for you\n\"\"\"\n\nimport argparse\nimport json\nimport os\nfrom stanza.models.common.doc import Document\nfrom stanza.models.ner.utils import process_tags\nfrom stanza.utils.default_paths import get_default_paths\n\ndef convert_json_to_bio(input_filename, output_filename):\n    with open(input_filename, encoding=\"utf-8\") as fin:\n        doc = Document(json.load(fin))\n    sentences = [[(word.text, word.ner) for word in sentence.tokens] for sentence in doc.sentences]\n    sentences = process_tags(sentences, \"bioes\")\n    with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n        for sentence in sentences:\n            for word in sentence:\n                fout.write(\"%s\\t%s\\n\" % word)\n            fout.write(\"\\n\")\n\ndef main(args=None):\n    ner_data_dir = get_default_paths()['NER_DATA_DIR']\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input_filename', type=str, default=\"data/ner/en_foreign-4class.test.json\", help='Convert an individual file')\n    parser.add_argument('--input_dir', type=str, default=ner_data_dir, help='Which directory to find the dataset, if using --input_dataset')\n    parser.add_argument('--input_dataset', type=str, help='Convert an entire dataset')\n    parser.add_argument('--output_suffix', type=str, default='bioes', help='suffix for output filenames')\n    args = parser.parse_args(args)\n\n    if args.input_dataset:\n        input_filenames = [os.path.join(args.input_dir, \"%s.%s.json\" % (args.input_dataset, shard))\n                           for shard in (\"train\", \"dev\", \"test\")]\n    else:\n        input_filenames = [args.input_filename]\n    for input_filename in input_filenames:\n        output_filename = os.path.splitext(input_filename)[0] + \".\" + args.output_suffix\n        print(\"%s -> %s\" % (input_filename, output_filename))\n        convert_json_to_bio(input_filename, output_filename)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/misc_to_date.py",
    "content": "# for the Worldwide dataset, automatically switch the Misc tags to Date when Stanza Ontonotes thinks it's a Date\n# this keeps our annotation scheme for dates (eg, not \"3 months ago\") while hopefully switching them all to Date\n#\n# maybe some got missed\n# also, there are a few with some nested entities.  printed out warnings and edited those by hand\n#\n# just need to run this with the Worldwide dataset in the ner path\n# it will automatically convert as many as it can\n\nimport os\n\nfrom tqdm import tqdm\n\nimport stanza\nfrom stanza.utils.datasets.ner.utils import read_tsv\nfrom stanza.utils.default_paths import get_default_paths\n\npaths = get_default_paths()\nBASE_PATH = os.path.join(paths[\"NERBASE\"], \"en_foreign\")\ninput_dir = os.path.join(BASE_PATH, \"en-foreign-newswire\")\n\npipe = stanza.Pipeline(\"en\", processors=\"tokenize,ner\", tokenize_pretokenized=True, package={\"ner\": \"ontonotes_bert\"})\n\nfilenames = []\n\ndef ner_tags(pipe, sentence):\n    doc = pipe([sentence])\n    tags = [token.ner for sentence in doc.sentences for token in sentence.tokens]\n    return tags\n\nfor root, dirs, files in os.walk(input_dir):\n    if root[-6:] == \"REVIEW\":\n        batch_files = os.listdir(root)\n        for filename in batch_files:\n            file_path = os.path.join(root, filename)\n            filenames.append(file_path)\n\nfor filename in tqdm(filenames):\n    try:\n        data = read_tsv(filename, text_column=0, annotation_column=1, skip_comments=False, keep_all_columns=True)\n\n        with open(filename, 'w', encoding='utf-8') as fout:\n            warned_file = False\n            for sentence in data:  # segments delimited by spaces, effectively sentences\n                tokens = [x[0] for x in sentence]\n                labels = [x[1] for x in sentence]\n\n                if any(x.endswith(\"Misc\") for x in labels):\n                    stanza_tags = ner_tags(pipe, tokens)\n                    in_date = False\n                    for i, stanza_tag in enumerate(stanza_tags):\n                        if stanza_tag[2:] == \"DATE\" and labels[i] != \"O\":\n                            if len(sentence[i]) > 2:\n                                if not warned_file:\n                                    print(\"Warning: file %s has nested tags being altered\" % filename)\n                                    warned_file = True\n                            # put DATE tags where Stanza thinks there are DATEs\n                            # as long as we already had a MISC (or something else, I suppose)\n                            if in_date and not stanza_tag[0].startswith(\"B\") and not stanza_tag[0].startswith(\"S\"):\n                                sentence[i][1] = \"I-Date\"\n                            else:\n                                sentence[i][1] = \"B-Date\"\n                            in_date = True\n                        elif in_date:\n                            # make sure new tags start with B- instead of I-\n                            # honestly it's not clear if, in these cases,\n                            # we should be switching the following tags to\n                            # DATE as well. will have to experiment some\n                            in_date = False\n                            if labels[i].startswith(\"I-\"):\n                                sentence[i][1] = \"B-\" + labels[i][2:]\n                for word in sentence:\n                    fout.write(\"\\t\".join(word))\n                    fout.write(\"\\n\")\n                fout.write(\"\\n\")\n    except AssertionError:\n        print(\"Could not process %s\" % filename)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/ontonotes_multitag.py",
    "content": "\"\"\"\nCombines OntoNotes and WW into a single dataset with OntoNotes used for dev & test\n\nThe resulting dataset has two layers saved in the multi_ner column.\n\nWW is kept as 9 classes, with the tag put in either the first or\nsecond layer depending on the flags.\n\nOntoNotes is converted to one column for 18 and one column for 9 classes.\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport shutil\n\nfrom stanza.utils import default_paths\nfrom stanza.utils.datasets.ner.utils import combine_files\nfrom stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide\n\ndef convert_ontonotes_file(filename, simplify, bigger_first):\n    assert \"en_ontonotes\" in filename\n    if not os.path.exists(filename):\n        raise FileNotFoundError(\"Cannot convert missing file %s\" % filename)\n    new_filename = filename.replace(\"en_ontonotes\", \"en_ontonotes-multi\")\n\n    with open(filename) as fin:\n        doc = json.load(fin)\n\n    for sentence in doc:\n        for word in sentence:\n            ner = word['ner']\n            if simplify:\n                simplified = simplify_ontonotes_to_worldwide(ner)\n            else:\n                simplified = \"-\"\n            if bigger_first:\n                word['multi_ner'] = (ner, simplified)\n            else:\n                word['multi_ner'] = (simplified, ner)\n\n    with open(new_filename, \"w\") as fout:\n        json.dump(doc, fout, indent=2)\n\ndef convert_worldwide_file(filename, bigger_first):\n    assert \"en_worldwide-9class\" in filename\n    if not os.path.exists(filename):\n        raise FileNotFoundError(\"Cannot convert missing file %s\" % filename)\n\n    new_filename = filename.replace(\"en_worldwide-9class\", \"en_worldwide-9class-multi\")\n\n    with open(filename) as fin:\n        doc = json.load(fin)\n\n    for sentence in doc:\n        for word in sentence:\n            ner = word['ner']\n            if bigger_first:\n                word['multi_ner'] = (\"-\", ner)\n            else:\n                word['multi_ner'] = (ner, \"-\")\n\n    with open(new_filename, \"w\") as fout:\n        json.dump(doc, fout, indent=2)\n\ndef build_multitag_dataset(base_output_path, short_name, simplify, bigger_first):\n    convert_ontonotes_file(os.path.join(base_output_path, \"en_ontonotes.train.json\"), simplify, bigger_first)\n    convert_ontonotes_file(os.path.join(base_output_path, \"en_ontonotes.dev.json\"), simplify, bigger_first)\n    convert_ontonotes_file(os.path.join(base_output_path, \"en_ontonotes.test.json\"), simplify, bigger_first)\n\n    convert_worldwide_file(os.path.join(base_output_path, \"en_worldwide-9class.train.json\"), bigger_first)\n    convert_worldwide_file(os.path.join(base_output_path, \"en_worldwide-9class.dev.json\"), bigger_first)\n    convert_worldwide_file(os.path.join(base_output_path, \"en_worldwide-9class.test.json\"), bigger_first)\n\n    combine_files(os.path.join(base_output_path, \"%s.train.json\" % short_name),\n                  os.path.join(base_output_path, \"en_ontonotes-multi.train.json\"),\n                  os.path.join(base_output_path, \"en_worldwide-9class-multi.train.json\"))\n    shutil.copyfile(os.path.join(base_output_path, \"en_ontonotes-multi.dev.json\"),\n                    os.path.join(base_output_path, \"%s.dev.json\" % short_name))\n    shutil.copyfile(os.path.join(base_output_path, \"en_ontonotes-multi.test.json\"),\n                    os.path.join(base_output_path, \"%s.test.json\" % short_name))\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--no_simplify', dest='simplify', action='store_false', help='By default, this script will simplify the OntoNotes 18 classes to the 8 WorldWide classes in a second column.  Turning that off will leave that column blank.  Initial experiments with that setting were very bad, though')\n    parser.add_argument('--no_bigger_first', dest='bigger_first', action='store_false', help='By default, this script will put the 18 class tags in the first column and the 8 in the second.  This flips the order')\n    args = parser.parse_args()\n\n    paths = default_paths.get_default_paths()\n    base_output_path = paths[\"NER_DATA_DIR\"]\n\n    build_multitag_dataset(base_output_path, \"en_ontonotes-ww-multi\", args.simplify, args.bigger_first)\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/ner/prepare_ner_dataset.py",
    "content": "\"\"\"Converts raw data files into json files usable by the training script.\n\nCurrently it supports converting WikiNER datasets, available here:\n  https://figshare.com/articles/dataset/Learning_multilingual_named_entity_recognition_from_Wikipedia/5462500\n  - download the language of interest to {Language}-WikiNER\n  - then run\n    prepare_ner_dataset.py French-WikiNER\n\nA gold re-edit of WikiNER for French is here:\n  - https://huggingface.co/datasets/danrun/WikiNER-fr-gold/tree/main\n  - https://arxiv.org/abs/2411.00030\n    Danrun Cao, Nicolas Béchet, Pierre-François Marteau\n  - download to $NERBASE/wikiner-fr-gold/wikiner-fr-gold.conll\n    prepare_ner_dataset.py fr_wikinergold\n\nFrench WikiNER and its gold re-edit can be mixed together with\n    prepare_ner_dataset.py fr_wikinermixed\n  - the data for both WikiNER and WikiNER-fr-gold needs to be in the right place first\n\nAlso, Finnish Turku dataset, available here:\n  - https://turkunlp.org/fin-ner.html\n  - https://github.com/TurkuNLP/turku-ner-corpus\n    git clone the repo into $NERBASE/finnish\n    you will now have a directory\n    $NERBASE/finnish/turku-ner-corpus\n  - prepare_ner_dataset.py fi_turku\n\nFBK in Italy produced an Italian dataset.\n  - KIND: an Italian Multi-Domain Dataset for Named Entity Recognition\n    Paccosi T. and Palmero Aprosio A.\n    LREC 2022\n  - https://arxiv.org/abs/2112.15099\n  The processing here is for a combined .tsv file they sent us.\n  - prepare_ner_dataset.py it_fbk\n  There is a newer version of the data available here:\n    https://github.com/dhfbk/KIND\n  TODO: update to the newer version of the data\n\nIJCNLP 2008 produced a few Indian language NER datasets.\n  description:\n    http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=3\n  download:\n    http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=5\n  The models produced from these datasets have extremely low recall, unfortunately.\n  - prepare_ner_dataset.py hi_ijc\n\nFIRE 2013 also produced NER datasets for Indian languages.\n  http://au-kbc.org/nlp/NER-FIRE2013/index.html\n  The datasets are password locked.\n  For Stanford users, contact Chris Manning for license details.\n  For external users, please contact the organizers for more information.\n  - prepare_ner_dataset.py hi-fire2013\n\nHiNER is another Hindi dataset option\n  https://github.com/cfiltnlp/HiNER\n  - HiNER: A Large Hindi Named Entity Recognition Dataset\n    Murthy, Rudra and Bhattacharjee, Pallab and Sharnagat, Rahul and\n    Khatri, Jyotsana and Kanojia, Diptesh and Bhattacharyya, Pushpak\n  There are two versions:\n    hi_hinercollapsed and hi_hiner\n  The collapsed version has just PER, LOC, ORG\n  - convert data as follows:\n    cd $NERBASE\n    mkdir hindi\n    cd hindi\n    git clone git@github.com:cfiltnlp/HiNER.git\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hi_hiner\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hi_hinercollapsed\n\nIL-NER has four datasets: HI, OR, TE, UR\n  https://github.com/ltrc/IL-NER\n  - Fine-tuning Pre-trained Named Entity Recognition Models For Indian Languages\n    Bahad, Sankalp and Mishra, Pruthwik and\n    Krishnamurthy, Parameswari and Sharma, Dipti\n  Convert the data as follows:\n    cd $NERBASE\n    mkdir indic\n    cd indic\n    git clone git@github.com:ltrc/IL-NER.git\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset or_ilner\n\nsuralk/multiNER contains three languages, EN, SI, and TA\n  https://github.com/suralk/multiNER\n  https://arxiv.org/abs/2412.02056\n  - Ranathunga, Surangika, et al.\n    A Multi-way Parallel Named Entity Annotated Corpus for English, Tamil and Sinhala\n  The tags are in BIO format, with the same 4 tags as CoNLL\n  Convert the data as follows:\n    cd $NERBASE\n    mkdir mixed\n    cd mixed\n    git clone git@github.com:suralk/multiNER.git\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset ta_suralk\n\nUkranian NER is provided by lang-uk, available here:\n  https://github.com/lang-uk/ner-uk\n  git clone the repo to $NERBASE/lang-uk\n  There should be a subdirectory $NERBASE/lang-uk/ner-uk/data at that point\n  Conversion script graciously provided by Andrii Garkavyi @gawy\n  - prepare_ner_dataset.py uk_languk\n\nThere are two Hungarian datasets are available here:\n  https://rgai.inf.u-szeged.hu/node/130\n  http://www.lrec-conf.org/proceedings/lrec2006/pdf/365_pdf.pdf\n  We combined them and give them the label hu_rgai\n  You can also build individual pieces with hu_rgai_business or hu_rgai_criminal\n  Create a subdirectory of $NERBASE, $NERBASE/hu_rgai, and download both of\n    the pieces and unzip them in that directory.\n  - prepare_ner_dataset.py hu_rgai\n\nAnother Hungarian dataset is here:\n  - https://github.com/nytud/NYTK-NerKor\n  - git clone the entire thing in your $NERBASE directory to operate on it\n  - prepare_ner_dataset.py hu_nytk\n\nThe two Hungarian datasets can be combined with hu_combined\n  TODO: verify that there is no overlap in text\n  - prepare_ner_dataset.py hu_combined\n\nBSNLP publishes NER datasets for Eastern European languages.\n  - In 2019 they published BG, CS, PL, RU.\n  - http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html\n  - In 2021 they added some more data, but the test sets\n    were not publicly available as of April 2021.\n    Therefore, currently the model is made from 2019.\n    In 2021, the link to the 2021 task is here:\n    http://bsnlp.cs.helsinki.fi/shared-task.html\n  - The below method processes the 2019 version of the corpus.\n    It has specific adjustments for the BG section, which has\n    quite a few typos or mis-annotations in it.  Other languages\n    probably need similar work in order to function optimally.\n  - make a directory $NERBASE/bsnlp2019\n  - download the \"training data are available HERE\" and\n    \"test data are available HERE\" to this subdirectory\n  - unzip those files in that directory\n  - we use the code name \"bg_bsnlp19\".  Other languages from\n    bsnlp 2019 can be supported by adding the appropriate\n    functionality in convert_bsnlp.py.\n  - prepare_ner_dataset.py bg_bsnlp19\n\nNCHLT produced NER datasets for many African languages.\n  Unfortunately, it is difficult to make use of many of these,\n  as there is no corresponding UD data from which to build a\n  tokenizer or other tools.\n  - Afrikaans:  https://repo.sadilar.org/handle/20.500.12185/299\n  - isiNdebele: https://repo.sadilar.org/handle/20.500.12185/306\n  - isiXhosa:   https://repo.sadilar.org/handle/20.500.12185/312\n  - isiZulu:    https://repo.sadilar.org/handle/20.500.12185/319\n  - Sepedi:     https://repo.sadilar.org/handle/20.500.12185/328\n  - Sesotho:    https://repo.sadilar.org/handle/20.500.12185/334\n  - Setswana:   https://repo.sadilar.org/handle/20.500.12185/341\n  - Siswati:    https://repo.sadilar.org/handle/20.500.12185/346\n  - Tsivenda:   https://repo.sadilar.org/handle/20.500.12185/355\n  - Xitsonga:   https://repo.sadilar.org/handle/20.500.12185/362\n  Agree to the license, download the zip, and unzip it in\n  $NERBASE/NCHLT\n\nUCSY built a Myanmar dataset.  They have not made it publicly\n  available, but they did make it available to Stanford for research\n  purposes.  Contact Chris Manning or John Bauer for the data files if\n  you are Stanford affiliated.\n  - https://arxiv.org/abs/1903.04739\n  - Syllable-based Neural Named Entity Recognition for Myanmar Language\n    by Hsu Myat Mo and Khin Mar Soe\n\nHanieh Poostchi et al produced a Persian NER dataset:\n  - git@github.com:HaniehP/PersianNER.git\n  - https://github.com/HaniehP/PersianNER\n  - Hanieh Poostchi, Ehsan Zare Borzeshi, Mohammad Abdous, and Massimo Piccardi,\n    \"PersoNER: Persian Named-Entity Recognition\"\n  - Hanieh Poostchi, Ehsan Zare Borzeshi, and Massimo Piccardi,\n    \"BiLSTM-CRF for Persian Named-Entity Recognition; ArmanPersoNERCorpus: the First Entity-Annotated Persian Dataset\"\n  - Conveniently, this dataset is already in BIO format.  It does not have a dev split, though.\n    git clone the above repo, unzip ArmanPersoNERCorpus.zip, and this script will split the\n    first train fold into a dev section.\n\nSUC3 is a Swedish NER dataset provided by Språkbanken\n  - https://spraakbanken.gu.se/en/resources/suc3\n  - The splitting tool is generously provided by\n    Emil Stenstrom\n    https://github.com/EmilStenstrom/suc_to_iob\n  - Download the .bz2 file at this URL and put it in $NERBASE/sv_suc3shuffle\n    It is not necessary to unzip it.\n  - Gustafson-Capková, Sophia and Britt Hartmann, 2006, \n    Manual of the Stockholm Umeå Corpus version 2.0.\n    Stockholm University.\n  - Östling, Robert, 2013, Stagger \n    an Open-Source Part of Speech Tagger for Swedish\n    Northern European Journal of Language Technology 3: 1–18\n    DOI 10.3384/nejlt.2000-1533.1331\n  - The shuffled dataset can be converted with dataset code\n    prepare_ner_dataset.py sv_suc3shuffle\n  - If you fill out the license form and get the official data,\n    you can get the official splits by putting the provided zip file\n    in $NERBASE/sv_suc3licensed.  Again, not necessary to unzip it\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset sv_suc3licensed\n\nDDT is a reformulation of the Danish Dependency Treebank as an NER dataset\n  - https://danlp-alexandra.readthedocs.io/en/latest/docs/datasets.html#dane\n  - direct download link as of late 2021: https://danlp.alexandra.dk/304bd159d5de/datasets/ddt.zip\n  - https://aclanthology.org/2020.lrec-1.565.pdf\n    DaNE: A Named Entity Resource for Danish\n    Rasmus Hvingelby, Amalie Brogaard Pauli, Maria Barrett,\n    Christina Rosted, Lasse Malm Lidegaard, Anders Søgaard\n  - place ddt.zip in $NERBASE/da_ddt/ddt.zip\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset da_ddt\n\nNorNE is the Norwegian Dependency Treebank with NER labels\n  - LREC 2020\n    NorNE: Annotating Named Entities for Norwegian\n    Fredrik Jørgensen, Tobias Aasmoe, Anne-Stine Ruud Husevåg,\n    Lilja Øvrelid, and Erik Velldal\n  - both Bokmål and Nynorsk\n  - This dataset is in a git repo:\n    https://github.com/ltgoslo/norne\n    Clone it into $NERBASE\n    git clone git@github.com:ltgoslo/norne.git\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset nb_norne\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset nn_norne\n\ntr_starlang is a set of constituency trees for Turkish\n  The words in this dataset (usually) have NER labels as well\n\n  A dataset in three parts from the Starlang group in Turkey:\n  Neslihan Kara, Büşra Marşan, et al\n    Creating A Syntactically Felicitous Constituency Treebank For Turkish\n    https://ieeexplore.ieee.org/document/9259873\n  git clone the following three repos\n    https://github.com/olcaytaner/TurkishAnnotatedTreeBank-15\n    https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-15\n    https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-20\n  Put them in\n    $CONSTITUENCY_HOME/turkish    (yes, the constituency home)\n  python3 -m stanza.utils.datasets.ner.prepare_ner_dataset tr_starlang\n\nGermEval2014 is a German NER dataset\n  https://sites.google.com/site/germeval2014ner/data\n  https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J\n  Download the files in that directory\n    NER-de-train.tsv NER-de-dev.tsv NER-de-test.tsv\n  put them in\n    $NERBASE/germeval2014\n  then run\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset de_germeval2014\n\nThe UD Japanese GSD dataset has a conversion by Megagon Labs\n  https://github.com/megagonlabs/UD_Japanese-GSD\n  https://github.com/megagonlabs/UD_Japanese-GSD/tags\n  - r2.9-NE has the NE tagged files inside a \"spacy\"\n    folder in the download\n  - expected directory for this data:\n    unzip the .zip of the release into\n      $NERBASE/ja_gsd\n    so it should wind up in\n      $NERBASE/ja_gsd/UD_Japanese-GSD-r2.9-NE\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset ja_gsd\n\nL3Cube is a Marathi dataset\n  - https://arxiv.org/abs/2204.06029\n    https://arxiv.org/pdf/2204.06029.pdf\n    https://github.com/l3cube-pune/MarathiNLP\n  - L3Cube-MahaNER: A Marathi Named Entity Recognition Dataset and BERT models\n    Parth Patil, Aparna Ranade, Maithili Sabane, Onkar Litake, Raviraj Joshi\n\n  Clone the repo into $NERBASE/marathi\n    git clone git@github.com:l3cube-pune/MarathiNLP.git\n  Then run\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset mr_l3cube\n\nDaffodil University produced a Bangla NER dataset\n  - https://github.com/Rifat1493/Bengali-NER\n  - https://ieeexplore.ieee.org/document/8944804\n  - Bengali Named Entity Recognition:\n    A survey with deep learning benchmark\n    Md Jamiur Rahman Rifat, Sheikh Abujar, Sheak Rashed Haider Noori,\n    Syed Akhter Hossain\n\n  Clone the repo into a \"bangla\" subdirectory of $NERBASE\n    cd $NERBASE/bangla\n    git clone git@github.com:Rifat1493/Bengali-NER.git\n  Then run\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset bn_daffodil\n\nLST20 is a Thai NER dataset from 2020\n  - https://arxiv.org/abs/2008.05055\n    The Annotation Guideline of LST20 Corpus\n    Prachya Boonkwan, Vorapon Luantangsrisuk, Sitthaa Phaholphinyo,\n    Kanyanat Kriengket, Dhanon Leenoi, Charun Phrombut,\n    Monthika Boriboon, Krit Kosawat, Thepchai Supnithi\n  - This script processes a version which can be downloaded here after registration:\n    https://aiforthai.in.th/index.php\n  - There is another version downloadable from HuggingFace\n    The script will likely need some modification to be compatible\n    with the HuggingFace version\n  - Download the data in $NERBASE/thai/LST20_Corpus\n    There should be \"train\", \"eval\", \"test\" directories after downloading\n  - Then run\n    pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset th_lst20\n\nThai-NNER is another Thai NER dataset, from 2022\n  - https://github.com/vistec-AI/Thai-NNER\n  - https://aclanthology.org/2022.findings-acl.116/\n    Thai Nested Named Entity Recognition Corpus\n    Weerayut Buaphet, Can Udomcharoenchaikit, Peerat Limkonchotiwat,\n    Attapol Rutherford, and Sarana Nutanong\n  - git clone the data to $NERBASE/thai\n  - On the git repo, there should be a link to a more complete version\n    of the dataset.  For example, in Sep. 2023 it is here:\n    https://github.com/vistec-AI/Thai-NNER#dataset\n    The Google drive it goes to has \"postproc\".\n    Put the train.json, dev.json, and test.json in\n    $NERBASE/thai/Thai-NNER/data/scb-nner-th-2022/postproc/\n  - Then run\n    pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset th_nner22\n\n\nNKJP is a Polish NER dataset\n  - http://nkjp.pl/index.php?page=0&lang=1\n    About the Project\n  - http://zil.ipipan.waw.pl/DistrNKJP\n    Wikipedia subcorpus used to train charlm model\n  - http://clip.ipipan.waw.pl/NationalCorpusOfPolish?action=AttachFile&do=view&target=NKJP-PodkorpusMilionowy-1.2.tar.gz\n    Annotated subcorpus to train NER model.\n    Download and extract to $NERBASE/Polish-NKJP or leave the gzip in $NERBASE/polish/...\n\nkk_kazNERD is a Kazakh dataset published in 2021\n  - https://github.com/IS2AI/KazNERD\n  - https://arxiv.org/abs/2111.13419\n    KazNERD: Kazakh Named Entity Recognition Dataset\n    Rustem Yeshpanov, Yerbolat Khassanov, Huseyin Atakan Varol\n  - in $NERBASE, make a \"kazakh\" directory, then git clone the repo there\n    mkdir -p $NERBASE/kazakh\n    cd $NERBASE/kazakh\n    git clone git@github.com:IS2AI/KazNERD.git\n  - Then run\n    pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset kk_kazNERD\n\nMasakhane NER is a set of NER datasets for African languages\n  - MasakhaNER: Named Entity Recognition for African Languages\n    Adelani, David Ifeoluwa; Abbott, Jade; Neubig, Graham;\n    D’souza, Daniel; Kreutzer, Julia; Lignos, Constantine;\n    Palen-Michel, Chester; Buzaaba, Happy; Rijhwani, Shruti;\n    Ruder, Sebastian; Mayhew, Stephen; Azime, Israel Abebe;\n    Muhammad, Shamsuddeen H.; Emezue, Chris Chinenye;\n    Nakatumba-Nabende, Joyce; Ogayo, Perez; Anuoluwapo, Aremu;\n    Gitau, Catherine; Mbaye, Derguene; Alabi, Jesujoba;\n    Yimam, Seid Muhie; Gwadabe, Tajuddeen Rabiu; Ezeani, Ignatius;\n    Niyongabo, Rubungo Andre; Mukiibi, Jonathan; Otiende, Verrah;\n    Orife, Iroro; David, Davis; Ngom, Samba; Adewumi, Tosin;\n    Rayson, Paul; Adeyemi, Mofetoluwa; Muriuki, Gerald;\n    Anebi, Emmanuel; Chukwuneke, Chiamaka; Odu, Nkiruka;\n    Wairagala, Eric Peter; Oyerinde, Samuel; Siro, Clemencia;\n    Bateesa, Tobius Saul; Oloyede, Temilola; Wambui, Yvonne;\n    Akinode, Victor; Nabagereka, Deborah; Katusiime, Maurice;\n    Awokoya, Ayodele; MBOUP, Mouhamadane; Gebreyohannes, Dibora;\n    Tilaye, Henok; Nwaike, Kelechi; Wolde, Degaga; Faye, Abdoulaye;\n    Sibanda, Blessing; Ahia, Orevaoghene; Dossou, Bonaventure F. P.;\n    Ogueji, Kelechi; DIOP, Thierno Ibrahima; Diallo, Abdoulaye;\n    Akinfaderin, Adewale; Marengereke, Tendai; Osei, Salomey\n  - https://github.com/masakhane-io/masakhane-ner\n  - git clone the repo to $NERBASE\n  - Then run\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset lcode_masakhane\n  - You can use the full language name, the 3 letter language code,\n    or in the case of languages with a 2 letter language code,\n    the 2 letter code for lcode.  The tool will throw an error\n    if the language is not supported in Masakhane.\n\nSiNER is a Sindhi NER dataset\n  - https://aclanthology.org/2020.lrec-1.361/\n    SiNER: A Large Dataset for Sindhi Named Entity Recognition\n    Wazir Ali, Junyu Lu, Zenglin Xu\n  - It is available via git repository\n    https://github.com/AliWazir/SiNER-dataset\n    As of Nov. 2022, there were a few changes to the dataset\n    to update a couple instances of broken tags & tokenization\n  - Clone the repo to $NERBASE/sindhi\n    mkdir $NERBASE/sindhi\n    cd $NERBASE/sindhi\n    git clone git@github.com:AliWazir/SiNER-dataset.git\n  - Then, prepare the dataset with this script:\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset sd_siner\n\nen_sample is the toy dataset included with stanza-train\n  https://github.com/stanfordnlp/stanza-train\n  this is not meant for any kind of actual NER use\n\nArmTDP-NER is an Armenian NER dataset\n  - https://github.com/myavrum/ArmTDP-NER.git\n    ArmTDP-NER: The corpus was developed by the ArmTDP team led by Marat M. Yavrumyan\n    at the Yerevan State University by the collaboration of \"Armenia National SDG Innovation Lab\"\n    and \"UC Berkley's Armenian Linguists' network\".\n  - in $NERBASE, make a \"armenian\" directory, then git clone the repo there\n    mkdir -p $NERBASE/armenian\n    cd $NERBASE/armenian\n    git clone https://github.com/myavrum/ArmTDP-NER.git\n  - Then run\n    python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hy_armtdp\n\nen_conll03 is the classic 2003 4 class CoNLL dataset\n  - The version we use is posted on HuggingFace\n  - https://huggingface.co/datasets/conll2003\n  - The prepare script will download from HF\n    using the datasets package, then convert to json\n  - Introduction to the CoNLL-2003 Shared Task:\n    Language-Independent Named Entity Recognition\n    Tjong Kim Sang, Erik F. and De Meulder, Fien\n  - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conll03\n\nen_conll03ww is CoNLL 03 with Worldwide added to the training data.\n  - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conll03ww\n\nen_conllpp is a test set from 2020 newswire\n  - https://arxiv.org/abs/2212.09747\n  - https://github.com/ShuhengL/acl2023_conllpp\n  - Do CoNLL-2003 Named Entity Taggers Still Work Well in 2023?\n    Shuheng Liu, Alan Ritter\n  - git clone the repo in $NERBASE\n  - then run\n    python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conllpp\n\nen_ontonotes is the OntoNotes 5 on HuggingFace\n  - https://huggingface.co/datasets/conll2012_ontonotesv5\n  - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_ontonotes\n  - this downloads the \"v12\" version of the data\n\nen_worldwide-4class is an English non-US newswire dataset\n  - annotated by MLTwist and Aya Data, with help from Datasaur,\n    collected at Stanford\n  - work to be published at EMNLP Findings\n  - the 4 class version is converted to the 4 classes in conll,\n    then split into train/dev/test\n  - clone https://github.com/stanfordnlp/en-worldwide-newswire\n    into $NERBASE/en_worldwide\n\nen_worldwide-9class is an English non-US newswire dataset\n  - annotated by MLTwist and Aya Data, with help from Datasaur,\n    collected at Stanford\n  - work to be published at EMNLP Findings\n  - the 9 class version is not edited\n  - clone https://github.com/stanfordnlp/en-worldwide-newswire\n    into $NERBASE/en_worldwide\n\nzh-hans_ontonotes is the ZH split of the OntoNotes dataset\n  - https://catalog.ldc.upenn.edu/LDC2013T19\n  - https://huggingface.co/datasets/conll2012_ontonotesv5\n  - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py zh-hans_ontonotes\n  - this downloads the \"v4\" version of the data\n\n\nAQMAR is a small dataset of Arabic Wikipedia articles\n  - http://www.cs.cmu.edu/~ark/ArabicNER/\n  - Recall-Oriented Learning of Named Entities in Arabic Wikipedia\n    Behrang Mohit, Nathan Schneider, Rishav Bhowmick, Kemal Oflazer, and Noah A. Smith.\n    In Proceedings of the 13th Conference of the European Chapter of\n    the Association for Computational Linguistics, Avignon, France,\n    April 2012.\n  - download the .zip file there and put it in\n    $NERBASE/arabic/AQMAR\n  - there is a challenge for it here:\n    https://www.topcoder.com/challenges/f3cf483e-a95c-4a7e-83e8-6bdd83174d38\n  - alternatively, we just randomly split it ourselves\n  - currently, running the following reproduces the random split:\n    python3 stanza/utils/datasets/ner/prepare_ner_dataset.py ar_aqmar\n\nIAHLT contains NER for Hebrew in the knesset treebank\n  - as of UD 2.14, it is only in the git repo\n  - download that git repo to $UDBASE_GIT:\n    https://github.com/UniversalDependencies/UD_Hebrew-IAHLTknesset\n  - change to the dev branch in that repo\n    python3 stanza/utils/datasets/ner/prepare_ner_dataset.py he_iahlt\n\nang_ewt is an Old English dataset available here:\n  https://github.com/dmetola/Old_English-OEDT/tree/main\n  As more information, including a citation, will be added here\n  - install in NERBASE:\n    mkdir $NERBASE/ang\n    cd $NERBASE/ang\n    git clone git@github.com:dmetola/Old_English-OEDT.git\n  - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py ang_ewt\n\"\"\"\n\nimport glob\nimport os\nimport json\nimport random\nimport re\nimport shutil\nimport sys\nimport tempfile\n\nfrom stanza.models.common.constant import treebank_to_short_name, lcode2lang, lang_to_langcode, two_to_three_letters\nfrom stanza.models.ner.utils import to_bio2, bio2_to_bioes\nimport stanza.utils.default_paths as default_paths\n\nfrom stanza.utils.datasets.common import UnknownDatasetError\nfrom stanza.utils.datasets.ner.preprocess_wikiner import preprocess_wikiner\nfrom stanza.utils.datasets.ner.split_wikiner import split_wikiner, split_wikiner_data\nimport stanza.utils.datasets.ner.build_en_combined as build_en_combined\nimport stanza.utils.datasets.ner.conll_to_iob as conll_to_iob\nimport stanza.utils.datasets.ner.convert_ar_aqmar as convert_ar_aqmar\nimport stanza.utils.datasets.ner.convert_bn_daffodil as convert_bn_daffodil\nimport stanza.utils.datasets.ner.convert_bsf_to_beios as convert_bsf_to_beios\nimport stanza.utils.datasets.ner.convert_bsnlp as convert_bsnlp\nimport stanza.utils.datasets.ner.convert_en_conll03 as convert_en_conll03\nimport stanza.utils.datasets.ner.convert_fire_2013 as convert_fire_2013\nimport stanza.utils.datasets.ner.convert_he_iahlt as convert_he_iahlt\nimport stanza.utils.datasets.ner.convert_ijc as convert_ijc\nimport stanza.utils.datasets.ner.convert_kk_kazNERD as convert_kk_kazNERD\nimport stanza.utils.datasets.ner.convert_lst20 as convert_lst20\nimport stanza.utils.datasets.ner.convert_nner22 as convert_nner22\nimport stanza.utils.datasets.ner.convert_mr_l3cube as convert_mr_l3cube\nimport stanza.utils.datasets.ner.convert_my_ucsy as convert_my_ucsy\nimport stanza.utils.datasets.ner.convert_ontonotes as convert_ontonotes\nimport stanza.utils.datasets.ner.convert_rgai as convert_rgai\nimport stanza.utils.datasets.ner.convert_nytk as convert_nytk\nimport stanza.utils.datasets.ner.convert_starlang_ner as convert_starlang_ner\nimport stanza.utils.datasets.ner.convert_nkjp as convert_nkjp\nimport stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file\nimport stanza.utils.datasets.ner.convert_sindhi_siner as convert_sindhi_siner\nimport stanza.utils.datasets.ner.ontonotes_multitag as ontonotes_multitag\nimport stanza.utils.datasets.ner.simplify_en_worldwide as simplify_en_worldwide\nimport stanza.utils.datasets.ner.suc_to_iob as suc_to_iob\nimport stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob\nimport stanza.utils.datasets.ner.convert_hy_armtdp as convert_hy_armtdp\nfrom stanza.utils.datasets.ner.utils import convert_bioes_to_bio, convert_bio_to_json, get_tags, read_tsv, write_sentences, write_dataset, random_shuffle_by_prefixes, read_prefix_file, combine_files\n\nSHARDS = ('train', 'dev', 'test')\n\ndef process_turku(paths, short_name):\n    assert short_name == 'fi_turku'\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"finnish\", \"turku-ner-corpus\", \"data\", \"conll\")\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    for shard in SHARDS:\n        input_filename = os.path.join(base_input_path, '%s.tsv' % shard)\n        if not os.path.exists(input_filename):\n            raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename))\n        output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))\n        prepare_ner_file.process_dataset(input_filename, output_filename)\n\ndef process_it_fbk(paths, short_name):\n    assert short_name == \"it_fbk\"\n    base_input_path = os.path.join(paths[\"NERBASE\"], short_name)\n    csv_file = os.path.join(base_input_path, \"all-wiki-split.tsv\")\n    if not os.path.exists(csv_file):\n        raise FileNotFoundError(\"Cannot find the FBK dataset in its expected location: {}\".format(csv_file))\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    split_wikiner(base_output_path, csv_file, prefix=short_name, suffix=\"io\", shuffle=False, train_fraction=0.8, dev_fraction=0.1)\n    convert_bio_to_json(base_output_path, base_output_path, short_name, suffix=\"io\")\n\ndef process_suralk_multiner(paths, short_name):\n    lang_filenames = {\n        \"en\": \"Final_English.txt\",\n        \"si\": \"Final_Sinhala.txt\",\n        \"ta\": \"Final_Tamil.txt\",\n    }\n    lang, ending = short_name.split(\"_\")\n    assert ending == \"suralk\"\n    assert lang in lang_filenames, \"suralk/multiNER only supports %s\" % (\", \".join(lang_filenames.keys()))\n    suralk_path = os.path.join(paths[\"NERBASE\"], \"mixed\", \"multiNER\", \"nerannotateddatasets.zip\")\n    if not os.path.exists(suralk_path):\n        raise FileNotFoundError(\"Expected to find the suralk/multiNER dataset in %s\" % suralk_path)\n    sentences = read_tsv(lang_filenames[lang], text_column=0, annotation_column=1, separator=None, zip_filename=suralk_path)\n    print(\"Read %d sentences from %s::%s\" % (len(sentences), suralk_path, lang_filenames[lang]))\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    split_wikiner_data(base_output_path, sentences, prefix=short_name, suffix=\"bio\", shuffle=True)\n    convert_bio_to_json(base_output_path, base_output_path, short_name, suffix=\"bio\")\n\ndef process_il_ner(paths, short_name):\n    joiner = chr(0x200c)\n    def fix_tag(tag):\n        if tag == '-':\n            return 'O'\n        if tag.endswith(\"'\"):\n            # not sure the correct fix, but we filed an issue, so hopefully they fix it\n            return \"O\"\n        if tag.endswith(\"NIMI\") or tag.endswith(\"NET\"):\n            return tag[:2] + \"NETI\"\n        tag = tag.replace(joiner, \"\").upper()\n        if tag.startswith(\"-\"):\n            return 'B%s' % tag\n        return tag\n\n    def fix_line(line):\n        if line == 'O':\n            return '-\\tO'\n        return line\n\n    lang_paths = {\n        \"hi\": \"Hindi\",\n        \"or\": \"Odia\",\n        \"te\": \"Telugu\",\n        \"ur\": \"Urdu\",\n    }\n    lang, ending = short_name.split(\"_\")\n    assert ending == \"ilner\"\n    assert lang in lang_paths, \"IL-NER only supports %s\" % (\", \".join(lang_paths.keys()))\n    ilner_path = os.path.join(paths[\"NERBASE\"], \"indic\", \"IL-NER\")\n    if not os.path.exists(ilner_path):\n        raise FileNotFounderror(\"Cannot find the IL-NER dataset in its expected location: {}\".format(ilner_path))\n    ilner_path = os.path.join(ilner_path, \"Datasets\", lang_paths[lang])\n    if not os.path.exists(ilner_path):\n        raise FileNotFoundError(\"IL-NER not in the layout expected: directory not found {}\".format(ilner_path))\n    filenames = os.listdir(ilner_path)\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    for shard in SHARDS:\n        input_filenames = [x for x in filenames if shard in x]\n        if len(input_filenames) == 0:\n            raise FileNotFoundError(\"No %s file in %s\" % (shard, ilner_path))\n        if len(input_filenames) > 1:\n            raise FileNotFoundError(\"Unexpected multiple files for %s in %s: %s\" % (shard, ilner_path, input_filenames))\n        input_filename = os.path.join(ilner_path, input_filenames[0])\n        int_filename = os.path.join(base_output_path, '%s.%s.tsv' % (short_name, shard))\n        output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))\n        sentences = read_tsv(input_filename, text_column=0, annotation_column=1, remap_tag_fn=fix_tag, remap_line=fix_line)\n        print(\"Loaded %d sentences from %s\" % (len(sentences), input_filename))\n        write_sentences(int_filename, sentences)\n\n        prepare_ner_file.process_dataset(int_filename, output_filename)\n\ndef process_languk(paths, short_name):\n    assert short_name == 'uk_languk'\n    base_input_path = os.path.join(paths[\"NERBASE\"], 'lang-uk', 'ner-uk', 'data')\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    train_test_split_fname = os.path.join(paths[\"NERBASE\"], 'lang-uk', 'ner-uk', 'doc', 'dev-test-split.txt')\n    convert_bsf_to_beios.convert_bsf_in_folder(base_input_path, base_output_path, train_test_split_file=train_test_split_fname)\n    for shard in SHARDS:\n        input_filename = os.path.join(base_output_path, convert_bsf_to_beios.CORPUS_NAME, \"%s.bio\" % shard)\n        if not os.path.exists(input_filename):\n            raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename))\n        output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))\n        prepare_ner_file.process_dataset(input_filename, output_filename)\n\n\ndef process_ijc(paths, short_name):\n    \"\"\"\n    Splits the ijc Hindi dataset in train, dev, test\n\n    The original data had train & test splits, so we randomly divide\n    the files in train to make a dev set.\n\n    The expected location of the IJC data is hi_ijc.  This method\n    should be possible to use for other languages, but we have very\n    little support for the other languages of IJC at the moment.\n    \"\"\"\n    base_input_path = os.path.join(paths[\"NERBASE\"], short_name)\n    base_output_path = paths[\"NER_DATA_DIR\"]\n\n    test_files = [os.path.join(base_input_path, \"test-data-hindi.txt\")]\n    test_csv_file = os.path.join(base_output_path, short_name + \".test.csv\")\n    print(\"Converting test input %s to space separated file in %s\" % (test_files[0], test_csv_file))\n    convert_ijc.convert_ijc(test_files, test_csv_file)\n\n    train_input_path = os.path.join(base_input_path, \"training-hindi\", \"*utf8\")\n    train_files = glob.glob(train_input_path)\n    train_csv_file = os.path.join(base_output_path, short_name + \".train.csv\")\n    dev_csv_file = os.path.join(base_output_path, short_name + \".dev.csv\")\n    print(\"Converting training input from %s to space separated files in %s and %s\" % (train_input_path, train_csv_file, dev_csv_file))\n    convert_ijc.convert_split_ijc(train_files, train_csv_file, dev_csv_file)\n\n    for csv_file, shard in zip((train_csv_file, dev_csv_file, test_csv_file), SHARDS):\n        output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))\n        prepare_ner_file.process_dataset(csv_file, output_filename)\n\n\ndef process_fire_2013(paths, dataset):\n    \"\"\"\n    Splits the FIRE 2013 dataset into train, dev, test\n\n    The provided datasets are all mixed together at this point, so it\n    is not possible to recreate the original test conditions used in\n    the bakeoff\n    \"\"\"\n    short_name = treebank_to_short_name(dataset)\n    langcode, _ = short_name.split(\"_\")\n    short_name = \"%s_fire2013\" % langcode\n    if not langcode in (\"hi\", \"en\", \"ta\", \"bn\", \"mal\"):\n        raise UnkonwnDatasetError(dataset, \"Language %s not one of the FIRE 2013 languages\" % langcode)\n    language = lcode2lang[langcode].lower()\n    \n    # for example, FIRE2013/hindi_train\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"FIRE2013\", \"%s_train\" % language)\n    base_output_path = paths[\"NER_DATA_DIR\"]\n\n    train_csv_file = os.path.join(base_output_path, \"%s.train.csv\" % short_name)\n    dev_csv_file   = os.path.join(base_output_path, \"%s.dev.csv\" % short_name)\n    test_csv_file  = os.path.join(base_output_path, \"%s.test.csv\" % short_name)\n\n    convert_fire_2013.convert_fire_2013(base_input_path, train_csv_file, dev_csv_file, test_csv_file)\n\n    for csv_file, shard in zip((train_csv_file, dev_csv_file, test_csv_file), SHARDS):\n        output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))\n        prepare_ner_file.process_dataset(csv_file, output_filename)\n\ndef process_wikiner(paths, dataset):\n    short_name = treebank_to_short_name(dataset)\n\n    base_input_path = os.path.join(paths[\"NERBASE\"], dataset)\n    base_output_path = paths[\"NER_DATA_DIR\"]\n\n    expected_filename = \"aij*wikiner*\"\n    input_files = [x for x in glob.glob(os.path.join(base_input_path, expected_filename)) if not x.endswith(\"bz2\")]\n    if len(input_files) == 0:\n        raw_input_path = os.path.join(base_input_path, \"raw\")\n        input_files = [x for x in glob.glob(os.path.join(raw_input_path, expected_filename)) if not x.endswith(\"bz2\")]\n        if len(input_files) > 1:\n            raise FileNotFoundError(\"Found too many raw wikiner files in %s: %s\" % (raw_input_path, \", \".join(input_files)))\n    elif len(input_files) > 1:\n        raise FileNotFoundError(\"Found too many raw wikiner files in %s: %s\" % (base_input_path, \", \".join(input_files)))\n\n    if len(input_files) == 0:\n        raise FileNotFoundError(\"Could not find any raw wikiner files in %s or %s\" % (base_input_path, raw_input_path))\n\n    csv_file = os.path.join(base_output_path, short_name + \"_csv\")\n    print(\"Converting raw input %s to space separated file in %s\" % (input_files[0], csv_file))\n    try:\n        preprocess_wikiner(input_files[0], csv_file)\n    except UnicodeDecodeError:\n        preprocess_wikiner(input_files[0], csv_file, encoding=\"iso8859-1\")\n\n    # this should create train.bio, dev.bio, and test.bio\n    print(\"Splitting %s to %s\" % (csv_file, base_output_path))\n    split_wikiner(base_output_path, csv_file, prefix=short_name)\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef process_french_wikiner_gold(paths, dataset):\n    short_name = treebank_to_short_name(dataset)\n\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"wikiner-fr-gold\")\n    base_output_path = paths[\"NER_DATA_DIR\"]\n\n    input_filename = os.path.join(base_input_path, \"wikiner-fr-gold.conll\")\n    if not os.path.exists(input_filename):\n        raise FileNotFoundError(\"Could not find the expected input file %s for dataset %s\" % (input_filename, base_input_path))\n\n    print(\"Reading %s\" % input_filename)\n    sentences = read_tsv(input_filename, text_column=0, annotation_column=2, separator=\" \")\n    print(\"Read %d sentences\" % len(sentences))\n\n    tags = [y for sentence in sentences for x, y in sentence]\n    tags = sorted(set(tags))\n    print(\"Found the following tags:\\n%s\" % tags)\n    expected_tags = ['B-LOC', 'B-MISC', 'B-ORG', 'B-PER',\n                     'E-LOC', 'E-MISC', 'E-ORG', 'E-PER',\n                     'I-LOC', 'I-MISC', 'I-ORG', 'I-PER',\n                     'O',\n                     'S-LOC', 'S-MISC', 'S-ORG', 'S-PER']\n    assert tags == expected_tags\n\n    output_filename = os.path.join(base_output_path, \"%s.full.bioes\" % short_name)\n    print(\"Writing BIOES to %s\" % output_filename)\n    write_sentences(output_filename, sentences)\n\n    print(\"Splitting %s to %s\" % (output_filename, base_output_path))\n    split_wikiner(base_output_path, output_filename, prefix=short_name, suffix=\"bioes\")\n    convert_bioes_to_bio(base_output_path, base_output_path, short_name)\n    convert_bio_to_json(base_output_path, base_output_path, short_name, suffix=\"bioes\")\n\ndef process_french_wikiner_mixed(paths, dataset):\n    \"\"\"\n    Build both the original and gold edited versions of WikiNER, then mix them\n\n    First we eliminate any duplicates (with one exception), then we combine the data\n\n    There are two main ways we could have done this:\n      - mix it together without any restrictions\n      - use the multi_ner mechanism to build a dataset which represents two prediction heads\n\n    The second method seems to give slightly better results than the first method,\n    but neither beat just using a transformer on the gold set alone\n\n    On the randomly selected test set, using WV and charlm but not a transformer\n    (this was on a previously published version of the dataset):\n\n    one prediction head:\n      INFO: Score by entity:\n        Prec.   Rec.    F1\n        89.32   89.26   89.29\n      INFO: Score by token:\n        Prec.   Rec.    F1\n        89.43   86.88   88.14\n      INFO: Weighted f1 for non-O tokens: 0.878855\n\n    two prediction heads:\n      INFO: Score by entity:\n        Prec.   Rec.    F1\n        89.83   89.76   89.79\n      INFO: Score by token:\n        Prec.   Rec.    F1\n        89.17   88.15   88.66\n      INFO: Weighted f1 for non-O tokens: 0.885675\n\n    On a randomly selected dev set, using transformer:\n\n    gold:\n      INFO: Score by entity:\n        Prec.   Rec.    F1\n        93.63   93.98   93.81\n      INFO: Score by token:\n        Prec.   Rec.    F1\n        92.80   92.79   92.80\n      INFO: Weighted f1 for non-O tokens: 0.927548\n\n    mixed:\n      INFO: Score by entity:\n        Prec.   Rec.    F1\n        93.54   93.82   93.68\n      INFO: Score by token:\n        Prec.   Rec.    F1\n        92.99   92.51   92.75\n      INFO: Weighted f1 for non-O tokens: 0.926964\n    \"\"\"\n    short_name = treebank_to_short_name(dataset)\n\n    process_french_wikiner_gold(paths, \"fr_wikinergold\")\n    process_wikiner(paths, \"French-WikiNER\")\n    base_output_path = paths[\"NER_DATA_DIR\"]\n\n    with open(os.path.join(base_output_path, \"fr_wikinergold.train.json\")) as fin:\n        gold_train = json.load(fin)\n    with open(os.path.join(base_output_path, \"fr_wikinergold.dev.json\")) as fin:\n        gold_dev = json.load(fin)\n    with open(os.path.join(base_output_path, \"fr_wikinergold.test.json\")) as fin:\n        gold_test = json.load(fin)\n\n    gold = gold_train + gold_dev + gold_test\n    print(\"%d total sentences in the gold relabeled dataset (randomly split)\" % len(gold))\n    gold = {tuple([x[\"text\"] for x in sentence]): sentence for sentence in gold}\n    print(\"  (%d after dedup)\" % len(gold))\n\n    original = (read_tsv(os.path.join(base_output_path, \"fr_wikiner.train.bio\"), text_column=0, annotation_column=1) +\n                read_tsv(os.path.join(base_output_path, \"fr_wikiner.dev.bio\"), text_column=0, annotation_column=1) +\n                read_tsv(os.path.join(base_output_path, \"fr_wikiner.test.bio\"), text_column=0, annotation_column=1))\n    print(\"%d total sentences in the original wiki\" % len(original))\n    original_words = {tuple([x[0] for x in sentence]) for sentence in original}\n    print(\"  (%d after dedup)\" % len(original_words))\n\n    missing = [sentence for sentence in gold if sentence not in original_words]\n    for sentence in missing:\n        # the capitalization of WisiGoths and OstroGoths is different\n        # between the original and the new in some cases\n        goths = tuple([x.replace(\"Goth\", \"goth\") for x in sentence])\n        if goths != sentence and goths in original_words:\n            original_words.add(sentence)\n    missing = [sentence for sentence in gold if sentence not in original_words]\n    # currently this dataset doesn't find two sentences\n    # one was dropped by the filter for incompletely tagged lines\n    # the other is probably not a huge deal to have one duplicate\n    print(\"Missing %d sentences\" % len(missing))\n    assert len(missing) <= 2\n    for sent in missing:\n        print(sent)\n\n    skipped = 0\n    silver = []\n    silver_used = set()\n    for sentence in original:\n        words = tuple([x[0] for x in sentence])\n        tags = tuple([x[1] for x in sentence])\n        if words in gold or words in silver_used:\n            skipped += 1\n            continue\n        tags = to_bio2(tags)\n        tags = bio2_to_bioes(tags)\n        sentence = [{\"text\": x, \"ner\": y, \"multi_ner\": [\"-\", y]} for x, y in zip(words, tags)]\n        silver.append(sentence)\n        silver_used.add(words)\n    print(\"Using %d sentences from the original wikiner alongside the gold annotated train set\" % len(silver))\n    print(\"Skipped %d sentences\" % skipped)\n\n    gold_train = [[{\"text\": x[\"text\"], \"ner\": x[\"ner\"], \"multi_ner\": [x[\"ner\"], \"-\"]} for x in sentence]\n                  for sentence in gold_train]\n    gold_dev = [[{\"text\": x[\"text\"], \"ner\": x[\"ner\"], \"multi_ner\": [x[\"ner\"], \"-\"]} for x in sentence]\n                  for sentence in gold_dev]\n    gold_test = [[{\"text\": x[\"text\"], \"ner\": x[\"ner\"], \"multi_ner\": [x[\"ner\"], \"-\"]} for x in sentence]\n                  for sentence in gold_test]\n\n    mixed_train = gold_train + silver\n    print(\"Total sentences in the mixed training set: %d\" % len(mixed_train))\n    output_filename = os.path.join(base_output_path, \"%s.train.json\" % short_name)\n    with open(output_filename, 'w', encoding='utf-8') as fout:\n        json.dump(mixed_train, fout, indent=1)\n\n    output_filename = os.path.join(base_output_path, \"%s.dev.json\" % short_name)\n    with open(output_filename, 'w', encoding='utf-8') as fout:\n        json.dump(gold_dev, fout, indent=1)\n    output_filename = os.path.join(base_output_path, \"%s.test.json\" % short_name)\n    with open(output_filename, 'w', encoding='utf-8') as fout:\n        json.dump(gold_test, fout, indent=1)\n\n\ndef get_rgai_input_path(paths):\n    return os.path.join(paths[\"NERBASE\"], \"hu_rgai\")\n\ndef process_rgai(paths, short_name):\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    base_input_path = get_rgai_input_path(paths)\n\n    if short_name == 'hu_rgai':\n        use_business = True\n        use_criminal = True\n    elif short_name == 'hu_rgai_business':\n        use_business = True\n        use_criminal = False\n    elif short_name == 'hu_rgai_criminal':\n        use_business = False\n        use_criminal = True\n    else:\n        raise UnknownDatasetError(short_name, \"Unknown subset of hu_rgai data: %s\" % short_name)\n\n    convert_rgai.convert_rgai(base_input_path, base_output_path, short_name, use_business, use_criminal)\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef get_nytk_input_path(paths):\n    return os.path.join(paths[\"NERBASE\"], \"NYTK-NerKor\")\n\ndef process_nytk(paths, short_name):\n    \"\"\"\n    Process the NYTK dataset\n    \"\"\"\n    assert short_name == \"hu_nytk\"\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    base_input_path = get_nytk_input_path(paths)\n\n    convert_nytk.convert_nytk(base_input_path, base_output_path, short_name)\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef concat_files(output_file, *input_files):\n    input_lines = []\n    for input_file in input_files:\n        with open(input_file) as fin:\n            lines = fin.readlines()\n        if not len(lines):\n            raise ValueError(\"Empty input file: %s\" % input_file)\n        if not lines[-1]:\n            lines[-1] = \"\\n\"\n        elif lines[-1].strip():\n            lines.append(\"\\n\")\n        input_lines.append(lines)\n    with open(output_file, \"w\") as fout:\n        for lines in input_lines:\n            for line in lines:\n                fout.write(line)\n\n\ndef process_hu_combined(paths, short_name):\n    assert short_name == \"hu_combined\"\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    rgai_input_path = get_rgai_input_path(paths)\n    nytk_input_path = get_nytk_input_path(paths)\n\n    with tempfile.TemporaryDirectory() as tmp_output_path:\n        convert_rgai.convert_rgai(rgai_input_path, tmp_output_path, \"hu_rgai\", True, True)\n        convert_nytk.convert_nytk(nytk_input_path, tmp_output_path, \"hu_nytk\")\n\n        for shard in SHARDS:\n            rgai_input = os.path.join(tmp_output_path, \"hu_rgai.%s.bio\" % shard)\n            nytk_input = os.path.join(tmp_output_path, \"hu_nytk.%s.bio\" % shard)\n            output_file = os.path.join(base_output_path, \"hu_combined.%s.bio\" % shard)\n            concat_files(output_file, rgai_input, nytk_input)\n\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef process_bsnlp(paths, short_name):\n    \"\"\"\n    Process files downloaded from http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html\n\n    If you download the training and test data zip files and unzip\n    them without rearranging in any way, the layout is somewhat weird.\n    Training data goes into a specific subdirectory, but the test data\n    goes into the top level directory.\n    \"\"\"\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"bsnlp2019\")\n    base_train_path = os.path.join(base_input_path, \"training_pl_cs_ru_bg_rc1\")\n    base_test_path = base_input_path\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n\n    output_train_filename = os.path.join(base_output_path, \"%s.train.csv\" % short_name)\n    output_dev_filename   = os.path.join(base_output_path, \"%s.dev.csv\" % short_name)\n    output_test_filename  = os.path.join(base_output_path, \"%s.test.csv\" % short_name)\n\n    language = short_name.split(\"_\")[0]\n\n    convert_bsnlp.convert_bsnlp(language, base_test_path, output_test_filename)\n    convert_bsnlp.convert_bsnlp(language, base_train_path, output_train_filename, output_dev_filename)\n\n    for shard, csv_file in zip(SHARDS, (output_train_filename, output_dev_filename, output_test_filename)):\n        output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))\n        prepare_ner_file.process_dataset(csv_file, output_filename)\n\nNCHLT_LANGUAGE_MAP = {\n    \"af\":  \"NCHLT Afrikaans Named Entity Annotated Corpus\",\n    # none of the following have UD datasets as of 2.8.  Until they\n    # exist, we assume the language codes NCHTL are sufficient\n    \"nr\":  \"NCHLT isiNdebele Named Entity Annotated Corpus\",\n    \"nso\": \"NCHLT Sepedi Named Entity Annotated Corpus\",\n    \"ss\":  \"NCHLT Siswati Named Entity Annotated Corpus\",\n    \"st\":  \"NCHLT Sesotho Named Entity Annotated Corpus\",\n    \"tn\":  \"NCHLT Setswana Named Entity Annotated Corpus\",\n    \"ts\":  \"NCHLT Xitsonga Named Entity Annotated Corpus\",\n    \"ve\":  \"NCHLT Tshivenda Named Entity Annotated Corpus\",\n    \"xh\":  \"NCHLT isiXhosa Named Entity Annotated Corpus\",\n    \"zu\":  \"NCHLT isiZulu Named Entity Annotated Corpus\",\n}\n\ndef process_nchlt(paths, short_name):\n    language = short_name.split(\"_\")[0]\n    if not language in NCHLT_LANGUAGE_MAP:\n        raise UnknownDatasetError(short_name, \"Language %s not part of NCHLT\" % language)\n    short_name = \"%s_nchlt\" % language\n\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"NCHLT\", NCHLT_LANGUAGE_MAP[language], \"*Full.txt\")\n    input_files = glob.glob(base_input_path)\n    if len(input_files) == 0:\n        raise FileNotFoundError(\"Cannot find NCHLT dataset in '%s'  Did you remember to download the file?\" % base_input_path)\n\n    if len(input_files) > 1:\n        raise ValueError(\"Unexpected number of files matched '%s'  There should only be one\" % base_input_path)\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    split_wikiner(base_output_path, input_files[0], prefix=short_name, remap={\"OUT\": \"O\"})\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef process_my_ucsy(paths, short_name):\n    assert short_name == \"my_ucsy\"\n    language = \"my\"\n\n    base_input_path = os.path.join(paths[\"NERBASE\"], short_name)\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    convert_my_ucsy.convert_my_ucsy(base_input_path, base_output_path)\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef process_fa_arman(paths, short_name):\n    \"\"\"\n    Converts fa_arman dataset\n\n    The conversion is quite simple, actually.\n    Just need to split the train file and then convert bio -> json\n    \"\"\"\n    assert short_name == \"fa_arman\"\n    language = \"fa\"\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"PersianNER\")\n    train_input_file = os.path.join(base_input_path, \"train_fold1.txt\")\n    test_input_file = os.path.join(base_input_path, \"test_fold1.txt\")\n    if not os.path.exists(train_input_file) or not os.path.exists(test_input_file):\n        full_corpus_file = os.path.join(base_input_path, \"ArmanPersoNERCorpus.zip\")\n        if os.path.exists(full_corpus_file):\n            raise FileNotFoundError(\"Please unzip the file {}\".format(full_corpus_file))\n        raise FileNotFoundError(\"Cannot find the arman corpus in the expected directory: {}\".format(base_input_path))\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    test_output_file = os.path.join(base_output_path, \"%s.test.bio\" % short_name)\n\n    split_wikiner(base_output_path, train_input_file, prefix=short_name, train_fraction=0.8, test_section=False)\n    shutil.copy2(test_input_file, test_output_file)\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef process_sv_suc3licensed(paths, short_name):\n    \"\"\"\n    The .zip provided for SUC3 includes train/dev/test splits already\n\n    This extracts those splits without needing to unzip the original file\n    \"\"\"\n    assert short_name == \"sv_suc3licensed\"\n    language = \"sv\"\n    train_input_file = os.path.join(paths[\"NERBASE\"], short_name, \"SUC3.0.zip\")\n    if not os.path.exists(train_input_file):\n        raise FileNotFoundError(\"Cannot find the officially licensed SUC3 dataset in %s\" % train_input_file)\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    suc_conll_to_iob.process_suc3(train_input_file, short_name, base_output_path)\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef process_sv_suc3shuffle(paths, short_name):\n    \"\"\"\n    Uses an externally provided script to read the SUC3 XML file, then splits it\n    \"\"\"\n    assert short_name == \"sv_suc3shuffle\"\n    language = \"sv\"\n    train_input_file = os.path.join(paths[\"NERBASE\"], short_name, \"suc3.xml.bz2\")\n    if not os.path.exists(train_input_file):\n        train_input_file = train_input_file[:-4]\n    if not os.path.exists(train_input_file):\n        raise FileNotFoundError(\"Unable to find the SUC3 dataset in {}.bz2\".format(train_input_file))\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    train_output_file = os.path.join(base_output_path, \"sv_suc3shuffle.bio\")\n    suc_to_iob.main([train_input_file, train_output_file])\n    split_wikiner(base_output_path, train_output_file, prefix=short_name)\n    convert_bio_to_json(base_output_path, base_output_path, short_name)    \n    \ndef process_da_ddt(paths, short_name):\n    \"\"\"\n    Processes Danish DDT dataset\n\n    This dataset is in a conll file with the \"name\" attribute in the\n    misc column for the NER tag.  This function uses a script to\n    convert such CoNLL files to .bio\n    \"\"\"\n    assert short_name == \"da_ddt\"\n    language = \"da\"\n    IN_FILES = (\"ddt.train.conllu\", \"ddt.dev.conllu\", \"ddt.test.conllu\")\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    OUT_FILES = [os.path.join(base_output_path, \"%s.%s.bio\" % (short_name, shard)) for shard in SHARDS]\n\n    zip_file = os.path.join(paths[\"NERBASE\"], \"da_ddt\", \"ddt.zip\")\n    if os.path.exists(zip_file):\n        for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS):\n            conll_to_iob.process_conll(in_filename, out_filename, zip_file)\n    else:\n        for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS):\n            in_filename = os.path.join(paths[\"NERBASE\"], \"da_ddt\", in_filename)\n            if not os.path.exists(in_filename):\n                raise FileNotFoundError(\"Could not find zip in expected location %s and could not file %s file in %s\" % (zip_file, shard, in_filename))\n\n            conll_to_iob.process_conll(in_filename, out_filename)\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\n\ndef process_norne(paths, short_name):\n    \"\"\"\n    Processes Norwegian NorNE\n\n    Can handle either Bokmål or Nynorsk\n\n    Converts GPE_LOC and GPE_ORG to GPE\n    \"\"\"\n    language, name = short_name.split(\"_\", 1)\n    assert language in ('nb', 'nn')\n    assert name == 'norne'\n\n    if language == 'nb':\n        IN_FILES = (\"nob/no_bokmaal-ud-train.conllu\", \"nob/no_bokmaal-ud-dev.conllu\", \"nob/no_bokmaal-ud-test.conllu\")\n    else:\n        IN_FILES = (\"nno/no_nynorsk-ud-train.conllu\", \"nno/no_nynorsk-ud-dev.conllu\", \"nno/no_nynorsk-ud-test.conllu\")\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    OUT_FILES = [os.path.join(base_output_path, \"%s.%s.bio\" % (short_name, shard)) for shard in SHARDS]\n\n    CONVERSION = { \"GPE_LOC\": \"GPE\", \"GPE_ORG\": \"GPE\" }\n\n    for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS):\n        in_filename = os.path.join(paths[\"NERBASE\"], \"norne\", \"ud\", in_filename)\n        if not os.path.exists(in_filename):\n            raise FileNotFoundError(\"Could not find %s file in %s\" % (shard, in_filename))\n\n        conll_to_iob.process_conll(in_filename, out_filename, conversion=CONVERSION)\n\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef process_ja_gsd(paths, short_name):\n    \"\"\"\n    Convert ja_gsd from MegagonLabs\n\n    for example, can download from https://github.com/megagonlabs/UD_Japanese-GSD/releases/tag/r2.9-NE\n    \"\"\"\n    language, name = short_name.split(\"_\", 1)\n    assert language == 'ja'\n    assert name == 'gsd'\n\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    output_files = [os.path.join(base_output_path, \"%s.%s.bio\" % (short_name, shard)) for shard in SHARDS]\n\n    search_path = os.path.join(paths[\"NERBASE\"], \"ja_gsd\", \"UD_Japanese-GSD-r2.*-NE\")\n    versions = glob.glob(search_path)\n    max_version = None\n    base_input_path = None\n    version_re = re.compile(\"GSD-r2.([0-9]+)-NE$\")\n\n    for ver in versions:\n        match = version_re.search(ver)\n        if not match:\n            continue\n        ver_num = int(match.groups(1)[0])\n        if max_version is None or ver_num > max_version:\n            max_version = ver_num\n            base_input_path = ver\n\n    if base_input_path is None:\n        raise FileNotFoundError(\"Could not find any copies of the NE conversion of ja_gsd here: {}\".format(search_path))\n    print(\"Most recent version found: {}\".format(base_input_path))\n\n    input_files = [\"ja_gsd-ud-train.ne.conllu\", \"ja_gsd-ud-dev.ne.conllu\", \"ja_gsd-ud-test.ne.conllu\"]\n\n    def conversion(x):\n        if x[0] == 'L':\n            return 'E' + x[1:]\n        if x[0] == 'U':\n            return 'S' + x[1:]\n        # B, I unchanged\n        return x\n\n    for in_filename, out_filename, shard in zip(input_files, output_files, SHARDS):\n        in_path = os.path.join(base_input_path, in_filename)\n        if not os.path.exists(in_path):\n            in_spacy = os.path.join(base_input_path, \"spacy\", in_filename)\n            if not os.path.exists(in_spacy):\n                raise FileNotFoundError(\"Could not find %s file in %s or %s\" % (shard, in_path, in_spacy))\n            in_path = in_spacy\n\n        conll_to_iob.process_conll(in_path, out_filename, conversion=conversion, allow_empty=True, attr_prefix=\"NE\")\n\n    convert_bio_to_json(base_output_path, base_output_path, short_name)\n\ndef process_starlang(paths, short_name):\n    \"\"\"\n    Process a Turkish dataset from Starlang\n    \"\"\"\n    assert short_name == 'tr_starlang'\n\n    PIECES = [\"TurkishAnnotatedTreeBank-15\",\n              \"TurkishAnnotatedTreeBank2-15\",\n              \"TurkishAnnotatedTreeBank2-20\"]\n\n    chunk_paths = [os.path.join(paths[\"CONSTITUENCY_BASE\"], \"turkish\", piece) for piece in PIECES]\n    datasets = convert_starlang_ner.read_starlang(chunk_paths)\n\n    write_dataset(datasets, paths[\"NER_DATA_DIR\"], short_name)\n\ndef remap_germeval_tag(tag):\n    \"\"\"\n    Simplify tags for GermEval2014 using a simple rubric\n\n    all tags become their parent tag\n    OTH becomes MISC\n    \"\"\"\n    if tag == \"O\":\n        return tag\n    if tag[1:5] == \"-LOC\":\n        return tag[:5]\n    if tag[1:5] == \"-PER\":\n        return tag[:5]\n    if tag[1:5] == \"-ORG\":\n        return tag[:5]\n    if tag[1:5] == \"-OTH\":\n        return tag[0] + \"-MISC\"\n    raise ValueError(\"Unexpected tag: %s\" % tag)\n\ndef process_de_germeval2014(paths, short_name):\n    \"\"\"\n    Process the TSV of the GermEval2014 dataset\n    \"\"\"\n    in_directory = os.path.join(paths[\"NERBASE\"], \"germeval2014\")\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    datasets = []\n    for shard in SHARDS:\n        in_file = os.path.join(in_directory, \"NER-de-%s.tsv\" % shard)\n        sentences = read_tsv(in_file, 1, 2, remap_tag_fn=remap_germeval_tag)\n        datasets.append(sentences)\n    tags = get_tags(datasets)\n    print(\"Found the following tags: {}\".format(sorted(tags)))\n    write_dataset(datasets, base_output_path, short_name)\n\ndef process_hiner(paths, short_name):\n    in_directory = os.path.join(paths[\"NERBASE\"], \"hindi\", \"HiNER\", \"data\", \"original\")\n    convert_bio_to_json(in_directory, paths[\"NER_DATA_DIR\"], short_name, suffix=\"conll\", shard_names=(\"train\", \"validation\", \"test\"))\n\ndef process_hinercollapsed(paths, short_name):\n    in_directory = os.path.join(paths[\"NERBASE\"], \"hindi\", \"HiNER\", \"data\", \"collapsed\")\n    convert_bio_to_json(in_directory, paths[\"NER_DATA_DIR\"], short_name, suffix=\"conll\", shard_names=(\"train\", \"validation\", \"test\"))\n\ndef process_lst20(paths, short_name, include_space_char=True):\n    convert_lst20.convert_lst20(paths, short_name, include_space_char)\n\ndef process_nner22(paths, short_name, include_space_char=True):\n    convert_nner22.convert_nner22(paths, short_name, include_space_char)\n\ndef process_mr_l3cube(paths, short_name):\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    in_directory = os.path.join(paths[\"NERBASE\"], \"marathi\", \"MarathiNLP\", \"L3Cube-MahaNER\", \"IOB\")\n    input_files = [\"train_iob.txt\", \"valid_iob.txt\", \"test_iob.txt\"]\n    input_files = [os.path.join(in_directory, x) for x in input_files]\n    for input_file in input_files:\n        if not os.path.exists(input_file):\n            raise FileNotFoundError(\"Could not find the expected piece of the l3cube dataset %s\" % input_file)\n\n    datasets = [convert_mr_l3cube.convert(input_file) for input_file in input_files]\n    write_dataset(datasets, base_output_path, short_name)\n\ndef process_bn_daffodil(paths, short_name):\n    in_directory = os.path.join(paths[\"NERBASE\"], \"bangla\", \"Bengali-NER\")\n    out_directory = paths[\"NER_DATA_DIR\"]\n    convert_bn_daffodil.convert_dataset(in_directory, out_directory)\n\ndef process_pl_nkjp(paths, short_name):\n    out_directory = paths[\"NER_DATA_DIR\"]\n    candidates = [os.path.join(paths[\"NERBASE\"], \"Polish-NKJP\"),\n                  os.path.join(paths[\"NERBASE\"], \"polish\", \"Polish-NKJP\"),\n                  os.path.join(paths[\"NERBASE\"], \"polish\", \"NKJP-PodkorpusMilionowy-1.2.tar.gz\"),]\n    for in_path in candidates:\n        if os.path.exists(in_path):\n            break\n    else:\n        raise FileNotFoundError(\"Could not find %s  Looked in %s\" % (short_name, \" \".join(candidates)))\n    convert_nkjp.convert_nkjp(in_path, out_directory)\n\ndef process_kk_kazNERD(paths, short_name):\n    in_directory = os.path.join(paths[\"NERBASE\"], \"kazakh\", \"KazNERD\", \"KazNERD\")\n    out_directory = paths[\"NER_DATA_DIR\"]\n    convert_kk_kazNERD.convert_dataset(in_directory, out_directory, short_name)\n\ndef process_masakhane(paths, dataset_name):\n    \"\"\"\n    Converts Masakhane NER datasets to Stanza's .json format\n\n    If we let N be the length of the first sentence, the NER files\n    (in version 2, at least) are all of the form\n\n    word tag\n    ...\n    word tag\n      (blank line for sentence break)\n    word tag\n    ...\n\n    Once the dataset is git cloned in $NERBASE, the directory structure is\n\n    $NERBASE/masakhane-ner/MasakhaNER2.0/data/$lcode/{train,dev,test}.txt\n\n    The only tricky thing here is that for some languages, we treat\n    the 2 letter lcode as canonical thanks to UD, but Masakhane NER\n    uses 3 letter lcodes for all languages.\n    \"\"\"\n    language, dataset = dataset_name.split(\"_\")\n    lcode = lang_to_langcode(language)\n    if lcode in two_to_three_letters:\n        masakhane_lcode = two_to_three_letters[lcode]\n    else:\n        masakhane_lcode = lcode\n\n    mn_directory = os.path.join(paths[\"NERBASE\"], \"masakhane-ner\")\n    if not os.path.exists(mn_directory):\n        raise FileNotFoundError(\"Cannot find Masakhane NER repo.  Please check the setting of NERBASE or clone the repo to %s\" % mn_directory)\n    data_directory = os.path.join(mn_directory, \"MasakhaNER2.0\", \"data\")\n    if not os.path.exists(data_directory):\n        raise FileNotFoundError(\"Apparently found the repo at %s but the expected directory structure is not there - was looking for %s\" % (mn_directory, data_directory))\n\n    in_directory = os.path.join(data_directory, masakhane_lcode)\n    if not os.path.exists(in_directory):\n        raise UnknownDatasetError(dataset_name, \"Found the Masakhane repo, but there was no %s in the repo at path %s\" % (dataset_name, in_directory))\n    convert_bio_to_json(in_directory, paths[\"NER_DATA_DIR\"], \"%s_masakhane\" % lcode, \"txt\")\n\ndef process_sd_siner(paths, short_name):\n    in_directory = os.path.join(paths[\"NERBASE\"], \"sindhi\", \"SiNER-dataset\")\n    if not os.path.exists(in_directory):\n        raise FileNotFoundError(\"Cannot find SiNER checkout in $NERBASE/sindhi  Please git clone to repo in that directory\")\n    in_filename = os.path.join(in_directory, \"SiNER-dataset.txt\")\n    if not os.path.exists(in_filename):\n        in_filename = os.path.join(in_directory, \"SiNER dataset.txt\")\n        if not os.path.exists(in_filename):\n            raise FileNotFoundError(\"Found an SiNER directory at %s but the directory did not contain the dataset\" % in_directory)\n    convert_sindhi_siner.convert_sindhi_siner(in_filename, paths[\"NER_DATA_DIR\"], short_name)\n\ndef process_en_worldwide_4class(paths, short_name):\n    simplify_en_worldwide.main(args=['--simplify'])\n\n    in_directory = os.path.join(paths[\"NERBASE\"], \"en_worldwide\", \"4class\")\n    out_directory = paths[\"NER_DATA_DIR\"]\n\n    destination_file = os.path.join(paths[\"NERBASE\"], \"en_worldwide\", \"en-worldwide-newswire\", \"regions.txt\")\n    prefix_map = read_prefix_file(destination_file)\n\n    random_shuffle_by_prefixes(in_directory, out_directory, short_name, prefix_map)\n\ndef process_en_worldwide_9class(paths, short_name):\n    simplify_en_worldwide.main(args=['--no_simplify'])\n\n    in_directory = os.path.join(paths[\"NERBASE\"], \"en_worldwide\", \"9class\")\n    out_directory = paths[\"NER_DATA_DIR\"]\n\n    destination_file = os.path.join(paths[\"NERBASE\"], \"en_worldwide\", \"en-worldwide-newswire\", \"regions.txt\")\n    prefix_map = read_prefix_file(destination_file)\n\n    random_shuffle_by_prefixes(in_directory, out_directory, short_name, prefix_map)\n\ndef process_en_ontonotes(paths, short_name):\n    ner_input_path = paths['NERBASE']\n    ontonotes_path = os.path.join(ner_input_path, \"english\", \"en_ontonotes\")\n    ner_output_path = paths['NER_DATA_DIR']\n    convert_ontonotes.process_dataset(\"en_ontonotes\", ontonotes_path, ner_output_path)\n\ndef process_zh_ontonotes(paths, short_name):\n    ner_input_path = paths['NERBASE']\n    ontonotes_path = os.path.join(ner_input_path, \"chinese\", \"zh_ontonotes\")\n    ner_output_path = paths['NER_DATA_DIR']\n    convert_ontonotes.process_dataset(short_name, ontonotes_path, ner_output_path)\n\ndef process_en_conll03(paths, short_name):\n    ner_input_path = paths['NERBASE']\n    conll_path = os.path.join(ner_input_path, \"english\", \"en_conll03\")\n    ner_output_path = paths['NER_DATA_DIR']\n    convert_en_conll03.process_dataset(\"en_conll03\", conll_path, ner_output_path)\n\ndef process_en_conll03_worldwide(paths, short_name):\n    \"\"\"\n    Adds the training data for conll03 and worldwide together\n    \"\"\"\n    print(\"============== Preparing CoNLL 2003 ===================\")\n    process_en_conll03(paths, \"en_conll03\")\n    print(\"========== Preparing 4 Class Worldwide ================\")\n    process_en_worldwide_4class(paths, \"en_worldwide-4class\")\n    print(\"============== Combined Train Data ====================\")\n    input_files = [os.path.join(paths['NER_DATA_DIR'], \"en_conll03.train.json\"),\n                   os.path.join(paths['NER_DATA_DIR'], \"en_worldwide-4class.train.json\")]\n    output_file = os.path.join(paths['NER_DATA_DIR'], \"%s.train.json\" % short_name)\n    combine_files(output_file, *input_files)\n    shutil.copyfile(os.path.join(paths['NER_DATA_DIR'], \"en_conll03.dev.json\"),\n                    os.path.join(paths['NER_DATA_DIR'], \"%s.dev.json\" % short_name))\n    shutil.copyfile(os.path.join(paths['NER_DATA_DIR'], \"en_conll03.test.json\"),\n                    os.path.join(paths['NER_DATA_DIR'], \"%s.test.json\" % short_name))\n\ndef process_en_ontonotes_ww_multi(paths, short_name):\n    \"\"\"\n    Combine the worldwide data with the OntoNotes data in a multi channel format\n    \"\"\"\n    print(\"=============== Preparing OntoNotes ===============\")\n    process_en_ontonotes(paths, \"en_ontonotes\")\n    print(\"========== Preparing 9 Class Worldwide ================\")\n    process_en_worldwide_9class(paths, \"en_worldwide-9class\")\n    # TODO: pass in options?\n    ontonotes_multitag.build_multitag_dataset(paths['NER_DATA_DIR'], short_name, True, True)\n\ndef process_en_combined(paths, short_name):\n    \"\"\"\n    Combine WW, OntoNotes, and CoNLL into a 3 channel dataset\n    \"\"\"\n    print(\"================= Preparing OntoNotes =================\")\n    process_en_ontonotes(paths, \"en_ontonotes\")\n    print(\"========== Preparing 9 Class Worldwide ================\")\n    process_en_worldwide_9class(paths, \"en_worldwide-9class\")\n    print(\"=============== Preparing CoNLL 03 ====================\")\n    process_en_conll03(paths, \"en_conll03\")\n    build_en_combined.build_combined_dataset(paths['NER_DATA_DIR'], short_name)\n\n\ndef process_en_conllpp(paths, short_name):\n    \"\"\"\n    This is ONLY a test set\n\n    the test set has entities start with I- instead of B- unless they\n    are in the middle of a sentence, but that should be find, as\n    process_tags in the NER model converts those to B- in a BIOES\n    conversion\n    \"\"\"\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"acl2023_conllpp\", \"dataset\", \"conllpp.txt\")\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    sentences = read_tsv(base_input_path, 0, 3, separator=None)\n    sentences = [sent for sent in sentences if len(sent) > 1 or sent[0][0] != '-DOCSTART-']\n    write_dataset([sentences], base_output_path, short_name, shard_names=[\"test\"], shards=[\"test\"])\n\ndef process_armtdp(paths, short_name):\n    assert short_name == 'hy_armtdp'\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"armenian\", \"ArmTDP-NER\")\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    convert_hy_armtdp.convert_dataset(base_input_path, base_output_path, short_name)\n    for shard in SHARDS:\n        input_filename = os.path.join(base_output_path, f'{short_name}.{shard}.tsv')\n        if not os.path.exists(input_filename):\n            raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename))\n        output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))\n        prepare_ner_file.process_dataset(input_filename, output_filename)\n\ndef process_toy_dataset(paths, short_name):\n    convert_bio_to_json(os.path.join(paths[\"NERBASE\"], \"English-SAMPLE\"), paths[\"NER_DATA_DIR\"], short_name)\n\ndef process_ar_aqmar(paths, short_name):\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"arabic\", \"AQMAR\", \"AQMAR_Arabic_NER_corpus-1.0.zip\")\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    convert_ar_aqmar.convert_shuffle(base_input_path, base_output_path, short_name)\n\ndef process_he_iahlt(paths, short_name):\n    assert short_name == 'he_iahlt'\n    # for now, need to use UDBASE_GIT until IAHLTknesset is added to UD\n    udbase = paths[\"UDBASE_GIT\"]\n    base_output_path = paths[\"NER_DATA_DIR\"]\n    convert_he_iahlt.convert_iahlt(udbase, base_output_path, \"he_iahlt\")\n\ndef process_ang_ewt(paths, short_name):\n    assert short_name == 'ang_ewt'\n    base_input_path = os.path.join(paths[\"NERBASE\"], \"ang\", \"Old_English-OEDT\")\n    convert_bio_to_json(base_input_path, paths[\"NER_DATA_DIR\"], short_name)\n\nDATASET_MAPPING = {\n    \"ang_ewt\":           process_ang_ewt,\n    \"ar_aqmar\":          process_ar_aqmar,\n    \"bn_daffodil\":       process_bn_daffodil,\n    \"da_ddt\":            process_da_ddt,\n    \"de_germeval2014\":   process_de_germeval2014,\n    \"en_conll03\":        process_en_conll03,\n    \"en_conll03ww\":      process_en_conll03_worldwide,\n    \"en_conllpp\":        process_en_conllpp,\n    \"en_ontonotes\":      process_en_ontonotes,\n    \"en_ontonotes-ww-multi\": process_en_ontonotes_ww_multi,\n    \"en_combined\":       process_en_combined,\n    \"en_worldwide-4class\": process_en_worldwide_4class,\n    \"en_worldwide-9class\": process_en_worldwide_9class,\n    \"fa_arman\":          process_fa_arman,\n    \"fi_turku\":          process_turku,\n    \"fr_wikinergold\":    process_french_wikiner_gold,\n    \"fr_wikinermixed\":   process_french_wikiner_mixed,\n    \"hi_hiner\":          process_hiner,\n    \"hi_hinercollapsed\": process_hinercollapsed,\n    \"hi_ijc\":            process_ijc,\n    \"he_iahlt\":          process_he_iahlt,\n    \"hu_nytk\":           process_nytk,\n    \"hu_combined\":       process_hu_combined,\n    \"hy_armtdp\":         process_armtdp,\n    \"it_fbk\":            process_it_fbk,\n    \"ja_gsd\":            process_ja_gsd,\n    \"kk_kazNERD\":        process_kk_kazNERD,\n    \"mr_l3cube\":         process_mr_l3cube,\n    \"my_ucsy\":           process_my_ucsy,\n    \"pl_nkjp\":           process_pl_nkjp,\n    \"sd_siner\":          process_sd_siner,\n    \"sv_suc3licensed\":   process_sv_suc3licensed,\n    \"sv_suc3shuffle\":    process_sv_suc3shuffle,\n    \"tr_starlang\":       process_starlang,\n    \"th_lst20\":          process_lst20,\n    \"th_nner22\":         process_nner22,\n    \"zh-hans_ontonotes\": process_zh_ontonotes,\n}\n\nSUFFIX_MAPPING = {\n    \"_ilner\":            process_il_ner,\n    \"_suralk\":           process_suralk_multiner,\n}\n\ndef main(dataset_name):\n    paths = default_paths.get_default_paths()\n    print(\"Processing %s\" % dataset_name)\n\n    random.seed(1234)\n\n    if dataset_name in DATASET_MAPPING:\n        DATASET_MAPPING[dataset_name](paths, dataset_name)\n    elif dataset_name in ('uk_languk', 'Ukranian_languk', 'Ukranian-languk'):\n        process_languk(paths, 'uk_languk')\n    elif dataset_name.endswith(\"FIRE2013\") or dataset_name.endswith(\"fire2013\"):\n        process_fire_2013(paths, dataset_name)\n    elif dataset_name.endswith('WikiNER'):\n        process_wikiner(paths, dataset_name)\n    elif dataset_name.startswith('hu_rgai'):\n        process_rgai(paths, dataset_name)\n    elif dataset_name.endswith(\"_bsnlp19\"):\n        process_bsnlp(paths, dataset_name)\n    elif dataset_name.endswith(\"_nchlt\"):\n        process_nchlt(paths, dataset_name)\n    elif dataset_name in (\"nb_norne\", \"nn_norne\"):\n        process_norne(paths, dataset_name)\n    elif dataset_name == 'en_sample':\n        process_toy_dataset(paths, dataset_name)\n    elif dataset_name.lower().endswith(\"_masakhane\"):\n        process_masakhane(paths, dataset_name)\n    else:\n        for ending in SUFFIX_MAPPING:\n            if dataset_name.endswith(ending):\n                SUFFIX_MAPPING[ending](paths, dataset_name)\n                break\n        else:\n            raise UnknownDatasetError(dataset_name, f\"dataset {dataset_name} currently not handled by prepare_ner_dataset\")\n    print(\"Done processing %s\" % dataset_name)\n\nif __name__ == '__main__':\n    main(sys.argv[1])\n"
  },
  {
    "path": "stanza/utils/datasets/ner/prepare_ner_file.py",
    "content": "\"\"\"\nThis script converts NER data from the CoNLL03 format to the latest CoNLL-U format. The script assumes that in the \ninput column format data, the token is always in the first column, while the NER tag is always in the last column.\n\"\"\"\n\nimport argparse\nimport json\n\nMIN_NUM_FIELD = 2\nMAX_NUM_FIELD = 5\n\nDOC_START_TOKEN = '-DOCSTART-'\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Convert the conll03 format data into conllu format.\")\n    parser.add_argument('input', help='Input conll03 format data filename.')\n    parser.add_argument('output', help='Output json filename.')\n    args = parser.parse_args()\n    return args\n\ndef main():\n    args = parse_args()\n    process_dataset(args.input, args.output)\n\ndef process_dataset(input_filename, output_filename):\n    sentences = load_conll03(input_filename)\n    print(\"{} examples loaded from {}\".format(len(sentences), input_filename))\n    \n    document = []\n    for (words, tags) in sentences:\n        sent = []\n        for w, t in zip(words, tags):\n            sent += [{'text': w, 'ner': t}]\n        document += [sent]\n\n    with open(output_filename, 'w', encoding=\"utf-8\") as outfile:\n        json.dump(document, outfile, indent=1)\n    print(\"Generated json file {}\".format(output_filename))\n\n# TODO: make skip_doc_start an argument\ndef load_conll03(filename, skip_doc_start=True):\n    cached_lines = []\n    examples = []\n    with open(filename, encoding=\"utf-8\") as infile:\n        for line in infile:\n            line = line.strip()\n            if skip_doc_start and DOC_START_TOKEN in line:\n                continue\n            if len(line) > 0:\n                array = line.split(\"\\t\")\n                if len(array) < MIN_NUM_FIELD:\n                    array = line.split()\n                if len(array) < MIN_NUM_FIELD:\n                    continue\n                else:\n                    cached_lines.append(line)\n            elif len(cached_lines) > 0:\n                example = process_cache(cached_lines)\n                examples.append(example)\n                cached_lines = []\n        if len(cached_lines) > 0:\n            examples.append(process_cache(cached_lines))\n    return examples\n\ndef process_cache(cached_lines):\n    tokens = []\n    ner_tags = []\n    for line in cached_lines:\n        array = line.split(\"\\t\")\n        if len(array) < MIN_NUM_FIELD:\n            array = line.split()\n        assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD, \"Got unexpected line length: {}\".format(array)\n        tokens.append(array[0])\n        ner_tags.append(array[-1])\n    return (tokens, ner_tags)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/preprocess_wikiner.py",
    "content": "\"\"\"\nConverts the WikiNER data format to a format usable by our processing tools\n\npython preprocess_wikiner input output\n\"\"\"\n\nimport sys\n\ndef preprocess_wikiner(input_file, output_file, encoding=\"utf-8\"):\n    with open(input_file, encoding=encoding) as fin:\n        with open(output_file, \"w\", encoding=\"utf-8\") as fout:\n            for line in fin:\n                line = line.strip()\n                if not line:\n                    fout.write(\"-DOCSTART- O\\n\")\n                    fout.write(\"\\n\")\n                    continue\n\n                words = line.split()\n                for word in words:\n                    pieces = word.split(\"|\")\n                    text = pieces[0]\n                    tag = pieces[-1]\n                    # some words look like Daniel_Bernoulli|I-PER\n                    # but the original .pl conversion script didn't take that into account\n                    subtext = text.split(\"_\")\n                    if tag.startswith(\"B-\") and len(subtext) > 1:\n                        fout.write(\"{} {}\\n\".format(subtext[0], tag))\n                        for chunk in subtext[1:]:\n                            fout.write(\"{} I-{}\\n\".format(chunk, tag[2:]))\n                    else:\n                        for chunk in subtext:\n                            fout.write(\"{} {}\\n\".format(chunk, tag))\n                fout.write(\"\\n\")\n\nif __name__ == '__main__':\n    preprocess_wikiner(sys.argv[1], sys.argv[2])\n"
  },
  {
    "path": "stanza/utils/datasets/ner/simplify_en_worldwide.py",
    "content": "import argparse\nimport os\nimport tempfile\n\nimport stanza\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.utils.datasets.ner.utils import read_tsv\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nPUNCTUATION = \"\"\"!\"#%&'()*+, -./:;<=>?@[\\\\]^_`{|}~\"\"\"\nMONEY_WORDS = {\"million\", \"billion\", \"trillion\", \"millions\", \"billions\", \"trillions\", \"hundred\", \"hundreds\",\n               \"lakh\", \"crore\", # south asian english\n               \"tens\", \"of\", \"ten\", \"one\", \"two\", \"three\", \"four\", \"five\", \"six\", \"seven\", \"eight\", \"nine\", \"couple\"}\n\n# Doesn't include Money but this case is handled explicitly for processing\nLABEL_TRANSLATION = {\n    \"Date\":         None,\n    \"Misc\":         \"MISC\",\n    \"Product\":      \"MISC\",\n    \"NORP\":         \"MISC\",\n    \"Facility\":     \"LOC\",\n    \"Location\":     \"LOC\",\n    \"Person\":       \"PER\",\n    \"Organization\": \"ORG\",\n}\n\ndef isfloat(num):\n    try:\n        float(num)\n        return True\n    except ValueError:\n        return False\n\n\ndef process_label(line, is_start=False):\n    \"\"\"\n    Converts our stuff to conll labels\n\n    event, product, work of art, norp -> MISC\n    take out dates - can use Stanza to identify them as dates and eliminate them\n    money requires some special care\n    facility -> location  (there are examples of Bridge and Hospital in the data)\n    the version of conll we used to train CoreNLP NER is here:\n\n    Overall plan:\n    Collapse Product, NORP, Money (extract only the symbols), into misc.\n    Collapse Facilities into LOC\n    Deletes Dates\n\n    Rule for currency is that we take out labels for the numbers that return True for isfloat()\n    Take out words that categorize money (Million, Billion, Trillion, Thousand, Hundred, Ten, Nine, Eight, Seven, Six, Five,\n    Four, Three, Two, One)\n    Take out punctuation characters\n\n    If we remove the 'B' tag, then move it to the first remaining tag.\n\n    Replace tags with 'O'\n    is_start parameter signals whether or not this current line is the new start of a tag. Needed for when\n    the previous line analyzed is the start of a MONEY tag but is removed because it is a non symbol- need to\n    set the starting token that is a symbol to the B-MONEY tag when it might have previously been I-MONEY\n    \"\"\"\n    if not line:\n        return []\n    token = line[0]\n    biggest_label = line[1]\n    position, label_name = biggest_label[:2], biggest_label[2:]\n\n    if label_name == \"Money\":\n        if token.lower() in MONEY_WORDS or token in PUNCTUATION or isfloat(token):  # remove this tag\n            label_name = \"O\"\n            is_start = True\n            position = \"\"\n        else:  # keep money tag\n            label_name = \"MISC\"\n            if is_start:\n                position = \"B-\"\n                is_start = False\n\n    elif not label_name or label_name == \"O\":\n        pass\n    elif label_name in LABEL_TRANSLATION:\n        label_name = LABEL_TRANSLATION[label_name]\n        if label_name is None:\n            position = \"\"\n            label_name = \"O\"\n            is_start = False\n    else:\n        raise ValueError(\"Oops, missed a label: %s\" % label_name)\n    return [token, position + label_name, is_start]\n\n\ndef write_new_file(save_dir, input_path, old_file, simplify):\n    starts_b = False\n    with open(input_path, \"r+\", encoding=\"utf-8\") as iob:\n        new_filename = (os.path.splitext(old_file)[0] + \".4class.tsv\") if simplify else old_file\n        with open(os.path.join(save_dir, new_filename), 'w', encoding='utf-8') as fout:\n            for i, line in enumerate(iob):\n                if i == 0 or i == 1:  # skip over the URL and subsequent space line.\n                    continue\n                line = line.strip()\n                if not line:\n                    fout.write(\"\\n\")\n                    continue\n                label = line.split(\"\\t\")\n                if simplify:\n                    try:\n                        edited = process_label(label, is_start=starts_b)  # processed label line labels\n                    except ValueError as e:\n                        raise ValueError(\"Error in %s at line %d\" % (input_path, i)) from e\n                    assert edited\n                    starts_b = edited[-1]\n                    fout.write(\"\\t\".join(edited[:-1]))\n                    fout.write(\"\\n\")\n                else:\n                    fout.write(\"%s\\t%s\\n\" % (label[0], label[1]))\n\n\ndef copy_and_simplify(base_path, simplify):\n    with tempfile.TemporaryDirectory(dir=base_path) as tempdir:\n        # Condense Labels\n        input_dir = os.path.join(base_path, \"en-worldwide-newswire\")\n        final_dir = os.path.join(base_path, \"4class\" if simplify else \"9class\")\n        os.makedirs(tempdir, exist_ok=True)\n        os.makedirs(final_dir, exist_ok=True)\n        for root, dirs, files in os.walk(input_dir):\n            if root[-6:] == \"REVIEW\":\n                batch_files = os.listdir(root)\n                for filename in batch_files:\n                    file_path = os.path.join(root, filename)\n                    write_new_file(final_dir, file_path, filename, simplify)\n\ndef main(args=None):\n    BASE_PATH = \"C:\\\\Users\\\\SystemAdmin\\\\PycharmProjects\\\\General Code\\\\stanza source code\"\n    if not os.path.exists(BASE_PATH):\n        paths = get_default_paths()\n        BASE_PATH = os.path.join(paths[\"NERBASE\"], \"en_worldwide\")\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--base_path', type=str, default=BASE_PATH, help=\"Where to find the raw data\")\n    parser.add_argument('--simplify', default=False, action='store_true', help='Simplify to 4 classes... otherwise, keep all classes')\n    parser.add_argument('--no_simplify', dest='simplify', action='store_false', help=\"Don't simplify to 4 classes\")\n    args = parser.parse_args(args=args)\n\n    copy_and_simplify(args.base_path, args.simplify)\n\nif __name__ == '__main__':\n    main()\n\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/ner/simplify_ontonotes_to_worldwide.py",
    "content": "\"\"\"\nSimplify an existing ner json with the OntoNotes 18 class scheme to the Worldwide scheme\n\nSimplified classes used in the Worldwide dataset are:\n\nDate\nFacility\nLocation\nMisc\nMoney\nNORP\nOrganization\nPerson\nProduct\n\nvs OntoNotes classes:\n\nCARDINAL\nDATE\nEVENT\nFAC\nGPE\nLANGUAGE\nLAW\nLOC\nMONEY\nNORP\nORDINAL\nORG\nPERCENT\nPERSON\nPRODUCT\nQUANTITY\nTIME\nWORK_OF_ART\n\"\"\"\n\nimport argparse\nimport glob\nimport json\nimport os\n\nfrom stanza.utils.default_paths import get_default_paths\n\nWORLDWIDE_ENTITY_MAPPING = {\n    \"CARDINAL\":    None,\n    \"ORDINAL\":     None,\n    \"PERCENT\":     None,\n    \"QUANTITY\":    None,\n    \"TIME\":        None,\n\n    \"DATE\":        \"Date\",\n    \"EVENT\":       \"Misc\",\n    \"FAC\":         \"Facility\",\n    \"GPE\":         \"Location\",\n    \"LANGUAGE\":    \"NORP\",\n    \"LAW\":         \"Misc\",\n    \"LOC\":         \"Location\",\n    \"MONEY\":       \"Money\",\n    \"NORP\":        \"NORP\",\n    \"ORG\":         \"Organization\",\n    \"PERSON\":      \"Person\",\n    \"PRODUCT\":     \"Product\",\n    \"WORK_OF_ART\": \"Misc\",\n\n    # identity map in case this is called on the Worldwide half of the tags\n    \"Date\":        \"Date\",\n    \"Facility\":    \"Facility\",\n    \"Location\":    \"Location\",\n    \"Misc\":        \"Misc\",\n    \"Money\":       \"Money\",\n    \"Organization\":\"Organization\",\n    \"Person\":      \"Person\",\n    \"Product\":     \"Product\",\n}\n\ndef simplify_ontonotes_to_worldwide(entity):\n    if not entity or entity == \"O\":\n        return \"O\"\n\n    ent_iob, ent_type = entity.split(\"-\", maxsplit=1)\n\n    if ent_type in WORLDWIDE_ENTITY_MAPPING:\n        if not WORLDWIDE_ENTITY_MAPPING[ent_type]:\n            return \"O\"\n        return ent_iob + \"-\" + WORLDWIDE_ENTITY_MAPPING[ent_type]\n    raise ValueError(\"Unhandled entity: %s\" % ent_type)\n\ndef convert_file(in_file, out_file):\n    with open(in_file) as fin:\n        gold_doc = json.load(fin)\n\n    for sentence in gold_doc:\n        for word in sentence:\n            if 'ner' not in word:\n                continue\n            word['ner'] = simplify_ontonotes_to_worldwide(word['ner'])\n\n    with open(out_file, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(gold_doc, fout, indent=2)\n\ndef main():\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument('--input_dataset', type=str, default='en_ontonotes', help='which files to convert')\n    parser.add_argument('--output_dataset', type=str, default='en_ontonotes-8class', help='which files to write out')\n    parser.add_argument('--ner_data_dir', type=str, default=get_default_paths()[\"NER_DATA_DIR\"], help='which directory has the data')\n    args = parser.parse_args()\n\n    input_files = glob.glob(os.path.join(args.ner_data_dir, args.input_dataset + \".*\"))\n    for input_file in input_files:\n        output_file = os.path.split(input_file)[1][len(args.input_dataset):]\n        output_file = os.path.join(args.ner_data_dir, args.output_dataset + output_file)\n        print(\"Converting %s to %s\" % (input_file, output_file))\n        convert_file(input_file, output_file)\n\n    \nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/split_wikiner.py",
    "content": "\"\"\"\nPreprocess the WikiNER dataset, by\n1) normalizing tags;\n2) split into train (70%), dev (15%), test (15%) datasets.\n\"\"\"\n\nimport os\nimport random\nimport warnings\nfrom collections import Counter\n\ndef read_sentences(filename, encoding):\n    sents = []\n    cache = []\n    skipped = 0\n    skip = False\n    with open(filename, encoding=encoding) as infile:\n        for i, line in enumerate(infile):\n            line = line.rstrip()\n            if len(line) == 0:\n                if len(cache) > 0:\n                    if not skip:\n                        sents.append(cache)\n                    else:\n                        skipped += 1\n                        skip = False\n                    cache = []\n                continue\n            array = line.split()\n            if len(array) != 2:\n                skip = True\n                warnings.warn(\"Format error at line {}: {}\".format(i+1, line))\n                continue\n            w, t = array\n            cache.append([w, t])\n        if len(cache) > 0:\n            if not skip:\n                sents.append(cache)\n            else:\n                skipped += 1\n            cache = []\n    print(\"Skipped {} examples due to formatting issues.\".format(skipped))\n    return sents\n\ndef write_sentences_to_file(sents, filename):\n    print(f\"Writing {len(sents)} sentences to {filename}\")\n    with open(filename, 'w', encoding='utf-8') as outfile:\n        for sent in sents:\n            for pair in sent:\n                print(f\"{pair[0]}\\t{pair[1]}\", file=outfile)\n            print(\"\", file=outfile)\n\ndef remap_labels(sents, remap):\n    new_sentences = []\n    for sentence in sents:\n        new_sent = []\n        for word in sentence:\n            new_sent.append([word[0], remap.get(word[1], word[1])])\n        new_sentences.append(new_sent)\n    return new_sentences\n\ndef split_wikiner_data(directory, sents, prefix=\"\", suffix=\"bio\", remap=None, shuffle=True, train_fraction=0.7, dev_fraction=0.15, test_section=True):\n    random.seed(1234)\n\n    if remap:\n        sents = remap_labels(sents, remap)\n\n    # split\n    num = len(sents)\n    train_num = int(num*train_fraction)\n    if test_section:\n        dev_num = int(num*dev_fraction)\n        if train_fraction + dev_fraction > 1.0:\n            raise ValueError(\"Train and dev fractions added up to more than 1: {} {} {}\".format(train_fraction, dev_fraction))\n    else:\n        dev_num = num - train_num\n\n    if shuffle:\n        random.shuffle(sents)\n    train_sents = sents[:train_num]\n    dev_sents = sents[train_num:train_num+dev_num]\n    if test_section:\n        test_sents = sents[train_num+dev_num:]\n        batches = [train_sents, dev_sents, test_sents]\n        filenames = [f'train.{suffix}', f'dev.{suffix}', f'test.{suffix}']\n    else:\n        batches = [train_sents, dev_sents]\n        filenames = [f'train.{suffix}', f'dev.{suffix}']\n\n    if prefix:\n        filenames = ['%s.%s' % (prefix, f) for f in filenames]\n    for batch, filename in zip(batches, filenames):\n        write_sentences_to_file(batch, os.path.join(directory, filename))\n\ndef split_wikiner(directory, *in_filenames, encoding=\"utf-8\", **kwargs):\n    sents = []\n    for filename in in_filenames:\n        new_sents = read_sentences(filename, encoding)\n        print(f\"{len(new_sents)} sentences read from {filename}.\")\n        sents.extend(new_sents)\n\n    split_wikiner_data(directory, sents, **kwargs)\n\nif __name__ == \"__main__\":\n    in_filename = 'raw/wp2.txt'\n    directory = \".\"\n    split_wikiner(directory, in_filename)\n"
  },
  {
    "path": "stanza/utils/datasets/ner/suc_conll_to_iob.py",
    "content": "\"\"\"\nProcess the licensed version of SUC3 to BIO\n\nThe main program processes the expected location, or you can pass in a\nspecific zip or filename to read\n\"\"\"\n\nfrom io import TextIOWrapper\nfrom zipfile import ZipFile\n\ndef extract(infile, outfile):\n    \"\"\"\n    Convert the infile to an outfile\n\n    Assumes the files are already open (this allows you to pass in a zipfile reader, for example)\n\n    The SUC3 format is like conll, but with the tags in tabs 10 and 11\n    \"\"\"\n    lines = infile.readlines()\n    sentences = []\n    cur_sentence = []\n    for idx, line in enumerate(lines):\n        line = line.strip()\n        if not line:\n            # if we're currently reading a sentence, append it to the list\n            if cur_sentence:\n                sentences.append(cur_sentence)\n                cur_sentence = []\n            continue\n\n        pieces = line.split(\"\\t\")\n        if len(pieces) < 12:\n            raise ValueError(\"Unexpected line length in the SUC3 dataset at %d\" % idx)\n        if pieces[10] == 'O':\n            cur_sentence.append((pieces[1], \"O\"))\n        else:\n            cur_sentence.append((pieces[1], \"%s-%s\" % (pieces[10], pieces[11])))\n    if cur_sentence:\n        sentences.append(cur_sentence)\n\n    for sentence in sentences:\n        for word in sentence:\n            outfile.write(\"%s\\t%s\\n\" % word)\n        outfile.write(\"\\n\")\n\n    return len(sentences)\n\ndef extract_from_zip(zip_filename, in_filename, out_filename):\n    \"\"\"\n    Process a single file from SUC3\n\n    zip_filename: path to SUC3.0.zip\n    in_filename: which piece to read\n    out_filename: where to write the result\n    \"\"\"\n    with ZipFile(zip_filename) as zin:\n        with zin.open(in_filename) as fin:\n            with open(out_filename, \"w\") as fout:\n                num = extract(TextIOWrapper(fin, encoding=\"utf-8\"), fout)\n                print(\"Processed %d sentences from %s:%s to %s\" % (num, zip_filename, in_filename, out_filename))\n                return num\n\ndef process_suc3(zip_filename, short_name, out_dir):\n    extract_from_zip(zip_filename, \"SUC3.0/corpus/conll/suc-train.conll\", \"%s/%s.train.bio\" % (out_dir, short_name))\n    extract_from_zip(zip_filename, \"SUC3.0/corpus/conll/suc-dev.conll\", \"%s/%s.dev.bio\" % (out_dir, short_name))\n    extract_from_zip(zip_filename, \"SUC3.0/corpus/conll/suc-test.conll\", \"%s/%s.test.bio\" % (out_dir, short_name))\n\ndef main():\n    process_suc3(\"extern_data/ner/sv_suc3/SUC3.0.zip\", \"data/ner\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/suc_to_iob.py",
    "content": "\"\"\"\nConversion tool to transform SUC3's xml format to IOB\n\nCopyright 2017-2022, Emil Stenström\n\nPermission is hereby granted, free of charge, to any person obtaining\na copy of this software and associated documentation files (the\n\"Software\"), to deal in the Software without restriction, including\nwithout limitation the rights to use, copy, modify, merge, publish,\ndistribute, sublicense, and/or sell copies of the Software, and to\npermit persons to whom the Software is furnished to do so, subject to\nthe following conditions:\n\nThe above copyright notice and this permission notice shall be\nincluded in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\nEXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\nMERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\nNONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE\nLIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION\nOF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION\nWITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n\"\"\"\n\nfrom bz2 import BZ2File\nfrom xml.etree.ElementTree import iterparse\nimport argparse\nfrom collections import Counter\nimport sys\n\ndef parse(fp, skiptypes=[]):\n    root = None\n    ne_prefix = \"\"\n    ne_type = \"O\"\n    name_prefix = \"\"\n    name_type = \"O\"\n\n    for event, elem in iterparse(fp, events=(\"start\", \"end\")):\n        if root is None:\n            root = elem\n\n        if event == \"start\":\n            if elem.tag == \"name\":\n                _type = name_type_to_label(elem.attrib[\"type\"])\n                if (\n                    _type not in skiptypes and\n                    not (_type == \"ORG\" and ne_type == \"LOC\")\n                ):\n                    name_type = _type\n                    name_prefix = \"B-\"\n\n            elif elem.tag == \"ne\":\n                _type = ne_type_to_label(elem.attrib[\"type\"])\n                if \"/\" in _type:\n                    _type = ne_type_to_label(_type[_type.index(\"/\") + 1:])\n\n                if _type not in skiptypes:\n                    ne_type = _type\n                    ne_prefix = \"B-\"\n\n            elif elem.tag == \"w\":\n                if name_type == \"PER\" and elem.attrib[\"pos\"] == \"NN\":\n                    name_type = \"O\"\n                    name_prefix = \"\"\n\n        elif event == \"end\":\n            if elem.tag == \"sentence\":\n                yield\n\n            elif elem.tag == \"name\":\n                name_type = \"O\"\n                name_prefix = \"\"\n\n            elif elem.tag == \"ne\":\n                ne_type = \"O\"\n                ne_prefix = \"\"\n\n            elif elem.tag == \"w\":\n                if name_type != \"O\" and name_type != \"OTH\":\n                    yield elem.text, name_prefix, name_type\n                elif ne_type != \"O\":\n                    yield elem.text, ne_prefix, ne_type\n                else:\n                    yield elem.text, \"\", \"O\"\n\n                if ne_type != \"O\":\n                    ne_prefix = \"I-\"\n\n                if name_type != \"O\":\n                    name_prefix = \"I-\"\n\n        root.clear()\n\ndef ne_type_to_label(ne_type):\n    mapping = {\n        \"PRS\": \"PER\",\n    }\n    return mapping.get(ne_type, ne_type)\n\ndef name_type_to_label(name_type):\n    mapping = {\n        \"inst\": \"ORG\",\n        \"product\": \"OBJ\",\n        \"other\": \"OTH\",\n        \"place\": \"LOC\",\n        \"myth\": \"PER\",\n        \"person\": \"PER\",\n        \"event\": \"EVN\",\n        \"work\": \"WRK\",\n        \"animal\": \"PER\",\n    }\n    return mapping.get(name_type)\n\ndef main(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"infile\",\n        help=\"\"\"\n            Input for that contains the full SUC 3.0 XML.\n            Can be the bz2-zipped version or the xml version.\n        \"\"\"\n    )\n    parser.add_argument(\n        \"outfile\",\n        nargs=\"?\",\n        help=\"\"\"\n             Output file for IOB format.\n             Optional - will print to stdout otherwise\n        \"\"\"\n    )\n    parser.add_argument(\n        \"--skiptypes\",\n        help=\"Entity types that should be skipped in output.\",\n        nargs=\"+\",\n        default=[]\n    )\n    parser.add_argument(\n        \"--stats_only\",\n        help=\"Show statistics of found labels at the end of output.\",\n        action='store_true',\n        default=False\n    )\n    args = parser.parse_args(args)\n\n    MAGIC_BZ2_FILE_START = b\"\\x42\\x5a\\x68\"\n    fp = open(args.infile, \"rb\")\n    is_bz2 = (fp.read(len(MAGIC_BZ2_FILE_START)) == MAGIC_BZ2_FILE_START)\n\n    if is_bz2:\n        fp = BZ2File(args.infile, \"rb\")\n    else:\n        fp = open(args.infile, \"rb\")\n\n    if args.outfile is not None:\n        fout = open(args.outfile, \"w\", encoding=\"utf-8\")\n    else:\n        fout = sys.stdout\n\n    type_stats = Counter()\n    for token in parse(fp, skiptypes=args.skiptypes):\n        if not token:\n            if not args.stats_only:\n                fout.write(\"\\n\")\n        else:\n            word, prefix, label = token\n            if args.stats_only:\n                type_stats[label] += 1\n            else:\n                fout.write(\"%s\\t%s%s\\n\" % (word, prefix, label))\n\n    if args.stats_only:\n        fout.write(str(type_stats) + \"\\n\")\n\n    fp.close()\n    if args.outfile is not None:\n        fout.close()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/ner/utils.py",
    "content": "\"\"\"\nUtils for the processing of NER datasets\n\nThese can be invoked from either the specific dataset scripts\nor the entire prepare_ner_dataset.py script\n\"\"\"\n\nfrom collections import defaultdict\nimport io\nimport json\nimport os\nimport random\nimport zipfile\n\nfrom stanza.models.common.doc import Document\nimport stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file\n\nSHARDS = ('train', 'dev', 'test')\n\ndef bioes_to_bio(tags):\n    new_tags = []\n    in_entity = False\n    for tag in tags:\n        if tag == 'O':\n            new_tags.append(tag)\n            in_entity = False\n        elif in_entity and (tag.startswith(\"B-\") or tag.startswith(\"S-\")):\n            # TODO: does the tag have to match the previous tag?\n            # eg, does B-LOC B-PER in BIOES need a B-PER or is I-PER sufficient?\n            new_tags.append('B-' + tag[2:])\n        else:\n            new_tags.append('I-' + tag[2:])\n            in_entity = True\n    return new_tags\n\ndef convert_bioes_to_bio(base_input_path, base_output_path, short_name):\n    \"\"\"\n    Convert BIOES files back to BIO (not BIO2)\n\n    Useful for preparing datasets for CoreNLP, which doesn't do great with the more highly split classes\n    \"\"\"\n    for shard in SHARDS:\n        input_filename = os.path.join(base_input_path, '%s.%s.bioes' % (short_name, shard))\n        output_filename = os.path.join(base_output_path, '%s.%s.bio' % (short_name, shard))\n\n        input_sentences = read_tsv(input_filename, text_column=0, annotation_column=1)\n        new_sentences = []\n        for sentence in input_sentences:\n            tags = [x[1] for x in sentence]\n            tags = bioes_to_bio(tags)\n            sentence = [(x[0], y) for x, y in zip(sentence, tags)]\n            new_sentences.append(sentence)\n        write_sentences(output_filename, new_sentences)\n\n\ndef convert_bio_to_json(base_input_path, base_output_path, short_name, suffix=\"bio\", shard_names=SHARDS, shards=SHARDS):\n    \"\"\"\n    Convert BIO files to json\n\n    It can often be convenient to put the intermediate BIO files in\n    the same directory as the output files, in which case you can pass\n    in same path for both base_input_path and base_output_path.\n\n    This also will rewrite a BIOES as json\n    \"\"\"\n    for input_shard, output_shard in zip(shard_names, shards):\n        input_filename = os.path.join(base_input_path, '%s.%s.%s' % (short_name, input_shard, suffix))\n        if not os.path.exists(input_filename):\n            alt_filename = os.path.join(base_input_path, '%s.%s' % (input_shard, suffix))\n            if os.path.exists(alt_filename):\n                input_filename = alt_filename\n            else:\n                raise FileNotFoundError('Cannot find %s component of %s in %s or %s' % (output_shard, short_name, input_filename, alt_filename))\n        output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, output_shard))\n        print(\"Converting %s to %s\" % (input_filename, output_filename))\n        prepare_ner_file.process_dataset(input_filename, output_filename)\n\ndef get_tags(datasets):\n    \"\"\"\n    return the set of tags used in these datasets\n\n    datasets is expected to be train, dev, test but could be any list\n    \"\"\"\n    tags = set()\n    for dataset in datasets:\n        for sentence in dataset:\n            for word, tag in sentence:\n                tags.add(tag)\n    return tags\n\ndef write_sentences(output_filename, dataset):\n    \"\"\"\n    Write exactly one output file worth of dataset\n    \"\"\"\n    os.makedirs(os.path.split(output_filename)[0], exist_ok=True)\n    with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n        for sent_idx, sentence in enumerate(dataset):\n            for word_idx, word in enumerate(sentence):\n                if len(word) > 2:\n                    word = word[:2]\n                try:\n                    fout.write(\"%s\\t%s\\n\" % word)\n                except TypeError:\n                    raise TypeError(\"Unable to process sentence %d word %d of file %s\" % (sent_idx, word_idx, output_filename))\n            fout.write(\"\\n\")\n\ndef write_dataset(datasets, output_dir, short_name, suffix=\"bio\", shard_names=SHARDS, shards=SHARDS):\n    \"\"\"\n    write all three pieces of a dataset to output_dir\n\n    datasets should be 3 lists: train, dev, test\n    each list should be a list of sentences\n    each sentence is a list of pairs: word, tag\n\n    after writing to .bio files, the files will be converted to .json\n    \"\"\"\n    for shard, dataset in zip(shard_names, datasets):\n        output_filename = os.path.join(output_dir, \"%s.%s.%s\" % (short_name, shard, suffix))\n        write_sentences(output_filename, dataset)\n\n    convert_bio_to_json(output_dir, output_dir, short_name, suffix, shard_names=shard_names, shards=shards)\n\n\ndef write_multitag_json(output_filename, dataset):\n    json_dataset = []\n    for sentence in dataset:\n        json_sentence = []\n        for word in sentence:\n            word = {'text': word[0],\n                    'ner': word[1],\n                    'multi_ner': word[2]}\n            json_sentence.append(word)\n        json_dataset.append(json_sentence)\n    with open(output_filename, 'w', encoding='utf-8') as fout:\n        json.dump(json_dataset, fout, indent=2)\n\ndef write_multitag_dataset(datasets, output_dir, short_name, suffix=\"bio\", shard_names=SHARDS, shards=SHARDS):\n    for shard, dataset in zip(shard_names, datasets):\n        output_filename = os.path.join(output_dir, \"%s.%s.%s\" % (short_name, shard, suffix))\n        write_sentences(output_filename, dataset)\n\n    for shard, dataset in zip(shard_names, datasets):\n        output_filename = os.path.join(output_dir, \"%s.%s.json\" % (short_name, shard))\n        write_multitag_json(output_filename, dataset)\n\ndef read_tsv(filename, text_column, annotation_column, remap_tag_fn=None, remap_line=None, skip_comments=True, keep_broken_tags=False, keep_all_columns=False, separator=\"\\t\", zip_filename=None):\n    \"\"\"\n    Read sentences from a TSV file\n\n    Returns a list of list of (word, tag)\n\n    If keep_broken_tags==True, then None is returned for a missing.  Otherwise, an IndexError is thrown\n    \"\"\"\n    if zip_filename is not None:\n        with zipfile.ZipFile(zip_filename) as zin:\n            with zin.open(filename) as fin:\n                fin = io.TextIOWrapper(fin, encoding='utf-8')\n                lines = fin.readlines()\n    else:\n        with open(filename, encoding=\"utf-8\") as fin:\n            lines = fin.readlines()\n\n    lines = [x.strip() for x in lines]\n\n    sentences = []\n    current_sentence = []\n    for line_idx, line in enumerate(lines):\n        if not line:\n            if current_sentence:\n                sentences.append(current_sentence)\n                current_sentence = []\n            continue\n        if skip_comments and line.startswith(\"#\"):\n            continue\n\n        if remap_line is not None:\n            line = remap_line(line)\n        pieces = line.split(separator)\n        try:\n            word = pieces[text_column]\n        except IndexError as e:\n            raise IndexError(\"Filename %s: could not find word index %d at line %d |%s|\" % (filename, text_column, line_idx, line)) from e\n        if word == '\\x96':\n            # this happens in GermEval2014 for some reason\n            continue\n        try:\n            tag = pieces[annotation_column]\n        except IndexError as e:\n            if keep_broken_tags:\n                tag = None\n            else:\n                raise IndexError(\"Filename %s: could not find tag index %d at line %d |%s|\" % (filename, annotation_column, line_idx, line)) from e\n        if remap_tag_fn is not None:\n            tag = remap_tag_fn(tag)\n\n        if keep_all_columns:\n            pieces[annotation_column] = tag\n            current_sentence.append(pieces)\n        else:\n            current_sentence.append((word, tag))\n\n    if current_sentence:\n        sentences.append(current_sentence)\n\n    return sentences\n\ndef random_shuffle_directory(input_dir, output_dir, short_name):\n    input_files = os.listdir(input_dir)\n    input_files = sorted(input_files)\n    random_shuffle_files(input_dir, input_files, output_dir, short_name)\n\ndef random_shuffle_files(input_dir, input_files, output_dir, short_name):\n    \"\"\"\n    Shuffle the files into different chunks based on their filename\n\n    The first piece of the filename, split by \".\", is used as a random seed.\n\n    This will make it so that adding new files or using a different\n    annotation scheme (assuming that's encoding in pieces of the\n    filename) won't change the distibution of the files\n    \"\"\"\n    input_keys = {}\n    for f in input_files:\n        seed = f.split(\".\")[0]\n        if seed in input_keys:\n            raise ValueError(\"Multiple files with the same prefix: %s and %s\" % (input_keys[seed], f))\n        input_keys[seed] = f\n    assert len(input_keys) == len(input_files)\n\n    train_files = []\n    dev_files = []\n    test_files = []\n\n    for filename in input_files:\n        seed = filename.split(\".\")[0]\n        # \"salt\" the filenames when using as a seed\n        # definitely not because of a dumb bug in the original implementation\n        seed = seed + \".txt.4class.tsv\"\n        random.seed(seed, 2)\n        location = random.random()\n        if location < 0.7:\n            train_files.append(filename)\n        elif location < 0.8:\n            dev_files.append(filename)\n        else:\n            test_files.append(filename)\n\n    print(\"Train files: %d  Dev files: %d  Test files: %d\" % (len(train_files), len(dev_files), len(test_files)))\n    assert len(train_files) + len(dev_files) + len(test_files) == len(input_files)\n\n    file_lists = [train_files, dev_files, test_files]\n    datasets = []\n    for files in file_lists:\n        dataset = []\n        for filename in files:\n            dataset.extend(read_tsv(os.path.join(input_dir, filename), 0, 1))\n        datasets.append(dataset)\n\n    write_dataset(datasets, output_dir, short_name)\n    return len(train_files), len(dev_files), len(test_files)\n\ndef random_shuffle_by_prefixes(input_dir, output_dir, short_name, prefix_map):\n    input_files = os.listdir(input_dir)\n    input_files = sorted(input_files)\n\n    file_divisions = defaultdict(list)\n    for filename in input_files:\n        for division in prefix_map.keys():\n            for prefix in prefix_map[division]:\n                if filename.startswith(prefix):\n                    break\n            else: # for/else is intentional\n                continue\n            break\n        else: # yes, stop asking\n            raise ValueError(\"Could not assign %s to any of the divisions in the prefix_map\" % filename)\n        #print(\"Assigning %s to %s because of %s\" % (filename, division, prefix))\n        file_divisions[division].append(filename)\n\n    num_train_files = 0\n    num_dev_files = 0\n    num_test_files = 0\n    for division in file_divisions.keys():\n        print()\n        print(\"Processing %d files from %s\" % (len(file_divisions[division]), division))\n        d_train, d_dev, d_test = random_shuffle_files(input_dir, file_divisions[division], output_dir, \"%s-%s\" % (short_name, division))\n        num_train_files += d_train\n        num_dev_files += d_dev\n        num_test_files += d_test\n\n    print()\n    print(\"After shuffling: Train files: %d  Dev files: %d  Test files: %d\" % (num_train_files, num_dev_files, num_test_files))\n    dataset_divisions = [\"%s-%s\" % (short_name, division) for division in file_divisions]\n    combine_dataset(output_dir, output_dir, dataset_divisions, short_name)\n\ndef combine_dataset(input_dir, output_dir, input_datasets, output_dataset):\n    datasets = []\n    for shard in SHARDS:\n        full_dataset = []\n        for input_dataset in input_datasets:\n            input_filename = \"%s.%s.json\" % (input_dataset, shard)\n            input_path = os.path.join(input_dir, input_filename)\n            with open(input_path, encoding=\"utf-8\") as fin:\n                dataset = json.load(fin)\n                converted = [[(word['text'], word['ner']) for word in sentence] for sentence in dataset]\n                full_dataset.extend(converted)\n        datasets.append(full_dataset)\n    write_dataset(datasets, output_dir, output_dataset)\n\ndef read_prefix_file(destination_file):\n    \"\"\"\n    Read a prefix file such as the one for the Worldwide dataset\n\n    the format should be\n\n    africa:\n    af_\n    ...\n\n    asia:\n    cn_\n    ...\n    \"\"\"\n    destination = None\n    known_prefixes = set()\n    prefixes = []\n\n    prefix_map = {}\n    with open(destination_file, encoding=\"utf-8\") as fin:\n        for line in fin:\n            line = line.strip()\n            if line.startswith(\"#\"):\n                continue\n            if not line:\n                continue\n            if line.endswith(\":\"):\n                if destination is not None:\n                    prefix_map[destination] = prefixes\n                prefixes = []\n                destination = line[:-1].strip().lower().replace(\" \", \"_\")\n            else:\n                if not destination:\n                    raise RuntimeError(\"Found a prefix before the first label was assigned when reading %s\" % destination_file)\n                prefixes.append(line)\n                if line in known_prefixes:\n                    raise RuntimeError(\"Found the same prefix twice! %s\" % line)\n                known_prefixes.add(line)\n\n        if destination and prefixes:\n            prefix_map[destination] = prefixes\n\n    return prefix_map\n\ndef read_json_entities(filename):\n    \"\"\"\n    Read entities from a file, return a list of (text, label)\n\n    Should work on both BIOES and BIO\n    \"\"\"\n    with open(filename) as fin:\n        doc = Document(json.load(fin))\n\n        return list_doc_entities(doc)\n\ndef list_doc_entities(doc):\n    \"\"\"\n    Return a list of (text, label)\n\n    Should work on both BIOES and BIO\n    \"\"\"\n    entities = []\n    for sentence in doc.sentences:\n        current_entity = []\n        previous_label = None\n        for token in sentence.tokens:\n            if token.ner == 'O' or token.ner.startswith(\"E-\"):\n                if token.ner.startswith(\"E-\"):\n                    current_entity.append(token.text)\n                if current_entity:\n                    assert previous_label is not None\n                    entities.append((current_entity, previous_label))\n                    current_entity = []\n                    previous_label = None\n            elif token.ner.startswith(\"I-\"):\n                if previous_label is not None and previous_label != 'O' and previous_label != token.ner[2:]:\n                    if current_entity:\n                        assert previous_label is not None\n                        entities.append((current_entity, previous_label))\n                        current_entity = []\n                        previous_label = token.ner[2:]\n                current_entity.append(token.text)\n            elif token.ner.startswith(\"B-\") or token.ner.startswith(\"S-\"):\n                if current_entity:\n                    assert previous_label is not None\n                    entities.append((current_entity, previous_label))\n                    current_entity = []\n                    previous_label = None\n                current_entity.append(token.text)\n                previous_label = token.ner[2:]\n                if token.ner.startswith(\"S-\"):\n                    assert previous_label is not None\n                    entities.append(current_entity)\n                    current_entity = []\n                    previous_label = None\n            else:\n                raise RuntimeError(\"Expected BIO(ES) format in the json file!\")\n            previous_label = token.ner[2:]\n        if current_entity:\n            assert previous_label is not None\n            entities.append((current_entity, previous_label))\n    entities = [(tuple(x[0]), x[1]) for x in entities]\n    return entities\n\ndef combine_files(output_filename, *input_filenames):\n    \"\"\"\n    Combine multiple NER json files into one NER file\n    \"\"\"\n    doc = []\n\n    for filename in input_filenames:\n        with open(filename) as fin:\n            new_doc = json.load(fin)\n            doc.extend(new_doc)\n\n    with open(output_filename, \"w\") as fout:\n        json.dump(doc, fout, indent=2)\n\n"
  },
  {
    "path": "stanza/utils/datasets/pos/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/pos/convert_trees_to_pos.py",
    "content": "\"\"\"\nTurns a constituency treebank into a POS dataset with the tags as the upos column\n\nThe constituency treebank first has to be converted from the original\ndata to PTB style trees.  This script converts trees from the\nCONSTITUENCY_DATA_DIR folder to a conllu dataset in the POS_DATA_DIR folder.\n\nNote that this doesn't pay any attention to whether or not the tags actually are upos.\nAlso not possible: using this for tokenization.\n\nTODO: upgrade the POS model to handle xpos datasets with no upos, then make upos/xpos an option here\n\nTo run this:\n  python3 stanza/utils/training/run_pos.py vi_vlsp22\n\n\"\"\"\n\nimport argparse\nimport os\nimport shutil\nimport sys\n\nfrom stanza.models.constituency import tree_reader\nimport stanza.utils.default_paths as default_paths\nfrom stanza.utils.get_tqdm import get_tqdm\n\ntqdm = get_tqdm()\n\nSHARDS = (\"train\", \"dev\", \"test\")\n\ndef convert_file(in_file, out_file, upos):\n    print(\"Reading %s\" % in_file)\n    trees = tree_reader.read_tree_file(in_file)\n    print(\"Writing %s\" % out_file)\n    with open(out_file, \"w\") as fout:\n        for tree in tqdm(trees):\n            tree = tree.simplify_labels()\n            text = \" \".join(tree.leaf_labels())\n            fout.write(\"# text = %s\\n\" % text)\n\n            for pt_idx, pt in enumerate(tree.yield_preterminals()):\n                # word index\n                fout.write(\"%d\\t\" % (pt_idx+1))\n                # word\n                fout.write(\"%s\\t\" % pt.children[0].label)\n                # don't know the lemma\n                fout.write(\"_\\t\")\n                # always put the tag, whatever it is, in the upos (for now)\n                if upos:\n                    fout.write(\"%s\\t_\\t\" % pt.label)\n                else:\n                    fout.write(\"_\\t%s\\t\" % pt.label)\n                # don't have any features\n                fout.write(\"_\\t\")\n                # so word 0 fake dep on root, everyone else fake dep on previous word\n                fout.write(\"%d\\t\" % pt_idx)\n                if pt_idx == 0:\n                    fout.write(\"root\")\n                else:\n                    fout.write(\"dep\")\n                fout.write(\"\\t_\\t_\\n\")\n            fout.write(\"\\n\")\n\ndef convert_treebank(short_name, upos, output_name, paths):\n    in_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    in_files = [os.path.join(in_dir, \"%s_%s.mrg\" % (short_name, shard)) for shard in SHARDS]\n    for in_file in in_files:\n        if not os.path.exists(in_file):\n            raise FileNotFoundError(\"Cannot find expected datafile %s\" % in_file)\n\n    out_dir = paths[\"POS_DATA_DIR\"]\n    if not os.path.exists(out_dir):\n        os.makedirs(out_dir)\n    if output_name is None:\n        output_name = short_name\n    out_files = [os.path.join(out_dir, \"%s.%s.in.conllu\" % (output_name, shard)) for shard in SHARDS]\n    gold_files = [os.path.join(out_dir, \"%s.%s.gold.conllu\" % (output_name, shard)) for shard in SHARDS]\n\n    for in_file, out_file in zip(in_files, out_files):\n        convert_file(in_file, out_file, upos)\n    for out_file, gold_file in zip(out_files, gold_files):\n        shutil.copy2(out_file, gold_file)\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"dataset\", help=\"Which dataset to process from trees to POS\")\n    parser.add_argument(\"--upos\", action=\"store_true\", default=False, help=\"Store tags on the UPOS\")\n    parser.add_argument(\"--xpos\", dest=\"upos\", action=\"store_false\", help=\"Store tags on the XPOS\")\n    parser.add_argument(\"--output_name\", default=None, help=\"What name to give the output dataset.  If blank, will use the dataset arg\")\n    args = parser.parse_args()\n\n    paths = default_paths.get_default_paths()\n\n    convert_treebank(args.dataset, args.upos, args.output_name, paths)\n"
  },
  {
    "path": "stanza/utils/datasets/pos/remove_columns.py",
    "content": "\"\"\"\nRemove xpos and feats from each file given at the command line.\n\nUseful to strip unwanted tags when combining files of two different\ntypes (or two different stages in the annotation process).\n\nSuper rudimentary right now.  Will be upgraded if needed\n\"\"\"\n\nimport sys\n\nfrom stanza.utils.conll import CoNLL\n\ndef remove_columns(filename):\n    doc = CoNLL.conll2doc(filename)\n\n    for sentence in doc.sentences:\n        for word in sentence.words:\n            word.xpos = None\n            word.feats = None\n\n    CoNLL.write_doc2conll(doc, filename)\n\nif __name__ == '__main__':\n    for filename in sys.argv[1:]:\n        remove_columns(filename)\n"
  },
  {
    "path": "stanza/utils/datasets/prepare_depparse_treebank.py",
    "content": "\"\"\"\nA script to prepare all depparse datasets.\n\nPrepares each of train, dev, test.\n\nExample:\n    python -m stanza.utils.datasets.prepare_depparse_treebank {TREEBANK}\nExample:\n    python -m stanza.utils.datasets.prepare_depparse_treebank UD_English-EWT\n\"\"\"\n\nfrom enum import Enum\nimport glob\nimport logging\nimport os\n\nfrom stanza.models import tagger\nfrom stanza.models.common.constant import treebank_to_short_name\nfrom stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError\nfrom stanza.resources.default_packages import default_charlms, pos_charlms\nimport stanza.utils.datasets.common as common\nimport stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank\nfrom stanza.utils.training.common import build_pos_wordvec_args\nfrom stanza.utils.training.common import add_charlm_args, build_charlm_args, choose_charlm\n\nlogger = logging.getLogger('stanza')\n\n\nclass Tags(Enum):\n    \"\"\"Tags parameter values.\"\"\"\n\n    GOLD = 1\n    PREDICTED = 2\n\n\n# fmt: off\ndef add_specific_args(parser) -> None:\n    \"\"\"Add specific args.\"\"\"\n    parser.add_argument(\"--gold\", dest='tag_method', action='store_const', const=Tags.GOLD, default=Tags.PREDICTED,\n                        help='Use gold tags for building the depparse data')\n    parser.add_argument(\"--predicted\", dest='tag_method', action='store_const', const=Tags.PREDICTED,\n                        help='Use predicted tags for building the depparse data')\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None,\n                        help='Exact name of the pretrain file to read')\n    parser.add_argument('--tagger_model', type=str, default=None,\n                        help='Tagger save file to use.  If not specified, order searched will be saved/models, then $STANZA_RESOURCES_DIR')\n    parser.add_argument('--save_dir', type=str, default=os.path.join('saved_models', 'pos'),\n                        help='Where to look for recently trained POS models')\n    parser.add_argument('--no_download_tagger', default=True, dest='download_tagger', action='store_false',\n                        help=\"Don't try to automatically download a tagger for retagging the dependencies.  Will fail to make silver tags if there is no tagger model to be found\")\n    add_charlm_args(parser)\n# fmt: on\n\ndef choose_tagger_model(short_language, dataset, tagger_model, args):\n    \"\"\"\n    Preferentially chooses a retrained tagger model, but tries to download one if that doesn't exist\n    \"\"\"\n    logger.debug(\"Looking for tagger for lang |%s| dataset |%s|.  Suggested model |%s|.  Looking first in |%s|.\", short_language, dataset, tagger_model, args.save_dir)\n    if tagger_model:\n        return tagger_model\n\n    candidates = glob.glob(os.path.join(args.save_dir, \"%s_%s_*.pt\" % (short_language, dataset)))\n    if len(candidates) == 1:\n        return candidates[0]\n    if len(candidates) > 1:\n        for ending in (\"_trans_tagger.pt\", \"_charlm_tagger.pt\", \"_nocharlm_tagger.pt\"):\n            best_candidates = [x for x in candidates if x.endswith(ending)]\n            if len(best_candidates) == 1:\n                return best_candidates[0]\n            if len(best_candidates) > 1:\n                raise FileNotFoundError(\"Could not choose among the candidate taggers... please pick one with --tagger_model: {}\".format(best_candidates))\n        raise FileNotFoundError(\"Could not choose among the candidate taggers... please pick one with --tagger_model: {}\".format(candidates))\n\n    if not args.download_tagger:\n        return None\n\n    # TODO: just create a Pipeline for the retagging instead?\n    pos_path = os.path.join(DEFAULT_MODEL_DIR, short_language, \"pos\", dataset + \".pt\")\n    if os.path.exists(pos_path):\n        return pos_path\n    try:\n        download_list = download(lang=short_language, package=None, processors={\"pos\": dataset})\n    except UnknownLanguageError as e:\n        raise FileNotFoundError(\"The language %s appears to be a language new to Stanza.  Unfortunately, that means there are no taggers available for retagging the dependency dataset.  Furthermore, there are no tagger models for this language found in %s.  You can specify a different directory for already trained tagger models with --save_dir, specify an exact tagger model name with --tagger_model, or use gold tags with --gold\" % (short_language, args.save_dir)) from e\n    for processor, name in download_list:\n        if processor == 'pos':\n            pos_path = os.path.join(DEFAULT_MODEL_DIR, short_language, \"pos\", name + \".pt\")\n            return pos_path\n    else:\n        raise FileNotFoundError(\"Could not figure out which model file to use for %s.  Just tried to download to %s the models %s\" % (short_language, args.save_dir, download_list))\n\n\ndef process_treebank(treebank, model_type, paths, args) -> None:\n    \"\"\"Process treebank.\"\"\"\n    if args.tag_method is Tags.GOLD:\n        prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths[\"DEPPARSE_DATA_DIR\"])\n    elif args.tag_method is Tags.PREDICTED:\n        short_name = treebank_to_short_name(treebank)\n        short_language, dataset = short_name.split(\"_\", 1)\n\n        # fmt: off\n        base_args = [\"--wordvec_dir\", paths[\"WORDVEC_DIR\"],\n                     \"--lang\", short_language,\n                     \"--shorthand\", short_name,\n                     \"--mode\", \"predict\"]\n        # fmt: on\n\n        # perhaps download a tagger if one doesn't already exist\n        tagger_model = choose_tagger_model(short_language, dataset, args.tagger_model, args)\n        if tagger_model is None:\n            raise FileNotFoundError(\"Cannot find a tagger for language %s, dataset %s - you can specify one with the --tagger_model flag\")\n        else:\n            logger.info(\"Using tagger model in %s for %s_%s\", tagger_model, short_language, dataset)\n        tagger_dir, tagger_name = os.path.split(tagger_model)\n        base_args = base_args + ['--save_dir', tagger_dir, '--save_name', tagger_name]\n\n        # word vector file for POS\n        if args.wordvec_pretrain_file:\n            base_args += [\"--wordvec_pretrain_file\", args.wordvec_pretrain_file]\n        else:\n            base_args = base_args + build_pos_wordvec_args(short_language, dataset, [])\n\n\n        # charlm for POS\n        charlm = choose_charlm(short_language, dataset, args.charlm, default_charlms, pos_charlms)\n        charlm_args = build_charlm_args(short_language, charlm)\n        base_args = base_args + charlm_args\n\n        def retag_dataset(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name):\n            original = f\"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu\"\n            retagged = f\"{dest_dir}/{short_name}.{dest_file}.conllu\"\n            # fmt: off\n            tagger_args = [\"--eval_file\", original,\n                           \"--output_file\", retagged]\n            # fmt: on\n            tagger_args = base_args + tagger_args\n            logger.info(\"Running tagger to retag {} to {}\\n  Args: {}\".format(original, retagged, tagger_args))\n            tagger.main(tagger_args)\n\n        prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths[\"DEPPARSE_DATA_DIR\"], retag_dataset)\n    else:\n        raise ValueError(\"Unknown tags method: {}\".format(args.tag_method))\n\n\ndef main() -> None:\n    \"\"\"Call Process Treebank.\"\"\"\n    common.main(process_treebank, common.ModelType.DEPPARSE, add_specific_args)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/prepare_lemma_classifier.py",
    "content": "import os\nimport sys\n\nfrom stanza.utils.datasets.common import find_treebank_dataset_file, UnknownDatasetError\nfrom stanza.utils.default_paths import get_default_paths\nfrom stanza.models.lemma_classifier import prepare_dataset\nfrom stanza.models.common.short_name_to_treebank import short_name_to_treebank\nfrom stanza.utils.conll import CoNLL\n\nSECTIONS = (\"train\", \"dev\", \"test\")\n\ndef process_treebank(paths, short_name, word, upos, allowed_lemmas, sections=SECTIONS):\n    treebank = short_name_to_treebank(short_name)\n    udbase_dir = paths[\"UDBASE\"]\n\n    output_dir = paths[\"LEMMA_CLASSIFIER_DATA_DIR\"]\n    os.makedirs(output_dir, exist_ok=True)\n\n    output_filenames = []\n\n    for section in sections:\n        filename = find_treebank_dataset_file(treebank, udbase_dir, section, \"conllu\", fail=True)\n        output_filename = os.path.join(output_dir, \"%s.%s.lemma\" % (short_name, section))\n        args = [\"--conll_path\", filename,\n                \"--target_word\", word,\n                \"--target_upos\", upos,\n                \"--output_path\", output_filename]\n        if allowed_lemmas is not None:\n            args.extend([\"--allowed_lemmas\", allowed_lemmas])\n        prepare_dataset.main(args)\n        output_filenames.append(output_filename)\n\n    return output_filenames\n\ndef process_en_combined(paths, short_name):\n    udbase_dir = paths[\"UDBASE\"]\n    output_dir = paths[\"LEMMA_CLASSIFIER_DATA_DIR\"]\n    os.makedirs(output_dir, exist_ok=True)\n\n    train_treebanks = [\"UD_English-EWT\", \"UD_English-GUM\", \"UD_English-GUMReddit\", \"UD_English-LinES\"]\n    test_treebanks = [\"UD_English-PUD\", \"UD_English-Pronouns\"]\n\n    target_word = \"'s\"\n    target_upos = [\"AUX\"]\n\n    sentences = [ [], [], [] ]\n    for treebank in train_treebanks:\n        for section_idx, section in enumerate(SECTIONS):\n            filename = find_treebank_dataset_file(treebank, udbase_dir, section, \"conllu\", fail=True)\n            doc = CoNLL.conll2doc(filename)\n            processor = prepare_dataset.DataProcessor(target_word=target_word, target_upos=target_upos, allowed_lemmas=\".*\")\n            new_sentences = processor.process_document(doc, save_name=None)\n            print(\"Read %d sentences from %s\" % (len(new_sentences), filename))\n            sentences[section_idx].extend(new_sentences)\n    for treebank in test_treebanks:\n        section = \"test\"\n        filename = find_treebank_dataset_file(treebank, udbase_dir, section, \"conllu\", fail=True)\n        doc = CoNLL.conll2doc(filename)\n        processor = prepare_dataset.DataProcessor(target_word=target_word, target_upos=target_upos, allowed_lemmas=\".*\")\n        new_sentences = processor.process_document(doc, save_name=None)\n        print(\"Read %d sentences from %s\" % (len(new_sentences), filename))\n        sentences[2].extend(new_sentences)\n\n    for section, section_sentences in zip(SECTIONS, sentences):\n        output_filename = os.path.join(output_dir, \"%s.%s.lemma\" % (short_name, section))\n        prepare_dataset.DataProcessor.write_output_file(output_filename, target_upos, section_sentences)\n        print(\"Wrote %s sentences to %s\" % (len(section_sentences), output_filename))\n\ndef process_ja_gsd(paths, short_name):\n    # this one looked promising, but only has 10 total dev & test cases\n    # 行っ VERB Counter({'行う': 60, '行く': 38})\n    # could possibly do\n    # ない AUX Counter({'ない': 383, '無い': 99})\n    # なく AUX Counter({'無い': 53, 'ない': 42})\n    # currently this one has enough in the dev & test data\n    # and functions well\n    # だ AUX Counter({'だ': 237, 'た': 67})\n    word = \"だ\"\n    upos = \"AUX\"\n    allowed_lemmas = None\n\n    process_treebank(paths, short_name, word, upos, allowed_lemmas)\n\ndef process_fa_perdt(paths, short_name):\n    word = \"شد\"\n    upos = \"VERB\"\n    allowed_lemmas = \"کرد|شد\"\n\n    process_treebank(paths, short_name, word, upos, allowed_lemmas)\n\ndef process_hi_hdtb(paths, short_name):\n    word = \"के\"\n    upos = \"ADP\"\n    allowed_lemmas = \"का|के\"\n\n    process_treebank(paths, short_name, word, upos, allowed_lemmas)\n\ndef process_ar_padt(paths, short_name):\n    word = \"أن\"\n    upos = \"SCONJ\"\n    allowed_lemmas = \"أَن|أَنَّ\"\n\n    process_treebank(paths, short_name, word, upos, allowed_lemmas)\n\ndef process_el_gdt(paths, short_name):\n    \"\"\"\n    All of the Greek lemmas for these words are εγώ or μου\n\n    τους PRON Counter({'μου': 118, 'εγώ': 32})\n    μας PRON Counter({'μου': 89, 'εγώ': 32})\n    του PRON Counter({'μου': 82, 'εγώ': 8})\n    της PRON Counter({'μου': 80, 'εγώ': 2})\n    σας PRON Counter({'μου': 34, 'εγώ': 24})\n    μου PRON Counter({'μου': 45, 'εγώ': 10})\n    \"\"\"\n    word = \"τους|μας|του|της|σας|μου\"\n    upos = \"PRON\"\n    allowed_lemmas = None\n\n    process_treebank(paths, short_name, word, upos, allowed_lemmas)\n\nDATASET_MAPPING = {\n    \"ar_padt\":           process_ar_padt,\n    \"el_gdt\":            process_el_gdt,\n    \"en_combined\":       process_en_combined,\n    \"fa_perdt\":          process_fa_perdt,\n    \"hi_hdtb\":           process_hi_hdtb,\n    \"ja_gsd\":            process_ja_gsd,\n}\n\n\ndef main(dataset_name):\n    paths = get_default_paths()\n    print(\"Processing %s\" % dataset_name)\n\n    # obviously will want to multiplex to multiple languages / datasets\n    if dataset_name in DATASET_MAPPING:\n        DATASET_MAPPING[dataset_name](paths, dataset_name)\n    else:\n        raise UnknownDatasetError(dataset_name, f\"dataset {dataset_name} currently not handled by prepare_lemma_classifier.py\")\n    print(\"Done processing %s\" % dataset_name)\n\nif __name__ == '__main__':\n    main(sys.argv[1])\n"
  },
  {
    "path": "stanza/utils/datasets/prepare_lemma_treebank.py",
    "content": "\"\"\"\nA script to prepare all lemma datasets.\n\nFor example, do\n  python -m stanza.utils.datasets.prepare_lemma_treebank TREEBANK\nsuch as\n  python -m stanza.utils.datasets.prepare_lemma_treebank UD_English-EWT\n\nand it will prepare each of train, dev, test\n\"\"\"\n\nfrom stanza.models.common.constant import treebank_to_short_name\n\nimport stanza.utils.datasets.common as common\nimport stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank\n\nimport stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier\n\ndef add_specific_args(parser) -> None:\n    parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false', default=True,\n                        help=\"Don't use the lemma classifier datasets.  Default is to build lemma classifier as part of the original lemmatizer\")\n\ndef check_lemmas(train_file):\n    \"\"\"\n    Check if a treebank has any lemmas in it\n\n    For example, in Vietnamese-VTB, all the words and lemmas are exactly the same\n    in Telugu-MTG, all the lemmas are blank\n    \"\"\"\n    # could eliminate a few languages immediately based on UD 2.7\n    # but what if a later dataset includes lemmas?\n    #if short_language in ('vi', 'fro', 'th'):\n    #    return False\n    with open(train_file, encoding=\"utf-8\") as fin:\n        for line in fin:\n            line = line.strip()\n            if not line or line.startswith(\"#\"):\n                continue\n            pieces = line.split(\"\\t\")\n            word = pieces[1].lower().strip()\n            lemma = pieces[2].lower().strip()\n            if not lemma or lemma == '_' or lemma == '-':\n                continue\n            if word == lemma:\n                continue\n            return True\n    return False\n\ndef process_treebank(treebank, model_type, paths, args):\n    if treebank.startswith(\"UD_\"):\n        udbase_dir = paths[\"UDBASE\"]\n        input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, \"train\", \"conllu\")\n        if not input_conllu:\n            input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, \"test\", \"conllu\", fail=True)\n        augment = check_lemmas(input_conllu)\n        if not augment:\n            print(\"No lemma information found in %s.  Not augmenting the dataset\" % train_conllu)\n    else:\n        # TODO: check the data to see if there are lemmas or not\n        augment = True\n    prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths[\"LEMMA_DATA_DIR\"], augment=augment)\n\n    short_name = treebank_to_short_name(treebank)\n    if args.lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING:\n        prepare_lemma_classifier.main(short_name)\n\ndef main():\n    common.main(process_treebank, common.ModelType.LEMMA, add_specific_args)\n\nif __name__ == '__main__':\n    main()\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/prepare_mwt_treebank.py",
    "content": "\"\"\"\nA script to prepare all MWT datasets.\n\nFor example, do\n  python -m stanza.utils.datasets.prepare_mwt_treebank TREEBANK\nsuch as\n  python -m stanza.utils.datasets.prepare_mwt_treebank UD_English-EWT\n\nand it will prepare each of train, dev, test\n\"\"\"\n\nimport argparse\nimport os\nimport shutil\nimport tempfile\n\nfrom stanza.utils.conll import CoNLL\nfrom stanza.models.common.constant import treebank_to_short_name\nimport stanza.utils.datasets.common as common\nimport stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank\n\nfrom stanza.utils.datasets.contract_mwt import contract_mwt\n\n# languages where the MWTs are always a composition of the words themselves\nKNOWN_COMPOSABLE_MWTS = {\"en\"}\n# ... but partut is not put together that way\nMWT_EXCEPTIONS = {\"en_partut\"}\n\ndef copy_conllu(tokenizer_dir, mwt_dir, short_name, dataset, particle):\n    input_conllu_tokenizer = f\"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu\"\n    input_conllu_mwt = f\"{mwt_dir}/{short_name}.{dataset}.{particle}.conllu\"\n    shutil.copyfile(input_conllu_tokenizer, input_conllu_mwt)\n\ndef check_mwt_composition(filename):\n    print(\"Checking the MWTs in %s\" % filename)\n    doc = CoNLL.conll2doc(filename)\n    for sent_idx, sentence in enumerate(doc.sentences):\n        for token_idx, token in enumerate(sentence.tokens):\n            if len(token.words) > 1:\n                expected = \"\".join(x.text for x in token.words)\n                if token.text != expected:\n                    raise ValueError(\"Unexpected token composition in filename %s sentence %d id %s token %d: %s instead of %s\" % (filename, sent_idx, sentence.sent_id, token_idx, token.text, expected))\n\ndef process_treebank(treebank, model_type, paths, args):\n    short_name = treebank_to_short_name(treebank)\n\n    mwt_dir = paths[\"MWT_DATA_DIR\"]\n    os.makedirs(mwt_dir, exist_ok=True)\n\n    with tempfile.TemporaryDirectory() as tokenizer_dir:\n        paths = dict(paths)\n        paths[\"TOKENIZE_DATA_DIR\"] = tokenizer_dir\n\n        # first we process the tokenization data\n        tokenizer_args = argparse.Namespace()\n        tokenizer_args.augment = False\n        tokenizer_args.prepare_labels = True\n        prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, tokenizer_args)\n\n        copy_conllu(tokenizer_dir, mwt_dir, short_name, \"train\", \"in\")\n        copy_conllu(tokenizer_dir, mwt_dir, short_name, \"dev\", \"gold\")\n        copy_conllu(tokenizer_dir, mwt_dir, short_name, \"test\", \"gold\")\n\n        for shard in (\"train\", \"dev\", \"test\"):\n            source_filename = common.mwt_name(tokenizer_dir, short_name, shard)\n            dest_filename = common.mwt_name(mwt_dir, short_name, shard)\n            print(\"Copying from %s to %s\" % (source_filename, dest_filename))\n            shutil.copyfile(source_filename, dest_filename)\n\n        language = short_name.split(\"_\", 1)[0]\n        if language in KNOWN_COMPOSABLE_MWTS and short_name not in MWT_EXCEPTIONS:\n            print(\"Language %s is known to have all MWT composed of exactly its word pieces.  Checking...\" % language)\n            check_mwt_composition(f\"{mwt_dir}/{short_name}.train.in.conllu\")\n            check_mwt_composition(f\"{mwt_dir}/{short_name}.dev.gold.conllu\")\n            check_mwt_composition(f\"{mwt_dir}/{short_name}.test.gold.conllu\")\n\n        contract_mwt(f\"{mwt_dir}/{short_name}.dev.gold.conllu\",\n                     f\"{mwt_dir}/{short_name}.dev.in.conllu\")\n        contract_mwt(f\"{mwt_dir}/{short_name}.test.gold.conllu\",\n                     f\"{mwt_dir}/{short_name}.test.in.conllu\")\n\ndef main():\n    common.main(process_treebank, common.ModelType.MWT)\n\nif __name__ == '__main__':\n    main()\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/prepare_pos_treebank.py",
    "content": "\"\"\"\nA script to prepare all pos datasets.\n\nFor example, do\n  python -m stanza.utils.datasets.prepare_pos_treebank TREEBANK\nsuch as\n  python -m stanza.utils.datasets.prepare_pos_treebank UD_English-EWT\n\nand it will prepare each of train, dev, test\n\"\"\"\n\nimport os\nimport shutil\n\nimport stanza.utils.datasets.common as common\nimport stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank\n\ndef copy_conllu_file_or_zip(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name):\n    original = f\"{tokenizer_dir}/{short_name}.{tokenizer_file}.zip\"\n    copied = f\"{dest_dir}/{short_name}.{dest_file}.zip\"\n\n    if os.path.exists(original):\n        print(\"Copying from %s to %s\" % (original, copied))\n        shutil.copyfile(original, copied)\n    else:\n        prepare_tokenizer_treebank.copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name)\n\n\ndef process_treebank(treebank, model_type, paths, args):\n    prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths[\"POS_DATA_DIR\"], postprocess=copy_conllu_file_or_zip)\n\ndef main():\n    common.main(process_treebank, common.ModelType.POS)\n\nif __name__ == '__main__':\n    main()\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/prepare_tokenizer_data.py",
    "content": "import argparse\nimport json\nimport os\nimport re\nimport sys\n\nfrom collections import Counter\n\n\"\"\"\nData is output in 4 files:\n\na file containing the mwt information\na file containing the words and sentences in conllu format\na file containing the raw text of each paragraph\na file of 0,1,2 indicating word break or sentence break on a character level for the raw text\n  1: end of word\n  2: end of sentence\n\"\"\"\n\nPARAGRAPH_BREAK = re.compile(r'\\n\\s*\\n')\n\ndef is_para_break(index, text):\n    \"\"\" Detect if a paragraph break can be found, and return the length of the paragraph break sequence. \"\"\"\n    if text[index] == '\\n':\n        para_break = PARAGRAPH_BREAK.match(text, index)\n        if para_break:\n            break_len = len(para_break.group(0))\n            return True, break_len\n    return False, 0\n\ndef find_next_word(index, text, word, output):\n    \"\"\"\n    Locate the next word in the text. In case a paragraph break is found, also write paragraph break to labels.\n    \"\"\"\n    idx = 0\n    word_sofar = ''\n    while index < len(text) and idx < len(word):\n        para_break, break_len = is_para_break(index, text)\n        if para_break:\n            # multiple newlines found, paragraph break\n            if len(word_sofar) > 0:\n                assert re.match(r'^\\s+$', word_sofar), 'Found non-empty string at the end of a paragraph that doesn\\'t match any token: |{}|'.format(word_sofar)\n                word_sofar = ''\n\n            output.write('\\n\\n')\n            index += break_len - 1\n        elif re.match(r'^\\s$', text[index]) and not re.match(r'^\\s$', word[idx]):\n            # whitespace found, and whitespace is not part of a word\n            word_sofar += text[index]\n        else:\n            # non-whitespace char, or a whitespace char that's part of a word\n            word_sofar += text[index]\n            assert text[index].replace('\\n', ' ') == word[idx], \"Character mismatch: raw text contains |%s| but the next word is |%s|.\" % (word_sofar, word)\n            idx += 1\n        index += 1\n    return index, word_sofar\n\ndef main(args):\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('plaintext_file', type=str, help=\"Plaintext file containing the raw input\")\n    parser.add_argument('conllu_file', type=str, help=\"CoNLL-U file containing tokens and sentence breaks\")\n    parser.add_argument('-o', '--output', default=None, type=str, help=\"Output file name; output to the console if not specified (the default)\")\n    parser.add_argument('-m', '--mwt_output', default=None, type=str, help=\"Output file name for MWT expansions; output to the console if not specified (the default)\")\n\n    args = parser.parse_args(args=args)\n\n    with open(args.plaintext_file, 'r', encoding='utf-8') as f:\n        text = ''.join(f.readlines())\n    textlen = len(text)\n\n    if args.output is None:\n        output = sys.stdout\n    else:\n        outdir = os.path.split(args.output)[0]\n        os.makedirs(outdir, exist_ok=True)\n        output = open(args.output, 'w')\n\n    index = 0 # character offset in rawtext\n\n    mwt_expansions = []\n    with open(args.conllu_file, 'r', encoding='utf-8') as f:\n        buf = ''\n        mwtbegin = 0\n        mwtend = -1\n        expanded = []\n        last_comments = \"\"\n        for line in f:\n            line = line.strip()\n            if len(line):\n                if line[0] == \"#\":\n                    # comment, don't do anything\n                    if len(last_comments) == 0:\n                        last_comments = line\n                    continue\n\n                line = line.split('\\t')\n                if '.' in line[0]:\n                    # the tokenizer doesn't deal with ellipsis\n                    continue\n\n                word = line[1]\n                if '-' in line[0]:\n                    # multiword token\n                    mwtbegin, mwtend = [int(x) for x in line[0].split('-')]\n                    lastmwt = word\n                    expanded = []\n                elif mwtbegin <= int(line[0]) < mwtend:\n                    expanded += [word]\n                    continue\n                elif int(line[0]) == mwtend:\n                    expanded += [word]\n                    expanded = [x.lower() for x in expanded] # evaluation doesn't care about case\n                    mwt_expansions += [(lastmwt, tuple(expanded))]\n                    if lastmwt[0].islower() and not expanded[0][0].islower():\n                        print('Sentence ID with potential wrong MWT expansion: ', last_comments, file=sys.stderr)\n                    mwtbegin = 0\n                    mwtend = -1\n                    lastmwt = None\n                    continue\n\n                if len(buf):\n                    output.write(buf)\n                index, word_found = find_next_word(index, text, word, output)\n                buf = '0' * (len(word_found)-1) + ('1' if '-' not in line[0] else '3')\n            else:\n                # sentence break found\n                if len(buf):\n                    assert int(buf[-1]) >= 1\n                    output.write(buf[:-1] + '{}'.format(int(buf[-1]) + 1))\n                    buf = ''\n\n                last_comments = ''\n\n    status_line = \"\"\n    if args.output:\n        output.close()\n        status_line = 'Tokenizer labels written to %s\\n  ' % args.output\n\n    mwts = Counter(mwt_expansions)\n    if args.mwt_output is None:\n        print('MWTs:', mwts)\n    else:\n        with open(args.mwt_output, 'w') as f:\n            json.dump(list(mwts.items()), f, indent=2)\n\n        status_line = status_line + '{} unique MWTs found in data.  MWTs written to {}'.format(len(mwts), args.mwt_output)\n        print(status_line)\n\nif __name__ == '__main__':\n    main(sys.argv[1:])\n"
  },
  {
    "path": "stanza/utils/datasets/prepare_tokenizer_treebank.py",
    "content": "\"\"\"\nPrepares train, dev, test for a treebank\n\nFor example, do\n  python -m stanza.utils.datasets.prepare_tokenizer_treebank TREEBANK\nsuch as\n  python -m stanza.utils.datasets.prepare_tokenizer_treebank UD_English-EWT\n\nand it will prepare each of train, dev, test\n\nThere are macros for preparing all of the UD treebanks at once:\n  python -m stanza.utils.datasets.prepare_tokenizer_treebank ud_all\n  python -m stanza.utils.datasets.prepare_tokenizer_treebank all_ud\nBoth are present because I kept forgetting which was the correct one\n\nThere are a few special case handlings of treebanks in this file:\n  - all Vietnamese treebanks have special post-processing to handle\n    some of the difficult spacing issues in Vietnamese text\n  - treebanks with train and test but no dev split have the\n    train data randomly split into two pieces\n  - however, instead of splitting very tiny treebanks, we skip those\n\"\"\"\n\nimport argparse\nimport glob\nimport io\nimport os\nimport random\nimport re\nimport sys\nimport tempfile\nimport zipfile\n\nfrom collections import Counter\n\nfrom stanza.models.common.constant import treebank_to_short_name\nimport stanza.utils.datasets.common as common\nfrom stanza.utils.datasets.common import read_sentences_from_conllu, write_sentences_to_conllu, write_sentences_to_file, INT_RE, MWT_RE, MWT_OR_COPY_RE\nimport stanza.utils.datasets.tokenization.convert_ml_cochin as convert_ml_cochin\nimport stanza.utils.datasets.tokenization.convert_my_alt as convert_my_alt\nimport stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp\nimport stanza.utils.datasets.tokenization.convert_th_best as convert_th_best\nimport stanza.utils.datasets.tokenization.convert_th_lst20 as convert_th_lst20\nimport stanza.utils.datasets.tokenization.convert_th_orchid as convert_th_orchid\n\ndef copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name):\n    original = f\"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu\"\n    copied = f\"{dest_dir}/{short_name}.{dest_file}.conllu\"\n\n    print(\"Copying from %s to %s\" % (original, copied))\n    # do this instead of shutil.copyfile in case there are manipulations needed\n    # for example, we might need to add fake dependencies (TODO: still needed?)\n    sents = read_sentences_from_conllu(original)\n    write_sentences_to_conllu(copied, sents)\n\ndef copy_conllu_treebank(treebank, model_type, paths, dest_dir, postprocess=None, augment=True):\n    \"\"\"\n    This utility method copies only the conllu files to the given destination directory.\n\n    Both POS, lemma, and depparse annotators need this.\n    \"\"\"\n    os.makedirs(dest_dir, exist_ok=True)\n\n    short_name = treebank_to_short_name(treebank)\n    short_language = short_name.split(\"_\")[0]\n\n    with tempfile.TemporaryDirectory() as tokenizer_dir:\n        paths = dict(paths)\n        paths[\"TOKENIZE_DATA_DIR\"] = tokenizer_dir\n\n        # first we process the tokenization data\n        args = argparse.Namespace()\n        args.augment = augment\n        args.prepare_labels = False\n        process_treebank(treebank, model_type, paths, args)\n\n        os.makedirs(dest_dir, exist_ok=True)\n\n        if postprocess is None:\n            postprocess = copy_conllu_file\n\n        # now we copy the processed conllu data files\n        postprocess(tokenizer_dir, \"train.gold\", dest_dir, \"train.in\", short_name)\n        postprocess(tokenizer_dir, \"dev.gold\", dest_dir, \"dev.in\", short_name)\n        postprocess(tokenizer_dir, \"test.gold\", dest_dir, \"test.in\", short_name)\n\ndef split_conllu_file(treebank, input_conllu, train_output_conllu, dev_output_conllu, test_output_conllu):\n    # set the seed for each data file so that the results are the same\n    # regardless of how many treebanks are processed at once\n    random.seed(1234)\n\n    # read and shuffle conllu data\n    sents = read_sentences_from_conllu(input_conllu)\n    random.shuffle(sents)\n    n_dev = int(len(sents) * XV_RATIO)\n    assert n_dev >= 1, \"Dev sentence number less than one.\"\n    n_test = int(len(sents) * XV_RATIO)\n    assert n_test >= 1, \"Test sentence number less than one.\"\n    n_train = len(sents) - n_dev - n_test\n\n    # split conllu data\n    dev_sents = sents[:n_dev]\n    test_sents = sents[n_dev:n_dev+n_test]\n    train_sents = sents[n_dev+n_test:]\n    print(\"Train/dev/test split not present.  Randomly splitting file from %s to %s, %s, %s\" % (input_conllu, train_output_conllu, dev_output_conllu, test_output_conllu))\n    print(f\"{len(sents)} total sentences found: {n_train} in train, {n_dev} in dev, {n_test} in test\")\n\n    # write conllu\n    write_sentences_to_conllu(train_output_conllu, train_sents)\n    write_sentences_to_conllu(dev_output_conllu, dev_sents)\n    write_sentences_to_conllu(test_output_conllu, test_sents)\n\n    return True\n\ndef split_train_file(treebank, train_input_conllu, train_output_conllu, dev_output_conllu):\n    # set the seed for each data file so that the results are the same\n    # regardless of how many treebanks are processed at once\n    random.seed(1234)\n\n    # read and shuffle conllu data\n    sents = read_sentences_from_conllu(train_input_conllu)\n    random.shuffle(sents)\n    n_dev = int(len(sents) * XV_RATIO)\n    assert n_dev >= 1, \"Dev sentence number less than one.\"\n    n_train = len(sents) - n_dev\n\n    # split conllu data\n    dev_sents = sents[:n_dev]\n    train_sents = sents[n_dev:]\n    print(\"Train/dev split not present.  Randomly splitting train file from %s to %s and %s\" % (train_input_conllu, train_output_conllu, dev_output_conllu))\n    print(f\"{len(sents)} total sentences found: {n_train} in train, {n_dev} in dev\")\n\n    # write conllu\n    write_sentences_to_conllu(train_output_conllu, train_sents)\n    write_sentences_to_conllu(dev_output_conllu, dev_sents)\n\n    return True\n\n\ndef has_space_after_no(piece):\n    if not piece or piece == \"_\":\n        return False\n    if piece == \"SpaceAfter=No\":\n        return True\n    tags = piece.split(\"|\")\n    return any(t == \"SpaceAfter=No\" for t in tags)\n\n\ndef remove_space_after_no(piece, fail_if_missing=True):\n    \"\"\"\n    Removes a SpaceAfter=No annotation from a single piece of a single word.\n    In other words, given a list of conll lines, first call split(\"\\t\"), then call this on the -1 column\n    \"\"\"\n    # |SpaceAfter is in UD_Romanian-Nonstandard... seems fitting\n    if piece == \"SpaceAfter=No\" or piece == \"|SpaceAfter=No\":\n        piece = \"_\"\n    elif piece.startswith(\"SpaceAfter=No|\"):\n        piece = piece.replace(\"SpaceAfter=No|\", \"\")\n    elif piece.find(\"|SpaceAfter=No\") > 0:\n        piece = piece.replace(\"|SpaceAfter=No\", \"\")\n    elif fail_if_missing:\n        raise ValueError(\"Could not find SpaceAfter=No in the given notes field\")\n    return piece\n\ndef add_space_after_no(piece, fail_if_found=True):\n    if piece == '_':\n        return \"SpaceAfter=No\"\n    else:\n        if fail_if_found:\n            if has_space_after_no(piece):\n                raise ValueError(\"Given notes field already contained SpaceAfter=No\")\n        return piece + \"|SpaceAfter=No\"\n\n\ndef augment_telugu(sents):\n    \"\"\"\n    Add a few sentences with modified punctuation to Telugu_MTG\n\n    The Telugu-MTG dataset has punctuation separated from the text in\n    almost all cases, which makes the tokenizer not learn how to\n    process that correctly.\n\n    All of the Telugu sentences end with their sentence final\n    punctuation being separated.  Furthermore, all commas are\n    separated.  We change that on some subset of the sentences to\n    make the tools more generalizable on wild text.\n    \"\"\"\n    new_sents = []\n    for sentence in sents:\n        if not sentence[1].startswith(\"# text\"):\n            raise ValueError(\"Expected the second line of %s to start with # text\" % sentence[0])\n        if not sentence[2].startswith(\"# translit\"):\n            raise ValueError(\"Expected the second line of %s to start with # translit\" % sentence[0])\n        if sentence[1].endswith(\". . .\") or sentence[1][-1] not in ('.', '?', '!'):\n            continue\n        if sentence[1][-1] in ('.', '?', '!') and sentence[1][-2] != ' ' and sentence[1][-3:] != ' ..' and sentence[1][-4:] != ' ...':\n            raise ValueError(\"Sentence %s does not end with space-punctuation, which is against our assumptions for the te_mtg treebank.  Please check the augment method to see if it is still needed\" % sentence[0])\n        if random.random() < 0.1:\n            new_sentence = list(sentence)\n            new_sentence[1] = new_sentence[1][:-2] + new_sentence[1][-1]\n            new_sentence[2] = new_sentence[2][:-2] + new_sentence[2][-1]\n            new_sentence[-2] = new_sentence[-2] + \"|SpaceAfter=No\"\n            new_sents.append(new_sentence)\n        if sentence[1].find(\",\") > 1 and random.random() < 0.1:\n            new_sentence = list(sentence)\n            index = sentence[1].find(\",\")\n            new_sentence[1] = sentence[1][:index-1] + sentence[1][index:]\n            index = sentence[1].find(\",\")\n            new_sentence[2] = sentence[2][:index-1] + sentence[2][index:]\n            for idx, word in enumerate(new_sentence):\n                if idx < 4:\n                    # skip sent_id, text, transliteration, and the first word\n                    continue\n                if word.split(\"\\t\")[1] == ',':\n                    new_sentence[idx-1] = new_sentence[idx-1] + \"|SpaceAfter=No\"\n                    break\n            new_sents.append(new_sentence)\n    return sents + new_sents\n\nCOMMA_SEPARATED_RE = re.compile(\" ([a-zA-Z]+)[,] ([a-zA-Z]+) \")\ndef augment_comma_separations(sents, ratio=0.03):\n    \"\"\"Find some fraction of the sentences which match \"asdf, zzzz\" and squish them to \"asdf,zzzz\"\n\n    This leaves the tokens and all of the other data the same.  The\n    only change made is to change SpaceAfter=No for the \",\" token and\n    adjust the #text line, with the assumption that the conllu->txt\n    conversion will correctly handle this change.\n\n    This was particularly an issue for Spanish-AnCora, but it's\n    reasonable to think it could happen to any dataset.  Currently\n    this just operates on commas and ascii letters to avoid\n    accidentally squishing anything that shouldn't be squished.\n\n    UD_Spanish-AnCora 2.7 had a problem is with this sentence:\n    # orig_file_sentence 143#5\n    In this sentence, there was a comma smashed next to a token.\n\n    Fixing just this one sentence is not sufficient to tokenize\n    \"asdf,zzzz\" as desired, so we also augment by some fraction where\n    we have squished \"asdf, zzzz\" into \"asdf,zzzz\".\n\n    This exact example was later fixed in UD 2.8, but it should still\n    potentially be useful for compensating for typos.\n    \"\"\"\n    new_sents = []\n    for sentence in sents:\n        for text_idx, text_line in enumerate(sentence):\n            # look for the line that starts with \"# text\".\n            # keep going until we find it, or silently ignore it\n            # if the dataset isn't in that format\n            if text_line.startswith(\"# text\"):\n                break\n        else:\n            continue\n\n        match = COMMA_SEPARATED_RE.search(sentence[text_idx])\n        if match and random.random() < ratio:\n            for idx, word in enumerate(sentence):\n                if word.startswith(\"#\"):\n                    continue\n                # find() doesn't work because we wind up finding substrings\n                if word.split(\"\\t\")[1] != match.group(1):\n                    continue\n                if sentence[idx+1].split(\"\\t\")[1] != ',':\n                    continue\n                if sentence[idx+2].split(\"\\t\")[1] != match.group(2):\n                    continue\n                break\n            if idx == len(sentence) - 1:\n                # this can happen with MWTs.  we may actually just\n                # want to skip MWTs anyway, so no big deal\n                continue\n            # now idx+1 should be the line with the comma in it\n            comma = sentence[idx+1]\n            pieces = comma.split(\"\\t\")\n            assert pieces[1] == ','\n            pieces[-1] = add_space_after_no(pieces[-1])\n            comma = \"\\t\".join(pieces)\n            new_sent = sentence[:idx+1] + [comma] + sentence[idx+2:]\n\n            text_offset = sentence[text_idx].find(match.group(1) + \", \" + match.group(2))\n            text_len = len(match.group(1) + \", \" + match.group(2))\n            new_text = sentence[text_idx][:text_offset] + match.group(1) + \",\" + match.group(2) + sentence[text_idx][text_offset+text_len:]\n            new_sent[text_idx] = new_text\n\n            new_sents.append(new_sent)\n\n    print(\"Added %d new sentences with asdf, zzzz -> asdf,zzzz\" % len(new_sents))\n            \n    return sents + new_sents\n\ndef augment_move_comma(sents, ratio=0.02):\n    \"\"\"\n    Move the comma from after a word to before the next word some fraction of the time\n\n    We looks for this exact pattern:\n      w1, w2\n    and replace it with\n      w1 ,w2\n\n    The idea is that this is a relatively common typo, but the tool\n    won't learn how to tokenize it without some help.\n\n    Note that this modification replaces the original text.\n    \"\"\"\n    new_sents = []\n    num_operations = 0\n    for sentence in sents:\n        if random.random() > ratio:\n            new_sents.append(sentence)\n            continue\n\n        found = False\n        for word_idx, word in enumerate(sentence):\n            if word.startswith(\"#\"):\n                continue\n            if word_idx == 0 or word_idx >= len(sentence) - 2:\n                continue\n            pieces = word.split(\"\\t\")\n            if pieces[1] == ',' and not has_space_after_no(pieces[-1]):\n                # found a comma with a space after it\n                prev_word = sentence[word_idx-1]\n                if not has_space_after_no(prev_word.split(\"\\t\")[-1]):\n                    # unfortunately, the previous word also had a\n                    # space after it.  does not fit what we are\n                    # looking for\n                    continue\n                # also, want to skip instances near MWT or copy nodes,\n                # since those are harder to rearrange\n                next_word = sentence[word_idx+1]\n                if MWT_OR_COPY_RE.match(next_word.split(\"\\t\")[0]):\n                    continue\n                if MWT_OR_COPY_RE.match(prev_word.split(\"\\t\")[0]):\n                    continue\n                # at this point, the previous word has no space and the comma does\n                found = True\n                break\n\n        if not found:\n            new_sents.append(sentence)\n            continue\n\n        new_sentence = list(sentence)\n\n        pieces = new_sentence[word_idx].split(\"\\t\")\n        pieces[-1] = add_space_after_no(pieces[-1])\n        new_sentence[word_idx] = \"\\t\".join(pieces)\n\n        pieces = new_sentence[word_idx-1].split(\"\\t\")\n        prev_word = pieces[1]\n        pieces[-1] = remove_space_after_no(pieces[-1])\n        new_sentence[word_idx-1] = \"\\t\".join(pieces)\n\n        next_word = new_sentence[word_idx+1].split(\"\\t\")[1]\n\n        for text_idx, text_line in enumerate(sentence):\n            # look for the line that starts with \"# text\".\n            # keep going until we find it, or silently ignore it\n            # if the dataset isn't in that format\n            if text_line.startswith(\"# text\"):\n                old_chunk = prev_word + \", \" + next_word\n                new_chunk = prev_word + \" ,\" + next_word\n                word_idx = text_line.find(old_chunk)\n                if word_idx < 0:\n                    raise RuntimeError(\"Unexpected #text line which did not contain the original text to be modified.  Looking for\\n\" + old_chunk + \"\\n\" + text_line)\n                new_text_line = text_line[:word_idx] + new_chunk + text_line[word_idx+len(old_chunk):]\n                new_sentence[text_idx] = new_text_line\n                break\n\n        new_sents.append(new_sentence)\n        num_operations = num_operations + 1\n\n    print(\"Swapped 'w1, w2' for 'w1 ,w2' %d times\" % num_operations)\n    return new_sents\n\ndef augment_apos(sents):\n\n    \"\"\"\n    If there are no instances of ’ in the dataset, but there are instances of ',\n    we replace some fraction of ' with ’ so that the tokenizer will recognize it.\n\n    # TODO: we could do it the other way around as well\n    \"\"\"\n    has_unicode_apos = False\n    has_ascii_apos = False\n    for sent_idx, sent in enumerate(sents):\n        if len(sent) == 0:\n            raise AssertionError(\"Got a blank sentence in position %d!\" % sent_idx)\n        for line in sent:\n            if line.startswith(\"# text\"):\n                if line.find(\"'\") >= 0:\n                    has_ascii_apos = True\n                if line.find(\"’\") >= 0:\n                    has_unicode_apos = True\n                break\n        else:\n            raise ValueError(\"Cannot find '# text' in sentences %d.  First line: %s\" % (sent_idx, sent[0]))\n\n    if has_unicode_apos or not has_ascii_apos:\n        return sents\n\n    new_sents = []\n    for sent in sents:\n        if random.random() > 0.05:\n            new_sents.append(sent)\n            continue\n        new_sent = []\n        for line in sent:\n            if line.startswith(\"# text\"):\n                new_sent.append(line.replace(\"'\", \"’\"))\n            elif line.startswith(\"#\"):\n                new_sent.append(line)\n            else:\n                pieces = line.split(\"\\t\")\n                pieces[1] = pieces[1].replace(\"'\", \"’\")\n                new_sent.append(\"\\t\".join(pieces))\n        new_sents.append(new_sent)\n\n    return new_sents\n\ndef augment_ellipses(sents):\n    \"\"\"\n    Replaces a fraction of '...' with '…'\n    \"\"\"\n    has_ellipses = False\n    has_unicode_ellipses = False\n    for sent in sents:\n        for line in sent:\n            if line.startswith(\"#\"):\n                continue\n            pieces = line.split(\"\\t\")\n            if pieces[1] == '...':\n                has_ellipses = True\n            elif pieces[1] == '…':\n                has_unicode_ellipses = True\n\n    if has_unicode_ellipses or not has_ellipses:\n        return sents\n\n    new_sents = []\n\n    num_updated = 0\n    for sent in sents:\n        if random.random() > 0.1:\n            new_sents.append(sent)\n            continue\n        found = False\n        new_sent = []\n        for line in sent:\n            if line.startswith(\"#\"):\n                new_sent.append(line)\n            else:\n                pieces = line.split(\"\\t\")\n                if pieces[1] == '...':\n                    pieces[1] = '…'\n                    found = True\n                new_sent.append(\"\\t\".join(pieces))\n        new_sents.append(new_sent)\n        if found:\n            num_updated = num_updated + 1\n\n    print(\"Changed %d sentences to use fancy unicode ellipses\" % num_updated)\n    return new_sents\n\n# https://en.wikipedia.org/wiki/Quotation_mark\nQUOTES = ['\"', '“', '”', '«', '»', '「', '」', '《', '》', '„', '″']\nQUOTES_RE = re.compile(\"(.?)[\" + \"\".join(QUOTES) + \"](.+)[\" + \"\".join(QUOTES) + \"](.?)\")\n# Danish does '«' the other way around from most European languages\nSTART_QUOTES = ['\"', '“', '”', '«', '»', '「', '《', '„', '„', '″']\nEND_QUOTES   = ['\"', '“', '”', '»', '«', '」', '》', '”', '“', '″']\n\ndef augment_quotes(sents, ratio=0.15):\n    \"\"\"\n    Go through the sentences and replace a fraction of sentences with alternate quotes\n\n    TODO: for certain languages we may want to make some language-specific changes\n      eg Danish, don't add «...»\n    \"\"\"\n    assert len(START_QUOTES) == len(END_QUOTES)\n\n    counts = Counter()\n    new_sents = []\n    for sent in sents:\n        if random.random() > ratio:\n            new_sents.append(sent)\n            continue\n\n        # count if there are exactly 2 quotes in this sentence\n        # this is for convenience - otherwise we need to figure out which pairs go together\n        count_quotes = sum(1 for x in sent\n                           if (not x.startswith(\"#\") and\n                               x.split(\"\\t\")[1] in QUOTES))\n        if count_quotes != 2:\n            new_sents.append(sent)\n            continue\n\n        # choose a pair of quotes from the candidates\n        quote_idx = random.choice(range(len(START_QUOTES)))\n        start_quote = START_QUOTES[quote_idx]\n        end_quote = END_QUOTES[quote_idx]\n        counts[start_quote + end_quote] = counts[start_quote + end_quote] + 1\n\n        new_sent = []\n        saw_start = False\n        for line in sent:\n            if line.startswith(\"#\"):\n                new_sent.append(line)\n                continue\n            pieces = line.split(\"\\t\")\n            if pieces[1] in QUOTES:\n                if saw_start:\n                    # Note that we don't change the lemma.  Presumably it's\n                    # set to the correct lemma for a quote for this treebank\n                    pieces[1] = end_quote\n                else:\n                    pieces[1] = start_quote\n                    saw_start = True\n                new_sent.append(\"\\t\".join(pieces))\n            else:\n                new_sent.append(line)\n\n        for text_idx, text_line in enumerate(new_sent):\n            # look for the line that starts with \"# text\".\n            # keep going until we find it, or silently ignore it\n            # if the dataset isn't in that format\n            if text_line.startswith(\"# text\"):\n                replacement = \"\\\\1%s\\\\2%s\\\\3\" % (start_quote, end_quote)\n                new_text_line = QUOTES_RE.sub(replacement, text_line)\n                new_sent[text_idx] = new_text_line\n\n        new_sents.append(new_sent)\n\n    # we go through this to make it simpler to execute on Windows\n    # rather than nagging the user to set utf-8\n    out = io.TextIOWrapper(sys.stdout.buffer, encoding=\"utf-8\", write_through=True)\n    print(\"Augmented {} quotes: {}\".format(sum(counts.values()), counts), file=out)\n    out.detach()\n    return new_sents\n\ndef find_text_idx(sentence):\n    \"\"\"\n    Return the index of the # text line or -1\n    \"\"\"\n    for idx, line in enumerate(sentence):\n        if line.startswith(\"# text\"):\n            return idx\n    return -1\n\nDIGIT_RE = re.compile(\"[0-9]\")\n\ndef change_indices(line, delta):\n    \"\"\"\n    Adjust all indices in the given sentence by delta.  Useful when removing a word, for example\n    \"\"\"\n    if line.startswith(\"#\"):\n        return line\n\n    pieces = line.split(\"\\t\")\n    if MWT_RE.match(pieces[0]):\n        indices = pieces[0].split(\"-\")\n        pieces[0] = \"%d-%d\" % (int(indices[0]) + delta, int(indices[1]) + delta)\n        line = \"\\t\".join(pieces)\n        return line\n\n    if MWT_OR_COPY_RE.match(pieces[0]):\n        index_pieces = pieces[0].split(\".\", maxsplit=1)\n        pieces[0] = \"%d.%s\" % (int(index_pieces[0]) + delta, index_pieces[1])\n    elif not INT_RE.match(pieces[0]):\n        raise NotImplementedError(\"Unknown index type: %s\" % pieces[0])\n    else:\n        pieces[0] = str(int(pieces[0]) + delta)\n    if pieces[6] != '_':\n        # copy nodes don't have basic dependencies in the es_ancora treebank\n        dep = int(pieces[6])\n        if dep != 0:\n            pieces[6] = str(int(dep) + delta)\n    if pieces[8] != '_':\n        dep_pieces = pieces[8].split(\":\", maxsplit=1)\n        if DIGIT_RE.search(dep_pieces[1]):\n            raise NotImplementedError(\"Need to handle multiple additional deps:\\n%s\" % line)\n        if int(dep_pieces[0]) != 0:\n            pieces[8] = str(int(dep_pieces[0]) + delta) + \":\" + dep_pieces[1]\n    line = \"\\t\".join(pieces)\n    return line\n\ndef augment_initial_punct(sents, ratio=0.20):\n    \"\"\"\n    If a sentence starts with certain punct marks, occasionally use the same sentence without the initial punct.\n\n    Currently this just handles ¿\n    This helps languages such as CA and ES where the models go awry when the initial ¿ is missing.\n    \"\"\"\n    new_sents = []\n    for sent in sents:\n        if random.random() > ratio:\n            continue\n\n        text_idx = find_text_idx(sent)\n        text_line = sent[text_idx]\n        if text_line.count(\"¿\") != 1:\n            # only handle sentences with exactly one ¿\n            continue\n\n        # find the first line with actual text\n        for idx, line in enumerate(sent):\n            if line.startswith(\"#\"):\n                continue\n            break\n        if idx >= len(sent) - 1:\n            raise ValueError(\"Unexpectedly an entire sentence is comments\")\n        pieces = line.split(\"\\t\")\n        if pieces[1] != '¿':\n            continue\n        if has_space_after_no(pieces[-1]):\n            replace_text = \"¿\"\n        else:\n            replace_text = \"¿ \"\n\n        new_sent = sent[:idx] + sent[idx+1:]\n        new_sent[text_idx] = text_line.replace(replace_text, \"\")\n\n        # now need to update all indices\n        new_sent = [change_indices(x, -1) for x in new_sent]\n        new_sents.append(new_sent)\n\n    if len(new_sents) > 0:\n        print(\"Added %d sentences with the leading ¿ removed\" % len(new_sents))\n\n    return sents + new_sents\n\n\ndef augment_brackets(sents, ratio=0.1):\n    \"\"\"\n    If there are no sentences with [], transform some () into []\n    \"\"\"\n    new_sents = []\n    for sent in sents:\n        text_idx = find_text_idx(sent)\n        text_line = sent[text_idx]\n        if text_line.count(\"[\") > 0 or text_line.count(\"]\") > 0:\n            # found a square bracket, so, never mind\n            return sents\n\n    for sent in sents:\n        if random.random() > ratio:\n            continue\n\n        text_idx = find_text_idx(sent)\n        text_line = sent[text_idx]\n        if text_line.count(\"(\") == 0 and text_line.count(\")\") == 0:\n            continue\n\n        text_line = text_line.replace(\"(\", \"[\").replace(\")\", \"]\")\n        new_sent = list(sent)\n        new_sent[text_idx] = text_line\n        for idx, line in enumerate(new_sent):\n            if line.startswith(\"#\"):\n                continue\n            pieces = line.split(\"\\t\")\n            if pieces[1] == '(':\n                pieces[1] = '['\n            elif pieces[1] == ')':\n                pieces[1] = ']'\n            new_sent[idx] = \"\\t\".join(pieces)\n        new_sents.append(new_sent)\n\n    if len(new_sents) > 0:\n        print(\"Added %d sentences with parens replaced with square brackets\" % len(new_sents))\n\n    return sents + new_sents\n\n\ndef augment_punct(sents):\n    \"\"\"\n    If there are no instances of ’ in the dataset, but there are instances of ',\n    we replace some fraction of ' with ’ so that the tokenizer will recognize it.\n\n    Also augments with ... / …\n    \"\"\"\n    new_sents = augment_apos(sents)\n    new_sents = augment_quotes(new_sents)\n    new_sents = augment_move_comma(new_sents)\n    new_sents = augment_comma_separations(new_sents)\n    new_sents = augment_initial_punct(new_sents)\n    new_sents = augment_ellipses(new_sents)\n    new_sents = augment_brackets(new_sents)\n\n    return new_sents\n\ndef remove_accents_from_words(sents):\n    new_sents = []\n    for sent in sents:\n        new_sent = []\n        for line in sent:\n            if line.startswith(\"#\"):\n                new_sent.append(line)\n            else:\n                pieces = line.split(\"\\t\")\n                pieces[1] = common.strip_accents(pieces[1])\n                new_sent.append(\"\\t\".join(pieces))\n        new_sents.append(new_sent)\n    return new_sents\n\ndef augment_accents(sents):\n    return sents + remove_accents_from_words(sents)\n\ndef write_augmented_dataset(input_conllu, output_conllu, augment_function):\n    # set the seed for each data file so that the results are the same\n    # regardless of how many treebanks are processed at once\n    random.seed(1234)\n\n    # read and shuffle conllu data\n    sents = read_sentences_from_conllu(input_conllu)\n\n    # the actual meat of the function - produce new sentences\n    new_sents = augment_function(sents)\n\n    write_sentences_to_conllu(output_conllu, new_sents)\n\ndef remove_spaces_from_sentences(sents):\n    \"\"\"\n    Makes sure every word in the list of sentences has SpaceAfter=No.\n\n    Returns a new list of sentences\n    \"\"\"\n    new_sents = []\n    for sentence in sents:\n        new_sentence = []\n        for word in sentence:\n            if word.startswith(\"#\"):\n                new_sentence.append(word)\n                continue\n            pieces = word.split(\"\\t\")\n            if pieces[-1] == \"_\":\n                pieces[-1] = \"SpaceAfter=No\"\n            elif pieces[-1].find(\"SpaceAfter=No\") >= 0:\n                pass\n            else:\n                raise ValueError(\"oops\")\n            word = \"\\t\".join(pieces)\n            new_sentence.append(word)\n        new_sents.append(new_sentence)\n    return new_sents\n\ndef remove_spaces(input_conllu, output_conllu):\n    \"\"\"\n    Turns a dataset into something appropriate for building a segmenter.\n\n    For example, this works well on the Korean datasets.\n    \"\"\"\n    sents = read_sentences_from_conllu(input_conllu)\n\n    new_sents = remove_spaces_from_sentences(sents)\n\n    write_sentences_to_conllu(output_conllu, new_sents)\n\n\ndef build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu):\n    \"\"\"\n    Builds a combined dataset out of multiple Korean datasets.\n\n    Currently this uses GSD and Kaist.  If a segmenter-appropriate\n    dataset was requested, spaces are removed.\n\n    TODO: we need to handle the difference in xpos tags somehow.\n    \"\"\"\n    gsd_conllu = common.find_treebank_dataset_file(\"UD_Korean-GSD\", udbase_dir, dataset, \"conllu\")\n    kaist_conllu = common.find_treebank_dataset_file(\"UD_Korean-Kaist\", udbase_dir, dataset, \"conllu\")\n    sents = read_sentences_from_conllu(gsd_conllu) + read_sentences_from_conllu(kaist_conllu)\n\n    segmenter = short_name.endswith(\"_seg\")\n    if segmenter:\n        sents = remove_spaces_from_sentences(sents)\n\n    write_sentences_to_conllu(output_conllu, sents)\n\ndef build_combined_korean(udbase_dir, tokenizer_dir, short_name):\n    for dataset in (\"train\", \"dev\", \"test\"):\n        output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)\n        build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu)\n\ndef build_combined_italian_dataset(paths, model_type, dataset):\n    udbase_dir = paths[\"UDBASE\"]\n    if dataset == 'train':\n        # could maybe add ParTUT, but that dataset has a slightly different xpos set\n        # (no DE or I)\n        # and I didn't feel like sorting through the differences\n        # TODO: for that dataset, can try adding it without the xpos or feats on ParTUT\n        treebanks = [\n            \"UD_Italian-ISDT\",\n            \"UD_Italian-VIT\",\n        ]\n        if model_type is not common.ModelType.TOKENIZER:\n            treebanks.extend([\n                \"UD_Italian-TWITTIRO\",\n                \"UD_Italian-PoSTWITA\"\n            ])\n        print(\"Building {} dataset out of {}\".format(model_type, \" \".join(treebanks)))\n        sents = []\n        for treebank in treebanks:\n            conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, \"conllu\", fail=True)\n            sents.extend(read_sentences_from_conllu(conllu_file))\n    else:\n        istd_conllu = common.find_treebank_dataset_file(\"UD_Italian-ISDT\", udbase_dir, dataset, \"conllu\")\n        sents = read_sentences_from_conllu(istd_conllu)\n\n    return sents\n\ndef check_gum_ready(udbase_dir):\n    gum_conllu = common.find_treebank_dataset_file(\"UD_English-GUMReddit\", udbase_dir, \"train\", \"conllu\")\n    if common.mostly_underscores(gum_conllu):\n        raise ValueError(\"Cannot process UD_English-GUMReddit in its current form.  There should be a download script available in the directory which will help integrate the missing proprietary values.  Please run that script to update the data, then try again.\")\n\ndef build_combined_english_dataset(paths, model_type, dataset):\n    \"\"\"\n    en_combined is currently EWT, GUM, PUD, Pronouns, and handparsed\n    \"\"\"\n    udbase_dir = paths[\"UDBASE_GIT\"]\n    check_gum_ready(udbase_dir)\n\n    if dataset == 'train':\n        # TODO: include more UD treebanks, possibly with xpos removed\n        #  UD_English-ParTUT - xpos are different\n        # also include \"external\" treebanks such as PTB\n        # NOTE: in order to get the best results, make sure each of these treebanks have the latest edits applied\n        train_treebanks = [\"UD_English-EWT\", \"UD_English-GUM\", \"UD_English-GUMReddit\"]\n        test_treebanks = [\"UD_English-PUD\", \"UD_English-Pronouns\"]\n        sents = []\n        for treebank in train_treebanks:\n            conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, \"train\", \"conllu\", fail=True)\n            new_sents = read_sentences_from_conllu(conllu_file)\n            print(\"Read %d sentences from %s\" % (len(new_sents), conllu_file))\n            sents.extend(new_sents)\n        for treebank in test_treebanks:\n            conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, \"test\", \"conllu\", fail=True)\n            new_sents = read_sentences_from_conllu(conllu_file)\n            print(\"Read %d sentences from %s\" % (len(new_sents), conllu_file))\n            sents.extend(new_sents)\n    else:\n        ewt_conllu = common.find_treebank_dataset_file(\"UD_English-EWT\", udbase_dir, dataset, \"conllu\")\n        sents = read_sentences_from_conllu(ewt_conllu)\n\n    return sents\n\ndef add_english_sentence_final_punctuation(handparsed_sentences):\n    \"\"\"\n    Add a period to the end of a sentence with no punct at the end.\n\n    The next-to-last word has SpaceAfter=No added as well.\n\n    Possibly English-specific because of the xpos.  Could be upgraded\n    to handle multiple languages by passing in the xpos as an argument\n    \"\"\"\n    new_sents = []\n    for sent in handparsed_sentences:\n        root_id = None\n        max_id = None\n        last_punct = False\n        for line in sent:\n            if line.startswith(\"#\"):\n                continue\n            pieces = line.split(\"\\t\")\n            if MWT_OR_COPY_RE.match(pieces[0]):\n                continue\n            if pieces[6] == '0':\n                root_id = pieces[0]\n            max_id = int(pieces[0])\n            last_punct = pieces[3] == 'PUNCT'\n        if not last_punct:\n            new_sent = list(sent)\n            pieces = new_sent[-1].split(\"\\t\")\n            pieces[-1] = add_space_after_no(pieces[-1])\n            new_sent[-1] = \"\\t\".join(pieces)\n            new_sent.append(\"%d\\t.\\t.\\tPUNCT\\t.\\t_\\t%s\\tpunct\\t%s:punct\\t_\" % (max_id+1, root_id, root_id))\n            new_sents.append(new_sent)\n        else:\n            new_sents.append(sent)\n    return new_sents\n\ndef build_extra_combined_french_dataset(paths, model_type, dataset):\n    \"\"\"\n    Extra sentences we don't want augmented for French - currently, handparsed lemmas\n    \"\"\"\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n    sents = []\n    if dataset == 'train':\n        if model_type is common.ModelType.LEMMA:\n            handparsed_path = os.path.join(handparsed_dir, \"french-lemmas\", \"fr_lemmas.conllu\")\n            handparsed_sentences = read_sentences_from_conllu(handparsed_path)\n            print(\"Loaded %d sentences from %s\" % (len(handparsed_sentences), handparsed_path))\n            sents.extend(handparsed_sentences)\n\n            handparsed_path = os.path.join(handparsed_dir, \"french-lemmas\", \"french1st_6thGrade.conllu\")\n            handparsed_sentences = read_sentences_from_conllu(handparsed_path)\n            print(\"Loaded %d sentences from %s\" % (len(handparsed_sentences), handparsed_path))\n            sents.extend(handparsed_sentences)\n    return sents\n\ndef build_extra_combined_german_dataset(paths, model_type, dataset):\n    \"\"\"\n    Extra sentences we don't want augmented for German\n\n    Currently, this is just the lemmas from Wiktionary\n    \"\"\"\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n    sents = []\n    if dataset == 'train':\n        if model_type is common.ModelType.LEMMA:\n            handparsed_path = os.path.join(handparsed_dir, \"german-lemmas-wiki\", \"de_wiki_lemmas.conllu\")\n            handparsed_sentences = read_sentences_from_conllu(handparsed_path)\n            print(\"Loaded %d sentences from %s\" % (len(handparsed_sentences), handparsed_path))\n            sents.extend(handparsed_sentences)\n    return sents\n\n\ndef build_extra_combined_english_dataset(paths, model_type, dataset):\n    \"\"\"\n    Extra sentences we don't want augmented\n    \"\"\"\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n    sents = []\n    if dataset == 'train':\n        handparsed_path = os.path.join(handparsed_dir, \"english-handparsed\", \"english.conll\")\n        handparsed_sentences = read_sentences_from_conllu(handparsed_path)\n        handparsed_sentences = add_english_sentence_final_punctuation(handparsed_sentences)\n        sents.extend(handparsed_sentences)\n        print(\"Loaded %d sentences from %s\" % (len(sents), handparsed_path))\n\n        if model_type is common.ModelType.LEMMA:\n            handparsed_path = os.path.join(handparsed_dir, \"english-lemmas\", \"en_lemmas.conllu\")\n            handparsed_sentences = read_sentences_from_conllu(handparsed_path)\n            print(\"Loaded %d sentences from %s\" % (len(handparsed_sentences), handparsed_path))\n            sents.extend(handparsed_sentences)\n\n            handparsed_path = os.path.join(handparsed_dir, \"english-lemmas-verbs\", \"irregularVerbs-noNnoAdj.conllu\")\n            handparsed_sentences = read_sentences_from_conllu(handparsed_path)\n            print(\"Loaded %d sentences from %s\" % (len(handparsed_sentences), handparsed_path))\n            sents.extend(handparsed_sentences)\n\n            handparsed_path = os.path.join(handparsed_dir, \"english-lemmas-adj\", \"en_adj.conllu\")\n            handparsed_sentences = read_sentences_from_conllu(handparsed_path)\n            print(\"Loaded %d sentences from %s\" % (len(handparsed_sentences), handparsed_path))\n            sents.extend(handparsed_sentences)\n    return sents\n\ndef build_extra_combined_italian_dataset(paths, model_type, dataset):\n    \"\"\"\n    Extra data - the MWT data for Italian\n    \"\"\"\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n    if dataset != 'train':\n        return []\n\n    extra_italian = os.path.join(handparsed_dir, \"italian-mwt\", \"italian.mwt\")\n    if not os.path.exists(extra_italian):\n        raise FileNotFoundError(\"Cannot find the extra dataset 'italian.mwt' which includes various multi-words retokenized, expected {}\".format(extra_italian))\n\n    extra_sents = read_sentences_from_conllu(extra_italian)\n    for sentence in extra_sents:\n        if not sentence[2].endswith(\"_\") or not MWT_RE.match(sentence[2]):\n            raise AssertionError(\"Unexpected format of the italian.mwt file.  Has it already be modified to have SpaceAfter=No everywhere?\")\n        sentence[2] = sentence[2][:-1] + \"SpaceAfter=No\"\n    print(\"Loaded %d sentences from %s\" % (len(extra_sents), extra_italian))\n    return extra_sents\n\ndef replace_semicolons(sentences):\n    \"\"\"\n    Spanish GSD and AnCora have different standards for semicolons.\n\n    GSD has semicolons at the end of sentences, AnCora has them in the middle as clause separators.\n    Consecutive sentences in GSD do not seem to be related, so there is no combining that can be done.\n    The easiest solution is to replace sentence final semicolons with \".\" in GSD\n    \"\"\"\n    new_sents = []\n    count = 0\n    for sentence in sentences:\n        for text_idx, text_line in enumerate(sentence):\n            if text_line.startswith(\"# text\"):\n                break\n        else:\n            raise ValueError(\"Expected every sentence in GSD to have a # text field\")\n        if not text_line.endswith(\";\"):\n            new_sents.append(sentence)\n            continue\n        count = count + 1\n        new_sent = list(sentence)\n        new_sent[text_idx] = text_line[:-1] + \".\"\n        new_sent[-1] = new_sent[-1].replace(\";\", \".\")\n        count = count + 1\n        new_sents.append(new_sent)\n    print(\"Updated %d sentences to replace sentence-final ; with .\" % count)\n    return new_sents\n\ndef strip_column(sents, column):\n    \"\"\"\n    Removes a specified column from the given dataset\n\n    Particularly useful when mixing two different POS formalisms in the same tagger\n    \"\"\"\n    new_sents = []\n    for sentence in sents:\n        new_sent = []\n        for word in sentence:\n            if word.startswith(\"#\"):\n                new_sent.append(word)\n                continue\n            pieces = word.split(\"\\t\")\n            pieces[column] = \"_\"\n            new_sent.append(\"\\t\".join(pieces))\n        new_sents.append(new_sent)\n    return new_sents\n\ndef strip_xpos(sents):\n    \"\"\"\n    Removes all xpos from the given dataset\n\n    Particularly useful when mixing two different POS formalisms in the same tagger\n    \"\"\"\n    return strip_column(sents, 4)\n\ndef strip_feats(sents):\n    \"\"\"\n    Removes all features from the given dataset\n\n    Particularly useful when mixing two different POS formalisms in the same tagger\n    \"\"\"\n    return strip_column(sents, 5)\n\ndef build_combined_japanese_dataset(paths, model_type, dataset):\n    \"\"\"\n    GSD with a handparsed dataset of some short verb phrases\n    \"\"\"\n    udbase_dir = paths[\"UDBASE\"]\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n\n    treebank = \"UD_Japanese-GSD\"\n    conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, \"conllu\", fail=True)\n    gsd_sents = read_sentences_from_conllu(conllu_file)\n    print(\"Read %d sentences from %s\" % (len(gsd_sents), conllu_file))\n\n    if dataset == 'train':\n        extra_japanese = os.path.join(handparsed_dir, \"japanese-handparsed\", \"spaces-ready-checked.conllu\")\n        if not os.path.exists(extra_japanese):\n            raise FileNotFoundError(\"Cannot find the extra dataset which includes various verb patterns, expected {}\".format(extra_japanese))\n        extra_sents = read_sentences_from_conllu(extra_japanese)\n        print(\"Read %d sentences from %s\" % (len(extra_sents), extra_japanese))\n\n        if model_type == common.ModelType.POS:\n            documents = {}\n            documents[treebank] = gsd_sents\n            documents['handparsed'] = extra_sents\n            return documents\n        else:\n            sents = gsd_sents + extra_sents\n            return sents\n    else:\n        return gsd_sents\n\n\ndef build_combined_albanian_dataset(paths, model_type, dataset):\n    \"\"\"\n    sq_combined is STAF as the base, with TSA added for some things\n    \"\"\"\n    udbase_dir = paths[\"UDBASE\"]\n    udbase_git_dir = paths[\"UDBASE_GIT\"]\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n\n    treebanks = [\"UD_Albanian-STAF\", \"UD_Albanian-TSA\"]\n\n    if dataset == 'train' and model_type == common.ModelType.POS:\n        documents = {}\n\n        conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, \"train\", \"conllu\", fail=True)\n        new_sents = read_sentences_from_conllu(conllu_file)\n        documents[treebanks[0]] = new_sents\n\n        # we use udbase_git_dir for TSA because of an updated MWT scheme\n        conllu_file = common.find_treebank_dataset_file(treebanks[1], udbase_git_dir, \"test\", \"conllu\", fail=True)\n        new_sents = read_sentences_from_conllu(conllu_file)\n        new_sents = strip_xpos(new_sents)\n        new_sents = strip_feats(new_sents)\n        documents[treebanks[1]] = new_sents\n\n        return documents\n\n    if dataset == 'train' and model_type is not common.ModelType.DEPPARSE:\n        sents = []\n\n        conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, \"train\", \"conllu\", fail=True)\n        new_sents = read_sentences_from_conllu(conllu_file)\n        print(\"Read %d sentences from %s\" % (len(new_sents), conllu_file))\n        sents.extend(new_sents)\n\n        conllu_file = common.find_treebank_dataset_file(treebanks[1], udbase_git_dir, \"test\", \"conllu\", fail=True)\n        new_sents = read_sentences_from_conllu(conllu_file)\n        print(\"Read %d sentences from %s\" % (len(new_sents), conllu_file))\n        sents.extend(new_sents)\n\n        return sents\n\n    conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, dataset, \"conllu\", fail=True)\n    sents = read_sentences_from_conllu(conllu_file)\n    return sents\n\ndef build_combined_german_dataset(paths, model_type, dataset):\n    \"\"\"\n    de_combined is currently GSD, with lemma information from Wiktionary\n\n    the lemma information is added in build_extra_combined_german_dataset\n\n    TODO: quite a bit of HDT we could possibly use\n    \"\"\"\n    udbase_dir = paths[\"UDBASE\"]\n\n    treebanks = [\"UD_German-GSD\"]\n\n    treebank = treebanks[0]\n    conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, \"conllu\", fail=True)\n    sents = read_sentences_from_conllu(conllu_file)\n\n    return sents\n\n\ndef build_combined_spanish_dataset(paths, model_type, dataset):\n    \"\"\"\n    es_combined is AnCora and GSD put together\n\n    For POS training, we put the different datasets into a zip file so\n    that we can keep the conllu files separate and remove the xpos\n    from the non-AnCora training files.  It is necessary to remove the\n    xpos because GSD and PUD both use different xpos schemes from\n    AnCora, and the tagger can use additional data files as training\n    data without a specific column if that column is entirely blank\n\n    TODO: consider mixing in PUD?\n    \"\"\"\n    udbase_dir = paths[\"UDBASE\"]\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n\n    treebanks = [\"UD_Spanish-AnCora\", \"UD_Spanish-GSD\"]\n\n    if dataset == 'train' and model_type == common.ModelType.POS:\n        documents = {}\n        for treebank in treebanks:\n            conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, \"conllu\", fail=True)\n            new_sents = read_sentences_from_conllu(conllu_file)\n            if not treebank.endswith(\"AnCora\"):\n                new_sents = strip_xpos(new_sents)\n            documents[treebank] = new_sents\n\n        return documents\n\n    if dataset == 'train':\n        sents = []\n        for treebank in treebanks:\n            conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, \"conllu\", fail=True)\n            new_sents = read_sentences_from_conllu(conllu_file)\n            print(\"Read %d sentences from %s\" % (len(new_sents), conllu_file))\n            if treebank.endswith(\"GSD\"):\n                new_sents = replace_semicolons(new_sents)\n            sents.extend(new_sents)\n\n        if model_type in (common.ModelType.TOKENIZER, common.ModelType.MWT, common.ModelType.LEMMA):\n            extra_spanish = os.path.join(handparsed_dir, \"spanish-mwt\", \"adjectives.conllu\")\n            if not os.path.exists(extra_spanish):\n                raise FileNotFoundError(\"Cannot find the extra dataset 'adjectives.conllu' which includes various multi-words retokenized, expected {}\".format(extra_spanish))\n            extra_sents = read_sentences_from_conllu(extra_spanish)\n            print(\"Read %d sentences from %s\" % (len(extra_sents), extra_spanish))\n            sents.extend(extra_sents)\n    else:\n        conllu_file = common.find_treebank_dataset_file(\"UD_Spanish-AnCora\", udbase_dir, dataset, \"conllu\", fail=True)\n        sents = read_sentences_from_conllu(conllu_file)\n\n    return sents\n\ndef build_combined_french_dataset(paths, model_type, dataset):\n    udbase_dir = paths[\"UDBASE\"]\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n    if dataset == 'train':\n        train_treebanks = [\"UD_French-GSD\", \"UD_French-ParisStories\", \"UD_French-Rhapsodie\", \"UD_French-Sequoia\"]\n        sents = []\n        for treebank in train_treebanks:\n            conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, \"train\", \"conllu\", fail=True)\n            new_sents = read_sentences_from_conllu(conllu_file)\n            print(\"Read %d sentences from %s\" % (len(new_sents), conllu_file))\n            sents.extend(new_sents)\n\n        extra_french = os.path.join(handparsed_dir, \"french-handparsed\", \"handparsed_deps.conllu\")\n        if not os.path.exists(extra_french):\n            raise FileNotFoundError(\"Cannot find the extra dataset 'handparsed_deps.conllu' which includes various dependency fixes, expected {}\".format(extra_italian))\n        extra_sents = read_sentences_from_conllu(extra_french)\n        print(\"Read %d sentences from %s\" % (len(extra_sents), extra_french))\n        sents.extend(extra_sents)\n    else:\n        gsd_conllu = common.find_treebank_dataset_file(\"UD_French-GSD\", udbase_dir, dataset, \"conllu\")\n        sents = read_sentences_from_conllu(gsd_conllu)\n\n    return sents\n\ndef build_combined_hebrew_dataset(paths, model_type, dataset):\n    \"\"\"\n    Combines the IAHLT treebank with an updated form of HTB where the annotation style more closes matches IAHLT\n\n    Currently the updated HTB is not in UD, so you will need to clone\n    git@github.com:IAHLT/UD_Hebrew.git to $UDBASE_GIT\n\n    dev and test sets will be those from IAHLT\n    \"\"\"\n    udbase_dir = paths[\"UDBASE\"]\n    udbase_git_dir = paths[\"UDBASE_GIT\"]\n\n    treebanks = [\"UD_Hebrew-IAHLTwiki\", \"UD_Hebrew-IAHLTknesset\"]\n    if dataset == 'train':\n        sents = []\n        for treebank in treebanks:\n            conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, \"conllu\", fail=True)\n            new_sents = read_sentences_from_conllu(conllu_file)\n            print(\"Read %d sentences from %s\" % (len(new_sents), conllu_file))\n            sents.extend(new_sents)\n\n        # if/when this gets ported back to UD, switch to getting both datasets from UD\n        hebrew_git_dir = os.path.join(udbase_git_dir, \"UD_Hebrew\")\n        if not os.path.exists(hebrew_git_dir):\n            raise FileNotFoundError(\"Please download git@github.com:IAHLT/UD_Hebrew.git to %s (based on $UDBASE_GIT)\" % hebrew_git_dir)\n        conllu_file = os.path.join(hebrew_git_dir, \"he_htb-ud-train.conllu\")\n        if not os.path.exists(conllu_file):\n            raise FileNotFoundError(\"Found %s but inexplicably there was no %s\" % (hebrew_git_dir, conllu_file))\n        new_sents = read_sentences_from_conllu(conllu_file)\n        print(\"Read %d sentences from %s\" % (len(new_sents), conllu_file))\n        sents.extend(new_sents)\n    else:\n        conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, dataset, \"conllu\", fail=True)\n        sents = read_sentences_from_conllu(conllu_file)\n\n    return sents\n\nCOMBINED_FNS = {\n    \"de_combined\": build_combined_german_dataset,\n    \"en_combined\": build_combined_english_dataset,\n    \"es_combined\": build_combined_spanish_dataset,\n    \"fr_combined\": build_combined_french_dataset,\n    \"he_combined\": build_combined_hebrew_dataset,\n    \"it_combined\": build_combined_italian_dataset,\n    \"ja_combined\": build_combined_japanese_dataset,\n    \"sq_combined\": build_combined_albanian_dataset,\n}\n\n# some extra data for the combined models without augmenting\nCOMBINED_EXTRA_FNS = {\n    \"de_combined\": build_extra_combined_german_dataset,\n    \"en_combined\": build_extra_combined_english_dataset,\n    \"fr_combined\": build_extra_combined_french_dataset,\n    \"it_combined\": build_extra_combined_italian_dataset,\n}\n\ndef build_combined_dataset(paths, short_name, model_type, augment):\n    random.seed(1234)\n    tokenizer_dir = paths[\"TOKENIZE_DATA_DIR\"]\n    build_fn = COMBINED_FNS[short_name]\n    extra_fn = COMBINED_EXTRA_FNS.get(short_name, None)\n    for dataset in (\"train\", \"dev\", \"test\"):\n        output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)\n        sents = build_fn(paths, model_type, dataset)\n        if isinstance(sents, dict):\n            if dataset == 'train' and augment:\n                for filename in list(sents.keys()):\n                    sents[filename] = augment_punct(sents[filename])\n            output_zip = os.path.splitext(output_conllu)[0] + \".zip\"\n            with zipfile.ZipFile(output_zip, \"w\") as zout:\n                for filename in list(sents.keys()):\n                    with zout.open(filename + \".conllu\", \"w\") as zfout:\n                        with io.TextIOWrapper(zfout, encoding='utf-8', newline='') as fout:\n                            write_sentences_to_file(fout, sents[filename])\n        else:\n            if dataset == 'train' and augment:\n                sents = augment_punct(sents)\n            if extra_fn is not None:\n                sents.extend(extra_fn(paths, model_type, dataset))\n            write_sentences_to_conllu(output_conllu, sents)\n\nBIO_DATASETS = (\"en_craft\", \"en_genia\", \"en_mimic\")\n\ndef build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_name, model_type, augment):\n    \"\"\"\n    Process the en bio datasets\n\n    Creates a dataset by combining the en_combined data with one of the bio sets\n    \"\"\"\n    random.seed(1234)\n    name, bio_dataset = short_name.split(\"_\")\n    assert name == 'en'\n    for dataset in (\"train\", \"dev\", \"test\"):\n        output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)\n        if dataset == 'train':\n            sents = build_combined_english_dataset(paths, model_type, dataset)\n            if dataset == 'train' and augment:\n                sents = augment_punct(sents)\n        else:\n            sents = []\n        bio_file = os.path.join(paths[\"BIO_UD_DIR\"], \"UD_English-%s\" % bio_dataset.upper(), \"en_%s-ud-%s.conllu\" % (bio_dataset.lower(), dataset))\n        new_sents = read_sentences_from_conllu(bio_file)\n        print(\"Read %d sentences from %s\" % (len(new_sents), bio_file))\n        sents.extend(new_sents)\n        write_sentences_to_conllu(output_conllu, sents)\n\ndef build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment):\n    \"\"\"\n    Build the GUM dataset by combining GUMReddit\n\n    It checks to make sure GUMReddit is filled out using the included script\n    \"\"\"\n    check_gum_ready(udbase_dir)\n    random.seed(1234)\n\n    output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)\n\n    treebanks = [\"UD_English-GUM\", \"UD_English-GUMReddit\"]\n    sents = []\n    for treebank in treebanks:\n        conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, \"conllu\", fail=True)\n        sents.extend(read_sentences_from_conllu(conllu_file))\n\n    if dataset == 'train' and augment:\n        sents = augment_punct(sents)\n\n    write_sentences_to_conllu(output_conllu, sents)\n\ndef build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, augment):\n    for dataset in (\"train\", \"dev\", \"test\"):\n        build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment)\n\ndef prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, dataset, augment=True, input_conllu=None, output_conllu=None):\n    if input_conllu is None:\n        input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, \"conllu\", fail=True)\n    if output_conllu is None:\n        output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)\n    print(\"Reading from %s and writing to %s\" % (input_conllu, output_conllu))\n\n    if short_name == \"te_mtg\" and dataset == 'train' and augment:\n        write_augmented_dataset(input_conllu, output_conllu, augment_telugu)\n    elif short_name.startswith(\"ko_\") and short_name.endswith(\"_seg\"):\n        remove_spaces(input_conllu, output_conllu)\n    elif short_name.startswith(\"grc_\") and short_name.endswith(\"-diacritics\"):\n        write_augmented_dataset(input_conllu, output_conllu, augment_accents)\n    elif dataset == 'train' and augment:\n        write_augmented_dataset(input_conllu, output_conllu, augment_punct)\n    else:\n        sents = read_sentences_from_conllu(input_conllu)\n        write_sentences_to_conllu(output_conllu, sents)\n\ndef process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, augment=True):\n    \"\"\"\n    Process a normal UD treebank with train/dev/test splits\n\n    SL-SSJ and other datasets with inline modifications all use this code path as well.\n    \"\"\"\n    prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, \"train\", augment)\n    prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, \"dev\", augment)\n    prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, \"test\", augment)\n\n\nXV_RATIO = 0.2\n\ndef process_test_only_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language):\n    \"\"\"\n    Process a large UD treebank with only a test\n\n    Return False if the treebank is too small\n    \"\"\"\n    train_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, \"train\", \"conllu\")\n    dev_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, \"train\", \"conllu\")\n    test_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, \"test\", \"conllu\")\n    if train_input_conllu or dev_input_conllu:\n        return False\n\n    train_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, \"train\")\n    dev_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, \"dev\")\n    test_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, \"test\")\n\n    if common.num_words_in_file(test_input_conllu) <= 10000:\n        return False\n\n    if not split_conllu_file(treebank=treebank,\n                             input_conllu=test_input_conllu,\n                             train_output_conllu=train_output_conllu,\n                             dev_output_conllu=dev_output_conllu,\n                             test_output_conllu=test_output_conllu):\n        return False\n\n    return True\n\n\ndef process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language):\n    \"\"\"\n    Process a UD treebank with only train/test splits\n\n    For example, in UD 2.7:\n      UD_Buryat-BDT\n      UD_Galician-TreeGal\n      UD_Indonesian-CSUI\n      UD_Kazakh-KTB\n      UD_Kurmanji-MG\n      UD_Latin-Perseus\n      UD_Livvi-KKPP\n      UD_North_Sami-Giella\n      UD_Old_Russian-RNC\n      UD_Sanskrit-Vedic\n      UD_Slovenian-SST\n      UD_Upper_Sorbian-UFAL\n      UD_Welsh-CCG\n\n    Returns True if successful, False if not\n    \"\"\"\n    train_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, \"train\", \"conllu\")\n    test_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, \"test\", \"conllu\")\n\n    train_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, \"train\")\n    dev_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, \"dev\")\n    test_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, \"test\")\n\n    if (common.num_words_in_file(train_input_conllu) <= 1000 and\n        common.num_words_in_file(test_input_conllu) > 5000):\n        train_input_conllu, test_input_conllu = test_input_conllu, train_input_conllu\n\n    if not split_train_file(treebank=treebank,\n                            train_input_conllu=train_input_conllu,\n                            train_output_conllu=train_output_conllu,\n                            dev_output_conllu=dev_output_conllu):\n        return False\n\n    # the test set is already fine\n    # currently we do not do any augmentation of these partial treebanks\n    prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, \"test\", augment=False, input_conllu=test_input_conllu, output_conllu=test_output_conllu)\n    return True\n\ndef add_specific_args(parser):\n    parser.add_argument('--no_augment', action='store_false', dest='augment', default=True,\n                        help='Augment the dataset in various ways')\n    parser.add_argument('--no_prepare_labels', action='store_false', dest='prepare_labels', default=True,\n                        help='Prepare tokenizer and MWT labels.  Expensive, but obviously necessary for training those models.')\n    convert_th_lst20.add_lst20_args(parser)\n\n    convert_vi_vlsp.add_vlsp_args(parser)\n\ndef process_treebank(treebank, model_type, paths, args):\n    \"\"\"\n    Processes a single treebank into train, dev, test parts\n\n    Includes processing for a few external tokenization datasets:\n      vi_vlsp, th_orchid, th_best\n\n    Also, there is no specific mechanism for UD_Arabic-NYUAD or\n    similar treebanks, which need integration with LDC datsets\n    \"\"\"\n    udbase_dir = paths[\"UDBASE\"]\n    tokenizer_dir = paths[\"TOKENIZE_DATA_DIR\"]\n    handparsed_dir = paths[\"HANDPARSED_DIR\"]\n\n    short_name = treebank_to_short_name(treebank)\n    short_language = short_name.split(\"_\")[0]\n\n    os.makedirs(tokenizer_dir, exist_ok=True)\n\n    success = False\n    if short_name == \"my_alt\":\n        convert_my_alt.convert_my_alt(paths[\"CONSTITUENCY_BASE\"], tokenizer_dir)\n    elif short_name == \"vi_vlsp\":\n        convert_vi_vlsp.convert_vi_vlsp(paths[\"STANZA_EXTERN_DIR\"], tokenizer_dir, args)\n    elif short_name == \"th_orchid\":\n        convert_th_orchid.main(paths[\"STANZA_EXTERN_DIR\"], tokenizer_dir)\n    elif short_name == \"th_lst20\":\n        convert_th_lst20.convert(paths[\"STANZA_EXTERN_DIR\"], tokenizer_dir, args)\n    elif short_name == \"th_best\":\n        convert_th_best.main(paths[\"STANZA_EXTERN_DIR\"], tokenizer_dir)\n    elif short_name == \"ml_cochin\":\n        convert_ml_cochin.main(paths[\"STANZA_EXTERN_DIR\"], tokenizer_dir)\n    elif short_name.startswith(\"ko_combined\"):\n        build_combined_korean(udbase_dir, tokenizer_dir, short_name)\n    elif short_name in COMBINED_FNS: # eg \"it_combined\", \"en_combined\", etc\n        build_combined_dataset(paths, short_name, model_type, args.augment)\n    elif short_name in BIO_DATASETS:\n        build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_name, model_type, args.augment)\n    elif short_name.startswith(\"en_gum\"):\n        # we special case GUM because it should include a filled-out GUMReddit\n        print(\"Preparing data for %s: %s, %s\" % (treebank, short_name, short_language))\n        build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, args.augment)\n    else:\n        # check that we can find the train file where we expect it\n        train_conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, \"train\", \"conllu\", fail=False)\n        if not train_conllu_file:\n            # maybe this dataset has a huge test set we can split?\n            test_conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, \"test\", \"conllu\", fail=True)\n            print(\"Checking data for %s: %s, %s to see if the test dataset is large enough\" % (treebank, short_name, short_language))\n            success = process_test_only_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language)\n        else:\n            print(\"Preparing data for %s: %s, %s\" % (treebank, short_name, short_language))\n\n            if not common.find_treebank_dataset_file(treebank, udbase_dir, \"dev\", \"conllu\", fail=False):\n                success = process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language)\n            else:\n                process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, args.augment)\n\n    if success and (model_type is common.ModelType.TOKENIZER or model_type is common.ModelType.MWT):\n        if not short_name in ('th_orchid', 'th_lst20'):\n            common.convert_conllu_to_txt(tokenizer_dir, short_name)\n\n        if args.prepare_labels:\n            common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name)\n\n\ndef main():\n    common.main(process_treebank, common.ModelType.TOKENIZER, add_specific_args)\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/pretrain/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/pretrain/word_in_pretrain.py",
    "content": "\"\"\"\nSimple tool to query a word vector file to see if certain words are in that file\n\"\"\"\n\nimport argparse\nimport os\n\nfrom stanza.models.common.pretrain import Pretrain\nfrom stanza.resources.common import DEFAULT_MODEL_DIR, download\n\ndef main():\n    parser = argparse.ArgumentParser()\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument(\"--pretrain\", default=None, type=str, help=\"Where to read the converted PT file\")\n    group.add_argument(\"--package\", default=None, type=str, help=\"Use a pretrain package instead\")\n    parser.add_argument(\"--download_json\", default=False, action='store_true', help=\"Download the json even if it already exists\")\n    parser.add_argument(\"words\", type=str, nargs=\"+\", help=\"Which words to search for\")\n    args = parser.parse_args()\n\n    if args.pretrain:\n        pt = Pretrain(args.pretrain)\n    else:\n        lang, package = args.package.split(\"_\", 1)\n        download(lang=lang, package=None, processors={\"pretrain\": package}, download_json=args.download_json)\n        pt_filename = os.path.join(DEFAULT_MODEL_DIR, lang, \"pretrain\", \"%s.pt\" % package)\n        pt = Pretrain(pt_filename)\n\n    for word in args.words:\n        print(\"{}: {}\".format(word, word in pt.vocab))\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/random_split_conllu.py",
    "content": "\"\"\"\nRandomly split a file into train, dev, and test sections\n\nSpecifically used in the case of building a tagger from the initial\nPOS tagging provided by Isra, but obviously can be used to split any\nconllu file\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nfrom stanza.models.common.doc import Document\nfrom stanza.utils.conll import CoNLL\nfrom stanza.utils.default_paths import get_default_paths\n\ndef random_split(doc, weights, remove_xpos=False, remove_feats=False):\n    \"\"\"\n    weights: a tuple / list of (train, dev, test) weights\n    \"\"\"\n    train_doc = ([], [])\n    dev_doc = ([], [])\n    test_doc = ([], [])\n    splits = [train_doc, dev_doc, test_doc]\n    for sentence in doc.sentences:\n        sentence_dict = sentence.to_dict()\n        if remove_xpos:\n            for x in sentence_dict:\n                x.pop('xpos', None)\n        if remove_feats:\n            for x in sentence_dict:\n                x.pop('feats', None)\n        split = random.choices(splits, weights)[0]\n        split[0].append(sentence_dict)\n        split[1].append(sentence.comments)\n\n    splits = [Document(split[0], comments=split[1]) for split in splits]\n    return splits\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--filename', default='extern_data/sindhi/upos/sindhi_upos.conllu', help='Which file to split')\n    parser.add_argument('--train', type=float, default=0.8, help='Fraction of the data to use for train')\n    parser.add_argument('--dev', type=float, default=0.1, help='Fraction of the data to use for dev')\n    parser.add_argument('--test', type=float, default=0.1, help='Fraction of the data to use for test')\n    parser.add_argument('--seed', default='1234', help='Random seed to use')\n    parser.add_argument('--short_name', default='sd_isra', help='Dataset name to use when writing output files')\n    parser.add_argument('--no_remove_xpos', default=True, action='store_false', dest='remove_xpos', help='By default, we remove the xpos from the dataset')\n    parser.add_argument('--no_remove_feats', default=True, action='store_false', dest='remove_feats', help='By default, we remove the feats from the dataset')\n    parser.add_argument('--output_directory', default=get_default_paths()[\"POS_DATA_DIR\"], help=\"Where to put the split conllu\")\n    args = parser.parse_args()\n\n    weights = (args.train, args.dev, args.test)\n\n    doc = CoNLL.conll2doc(args.filename)\n    random.seed(args.seed)\n\n    splits = random_split(doc, weights, args.remove_xpos, args.remove_feats)\n\n    for split_doc, split_name in zip(splits, (\"train\", \"dev\", \"test\")):\n        filename = os.path.join(args.output_directory, \"%s.%s.in.conllu\" % (args.short_name, split_name))\n        print(\"Outputting %d sentences to %s\" % (len(split_doc.sentences), filename))\n        CoNLL.write_doc2conll(split_doc, filename)\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/sentiment/add_constituency.py",
    "content": "\"\"\"\nFor a dataset produced by prepare_sentiment_dataset, add constituency parses.\n\nObviously this will only work on languages that have a constituency parser\n\"\"\"\n\nimport argparse\nimport os\n\nimport stanza\nfrom stanza.models.classifiers.data import read_dataset\nfrom stanza.models.classifiers.utils import WVType\nfrom stanza.models.mwt.utils import resplit_mwt\nfrom stanza.utils.datasets.sentiment import prepare_sentiment_dataset\nfrom stanza.utils.datasets.sentiment import process_utils\nimport stanza.utils.default_paths as default_paths\n\nSHARDS = (\"train\", \"dev\", \"test\")\n\ndef main():\n    parser = argparse.ArgumentParser()\n    # TODO: allow multiple files?\n    parser.add_argument('dataset', type=str, help=\"Dataset (or a single file) to process\")\n    parser.add_argument('--output', type=str, help=\"Write the processed data here instead of clobbering\")\n    parser.add_argument('--constituency_package', type=str, default=None, help=\"Constituency model to use for parsing\")\n    parser.add_argument('--constituency_model', type=str, default=None, help=\"Specific model file to use for parsing\")\n    parser.add_argument('--retag_package', type=str, default=None, help=\"Which tagger to use for retagging\")\n    parser.add_argument('--split_mwt', action='store_true', help=\"Split MWT from the original sentences if the language has MWT\")\n    parser.add_argument('--lang', type=str, default=None, help=\"Which language the dataset/file is in.  If not specified, will try to use the dataset name\")\n    args = parser.parse_args()\n\n    if os.path.exists(args.dataset):\n        expected_files = [args.dataset]\n        if args.output:\n            output_files = [args.output]\n        else:\n            output_files = expected_files\n        if not args.lang:\n            _, filename = os.path.split(args.dataset)\n            args.lang = filename.split(\"_\")[0]\n            print(\"Guessing lang=%s based on the filename %s\" % (args.lang, filename))\n    else:\n        paths = default_paths.get_default_paths()\n        # TODO: one of the side effects of the tass2020 dataset is to make a bunch of extra files\n        # Perhaps we could have the prepare_sentiment_dataset script return a list of those files\n        expected_files = [os.path.join(paths['SENTIMENT_DATA_DIR'], '%s.%s.json' % (args.dataset, shard)) for shard in SHARDS]\n        if args.output:\n            output_files = [os.path.join(paths['SENTIMENT_DATA_DIR'], '%s.%s.json' % (args.output, shard)) for shard in SHARDS]\n        else:\n            output_files = expected_files\n        for filename in expected_files:\n            if not os.path.exists(filename):\n                print(\"Cannot find expected dataset file %s - rebuilding dataset\" % filename)\n                prepare_sentiment_dataset.main(args.dataset)\n                break\n        if not args.lang:\n            args.lang, _ = args.dataset.split(\"_\", 1)\n            print(\"Guessing lang=%s based on the dataset name\" % args.lang)\n\n\n    pipeline_args = {\"lang\": args.lang,\n                     \"processors\": \"tokenize,pos,constituency\",\n                     \"tokenize_pretokenized\": True,\n                     \"pos_batch_size\": 50,\n                     \"pos_tqdm\": True,\n                     \"constituency_tqdm\": True}\n    package = {}\n    if args.constituency_package is not None:\n        package[\"constituency\"] = args.constituency_package\n    if args.retag_package is not None:\n        package[\"pos\"] = args.retag_package\n    if package:\n        pipeline_args[\"package\"] = package\n    if args.constituency_model is not None:\n        pipeline_args[\"constituency_model_path\"] = args.constituency_model\n    pipe = stanza.Pipeline(**pipeline_args)\n\n    if args.split_mwt:\n        # TODO: allow for different tokenize packages\n        mwt_pipe = stanza.Pipeline(lang=args.lang, processors=\"tokenize\")\n        if \"mwt\" in mwt_pipe.processors:\n            print(\"This language has MWT.  Will resplit any MWTs found in the dataset\")\n        else:\n            print(\"--split_mwt was requested, but %s does not support MWT!\" % args.lang)\n            args.split_mwt = False\n\n    for filename, output_filename in zip(expected_files, output_files):\n        dataset = read_dataset(filename, WVType.OTHER, 1)\n        text = [x.text for x in dataset]\n        if args.split_mwt:\n            print(\"Resplitting MWT in %d sentences from %s\" % (len(dataset), filename))\n            doc = resplit_mwt(text, mwt_pipe)\n            print(\"Parsing %d sentences from %s\" % (len(dataset), filename))\n            doc = pipe(doc)\n        else:\n            print(\"Parsing %d sentences from %s\" % (len(dataset), filename))\n            doc = pipe(text)\n\n        assert len(dataset) == len(doc.sentences)\n        for datum, sentence in zip(dataset, doc.sentences):\n            datum.constituency = sentence.constituency\n\n        process_utils.write_list(output_filename, dataset)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/convert_italian_poetry_classification.py",
    "content": "\"\"\"\nA short tool to turn a labeled dataset of the format\nProf. Delmonte provided into a stanza input file for the classifier.\n\nData is expected to be in the sentiment italian subdirectory (see below)\n\nOnly writes a test set.  Use it as an eval file for a trained model.\n\"\"\"\n\nimport os\n\nimport stanza\nfrom stanza.models.classifiers.data import SentimentDatum\nfrom stanza.utils.datasets.sentiment import process_utils\nimport stanza.utils.default_paths as default_paths\n\ndef main():\n    paths = default_paths.get_default_paths()\n\n    dataset_name = \"it_vit_sentences_poetry\"\n\n    poetry_filename = os.path.join(paths[\"SENTIMENT_BASE\"], \"italian\", \"sentence_classification\", \"poetry\", \"testset_300_labeled.txt\")\n    if not os.path.exists(poetry_filename):\n        raise FileNotFoundError(\"Expected to find the labeled file in %s\" % poetry_filename)\n    print(\"Reading the labeled poetry from %s\" % poetry_filename)\n\n    tokenizer = stanza.Pipeline(\"it\", processors=\"tokenize\", tokenize_no_ssplit=True)\n    dataset = []\n    with open(poetry_filename, encoding=\"utf-8\") as fin:\n        for line_num, line in enumerate(fin):\n            line = line.strip()\n            if not line:\n                continue\n\n            line = line.replace(u'\\ufeff', '')\n            pieces = line.split(maxsplit=1)\n            # first column is the label\n            # remainder of the text is the raw text\n            label = pieces[0].strip()\n            if label not in ('0', '1'):\n                if label == \"viene\" and line_num == 257:\n                    print(\"Skipping known missing label at line 257\")\n                    continue\n                assert isinstance(label, str)\n                ords = \",\".join(str(ord(x)) for x in label)\n                raise ValueError(\"Unexpected label |%s| (%s) for line %d\" % (label, ords, line_num))\n\n            # tokenize the text into words\n            # we could make this faster by stacking it, but the input file is quite short anyway\n            text = pieces[1]\n            doc = tokenizer(text)\n            words = [x.text for x in doc.sentences[0].words]\n\n            dataset.append(SentimentDatum(label, words))\n\n    print(\"Read %d lines from %s\" % (len(dataset), poetry_filename))\n    output_filename = \"%s.test.json\" % dataset_name\n    output_path = os.path.join(paths[\"SENTIMENT_DATA_DIR\"], output_filename)\n    print(\"Writing output to %s\" % output_path)\n    process_utils.write_list(output_path, dataset)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/convert_italian_sentence_classification.py",
    "content": "\"\"\"\nConverts a file of labels on constituency trees for the it_vit dataset\n\nThe labels are for whether or not a sentence is written in a standard\nS-V-O order.  The intent is to see how much a constituency parser\ncan improve over a regular transformer classifier.\n\nThis file is provided by Prof. Delmonte as part of a classification\nproject.  Contact John Bauer for more details.\n\nTechnically this should be \"classifier\" instead of \"sentiment\"\n\"\"\"\n\nimport os\n\nfrom stanza.models.classifiers.data import SentimentDatum\nfrom stanza.utils.datasets.sentiment import process_utils\nfrom stanza.utils.datasets.constituency.convert_it_vit import read_updated_trees\nimport stanza.utils.default_paths as default_paths\n\ndef label_trees(label_map, trees):\n    new_trees = []\n    for tree in trees:\n        if tree.con_id not in label_map:\n            raise ValueError(\"%s not labeled\" % tree.con_id)\n        label = label_map[tree.con_id]\n        new_trees.append(SentimentDatum(label, tree.tree.leaf_labels(), tree.tree))\n    return new_trees\n\ndef read_label_map(label_filename):\n    with open(label_filename, encoding=\"utf-8\") as fin:\n        lines = fin.readlines()\n    lines = [x.strip() for x in lines]\n    lines = [x.split() for x in lines if x]\n    label_map = {}\n    for line_idx, line in enumerate(lines):\n        k = line[0].split(\"#\")[1]\n        v = line[1]\n\n        # compensate for an off-by-one error in the labels for ids 12 through 129\n        # we went back and forth a few times but i couldn't explain the error,\n        # so whatever, just compensate for it on the conversion side\n        k_idx = int(k.split(\"_\")[1])\n        if k_idx != line_idx + 1:\n            if k_idx >= 12 and k_idx <= 129:\n                k = \"sent_%05d\" % (k_idx - 1)\n            else:\n                raise ValueError(\"Unexpected key offset for line {}: {}\".format(line_idx, line))\n\n        if v == \"neg\":\n            v = \"0\"\n        elif v == \"pos\":\n            v = \"1\"\n        else:\n            raise ValueError(\"Unexpected label %s for key %s\" % (v, k))\n\n        if k in label_map:\n            raise ValueError(\"Duplicate key %s: new value %s, old value %s\" % (k, v, label_map[k]))\n        label_map[k] = v\n\n    return label_map\n\ndef main():\n    paths = default_paths.get_default_paths()\n\n    dataset_name = \"it_vit_sentences\"\n\n    label_filename = os.path.join(paths[\"SENTIMENT_BASE\"], \"italian\", \"sentence_classification\", \"classified\")\n    if not os.path.exists(label_filename):\n        raise FileNotFoundError(\"Expected to find the labeled file in %s\" % label_filename)\n\n    label_map = read_label_map(label_filename)\n\n    # this will produce three lists of trees with their con_id attached\n    train_trees, dev_trees, test_trees = read_updated_trees(paths)\n\n    train_trees = label_trees(label_map, train_trees)\n    dev_trees   = label_trees(label_map, dev_trees)\n    test_trees  = label_trees(label_map, test_trees)\n\n    dataset = (train_trees, dev_trees, test_trees)\n    process_utils.write_dataset(dataset, paths[\"SENTIMENT_DATA_DIR\"], dataset_name)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/prepare_sentiment_dataset.py",
    "content": "\"\"\"Prepare a single dataset or a combination dataset for the sentiment project\n\nManipulates various downloads from their original form to a form\nusable by the classifier model\n\nExplanations for the existing datasets are below.\nAfter processing the dataset, you can train with\nthe run_sentiment script\n\npython3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset <dataset>\npython3 -m stanza.utils.training.run_sentiment <dataset>\n\nEnglish\n-------\n\nSST (Stanford Sentiment Treebank)\n  https://nlp.stanford.edu/sentiment/\n  https://github.com/stanfordnlp/sentiment-treebank\n  The git repo includes fixed tokenization and sentence splits, along\n    with a partial conversion to updated PTB tokenization standards.\n\n  The first step is to git clone the SST to here:\n    $SENTIMENT_BASE/sentiment-treebank\n  eg:\n    cd $SENTIMENT_BASE\n    git clone git@github.com:stanfordnlp/sentiment-treebank.git\n\n  There are a few different usages of SST.\n\n  The scores most commonly reported are for SST-2,\n    positive and negative only.\n  To get a version of this:\n\n    python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sst2\n    python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sst2roots\n\n  The model we distribute is a three class model (+, 0, -)\n    with some smaller datasets added for better coverage.\n    See \"sstplus\" below.\n\nMELD\n  https://github.com/SenticNet/MELD/tree/master/data/MELD\n  https://github.com/SenticNet/MELD\n  https://arxiv.org/pdf/1810.02508.pdf\n\n  MELD: A Multimodal Multi-Party Dataset for Emotion Recognition in Conversation. ACL 2019.\n  S. Poria, D. Hazarika, N. Majumder, G. Naik, E. Cambria, R. Mihalcea.\n\n  An Emotion Corpus of Multi-Party Conversations.\n  Chen, S.Y., Hsu, C.C., Kuo, C.C. and Ku, L.W.\n\n  Copy the three files in the repo into\n    $SENTIMENT_BASE/MELD\n  TODO: make it so you git clone the repo instead\n\n  There are train/dev/test splits, so you can build a model\n    out of just this corpus.  The first step is to convert\n    to the classifier data format:\n\n    python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_meld\n\n  However, in general we simply include this in the sstplus model\n    rather than releasing a separate model.\n\nArguana\n  http://argumentation.bplaced.net/arguana/data\n  http://argumentation.bplaced.net/arguana-data/arguana-tripadvisor-annotated-v2.zip\n\n  http://argumentation.bplaced.net/arguana-publications/papers/wachsmuth14a-cicling.pdf\n  A Review Corpus for Argumentation Analysis.  CICLing 2014\n  Henning Wachsmuth, Martin Trenkmann, Benno Stein, Gregor Engels, Tsvetomira Palarkarska\n\n  Download the zip file and unzip it in\n    $SENTIMENT_BASE/arguana\n\n  This is included in the sstplus model.\n\nairline\n  A Kaggle corpus for sentiment detection on airline tweets.\n  We include this in sstplus as well.\n\n  https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment\n\n  Download Tweets.csv and put it in\n    $SENTIMENT_BASE/airline\n\nSLSD\n  https://archive.ics.uci.edu/ml/datasets/Sentiment+Labelled+Sentences\n\n  From Group to Individual Labels using Deep Features.  KDD 2015\n  Kotzias et. al\n\n  Put the contents of the zip file in\n    $SENTIMENT_BASE/slsd\n\n  The sstplus model includes this as training data\n\nen_sstplus\n  This is a three class model built from SST, along with the additional\n    English data sources above for coverage of additional domains.\n\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sstplus\n\nen_corona\n  A kaggle covid-19 text classification dataset\n  https://www.kaggle.com/datasets/datatattle/covid-19-nlp-text-classification\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_corona\n\nGerman\n------\n\nde_sb10k\n  This used to be here:\n    https://www.spinningbytes.com/resources/germansentiment/\n  Now it appears to have moved here?\n    https://github.com/oliverguhr/german-sentiment\n\n  https://dl.acm.org/doi/pdf/10.1145/3038912.3052611\n  Leveraging Large Amounts of Weakly Supervised Data for Multi-Language Sentiment Classification\n  WWW '17: Proceedings of the 26th International Conference on World Wide Web\n  Jan Deriu, Aurelien Lucchi, Valeria De Luca, Aliaksei Severyn,\n    Simon Müller, Mark Cieliebak, Thomas Hofmann, Martin Jaggi\n\n  The current prep script works on the old version of the data.\n  TODO: update to work on the git repo\n\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset de_sb10k\n\nde_scare\n  http://romanklinger.de/scare/\n\n  The Sentiment Corpus of App Reviews with Fine-grained Annotations in German\n  LREC 2016\n  Mario Sänger, Ulf Leser, Steffen Kemmerer, Peter Adolphs, and Roman Klinger\n\n  Download the data and put it in\n    $SENTIMENT_BASE/german/scare\n  There should be two subdirectories once you are done:\n    scare_v1.0.0\n    scare_v1.0.0_text\n\n  We wound up not including this in the default German model.\n  It might be worth revisiting in the future.\n\nde_usage\n  https://www.romanklinger.de/usagecorpus/\n\n  http://www.lrec-conf.org/proceedings/lrec2014/summaries/85.html\n  The USAGE Review Corpus for Fine Grained Multi Lingual Opinion Analysis\n  Roman Klinger and Philipp Cimiano\n\n  Again, not included in the default German model\n\nChinese\n-------\n\nzh-hans_ren\n  This used to be here:\n  http://a1-www.is.tokushima-u.ac.jp/member/ren/Ren-CECps1.0/Ren-CECps1.0.html\n\n  That page doesn't seem to respond as of 2022, and I can't find it elsewhere.\n\nThe following will be available starting in 1.4.1:\n\nSpanish\n-------\n\ntass2020\n  - http://tass.sepln.org/2020/?page_id=74\n  - Download the following 5 files:\n      task1.2-test-gold.tsv\n      Task1-train-dev.zip\n      tass2020-test-gold.zip\n      Test1.1.zip\n      test1.2.zip\n    Put them in a directory\n      $SENTIMENT_BASE/spanish/tass2020\n\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset es_tass2020\n\n\nVietnamese\n----------\n\nvi_vsfc\n  I found a corpus labeled VSFC here:\n  https://drive.google.com/drive/folders/1xclbjHHK58zk2X6iqbvMPS2rcy9y9E0X\n  It doesn't seem to have a license or paper associated with it,\n  but happy to put those details here if relevant.\n\n  Download the files to\n    $SENTIMENT_BASE/vietnamese/_UIT-VSFC\n\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset vi_vsfc\n\nMarathi\n-------\n\nmr_l3cube\n  https://github.com/l3cube-pune/MarathiNLP\n\n  https://arxiv.org/abs/2103.11408\n  L3CubeMahaSent: A Marathi Tweet-based Sentiment Analysis Dataset\n  Atharva Kulkarni, Meet Mandhane, Manali Likhitkar, Gayatri Kshirsagar, Raviraj Joshi\n\n  git clone the repo in\n    $SENTIMENT_BASE\n\n  cd $SENTIMENT_BASE\n  git clone git@github.com:l3cube-pune/MarathiNLP.git\n\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset mr_l3cube\n\n\nHindi\n-----\n\nodiagenai\n  https://huggingface.co/datasets/OdiaGenAI/sentiment_analysis_hindi\n\n  Uses datasets package from HF, so that needs to be installed\n\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset hi_odiagenai\n\n  This dataset has 2497 sentences in a train section.\n  We randomly split them to make a usable dataset\n\nItalian\n-------\n\nit_sentipolc16\n  from here:\n  http://www.di.unito.it/~tutreeb/sentipolc-evalita16/data.html\n  paper describing the evaluation and the results:\n  http://ceur-ws.org/Vol-1749/paper_026.pdf\n\n  download the training and test zip files to $SENTIMENT_BASE/italian/sentipolc16\n  unzip them there\n\n  so you should have\n    $SENTIMENT_BASE/italian/sentipolc16/test_set_sentipolc16_gold2000.csv\n    $SENTIMENT_BASE/italian/sentipolc16/training_set_sentipolc16.csv\n\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset it_sentipolc16\n\n  this script splits the training data into dev & train, keeps the test the same\n\n  The conversion allows for 4 ways of handling the \"mixed\" class:\n    treat it as the same as neutral, treat it as a separate class,\n    only distinguish positive or not positive,\n    only distinguish negative or not negative\n  for more details:\n  python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset it_sentipolc16 --help\n\nanother option not implemented yet: absita18\n  http://sag.art.uniroma2.it/absita/data/\n\"\"\"\n\nimport os\nimport random\nimport sys\n\nimport stanza\nimport stanza.utils.default_paths as default_paths\n\nfrom stanza.utils.datasets.sentiment import process_airline\nfrom stanza.utils.datasets.sentiment import process_arguana_xml\nfrom stanza.utils.datasets.sentiment import process_corona\nfrom stanza.utils.datasets.sentiment import process_es_tass2020\nfrom stanza.utils.datasets.sentiment import process_it_sentipolc16\nfrom stanza.utils.datasets.sentiment import process_MELD\nfrom stanza.utils.datasets.sentiment import process_ren_chinese\nfrom stanza.utils.datasets.sentiment import process_sb10k\nfrom stanza.utils.datasets.sentiment import process_scare\nfrom stanza.utils.datasets.sentiment import process_slsd\nfrom stanza.utils.datasets.sentiment import process_sst\nfrom stanza.utils.datasets.sentiment import process_usage_german\nfrom stanza.utils.datasets.sentiment import process_vsfc_vietnamese\n\nfrom stanza.utils.datasets.sentiment import process_utils\n\nfrom tqdm import tqdm\n\ndef convert_sst_general(paths, dataset_name, version):\n    in_directory = paths['SENTIMENT_BASE']\n    sst_dir = os.path.join(in_directory, \"sentiment-treebank\")\n    train_phrases = process_sst.get_phrases(version, \"train.txt\", sst_dir)\n    dev_phrases = process_sst.get_phrases(version, \"dev.txt\", sst_dir)\n    test_phrases = process_sst.get_phrases(version, \"test.txt\", sst_dir)\n\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    dataset = [train_phrases, dev_phrases, test_phrases]\n    process_utils.write_dataset(dataset, out_directory, dataset_name)\n\ndef convert_sst2(paths, dataset_name, *args):\n    \"\"\"\n    Create a 2 class SST dataset (neutral items are dropped)\n    \"\"\"\n    convert_sst_general(paths, dataset_name, \"binary\")\n\ndef convert_sst2roots(paths, dataset_name, *args):\n    \"\"\"\n    Create a 2 class SST dataset using only the roots\n    \"\"\"\n    convert_sst_general(paths, dataset_name, \"binaryroot\")\n\ndef convert_sst3(paths, dataset_name, *args):\n    \"\"\"\n    Create a 3 class SST dataset using only the roots\n    \"\"\"\n    convert_sst_general(paths, dataset_name, \"threeclass\")\n\ndef convert_sst3roots(paths, dataset_name, *args):\n    \"\"\"\n    Create a 3 class SST dataset using only the roots\n    \"\"\"\n    convert_sst_general(paths, dataset_name, \"threeclassroot\")\n\ndef convert_sstplus(paths, dataset_name, *args):\n    \"\"\"\n    Create a 3 class SST dataset with a few other small datasets added\n    \"\"\"\n    train_phrases = []\n    in_directory = paths['SENTIMENT_BASE']\n    train_phrases.extend(process_arguana_xml.get_tokenized_phrases(os.path.join(in_directory, \"arguana\")))\n    train_phrases.extend(process_MELD.get_tokenized_phrases(\"train\", os.path.join(in_directory, \"MELD\")))\n    train_phrases.extend(process_slsd.get_tokenized_phrases(os.path.join(in_directory, \"slsd\")))\n    train_phrases.extend(process_airline.get_tokenized_phrases(os.path.join(in_directory, \"airline\")))\n\n    sst_dir = os.path.join(in_directory, \"sentiment-treebank\")\n    train_phrases.extend(process_sst.get_phrases(\"threeclass\", \"train.txt\", sst_dir))\n    train_phrases.extend(process_sst.get_phrases(\"threeclass\", \"extra-train.txt\", sst_dir))\n    train_phrases.extend(process_sst.get_phrases(\"threeclass\", \"checked-extra-train.txt\", sst_dir))\n\n    dev_phrases = process_sst.get_phrases(\"threeclass\", \"dev.txt\", sst_dir)\n    test_phrases = process_sst.get_phrases(\"threeclass\", \"test.txt\", sst_dir)\n\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    dataset = [train_phrases, dev_phrases, test_phrases]\n    process_utils.write_dataset(dataset, out_directory, dataset_name)\n\ndef convert_meld(paths, dataset_name, *args):\n    \"\"\"\n    Convert the MELD dataset to train/dev/test files\n    \"\"\"\n    in_directory = os.path.join(paths['SENTIMENT_BASE'], \"MELD\")\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    process_MELD.main(in_directory, out_directory, dataset_name)\n\ndef convert_corona(paths, dataset_name, *args):\n    \"\"\"\n    Convert the kaggle covid dataset to train/dev/test files\n    \"\"\"\n    process_corona.main(*args)\n\ndef convert_scare(paths, dataset_name, *args):\n    in_directory = os.path.join(paths['SENTIMENT_BASE'], \"german\", \"scare\")\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    process_scare.main(in_directory, out_directory, dataset_name)\n\n\ndef convert_de_usage(paths, dataset_name, *args):\n    in_directory = os.path.join(paths['SENTIMENT_BASE'], \"USAGE\")\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    process_usage_german.main(in_directory, out_directory, dataset_name)\n\ndef convert_sb10k(paths, dataset_name, *args):\n    \"\"\"\n    Essentially runs the sb10k script twice with different arguments to produce the de_sb10k dataset\n\n    stanza.utils.datasets.sentiment.process_sb10k --csv_filename extern_data/sentiment/german/sb-10k/de_full/de_test.tsv --out_dir $SENTIMENT_DATA_DIR --short_name de_sb10k --split test --sentiment_column 2 --text_column 3\n    stanza.utils.datasets.sentiment.process_sb10k --csv_filename extern_data/sentiment/german/sb-10k/de_full/de_train.tsv --out_dir $SENTIMENT_DATA_DIR --short_name de_sb10k --split train_dev --sentiment_column 2 --text_column 3\n    \"\"\"\n    column_args = [\"--sentiment_column\", \"2\", \"--text_column\", \"3\"]\n\n    process_sb10k.main([\"--csv_filename\", os.path.join(paths['SENTIMENT_BASE'], \"german\", \"sb-10k\", \"de_full\", \"de_test.tsv\"),\n                        \"--out_dir\", paths['SENTIMENT_DATA_DIR'],\n                        \"--short_name\", dataset_name,\n                        \"--split\", \"test\",\n                        *column_args])\n    process_sb10k.main([\"--csv_filename\", os.path.join(paths['SENTIMENT_BASE'], \"german\", \"sb-10k\", \"de_full\", \"de_train.tsv\"),\n                        \"--out_dir\", paths['SENTIMENT_DATA_DIR'],\n                        \"--short_name\", dataset_name,\n                        \"--split\", \"train_dev\",\n                        *column_args])\n\ndef convert_vi_vsfc(paths, dataset_name, *args):\n    in_directory = os.path.join(paths['SENTIMENT_BASE'], \"vietnamese\", \"_UIT-VSFC\")\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    process_vsfc_vietnamese.main(in_directory, out_directory, dataset_name)\n\ndef convert_hi_odiagenai(paths, dataset_name, *args):\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    os.makedirs(out_directory, exist_ok=True)\n\n    import datasets\n    ds = datasets.load_dataset(\"OdiaGenAI/sentiment_analysis_hindi\")\n\n    nlp = stanza.Pipeline(\"hi\", processors='tokenize')\n    mapping = {\"pos\": 2, \"neu\": 1, \"neg\": 0}\n\n    train = []\n    dev = []\n    test = []\n    for datum in tqdm(ds['train']):\n        random_slice = random.randint(0, 9)\n        if random_slice == 0:\n            random_slice = dev\n        elif random_slice == 1:\n            random_slice = test\n        else:\n            random_slice = train\n        datum = process_utils.process_datum(nlp, datum['text'], mapping, datum['label'])\n        random_slice.append(datum)\n\n    dataset = [train, dev, test]\n    process_utils.write_dataset(dataset, out_directory, dataset_name)\n\ndef convert_mr_l3cube(paths, dataset_name, *args):\n    # csv_filename = 'extern_data/sentiment/MarathiNLP/L3CubeMahaSent Dataset/tweets-train.csv'\n    MAPPING = {\"-1\": \"0\", \"0\": \"1\", \"1\": \"2\"}\n\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    os.makedirs(out_directory, exist_ok=True)\n\n    in_directory = os.path.join(paths['SENTIMENT_BASE'], \"MarathiNLP\", \"L3CubeMahaSent Dataset\")\n    input_files = ['tweets-train.csv', 'tweets-valid.csv', 'tweets-test.csv']\n    input_files = [os.path.join(in_directory, x) for x in input_files]\n    datasets = [process_utils.read_snippets(csv_filename, sentiment_column=1, text_column=0, tokenizer_language=\"mr\", mapping=MAPPING, delimiter=',', quotechar='\"', skip_first_line=True)\n                for csv_filename in input_files]\n\n    process_utils.write_dataset(datasets, out_directory, dataset_name)\n\ndef convert_es_tass2020(paths, dataset_name, *args):\n    process_es_tass2020.convert_tass2020(paths['SENTIMENT_BASE'], paths['SENTIMENT_DATA_DIR'], dataset_name)\n\ndef convert_it_sentipolc16(paths, dataset_name, *args):\n    in_directory = os.path.join(paths['SENTIMENT_BASE'], \"italian\", \"sentipolc16\")\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    process_it_sentipolc16.main(in_directory, out_directory, dataset_name, *args)\n\n\ndef convert_ren(paths, dataset_name, *args):\n    in_directory = os.path.join(paths['SENTIMENT_BASE'], \"chinese\", \"RenCECps\")\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    process_ren_chinese.main(in_directory, out_directory, dataset_name)\n\nDATASET_MAPPING = {\n    \"de_sb10k\":     convert_sb10k,\n    \"de_scare\":     convert_scare,\n    \"de_usage\":     convert_de_usage,\n\n    \"en_corona\":    convert_corona,\n    \"en_sst2\":      convert_sst2,\n    \"en_sst2roots\": convert_sst2roots,\n    \"en_sst3\":      convert_sst3,\n    \"en_sst3roots\": convert_sst3roots,\n    \"en_sstplus\":   convert_sstplus,\n    \"en_meld\":      convert_meld,\n\n    \"es_tass2020\":  convert_es_tass2020,\n\n    \"hi_odiagenai\": convert_hi_odiagenai,\n\n    \"it_sentipolc16\": convert_it_sentipolc16,\n\n    \"mr_l3cube\":    convert_mr_l3cube,\n\n    \"vi_vsfc\":      convert_vi_vsfc,\n\n    \"zh-hans_ren\":  convert_ren,\n}\n\ndef main(dataset_name, *args):\n    paths = default_paths.get_default_paths()\n\n    random.seed(1234)\n\n    if dataset_name in DATASET_MAPPING:\n        DATASET_MAPPING[dataset_name](paths, dataset_name, *args)\n    else:\n        raise ValueError(f\"dataset {dataset_name} currently not handled\")\n\nif __name__ == '__main__':\n    main(sys.argv[1], sys.argv[2:])\n\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_MELD.py",
    "content": "\"\"\"\nMELD is a dataset of Friends (the TV show) utterances.  \n\nThe ratings include judgment based on the visuals, so it might be\nharder than expected to directly extract from the text.  However, it\nshould broaden the scope of the model and doesn't seem to hurt\nperformance.\n\nhttps://github.com/SenticNet/MELD/tree/master/data/MELD\n\nhttps://github.com/SenticNet/MELD\n\nhttps://arxiv.org/pdf/1810.02508.pdf\n\nFiles in the MELD repo are csv, with quotes in \"...\" if they contained commas themselves.\n\nAccordingly, we use the csv module to read the files and output them in the format\n<class> <sentence>\n\nRun using \n\npython3 convert_MELD.py MELD/train_sent_emo.csv train.txt\netc\n\n\"\"\"\n\nimport csv\nimport os\nimport sys\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\ndef get_phrases(in_filename):\n    \"\"\"\n    Get the phrases from a single CSV filename\n    \"\"\"\n    with open(in_filename, newline='', encoding='windows-1252') as fin:\n        cin = csv.reader(fin, delimiter=',', quotechar='\"')\n        lines = list(cin)\n\n    phrases = []\n    for line in lines[1:]:\n        sentiment = line[4]\n        if sentiment == 'negative':\n            sentiment = '0'\n        elif sentiment == 'neutral':\n            sentiment = '1'\n        elif sentiment == 'positive':\n            sentiment = '2'\n        else:\n            raise ValueError(\"Unknown sentiment: {}\".format(sentiment))\n        utterance = line[1].replace(\"Â\", \"\")\n        phrases.append(SentimentDatum(sentiment, utterance))\n    return phrases\n\ndef get_tokenized_phrases(split, in_directory):\n    \"\"\"\n    split in train,dev,test\n    \"\"\"\n    in_filename  = os.path.join(in_directory, \"%s_sent_emo.csv\" % split)\n    phrases = get_phrases(in_filename)\n\n    phrases = process_utils.get_ptb_tokenized_phrases(phrases)\n    print(\"Found {} phrases in MELD {}\".format(len(phrases), split))\n    return phrases\n\ndef main(in_directory, out_directory, short_name):\n    os.makedirs(out_directory, exist_ok=True)\n    for split in (\"train\", \"dev\", \"test\"):\n        phrases = get_tokenized_phrases(split, in_directory)\n        process_utils.write_list(os.path.join(out_directory, \"%s.%s.json\" % (short_name, split)), phrases)\n\nif __name__ == '__main__':\n    in_directory = sys.argv[1]\n    out_directory = sys.argv[2]\n    short_name = sys.argv[3]\n\n    main(in_directory, out_directory, short_name)\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_airline.py",
    "content": "\"\"\"\nAirline tweets from Kaggle\nfrom https://www.kaggle.com/crowdflower/twitter-airline-sentiment/data#\nSome ratings seem questionable, but it doesn't hurt performance much, if at all\n\nFiles in the airline repo are csv, with quotes in \"...\" if they contained commas themselves.\n\nAccordingly, we use the csv module to read the files and output them in the format\n<class> <sentence>\n\nRun using \n\npython3 convert_airline.py Tweets.csv train.json\n\nIf the first word is an @, it is removed, and after that, leading @ or # are removed.\nFor example:\n\n@AngledLuffa you must hate having Mox Opal #banned\n-> \nyou must hate having Mox Opal banned\n\"\"\"\n\nimport csv\nimport os\nimport sys\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\ndef get_phrases(in_directory):\n    in_filename = os.path.join(in_directory, \"Tweets.csv\")\n    with open(in_filename, newline='') as fin:\n        cin = csv.reader(fin, delimiter=',', quotechar='\"')\n        lines = list(cin)\n\n    phrases = []\n    for line in lines[1:]:\n        sentiment = line[1]\n        if sentiment == 'negative':\n            sentiment = '0'\n        elif sentiment == 'neutral':\n            sentiment = '1'\n        elif sentiment == 'positive':\n            sentiment = '2'\n        else:\n            raise ValueError(\"Unknown sentiment: {}\".format(sentiment))\n        # some of the tweets have \\n in them\n        utterance = line[10].replace(\"\\n\", \" \")\n        phrases.append(SentimentDatum(sentiment, utterance))\n\n    return phrases\n\ndef get_tokenized_phrases(in_directory):\n    phrases = get_phrases(in_directory)\n    phrases = process_utils.get_ptb_tokenized_phrases(phrases)\n    phrases = [SentimentDatum(x.sentiment, process_utils.clean_tokenized_tweet(x.text)) for x in phrases]\n    print(\"Found {} phrases in the airline corpus\".format(len(phrases)))\n    return phrases\n\ndef main(in_directory, out_directory, short_name):\n    phrases = get_tokenized_phrases(in_directory)\n\n    os.makedirs(out_directory, exist_ok=True)\n    out_filename = os.path.join(out_directory, \"%s.train.json\" % short_name)\n    # filter leading @United, @American, etc from the tweets\n    process_utils.write_list(out_filename, phrases)\n\n    # something like this would count @s if you cared enough to count\n    # would need to update for SentimentDatum()\n    #ats = Counter()\n    #for line in lines:\n    #    ats.update([x for x in line.split() if x[0] == '@'])\n\nif __name__ == '__main__':\n    in_directory = sys.argv[1]\n    out_directory = sys.argv[2]\n    short_name = sys.argv[3]\n    main(in_directory, out_directory, short_name)\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_arguana_xml.py",
    "content": "from collections import namedtuple\nimport glob\nimport os\nimport sys\nimport xml.etree.ElementTree as ET\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\nArguanaSentimentDatum = namedtuple('ArguanaSentimentDatum', ['begin', 'end', 'rating'])\n\n\"\"\"\nExtracts positive, neutral, and negative phrases from the ArguAna hotel review corpus\n\nRun as follows:\n\npython3 parse_arguana_xml.py split/training data/sentiment\n\nArguAna can be downloaded here:\n\nhttp://argumentation.bplaced.net/arguana/data\nhttp://argumentation.bplaced.net/arguana-data/arguana-tripadvisor-annotated-v2.zip\n\"\"\"\n\ndef get_phrases(filename):\n    tree = ET.parse(filename)\n    fragments = []\n\n    root = tree.getroot()\n    body = None\n    for child in root:\n        if child.tag == '{http:///uima/cas.ecore}Sofa':\n            body = child.attrib['sofaString']\n        elif child.tag == '{http:///de/aitools/ie/uima/type/arguana.ecore}Fact':\n            fragments.append(ArguanaSentimentDatum(begin=int(child.attrib['begin']),\n                                                   end=int(child.attrib['end']),\n                                                   rating=\"1\"))\n        elif child.tag == '{http:///de/aitools/ie/uima/type/arguana.ecore}Opinion':\n            if child.attrib['polarity'] == 'negative':\n                rating = \"0\"\n            elif child.attrib['polarity'] == 'positive':\n                rating = \"2\"\n            else:\n                raise ValueError(\"Unexpected polarity found in {}\".format(filename))\n            fragments.append(ArguanaSentimentDatum(begin=int(child.attrib['begin']),\n                                                   end=int(child.attrib['end']),\n                                                   rating=rating))\n\n\n    phrases = [SentimentDatum(fragment.rating, body[fragment.begin:fragment.end]) for fragment in fragments]\n    #phrases = [phrase.replace(\"\\n\", \" \") for phrase in phrases]\n    return phrases\n\ndef get_phrases_from_directory(directory):\n    phrases = []\n    inpath = os.path.join(directory, \"arguana-tripadvisor-annotated-v2\", \"split\", \"training\", \"*\", \"*xmi\")\n    for filename in glob.glob(inpath):\n        phrases.extend(get_phrases(filename))\n    return phrases\n\ndef get_tokenized_phrases(in_directory):\n    phrases = get_phrases_from_directory(in_directory)\n    phrases = process_utils.get_ptb_tokenized_phrases(phrases)\n    print(\"Found {} phrases in arguana\".format(len(phrases)))\n    return phrases\n\ndef main(in_directory, out_directory, short_name):\n    phrases = get_tokenized_phrases(in_directory)\n    process_utils.write_list(os.path.join(out_directory, \"%s.train.json\" % short_name), phrases)\n\n\nif __name__ == \"__main__\":\n    in_directory = sys.argv[1]\n    out_directory = sys.argv[2]\n    short_name = sys.argv[3]\n    main(in_directory, out_directory, short_name)\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_corona.py",
    "content": "\"\"\"\nProcesses a kaggle covid-19 text classification dataset\n\nThe original description of the dataset is here:\n\nhttps://www.kaggle.com/datasets/datatattle/covid-19-nlp-text-classification\n\nThere are two files in the archive, Corona_NLP_train.csv and Corona_NLP_test.csv\nUnzip the files in archive.zip to $SENTIMENT_BASE/english/corona/Corona_NLP_train.csv\n\nThere is no dedicated dev set, so we randomly split train/dev\n(using a specific seed, so that the split always comes out the same)\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nimport stanza\n\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\nfrom stanza.utils.default_paths import get_default_paths\n\n# TODO: could give an option to keep the 'extremely'\nMAPPING = {'extremely positive': \"2\",\n           'positive': \"2\",\n           'neutral': \"1\",\n           'negative': \"0\",\n           'extremely negative': \"0\"}\n\ndef main(args=None):\n    default_paths = get_default_paths()\n    sentiment_base_dir = default_paths[\"SENTIMENT_BASE\"]\n    default_in_dir = os.path.join(sentiment_base_dir, \"english\", \"corona\")\n    default_out_dir = default_paths[\"SENTIMENT_DATA_DIR\"]\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--in_dir', type=str, default=default_in_dir, help='Where to get the input files')\n    parser.add_argument('--out_dir', type=str, default=default_out_dir, help='Where to write the output files')\n    parser.add_argument('--short_name', type=str, default=\"en_corona\", help='short name to use when writing files')\n    args = parser.parse_args(args=args)\n\n    TEXT_COLUMN = 4\n    SENTIMENT_COLUMN = 5\n\n    train_csv = os.path.join(args.in_dir, \"Corona_NLP_train.csv\")\n    test_csv = os.path.join(args.in_dir, \"Corona_NLP_test.csv\")\n\n    nlp = stanza.Pipeline(\"en\", processors='tokenize')\n\n    train_snippets = process_utils.read_snippets(train_csv, SENTIMENT_COLUMN, TEXT_COLUMN, 'en', MAPPING, delimiter=\",\", quotechar='\"', skip_first_line=True, nlp=nlp, encoding=\"latin1\")\n    test_snippets = process_utils.read_snippets(test_csv, SENTIMENT_COLUMN, TEXT_COLUMN, 'en', MAPPING, delimiter=\",\", quotechar='\"', skip_first_line=True, nlp=nlp, encoding=\"latin1\")\n\n    print(\"Read %d train snippets\" % len(train_snippets))\n    print(\"Read %d test snippets\" % len(test_snippets))\n\n    random.seed(1234)\n    random.shuffle(train_snippets)\n\n    os.makedirs(args.out_dir, exist_ok=True)\n    process_utils.write_splits(args.out_dir,\n                               train_snippets,\n                               (process_utils.Split(\"%s.train.json\" % args.short_name, 0.9),\n                                process_utils.Split(\"%s.dev.json\" % args.short_name, 0.1)))\n    process_utils.write_list(os.path.join(args.out_dir, \"%s.test.json\" % args.short_name), test_snippets)\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_es_tass2020.py",
    "content": "\"\"\"\nConvert the TASS 2020 dataset, available here: http://tass.sepln.org/2020/?page_id=74\n\nThere are two parts to the dataset, but only part 1 has the gold\nannotations available.\n\nDownload:\nTask 1 train & dev sets\nTask 1.1 test set\nTask 1.2 test set\nTask 1.1 test set gold standard\nTask 1.2 test set gold standard   (.tsv, not .zip)\n\nNo need to unzip any of the files.  The extraction script reads the\nexpected paths directly from the zip files.\n\nThere are two subtasks in TASS 2020.  One is split among 5 Spanish\nspeaking countries, and the other is combined across all of the\ncountries.  Here we combine all of the data into one output file.\n\nAlso, each of the subparts are output into their own files, such as\np2.json, p1.mx.json, etc\n\"\"\"\n\nimport os\nimport zipfile\n\nimport stanza\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.default_paths as default_paths\nfrom stanza.utils.datasets.sentiment.process_utils import write_dataset, write_list\n\ndef convert_label(label):\n    \"\"\"\n    N/NEU/P or error\n    \"\"\"\n    if label == \"N\":\n        return 0\n    if label == \"NEU\":\n        return 1\n    if label == \"P\":\n        return 2\n    raise ValueError(\"Unexpected label %s\" % label)\n\ndef read_test_labels(fin):\n    \"\"\"\n    Read a tab (or space) separated list of id/label pairs\n    \"\"\"\n    label_map = {}\n    for line_idx, line in enumerate(fin):\n        if isinstance(line, bytes):\n            line = line.decode(\"utf-8\")\n        pieces = line.split()\n        if len(pieces) < 2:\n            continue\n        if len(pieces) > 2:\n            raise ValueError(\"Unexpected format at line %d: all label lines should be len==2\\n%s\" % (line_idx, line))\n\n        datum_id, label = pieces\n        try:\n            label = convert_label(label)\n        except ValueError:\n            raise ValueError(\"Unexpected test label %s at line %d\\n%s\" % (label, line_idx, line))\n\n        label_map[datum_id] = label\n    return label_map\n\ndef open_read_test_labels(filename, zip_filename=None):\n    \"\"\"\n    Open either a text or zip file, then read the labels\n    \"\"\"\n    if zip_filename is None:\n        with open(filename, encoding=\"utf-8\") as fin:\n            test_labels = read_test_labels(fin)\n            print(\"Read %d lines from %s\" % (len(test_labels), filename))\n            return test_labels\n\n    with zipfile.ZipFile(zip_filename) as zin:\n        with zin.open(filename) as fin:\n            test_labels = read_test_labels(fin)\n            print(\"Read %d lines from %s - %s\" % (len(test_labels), zip_filename, filename))\n            return test_labels\n\n\ndef read_sentences(fin):\n    \"\"\"\n    Read ids and text from the given file\n    \"\"\"\n    lines = []\n    for line_idx, line in enumerate(fin):\n        line = line.decode(\"utf-8\")\n        pieces = line.split(maxsplit=1)\n        if len(pieces) < 2:\n            continue\n        lines.append(pieces)\n    return lines\n\ndef open_read_sentences(filename, zip_filename):\n    \"\"\"\n    Opens a file and then reads the sentences\n\n    Only applies to files inside zips, as all of the sentence files in\n    this dataset are inside a zip\n    \"\"\"\n    with zipfile.ZipFile(zip_filename) as zin:\n        with zin.open(filename) as fin:\n            test_sentences = read_sentences(fin)\n            print(\"Read %d texts from %s - %s\" % (len(test_sentences), zip_filename, filename))\n\n    return test_sentences\n\ndef combine_test_set(sentences, labels):\n    \"\"\"\n    Combines the labels and sentences from two pieces of the test set\n\n    Matches the ID from the label files and the text files\n    \"\"\"\n    combined = []\n    if len(sentences) != len(labels):\n        raise ValueError(\"Lengths of sentences and labels should match!\")\n    for sent_id, text in sentences:\n        label = labels.get(sent_id, None)\n        if label is None:\n            raise KeyError(\"Cannot find a test label from the ID: %s\" % sent_id)\n        # not tokenized yet - we can do tokenization in batches\n        combined.append(SentimentDatum(label, text))\n    return combined\n\nDATASET_PIECES = (\"cr\", \"es\", \"mx\", \"pe\", \"uy\")\n\ndef tokenize(sentiment_data, pipe):\n    \"\"\"\n    Takes a list of (label, text) and returns a list of SentimentDatum with tokenized text\n\n    Only the first 'sentence' is used - ideally the pipe has ssplit turned off\n    \"\"\"\n    docs = [x.text for x in sentiment_data]\n    in_docs = [stanza.Document([], text=d) for d in docs]\n    out_docs = pipe(in_docs)\n\n    sentiment_data = [SentimentDatum(datum.sentiment,\n                                     [y.text for y in doc.sentences[0].tokens]) # list of text tokens for each doc\n                      for datum, doc in zip(sentiment_data, out_docs)]\n\n    return sentiment_data\n\ndef read_test_set(label_zip_filename, label_filename, sentence_zip_filename, sentence_filename, pipe):\n    \"\"\"\n    Read and tokenize an entire test set given the label and sentence filenames\n    \"\"\"\n    test_labels = open_read_test_labels(label_filename, label_zip_filename)\n    test_sentences = open_read_sentences(sentence_filename, sentence_zip_filename)\n    sentiment_data = combine_test_set(test_sentences, test_labels)\n    return tokenize(sentiment_data, pipe)\n\n    return sentiment_data\n\ndef read_train_file(zip_filename, filename, pipe):\n    \"\"\"\n    Read and tokenize a train set\n\n    All of the train data is inside one zip.  We read it one piece at a time\n    \"\"\"\n    sentiment_data = []\n    with zipfile.ZipFile(zip_filename) as zin:\n        with zin.open(filename) as fin:\n            for line_idx, line in enumerate(fin):\n                if isinstance(line, bytes):\n                    line = line.decode(\"utf-8\")\n                pieces = line.split(maxsplit=1)\n                if len(pieces) < 2:\n                    continue\n                pieces = pieces[1].rsplit(maxsplit=1)\n                if len(pieces) < 2:\n                    continue\n                text, label = pieces\n                try:\n                    label = convert_label(label)\n                except ValueError:\n                    raise ValueError(\"Unexpected train label %s at line %d\\n%s\" % (label, line_idx, line))\n                sentiment_data.append(SentimentDatum(label, text))\n\n    print(\"Read %d texts from %s - %s\" % (len(sentiment_data), zip_filename, filename))\n    sentiment_data = tokenize(sentiment_data, pipe)\n    return sentiment_data\n\ndef convert_tass2020(in_directory, out_directory, dataset_name):\n    \"\"\"\n    Read all of the data from in_directory/spanish/tass2020, write it to out_directory/dataset_name...\n    \"\"\"\n    in_directory = os.path.join(in_directory, \"spanish\", \"tass2020\")\n\n    pipe = stanza.Pipeline(lang=\"es\", processors=\"tokenize\", tokenize_no_ssplit=True)\n\n    test_11 = {}\n    test_11_labels_zip = os.path.join(in_directory, \"tass2020-test-gold.zip\")\n    test_11_sentences_zip = os.path.join(in_directory, \"Test1.1.zip\")\n    for piece in DATASET_PIECES:\n        inner_label_filename = piece + \".tsv\"\n        inner_sentence_filename = os.path.join(\"Test1.1\", piece.upper() + \".tsv\")\n        test_11[piece] = read_test_set(test_11_labels_zip, inner_label_filename,\n                                       test_11_sentences_zip, inner_sentence_filename, pipe)\n\n    test_12_label_filename = os.path.join(in_directory, \"task1.2-test-gold.tsv\")\n    test_12_sentences_zip = os.path.join(in_directory, \"test1.2.zip\")\n    test_12_sentences_filename = \"test1.2/task1.2.tsv\"\n    test_12 = read_test_set(None, test_12_label_filename,\n                            test_12_sentences_zip, test_12_sentences_filename, pipe)\n\n    train_dev_zip = os.path.join(in_directory, \"Task1-train-dev.zip\")\n    dev = {}\n    train = {}\n    for piece in DATASET_PIECES:\n        dev_filename = os.path.join(\"dev\", piece + \".tsv\")\n        dev[piece] = read_train_file(train_dev_zip, dev_filename, pipe)\n\n    for piece in DATASET_PIECES:\n        train_filename = os.path.join(\"train\", piece + \".tsv\")\n        train[piece] = read_train_file(train_dev_zip, train_filename, pipe)\n\n    all_test = test_12 + [item for piece in test_11.values() for item in piece]\n    all_dev = [item for piece in dev.values() for item in piece]\n    all_train = [item for piece in train.values() for item in piece]\n\n    print(\"Total train items: %8d\" % len(all_train))\n    print(\"Total dev items:   %8d\" % len(all_dev))\n    print(\"Total test items:  %8d\" % len(all_test))\n\n    write_dataset((all_train, all_dev, all_test), out_directory, dataset_name)\n\n    output_file = os.path.join(out_directory, \"%s.test.p2.json\" % dataset_name)\n    write_list(output_file, test_12)\n\n    for piece in DATASET_PIECES:\n        output_file = os.path.join(out_directory, \"%s.test.p1.%s.json\" % (dataset_name, piece))\n        write_list(output_file, test_11[piece])\n\ndef main(paths):\n    in_directory = paths['SENTIMENT_BASE']\n    out_directory = paths['SENTIMENT_DATA_DIR']\n\n    convert_tass2020(in_directory, out_directory, \"es_tass2020\")\n\n\nif __name__ == '__main__':\n    paths = default_paths.get_default_paths()\n    main(paths)\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_it_sentipolc16.py",
    "content": "\"\"\"\nProcess the SentiPolc dataset from Evalita\n\nCan be run as a standalone script or as a module from\nprepare_sentiment_dataset\n\nAn option controls how to split up the positive/negative/neutral/mixed classes\n\"\"\"\n\nimport argparse\nfrom enum import Enum\nimport os\nimport random\nimport sys\n\nimport stanza\nfrom stanza.utils.datasets.sentiment import process_utils\nimport stanza.utils.default_paths as default_paths\n\nclass Mode(Enum):\n    COMBINED = 1\n    SEPARATE = 2\n    POSITIVE = 3\n    NEGATIVE = 4\n\ndef main(in_dir, out_dir, short_name, *args):\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--mode', default=Mode.COMBINED, type=lambda x: Mode[x.upper()],\n                        help='How to handle mixed vs neutral.  {}'.format(\", \".join(x.name for x in Mode)))\n    parser.add_argument('--name', default=None, type=str,\n                        help='Use a different name to save the dataset.  Useful for keeping POSITIVE & NEGATIVE separate')\n    args = parser.parse_args(args=list(*args))\n\n    if args.name is not None:\n        short_name = args.name\n\n    nlp = stanza.Pipeline(\"it\", processors='tokenize')\n\n    if args.mode == Mode.COMBINED:\n        mapping = {\n            ('0', '0'): \"1\", # neither negative nor positive: neutral\n            ('1', '0'): \"2\", # positive, not negative: positive\n            ('0', '1'): \"0\", # negative, not positive: negative\n            ('1', '1'): \"1\", # mixed combined with neutral\n        }\n    elif args.mode == Mode.SEPARATE:\n        mapping = {\n            ('0', '0'): \"1\", # neither negative nor positive: neutral\n            ('1', '0'): \"2\", # positive, not negative: positive\n            ('0', '1'): \"0\", # negative, not positive: negative\n            ('1', '1'): \"3\", # mixed as a different class\n        }\n    elif args.mode == Mode.POSITIVE:\n        mapping = {\n            ('0', '0'): \"0\", # neutral -> not positive\n            ('1', '0'): \"1\", # positive -> positive\n            ('0', '1'): \"0\", # negative -> not positive\n            ('1', '1'): \"1\", # mixed -> positive\n        }\n    elif args.mode == Mode.NEGATIVE:\n        mapping = {\n            ('0', '0'): \"0\", # neutral -> not negative\n            ('1', '0'): \"0\", # positive -> not negative\n            ('0', '1'): \"1\", # negative -> negative\n            ('1', '1'): \"1\", # mixed -> negative\n        }\n\n    print(\"Using {} scheme to handle the 4 values.  Mapping: {}\".format(args.mode, mapping))\n    print(\"Saving to {} using the short name {}\".format(out_dir, short_name))\n\n    test_filename = os.path.join(in_dir, \"test_set_sentipolc16_gold2000.csv\")\n    test_snippets = process_utils.read_snippets(test_filename, (2,3), 8, \"it\", mapping, delimiter=\",\", skip_first_line=False, quotechar='\"', nlp=nlp)\n\n    train_filename = os.path.join(in_dir, \"training_set_sentipolc16.csv\")\n    train_snippets = process_utils.read_snippets(train_filename, (2,3), 8, \"it\", mapping, delimiter=\",\", skip_first_line=True, quotechar='\"', nlp=nlp)\n\n    random.shuffle(train_snippets)\n    dev_len = len(train_snippets) // 10\n    dev_snippets = train_snippets[:dev_len]\n    train_snippets = train_snippets[dev_len:]\n\n    dataset = (train_snippets, dev_snippets, test_snippets)\n\n    process_utils.write_dataset(dataset, out_dir, short_name)\n\nif __name__ == '__main__':\n    paths = default_paths.get_default_paths()\n    random.seed(1234)\n\n    in_directory = os.path.join(paths['SENTIMENT_BASE'], \"italian\", \"sentipolc16\")\n    out_directory = paths['SENTIMENT_DATA_DIR']\n    main(in_directory, out_directory, \"it_sentipolc16\", sys.argv[1:])\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_ren_chinese.py",
    "content": "import glob\nimport os\nimport random\nimport sys\n\nimport xml.etree.ElementTree as ET\n\nfrom collections import namedtuple\n\nimport stanza\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\n\"\"\"\nThis processes a Chinese corpus, hosted here:\n\nhttp://a1-www.is.tokushima-u.ac.jp/member/ren/Ren-CECps1.0/Ren-CECps1.0.html\n\nThe authors want a signed document saying you won't redistribute the corpus.\n\nThe corpus format is a bunch of .xml files, with sentences labeled with various emotions and an overall polarity.  Polarity is labeled as follows:\n\n消极: negative\n中性: neutral\n积极: positive\n\"\"\"\n\ndef get_phrases(filename):\n    tree = ET.parse(filename)\n    fragments = []\n\n    root = tree.getroot()\n    for child in root:\n        if child.tag == 'paragraph':\n            for subchild in child:\n                if subchild.tag == 'sentence':\n                    text = subchild.attrib['S'].strip()\n                    if len(text) <= 2:\n                        continue\n                    polarity = None\n                    for inner in subchild:\n                        if inner.tag == 'Polarity':\n                            polarity = inner\n                            break\n                    if polarity is None:\n                        print(\"Found sentence with no polarity in {}: {}\".format(filename, text))\n                        continue\n                    if polarity.text == '消极':\n                        sentiment = \"0\"\n                    elif polarity.text == '中性':\n                        sentiment = \"1\"\n                    elif polarity.text == '积极':\n                        sentiment = \"2\"\n                    else:\n                        raise ValueError(\"Unknown polarity {} in {}\".format(polarity.text, filename))\n                    fragments.append(SentimentDatum(sentiment, text))\n\n    return fragments\n\ndef read_snippets(xml_directory):\n    sentences = []\n    for filename in glob.glob(xml_directory + '/xml/cet_*xml'):\n        sentences.extend(get_phrases(filename))\n\n    nlp = stanza.Pipeline('zh', processors='tokenize')\n    snippets = []\n    for sentence in sentences:\n        doc = nlp(sentence.text)\n        text = [token.text for sentence in doc.sentences for token in sentence.tokens]\n        snippets.append(SentimentDatum(sentence.sentiment, text))\n    random.shuffle(snippets)\n    return snippets\n\ndef main(xml_directory, out_directory, short_name):\n    snippets = read_snippets(xml_directory)\n\n    print(\"Found {} phrases\".format(len(snippets)))\n    os.makedirs(out_directory, exist_ok=True)\n    process_utils.write_splits(out_directory,\n                               snippets,\n                               (process_utils.Split(\"%s.train.json\" % short_name, 0.8),\n                                process_utils.Split(\"%s.dev.json\" % short_name, 0.1),\n                                process_utils.Split(\"%s.test.json\" % short_name, 0.1)))\n\n\nif __name__ == \"__main__\":\n    random.seed(1234)\n    xml_directory = sys.argv[1]\n    out_directory = sys.argv[2]\n    short_name = sys.argv[3]\n    main(xml_directory, out_directory, short_name)\n\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_sb10k.py",
    "content": "\"\"\"\nProcesses the SB10k dataset\n\nThe original description of the dataset and corpus_v1.0.tsv is here:\n\nhttps://www.spinningbytes.com/resources/germansentiment/\n\nDownload script is here:\n\nhttps://github.com/aritter/twitter_download\n\nThe problem with this file is that many of the tweets with labels no\nlonger exist.  Roughly 1/3 as of June 2020.\n\nYou can contact the authors for the complete dataset.\n\nThere is a paper describing some experiments run on the dataset here:\nhttps://dl.acm.org/doi/pdf/10.1145/3038912.3052611\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nfrom enum import Enum\n\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\nclass Split(Enum):\n    TRAIN_DEV_TEST = 1\n    TRAIN_DEV = 2\n    TEST = 3\n\nMAPPING = {'positive': \"2\",\n           'neutral': \"1\",\n           'negative': \"0\"}\n\ndef main(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--csv_filename', type=str, default=None, help='CSV file to read in')\n    parser.add_argument('--out_dir', type=str, default=None, help='Where to write the output files')\n    parser.add_argument('--sentiment_column', type=int, default=2, help='Column with the sentiment')\n    parser.add_argument('--text_column', type=int, default=3, help='Column with the text')\n    parser.add_argument('--short_name', type=str, default=\"sb10k\", help='short name to use when writing files')\n\n    parser.add_argument('--split', type=lambda x: Split[x.upper()], default=Split.TRAIN_DEV_TEST,\n                        help=\"How to split the resulting data\")\n\n    args = parser.parse_args(args=args)\n\n    snippets = process_utils.read_snippets(args.csv_filename, args.sentiment_column, args.text_column, 'de', MAPPING)\n\n    print(len(snippets))\n    random.shuffle(snippets)\n\n    os.makedirs(args.out_dir, exist_ok=True)\n    if args.split is Split.TRAIN_DEV_TEST:\n        process_utils.write_splits(args.out_dir,\n                                   snippets,\n                                   (process_utils.Split(\"%s.train.json\" % args.short_name, 0.8),\n                                    process_utils.Split(\"%s.dev.json\" % args.short_name, 0.1),\n                                    process_utils.Split(\"%s.test.json\" % args.short_name, 0.1)))\n    elif args.split is Split.TRAIN_DEV:\n        process_utils.write_splits(args.out_dir,\n                                   snippets,\n                                   (process_utils.Split(\"%s.train.json\" % args.short_name, 0.9),\n                                    process_utils.Split(\"%s.dev.json\" % args.short_name, 0.1)))\n    elif args.split is Split.TEST:\n        process_utils.write_list(os.path.join(args.out_dir, \"%s.test.json\" % args.short_name), snippets)\n    else:\n        raise ValueError(\"Unknown split method {}\".format(args.split))\n\nif __name__ == '__main__':\n    random.seed(1234)\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_scare.py",
    "content": "\"\"\"\nSCARE is a dataset of German text with sentiment annotations.\n\nhttp://romanklinger.de/scare/\n\nTo run the script, pass in the directory where scare was unpacked.  It\nshould have subdirectories scare_v1.0.0 and scare_v1.0.0_text\n\nYou need to fill out a license agreement to not redistribute the data\nin order to get the data, but the process is not onerous.\n\nAlthough it sounds interesting, there are unfortunately a lot of very\nshort items.  Not sure the long items will be enough\n\"\"\"\n\n\nimport csv\nimport glob\nimport os\nimport sys\n\nimport stanza\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\ndef get_scare_snippets(nlp, csv_dir_path, text_id_map, filename_pattern=\"*.csv\"):\n    \"\"\"\n    Read snippets from the given CSV directory\n    \"\"\"\n    num_short_items = 0\n\n    snippets = []\n    csv_files = glob.glob(os.path.join(csv_dir_path, filename_pattern))\n    for csv_filename in csv_files:\n        with open(csv_filename, newline='') as fin:\n            cin = csv.reader(fin, delimiter='\\t', quotechar='\"')\n            lines = list(cin)\n\n            for line in lines:\n                ann_id, begin, end, sentiment = [line[i] for i in [1, 2, 3, 6]]\n                begin = int(begin)\n                end = int(end)\n                if sentiment.lower() == 'unknown':\n                    continue\n                elif sentiment.lower() == 'positive':\n                    sentiment = 2\n                elif sentiment.lower() == 'neutral':\n                    sentiment = 1\n                elif sentiment.lower() == 'negative':\n                    sentiment = 0\n                else:\n                    raise ValueError(\"Tell John he screwed up and this is why he can't have Mox Opal: {}\".format(sentiment))\n                if ann_id not in text_id_map:\n                    print(\"Found snippet which can't be found: {}-{}\".format(csv_filename, ann_id))\n                    continue\n                snippet = text_id_map[ann_id][begin:end]\n                doc = nlp(snippet)\n                text = [token.text for sentence in doc.sentences for token in sentence.tokens]\n                num_tokens = sum(len(sentence.tokens) for sentence in doc.sentences)\n                if num_tokens < 4:\n                    num_short_items = num_short_items + 1\n                snippets.append(SentimentDatum(sentiment, text))\n    print(\"Number of short items: {}\".format(num_short_items))\n    return snippets\n\n\ndef main(in_directory, out_directory, short_name):\n    os.makedirs(out_directory, exist_ok=True)\n\n    input_path = os.path.join(in_directory, \"scare_v1.0.0_text\", \"annotations\", \"*txt\")\n    text_files = glob.glob(input_path)\n    if len(text_files) == 0:\n        raise FileNotFoundError(\"Did not find any input files in %s\" % input_path)\n    else:\n        print(\"Found %d input files in %s\" % (len(text_files), input_path))\n    text_id_map = {}\n    for filename in text_files:\n        with open(filename) as fin:\n            for line in fin.readlines():\n                line = line.strip()\n                if not line:\n                    continue\n                key, value = line.split(maxsplit=1)\n                if key in text_id_map:\n                    raise ValueError(\"Duplicate key {}\".format(key))\n                text_id_map[key] = value\n\n    print(\"Found %d total sentiment ratings\" % len(text_id_map))\n    nlp = stanza.Pipeline('de', processors='tokenize')\n    snippets = get_scare_snippets(nlp, os.path.join(in_directory, \"scare_v1.0.0\", \"annotations\"), text_id_map)\n\n    print(len(snippets))\n    process_utils.write_list(os.path.join(out_directory, \"%s.train.json\" % short_name), snippets)\n\nif __name__ == '__main__':\n    in_directory = sys.argv[1]\n    out_directory = sys.argv[2]\n    short_name = sys.argv[3]\n\n    main(in_directory, out_directory, short_name)\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_slsd.py",
    "content": "\"\"\"\nA small dataset of 1500 positive and 1500 negative sentences.\nSupposedly has no neutral sentences by design\n\nhttps://archive.ics.uci.edu/ml/datasets/Sentiment+Labelled+Sentences\n\nhttps://archive.ics.uci.edu/ml/machine-learning-databases/00331/\n\nSee the existing readme for citation requirements etc\n\nFiles in the slsd repo were one line per annotation, with labels 0\nfor negative and 1 for positive.  No neutral labels existed.\n\nAccordingly, we rearrange the text and adjust the label to fit the\n0/1/2 paradigm.  Text is retokenized using PTBTokenizer.\n\n<class> <sentence>\n\nprocess_slsd.py <directory> <outputfile>\n\"\"\"\n\nimport os\nimport sys\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\ndef get_phrases(in_directory):\n    in_filenames = [os.path.join(in_directory, 'amazon_cells_labelled.txt'),\n                    os.path.join(in_directory, 'imdb_labelled.txt'),\n                    os.path.join(in_directory, 'yelp_labelled.txt')]\n\n    lines = []\n    for filename in in_filenames:\n        lines.extend(open(filename, newline=''))\n\n    phrases = []\n    for line in lines:\n        line = line.strip()\n        sentiment = line[-1]\n        utterance = line[:-1]\n        utterance = utterance.replace(\"!.\", \"!\")\n        utterance = utterance.replace(\"?.\", \"?\")\n        if sentiment == '0':\n            sentiment = '0'\n        elif sentiment == '1':\n            sentiment = '2'\n        else:\n            raise ValueError(\"Unknown sentiment: {}\".format(sentiment))\n        phrases.append(SentimentDatum(sentiment, utterance))\n\n    return phrases\n\ndef get_tokenized_phrases(in_directory):\n    phrases = get_phrases(in_directory)\n    phrases = process_utils.get_ptb_tokenized_phrases(phrases)\n    print(\"Found %d phrases in slsd\" % len(phrases))\n    return phrases\n\ndef main(in_directory, out_directory, short_name):\n    phrases = get_tokenized_phrases(in_directory)\n    out_filename = os.path.join(out_directory, \"%s.train.json\" % short_name)\n    os.makedirs(out_directory, exist_ok=True)\n    process_utils.write_list(out_filename, phrases)\n\n\nif __name__ == '__main__':\n    in_directory = sys.argv[1]\n    out_directory = sys.argv[2]\n    short_name = sys.argv[3]\n    main(in_directory, out_directory, short_name)\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_sst.py",
    "content": "import argparse\nimport os\nimport subprocess\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\nimport stanza.utils.default_paths as default_paths\n\nTREEBANK_FILES = [\"train.txt\", \"dev.txt\", \"test.txt\", \"extra-train.txt\", \"checked-extra-train.txt\"]\n\nARGUMENTS = {\n    \"fiveclass\":      [],\n    \"root\":           [\"-root_only\"],\n    \"binary\":         [\"-ignore_labels\", \"2\", \"-remap_labels\", \"1=0,2=-1,3=1,4=1\"],\n    \"binaryroot\":     [\"-root_only\", \"-ignore_labels\", \"2\", \"-remap_labels\", \"1=0,2=-1,3=1,4=1\"],\n    \"threeclass\":     [\"-remap_labels\", \"0=0,1=0,2=1,3=2,4=2\"],\n    \"threeclassroot\": [\"-root_only\", \"-remap_labels\", \"0=0,1=0,2=1,3=2,4=2\"],\n}\n\n\ndef get_subtrees(input_file, *args):\n    \"\"\"\n    Use the CoreNLP OutputSubtrees tool to convert the input file to a bunch of phrases\n\n    Returns a list of the SentimentDatum namedtuple\n    \"\"\"\n    # TODO: maybe can convert this to use the python tree?\n    cmd = [\"java\", \"edu.stanford.nlp.trees.OutputSubtrees\", \"-input\", input_file]\n    if len(args) > 0:\n        cmd = cmd + list(args)\n    print (\" \".join(cmd))\n    results = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding=\"utf-8\")\n    lines = results.stdout.split(\"\\n\")\n    lines = [x.strip() for x in lines]\n    lines = [x for x in lines if x]\n    lines = [x.split(maxsplit=1) for x in lines]\n    phrases = [SentimentDatum(x[0], x[1].split()) for x in lines]\n    return phrases\n\ndef get_phrases(dataset, treebank_file, input_dir):\n    extra_args = ARGUMENTS[dataset]\n\n    input_file = os.path.join(input_dir, \"fiveclass\", treebank_file)\n    if not os.path.exists(input_file):\n        raise FileNotFoundError(input_file)\n    phrases = get_subtrees(input_file, *extra_args)\n    print(\"Found {} phrases in SST {} {}\".format(len(phrases), treebank_file, dataset))\n    return phrases\n\ndef convert_version(dataset, treebank_file, input_dir, output_dir):\n    \"\"\"\n    Convert the fiveclass files to a specific format\n\n    Uses the ARGUMENTS specific for the format wanted\n    \"\"\"\n    phrases = get_phrases(dataset, treebank_file, input_dir)\n    output_file = os.path.join(output_dir, \"en_sst.%s.%s.json\" % (dataset, treebank_file.split(\".\")[0]))\n    process_utils.write_list(output_file, phrases)\n\ndef parse_args():\n    \"\"\"\n    Actually, the only argument used right now is the formats to convert\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument('sections', type=str, nargs='*', help='Which transformations to use: {}'.format(\" \".join(ARGUMENTS.keys())))\n    args = parser.parse_args()\n    if not args.sections:\n        args.sections = list(ARGUMENTS.keys())\n    return args\n\ndef main():\n    args = parse_args()\n    paths = default_paths.get_default_paths()\n    input_dir = os.path.join(paths[\"SENTIMENT_BASE\"], \"sentiment-treebank\")\n    output_dir = paths[\"SENTIMENT_DATA_DIR\"]\n\n    os.makedirs(output_dir, exist_ok=True)\n    for section in args.sections:\n        for treebank_file in TREEBANK_FILES:\n            convert_version(section, treebank_file, input_dir, output_dir)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_usage_german.py",
    "content": "\"\"\"\nUSAGE is produced by the same people as SCARE.  \n\nUSAGE has a German and English part.  This script parses the German part.\nRun the script as \n  process_usage_german.py path\n\nHere, path should be where USAGE was unpacked.  It will have the\ndocuments, files, etc subdirectories.\n\nhttps://www.romanklinger.de/usagecorpus/\n\"\"\"\n\nimport csv\nimport glob\nimport os\nimport sys\n\nimport stanza\n\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\ndef main(in_directory, out_directory, short_name):\n    os.makedirs(out_directory, exist_ok=True)\n    nlp = stanza.Pipeline('de', processors='tokenize')\n\n    num_short_items = 0\n    snippets = []\n    csv_files = glob.glob(os.path.join(in_directory, \"files/de*csv\"))\n    for csv_filename in csv_files:\n        with open(csv_filename, newline='') as fin:\n            cin = csv.reader(fin, delimiter='\\t', quotechar=None)\n            lines = list(cin)\n\n            for index, line in enumerate(lines):\n                begin, end, snippet, sentiment = [line[i] for i in [2, 3, 4, 6]]\n                begin = int(begin)\n                end = int(end)\n                if len(snippet) != end - begin:\n                    raise ValueError(\"Error found in {} line {}.  Expected {} got {}\".format(csv_filename, index, (end-begin), len(snippet)))\n                if sentiment.lower() == 'unknown':\n                    continue\n                elif sentiment.lower() == 'positive':\n                    sentiment = 2\n                elif sentiment.lower() == 'neutral':\n                    sentiment = 1\n                elif sentiment.lower() == 'negative':\n                    sentiment = 0\n                else:\n                    raise ValueError(\"Tell John he screwed up and this is why he can't have Mox Opal: {}\".format(sentiment))\n                doc = nlp(snippet)\n                text = [token.text for sentence in doc.sentences for token in sentence.tokens]\n                num_tokens = sum(len(sentence.tokens) for sentence in doc.sentences)\n                if num_tokens < 4:\n                    num_short_items = num_short_items + 1\n                snippets.append(SentimentDatum(sentiment, text))\n\n    print(\"Total snippets found for USAGE: %d\" % len(snippets))\n\n    process_utils.write_list(os.path.join(out_directory, \"%s.train.json\" % short_name), snippets)\n\nif __name__ == '__main__':\n    in_directory = sys.argv[1]\n    out_directory = sys.argv[2]\n    short_name = sys.argv[3]\n\n    main(in_directory, out_directory, short_name)\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_utils.py",
    "content": "import csv\nimport glob\nimport json\nimport os\nimport tempfile\n\nfrom collections import namedtuple\n\nfrom tqdm import tqdm\n\nimport stanza\nfrom stanza.models.classifiers.data import SentimentDatum\n\nSplit = namedtuple('Split', ['filename', 'weight'])\n\nSHARDS = (\"train\", \"dev\", \"test\")\n\ndef write_list(out_filename, dataset):\n    \"\"\"\n    Write a list of items to the given output file\n\n    Expected: list(SentimentDatum)\n    \"\"\"\n    formatted_dataset = [line._asdict() for line in dataset]\n    # Rather than write the dataset at once, we write one line at a time\n    # Using `indent` puts each word on a separate line, which is rather noisy,\n    # but not formatting at all makes one long line out of an entire dataset,\n    # which is impossible to read\n    #json.dump(formatted_dataset, fout, indent=2, ensure_ascii=False)\n\n    with open(out_filename, 'w') as fout:\n        fout.write(\"[\\n\")\n        for idx, line in enumerate(formatted_dataset):\n            fout.write(\"  \")\n            json.dump(line, fout, ensure_ascii=False)\n            if idx < len(formatted_dataset) - 1:\n                fout.write(\",\")\n            fout.write(\"\\n\")\n        fout.write(\"]\\n\")\n\ndef write_dataset(dataset, out_directory, dataset_name):\n    \"\"\"\n    Write train, dev, test as .json files for a given dataset\n\n    dataset: 3 lists of sentiment tuples\n    \"\"\"\n    for shard, phrases in zip(SHARDS, dataset):\n        output_file = os.path.join(out_directory, \"%s.%s.json\" % (dataset_name, shard))\n        write_list(output_file, phrases)\n\ndef write_splits(out_directory, snippets, splits):\n    \"\"\"\n    Write the given list of items to the split files in the specified output directory\n    \"\"\"\n    total_weight = sum(split.weight for split in splits)\n    divs = []\n    subtotal = 0.0\n    for split in splits:\n        divs.append(int(len(snippets) * subtotal / total_weight))\n        subtotal = subtotal + split.weight\n    # the last div will be guaranteed to be the full thing - no math used\n    divs.append(len(snippets))\n\n    for i, split in enumerate(splits):\n        filename = os.path.join(out_directory, split.filename)\n        print(\"Writing {}:{} to {}\".format(divs[i], divs[i+1], filename))\n        write_list(filename, snippets[divs[i]:divs[i+1]])\n\ndef clean_tokenized_tweet(line):\n    line = list(line)\n    if len(line) > 3 and line[0] == 'RT' and line[1][0] == '@' and line[2] == ':':\n        line = line[3:]\n    elif len(line) > 4 and line[0] == 'RT' and line[1] == '@' and line[3] == ':':\n        line = line[4:]\n    elif line[0][0] == '@':\n        line = line[1:]\n    for i in range(len(line)):\n        if line[i][0] == '@' or line[i][0] == '#':\n            line[i] = line[i][1:]\n    line = [x for x in line if x and not x.startswith(\"http:\") and not x.startswith(\"https:\")]\n    return line\n\ndef get_ptb_tokenized_phrases(dataset):\n    \"\"\"\n    Use the PTB tokenizer to retokenize the phrases\n\n    Not clear which is better, \"Nov.\" or \"Nov .\"\n    strictAcronym=true makes it do the latter\n    tokenizePerLine=true should make it only pay attention to one line at a time\n\n    Phrases will be returned as lists of words rather than one string\n    \"\"\"\n    with tempfile.TemporaryDirectory() as tempdir:\n        phrase_filename = os.path.join(tempdir, \"phrases.txt\")\n        #phrase_filename = \"asdf.txt\"\n        with open(phrase_filename, \"w\", encoding=\"utf-8\") as fout:\n            for item in dataset:\n                # extra newlines are so the tokenizer treats the lines\n                # as separate sentences\n                fout.write(\"%s\\n\\n\\n\" % (item.text))\n        tok_filename = os.path.join(tempdir, \"tokenized.txt\")\n        os.system('java edu.stanford.nlp.process.PTBTokenizer -options \"strictAcronym=true,tokenizePerLine=true\" -preserveLines %s > %s' % (phrase_filename, tok_filename))\n        with open(tok_filename, encoding=\"utf-8\") as fin:\n            tokenized = fin.readlines()\n\n    tokenized = [x.strip() for x in tokenized]\n    tokenized = [x for x in tokenized if x]\n    phrases = [SentimentDatum(x.sentiment, y.split()) for x, y in zip(dataset, tokenized)]\n    return phrases\n\ndef process_datum(nlp, text, mapping, sentiment):\n    doc = nlp(text.strip())\n\n    converted_sentiment = mapping.get(sentiment, None)\n    if converted_sentiment is None:\n        raise ValueError(\"Value {} not in mapping at line {} of {}\".format(sentiment, idx, csv_filename))\n\n    text = []\n    for sentence in doc.sentences:\n        text.extend(token.text for token in sentence.tokens)\n    text = clean_tokenized_tweet(text)\n    return SentimentDatum(converted_sentiment, text)\n\ndef read_snippets(csv_filename, sentiment_column, text_column, tokenizer_language, mapping, delimiter='\\t', quotechar=None, skip_first_line=False, nlp=None, encoding=\"utf-8\"):\n    \"\"\"\n    Read in a single CSV file and return a list of SentimentDatums\n    \"\"\"\n    if nlp is None:\n        nlp = stanza.Pipeline(tokenizer_language, processors='tokenize')\n\n    with open(csv_filename, newline='', encoding=encoding) as fin:\n        if skip_first_line:\n            next(fin)\n        cin = csv.reader(fin, delimiter=delimiter, quotechar=quotechar)\n        lines = list(cin)\n\n    # Read in the data and parse it\n    snippets = []\n    for idx, line in enumerate(tqdm(lines)):\n        try:\n            if isinstance(sentiment_column, int):\n                sentiment = line[sentiment_column].lower()\n            else:\n                sentiment = tuple([line[x] for x in sentiment_column])\n        except IndexError as e:\n            raise IndexError(\"Columns {} did not exist at line {}: {}\".format(sentiment_column, idx, line)) from e\n        text = line[text_column]\n        datum = process_datum(nlp, text, mapping, sentiment)\n        snippets.append(datum)\n    return snippets\n\n"
  },
  {
    "path": "stanza/utils/datasets/sentiment/process_vsfc_vietnamese.py",
    "content": "\"\"\"\nVSFC sentiment dataset is available at\n  https://drive.google.com/drive/folders/1xclbjHHK58zk2X6iqbvMPS2rcy9y9E0X\n\nThe format is extremely similar to ours - labels are 0,1,2.\nText needs to be tokenized, though.\nAlso, the files are split into two pieces, labels and text.\n\"\"\"\n\nimport os\nimport sys\n\nfrom tqdm import tqdm\n\nimport stanza\nfrom stanza.models.classifiers.data import SentimentDatum\nimport stanza.utils.datasets.sentiment.process_utils as process_utils\n\nimport stanza.utils.default_paths as default_paths\n\ndef combine_columns(in_directory, dataset, nlp):\n    directory = os.path.join(in_directory, dataset)\n\n    sentiment_file = os.path.join(directory, \"sentiments.txt\")\n    with open(sentiment_file) as fin:\n        sentiment = fin.readlines()\n\n    text_file = os.path.join(directory, \"sents.txt\")\n    with open(text_file) as fin:\n        text = fin.readlines()\n\n    text = [[token.text for sentence in nlp(line.strip()).sentences for token in sentence.tokens]\n            for line in tqdm(text)]\n\n    phrases = [SentimentDatum(s.strip(), t) for s, t in zip(sentiment, text)]\n    return phrases\n\ndef main(in_directory, out_directory, short_name):\n    nlp = stanza.Pipeline('vi', processors='tokenize')\n    for shard in (\"train\", \"dev\", \"test\"):\n        phrases = combine_columns(in_directory, shard, nlp)\n        output_file = os.path.join(out_directory, \"%s.%s.json\" % (short_name, shard))\n        process_utils.write_list(output_file, phrases)\n\n\nif __name__ == '__main__':\n    paths = default_paths.get_default_paths()\n\n    if len(sys.argv) <= 1:\n        in_directory = os.path.join(paths['SENTIMENT_BASE'], \"vietnamese\", \"_UIT-VSFC\")\n    else:\n        in_directory = sys.argv[1]\n\n    if len(sys.argv) <= 2:\n        out_directory = paths['SENTIMENT_DATA_DIR']\n    else:\n        out_directory = sys.argv[2]\n\n    if len(sys.argv) <= 3:\n        short_name = 'vi_vsfc'\n    else:\n        short_name = sys.argv[3]\n\n    main(in_directory, out_directory, short_name)\n"
  },
  {
    "path": "stanza/utils/datasets/thai_syllable_dict_generator.py",
    "content": "import glob\nimport pathlib\nimport argparse\n\n\ndef create_dictionary(dataset_dir, save_dir):\n    syllables = set()\n\n    for p in pathlib.Path(dataset_dir).rglob(\"*.ssg\"): # iterate through all files\n\n        with open(p) as f: # for each file\n            sentences = f.readlines()\n\n        for i in range(len(sentences)):\n\n            sentences[i] = sentences[i].replace(\"\\n\", \"\")\n            sentences[i] = sentences[i].replace(\"<s/>\", \"~\")\n            sentences[i] = sentences[i].split(\"~\") # create list of all syllables\n\n            syllables = syllables.union(sentences[i])\n\n\n        print(len(syllables))\n\n    # Filter out syllables with English words\n    import re\n\n    a = []\n\n    for s in syllables:\n        print(\"---\")\n        if bool(re.match(\"^[\\u0E00-\\u0E7F]*$\", s)) and s != \"\" and \" \" not in s:\n            a.append(s)\n        else:\n            pass\n\n    a = set(a)\n    a = dict(zip(list(a), range(len(a))))\n\n    import json\n    print(a)\n    print(len(a))\n    with open(save_dir, \"w\") as fp:\n        json.dump(a, fp)\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--dataset_dir', type=str, default=\"syllable_segmentation_data\", help=\"Directory for syllable dataset\")\n    parser.add_argument('--save_dir', type=str, default=\"thai-syllable.json\", help=\"Directory for generated file\")\n    args = parser.parse_args()\n\n    create_dictionary(args.dataset_dir, args.save_dir)\n"
  },
  {
    "path": "stanza/utils/datasets/tokenization/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/tokenization/convert_ml_cochin.py",
    "content": "\"\"\"\nConvert a Malayalam NER dataset to a tokenization dataset using\nthe additional labeling provided by TTec's Indian partners\n\nThis is still WIP - ongoing discussion with TTec and the team at UFAL\ndoing the UD Malayalam dataset - but if someone wants the data to\nrecreate it, feel free to contact Prof. Manning or John Bauer\n\nData was annotated through Datasaur by TTec - possibly another team\ninvolved, will double check with the annotators.\n\n#1 current issue with the data is a difference in annotation style\nobserved by the UFAL group.  I believe TTec is working on reannotating\nthis.\n\nDiscussing the first sentence in the first split file:\n\n> I am not sure about the guidelines that the annotators followed, but\n> I would not have split നാമജപത്തോടുകൂടി as നാമ --- ജപത്തോടുകൂടി. Because\n> they are not multiple syntactic words. I would have done it like\n> നാമജപത്തോടു --- കൂടി as കൂടി ('with') can be tagged as ADP. I agree with\n> the second MWT വ്യത്യസ്തം --- കൂടാതെ.\n>\n> In Malayalam, we do have many words which potentially can be treated\n> as compounds and split but sometimes it becomes difficult to make\n> that decision as the etymology or the word formation process is\n> unclear. So for the Malayalam UD annotations I stayed away from\n> doing it because I didn't find it necessary and moreover the\n> guidelines say that the words should be split into syntactic words\n> and not into morphemes.\n\nAs for using this script, create a directory extern_data/malayalam/cochin_ner/\nThe original NER dataset from Cochin University going there:\nextern_data/malayalam/cochin_ner/final_ner.txt\nThe relabeled data from TTEC goes in\nextern_data/malayalam/cochin_ner/relabeled_tsv/malayalam_File_1.txt.tsv etc etc\n\nThis can be invoked from the command line, or it can be used as part of\nstanza/utils/datasets/prepare_tokenizer_treebank.py ml_cochin\nin which case the conll splits will be turned into tokenizer labels as well\n\"\"\"\n\nfrom difflib import SequenceMatcher\nimport os\nimport random\nimport sys\n\nimport stanza.utils.default_paths as default_paths\n\ndef read_words(filename):\n    with open(filename, encoding=\"utf-8\") as fin:\n        text = fin.readlines()\n        text = [x.strip().split()[0] if x.strip() else \"\" for x in text]\n        return text\n\ndef read_original_text(input_dir):\n    original_file = os.path.join(input_dir, \"final_ner.txt\")\n    return read_words(original_file)\n\ndef list_relabeled_files(relabeled_dir):\n    tsv_files = os.listdir(relabeled_dir)\n    assert all(x.startswith(\"malayalam_File_\") and x.endswith(\".txt.tsv\") for x in tsv_files)\n    tsv_files = sorted(tsv_files, key = lambda filename: int(filename.split(\".\")[0].split(\"_\")[2]))\n    return tsv_files\n\ndef find_word(original_text, target, start_index, end_index):\n    for word in original_text[start_index:end_index]:\n        if word == target:\n            return True\n    return False\n\ndef scan_file(original_text, current_index, tsv_file):\n    relabeled_text = read_words(tsv_file)\n    # for now, at least, we ignore these markers\n    relabeled_indices = [idx for idx, x in enumerate(relabeled_text) if x != '$' and x != '^']\n    relabeled_text = [x for x in relabeled_text if x != '$' and x != '^']\n    diffs = SequenceMatcher(None, original_text, relabeled_text, False)\n\n    blocks = diffs.get_matching_blocks()\n    assert blocks[-1].size == 0\n    if len(blocks) == 1:\n        raise ValueError(\"Could not find a match between %s and the original text\" % tsv_file)\n\n    sentences = []\n    current_sentence = []\n\n    in_mwt = False\n    bad_sentence = False\n    current_mwt = []\n    block_index = 0\n    current_block = blocks[0]\n    for tsv_index, next_word in enumerate(relabeled_text):\n        if not next_word:\n            if in_mwt:\n                current_mwt = []\n                in_mwt = False\n                bad_sentence = True\n                print(\"Unclosed MWT found at %s line %d\" % (tsv_file, tsv_index))\n            if current_sentence:\n                if not bad_sentence:\n                    sentences.append(current_sentence)\n                bad_sentence = False\n                current_sentence = []\n            continue\n\n        # tsv_index will now be inside the current block or before the current block\n        while tsv_index >= blocks[block_index].b + current_block.size:\n            block_index += 1\n            current_block = blocks[block_index]\n        #print(tsv_index, current_block.b, current_block.size)\n\n        if next_word == ',' or next_word == '.':\n            # many of these punctuations were added by the relabelers\n            current_sentence.append(next_word)\n            continue\n        if tsv_index >= current_block.b and tsv_index < current_block.b + current_block.size:\n            # ideal case: in a matching block\n            current_sentence.append(next_word)\n            continue\n\n        # in between blocks... need to handle re-spelled words and MWTs\n        if not in_mwt and next_word == '@':\n            in_mwt = True\n            continue\n        if not in_mwt:\n            current_sentence.append(next_word)\n            continue\n        if in_mwt and next_word == '@' and (tsv_index + 1 < len(relabeled_text) and relabeled_text[tsv_index+1] == '@'):\n            # we'll stop the MWT next time around\n            continue\n        if in_mwt and next_word == '@':\n            if block_index > 0 and (len(current_mwt) == 2 or len(current_mwt) == 3):\n                mwt = \"\".join(current_mwt)\n                start_original = blocks[block_index-1].a + blocks[block_index-1].size\n                end_original = current_block.a\n                if find_word(original_text, mwt, start_original, end_original):\n                    current_sentence.append((mwt, current_mwt))\n                else:\n                    print(\"%d word MWT %s at %s %d.  Should be somewhere in %d %d\" % (len(current_mwt), mwt, tsv_file, relabeled_indices[tsv_index], start_original, end_original))\n                    bad_sentence = True\n            elif len(current_mwt) > 6:\n                raise ValueError(\"Unreasonably long MWT span in %s at line %d\" % (tsv_file, relabeled_indices[tsv_index]))\n            elif len(current_mwt) > 3:\n                print(\"%d word sequence, stop being lazy - %s %d\" % (len(current_mwt), tsv_file, relabeled_indices[tsv_index]))\n                bad_sentence = True\n            else:\n                # short MWT, but it was at the start of a file, and we don't want to search the whole file for the item\n                # TODO, could maybe search the 10 words or so before the start of the block?\n                bad_sentence = True\n            current_mwt = []\n            in_mwt = False\n            continue\n        # now we know we are in an MWT... TODO\n        current_mwt.append(next_word)\n\n    if len(current_sentence) > 0 and not bad_sentence:\n        sentences.append(current_sentence)\n\n    return current_index, sentences\n\ndef split_sentences(sentences):\n    train = []\n    dev = []\n    test = []\n\n    for sentence in sentences:\n        rand = random.random()\n        if rand < 0.8:\n            train.append(sentence)\n        elif rand < 0.9:\n            dev.append(sentence)\n        else:\n            test.append(sentence)\n\n    return train, dev, test\n\ndef main(input_dir, tokenizer_dir, relabeled_dir=\"relabeled_tsv\", split_data=True):\n    random.seed(1006)\n\n    input_dir = os.path.join(input_dir, \"malayalam\", \"cochin_ner\")\n    relabeled_dir = os.path.join(input_dir, relabeled_dir)\n    tsv_files = list_relabeled_files(relabeled_dir)\n\n    original_text = read_original_text(input_dir)\n    print(\"Original text len: %d\" %len(original_text))\n    current_index = 0\n    sentences = []\n    for tsv_file in tsv_files:\n        print(tsv_file)\n        current_index, new_sentences = scan_file(original_text, current_index, os.path.join(relabeled_dir, tsv_file))\n        sentences.extend(new_sentences)\n\n    print(\"Found %d sentences\" % len(sentences))\n\n    if split_data:\n        splits = split_sentences(sentences)\n        SHARDS = (\"train\", \"dev\", \"test\")\n    else:\n        splits = [sentences]\n        SHARDS = [\"train\"]\n\n    for split, shard in zip(splits, SHARDS):\n        output_filename = os.path.join(tokenizer_dir, \"ml_cochin.%s.gold.conllu\" % shard)\n        print(\"Writing %d sentences to %s\" % (len(split), output_filename))\n        with open(output_filename, \"w\", encoding=\"utf-8\") as fout:\n            for sentence in split:\n                word_idx = 1\n                for token in sentence:\n                    if isinstance(token, str):\n                        fake_dep = \"\\t0\\troot\" if word_idx == 1 else \"\\t1\\tdep\"\n                        fout.write(\"%d\\t%s\" % (word_idx, token) + \"\\t_\" * 4 + fake_dep + \"\\t_\\t_\\n\")\n                        word_idx += 1\n                    else:\n                        text = token[0]\n                        mwt = token[1]\n                        fout.write(\"%d-%d\\t%s\" % (word_idx, word_idx + len(mwt) - 1, text) + \"\\t_\" * 8 + \"\\n\")\n                        for piece in mwt:\n                            fake_dep = \"\\t0\\troot\" if word_idx == 1 else \"\\t1\\tdep\"\n                            fout.write(\"%d\\t%s\" % (word_idx, piece) + \"\\t_\" * 4 + fake_dep + \"\\t_\\t_\\n\")\n                            word_idx += 1\n                fout.write(\"\\n\")\n\nif __name__ == '__main__':\n    sys.stdout.reconfigure(encoding='utf-8')\n    paths = default_paths.get_default_paths()\n    tokenizer_dir = paths[\"TOKENIZE_DATA_DIR\"]\n    input_dir = paths[\"STANZA_EXTERN_DIR\"]\n    main(input_dir, tokenizer_dir, \"relabeled_tsv_v2\", False)\n\n"
  },
  {
    "path": "stanza/utils/datasets/tokenization/convert_my_alt.py",
    "content": "\"\"\"Converts the Myanmar ALT corpus to a tokenizer dataset.\n\nThe ALT corpus is in the form of constituency trees, which basically\nmeans there is no guidance on where the whitespace belongs.  However,\nin Myanmar writing, whitespace is apparently not actually required\nanywhere.  The plan will be to make sentences where there is no\nwhitespace at all, along with a random selection of sentences\nwhere some whitespace is randomly inserted.\n\nThe treebank is available here:\n\nhttps://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/\n\nThe following files describe the splits of the data:\n\nhttps://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-train.txt\nhttps://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-dev.txt\nhttps://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-test.txt\n\nand this is the actual treebank:\n\nhttps://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/my-alt-190530.zip\n\nDownload each of the files, then unzip the my-alt zip in place.\nThe expectation is this will produce a file\n  my-alt-190530/data\n\nThe default expected path to the Myanmar data is\n  extern_data/constituency/myanmar/my_alt/my-alt-190530/data\n\"\"\"\n\nimport os\nimport random\n\nfrom stanza.models.constituency.tree_reader import read_trees\n\ndef read_split(input_dir, section):\n    \"\"\"\n    Reads the split description for train, dev, or test\n\n    Format (at least for the Myanmar section of ALT) is:\n      one description per line\n      each line is URL.<number> <URL>\n      we actually don't care about the URL itself\n      all we want is the number, which we use to split up\n      the tree file later\n\n    Returns a set of numbers (as strings)\n    \"\"\"\n    urls = set()\n    filename = os.path.join(input_dir, \"myanmar\", \"my_alt\", \"URL-%s.txt\" % section)\n    with open(filename) as fin:\n        lines = fin.readlines()\n    for line in lines:\n        line = line.strip()\n        if not line or not line.startswith(\"URL\"):\n            continue\n        # split into URL.100161 and a bunch of description we don't care about\n        line = line.split(maxsplit=1)\n        # get just the number\n        line = line[0].split(\".\")\n        assert len(line) == 2\n        assert line[0] == 'URL'\n        urls.add(line[1])\n    return urls\n    \nSPLITS = (\"train\", \"dev\", \"test\")\n\ndef read_dataset_splits(input_dir):\n    \"\"\"\n    Call read_split for train, dev, and test\n\n    Returns three sets: train, dev, test in order\n    \"\"\"\n    url_splits = [read_split(input_dir, section) for section in SPLITS]\n    for url_split, split in zip(url_splits, SPLITS):\n        print(\"Split %s has %d files in it\" % (split, len(url_split)))\n    return url_splits\n\ndef read_alt_treebank(constituency_input_dir):\n    \"\"\"\n    Read the splits, read the trees, and split the trees based on the split descriptions\n\n    Trees in ALT are:\n      <tree id> <tree brackets>\n    The tree id will look like\n      SNT.<url_id>.<line>\n    All we care about from this id is the url_id, which we crossreference in the splits\n    to figure out which split the tree is in.\n\n    The tree itself we don't process much, although we do convert it to a ParseTree\n\n    The result is three lists: train, dev, test trees\n    \"\"\"\n    train_split, dev_split, test_split = read_dataset_splits(constituency_input_dir)\n\n    datafile = os.path.join(constituency_input_dir, \"myanmar\", \"my_alt\", \"my-alt-190530\", \"data\")\n    print(\"Reading trees from %s\" % datafile)\n    with open(datafile) as fin:\n        tree_lines = fin.readlines()\n\n    train_trees = []\n    dev_trees = []\n    test_trees = []\n\n    for idx, tree_line in enumerate(tree_lines):\n        tree_line = tree_line.strip()\n        if not tree_line:\n            continue\n        dataset, tree_text = tree_line.split(maxsplit=1)\n        dataset = dataset.split(\".\", 2)[1]\n\n        trees = read_trees(tree_text)\n        if len(trees) != 1:\n            raise ValueError(\"Unexpected number of trees in line %d: %d\" % (idx, len(trees)))\n        tree = trees[0]\n\n        if dataset in train_split:\n            train_trees.append(tree)\n        elif dataset in dev_split:\n            dev_trees.append(tree)\n        elif dataset in test_split:\n            test_trees.append(tree)\n        else:\n            raise ValueError(\"Could not figure out which split line %d belongs to\" % idx)\n\n    return train_trees, dev_trees, test_trees\n\ndef write_sentence(fout, words, spaces):\n    \"\"\"\n    Write a sentence based on the list of words.\n\n    spaces is a fraction of the words which should randomly have spaces\n    If 0.0, none of the words will have spaces\n    This is because the Myanmar language doesn't require spaces, but\n      spaces always separate words\n    \"\"\"\n    full_text = \"\".join(words)\n    fout.write(\"# text = %s\\n\" % full_text)\n\n    for word_idx, word in enumerate(words):\n        fake_dep = \"root\" if word_idx == 0 else \"dep\"\n        fout.write(\"%d\\t%s\\t%s\" % ((word_idx+1), word, word))\n        fout.write(\"\\t_\\t_\\t_\")\n        fout.write(\"\\t%d\\t%s\" % (word_idx, fake_dep))\n        fout.write(\"\\t_\\t\")\n        if random.random() > spaces:\n            fout.write(\"SpaceAfter=No\")\n        else:\n            fout.write(\"_\")\n        fout.write(\"\\n\")\n    fout.write(\"\\n\")\n\n\ndef write_dataset(filename, trees, split):\n    \"\"\"\n    Write all of the trees to the given filename\n    \"\"\"\n    count = 0\n    with open(filename, \"w\") as fout:\n        # TODO: make some fraction have random spaces inserted\n        for tree in trees:\n            count = count + 1\n            words = tree.leaf_labels()\n            write_sentence(fout, words, spaces=0.0)\n            # We include a small number of spaces to teach the model\n            # that spaces always separate a word\n            if split == 'train' and random.random() < 0.1:\n                count = count + 1\n                write_sentence(fout, words, spaces=0.05)\n    print(\"Wrote %d sentences from %d trees to %s\" % (count, len(trees), filename))\n\ndef convert_my_alt(constituency_input_dir, tokenizer_dir):\n    \"\"\"\n    Read and then convert the Myanmar ALT treebank\n    \"\"\"\n    random.seed(1234)\n    tree_splits = read_alt_treebank(constituency_input_dir)\n\n    output_filenames = [os.path.join(tokenizer_dir, \"my_alt.%s.gold.conllu\") % split for split in SPLITS]\n\n    for filename, trees, split in zip(output_filenames, tree_splits, SPLITS):\n        write_dataset(filename, trees, split)\n\ndef main():\n    convert_my_alt(\"extern_data/constituency\", \"data/tokenize\")\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/tokenization/convert_text_files.py",
    "content": "\"\"\"\nGiven a text file and a file with one word per line, convert the text file\n\nSentence splits should be represented as blank lines at the end of a sentence.\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nfrom stanza.models.tokenization.utils import match_tokens_with_text\nimport stanza.utils.datasets.common as common\n\ndef read_tokens_file(token_file):\n    \"\"\"\n    Returns a list of list of tokens\n\n    Each sentence is a list of tokens\n    \"\"\"\n    sentences = []\n    current_sentence = []\n    with open(token_file, encoding=\"utf-8\") as fin:\n        for line in fin:\n            line = line.strip()\n            if not line:\n                if current_sentence:\n                    sentences.append(current_sentence)\n                    current_sentence = []\n            else:\n                current_sentence.append(line)\n        if current_sentence:\n            sentences.append(current_sentence)\n\n    return sentences\n\ndef read_sentences_file(sentence_file):\n    sentences = []\n    with open(sentence_file, encoding=\"utf-8\") as fin:\n        for line in fin:\n            line = line.strip()\n            if not line:\n                continue\n            sentences.append(line)\n    return sentences\n\ndef process_raw_file(text_file, token_file, sentence_file, base_sent_idx=0):\n    \"\"\"\n    Process a text file separated into a list of tokens using match_tokens_with_text from the tokenizer\n\n    The tokens are one per line in the token_file\n    The tokens in the token_file must add up to the text_file modulo whitespace.\n\n    Sentences are also one per line in the sentence_file\n    These must also add up to text_file\n\n    The return format is a list of list of conllu lines representing the sentences.\n    The only fields set will be the token index, the token text, and possibly SpaceAfter=No\n    where SpaceAfter=No is true if the next token started with no whitespace in the text file\n    \"\"\"\n    with open(text_file, encoding=\"utf-8\") as fin:\n        text = fin.read()\n\n    tokens = read_tokens_file(token_file)\n    tokens = [[token for sentence in tokens for token in sentence]]\n    tokens_doc = match_tokens_with_text(tokens, text)\n\n    assert len(tokens_doc.sentences) == 1\n    assert len(tokens_doc.sentences[0].tokens) == len(tokens[0])\n\n    sentences = read_sentences_file(sentence_file)\n    sentences_doc = match_tokens_with_text([sentences], text)\n\n    assert len(sentences_doc.sentences) == 1\n    assert len(sentences_doc.sentences[0].tokens) == len(sentences)\n\n    start_token_idx = 0\n    sentences = []\n    for sent_idx, sentence in enumerate(sentences_doc.sentences[0].tokens):\n        tokens = []\n        tokens.append(\"# sent_id = %d\" % (base_sent_idx + sent_idx + 1))\n        tokens.append(\"# text = %s\" % text[sentence.start_char:sentence.end_char].replace(\"\\n\", \" \"))\n        token_idx = 0\n        while token_idx + start_token_idx < len(tokens_doc.sentences[0].tokens):\n            token = tokens_doc.sentences[0].tokens[token_idx + start_token_idx]\n            if token.start_char >= sentence.end_char:\n                # have reached the end of this sentence\n                # continue with the next sentence\n                start_token_idx += token_idx\n                break\n\n            if token_idx + start_token_idx == len(tokens_doc.sentences[0].tokens) - 1:\n                # definitely the end of the document\n                space_after = True\n            elif token.end_char == tokens_doc.sentences[0].tokens[token_idx + start_token_idx + 1].start_char:\n                space_after = False\n            else:\n                space_after = True\n            token = [str(token_idx+1), token.text] + [\"_\"] * 7 + [\"_\" if space_after else \"SpaceAfter=No\"]\n            assert len(token) == 10, \"Token length: %d\" % len(token)\n            token = \"\\t\".join(token)\n            tokens.append(token)\n            token_idx += 1\n        sentences.append(tokens)\n    return sentences\n\ndef extract_sentences(dataset_files):\n    sentences = []\n    for text_file, token_file, sentence_file in dataset_files:\n        print(\"Extracting sentences from %s and tokens from %s from the text file %s\" % (sentence_file, token_file, text_file))\n        sentences.extend(process_raw_file(text_file, token_file, sentence_file, len(sentences)))\n    return sentences\n\ndef split_sentences(sentences, train_split=0.8, dev_split=0.1):\n    \"\"\"\n    Splits randomly without shuffling\n    \"\"\"\n    generator = random.Random(1234)\n\n    train = []\n    dev = []\n    test = []\n    for sentence in sentences:\n        r = generator.random()\n        if r < train_split:\n            train.append(sentence)\n        elif r < train_split + dev_split:\n            dev.append(sentence)\n        else:\n            test.append(sentence)\n    return (train, dev, test)\n\ndef find_dataset_files(input_path, token_prefix, sentence_prefix):\n    files = os.listdir(input_path)\n    print(\"Found %d files in %s\" % (len(files), input_path))\n    if len(files) > 0:\n        if len(files) < 20:\n            print(\"Files:\", end=\"\\n  \")\n        else:\n            print(\"First few files:\", end=\"\\n  \")\n        print(\"\\n  \".join(files[:20]))\n    token_files = {}\n    sentence_files = {}\n    text_files = []\n    for filename in files:\n        if filename.endswith(\".zip\"):\n            continue\n        if filename.startswith(token_prefix):\n            short_filename = filename[len(token_prefix):]\n            if short_filename.startswith(\"_\"):\n                short_filename = short_filename[1:]\n            token_files[short_filename] = filename\n        elif filename.startswith(sentence_prefix):\n            short_filename = filename[len(sentence_prefix):]\n            if short_filename.startswith(\"_\"):\n                short_filename = short_filename[1:]\n            sentence_files[short_filename] = filename\n        else:\n            text_files.append(filename)\n    dataset_files = []\n    for filename in text_files:\n        if filename not in token_files:\n            raise FileNotFoundError(\"When looking in %s, found %s as a text file, but did not find a corresponding tokens file at %s_%s  Please give an input directory which has only the text files, tokens files, and sentences files\" % (input_path, filename, token_prefix, filename))\n        if filename not in sentence_files:\n            raise FileNotFoundError(\"When looking in %s, found %s as a text file, but did not find a corresponding sentences file at %s_%s  Please give an input directory which has only the text files, tokens files, and sentences files\" % (input_path, filename, sentence_prefix, filename))\n        text_file = os.path.join(input_path, filename)\n        token_file = os.path.join(input_path, token_files[filename])\n        sentence_file = os.path.join(input_path, sentence_files[filename])\n        dataset_files.append((text_file, token_file, sentence_file))\n    return dataset_files\n\nSHARDS = (\"train\", \"dev\", \"test\")\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--token_prefix', type=str, default=\"tkns\", help=\"Prefix for the token files\")\n    parser.add_argument('--sentence_prefix', type=str, default=\"stns\", help=\"Prefix for the token files\")\n    parser.add_argument('--input_path', type=str, default=\"extern_data/sindhi/tokenization\", help=\"Where to find all of the input files.  Files with the prefix tkns_ will be treated as token files, files with the prefix stns_ will be treated as sentence files, and all others will be the text files.\")\n    parser.add_argument('--output_path', type=str, default=\"data/tokenize\", help=\"Where to output the results\")\n    parser.add_argument('--dataset', type=str, default=\"sd_isra\", help=\"What name to give this dataset\")\n    args = parser.parse_args()\n\n    dataset_files = find_dataset_files(args.input_path, args.token_prefix, args.sentence_prefix)\n\n    tokenizer_dir = args.output_path\n    short_name = args.dataset  # todo: convert a full name?\n\n    sentences = extract_sentences(dataset_files)\n    splits = split_sentences(sentences)\n\n    os.makedirs(args.output_path, exist_ok=True)\n    for dataset, shard in zip(splits, SHARDS):\n        output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, shard)\n        common.write_sentences_to_conllu(output_conllu, dataset)\n\n    common.convert_conllu_to_txt(tokenizer_dir, short_name)\n    common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/tokenization/convert_th_best.py",
    "content": "\"\"\"Parses the BEST Thai dataset.\n\nThat is to say, the dataset named BEST.  We have not yet figured out\nwhich segmentation standard we prefer.\n\nNote that the version of BEST we used actually had some strange\nsentence splits according to a native Thai speaker.  Not sure how to\nfix that.  Options include doing it automatically or finding some\nknowledgable annotators to resplit it for us (or just not using BEST)\n\nThis outputs the tokenization results in a conll format similar to\nthat of the UD treebanks, so we pretend to be a UD treebank for ease\nof compatibility with the stanza tools.\n\nBEST can be downloaded from here:\n\nhttps://aiforthai.in.th/corpus.php\n\npython3 -m stanza.utils.datasets.tokenization.process_best extern_data/thai/best data/tokenize\n./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000\n\"\"\"\nimport glob\nimport os\nimport random\nimport re\nimport sys\n\ntry:\n    from pythainlp import sent_tokenize\nexcept ImportError:\n    pass\n\nfrom stanza.utils.datasets.tokenization.process_thai_tokenization import reprocess_lines, write_dataset, convert_processed_lines, write_dataset_best, write_dataset\n\ndef clean_line(line):\n    line = line.replace(\"html>\", \"html|>\")\n    # news_00089.txt\n    line = line.replace(\"<NER>\", \"<NE>\")\n    line = line.replace(\"</NER>\", \"</NE>\")\n    # specific error that occurs in encyclopedia_00095.txt\n    line = line.replace(\"</AB>Penn\", \"</AB>|Penn>\")\n    # news_00058.txt\n    line = line.replace(\"<AB>จม.</AB>เปิดผนึก\", \"<AB>จม.</AB>|เปิดผนึก\")\n    # news_00015.txt\n    line = re.sub(\"<NE><AB>([^|<>]+)</AB>([^|<>]+)</NE>\", \"\\\\1|\\\\2\", line)\n    # news_00024.txt\n    line = re.sub(\"<NE><AB>([^|<>]+)</AB></NE>\", \"\\\\1\", line)\n    # news_00055.txt\n    line = re.sub(\"<NE>([^|<>]+)<AB>([^|<>]+)</AB></NE>\", \"\\\\1|\\\\2\", line)\n    line = re.sub(\"<NE><AB>([^|<>]+)</AB><AB>([^|<>]+)</AB></NE>\", \"\\\\1|\\\\2\", line)\n    line = re.sub(\"<NE>([^|<>]+)<AB>([^|<>]+)</AB> <AB>([^|<>]+)</AB></NE>\", \"\\\\1|\\\\2|\\\\3\", line)\n    # news_00008.txt and other news articles\n    line = re.sub(\"</AB>([0-9])\", \"</AB>|\\\\1\", line)\n    line = line.replace(\"</AB> \", \"</AB>|\")\n    line = line.replace(\"<EM>\", \"<POEM>\")\n    line = line.replace(\"</EM>\", \"</POEM>\")\n    line = line.strip()\n    return line\n\n\ndef clean_word(word):\n    # novel_00078.txt\n    if word == '<NEพี่มน</NE>':\n        return 'พี่มน'\n    if word.startswith(\"<NE>\") and word.endswith(\"</NE>\"):\n        return word[4:-5]\n    if word.startswith(\"<AB>\") and word.endswith(\"</AB>\"):\n        return word[4:-5]\n    if word.startswith(\"<POEM>\") and word.endswith(\"</POEM>\"):\n        return word[6:-7]\n    \"\"\"\n    if word.startswith(\"<EM>\"):\n        return word[4:]\n    if word.endswith(\"</EM>\"):\n        return word[:-5]\n    \"\"\"\n    if word.startswith(\"<NE>\"):\n        return word[4:]\n    if word.endswith(\"</NE>\"):\n        return word[:-5]\n    if word.startswith(\"<POEM>\"):\n        return word[6:]\n    if word.endswith(\"</POEM>\"):\n        return word[:-7]\n    if word == '<':\n        return word\n    return word\n\ndef read_data(input_dir):\n    # data for test sets\n    test_files = [os.path.join(input_dir, 'TEST_100K_ANS.txt')]\n    print(test_files)\n\n    # data for train and dev sets\n    subdirs = [os.path.join(input_dir, 'article'),\n               os.path.join(input_dir, 'encyclopedia'),\n               os.path.join(input_dir, 'news'),\n               os.path.join(input_dir, 'novel')]\n    files = []\n    for subdir in subdirs:\n        if not os.path.exists(subdir):\n            raise FileNotFoundError(\"Expected a directory that did not exist: {}\".format(subdir))\n        files.extend(glob.glob(os.path.join(subdir, '*.txt')))\n\n    test_documents = []\n    for filename in test_files:\n        print(\"File name:\", filename)\n        with open(filename) as fin:\n            processed_lines = []\n            for line in fin.readlines():\n                line = clean_line(line)\n                words = line.split(\"|\")\n                words = [clean_word(x) for x in words]\n                for word in words:\n                    if len(word) > 1 and word[0] == '<':\n                        raise ValueError(\"Unexpected word '{}' in document {}\".format(word, filename))\n                words = [x for x in words if x]\n                processed_lines.append(words)\n\n            processed_lines = reprocess_lines(processed_lines)\n            paragraphs = convert_processed_lines(processed_lines)\n\n            test_documents.extend(paragraphs)\n    print(\"Test document finished.\")\n\n    documents = []\n\n    for filename in files:\n        with open(filename) as fin:\n            print(\"File:\", filename)\n            processed_lines = []\n            for line in fin.readlines():\n                line = clean_line(line)\n                words = line.split(\"|\")\n                words = [clean_word(x) for x in words]\n                for word in words:\n                    if len(word) > 1 and word[0] == '<':\n                        raise ValueError(\"Unexpected word '{}' in document {}\".format(word, filename))\n                words = [x for x in words if x]\n                processed_lines.append(words)\n\n            processed_lines = reprocess_lines(processed_lines)\n            paragraphs = convert_processed_lines(processed_lines)\n\n            documents.extend(paragraphs)\n\n    print(\"All documents finished.\")\n\n    return documents, test_documents\n\n\ndef main(*args):\n    random.seed(1000)\n    if not args:\n        args = sys.argv[1:]\n\n    input_dir = args[0]\n    full_input_dir = os.path.join(input_dir, \"thai\", \"best\")\n    if os.path.exists(full_input_dir):\n        # otherwise hopefully the user gave us the full path?\n        input_dir = full_input_dir\n\n    output_dir = args[1]\n    documents, test_documents = read_data(input_dir)\n    print(\"Finished reading data.\")\n    write_dataset_best(documents, test_documents, output_dir, \"best\")\n\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "stanza/utils/datasets/tokenization/convert_th_lst20.py",
    "content": "\"\"\"Processes the tokenization section of the LST20 Thai dataset\n\nThe dataset is available here:\n\nhttps://aiforthai.in.th/corpus.php\n\nThe data should be installed under ${EXTERN_DATA}/thai/LST20_Corpus\n\npython3 -m stanza.utils.datasets.tokenization.convert_th_lst20 extern_data data/tokenize\n\nUnlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train/dev/test.\n\n./scripts/run_tokenize.sh UD_Thai-lst20 --dropout 0.05 --unit_dropout 0.05\n\"\"\"\n\n\nimport argparse\nimport glob\nimport os\nimport sys\n\nfrom stanza.utils.datasets.tokenization.process_thai_tokenization import write_section, convert_processed_lines, reprocess_lines\n\ndef read_document(lines, spaces_after, split_clauses):\n    document = []\n    sentence = []\n    for line in lines:\n        line = line.strip()\n        if not line:\n            if sentence:\n                if spaces_after:\n                    sentence[-1] = (sentence[-1][0], True)\n                document.append(sentence)\n                sentence = []\n        else:\n            pieces = line.split(\"\\t\")\n            # there are some nbsp in tokens in lst20, but the downstream tools expect spaces\n            pieces = [p.replace(\"\\xa0\", \" \") for p in pieces]\n            if split_clauses and pieces[0] == '_' and pieces[3] == 'O':\n                if sentence:\n                    # note that we don't need to check spaces_after\n                    # the \"token\" is a space anyway\n                    sentence[-1] = (sentence[-1][0], True)\n                    document.append(sentence)\n                    sentence = []\n            elif pieces[0] == '_':\n                sentence[-1] = (sentence[-1][0], True)\n            else:\n                sentence.append((pieces[0], False))\n\n    if sentence:\n        if spaces_after:\n            sentence[-1] = (sentence[-1][0], True)\n        document.append(sentence)\n        sentence = []\n    # TODO: is there any way to divide up a single document into paragraphs?\n    return [[document]]\n\ndef retokenize_document(lines):\n    processed_lines = []\n    sentence = []\n    for line in lines:\n        line = line.strip()\n        if not line:\n            if sentence:\n                processed_lines.append(sentence)\n                sentence = []\n        else:\n            pieces = line.split(\"\\t\")\n            if pieces[0] == '_':\n                sentence.append(' ')\n            else:\n                sentence.append(pieces[0])\n    if sentence:\n        processed_lines.append(sentence)\n\n    processed_lines = reprocess_lines(processed_lines)\n    paragraphs = convert_processed_lines(processed_lines)\n    return paragraphs\n\n\ndef read_data(input_dir, section, resegment, spaces_after, split_clauses):\n    glob_path = os.path.join(input_dir, section, \"*.txt\")\n    filenames = glob.glob(glob_path)\n    print(\"  Found {} files in {}\".format(len(filenames), glob_path))\n    if len(filenames) == 0:\n        raise FileNotFoundError(\"Could not find any files for the {} section.  Is LST20 installed in {}?\".format(section, input_dir))\n    documents = []\n    for filename in filenames:\n        with open(filename) as fin:\n            lines = fin.readlines()\n        if resegment:\n            document = retokenize_document(lines)\n        else:\n            document = read_document(lines, spaces_after, split_clauses)\n        documents.extend(document)\n    return documents\n\ndef add_lst20_args(parser):\n    parser.add_argument('--no_lst20_resegment', action='store_false', dest=\"lst20_resegment\", default=True, help='When processing th_lst20 tokenization, use pythainlp to resegment the text.  The other option is to keep the original sentence segmentation.  Currently our model is not good at that')\n    parser.add_argument('--lst20_spaces_after', action='store_true', dest=\"lst20_spaces_after\", default=False, help='When processing th_lst20 without pythainlp, put spaces after each sentence.  This better fits the language but gets lower scores for some reason')\n    parser.add_argument('--split_clauses', action='store_true', dest=\"split_clauses\", default=False, help='When processing th_lst20 without pythainlp, turn spaces which are labeled as between clauses into sentence splits')\n\ndef parse_lst20_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('input_dir', help=\"Directory to use when processing lst20\")\n    parser.add_argument('output_dir', help=\"Directory to use when saving lst20\")\n    add_lst20_args(parser)\n    return parser.parse_args()\n\n\n\ndef convert(input_dir, output_dir, args):\n    input_dir = os.path.join(input_dir, \"thai\", \"LST20_Corpus\")\n    if not os.path.exists(input_dir):\n        raise FileNotFoundError(\"Could not find LST20 corpus in {}\".format(input_dir))\n\n    for (in_section, out_section) in ((\"train\", \"train\"),\n                                      (\"eval\", \"dev\"),\n                                      (\"test\", \"test\")):\n        print(\"Processing %s\" % out_section)\n        documents = read_data(input_dir, in_section, args.lst20_resegment, args.lst20_spaces_after, args.split_clauses)\n        print(\"  Read in %d documents\" % len(documents))\n        write_section(output_dir, \"lst20\", out_section, documents)\n\ndef main():\n    args = parse_lst20_args()\n    convert(args.input_dir, args.output_dir, args)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/tokenization/convert_th_orchid.py",
    "content": "\"\"\"Parses the xml conversion of orchid\n\nhttps://github.com/korakot/thainlp/blob/master/xmlchid.xml\n\nFor example, if you put the data file in the above link in\nextern_data/thai/orchid/xmlchid.xml\nyou would then run\npython3 -m stanza.utils.datasets.tokenization.convert_th_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize\n\nBecause there is no definitive train/dev/test split that we have found\nso far, we randomly shuffle the data on a paragraph level and split it\n80/10/10.  A random seed is chosen so that the splits are reproducible.\n\nThe datasets produced have a similar format to the UD datasets, so we\ngive it a fake UD name to make life easier for the downstream tools.\n\nTraining on this dataset seems to work best with low dropout numbers.\nFor example:\n\npython3 -m stanza.utils.training.run_tokenizer th_orchid --dropout 0.05 --unit_dropout 0.05\n\nThis results in a model with dev set scores:\n th_orchid 87.98 70.94\ntest set scores:\n 91.60   72.43\n\nApparently the random split produced a test set easier than the dev set.\n\"\"\"\n\nimport os\nimport random\nimport sys\nimport xml.etree.ElementTree as ET\n\nfrom stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset\n\n# line \"122819\" has some error in the tokenization of the musical notation\n# line \"209380\" is also messed up\n# others have @ followed by a part of speech, which is clearly wrong\n\nskipped_lines = {\n    \"122819\",\n    \"209380\",\n    \"227769\",\n    \"245992\",\n    \"347163\",\n    \"409708\",\n    \"431227\",\n}\n\nescape_sequences = {\n    '<left_parenthesis>': '(',\n    '<right_parenthesis>': ')',\n    '<circumflex_accent>': '^',\n    '<full_stop>': '.',\n    '<minus>': '-',\n    '<asterisk>': '*',\n    '<quotation>': '\"',\n    '<slash>': '/',\n    '<colon>': ':',\n    '<equal>': '=',\n    '<comma>': ',',\n    '<semi_colon>': ';',\n    '<less_than>': '<',\n    '<greater_than>': '>',\n    '<ampersand>': '&',\n    '<left_curly_bracket>': '{',\n    '<right_curly_bracket>': '}',\n    '<apostrophe>': \"'\",\n    '<plus>': '+',\n    '<number>': '#',\n    '<dollar>': '$',\n    '<at_mark>': '@',\n    '<question_mark>': '?',\n    '<exclamation>': '!',\n    'app<LI>ances': 'appliances',\n    'intel<LI>gence': 'intelligence',\n    \"<slash>'\": \"/'\",\n    '<100>': '100',\n}\n\nallowed_sequences = {\n    '<a>',\n    '<b>',\n    '<c>',\n    '<e>',\n    '<f>',\n    '<LI>',\n    '<---vp',\n    '<---',\n    '<----',\n}\n\ndef read_data(input_filename):\n    print(\"Reading {}\".format(input_filename))\n    tree = ET.parse(input_filename)\n    documents = parse_xml(tree)\n    print(\"Number of documents: {}\".format(len(documents)))\n    print(\"Number of paragraphs: {}\".format(sum(len(document) for document in documents)))\n    return documents\n\ndef parse_xml(tree):\n    # we will put each paragraph in a separate block in the output file\n    # we won't pay any attention to the document boundaries unless we\n    # later find out it was necessary\n    # a paragraph will be a list of sentences\n    # a sentence is a list of words, where each word is a string\n    documents = []\n\n    root = tree.getroot()\n    for document in root:\n        # these should all be documents\n        if document.tag != 'document':\n            raise ValueError(\"Unexpected orchid xml layout: {}\".format(document.tag))\n        paragraphs = []\n        for paragraph in document:\n            if paragraph.tag != 'paragraph':\n                raise ValueError(\"Unexpected orchid xml layout: {} under {}\".format(paragraph.tag, document.tag))\n            sentences = []\n            for sentence in paragraph:\n                if sentence.tag != 'sentence':\n                    raise ValueError(\"Unexpected orchid xml layout: {} under {}\".format(sentence.tag, document.tag))\n                if sentence.attrib['line_num'] in skipped_lines:\n                    continue\n                words = []\n                for word_idx, word in enumerate(sentence):\n                    if word.tag != 'word':\n                        raise ValueError(\"Unexpected orchid xml layout: {} under {}\".format(word.tag, sentence.tag))\n                    word = word.attrib['surface']\n                    word = escape_sequences.get(word, word)\n                    if word == '<space>':\n                        if word_idx == 0:\n                            raise ValueError(\"Space character was the first token in a sentence: {}\".format(sentence.attrib['line_num']))\n                        else:\n                            words[-1] = (words[-1][0], True)\n                            continue\n                    if len(word) > 1 and word[0] == '<' and word not in allowed_sequences:\n                        raise ValueError(\"Unknown escape sequence {}\".format(word))\n                    words.append((word, False))\n                if len(words) == 0:\n                    continue\n                words[-1] = (words[-1][0], True)\n                sentences.append(words)\n            paragraphs.append(sentences)\n        documents.append(paragraphs)\n\n    return documents\n\n\ndef main(*args):\n    random.seed(1007)\n    if not args:\n        args = sys.argv[1:]\n    input_filename = args[0]\n    if os.path.isdir(input_filename):\n        input_filename = os.path.join(input_filename, \"thai\", \"orchid\", \"xmlchid.xml\")\n    output_dir = args[1]\n    documents = read_data(input_filename)\n    write_dataset(documents, output_dir, \"orchid\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/datasets/tokenization/convert_vi_vlsp.py",
    "content": "\nimport os\n\npunctuation_set = (',', '.', '!', '?', ')', ':', ';', '”', '…', '...')\n\ndef find_spaces(sentence):\n    # TODO: there are some sentences where there is only one quote,\n    # and some of them should be attached to the previous word instead\n    # of the next word.  Training should work this way, though\n    odd_quotes = False\n\n    spaces = []\n    for word_idx, word in enumerate(sentence):\n        space = True\n        # Quote period at the end of a sentence needs to be attached\n        # to the rest of the text.  Some sentences have `\"... text`\n        # in the middle, though, so look for that\n        if word_idx < len(sentence) - 2 and sentence[word_idx+1] == '\"':\n            if sentence[word_idx+2] == '.':\n                space = False\n            elif word_idx == len(sentence) - 3 and sentence[word_idx+2] == '...':\n                space = False\n        if word_idx < len(sentence) - 1:\n            if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...','/', '%'):\n                space = False\n        if word in ('(', '“', '/'):\n            space = False\n        if word == '\"':\n            if odd_quotes:\n                # already saw one quote.  put this one at the end of the PREVIOUS word\n                # note that we know there must be at least one word already\n                odd_quotes = False\n                spaces[word_idx-1] = False\n            else:\n                odd_quotes = True\n                space = False\n        spaces.append(space)\n    return spaces\n\ndef add_vlsp_args(parser):\n    parser.add_argument('--include_pos_data', action='store_true', default=False, help='To include or not POS training dataset for tokenization training. The path to POS dataset is expected to be in the same dir with WS path. For example, extern_dir/vietnamese/VLSP2013-POS-data')\n    parser.add_argument('--vlsp_include_spaces', action='store_true', default=False, help='When processing vi_vlsp tokenization, include all of the spaces.  Otherwise, we try to turn the text back into standard text')\n\n\ndef write_file(vlsp_include_spaces, output_filename, sentences, shard):\n    with open(output_filename, \"w\") as fout:\n        check_headlines = False\n        for sent_idx, sentence in enumerate(sentences):\n            fout.write(\"# sent_id = %s.%d\\n\" % (shard, sent_idx))\n            orig_text = \" \".join(sentence)\n            #check if the previous line is a headline (no ending mark at the end) then make this sentence a new par\n            if check_headlines:\n                fout.write(\"# newpar id =%s.%d.1\\n\" % (shard, sent_idx))\n                check_headlines = False\n            if sentence[len(sentence) - 1] not in punctuation_set:\n                check_headlines = True\n\n            if vlsp_include_spaces:\n                fout.write(\"# text = %s\\n\" % orig_text)\n            else:\n                spaces = find_spaces(sentence)\n                full_text = \"\"\n                for word, space in zip(sentence, spaces):\n                    # could be made more efficient, but shouldn't matter\n                    full_text = full_text + word\n                    if space:\n                        full_text = full_text + \" \"\n                fout.write(\"# text = %s\\n\" % full_text)\n                fout.write(\"# orig_text = %s\\n\" % orig_text)\n            for word_idx, word in enumerate(sentence):\n                fake_dep = \"root\" if word_idx == 0 else \"dep\"\n                fout.write(\"%d\\t%s\\t%s\" % ((word_idx+1), word, word))\n                fout.write(\"\\t_\\t_\\t_\")\n                fout.write(\"\\t%d\\t%s\" % (word_idx, fake_dep))\n                fout.write(\"\\t_\\t\")\n                if vlsp_include_spaces or spaces[word_idx]:\n                    fout.write(\"_\")\n                else:\n                    fout.write(\"SpaceAfter=No\")\n                fout.write(\"\\n\")\n            fout.write(\"\\n\")\n\ndef convert_pos_dataset(file_path):\n    \"\"\"\n    This function is to process the pos dataset\n    \"\"\"\n    \n    file = open(file_path, \"r\")\n    document = file.readlines()\n    sentences = []\n    sent = []\n    for line in document:\n        if line == \"\\n\" and len(sent)>1:\n            if sent not in sentences:\n                sentences.append(sent)\n            sent = []\n        elif line != \"\\n\":\n            sent.append(line.split(\"\\t\")[0].replace(\"_\",\" \").strip())\n    return sentences\n        \ndef convert_file(vlsp_include_spaces, input_filename, output_filename, shard, split_filename=None, split_shard=None, pos_data = None):\n    with open(input_filename) as fin:\n        lines = fin.readlines()\n\n    sentences = []\n    set_sentences = set()\n    for line in lines:\n        if len(line.replace(\"_\", \" \").split())>1:\n            words = line.split()\n            #one syllable lines are eliminated\n            if len(words) == 1 and len(words[0].split(\"_\")) == 1:\n                continue\n            else:\n                words = [w.replace(\"_\", \" \") for w in words]\n                #only add sentences that hasn't been added before\n                if words not in sentences:\n                    sentences.append(words)\n                    set_sentences.add(' '.join(words))\n                \n    if split_filename is not None:\n        # even this is a larger dev set than the train set\n        split_point = int(len(sentences) * 0.95)\n        #check pos_data that aren't overlapping with current VLSP WS dataset\n        sentences_pos = [] if pos_data is None else [sent for sent in pos_data if ' '.join(sent) not in set_sentences]\n        print(\"Added \", len(sentences_pos), \" sentences from POS dataset.\")\n        write_file(vlsp_include_spaces, output_filename, sentences[:split_point]+sentences_pos, shard)\n        write_file(vlsp_include_spaces, split_filename, sentences[split_point:], split_shard)\n    else:\n        write_file(vlsp_include_spaces, output_filename, sentences, shard)\n\ndef convert_vi_vlsp(extern_dir, tokenizer_dir, args):\n    input_path = os.path.join(extern_dir, \"vietnamese\", \"VLSP2013-WS-data\")\n    input_pos_path = os.path.join(extern_dir, \"vietnamese\", \"VLSP2013-POS-data\")\n    input_train_filename = os.path.join(input_path, \"VLSP2013_WS_train_gold.txt\")\n    input_test_filename = os.path.join(input_path, \"VLSP2013_WS_test_gold.txt\")\n    \n    input_pos_filename = os.path.join(input_pos_path, \"VLSP2013_POS_train_BI_POS_Column.txt.goldSeg\")\n    if not os.path.exists(input_train_filename):\n        raise FileNotFoundError(\"Cannot find train set for VLSP at %s\" % input_train_filename)\n    if not os.path.exists(input_test_filename):\n        raise FileNotFoundError(\"Cannot find test set for VLSP at %s\" % input_test_filename)\n    pos_data = None\n    if args.include_pos_data:\n        if not os.path.exists(input_pos_filename):\n            raise FileNotFoundError(\"Cannot find pos dataset for VLSP at %\" % input_pos_filename)\n        else:\n            pos_data = convert_pos_dataset(input_pos_filename) \n\n    output_train_filename = os.path.join(tokenizer_dir, \"vi_vlsp.train.gold.conllu\")\n    output_dev_filename = os.path.join(tokenizer_dir,   \"vi_vlsp.dev.gold.conllu\")\n    output_test_filename = os.path.join(tokenizer_dir,  \"vi_vlsp.test.gold.conllu\")\n\n    convert_file(args.vlsp_include_spaces, input_train_filename, output_train_filename, \"train\", output_dev_filename, \"dev\", pos_data)\n    convert_file(args.vlsp_include_spaces, input_test_filename, output_test_filename, \"test\")\n\n"
  },
  {
    "path": "stanza/utils/datasets/tokenization/process_thai_tokenization.py",
    "content": "import os\nimport random\n\ntry:\n    from pythainlp import sent_tokenize\nexcept ImportError:\n    pass\n\ndef write_section(output_dir, dataset_name, section, documents):\n    \"\"\"\n    Writes a list of documents for tokenization, including a file in conll format\n\n    The Thai datasets generally have no MWT (apparently not relevant for Thai)\n\n    output_dir: the destination directory for the output files\n    dataset_name: orchid, BEST, lst20, etc\n    section: train/dev/test\n    documents: a nested list of documents, paragraphs, sentences, words\n      words is a list of (word, space_follows)\n    \"\"\"\n    with open(os.path.join(output_dir, 'th_%s-ud-%s-mwt.json' % (dataset_name, section)), 'w') as fout:\n        fout.write(\"[]\\n\")\n\n    text_out = open(os.path.join(output_dir, 'th_%s.%s.txt' % (dataset_name, section)), 'w')\n    label_out = open(os.path.join(output_dir, 'th_%s-ud-%s.toklabels' % (dataset_name, section)), 'w')\n    for document in documents:\n        for paragraph in document:\n            for sentence_idx, sentence in enumerate(paragraph):\n                for word_idx, word in enumerate(sentence):\n                    # TODO: split with newlines to make it more readable?\n                    text_out.write(word[0])\n                    for i in range(len(word[0]) - 1):\n                        label_out.write(\"0\")\n                    if word_idx == len(sentence) - 1:\n                        label_out.write(\"2\")\n                    else:\n                        label_out.write(\"1\")\n                    if word[1] and (sentence_idx != len(paragraph) - 1 or word_idx != len(sentence) - 1):\n                        text_out.write(' ')\n                        label_out.write('0')\n\n            text_out.write(\"\\n\\n\")\n            label_out.write(\"\\n\\n\")\n\n    text_out.close()\n    label_out.close()\n\n    with open(os.path.join(output_dir, 'th_%s.%s.gold.conllu' % (dataset_name, section)), 'w') as fout:\n        for document in documents:\n            for paragraph in document:\n                new_par = True\n                for sentence in paragraph:\n                    for word_idx, word in enumerate(sentence):\n                        # SpaceAfter is left blank if there is space after the word\n                        if word[1] and new_par:\n                            space = 'NewPar=Yes'\n                        elif word[1]:\n                            space = '_'\n                        elif new_par:\n                            space = 'SpaceAfter=No|NewPar=Yes'\n                        else:\n                            space = 'SpaceAfter=No'\n                        new_par = False\n\n                        # Note the faked dependency structure: the conll reading code\n                        # needs it even if it isn't being used in any way\n                        fake_dep = 'root' if word_idx == 0 else 'dep'\n                        fout.write('{}\\t{}\\t_\\t_\\t_\\t_\\t{}\\t{}\\t{}:{}\\t{}\\n'.format(word_idx+1, word[0], word_idx, fake_dep, word_idx, fake_dep, space))\n                    fout.write('\\n')\n\ndef write_dataset(documents, output_dir, dataset_name):\n    \"\"\"\n    Shuffle a list of documents, write three sections\n    \"\"\"\n    random.shuffle(documents)\n    num_train = int(len(documents) * 0.8)\n    num_dev = int(len(documents) * 0.1)\n    os.makedirs(output_dir, exist_ok=True)\n    write_section(output_dir, dataset_name, 'train', documents[:num_train])\n    write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev])\n    write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:])\n\ndef write_dataset_best(documents, test_documents, output_dir, dataset_name):\n    \"\"\"\n    Shuffle a list of documents, write three sections\n    \"\"\"\n    random.shuffle(documents)\n    num_train = int(len(documents) * 0.85)\n    num_dev = int(len(documents) * 0.15)\n    os.makedirs(output_dir, exist_ok=True)\n    write_section(output_dir, dataset_name, 'train', documents[:num_train])\n    write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev])\n    write_section(output_dir, dataset_name, 'test', test_documents)\n\n\ndef reprocess_lines(processed_lines):\n    \"\"\"\n    Reprocesses lines using pythainlp to cut up sentences into shorter sentences.\n\n    Many of the lines in BEST seem to be multiple Thai sentences concatenated, according to native Thai speakers.\n\n    Input: a list of lines, where each line is a list of words.  Space characters can be included as words\n    Output: a new list of lines, resplit using pythainlp\n    \"\"\"\n    reprocessed_lines = []\n    for line in processed_lines:\n        text = \"\".join(line)\n        try:\n            chunks = sent_tokenize(text)\n        except NameError as e:\n            raise NameError(\"Sentences cannot be reprocessed without first installing pythainlp\") from e\n        # Check that the total text back is the same as the text in\n        if sum(len(x) for x in chunks) != len(text):\n            raise ValueError(\"Got unexpected text length: \\n{}\\nvs\\n{}\".format(text, chunks))\n\n        chunk_lengths = [len(x) for x in chunks]\n\n        current_length = 0\n        new_line = []\n        for word in line:\n            if len(word) + current_length < chunk_lengths[0]:\n                new_line.append(word)\n                current_length = current_length + len(word)\n            elif len(word) + current_length == chunk_lengths[0]:\n                new_line.append(word)\n                reprocessed_lines.append(new_line)\n                new_line = []\n                chunk_lengths = chunk_lengths[1:]\n                current_length = 0\n            else:\n                remaining_len = chunk_lengths[0] - current_length\n                new_line.append(word[:remaining_len])\n                reprocessed_lines.append(new_line)\n                word = word[remaining_len:]\n                chunk_lengths = chunk_lengths[1:]\n                while len(word) > chunk_lengths[0]:\n                    new_line = [word[:chunk_lengths[0]]]\n                    reprocessed_lines.append(new_line)\n                    word = word[chunk_lengths[0]:]\n                    chunk_lengths = chunk_lengths[1:]\n                new_line = [word]\n                current_length = len(word)\n        reprocessed_lines.append(new_line)\n    return reprocessed_lines\n\ndef convert_processed_lines(processed_lines):\n    \"\"\"\n    Convert a list of sentences into documents suitable for the output methods in this module.\n\n    Input: a list of lines, including space words\n    Output: a list of documents, each document containing a list of sentences\n            Each sentence is a list of words: (text, space_follows)\n            Space words will be eliminated.\n    \"\"\"\n    paragraphs = []\n    sentences = []\n    for words in processed_lines:\n        # turn the words into a sentence\n        if len(words) > 1 and \" \" == words[0]:\n            words = words[1:]\n        elif len(words) == 1 and \" \" == words[0]:\n            words = []\n\n        sentence = []\n        for word in words:\n            word = word.strip()\n            if not word:\n                if len(sentence) == 0:\n                    print(word)\n                    raise ValueError(\"Unexpected space at start of sentence in document {}\".format(filename))\n                sentence[-1] = (sentence[-1][0], True)\n            else:\n                sentence.append((word, False))\n        # blank lines are very rare in best, but why not treat them as a paragraph break\n        if len(sentence) == 0:\n            paragraphs.append([sentences])\n            sentences = []\n            continue\n        sentence[-1] = (sentence[-1][0], True)\n        sentences.append(sentence)\n    paragraphs.append([sentences])\n    return paragraphs\n\n\n\n\n\n"
  },
  {
    "path": "stanza/utils/datasets/vietnamese/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/datasets/vietnamese/renormalize.py",
    "content": "\"\"\"\nScript to renormalize diacritics for Vietnamese text\n\nfrom BARTpho\nhttps://github.com/VinAIResearch/BARTpho/blob/main/VietnameseToneNormalization.md\nhttps://github.com/VinAIResearch/BARTpho/blob/main/LICENSE\n\nMIT License\n\nCopyright (c) 2021 VinAI Research\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nimport argparse\nimport os\n\nDICT_MAP = {\n    \"òa\": \"oà\",\n    \"Òa\": \"Oà\",\n    \"ÒA\": \"OÀ\",\n    \"óa\": \"oá\",\n    \"Óa\": \"Oá\",\n    \"ÓA\": \"OÁ\",\n    \"ỏa\": \"oả\",\n    \"Ỏa\": \"Oả\",\n    \"ỎA\": \"OẢ\",\n    \"õa\": \"oã\",\n    \"Õa\": \"Oã\",\n    \"ÕA\": \"OÃ\",\n    \"ọa\": \"oạ\",\n    \"Ọa\": \"Oạ\",\n    \"ỌA\": \"OẠ\",\n    \"òe\": \"oè\",\n    \"Òe\": \"Oè\",\n    \"ÒE\": \"OÈ\",\n    \"óe\": \"oé\",\n    \"Óe\": \"Oé\",\n    \"ÓE\": \"OÉ\",\n    \"ỏe\": \"oẻ\",\n    \"Ỏe\": \"Oẻ\",\n    \"ỎE\": \"OẺ\",\n    \"õe\": \"oẽ\",\n    \"Õe\": \"Oẽ\",\n    \"ÕE\": \"OẼ\",\n    \"ọe\": \"oẹ\",\n    \"Ọe\": \"Oẹ\",\n    \"ỌE\": \"OẸ\",\n    \"ùy\": \"uỳ\",\n    \"Ùy\": \"Uỳ\",\n    \"ÙY\": \"UỲ\",\n    \"úy\": \"uý\",\n    \"Úy\": \"Uý\",\n    \"ÚY\": \"UÝ\",\n    \"ủy\": \"uỷ\",\n    \"Ủy\": \"Uỷ\",\n    \"ỦY\": \"UỶ\",\n    \"ũy\": \"uỹ\",\n    \"Ũy\": \"Uỹ\",\n    \"ŨY\": \"UỸ\",\n    \"ụy\": \"uỵ\",\n    \"Ụy\": \"Uỵ\",\n    \"ỤY\": \"UỴ\",\n}\n\n\ndef replace_all(text):\n    for i, j in DICT_MAP.items():\n        text = text.replace(i, j)\n    return text\n\ndef convert_file(org_file, new_file):\n    with open(org_file, 'r', encoding='utf-8') as reader, open(new_file, 'w', encoding='utf-8') as writer:\n        content = reader.readlines()\n        for line in content:\n            new_line = replace_all(line)\n            writer.write(new_line)\n\ndef convert_files(file_list, new_dir):\n    for file_name in file_list:\n        base_name = os.path.split(file_name)[-1]\n        new_file_path = os.path.join(new_dir, base_name)\n\n        convert_file(file_name, new_file_path)\n\n\ndef convert_dir(org_dir, new_dir, suffix):\n    os.makedirs(new_dir, exist_ok=True)\n    file_list = os.listdir(org_dir)\n    file_list = [os.path.join(org_dir, f) for f in file_list if os.path.splitext(f)[1] == suffix]\n    convert_files(file_list, new_dir)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description='Script that renormalizes diacritics'\n    )\n\n    parser.add_argument(\n        'orig',\n        help='Location of the original directory'\n    )\n\n    parser.add_argument(\n        'converted',\n        help='The location of new directory'\n    )\n\n    parser.add_argument(\n        '--suffix',\n        type=str,\n        default='.txt',\n        help='Which suffix to look for when renormalizing a directory'\n    )\n\n    args = parser.parse_args()\n\n    if os.path.isfile(args.orig):\n        convert_file(args.orig, args.converted)\n    else:\n        convert_dir(args.orig, args.converted, args.suffix)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/default_paths.py",
    "content": "import os\n\ndef get_default_paths():\n    \"\"\"\n    Gets base paths for the data directories\n\n    If DATA_ROOT is set in the environment, use that as the root\n    otherwise use \"./data\"\n    individual paths can also be set in the environment\n    \"\"\"\n    DATA_ROOT = os.environ.get(\"DATA_ROOT\", \"data\")\n    defaults = {\n        \"TOKENIZE_DATA_DIR\": DATA_ROOT + \"/tokenize\",\n        \"MWT_DATA_DIR\": DATA_ROOT + \"/mwt\",\n        \"LEMMA_DATA_DIR\": DATA_ROOT + \"/lemma\",\n        \"POS_DATA_DIR\": DATA_ROOT + \"/pos\",\n        \"DEPPARSE_DATA_DIR\": DATA_ROOT + \"/depparse\",\n        \"ETE_DATA_DIR\": DATA_ROOT + \"/ete\",\n        \"NER_DATA_DIR\": DATA_ROOT + \"/ner\",\n        \"CHARLM_DATA_DIR\": DATA_ROOT + \"/charlm\",\n        \"SENTIMENT_DATA_DIR\": DATA_ROOT + \"/sentiment\",\n        \"CONSTITUENCY_DATA_DIR\": DATA_ROOT + \"/constituency\",\n        \"COREF_DATA_DIR\": DATA_ROOT + \"/coref\",\n        \"LEMMA_CLASSIFIER_DATA_DIR\": DATA_ROOT + \"/lemma_classifier\",\n\n        # Set directories to store external word vector data\n        \"WORDVEC_DIR\": \"extern_data/wordvec\",\n\n        # TODO: not sure what other people actually have\n        # TODO: also, could make this automatically update to the latest\n        \"UDBASE\": \"extern_data/ud2/ud-treebanks-v2.11\",\n        \"UDBASE_GIT\": \"extern_data/ud2/git\",\n\n        \"NERBASE\": \"extern_data/ner\",\n        \"CONSTITUENCY_BASE\": \"extern_data/constituency\",\n        \"SENTIMENT_BASE\": \"extern_data/sentiment\",\n        \"COREF_BASE\": \"extern_data/coref\",\n\n        # there's a stanford github, stanfordnlp/handparsed-treebank,\n        # with some data for different languages\n        \"HANDPARSED_DIR\": \"extern_data/handparsed-treebank\",\n\n        # directory with the contents of https://nlp.stanford.edu/projects/stanza/bio/\n        # on the cluster, for example, /u/nlp/software/stanza/bio_ud\n        \"BIO_UD_DIR\": \"extern_data/bio\",\n\n        # data root for other general input files, such as VI_VLSP\n        \"STANZA_EXTERN_DIR\": \"extern_data\",\n    }\n\n    paths = { \"DATA_ROOT\" : DATA_ROOT }\n    for k, v in defaults.items():\n        paths[k] = os.environ.get(k, v)\n\n    return paths\n"
  },
  {
    "path": "stanza/utils/get_tqdm.py",
    "content": "import sys\n\ndef get_tqdm():\n    \"\"\"\n    Return a tqdm appropriate for the situation\n\n    imports tqdm depending on if we're at a console, redir to a file, notebook, etc\n\n    from @tcrimi at https://github.com/tqdm/tqdm/issues/506\n\n    This replaces `import tqdm`, so for example, you do this:\n      from stanza.utils.get_tqdm import get_tqdm\n      tqdm = get_tqdm()\n    then do this when you want a scroll bar or regular iterator depending on context:\n      tqdm(list)\n\n    If there is no tty, the returned tqdm will always be disabled\n    unless disable=False is specifically set.\n    \"\"\"\n    ipy_str = \"\"\n    try:\n        from IPython import get_ipython\n        ipy_str = str(type(get_ipython()))\n    except ImportError:\n        pass\n\n    if 'zmqshell' in ipy_str:\n        from tqdm import tqdm_notebook as tqdm\n        return tqdm\n    if 'terminal' in ipy_str:\n        from tqdm import tqdm\n        return tqdm\n\n    if sys.stderr is not None and hasattr(sys.stderr, \"isatty\") and sys.stderr.isatty():\n        from tqdm import tqdm\n        return tqdm\n\n    from tqdm import tqdm\n    def hidden_tqdm(*args, **kwargs):\n        if \"disable\" in kwargs:\n            return tqdm(*args, **kwargs)\n        kwargs[\"disable\"] = True\n        return tqdm(*args, **kwargs)\n\n    return hidden_tqdm\n\n"
  },
  {
    "path": "stanza/utils/helper_func.py",
    "content": "def make_table(header, content, column_width=None):\n    '''\n    Input:\n    header -> List[str]: table header\n    content -> List[List[str]]: table content\n    column_width -> int: table column width; set to None for dynamically calculated widths\n    \n    Output:\n    table_str -> str: well-formatted string for the table\n    '''\n    table_str = ''\n    len_column, len_row = len(header), len(content) + 1\n    if column_width is None:\n        # dynamically decide column widths\n        lens = [[len(str(h)) for h in header]]\n        lens += [[len(str(x)) for x in row] for row in content]\n        column_widths = [max(c)+3 for c in zip(*lens)]\n    else:\n        column_widths = [column_width] * len_column\n    \n    table_str += '=' * (sum(column_widths) + 1) + '\\n'\n    \n    table_str += '|'\n    for i, item in enumerate(header):\n        table_str += ' ' + str(item).ljust(column_widths[i] - 2) + '|'\n    table_str += '\\n'\n    \n    table_str += '-' * (sum(column_widths) + 1) + '\\n'\n    \n    for line in content:\n        table_str += '|'\n        for i, item in enumerate(line):\n            table_str += ' ' + str(item).ljust(column_widths[i] - 2) + '|'\n        table_str += '\\n'\n    \n    table_str += '=' * (sum(column_widths) + 1) + '\\n'\n    \n    return table_str\n"
  },
  {
    "path": "stanza/utils/languages/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/languages/kazakh_transliteration.py",
    "content": "\"\"\"\nKazakh Transliteration:\n    Cyrillic Kazakh --> Latin Kazakh\n\n\n\"\"\"\n\nimport argparse\nimport os\nfrom re import M\nimport string\nimport sys\n\nfrom stanza.models.common.utils import open_read_text, get_tqdm\ntqdm = get_tqdm()\n\n\"\"\"\nThis dictionary isn't used in the code, just put this here in case you want to implement it more\nefficiently and in case the need to look up the unicode encodings for these letters might arise.\nSome letters are mapped to multiple latin letters, for these, I separated the unicde with a '%' delimiter\nbetween the two unicode characters.\n\"\"\"\nalph_map = {\n    '\\u0410' # А\n    : '\\u0041', # A\n    '\\u0430' # а\n    : '\\u0061', # a\n\n    '\\u04D8' # Ә\n    : '\\u00c4', # Ä\n    '\\u04D9' # ә\n    : '\\u00e4', # ä\n\n    '\\u0411' # Б\n    : '\\u0042', # B\n    '\\u0431' # б\n    : '\\u0062', # b\n\n    '\\u0412' # В\n    : '\\u0056', # V\n    '\\u0432' # в\n    : '\\u0076', # v\n\n    '\\u0413' # Г\n    : '\\u0047', # G\n    '\\u0433' # г\n    : '\\u0067', # g\n\n    '\\u0492' # Ғ\n    : '\\u011e', # Ğ\n    '\\u0493' # ғ\n    : '\\u011f', # ğ\n\n    '\\u0414' # Д\n    : '\\u0044', # D\n    '\\u0434' # д\n    : '\\u0064', # d\n\n    '\\u0415' # Е\n    : '\\u0045', # E\n    '\\u0435' # е\n    : '\\u0065', # e\n\n    '\\u0401' # Ё\n    : '\\u0130%\\u006f', # İo\n    '\\u0451' # ё\n    : '\\u0069%\\u006f', #io\n\n    '\\u0416' # Ж\n    : '\\u004a', # J\n    '\\u0436' # ж\n    : '\\u006a', # j\n\n    '\\u0417' # З\n    : '\\u005a', # Z\n    '\\u0437' # з\n    : '\\u007a', # z\n\n    '\\u0418' # И\n    : '\\u0130', # İ\n    '\\u0438' # и\n    : '\\u0069', # i\n\n    '\\u0419' # Й\n    : '\\u0130', # İ\n    '\\u0439' # й\n    : '\\u0069', # i\n\n    '\\u041A' # К\n    : '\\u004b', # K\n    '\\u043A' # к\n    : '\\u006b', # k\n\n    '\\u049A' # Қ\n    : '\\u0051', # Q\n    '\\u049B' # қ\n    : '\\u0071', # q\n\n    '\\u041B' # Л\n    : '\\u004c', # L\n    '\\u043B' # л\n    : '\\u006c', # l\n\n    '\\u041C' # М\n    : '\\u004d', # M\n    '\\u043C' # м\n    : '\\u006d', # m\n\n    '\\u041D' # Н\n    : '\\u004e', # N\n    '\\u043D' # н\n    : '\\u006e', # n\n\n    '\\u04A2' # Ң\n    : '\\u00d1', # Ñ\n    '\\u04A3' # ң\n    : '\\u00f1', # ñ\n\n    '\\u041E' # О\n    : '\\u004f', # O\n    '\\u043E' # о\n    : '\\u006f', # o\n\n    '\\u04E8' # Ө\n    : '\\u00d6', # Ö\n    '\\u04E9' # ө\n    : '\\u00f6', # ö\n\n    '\\u041F' # П\n    : '\\u0050', # P\n    '\\u043F' # п\n    : '\\u0070', # p\n\n    '\\u0420' # Р\n    : '\\u0052', # R\n    '\\u0440' # р\n    : '\\u0072', # r\n\n    '\\u0421' # С\n    : '\\u0053', # S\n    '\\u0441' # с\n    : '\\u0073', # s\n\n    '\\u0422' # Т\n    : '\\u0054', # T\n    '\\u0442' # т\n    : '\\u0074', # t\n\n    '\\u0423' # У\n    : '\\u0055', # U\n    '\\u0443' # у\n    : '\\u0075', # u\n\n    '\\u04B0' # Ұ\n    : '\\u016a', # Ū\n    '\\u04B1' # ұ\n    : '\\u016b', # ū\n\n    '\\u04AE' # Ү\n    : '\\u00dc', # Ü\n    '\\u04AF' # ү\n    : '\\u00fc', # ü\n\n    '\\u0424' # Ф\n    : '\\u0046', # F\n    '\\u0444' # ф\n    : '\\u0066', # f\n\n    '\\u0425' # Х\n    : '\\u0048', # H\n    '\\u0445' # х\n    : '\\u0068', # h\n\n    '\\u04BA' # Һ\n    : '\\u0048', # H\n    '\\u04BB' # һ\n    : '\\u0068', # h\n\n    '\\u0426' # Ц\n    : '\\u0043', # C\n    '\\u0446' # ц\n    : '\\u0063', # c\n\n    '\\u0427' # Ч\n    : '\\u00c7', # Ç\n    '\\u0447' # ч\n    : '\\u00e7', # ç\n\n    '\\u0428' # Ш\n    : '\\u015e', # Ş\n    '\\u0448' # ш\n    : '\\u015f', # ş\n\n    '\\u0429' # Щ\n    : '\\u015e%\\u00e7', # Şç\n    '\\u0449' # щ\n    : '\\u015f%\\u00e7', # şç\n\n    '\\u042A' # Ъ\n    : '', # Empty String\n    '\\u044A' # ъ\n    : '', # Empty String \\u\n\n    '\\u042B' # Ы\n    : '\\u0059', # Y\n    '\\u044B' # ы\n    : '\\u0079', # y\n\n    '\\u0406' # І\n    : '\\u0130', # İ\n    '\\u0456' # і\n    : '\\u0069', # i\n\n    '\\u042C' # Ь\n    : '', # Empty String\n    '\\u044C' # ь\n    : '', # Empty String\n\n    '\\u042D' # Э\n    : '\\u0045', # E\n    '\\u044D' # э\n    : '\\u0065', # e\n\n    '\\u042E' # Ю\n    : '\\u0130%\\u0075', # İu\n    '\\u044E' # ю\n    : '\\u0069%\\u0075', # iu\n\n    '\\u042F' # Я\n    : '\\u0130%\\u0061', # İa\n    '\\u044F' # я\n    : '\\u0069%\\u0061' # ia\n}\n\nkazakh_alph = \"АаӘәБбВвГгҒғДдЕеЁёЖжЗзИиЙйКкҚқЛлМмНнҢңОоӨөПпРрСсТтУуҰұҮүФфХхҺһЦцЧчШшЩщЪъЫыІіЬьЭэЮюЯя\"\nlatin_alph = \"AaÄäBbVvGgĞğDdEeİoioJjZzİiİiKkQqLlMmNnÑñOoÖöPpRrSsTtUuŪūÜüFfHhHhCcÇçŞşŞçşçYyİiEeİuiuİaia\"\nmult_mapping = \"ЁёЩщЮюЯя\"\nempty_mapping = \"ЪъЬь\"\n\n\n\"\"\"\nϵ : Ukrainian letter for 'ё'\nə : Russian utf-8 encoding for Kazakh 'ә'\nó : 2016 Kazakh Latin adopted this instead of 'ö'\nã : 1 occurrence in the dataset -- mapped to 'a'\n\n\"\"\"\nrussian_alph = \"ϵəóã\"\nrussian_counterpart = \"ioäaöa\"\ndef create_dic(source_alph, target_alph, mult_mapping, empty_mapping):\n    res = {}\n    idx = 0\n    for i in range(len(source_alph)):\n        l_s = source_alph[i]\n        if l_s in mult_mapping:\n            res[l_s] = target_alph[idx] + target_alph[idx+1]\n            idx += 1\n        elif l_s in empty_mapping:\n            res[l_s] = ''\n            idx -= 1\n        else:\n            res[l_s] = target_alph[idx]\n        idx += 1\n\n    res['ϵ'] = 'io'\n    res['ə'] = 'ä'\n    res['ó'] = 'ö'\n    res['ã'] = 'a'\n\n    print(res)\n    return res\n\n\nsupp_alph = \"IWwXx0123456789–«»—\"\n\ndef transliterate(source):\n    output = \"\"\n    tr_dict = create_dic(kazakh_alph, latin_alph, mult_mapping, empty_mapping)\n    punc = string.punctuation\n    white_spc = string.whitespace\n    for c in source:\n        if c in punc or c in white_spc:\n            output += c\n\n        elif c in latin_alph or c in supp_alph:\n            output += c\n\n        elif c in tr_dict:\n            output += tr_dict[c]\n\n        else:\n            print(f\"Transliteration Error: {c}\")\n\n    return output\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('input_file', type=str, nargs=\"+\", help=\"Files to process\")\n    parser.add_argument('--output_dir', type=str, default=None, help=\"Directory to output results\")\n    args = parser.parse_args()\n\n    tr_dict = create_dic(kazakh_alph, latin_alph, mult_mapping, empty_mapping)\n    for filename in tqdm(args.input_file):\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n            directory, basename = os.path.split(filename)\n            output_name = os.path.join(args.output_dir, basename)\n            if output_name.endswith(\".xz\"):\n                output_name = output_name[:-3]\n            output_name = output_name + \".trans\"\n        else:\n            output_name = filename + \".trans\"\n\n        tqdm.write(\"Transliterating %s to %s\" % (filename, output_name))\n\n        with open_read_text(filename) as f_in:\n            data = f_in.read()\n        with open(output_name, 'w') as f_out:\n            punc = string.punctuation\n            white_spc = string.whitespace\n            for c in tqdm(data, leave=False):\n                if c in tr_dict:\n                    f_out.write(tr_dict[c])\n\n                else:\n                    f_out.write(c)\n\n\n    print(\"Process Completed Successfully!\")\n\n"
  },
  {
    "path": "stanza/utils/lemma/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/lemma/count_ambiguous_lemmas.py",
    "content": "\"\"\"\nRead in a UD file, report any word/verb pairs which get lemmatized to different lemmas\n\"\"\"\n\nfrom collections import Counter, defaultdict\nimport sys\n\nfrom stanza.utils.conll import CoNLL\n\nfilename = sys.argv[1]\nprint(filename)\n\nlemma_counters = defaultdict(Counter)\n\ndoc = CoNLL.conll2doc(input_file=filename)\nfor sentence in doc.sentences:\n    for word in sentence.words:\n        text = word.text\n        upos = word.upos\n        lemma = word.lemma\n\n        lemma_counters[(text, upos)][lemma] += 1\n\nkeys = lemma_counters.keys()\nkeys = sorted(keys, reverse=True, key=lambda x: sum(lemma_counters[x][y] for y in lemma_counters[x]))\nfor text, upos in keys:\n    if len(lemma_counters[(text, upos)]) > 1:\n        print(text, upos, lemma_counters[(text, upos)])\n\n"
  },
  {
    "path": "stanza/utils/max_mwt_length.py",
    "content": "import sys\n\nimport json\n\ndef max_mwt_length(filenames):\n    max_len = 0\n    for filename in filenames:\n        with open(filename) as f:\n            d = json.load(f)\n            max_len = max([max_len] + [len(\" \".join(x[0][1])) for x in d])\n    return max_len\n\nif __name__ == '__main__':\n    print(max_max_jlength(sys.argv[1:]))\n"
  },
  {
    "path": "stanza/utils/ner/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/ner/flair_ner_tag_dataset.py",
    "content": "\"\"\"\nTest a flair model on a 4 class dataset\n\"\"\"\n\nimport argparse\nimport json\n\nfrom flair.data import Sentence\nfrom flair.models import SequenceTagger\n\nfrom stanza.models.ner.utils import process_tags\nfrom stanza.models.ner.scorer import score_by_entity, score_by_token\n\ndef test_file(eval_file, tagger):\n    with open(eval_file) as fin:\n        gold_doc = json.load(fin)\n    gold_doc = [[(x['text'], x['ner']) for x in sentence] for sentence in gold_doc]\n    gold_doc = process_tags(gold_doc, 'bioes')\n\n    pred_doc = []\n    for gold_sentence in gold_doc:\n        pred_sentence = [[x[0], 'O'] for x in gold_sentence]\n        flair_sentence = Sentence(\" \".join(x[0] for x in pred_sentence), use_tokenizer=False)\n        tagger.predict(flair_sentence)\n\n        for entity in flair_sentence.get_spans('ner'):\n            tag = entity.tag\n            tokens = entity.tokens\n            start_idx = tokens[0].idx - 1\n            end_idx = tokens[-1].idx\n            if len(tokens) == 1:\n                pred_sentence[start_idx][1] = \"S-\" + tag\n            else:\n                pred_sentence[start_idx][1] = \"B-\" + tag\n                pred_sentence[end_idx - 1][1] = \"E-\" + tag\n                for idx in range(start_idx+1, end_idx - 1):\n                    pred_sentence[idx][1] = \"I-\" + tag\n\n        pred_doc.append(pred_sentence)\n\n    pred_tags = [[x[1] for x in sentence] for sentence in pred_doc]\n    gold_tags = [[x[1] for x in sentence] for sentence in gold_doc]\n    print(\"RESULTS ON: %s\" % eval_file)\n    _, _, f_micro, _ = score_by_entity(pred_tags, gold_tags)\n    score_by_token(pred_tags, gold_tags)\n    return f_micro\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--ner_model', type=str, default=None,  help='Which NER model to test')\n    parser.add_argument('filename', type=str, nargs='*', help='which files to test')\n    args = parser.parse_args()\n\n    if args.ner_model is None:\n        ner_models = [\"ner-fast\", \"ner\", \"ner-large\"]\n    else:\n        ner_models = [args.ner_model]\n\n    if not args.filename:\n        args.filename = [\"data/ner/en_conll03.test.json\",\n                         \"data/ner/en_worldwide-4class.test.json\",\n                         \"data/ner/en_worldwide-4class-africa.test.json\",\n                         \"data/ner/en_worldwide-4class-asia.test.json\",\n                         \"data/ner/en_worldwide-4class-indigenous.test.json\",\n                         \"data/ner/en_worldwide-4class-latam.test.json\",\n                         \"data/ner/en_worldwide-4class-middle_east.test.json\"]\n\n    print(\"Processing the files: %s\" % \",\".join(args.filename))\n\n    results = []\n    model_results = {}\n\n    for ner_model in ner_models:\n        model_results[ner_model] = []\n\n        # load tagger\n        #tagger = SequenceTagger.load(\"ner-fast\")\n        print(\"-----------------------------\")\n        print(\"Running %s\" % ner_model)\n        print(\"-----------------------------\")\n        tagger = SequenceTagger.load(ner_model)\n\n        for filename in args.filename:\n            f_micro = test_file(filename, tagger)\n            f_micro = \"%.2f\" % (f_micro * 100)\n            results.append((ner_model, filename, f_micro))\n            model_results[ner_model].append(f_micro)\n\n    for result in results:\n        print(result)\n\n    for model in model_results.keys():\n        result = [model] + model_results[model]\n        print(\" & \".join(result))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/ner/paying_annotators.py",
    "content": "import json\nimport os\n\n\ndef get_worker_subs(json_string):\n    \"\"\"\n    Gets the AWS worker IDs from the annotation file in output folder.\n\n    Returns a list of the AWS worker subs\n    \"\"\"\n    subs = []\n    # json.loads() works on JSON strings, json.load() is for JSON files\n    job_data = json.loads(json_string)\n    for i in range(len(job_data[\"answers\"])):\n        subs.append(job_data[\"answers\"][i][\"workerMetadata\"][\"identityData\"][\"sub\"])\n    return subs\n\n\ndef track_tasks(input_path, worker_map=None):\n    \"\"\"\n    Takes a path to a folder containing the worker annotation metadata from AWS Sagemaker labeling job and a\n    dictionary mapping AWS worker subs to their names or identification tags and returns a dictionary mapping\n    the names/identification tags to the number of labeling tasks completed.\n\n    If no worker map is provided, this function returns a dictionary mapping the worker \"sub\" fields to\n    the number of tasks they completed.\n\n    :param input_path: string of the path to the directory containing the worker annotation sub-directories\n    :param worker_map: dictionary mapping AWS worker subs to the worker identifications\n    :return: dictionary mapping worker identifications to the number of tasks completed\n    \"\"\"\n    tracker = {}\n    res = {}\n    for direc in os.listdir(input_path):\n        subdir_path = os.path.join(input_path, direc)\n        subdir = os.listdir(subdir_path)\n        json_file_path = os.path.join(subdir_path, subdir[0])\n        with open(json_file_path) as json_file:\n            json_string = json_file.read()\n        subs = get_worker_subs(json_string)\n        for sub in subs:\n            tracker[sub] = tracker.get(sub, 0) + 1\n\n    if worker_map:\n        for sub in tracker:\n            worker = worker_map[sub]\n            res[worker] = tracker[sub]\n        return res\n    return tracker\n\n\ndef main():\n    # sample from completed labeling job\n    print(track_tasks('..\\\\tests\\\\ner\\\\aws_labeling_copy', worker_map={\n        \"7efc17ac-3397-4472-afe5-89184ad145d0\": \"Worker1\",\n        \"afce8c28-969c-4e73-a20f-622ef122f585\": \"Worker2\",\n        \"91f6236e-63c6-4a84-8fd6-1efbab6dedab\": \"Worker3\",\n        \"6f202e93-e6b6-4e1d-8f07-0484b9a9093a\": \"Worker4\",\n        \"2b674d33-f656-44b0-8f90-d70a1ab71ec2\": \"Worker5\"\n        }\n    ))\n    # sample from completed labeling job -- no worker map provided\n    print(track_tasks('..\\\\tests\\\\ner\\\\aws_labeling_copy'))\n    return\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/ner/spacy_ner_tag_dataset.py",
    "content": "\"\"\"\nTest a spacy model on a 4 class dataset\n\"\"\"\n\nimport argparse\nimport json\n\nimport spacy\nfrom spacy.tokens import Doc\n\nfrom stanza.models.ner.utils import process_tags\nfrom stanza.models.ner.scorer import score_by_entity, score_by_token\n\nfrom stanza.utils.confusion import format_confusion\nfrom stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide\n\nfrom stanza.utils.get_tqdm import get_tqdm\ntqdm = get_tqdm()\n\n\"\"\"\nSimplified classes used in the Worldwide dataset are:\n\nDate\nFacility\nLocation\nMisc\nMoney\nNORP\nOrganization\nPerson\nProduct\n\nvs OntoNotes classes:\n\nCARDINAL\nDATE\nEVENT\nFAC\nGPE\nLANGUAGE\nLAW\nLOC\nMONEY\nNORP\nORDINAL\nORG\nPERCENT\nPERSON\nPRODUCT\nQUANTITY\nTIME\nWORK_OF_ART\n\"\"\"\n\ndef test_file(eval_file, tagger, simplify):\n    with open(eval_file) as fin:\n        gold_doc = json.load(fin)\n    gold_doc = [[(x['text'], x['ner']) for x in sentence] for sentence in gold_doc]\n    gold_doc = process_tags(gold_doc, 'bioes')\n\n    if simplify:\n        for doc in gold_doc:\n            for idx, word in enumerate(doc):\n                if word[1] != \"O\":\n                    word = [word[0], simplify_ontonotes_to_worldwide(word[1])]\n                    doc[idx] = word\n\n    ignore_tags = \"Date,DATE\" if simplify else None\n\n    original_text = [[x[0] for x in gold_sentence] for gold_sentence in gold_doc]\n    pred_doc = []\n    for sentence in tqdm(original_text):\n        spacy_sentence = Doc(tagger.vocab, sentence)\n        spacy_sentence = tagger(spacy_sentence)\n        entities = [\"O\" if not token.ent_type_ else \"%s-%s\" % (token.ent_iob_, token.ent_type_) for token in spacy_sentence]\n        if simplify:\n            entities = [simplify_ontonotes_to_worldwide(x) for x in entities]\n        pred_sentence = [[token.text, entity] for token, entity in zip(spacy_sentence, entities)]\n        pred_doc.append(pred_sentence)\n\n    pred_doc = process_tags(pred_doc, 'bioes')\n    pred_tags = [[x[1] for x in sentence] for sentence in pred_doc]\n    gold_tags = [[x[1] for x in sentence] for sentence in gold_doc]\n    print(\"RESULTS ON: %s\" % eval_file)\n    _, _, f_micro, _ = score_by_entity(pred_tags, gold_tags, ignore_tags=ignore_tags)\n    _, _, _, confusion = score_by_token(pred_tags, gold_tags, ignore_tags=ignore_tags)\n    print(\"NER token confusion matrix:\\n{}\".format(format_confusion(confusion, hide_blank=True, transpose=True)))\n    return f_micro\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--ner_model', type=str, default=None,  help='Which spacy model to test')\n    parser.add_argument('filename', type=str, nargs='*', help='which files to test')\n    parser.add_argument('--simplify', default=False, action='store_true', help='Simplify classes to the 8 class Worldwide model')\n    args = parser.parse_args()\n\n    if args.ner_model is None:\n        ner_models = ['en_core_web_sm', 'en_core_web_trf']\n    else:\n        ner_models = [args.ner_model]\n\n    if not args.filename:\n        args.filename = [\"data/ner/en_ontonotes-8class.test.json\",\n                         \"data/ner/en_worldwide-8class.test.json\",\n                         \"data/ner/en_worldwide-8class-africa.test.json\",\n                         \"data/ner/en_worldwide-8class-asia.test.json\",\n                         \"data/ner/en_worldwide-8class-indigenous.test.json\",\n                         \"data/ner/en_worldwide-8class-latam.test.json\",\n                         \"data/ner/en_worldwide-8class-middle_east.test.json\"]\n\n    print(\"Processing the files: %s\" % \",\".join(args.filename))\n\n    results = []\n    model_results = {}\n\n    for ner_model in ner_models:\n        model_results[ner_model] = []\n        # load tagger\n        print(\"-----------------------------\")\n        print(\"Running %s\" % ner_model)\n        print(\"-----------------------------\")\n        tagger = spacy.load(ner_model, disable=[\"tagger\", \"parser\", \"attribute_ruler\", \"lemmatizer\"])\n\n        for filename in args.filename:\n            f_micro = test_file(filename, tagger, args.simplify)\n            f_micro = \"%.2f\" % (f_micro * 100)\n            results.append((ner_model, filename, f_micro))\n            model_results[ner_model].append(f_micro)\n\n    for result in results:\n        print(result)\n\n    for model in model_results.keys():\n        result = [model] + model_results[model]\n        print(\" & \".join(result))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/pretrain/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/pretrain/compare_pretrains.py",
    "content": "import sys\nimport numpy as np\n\nfrom stanza.models.common.pretrain import Pretrain\n\npt1_filename = sys.argv[1]\npt2_filename = sys.argv[2]\n\npt1 = Pretrain(pt1_filename)\npt2 = Pretrain(pt2_filename)\n\nvocab1 = pt1.vocab\nvocab2 = pt2.vocab\n\ncommon_words = [x for x in vocab1 if x in vocab2]\nprint(\"%d shared words, out of %d in %s and %d in %s\" % (len(common_words), len(vocab1), pt1_filename, len(vocab2), pt2_filename))\n\neps = 0.0001\ntotal_norm = 0.0\ntotal_close = 0\n\nwords_different = []\n\nfor word, idx in vocab1._unit2id.items():\n    if word not in vocab2:\n        continue\n    v1 = pt1.emb[idx]\n    v2 = pt2.emb[pt2.vocab[word]]\n    norm = np.linalg.norm(v1 - v2)\n\n    if norm < eps:\n        total_close += 1\n    else:\n        total_norm += norm\n        if len(words_different) < 10:\n            words_different.append(\"|%s|\" % word)\n            #print(word, idx, pt2.vocab[word])\n            #print(v1)\n            #print(v2)\n\nif total_close < len(common_words):\n    avg_norm = total_norm / (len(common_words) - total_close)\n    print(\"%d vectors were close.  Average difference of the others: %f\" % (total_close, avg_norm))\n    print(\"The first few different words were:\\n  %s\" % \"\\n  \".join(words_different))\nelse:\n    print(\"All %d vectors were close!\" % total_close)\n\n    for word, idx in vocab1._unit2id.items():\n        if word not in vocab2:\n            continue\n        if pt2.vocab[word] != idx:\n            break\n    else:\n        print(\"All indices are the same\")\n"
  },
  {
    "path": "stanza/utils/select_backoff.py",
    "content": "import sys\n\nbackoff_models = { \"UD_Breton-KEB\": \"ga_idt\",\n                   \"UD_Czech-PUD\": \"cs_pdt\",\n                   \"UD_English-PUD\": \"en_ewt\",\n                   \"UD_Faroese-OFT\": \"nn_nynorsk\",\n                   \"UD_Finnish-PUD\": \"fi_tdt\",\n                   \"UD_Japanese-Modern\": \"ja_gsd\",\n                   \"UD_Naija-NSC\": \"en_ewt\",\n                   \"UD_Swedish-PUD\": \"sv_talbanken\"\n                 }\n\nprint(backoff_models[sys.argv[1]])\n"
  },
  {
    "path": "stanza/utils/training/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/training/common.py",
    "content": "import argparse\nimport glob\nimport logging\nimport os\nimport pathlib\nimport random\nimport sys\n\nfrom enum import Enum\ntry:\n    from udtools.udeval import build_evaluation_table\nexcept ImportError:\n    from udtools.src.udtools.udeval import build_evaluation_table\n\nfrom stanza.resources.default_packages import default_charlms, lemma_charlms, tokenizer_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS\nfrom stanza.resources.default_packages import no_pretrain_languages, pos_pretrains, depparse_pretrains, default_pretrains\nfrom stanza.models.common.constant import treebank_to_short_name\nfrom stanza.models.common.utils import ud_scores\nfrom stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError\nfrom stanza.utils.datasets import common\nimport stanza.utils.default_paths as default_paths\n\nlogger = logging.getLogger('stanza')\n\nclass Mode(Enum):\n    TRAIN = 1\n    SCORE_DEV = 2\n    SCORE_TEST = 3\n    SCORE_TRAIN = 4\n\nclass ArgumentParserWithExtraHelp(argparse.ArgumentParser):\n    def __init__(self, sub_argparse, *args, **kwargs):\n        super().__init__(*args, **kwargs)  # forwards all unused arguments\n\n        self.sub_argparse = sub_argparse\n\n    def print_help(self, file=None):\n        super().print_help(file=file)\n\n    def format_help(self):\n        help_text = super().format_help()\n        if self.sub_argparse is not None:\n            sub_text = self.sub_argparse.format_help().split(\"\\n\")\n            first_line = -1\n            for line_idx, line in enumerate(sub_text):\n                if line.strip().startswith(\"usage:\"):\n                    first_line = line_idx\n                elif first_line >= 0 and not line.strip():\n                    first_line = line_idx\n                    break\n            help_text = help_text + \"\\n\\nmodel arguments:\" + \"\\n\".join(sub_text[first_line:])\n        return help_text\n\n\ndef build_argparse(sub_argparse=None):\n    parser = ArgumentParserWithExtraHelp(sub_argparse=sub_argparse, formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument('--save_output', dest='save_output', default=False, action='store_true', help=\"Save output - default is to use a temp directory.\")\n\n    parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on.  Use all_ud or ud_all for all UD treebanks')\n\n    parser.add_argument('--train', dest='mode', default=Mode.TRAIN, action='store_const', const=Mode.TRAIN, help='Run in train mode')\n    parser.add_argument('--score_dev', dest='mode', action='store_const', const=Mode.SCORE_DEV, help='Score the dev set')\n    parser.add_argument('--score_test', dest='mode', action='store_const', const=Mode.SCORE_TEST, help='Score the test set')\n    parser.add_argument('--score_train', dest='mode', action='store_const', const=Mode.SCORE_TRAIN, help='Score the train set as a test set.  Currently only implemented for some models')\n\n    # These arguments need to be here so we can identify if the model already exists in the user-specified home\n    # TODO: when all of the model scripts handle their own names, can eliminate this argument\n    parser.add_argument('--save_dir', type=str, default=None, help=\"Root dir for saving models.  If set, will override the model's default.\")\n    parser.add_argument('--save_name', type=str, default=None, help=\"Base name for saving models.  If set, will override the model's default.\")\n\n    parser.add_argument('--charlm_only', action='store_true', default=False, help='When asking for ud_all, filter the ones which have charlms')\n    parser.add_argument('--transformer_only', action='store_true', default=False, help='When asking for ud_all, filter the ones for languages where we have transformers')\n\n    parser.add_argument('--force', dest='force', action='store_true', default=False, help='Retrain existing models')\n    return parser\n\ndef add_charlm_args(parser):\n    parser.add_argument('--charlm', default=\"default\", type=str, help='Which charlm to run on.  Will use the default charlm for this language/model if not set.  Set to None to turn off charlm for languages with a default charlm')\n    parser.add_argument('--no_charlm', dest='charlm', action=\"store_const\", const=None, help=\"Don't use a charlm, even if one is used by default for this package\")\n\ndef main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None, args=None):\n    \"\"\"\n    A main program for each of the run_xyz scripts\n\n    It collects the arguments and runs the main method for each dataset provided.\n    It also tries to look for an existing model and not overwrite it unless --force is provided\n\n    model_name can be a callable expecting the args\n      - the charlm, for example, needs this feature, since it makes\n        both forward and backward models\n    \"\"\"\n    if args is None:\n        logger.info(\"Training program called with:\\n\" + \" \".join(sys.argv))\n        args = sys.argv[1:]\n    else:\n        logger.info(\"Training program called with:\\n\" + \" \".join(args))\n\n    paths = default_paths.get_default_paths()\n\n    parser = build_argparse(sub_argparse)\n    if add_specific_args is not None:\n        add_specific_args(parser)\n    if '--extra_args' in sys.argv:\n        idx = sys.argv.index('--extra_args')\n        extra_args = sys.argv[idx+1:]\n        command_args = parser.parse_args(sys.argv[:idx])\n    else:\n        command_args, extra_args = parser.parse_known_args(args=args)\n\n    # Pass this through to the underlying model as well as use it here\n    # we don't put --save_name here for the awkward situation of\n    # --save_name being specified for an invocation with multiple treebanks\n    if command_args.save_dir:\n        extra_args.extend([\"--save_dir\", command_args.save_dir])\n\n    # if --no_seed is added to the args, we actually want to pick a seed here\n    # that way, save file names will be consistent...\n    # otherwise, it might try to use different save names when using the\n    # train and dev sets, if the random seed is used as part of the save name\n    while '--no_seed' in extra_args:\n        idx = extra_args.index('--no_seed')\n        random_seed = random.randint(0, 1000000000)\n        logger.info(\"Using random seed %d\", random_seed)\n        extra_args[idx:idx+1] = ['--seed', str(random_seed)]\n\n    if callable(model_name):\n        model_name = model_name(command_args)\n\n    mode = command_args.mode\n    treebanks = []\n\n    for treebank in command_args.treebanks:\n        # this is a really annoying typo to make if you copy/paste a\n        # UD directory name on the cluster and your job dies 30s after\n        # being queued for an hour\n        if treebank.endswith(\"/\"):\n            treebank = treebank[:-1]\n        if treebank.lower() in ('ud_all', 'all_ud'):\n            ud_treebanks = common.get_ud_treebanks(paths[\"UDBASE\"])\n            if choose_charlm_method is not None and command_args.charlm_only:\n                logger.info(\"Filtering ud_all treebanks to only those which can use charlm for this model\")\n                ud_treebanks = [x for x in ud_treebanks\n                                if choose_charlm_method(*treebank_to_short_name(x).split(\"_\", 1), 'default') is not None]\n            if command_args.transformer_only:\n                logger.info(\"Filtering ud_all treebanks to only those which can use a transformer for this model\")\n                ud_treebanks = [x for x in ud_treebanks if treebank_to_short_name(x).split(\"_\")[0] in TRANSFORMERS]\n            logger.info(\"Expanding %s to %s\", treebank, \" \".join(ud_treebanks))\n            treebanks.extend(ud_treebanks)\n        else:\n            treebanks.append(treebank)\n\n    for treebank_idx, treebank in enumerate(treebanks):\n        if treebank_idx > 0:\n            logger.info(\"=========================================\")\n\n        short_name = treebank_to_short_name(treebank)\n        logger.debug(\"%s: %s\" % (treebank, short_name))\n\n        save_name_args = []\n        if model_name != 'ete':\n            # ete is several models at once, so we don't set --save_name\n            # theoretically we could handle a parametrized save_name\n            if command_args.save_name:\n                save_name = command_args.save_name\n                # if there's more than 1 treebank, we can't save them all to this save_name\n                # we have to override that value for each treebank\n                if len(treebanks) > 1:\n                    save_name_dir, save_name_filename = os.path.split(save_name)\n                    save_name_filename = \"%s_%s\" % (short_name, save_name_filename)\n                    save_name = os.path.join(save_name_dir, save_name_filename)\n                    logger.info(\"Save file for %s model for %s: %s\", short_name, treebank, save_name)\n                save_name_args = ['--save_name', save_name]\n            # some run scripts can build the model filename\n            # in order to check for models that are already created\n            elif build_model_filename is None:\n                save_name = \"%s_%s.pt\" % (short_name, model_name)\n                logger.info(\"Save file for %s model: %s\", short_name, save_name)\n                save_name_args = ['--save_name', save_name]\n            else:\n                save_name_args = []\n\n            if mode == Mode.TRAIN and not command_args.force:\n                if build_model_filename is not None:\n                    model_path = build_model_filename(paths, short_name, command_args, extra_args)\n                elif command_args.save_dir:\n                    model_path = os.path.join(command_args.save_dir, save_name)\n                else:\n                    save_dir = os.path.join(\"saved_models\", model_dir)\n                    save_name_args.extend([\"--save_dir\", save_dir])\n                    model_path = os.path.join(save_dir, save_name)\n\n                if model_path is None:\n                    # this can happen with the identity lemmatizer, for example\n                    pass\n                elif os.path.exists(model_path):\n                    logger.info(\"%s: %s exists, skipping!\" % (treebank, model_path))\n                    continue\n                else:\n                    logger.info(\"%s: %s does not exist, training new model\" % (treebank, model_path))\n\n        run_treebank(mode, paths, treebank, short_name, command_args, extra_args + save_name_args)\n\ndef run_eval_script(gold_conllu_file, system_conllu_file, evals=None):\n    \"\"\" Wrapper for lemma scorer. \"\"\"\n    evaluation = ud_scores(gold_conllu_file, system_conllu_file)\n\n    if evals is None:\n        return ud_eval.build_evaluation_table(evaluation, verbose=True, counts=False, enhanced=False)\n    else:\n        results = [evaluation[key].f1 for key in evals]\n        max_len = max(5, max(len(e) for e in evals))\n        evals_string = \" \".join((\"{:>%d}\" % max_len).format(e) for e in evals)\n        results_string = \" \".join((\"{:%d.2f}\" % max_len).format(100 * x) for x in results)\n        return evals_string + \"\\n\" + results_string\n\ndef run_eval_script_tokens(eval_gold, eval_pred):\n    return run_eval_script(eval_gold, eval_pred, evals=[\"Tokens\", \"Sentences\", \"Words\"])\n\ndef run_eval_script_mwt(eval_gold, eval_pred):\n    return run_eval_script(eval_gold, eval_pred, evals=[\"Words\"])\n\ndef run_eval_script_pos(eval_gold, eval_pred):\n    return run_eval_script(eval_gold, eval_pred, evals=[\"UPOS\", \"XPOS\", \"UFeats\", \"AllTags\"])\n\ndef run_eval_script_depparse(eval_gold, eval_pred):\n    return run_eval_script(eval_gold, eval_pred, evals=[\"UAS\", \"LAS\", \"CLAS\", \"MLAS\", \"BLEX\"])\n\n\ndef find_wordvec_pretrain(language, default_pretrains, dataset_pretrains=None, dataset=None, model_dir=DEFAULT_MODEL_DIR):\n    # try to get the default pretrain for the language,\n    # but allow the package specific value to override it if that is set\n    default_pt = default_pretrains.get(language, None)\n    if dataset is not None and dataset_pretrains is not None:\n        default_pt = dataset_pretrains.get(language, {}).get(dataset, default_pt)\n\n    if default_pt is not None:\n        default_pt_path = '{}/{}/pretrain/{}.pt'.format(model_dir, language, default_pt)\n        if not os.path.exists(default_pt_path):\n            logger.info(\"Default pretrain should be {}  Attempting to download\".format(default_pt_path))\n            try:\n                download(lang=language, package=None, processors={\"pretrain\": default_pt}, model_dir=model_dir)\n            except UnknownLanguageError:\n                # if there's a pretrain in the directory, hiding this\n                # error will let us find that pretrain later\n                pass\n        if os.path.exists(default_pt_path):\n            if dataset is not None and dataset_pretrains is not None and language in dataset_pretrains and dataset in dataset_pretrains[language]:\n                logger.info(f\"Using default pretrain for {language}:{dataset}, found in {default_pt_path}  To use a different pretrain, specify --wordvec_pretrain_file\")\n            else:\n                logger.info(f\"Using default pretrain for language {language}, found in {default_pt_path}  To use a different pretrain, specify --wordvec_pretrain_file\")\n            return default_pt_path\n\n    pretrain_path = '{}/{}/pretrain/*.pt'.format(model_dir, language)\n    pretrains = glob.glob(pretrain_path)\n    if len(pretrains) == 0:\n        # we already tried to download the default pretrain once\n        # and it didn't work.  maybe the default language package\n        # will have something?\n        logger.warning(f\"Cannot figure out which pretrain to use for '{language}'.  Will download the default package and hope for the best\")\n        try:\n            download(lang=language, model_dir=model_dir)\n        except UnknownLanguageError as e:\n            # this is a very unusual situation\n            # basically, there was a language which we started to add\n            # to the resources, but then didn't release the models\n            # as part of resources.json\n            raise FileNotFoundError(f\"Cannot find any pretrains in {pretrain_path}  No pretrains in the system for this language.  Please prepare an embedding as a .pt and use --wordvec_pretrain_file to specify a .pt file to use\") from e\n        pretrains = glob.glob(pretrain_path)\n    if len(pretrains) == 0:\n        raise FileNotFoundError(f\"Cannot find any pretrains in {pretrain_path}  Try 'stanza.download(\\\"{language}\\\")' to get a default pretrain or use --wordvec_pretrain_file to specify a .pt file to use\")\n    if len(pretrains) > 1:\n        raise FileNotFoundError(f\"Too many pretrains to choose from in {pretrain_path}  Must specify an exact path to a --wordvec_pretrain_file\")\n    pt = pretrains[0]\n    logger.info(f\"Using pretrain found in {pt}  To use a different pretrain, specify --wordvec_pretrain_file\")\n    return pt\n\ndef choose_depparse_pretrain(language, dataset):\n    if language in no_pretrain_languages:\n        return None\n    return find_wordvec_pretrain(language, default_pretrains, depparse_pretrains, dataset)\n\ndef find_charlm_file(direction, language, charlm, model_dir=DEFAULT_MODEL_DIR):\n    \"\"\"\n    Return the path to the forward or backward charlm if it exists for the given package\n\n    If we can figure out the package, but can't find it anywhere, we try to download it\n    \"\"\"\n    saved_path = 'saved_models/charlm/{}_{}_{}_charlm.pt'.format(language, charlm, direction)\n    if os.path.exists(saved_path):\n        logger.info(f'Using model {saved_path} for {direction} charlm')\n        return saved_path\n\n    resource_path = '{}/{}/{}_charlm/{}.pt'.format(model_dir, language, direction, charlm)\n    if os.path.exists(resource_path):\n        logger.info(f'Using model {resource_path} for {direction} charlm')\n        return resource_path\n\n    try:\n        download(lang=language, package=None, processors={f\"{direction}_charlm\": charlm}, model_dir=model_dir)\n        if os.path.exists(resource_path):\n            logger.info(f'Downloaded model, using model {resource_path} for {direction} charlm')\n            return resource_path\n    except ValueError as e:\n        raise FileNotFoundError(f\"Cannot find {direction} charlm in either {saved_path} or {resource_path}  Attempted downloading {charlm} but that did not work\") from e\n\n    raise FileNotFoundError(f\"Cannot find {direction} charlm in either {saved_path} or {resource_path}  Attempted downloading {charlm} but that did not work\")\n\ndef build_charlm_args(language, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR, use_backward_model=True):\n    \"\"\"\n    If specified, return forward and backward charlm args\n    \"\"\"\n    if charlm:\n        try:\n            forward = find_charlm_file('forward', language, charlm, model_dir=model_dir)\n            if use_backward_model:\n                backward = find_charlm_file('backward', language, charlm, model_dir=model_dir)\n        except FileNotFoundError as e:\n            # if we couldn't find sd_isra when training an SD model,\n            # for example, but isra exists, we try to download the\n            # shorter model name\n            if charlm.startswith(language + \"_\"):\n                short_charlm = charlm[len(language)+1:]\n                try:\n                    forward = find_charlm_file('forward', language, short_charlm, model_dir=model_dir)\n                    if use_backward_model:\n                        backward = find_charlm_file('backward', language, short_charlm, model_dir=model_dir)\n                except FileNotFoundError as e2:\n                    raise FileNotFoundError(\"Tried to find charlm %s, which doesn't exist.  Also tried %s, but didn't find that either\" % (charlm, short_charlm)) from e\n                logger.warning(\"Was asked to find charlm %s, which does not exist.  Did find %s though\", charlm, short_charlm)\n            else:\n                raise\n\n        char_args = ['--charlm_forward_file', forward]\n        if use_backward_model:\n            char_args += ['--charlm_backward_file', backward]\n        if not base_args:\n            return char_args\n        return ['--charlm',\n                '--charlm_shorthand', f'{language}_{charlm}'] + char_args\n\n    return []\n\ndef choose_charlm(language, dataset, charlm, language_charlms, dataset_charlms):\n    \"\"\"\n    charlm == \"default\" means the default charlm for this dataset or language\n    charlm == None is no charlm\n    \"\"\"\n    default_charlm = language_charlms.get(language, None)\n    specific_charlm = dataset_charlms.get(language, {}).get(dataset, None)\n\n    if charlm is None:\n        return None\n    elif charlm != \"default\":\n        return charlm\n    elif dataset in dataset_charlms.get(language, {}):\n        # this way, a \"\" or None result gets honored\n        # thus treating \"not in the map\" as a way for dataset_charlms to signal to use the default\n        return specific_charlm\n    elif default_charlm:\n        return default_charlm\n    else:\n        return None\n\ndef choose_pos_charlm(short_language, dataset, charlm):\n    \"\"\"\n    charlm == \"default\" means the default charlm for this dataset or language\n    charlm == None is no charlm\n    \"\"\"\n    return choose_charlm(short_language, dataset, charlm, default_charlms, pos_charlms)\n\ndef choose_depparse_charlm(short_language, dataset, charlm):\n    \"\"\"\n    charlm == \"default\" means the default charlm for this dataset or language\n    charlm == None is no charlm\n    \"\"\"\n    return choose_charlm(short_language, dataset, charlm, default_charlms, depparse_charlms)\n\ndef choose_lemma_charlm(short_language, dataset, charlm):\n    \"\"\"\n    charlm == \"default\" means the default charlm for this dataset or language\n    charlm == None is no charlm\n    \"\"\"\n    return choose_charlm(short_language, dataset, charlm, default_charlms, lemma_charlms)\n\ndef choose_tokenizer_charlm(short_language, dataset, charlm):\n    \"\"\"\n    charlm == \"default\" means the default charlm for this dataset or language\n    charlm == None is no charlm\n    \"\"\"\n    return choose_charlm(short_language, dataset, charlm, default_charlms, tokenizer_charlms)\n\ndef choose_transformer(short_language, command_args, extra_args, warn=True, layers=False):\n    \"\"\"\n    Choose a transformer using the default options for this language\n    \"\"\"\n    bert_args = []\n    if command_args is not None and command_args.use_bert and '--bert_model' not in extra_args:\n        if short_language in TRANSFORMERS:\n            bert_args = ['--bert_model', TRANSFORMERS.get(short_language)]\n            if layers and short_language in TRANSFORMER_LAYERS and '--bert_hidden_layers' not in extra_args:\n                bert_args.extend(['--bert_hidden_layers', str(TRANSFORMER_LAYERS.get(short_language))])\n        elif warn:\n            logger.error(\"Transformer requested, but no default transformer for %s  Specify one using --bert_model\" % short_language)\n\n    return bert_args\n\ndef build_pos_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):\n    charlm = choose_pos_charlm(short_language, dataset, charlm)\n    charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)\n    return charlm_args\n\ndef build_lemma_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):\n    charlm = choose_lemma_charlm(short_language, dataset, charlm)\n    charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)\n    return charlm_args\n\ndef build_depparse_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):\n    charlm = choose_depparse_charlm(short_language, dataset, charlm)\n    charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)\n    return charlm_args\n\ndef build_tokenizer_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):\n    charlm = choose_tokenizer_charlm(short_language, dataset, charlm)\n    charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir, use_backward_model=False)\n    return charlm_args\n\n\ndef build_wordvec_args(short_language, dataset, extra_args, task_pretrains):\n    if '--wordvec_pretrain_file' in extra_args or '--no_pretrain' in extra_args:\n        return []\n\n    if short_language in no_pretrain_languages:\n        # we couldn't find word vectors for a few languages...:\n        # coptic, naija, old russian, turkish german, swedish sign language\n        logger.warning(\"No known word vectors for language {}  If those vectors can be found, please update the training scripts.\".format(short_language))\n        return [\"--no_pretrain\"]\n    else:\n        if short_language in task_pretrains and dataset in task_pretrains[short_language]:\n            dataset_pretrains = task_pretrains\n        else:\n            dataset_pretrains = {}\n        wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, dataset_pretrains, dataset)\n        return [\"--wordvec_pretrain_file\", wordvec_pretrain]\n\ndef build_pos_wordvec_args(short_language, dataset, extra_args):\n    return build_wordvec_args(short_language, dataset, extra_args, pos_pretrains)\n\ndef build_depparse_wordvec_args(short_language, dataset, extra_args):\n    return build_wordvec_args(short_language, dataset, extra_args, depparse_pretrains)\n\n"
  },
  {
    "path": "stanza/utils/training/compose_ete_results.py",
    "content": "\"\"\"\nTurn the ETE results into markdown\n\nParses blocks like this from the model eval script\n\n2022-01-14 01:23:34 INFO: End to end results for af_afribooms models on af_afribooms test data:\nMetric     | Precision |    Recall |  F1 Score | AligndAcc\n-----------+-----------+-----------+-----------+-----------\nTokens     |     99.93 |     99.92 |     99.93 |\nSentences  |    100.00 |    100.00 |    100.00 |\nWords      |     99.93 |     99.92 |     99.93 |\nUPOS       |     97.97 |     97.96 |     97.97 |     98.04\nXPOS       |     93.98 |     93.97 |     93.97 |     94.04\nUFeats     |     97.23 |     97.22 |     97.22 |     97.29\nAllTags    |     93.89 |     93.88 |     93.88 |     93.95\nLemmas     |     97.40 |     97.39 |     97.39 |     97.46\nUAS        |     87.39 |     87.38 |     87.38 |     87.45\nLAS        |     83.57 |     83.56 |     83.57 |     83.63\nCLAS       |     76.88 |     76.45 |     76.66 |     76.52\nMLAS       |     72.28 |     71.87 |     72.07 |     71.94\nBLEX       |     73.20 |     72.79 |     73.00 |     72.86\n\n\nTurns them into a markdown table.\n\nIncluded is an attempt to mark the default packages with a green check.\n  <i class=\"fas fa-check\" style=\"color:#33a02c\"></i>\n\"\"\"\n\nimport argparse\n\nfrom stanza.models.common.constant import pretty_langcode_to_lang\nfrom stanza.models.common.short_name_to_treebank import short_name_to_treebank\nfrom stanza.utils.training.run_ete import RESULTS_STRING\nfrom stanza.resources.default_packages import default_treebanks\n\nEXPECTED_ORDER = [\"Tokens\", \"Sentences\", \"Words\", \"UPOS\", \"XPOS\", \"UFeats\", \"AllTags\", \"Lemmas\", \"UAS\", \"LAS\", \"CLAS\", \"MLAS\", \"BLEX\"]\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"filenames\", type=str, nargs=\"+\", help=\"Which file(s) to read\")\nargs = parser.parse_args()\n\nlines = []\nfor filename in args.filenames:\n    with open(filename) as fin:\n        lines.extend(fin.readlines())\n\nblocks = []\nindex = 0\nwhile index < len(lines):\n    line = lines[index]\n    if line.find(RESULTS_STRING) < 0:\n        index = index + 1\n        continue\n\n    line = line[line.find(RESULTS_STRING) + len(RESULTS_STRING):].strip()\n    short_name = line.split()[0]\n\n    # skip the header of the expected output\n    index = index + 1\n    line = lines[index]\n    pieces = line.split(\"|\")\n    assert pieces[0].strip() == 'Metric', \"output format changed?\"\n    assert pieces[3].strip() == 'F1 Score', \"output format changed?\"\n\n    index = index + 1\n    line = lines[index]\n    assert line.startswith(\"-----\"), \"output format changed?\"\n\n    index = index + 1\n\n    block = lines[index:index+13]\n    assert len(block) == 13\n    index = index + 13\n\n    block = [x.split(\"|\") for x in block]\n    assert all(x[0].strip() == y for x, y in zip(block, EXPECTED_ORDER)), \"output format changed?\"\n    lcode, short_dataset = short_name.split(\"_\", 1)\n    language = pretty_langcode_to_lang(lcode)\n    treebank = short_name_to_treebank(short_name)\n    long_dataset = treebank.split(\"-\")[-1]\n\n    checkmark = \"\"\n    if default_treebanks[lcode] == short_dataset:\n        checkmark = '<i class=\"fas fa-check\" style=\"color:#33a02c\"></i>'\n\n    block = [language, \"[%s](%s)\" % (long_dataset, \"https://github.com/UniversalDependencies/%s\" % treebank), lcode, checkmark] + [x[3].strip() for x in block]\n    blocks.append(block)\n\nPREFIX = [\"&#8203;Macro Avg\", \"&#8203;\", \"&#8203;\", \"\"]\n\navg = [sum(float(x[i]) for x in blocks) / len(blocks) for i in range(len(PREFIX), len(EXPECTED_ORDER) + len(PREFIX))]\navg = PREFIX + [\"%.2f\" % x for x in avg]\nblocks = sorted(blocks)\nblocks = [avg] + blocks\n\nchart = [\"|%s|\" % \"  |  \".join(x) for x in blocks]\nfor line in chart:\n    print(line)\n\n"
  },
  {
    "path": "stanza/utils/training/remove_constituency_optimizer.py",
    "content": "\"\"\"Saved a huge, bloated model with an optimizer?  Use this to remove it, greatly shrinking the model size\n\nThis tries to find reasonable defaults for word vectors and charlm\n(which need to be loaded so that the model knows the matrix sizes)\n\nso ideally all that needs to be run is\n\npython3 stanza/utils/training/remove_constituency_optimizer.py <treebanks>\npython3 stanza/utils/training/remove_constituency_optimizer.py da_arboretum ...\n\nThis can also be used to load and save models as part of an update\nto the serialized format\n\"\"\"\n\nimport argparse\nimport logging\nimport os\n\nfrom stanza.models import constituency_parser\nfrom stanza.models.common.constant import treebank_to_short_name\nfrom stanza.resources.default_packages import default_charlms, default_pretrains\nfrom stanza.utils.training import common\n\nlogger = logging.getLogger('stanza')\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')\n    parser.add_argument('--charlm', default=\"default\", type=str, help='Which charlm to run on.  Will use the default charlm for this language/model if not set.  Set to None to turn off charlm for languages with a default charlm')\n    parser.add_argument('--no_charlm', dest='charlm', action=\"store_const\", const=None, help=\"Don't use a charlm, even if one is used by default for this package\")\n\n    parser.add_argument('--load_dir', type=str, default=\"saved_models/constituency\", help=\"Root dir for getting the models to resave.\")\n    parser.add_argument('--save_dir', type=str, default=\"resaved_models/constituency\", help=\"Root dir for resaving the models.\")\n\n    parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on.  Use all_ud or ud_all for all UD treebanks')\n\n    args = parser.parse_args()\n    return args\n\ndef main():\n    \"\"\"\n    For each of the models specified, load and resave the model\n\n    The resaved model will have the optimizer removed\n    \"\"\"\n    args = parse_args()\n    os.makedirs(args.save_dir, exist_ok=True)\n\n    for treebank in args.treebanks:\n        logger.info(\"PROCESSING %s\", treebank)\n        short_name = treebank_to_short_name(treebank)\n        language, dataset = short_name.split(\"_\", maxsplit=1)\n        logger.info(\"%s: %s %s\", short_name, language, dataset)\n\n        if not args.wordvec_pretrain_file:\n            # will throw an error if the pretrain can't be found\n            wordvec_pretrain = common.find_wordvec_pretrain(language, default_pretrains)\n            wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]\n        else:\n            wordvec_args = []\n\n        charlm = common.choose_charlm(language, dataset, args.charlm, default_charlms, {})\n        charlm_args = common.build_charlm_args(language, charlm, base_args=False)\n\n        base_name = '{}_constituency.pt'.format(short_name)\n        load_name = os.path.join(args.load_dir, base_name)\n        save_name = os.path.join(args.save_dir, base_name)\n        resave_args = ['--mode', 'remove_optimizer',\n                       '--load_name', load_name,\n                       '--save_name', save_name,\n                       '--save_dir', \".\",\n                       '--shorthand', short_name]\n        resave_args = resave_args + wordvec_args + charlm_args\n        constituency_parser.main(resave_args)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/training/run_charlm.py",
    "content": "\"\"\"\nTrains or scores a charlm model.\n\"\"\"\n\nimport logging\nimport os\n\nfrom stanza.models import charlm\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode\n\nlogger = logging.getLogger('stanza')\n\n\ndef add_charlm_args(parser):\n    \"\"\"\n    Extra args for the charlm: forward/backward\n    \"\"\"\n    parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help=\"Forward or backward language model\")\n    parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help=\"Train a forward language model\")\n    parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help=\"Train a backward language model\")\n\n\ndef run_treebank(mode, paths, treebank, short_name,\n                 command_args, extra_args):\n    short_language, dataset_name = short_name.split(\"_\", 1)\n\n    train_dir = os.path.join(paths[\"CHARLM_DATA_DIR\"], short_language, dataset_name, \"train\")\n\n    dev_file  = os.path.join(paths[\"CHARLM_DATA_DIR\"], short_language, dataset_name, \"dev.txt\")\n    if not os.path.exists(dev_file) and os.path.exists(dev_file + \".xz\"):\n        dev_file = dev_file + \".xz\"\n\n    test_file = os.path.join(paths[\"CHARLM_DATA_DIR\"], short_language, dataset_name, \"test.txt\")\n    if not os.path.exists(test_file) and os.path.exists(test_file + \".xz\"):\n        test_file = test_file + \".xz\"\n\n    # python -m stanza.models.charlm --train_dir $train_dir --eval_file $dev_file \\\n    #     --direction $direction --shorthand $short --mode train $args\n    # python -m stanza.models.charlm --eval_file $dev_file \\\n    #     --direction $direction --shorthand $short --mode predict $args\n    # python -m stanza.models.charlm --eval_file $test_file \\\n    #     --direction $direction --shorthand $short --mode predict $args\n\n    direction = command_args.direction\n    default_args = ['--%s' % direction,\n                    '--shorthand', short_name]\n    if mode == Mode.TRAIN:\n        train_args = ['--mode', 'train']\n        if '--train_dir' not in extra_args:\n            train_args += ['--train_dir', train_dir]\n        if '--eval_file' not in extra_args:\n            train_args += ['--eval_file', dev_file]\n        train_args = train_args + default_args + extra_args\n        logger.info(\"Running train step with args: %s\", train_args)\n        charlm.main(train_args)\n\n    if mode == Mode.SCORE_DEV:\n        dev_args = ['--mode', 'predict']\n        if '--eval_file' not in extra_args:\n            dev_args += ['--eval_file', dev_file]\n        dev_args = dev_args + default_args + extra_args\n        logger.info(\"Running dev step with args: %s\", dev_args)\n        charlm.main(dev_args)\n\n    if mode == Mode.SCORE_TEST:\n        test_args = ['--mode', 'predict']\n        if '--eval_file' not in extra_args:\n            test_args += ['--eval_file', test_file]\n        test_args = test_args + default_args + extra_args\n        logger.info(\"Running test step with args: %s\", test_args)\n        charlm.main(test_args)\n\n\ndef get_model_name(args):\n    \"\"\"\n    The charlm saves forward and backward charlms to the same dir, but with different filenames\n    \"\"\"\n    return \"%s_charlm\" % args.direction\n\ndef main():\n    common.main(run_treebank, \"charlm\", get_model_name, add_charlm_args, charlm.build_argparse())\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_constituency.py",
    "content": "\"\"\"\nTrains or scores a constituency model.\n\nCurrently a suuuuper preliminary script.\n\nExample of how to run on multiple parsers at the same time on the Stanford workqueue:\n\nfor i in `echo 1000 1001 1002 1003 1004`; do nlprun -d a6000 \"python3 stanza/utils/training/run_constituency.py vi_vlsp23 --use_bert --stage1_bert_finetun --save_name vi_vlsp23_$i.pt --seed $i --epochs 200 --force\" -o vi_vlsp23_$i.out; done\n\n\"\"\"\n\nimport logging\nimport os\n\nfrom stanza.models import constituency_parser\nfrom stanza.models.constituency.retagging import RETAG_METHOD\nfrom stanza.utils.datasets.constituency import prepare_con_dataset\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain\n\nfrom stanza.resources.default_packages import default_charlms, default_pretrains\n\nlogger = logging.getLogger('stanza')\n\ndef add_constituency_args(parser):\n    add_charlm_args(parser)\n\n    parser.add_argument('--use_bert', default=False, action=\"store_true\", help='Use the default transformer for this language')\n\n    parser.add_argument('--parse_text', dest='mode', action='store_const', const=\"parse_text\", help='Parse a text file')\n\ndef build_wordvec_args(short_language, dataset, extra_args):\n    if '--wordvec_pretrain_file' not in extra_args:\n        # will throw an error if the pretrain can't be found\n        wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains)\n        wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]\n    else:\n        wordvec_args = []\n\n    return wordvec_args\n\ndef build_default_args(paths, short_language, dataset, command_args, extra_args):\n    if short_language in RETAG_METHOD:\n        retag_args = [\"--retag_method\", RETAG_METHOD[short_language]]\n    else:\n        retag_args = []\n\n    wordvec_args = build_wordvec_args(short_language, dataset, extra_args)\n\n    charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {})\n    charlm_args = build_charlm_args(short_language, charlm, base_args=False)\n\n    bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=True, layers=True)\n    default_args = retag_args + wordvec_args + charlm_args + bert_args\n\n    return default_args\n\ndef build_model_filename(paths, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)\n\n    train_args = [\"--shorthand\", short_name,\n                  \"--mode\", \"train\"]\n    train_args = train_args + default_args\n    if command_args.save_name is not None:\n        train_args.extend([\"--save_name\", command_args.save_name])\n    if command_args.save_dir is not None:\n        train_args.extend([\"--save_dir\", command_args.save_dir])\n    args = constituency_parser.parse_args(train_args)\n    save_name = constituency_parser.build_model_filename(args)\n    return save_name\n\n\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    constituency_dir = paths[\"CONSTITUENCY_DATA_DIR\"]\n    short_language, dataset = short_name.split(\"_\")\n\n    train_file = os.path.join(constituency_dir, f\"{short_name}_train.mrg\")\n    dev_file   = os.path.join(constituency_dir, f\"{short_name}_dev.mrg\")\n    test_file  = os.path.join(constituency_dir, f\"{short_name}_test.mrg\")\n\n    if not os.path.exists(train_file) or not os.path.exists(dev_file) or not os.path.exists(test_file):\n        logger.warning(f\"The data for {short_name} is missing or incomplete.  Attempting to rebuild...\")\n        try:\n            prepare_con_dataset.main(short_name)\n        except:\n            logger.error(f\"Unable to build the data.  Please correctly build the files in {train_file}, {dev_file}, {test_file} and then try again.\")\n            raise\n\n    default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)\n\n    if mode == Mode.TRAIN:\n        train_args = ['--train_file', train_file,\n                      '--eval_file', dev_file,\n                      '--shorthand', short_name,\n                      '--mode', 'train']\n        train_args = train_args + default_args + extra_args\n        logger.info(\"Running train step with args: {}\".format(train_args))\n        constituency_parser.main(train_args)\n\n    if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n        dev_args = ['--eval_file', dev_file,\n                    '--shorthand', short_name,\n                    '--mode', 'predict']\n        dev_args = dev_args + default_args + extra_args\n        logger.info(\"Running dev step with args: {}\".format(dev_args))\n        constituency_parser.main(dev_args)\n\n    if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:\n        test_args = ['--eval_file', test_file,\n                     '--shorthand', short_name,\n                     '--mode', 'predict']\n        test_args = test_args + default_args + extra_args\n        logger.info(\"Running test step with args: {}\".format(test_args))\n        constituency_parser.main(test_args)\n\n    if mode == \"parse_text\":\n        text_args = ['--shorthand', short_name,\n                     '--mode', 'parse_text']\n        text_args = text_args + default_args + extra_args\n        logger.info(\"Processing text with args: {}\".format(text_args))\n        constituency_parser.main(text_args)\n\ndef main():\n    common.main(run_treebank, \"constituency\", \"constituency\", add_constituency_args, sub_argparse=constituency_parser.build_argparse(), build_model_filename=build_model_filename)\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_depparse.py",
    "content": "import io\nimport logging\nimport os\n\nfrom stanza.models import parser\n\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, add_charlm_args, build_depparse_charlm_args, choose_depparse_charlm, choose_transformer\nfrom stanza.utils.training.common import build_depparse_wordvec_args\n\nfrom stanza.resources.default_packages import default_charlms, depparse_charlms\n\nlogger = logging.getLogger('stanza')\n\ndef add_depparse_args(parser):\n    add_charlm_args(parser)\n\n    parser.add_argument('--use_bert', default=False, action=\"store_true\", help='Use the default transformer for this language')\n\n#  TODO: refactor with run_pos\ndef build_model_filename(paths, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    # TODO: can avoid downloading the charlm at this point, since we\n    # might not even be training\n    charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm)\n\n    bert_args = choose_transformer(short_language, command_args, extra_args, warn=False)\n\n    train_args = [\"--shorthand\", short_name,\n                  \"--mode\", \"train\"]\n    # TODO: also, this downloads the wordvec, which we might not want to do yet\n    train_args = train_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args\n    if command_args.save_name is not None:\n        train_args.extend([\"--save_name\", command_args.save_name])\n    if command_args.save_dir is not None:\n        train_args.extend([\"--save_dir\", command_args.save_dir])\n    args = parser.parse_args(train_args)\n    save_name = parser.model_file_name(args)\n    return save_name\n\n\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    # TODO: refactor these blocks?\n    depparse_dir   = paths[\"DEPPARSE_DATA_DIR\"]\n    train_file     = f\"{depparse_dir}/{short_name}.train.in.conllu\"\n    dev_in_file    = f\"{depparse_dir}/{short_name}.dev.in.conllu\"\n    dev_pred_file  = f\"{depparse_dir}/{short_name}.dev.pred.conllu\"\n    test_in_file   = f\"{depparse_dir}/{short_name}.test.in.conllu\"\n    test_pred_file = f\"{depparse_dir}/{short_name}.test.pred.conllu\"\n\n    eval_file = None\n    if '--eval_file' in extra_args:\n        eval_file = extra_args[extra_args.index('--eval_file') + 1]\n\n    charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm)\n\n    bert_args = choose_transformer(short_language, command_args, extra_args)\n\n    if mode == Mode.TRAIN:\n        zip_train_file = os.path.splitext(train_file)[0] + \".zip\"\n        if os.path.exists(train_file) and os.path.exists(zip_train_file):\n            logger.error(\"POS TRAIN FILE %s and %s both exist... this is very confusing, skipping %s\" % (train_file, zip_train_file, short_name))\n            return\n        if os.path.exists(zip_train_file):\n            train_file = zip_train_file\n        if not os.path.exists(train_file):\n            logger.error(\"TRAIN FILE NOT FOUND: %s ... skipping\" % train_file)\n            return\n\n        # some languages need reduced batch size\n        if short_name == 'de_hdt':\n            # 'UD_German-HDT'\n            batch_size = \"1300\"\n        elif short_name in ('hr_set', 'fi_tdt', 'ru_taiga', 'cs_cltt', 'gl_treegal', 'lv_lvtb', 'ro_simonero'):\n            # 'UD_Croatian-SET', 'UD_Finnish-TDT', 'UD_Russian-Taiga',\n            # 'UD_Czech-CLTT', 'UD_Galician-TreeGal', 'UD_Latvian-LVTB' 'Romanian-SiMoNERo'\n            batch_size = \"3000\"\n        else:\n            batch_size = \"5000\"\n\n        train_args = [\"--wordvec_dir\", paths[\"WORDVEC_DIR\"],\n                      \"--train_file\", train_file,\n                      \"--eval_file\", eval_file if eval_file else dev_in_file,\n                      \"--batch_size\", batch_size,\n                      \"--lang\", short_language,\n                      \"--shorthand\", short_name,\n                      \"--mode\", \"train\"]\n        train_args = train_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args\n        train_args = train_args + extra_args\n        logger.info(\"Running train depparse for {} with args {}\".format(treebank, train_args))\n        parser.main(train_args)\n\n    if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n        dev_args = [\"--wordvec_dir\", paths[\"WORDVEC_DIR\"],\n                    \"--eval_file\", eval_file if eval_file else dev_in_file,\n                    \"--lang\", short_language,\n                    \"--shorthand\", short_name,\n                    \"--mode\", \"predict\"]\n        if command_args.save_output:\n            dev_args.extend([\"--output_file\", dev_pred_file])\n        dev_args = dev_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args\n        dev_args = dev_args + extra_args\n        logger.info(\"Running dev depparse for {} with args {}\".format(treebank, dev_args))\n        _, dev_doc = parser.main(dev_args)\n\n        if '--no_gold_labels' not in extra_args:\n            if not command_args.save_output:\n                dev_pred_file = \"{:C}\\n\\n\".format(dev_doc)\n                dev_pred_file = io.StringIO(dev_pred_file)\n            results = common.run_eval_script_depparse(eval_file if eval_file else dev_in_file, dev_pred_file)\n            logger.info(\"Finished running dev set on\\n{}\\n{}\".format(treebank, results))\n        if command_args.save_output:\n            logger.info(\"Output saved to %s\", dev_pred_file)\n\n    if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:\n        test_args = [\"--wordvec_dir\", paths[\"WORDVEC_DIR\"],\n                     \"--eval_file\", eval_file if eval_file else test_in_file,\n                     \"--lang\", short_language,\n                     \"--shorthand\", short_name,\n                     \"--mode\", \"predict\"]\n        if command_args.save_output:\n            test_args.extend([\"--output_file\", test_pred_file])\n        test_args = test_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args\n        test_args = test_args + extra_args\n        logger.info(\"Running test depparse for {} with args {}\".format(treebank, test_args))\n        _, test_doc = parser.main(test_args)\n\n        if '--no_gold_labels' not in extra_args:\n            if not command_args.save_output:\n                test_pred_file = \"{:C}\\n\\n\".format(test_doc)\n                test_pred_file = io.StringIO(test_pred_file)\n            results = common.run_eval_script_depparse(eval_file if eval_file else test_in_file, test_pred_file)\n            logger.info(\"Finished running test set on\\n{}\\n{}\".format(treebank, results))\n        if command_args.save_output:\n            logger.info(\"Output saved to %s\", test_pred_file)\n\n\ndef main():\n    common.main(run_treebank, \"depparse\", \"parser\", add_depparse_args, sub_argparse=parser.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_depparse_charlm)\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_ete.py",
    "content": "\"\"\"\nRuns a pipeline end-to-end, reports conll scores.\n\nFor example, you can do\n  python3 stanza/utils/training/run_ete.py it_isdt --score_test\nYou can run on all models at once:\n  python3 stanza/utils/training/run_ete.py ud_all --score_test\n\nYou can also run one model on a different model's data:\n  python3 stanza/utils/training/run_ete.py it_isdt --score_dev --test_data it_vit\n  python3 stanza/utils/training/run_ete.py it_isdt --score_test --test_data it_vit\n\nRunning multiple models with a --test_data flag will run them all on the same data:\n  python3 stanza/utils/training/run_ete.py it_combined it_isdt it_vit --score_test --test_data it_vit\n\nIf run with no dataset arguments, then the dataset used is the train\ndata, which may or may not be useful.\n\"\"\"\n\nimport logging\nimport os\nimport tempfile\n\nfrom stanza.models import identity_lemmatizer\nfrom stanza.models import lemmatizer\nfrom stanza.models import mwt_expander\nfrom stanza.models import parser\nfrom stanza.models import tagger\nfrom stanza.models import tokenizer\n\nfrom stanza.models.common.constant import treebank_to_short_name\n\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, build_tokenizer_charlm_args, build_pos_charlm_args, build_lemma_charlm_args, build_depparse_charlm_args, build_pos_wordvec_args, build_depparse_wordvec_args\nfrom stanza.utils.training.run_lemma import check_lemmas\nfrom stanza.utils.training.run_mwt import check_mwt\n\nlogger = logging.getLogger('stanza')\n\n# a constant so that the script which looks for these results knows what to look for\nRESULTS_STRING = \"End to end results for\"\n\ndef add_args(parser):\n    parser.add_argument('--test_data', default=None, type=str, help='Which data to test on, if not using the default data for this model')\n    common.add_charlm_args(parser)\n\ndef run_ete(paths, dataset, short_name, command_args, extra_args):\n    short_language, package = short_name.split(\"_\", 1)\n\n    tokenize_dir = paths[\"TOKENIZE_DATA_DIR\"]\n    mwt_dir      = paths[\"MWT_DATA_DIR\"]\n    lemma_dir    = paths[\"LEMMA_DATA_DIR\"]\n    ete_dir      = paths[\"ETE_DATA_DIR\"]\n    wordvec_dir  = paths[\"WORDVEC_DIR\"]\n\n    # run models in the following order:\n    #   tokenize\n    #   mwt, if exists\n    #   pos\n    #   lemma, if exists\n    #   depparse\n    # the output of each step is either kept or discarded based on the\n    # value of command_args.save_output\n\n    if command_args and command_args.test_data:\n        test_short_name = treebank_to_short_name(command_args.test_data)\n    else:\n        test_short_name = short_name\n\n    # TOKENIZE step\n    # the raw data to process starts in tokenize_dir\n    # retokenize it using the saved model\n    tokenizer_type = \"--txt_file\"\n    tokenizer_file = f\"{tokenize_dir}/{test_short_name}.{dataset}.txt\"\n\n    tokenizer_output = f\"{ete_dir}/{short_name}.{dataset}.tokenizer.conllu\"\n\n    tokenizer_args = [\"--mode\", \"predict\",\n                      tokenizer_type, tokenizer_file,\n                      \"--lang\", short_language,\n                      \"--conll_file\", tokenizer_output,\n                      \"--shorthand\", short_name]\n    tokenizer_charlm_args = build_tokenizer_charlm_args(short_language, package, command_args.charlm)\n    tokenizer_args = tokenizer_args + tokenizer_charlm_args + extra_args\n    logger.info(\"-----  TOKENIZER  ----------\")\n    logger.info(\"Running tokenizer step with args: {}\".format(tokenizer_args))\n    tokenizer.main(tokenizer_args)\n\n    # If the data has any MWT in it, there should be an MWT model\n    # trained, so run that.  Otherwise, we skip MWT\n    mwt_train_file = f\"{mwt_dir}/{short_name}.train.in.conllu\"\n    logger.info(\"-----  MWT        ----------\")\n    if check_mwt(mwt_train_file):\n        mwt_output = f\"{ete_dir}/{short_name}.{dataset}.mwt.conllu\"\n        mwt_args = ['--eval_file', tokenizer_output,\n                    '--output_file', mwt_output,\n                    '--lang', short_language,\n                    '--shorthand', short_name,\n                    '--mode', 'predict']\n        mwt_args = mwt_args + extra_args\n        logger.info(\"Running mwt step with args: {}\".format(mwt_args))\n        mwt_expander.main(mwt_args)\n    else:\n        logger.info(\"No MWT in training data.  Skipping\")\n        mwt_output = tokenizer_output\n\n    # Run the POS step\n    # TODO: add batch args\n    # TODO: add transformer args\n    logger.info(\"-----  POS        ----------\")\n    pos_output = f\"{ete_dir}/{short_name}.{dataset}.pos.conllu\"\n    pos_args = ['--wordvec_dir', wordvec_dir,\n                '--eval_file', mwt_output,\n                '--output_file', pos_output,\n                '--lang', short_language,\n                '--shorthand', short_name,\n                '--mode', 'predict',\n                # the MWT is not preserving the tags,\n                # so we don't ask the tagger to report a score\n                # the ETE will score the whole thing at the end\n                '--no_gold_labels']\n\n    pos_charlm_args = build_pos_charlm_args(short_language, package, command_args.charlm)\n\n    pos_args = pos_args + build_pos_wordvec_args(short_language, package, extra_args) + pos_charlm_args + extra_args\n    logger.info(\"Running pos step with args: {}\".format(pos_args))\n    tagger.main(pos_args)\n\n    # Run the LEMMA step.  If there are no lemmas in the training\n    # data, use the identity lemmatizer.\n    logger.info(\"-----  LEMMA      ----------\")\n    lemma_train_file = f\"{lemma_dir}/{short_name}.train.in.conllu\"\n    lemma_output = f\"{ete_dir}/{short_name}.{dataset}.lemma.conllu\"\n    lemma_args = ['--eval_file', pos_output,\n                  '--output_file', lemma_output,\n                  '--shorthand', short_name,\n                  '--mode', 'predict']\n    if check_lemmas(lemma_train_file):\n        lemma_charlm_args = build_lemma_charlm_args(short_language, package, command_args.charlm)\n        lemma_args = lemma_args + lemma_charlm_args + extra_args\n        logger.info(\"Running lemmatizer step with args: {}\".format(lemma_args))\n        lemmatizer.main(lemma_args)\n    else:\n        lemma_args = lemma_args + extra_args\n        logger.info(\"No lemmas in training data\")\n        logger.info(\"Running identity lemmatizer step with args: {}\".format(lemma_args))\n        identity_lemmatizer.main(lemma_args)\n\n    # Run the DEPPARSE step.  This is the last step\n    # Note that we do NOT use the depparse directory's data.  That is\n    # because it has either gold tags, or predicted tags based on\n    # retagging using gold tokenization, and we aren't sure which at\n    # this point in the process.\n    # TODO: add batch args\n    logger.info(\"-----  DEPPARSE   ----------\")\n    depparse_output = f\"{ete_dir}/{short_name}.{dataset}.depparse.conllu\"\n    depparse_args = ['--wordvec_dir', wordvec_dir,\n                     '--eval_file', lemma_output,\n                     '--output_file', depparse_output,\n                     '--lang', short_name,\n                     '--shorthand', short_name,\n                     '--mode', 'predict',\n                     # we don't ask the parser to report a score either\n                     '--no_gold_labels']\n    depparse_charlm_args = build_depparse_charlm_args(short_language, package, command_args.charlm)\n    depparse_args = depparse_args + build_depparse_wordvec_args(short_language, package, extra_args) + depparse_charlm_args + extra_args\n    logger.info(\"Running depparse step with args: {}\".format(depparse_args))\n    parser.main(depparse_args)\n\n    logger.info(\"-----  EVALUATION ----------\")\n    gold_file = f\"{tokenize_dir}/{test_short_name}.{dataset}.gold.conllu\"\n    ete_file = depparse_output\n    results = common.run_eval_script(gold_file, ete_file)\n    logger.info(\"{} {} models on {} {} data:\\n{}\".format(RESULTS_STRING, short_name, test_short_name, dataset, results))\n\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    if mode == Mode.TRAIN:\n        dataset = 'train'\n    elif mode == Mode.SCORE_DEV:\n        dataset = 'dev'\n    elif mode == Mode.SCORE_TEST:\n        dataset = 'test'\n\n    if not command_args.save_output:\n        with tempfile.TemporaryDirectory() as ete_dir:\n            paths = dict(paths)\n            paths[\"ETE_DATA_DIR\"] = ete_dir\n            run_ete(paths, dataset, short_name, command_args, extra_args)\n    else:\n        os.makedirs(paths[\"ETE_DATA_DIR\"], exist_ok=True)\n        run_ete(paths, dataset, short_name, command_args, extra_args)\n\ndef main():\n    common.main(run_treebank, \"ete\", \"ete\", add_args)\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_lemma.py",
    "content": "\"\"\"\nThis script allows for training or testing on dev / test of the UD lemmatizer.\n\nIf run with a single treebank name, it will train or test that treebank.\nIf run with ud_all or all_ud, it will iterate over all UD treebanks it can find.\n\nMode can be set to train&dev with --train, to dev set only\nwith --score_dev, and to test set only with --score_test.\n\nTreebanks are specified as a list.  all_ud or ud_all means to look for\nall UD treebanks.\n\nExtra arguments are passed to the lemmatizer.  In case the run script\nitself is shadowing arguments, you can specify --extra_args as a\nparameter to mark where the lemmatizer arguments start.\n\"\"\"\n\nimport logging\nimport os\n\nfrom stanza.models import identity_lemmatizer\nfrom stanza.models import lemmatizer\n\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm\n\nfrom stanza.utils.datasets.prepare_lemma_treebank import check_lemmas\nimport stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier\n\nlogger = logging.getLogger('stanza')\n\ndef add_lemma_args(parser):\n    add_charlm_args(parser)\n\n    parser.add_argument('--lemma_classifier', dest='lemma_classifier', action='store_true', default=None,\n                        help=\"Don't use the lemma classifier datasets.  Default is to build lemma classifier as part of the original lemmatizer if the charlm is used\")\n    parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false',\n                        help=\"Don't use the lemma classifier datasets.  Default is to build lemma classifier as part of the original lemmatizer if the charlm is used\")\n\ndef build_model_filename(paths, short_name, command_args, extra_args):\n    \"\"\"\n    Figure out what the model savename will be, taking into account the model settings.\n\n    Useful for figuring out if the model already exists\n\n    None will represent that there is no expected save_name\n    \"\"\"\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    lemma_dir      = paths[\"LEMMA_DATA_DIR\"]\n    train_file     = f\"{lemma_dir}/{short_name}.train.in.conllu\"\n\n    if not os.path.exists(train_file):\n        logger.debug(\"Treebank %s is not prepared for training the lemmatizer.  Could not find any training data at %s  Cannot figure out the expected save_name without looking at the data, but a later step in the process will skip the training anyway\" % (short_name, train_file))\n        return None\n\n    has_lemmas = check_lemmas(train_file)\n    if not has_lemmas:\n        return None\n\n    # TODO: can avoid downloading the charlm at this point, since we\n    # might not even be training\n    charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)\n\n    train_args = [\"--train_file\", train_file,\n                  \"--shorthand\", short_name,\n                  \"--mode\", \"train\"]\n    train_args = train_args + charlm_args + extra_args\n    args = lemmatizer.parse_args(train_args)\n    save_name = lemmatizer.build_model_filename(args)\n    return save_name\n\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    lemma_dir      = paths[\"LEMMA_DATA_DIR\"]\n    train_file     = f\"{lemma_dir}/{short_name}.train.in.conllu\"\n    dev_in_file    = f\"{lemma_dir}/{short_name}.dev.in.conllu\"\n    dev_pred_file  = f\"{lemma_dir}/{short_name}.dev.pred.conllu\"\n    test_in_file   = f\"{lemma_dir}/{short_name}.test.in.conllu\"\n    test_pred_file = f\"{lemma_dir}/{short_name}.test.pred.conllu\"\n\n    charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)\n\n    if not os.path.exists(train_file):\n        logger.error(\"Treebank %s is not prepared for training the lemmatizer.  Could not find any training data at %s  Skipping...\" % (treebank, train_file))\n        return\n\n    has_lemmas = check_lemmas(train_file)\n    if not has_lemmas:\n        logger.info(\"Treebank \" + treebank + \" (\" + short_name +\n                    \") has no lemmas.  Using identity lemmatizer\")\n        if mode == Mode.TRAIN or mode == Mode.SCORE_DEV:\n            train_args = [\"--train_file\", train_file,\n                          \"--eval_file\", dev_in_file,\n                          \"--gold_file\", dev_in_file,\n                          \"--shorthand\", short_name]\n            if command_args.save_output:\n                train_args.extend([\"--output_file\", dev_pred_file])\n            logger.info(\"Running identity lemmatizer for {} with args {}\".format(treebank, train_args))\n            identity_lemmatizer.main(train_args)\n        elif mode == Mode.SCORE_TEST:\n            train_args = [\"--train_file\", train_file,\n                          \"--eval_file\", test_in_file,\n                          \"--gold_file\", test_in_file,\n                          \"--shorthand\", short_name]\n            if command_args.save_output:\n                train_args.extend([\"--output_file\", test_pred_file])\n            logger.info(\"Running identity lemmatizer for {} with args {}\".format(treebank, train_args))\n            identity_lemmatizer.main(train_args)            \n    else:\n        if mode == Mode.TRAIN:\n            # ('UD_Czech-PDT', 'UD_Russian-SynTagRus', 'UD_German-HDT')\n            if short_name in ('cs_pdt', 'ru_syntagrus', 'de_hdt'):\n                num_epochs = \"30\"\n            else:\n                num_epochs = \"60\"\n\n            train_args = [\"--train_file\", train_file,\n                          \"--eval_file\", dev_in_file,\n                          \"--shorthand\", short_name,\n                          \"--num_epoch\", num_epochs,\n                          \"--mode\", \"train\"]\n            train_args = train_args + charlm_args + extra_args\n            logger.info(\"Running train lemmatizer for {} with args {}\".format(treebank, train_args))\n            lemmatizer.main(train_args)\n\n        if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n            dev_args = [\"--eval_file\", dev_in_file,\n                        \"--shorthand\", short_name,\n                        \"--mode\", \"predict\"]\n            if command_args.save_output:\n                train_args.extend([\"--output_file\", dev_pred_file])\n            dev_args = dev_args + charlm_args + extra_args\n            logger.info(\"Running dev lemmatizer for {} with args {}\".format(treebank, dev_args))\n            lemmatizer.main(dev_args)\n\n        if mode == Mode.SCORE_TEST:\n            test_args = [\"--eval_file\", test_in_file,\n                         \"--shorthand\", short_name,\n                         \"--mode\", \"predict\"]\n            if command_args.save_output:\n                train_args.extend([\"--output_file\", test_pred_file])\n            test_args = test_args + charlm_args + extra_args\n            logger.info(\"Running test lemmatizer for {} with args {}\".format(treebank, test_args))\n            lemmatizer.main(test_args)\n\n        use_lemma_classifier = command_args.lemma_classifier\n        if use_lemma_classifier is None:\n            use_lemma_classifier = command_args.charlm is not None\n        use_lemma_classifier = use_lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING\n        if use_lemma_classifier and mode == Mode.TRAIN:\n            # some installations may not have transformers,\n            # so we bury the lemma_classifier import in the codepath\n            # which actually needs it\n            from stanza.models.lemma import attach_lemma_classifier\n            from stanza.utils.training import run_lemma_classifier\n\n            lc_charlm_args = ['--no_charlm'] if command_args.charlm is None else ['--charlm', command_args.charlm]\n            lemma_classifier_args = [treebank] + lc_charlm_args\n            if command_args.force:\n                lemma_classifier_args.append('--force')\n            run_lemma_classifier.main(lemma_classifier_args)\n\n            save_name = build_model_filename(paths, short_name, command_args, extra_args)\n            # TODO: use a temp path for the lemma_classifier or keep it somewhere\n            attach_args = ['--input', save_name,\n                           '--output', save_name,\n                           '--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name]\n            attach_lemma_classifier.main(attach_args)\n\n            # now we rerun the dev set - the HI in particular demonstrates some good improvement\n            lemmatizer.main(dev_args)\n\ndef main():\n    common.main(run_treebank, \"lemma\", \"lemmatizer\", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm)\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_lemma_classifier.py",
    "content": "import os\n\nfrom stanza.models.lemma_classifier import evaluate_models\nfrom stanza.models.lemma_classifier import train_lstm_model\nfrom stanza.models.lemma_classifier import train_transformer_model\nfrom stanza.models.lemma_classifier.constants import ModelType\n\nfrom stanza.resources.default_packages import default_pretrains, TRANSFORMERS\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm, find_wordvec_pretrain\n\ndef add_lemma_args(parser):\n    add_charlm_args(parser)\n\n    parser.add_argument('--model_type', default=ModelType.LSTM, type=lambda x: ModelType[x.upper()],\n                        help='Model type to use.  {}'.format(\", \".join(x.name for x in ModelType)))\n\ndef build_model_filename(paths, short_name, command_args, extra_args):\n    return os.path.join(\"saved_models\", \"lemma_classifier\", short_name + \"_lemma_classifier.pt\")\n\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    base_args = []\n    if '--save_name' not in extra_args:\n        base_args += ['--save_name', build_model_filename(paths, short_name, command_args, extra_args)]\n\n    embedding_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)\n    if '--wordvec_pretrain_file' not in extra_args:\n        wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, {}, dataset)\n        embedding_args += [\"--wordvec_pretrain_file\", wordvec_pretrain]\n\n    bert_args = []\n    if command_args.model_type is ModelType.TRANSFORMER:\n        if '--bert_model' not in extra_args:\n            if short_language in TRANSFORMERS:\n                bert_args = ['--bert_model', TRANSFORMERS.get(short_language)]\n            else:\n                raise ValueError(\"--bert_model not specified, so cannot figure out which transformer to use for language %s\" % short_language)\n\n    extra_train_args = []\n    if command_args.force:\n        extra_train_args.append('--force')\n\n    if mode == Mode.TRAIN:\n        train_args = []\n        if \"--train_file\" not in extra_args:\n            train_file = os.path.join(\"data\", \"lemma_classifier\", \"%s.train.lemma\" % short_name)\n            train_args += ['--train_file', train_file]\n        if \"--eval_file\" not in extra_args:\n            eval_file = os.path.join(\"data\", \"lemma_classifier\", \"%s.dev.lemma\" % short_name)\n            train_args += ['--eval_file', eval_file]\n        train_args = base_args + train_args + extra_args + extra_train_args\n\n        if command_args.model_type == ModelType.LSTM:\n            train_args = embedding_args + train_args\n            train_lstm_model.main(train_args)\n        else:\n            model_type_args = [\"--model_type\", command_args.model_type.name.lower()]\n            train_args = bert_args + model_type_args + train_args\n            train_transformer_model.main(train_args)\n\n    if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n        eval_args = []\n        if \"--eval_file\" not in extra_args:\n            eval_file = os.path.join(\"data\", \"lemma_classifier\", \"%s.dev.lemma\" % short_name)\n            eval_args += ['--eval_file', eval_file]\n        model_type_args = [\"--model_type\", command_args.model_type.name.lower()]\n        eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args\n        evaluate_models.main(eval_args)\n\n    if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:\n        eval_args = []\n        if \"--eval_file\" not in extra_args:\n            eval_file = os.path.join(\"data\", \"lemma_classifier\", \"%s.test.lemma\" % short_name)\n            eval_args += ['--eval_file', eval_file]\n        model_type_args = [\"--model_type\", command_args.model_type.name.lower()]\n        eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args\n        evaluate_models.main(eval_args)\n\ndef main(args=None):\n    common.main(run_treebank, \"lemma_classifier\", \"lemma_classifier\", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm, args=args)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/training/run_mwt.py",
    "content": "\"\"\"\nThis script allows for training or testing on dev / test of the UD mwt tools.\n\nIf run with a single treebank name, it will train or test that treebank.\nIf run with ud_all or all_ud, it will iterate over all UD treebanks it can find.\n\nMode can be set to train&dev with --train, to dev set only\nwith --score_dev, and to test set only with --score_test.\n\nTreebanks are specified as a list.  all_ud or ud_all means to look for\nall UD treebanks.\n\nExtra arguments are passed to mwt.  In case the run script\nitself is shadowing arguments, you can specify --extra_args as a\nparameter to mark where the mwt arguments start.\n\"\"\"\n\n\nimport io\nimport logging\nimport math\n\nfrom stanza.models import mwt_expander\nfrom stanza.models.common.doc import Document\nfrom stanza.utils.conll import CoNLL\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode\n\nfrom stanza.utils.max_mwt_length import max_mwt_length\n\nlogger = logging.getLogger('stanza')\n\ndef check_mwt(filename):\n    \"\"\"\n    Checks whether or not there are MWTs in the given conll file\n    \"\"\"\n    doc = CoNLL.conll2doc(filename)\n    data = doc.get_mwt_expansions(False)\n    return len(data) > 0\n\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    short_language = short_name.split(\"_\")[0]\n\n    mwt_dir          = paths[\"MWT_DATA_DIR\"]\n\n    train_file       = f\"{mwt_dir}/{short_name}.train.in.conllu\"\n    dev_in_file      = f\"{mwt_dir}/{short_name}.dev.in.conllu\"\n    dev_gold_file    = f\"{mwt_dir}/{short_name}.dev.gold.conllu\"\n    dev_output_file  = f\"{mwt_dir}/{short_name}.dev.pred.conllu\"\n    test_in_file     = f\"{mwt_dir}/{short_name}.test.in.conllu\"\n    test_gold_file   = f\"{mwt_dir}/{short_name}.test.gold.conllu\"\n    test_output_file = f\"{mwt_dir}/{short_name}.test.pred.conllu\"\n\n    train_json       = f\"{mwt_dir}/{short_name}-ud-train-mwt.json\"\n    dev_json         = f\"{mwt_dir}/{short_name}-ud-dev-mwt.json\"\n    test_json        = f\"{mwt_dir}/{short_name}-ud-test-mwt.json\"\n\n    eval_file = None\n    if '--eval_file' in extra_args:\n        eval_file = extra_args[extra_args.index('--eval_file') + 1]\n\n    gold_file = None\n    if '--gold_file' in extra_args:\n        gold_file = extra_args[extra_args.index('--gold_file') + 1]\n\n    if not check_mwt(train_file):\n        logger.info(\"No training MWTS found for %s.  Skipping\" % treebank)\n        return\n    \n    if not check_mwt(dev_in_file) and mode == Mode.TRAIN:\n        logger.info(\"No dev MWTS found for %s.  Training only the deterministic MWT expander\" % treebank)\n        extra_args.append('--dict_only')\n\n    if mode == Mode.TRAIN:\n        max_mwt_len = math.ceil(max_mwt_length([train_json, dev_json]) * 1.1 + 1)\n        logger.info(\"Max len: %f\" % max_mwt_len)\n        train_args = ['--train_file', train_file,\n                      '--eval_file', eval_file if eval_file else dev_in_file,\n                      '--gold_file', gold_file if gold_file else dev_gold_file,\n                      '--lang', short_language,\n                      '--shorthand', short_name,\n                      '--mode', 'train',\n                      '--max_dec_len', str(max_mwt_len)]\n        train_args = train_args + extra_args\n        logger.info(\"Running train step with args: {}\".format(train_args))\n        mwt_expander.main(train_args)\n\n    if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n        dev_args = ['--eval_file', eval_file if eval_file else dev_in_file,\n                    '--gold_file', gold_file if gold_file else dev_gold_file,\n                    '--lang', short_language,\n                    '--shorthand', short_name,\n                    '--mode', 'predict']\n        if command_args.save_output:\n            dev_args.extend(['--output_file', dev_output_file])\n        dev_args = dev_args + extra_args\n        logger.info(\"Running dev step with args: {}\".format(dev_args))\n        _, dev_doc = mwt_expander.main(dev_args)\n        if not command_args.save_output:\n            dev_output_file = \"{:C}\\n\\n\".format(dev_doc)\n            dev_output_file = io.StringIO(dev_output_file)\n\n        results = common.run_eval_script_mwt(gold_file if gold_file else dev_gold_file, dev_output_file)\n        logger.info(\"Finished running dev set on\\n{}\\n{}\".format(treebank, results))\n\n    if mode == Mode.SCORE_TEST:\n        test_args = ['--eval_file', eval_file if eval_file else test_in_file,\n                     '--gold_file', gold_file if gold_file else test_gold_file,\n                     '--lang', short_language,\n                     '--shorthand', short_name,\n                     '--mode', 'predict']\n        if command_args.save_output:\n            test_args.extend(['--output_file', test_output_file])\n        test_args = test_args + extra_args\n        logger.info(\"Running test step with args: {}\".format(test_args))\n        _, test_doc = mwt_expander.main(test_args)\n        if not command_args.save_output:\n            test_output_file = \"{:C}\\n\\n\".format(test_doc)\n            test_output_file = io.StringIO(test_output_file)\n\n        results = common.run_eval_script_mwt(gold_file if gold_file else test_gold_file, test_output_file)\n        logger.info(\"Finished running test set on\\n{}\\n{}\".format(treebank, results))\n\ndef main():\n    common.main(run_treebank, \"mwt\", \"mwt_expander\", sub_argparse=mwt_expander.build_argparse())\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_ner.py",
    "content": "\"\"\"\nTrains or scores an NER model.\n\nWill attempt to guess the appropriate word vector file if none is\nspecified, and will use the charlms specified in the resources\nfor a given dataset or language if possible.\n\nExample command line:\n  python3 -m stanza.utils.training.run_ner.py hu_combined\n\nThis script expects the prepared data to be in\n  data/ner/{lang}_{dataset}.train.json, {lang}_{dataset}.dev.json, {lang}_{dataset}.test.json\n\nIf those files don't exist, it will make an attempt to rebuild them\nusing the prepare_ner_dataset script.  However, this will fail if the\ndata is not already downloaded.  More information on where to find\nmost of the datasets online is in that script.  Some of the datasets\nhave licenses which must be agreed to, so no attempt is made to\nautomatically download the data.\n\"\"\"\n\nimport logging\nimport os\n\nfrom stanza.models import ner_tagger\nfrom stanza.resources.common import DEFAULT_MODEL_DIR\nfrom stanza.utils.datasets.ner import prepare_ner_dataset\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain\n\nfrom stanza.resources.default_packages import default_charlms, default_pretrains, ner_charlms, ner_pretrains\n\n# extra arguments specific to a particular dataset\nDATASET_EXTRA_ARGS = {\n    \"da_ddt\":   [ \"--dropout\", \"0.6\" ],\n    \"fa_arman\": [ \"--dropout\", \"0.6\" ],\n    \"vi_vlsp\":  [ \"--dropout\", \"0.6\",\n                  \"--word_dropout\", \"0.1\",\n                  \"--locked_dropout\", \"0.1\",\n                  \"--char_dropout\", \"0.1\" ],\n}\n\nlogger = logging.getLogger('stanza')\n\ndef add_ner_args(parser):\n    add_charlm_args(parser)\n\n    parser.add_argument('--use_bert', default=False, action=\"store_true\", help='Use the default transformer for this language')\n\n\ndef build_pretrain_args(language, dataset, charlm=\"default\", command_args=None, extra_args=None, model_dir=DEFAULT_MODEL_DIR):\n    \"\"\"\n    Returns one list with the args for this language & dataset's charlm and pretrained embedding\n    \"\"\"\n    charlm = choose_charlm(language, dataset, charlm, default_charlms, ner_charlms)\n    charlm_args = build_charlm_args(language, charlm, model_dir=model_dir)\n\n    wordvec_args = []\n    if '--wordvec_pretrain_file' not in extra_args and '--no_pretrain' not in extra_args:\n        # will throw an error if the pretrain can't be found\n        wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains, ner_pretrains, dataset, model_dir=model_dir)\n        wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]\n\n    bert_args = common.choose_transformer(language, command_args, extra_args, warn=False)\n\n    return charlm_args + wordvec_args + bert_args\n\n\n# TODO: refactor?  tagger and depparse should be pretty similar\ndef build_model_filename(paths, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    # TODO: can avoid downloading the charlm at this point, since we\n    # might not even be training\n    pretrain_args = build_pretrain_args(short_language, dataset, command_args.charlm, command_args, extra_args)\n\n    dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])\n\n    train_args = [\"--shorthand\", short_name,\n                  \"--mode\", \"train\"]\n    train_args = train_args + pretrain_args + dataset_args + extra_args\n    if command_args.save_name is not None:\n        train_args.extend([\"--save_name\", command_args.save_name])\n    if command_args.save_dir is not None:\n        train_args.extend([\"--save_dir\", command_args.save_dir])\n    args = ner_tagger.parse_args(train_args)\n    save_name = ner_tagger.model_file_name(args)\n    return save_name\n\n\n# Technically NER datasets are not necessarily treebanks\n# (usually not, in fact)\n# However, to keep the naming consistent, we leave the\n# method which does the training as run_treebank\n# TODO: rename treebank -> dataset everywhere\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    ner_dir = paths[\"NER_DATA_DIR\"]\n    language, dataset = short_name.split(\"_\")\n\n    train_file = os.path.join(ner_dir, f\"{treebank}.train.json\")\n    dev_file   = os.path.join(ner_dir, f\"{treebank}.dev.json\")\n    test_file  = os.path.join(ner_dir, f\"{treebank}.test.json\")\n\n    # if any files are missing, try to rebuild the dataset\n    # if that still doesn't work, we have to throw an error\n    missing_file = [x for x in (train_file, dev_file, test_file) if not os.path.exists(x)]\n    if len(missing_file) > 0:\n        logger.warning(f\"The data for {treebank} is missing or incomplete.  Cannot find {missing_file}  Attempting to rebuild...\")\n        try:\n            prepare_ner_dataset.main(treebank)\n        except Exception as e:\n            raise FileNotFoundError(f\"An exception occurred while trying to build the data for {treebank}  At least one portion of the data was missing: {missing_file}  Please correctly build these files and then try again.\") from e\n\n    pretrain_args = build_pretrain_args(language, dataset, command_args.charlm, command_args, extra_args)\n\n    if mode == Mode.TRAIN:\n        # VI example arguments:\n        #   --wordvec_pretrain_file ~/stanza_resources/vi/pretrain/vtb.pt\n        #   --train_file data/ner/vi_vlsp.train.json\n        #   --eval_file data/ner/vi_vlsp.dev.json\n        #   --lang vi\n        #   --shorthand vi_vlsp\n        #   --mode train\n        #   --charlm --charlm_shorthand vi_conll17\n        #   --dropout 0.6 --word_dropout 0.1 --locked_dropout 0.1 --char_dropout 0.1\n        dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])\n\n        train_args = ['--train_file', train_file,\n                      '--eval_file', dev_file,\n                      '--shorthand', short_name,\n                      '--mode', 'train']\n        train_args = train_args + pretrain_args + dataset_args + extra_args\n        logger.info(\"Running train step with args: {}\".format(train_args))\n        ner_tagger.main(train_args)\n\n    if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n        dev_args = ['--eval_file', dev_file,\n                    '--shorthand', short_name,\n                    '--mode', 'predict']\n        dev_args = dev_args + pretrain_args + extra_args\n        logger.info(\"Running dev step with args: {}\".format(dev_args))\n        ner_tagger.main(dev_args)\n\n    if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:\n        test_args = ['--eval_file', test_file,\n                     '--shorthand', short_name,\n                     '--mode', 'predict']\n        test_args = test_args + pretrain_args + extra_args\n        logger.info(\"Running test step with args: {}\".format(test_args))\n        ner_tagger.main(test_args)\n\n\ndef main():\n    common.main(run_treebank, \"ner\", \"nertagger\", add_ner_args, ner_tagger.build_argparse(), build_model_filename=build_model_filename)\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_pos.py",
    "content": "\nimport io\nimport logging\nimport os\n\nfrom stanza.models import tagger\n\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, add_charlm_args, build_pos_charlm_args, choose_pos_charlm, find_wordvec_pretrain, build_pos_wordvec_args\n\nlogger = logging.getLogger('stanza')\n\ndef add_pos_args(parser):\n    add_charlm_args(parser)\n\n    parser.add_argument('--use_bert', default=False, action=\"store_true\", help='Use the default transformer for this language')\n\ndef build_model_filename(paths, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    # TODO: can avoid downloading the charlm at this point, since we\n    # might not even be training\n    charlm_args = build_pos_charlm_args(short_language, dataset, command_args.charlm)\n    bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=False)\n\n    train_args = [\"--shorthand\", short_name,\n                  \"--mode\", \"train\"]\n    # TODO: also, this downloads the wordvec, which we might not want to do yet\n    train_args = train_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args\n    if command_args.save_name is not None:\n        train_args.extend([\"--save_name\", command_args.save_name])\n    if command_args.save_dir is not None:\n        train_args.extend([\"--save_dir\", command_args.save_dir])\n    args = tagger.parse_args(train_args)\n    save_name = tagger.model_file_name(args)\n    return save_name\n\n\n\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    pos_dir        = paths[\"POS_DATA_DIR\"]\n    train_file     = f\"{pos_dir}/{short_name}.train.in.conllu\"\n    if short_name == 'vi_vlsp22':\n        train_file += f\";{pos_dir}/vi_vtb.train.in.conllu\"\n    dev_in_file    = f\"{pos_dir}/{short_name}.dev.in.conllu\"\n    dev_pred_file  = f\"{pos_dir}/{short_name}.dev.pred.conllu\"\n    test_in_file   = f\"{pos_dir}/{short_name}.test.in.conllu\"\n    test_pred_file = f\"{pos_dir}/{short_name}.test.pred.conllu\"\n\n    charlm_args = build_pos_charlm_args(short_language, dataset, command_args.charlm)\n    bert_args = common.choose_transformer(short_language, command_args, extra_args)\n\n    eval_file = None\n    if '--eval_file' in extra_args:\n        eval_file = extra_args[extra_args.index('--eval_file') + 1]\n\n    if mode == Mode.TRAIN:\n        train_pieces = []\n        for train_piece in train_file.split(\";\"):\n            zip_piece = os.path.splitext(train_piece)[0] + \".zip\"\n            if os.path.exists(train_piece) and os.path.exists(zip_piece):\n                logger.error(\"POS TRAIN FILE %s and %s both exist... this is very confusing, skipping %s\" % (train_piece, zip_piece, short_name))\n                return\n            if os.path.exists(train_piece):\n                train_pieces.append(train_piece)\n            else: # not os.path.exists(train_piece):\n                if os.path.exists(zip_piece):\n                    train_pieces.append(zip_piece)\n                    continue\n                logger.error(\"TRAIN FILE NOT FOUND: %s ... skipping\" % train_piece)\n                return\n        train_file = \";\".join(train_pieces)\n\n        train_args = [\"--wordvec_dir\", paths[\"WORDVEC_DIR\"],\n                      \"--train_file\", train_file,\n                      \"--lang\", short_language,\n                      \"--shorthand\", short_name,\n                      \"--mode\", \"train\"]\n        if eval_file is None:\n            train_args += ['--eval_file', dev_in_file]\n        train_args = train_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args\n        train_args = train_args + extra_args\n        logger.info(\"Running train POS for {} with args {}\".format(treebank, train_args))\n        tagger.main(train_args)\n\n    if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n        dev_args = [\"--wordvec_dir\", paths[\"WORDVEC_DIR\"],\n                    \"--lang\", short_language,\n                    \"--shorthand\", short_name,\n                    \"--mode\", \"predict\"]\n        if eval_file is None:\n            dev_args += ['--eval_file', dev_in_file]\n        if command_args.save_output:\n            dev_args.extend([\"--output_file\", dev_pred_file])\n        dev_args = dev_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args\n        dev_args = dev_args + extra_args\n        logger.info(\"Running dev POS for {} with args {}\".format(treebank, dev_args))\n        _, dev_doc = tagger.main(dev_args)\n        if not command_args.save_output:\n            dev_pred_file = \"{:C}\\n\\n\".format(dev_doc)\n            dev_pred_file = io.StringIO(dev_pred_file)\n\n        results = common.run_eval_script_pos(eval_file if eval_file else dev_in_file, dev_pred_file)\n        logger.info(\"Finished running dev set on\\n{}\\n{}\".format(treebank, results))\n        if command_args.save_output:\n            logger.info(\"Output saved to %s\", dev_pred_file)\n\n    if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:\n        test_args = [\"--wordvec_dir\", paths[\"WORDVEC_DIR\"],\n                     \"--lang\", short_language,\n                     \"--shorthand\", short_name,\n                     \"--mode\", \"predict\"]\n        if eval_file is None:\n            test_args += ['--eval_file', test_in_file]\n        if command_args.save_output:\n            dev_args.extend([\"--output_file\", test_pred_file])\n        test_args = test_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args\n        test_args = test_args + extra_args\n        logger.info(\"Running test POS for {} with args {}\".format(treebank, test_args))\n        _, test_doc = tagger.main(test_args)\n        if not command_args.save_output:\n            test_pred_file = \"{:C}\\n\\n\".format(test_doc)\n            test_pred_file = io.StringIO(test_pred_file)\n\n        results = common.run_eval_script_pos(eval_file if eval_file else test_in_file, test_pred_file)\n        logger.info(\"Finished running test set on\\n{}\\n{}\".format(treebank, results))\n        if command_args.save_output:\n            logger.info(\"Output saved to %s\", test_pred_file)\n\n\ndef main():\n    common.main(run_treebank, \"pos\", \"tagger\", add_pos_args, tagger.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_pos_charlm)\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_sentiment.py",
    "content": "\"\"\"\nTrains or tests a sentiment model using the classifier package\n\nThe prep script has separate entries for the root-only version of SST,\nwhich is what people typically use to test.  When training a model for\nSST which uses all the data, the root-only version is used for\ndev and test\n\"\"\"\n\nimport logging\nimport os\n\nfrom stanza.models import classifier\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, build_charlm_args, choose_charlm, find_wordvec_pretrain\n\nfrom stanza.resources.default_packages import default_charlms, default_pretrains\n\nlogger = logging.getLogger('stanza')\n\n# TODO: refactor with ner & conparse\ndef add_sentiment_args(parser):\n    parser.add_argument('--charlm', default=\"default\", type=str, help='Which charlm to run on.  Will use the default charlm for this language/model if not set.  Set to None to turn off charlm for languages with a default charlm')\n    parser.add_argument('--no_charlm', dest='charlm', action=\"store_const\", const=None, help=\"Don't use a charlm, even if one is used by default for this package\")\n    parser.add_argument('--use_charlm', action='store_true', help='If --use_bert is set, charlm will be turned off.  This turns it on anyway')\n\n    parser.add_argument('--use_bert', default=False, action=\"store_true\", help='Use the default transformer for this language')\n\nALTERNATE_DATASET = {\n    \"en_sst2\":    \"en_sst2roots\",\n    \"en_sstplus\": \"en_sst3roots\",\n}\n\ndef build_default_args(paths, short_language, dataset, command_args, extra_args):\n    if '--wordvec_pretrain_file' not in extra_args:\n        # will throw an error if the pretrain can't be found\n        wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains)\n        wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]\n    else:\n        wordvec_args = []\n\n    if command_args.use_bert and not command_args.use_charlm:\n        charlm = None\n    else:\n        charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {})\n    charlm_args = build_charlm_args(short_language, charlm, base_args=False)\n\n    bert_args = common.choose_transformer(short_language, command_args, extra_args)\n    default_args = wordvec_args + charlm_args + bert_args\n\n    return default_args\n\ndef build_model_filename(paths, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)\n\n    train_args = [\"--shorthand\", short_name]\n    train_args = train_args + default_args\n    if command_args.save_name is not None:\n        train_args.extend([\"--save_name\", command_args.save_name])\n    if command_args.save_dir is not None:\n        train_args.extend([\"--save_dir\", command_args.save_dir])\n    args = classifier.parse_args(train_args + extra_args)\n    save_name = classifier.build_model_filename(args)\n    return save_name\n\n\ndef run_dataset(mode, paths, treebank, short_name, command_args, extra_args):\n    sentiment_dir = paths[\"SENTIMENT_DATA_DIR\"]\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    train_file = os.path.join(sentiment_dir, f\"{short_name}.train.json\")\n\n    other_name = ALTERNATE_DATASET.get(short_name, short_name)\n    dev_file   = os.path.join(sentiment_dir, f\"{other_name}.dev.json\")\n    test_file  = os.path.join(sentiment_dir, f\"{other_name}.test.json\")\n\n    for filename in (train_file, dev_file, test_file):\n        if not os.path.exists(filename):\n            raise FileNotFoundError(\"Cannot find %s\" % filename)\n\n    default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)\n\n    if mode == Mode.TRAIN:\n        train_args = ['--train_file', train_file,\n                      '--dev_file', dev_file,\n                      '--test_file', test_file,\n                      '--shorthand', short_name,\n                      '--wordvec_type', 'word2vec',   # TODO: chinese is fasttext\n                      '--extra_wordvec_method', 'SUM']\n        train_args = train_args + default_args + extra_args\n        logger.info(\"Running train step with args: {}\".format(train_args))\n        classifier.main(train_args)\n\n    if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n        dev_args = ['--no_train',\n                    '--test_file', dev_file,\n                    '--shorthand', short_name,\n                    '--wordvec_type', 'word2vec']   # TODO: chinese is fasttext\n        dev_args = dev_args + default_args + extra_args\n        logger.info(\"Running dev step with args: {}\".format(dev_args))\n        classifier.main(dev_args)\n\n    if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:\n        test_args = ['--no_train',\n                     '--test_file', test_file,\n                     '--shorthand', short_name,\n                     '--wordvec_type', 'word2vec']   # TODO: chinese is fasttext\n        test_args = test_args + default_args + extra_args\n        logger.info(\"Running test step with args: {}\".format(test_args))\n        classifier.main(test_args)\n\n\n\ndef main():\n    common.main(run_dataset, \"classifier\", \"classifier\", add_sentiment_args, classifier.build_argparse(), build_model_filename=build_model_filename)\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "stanza/utils/training/run_tokenizer.py",
    "content": "\"\"\"\nThis script allows for training or testing on dev / test of the UD tokenizer.\n\nIf run with a single treebank name, it will train or test that treebank.\nIf run with ud_all or all_ud, it will iterate over all UD treebanks it can find.\n\nMode can be set to train&dev with --train, to dev set only\nwith --score_dev, and to test set only with --score_test.\n\nTreebanks are specified as a list.  all_ud or ud_all means to look for\nall UD treebanks.\n\nExtra arguments are passed to tokenizer.  In case the run script\nitself is shadowing arguments, you can specify --extra_args as a\nparameter to mark where the tokenizer arguments start.\n\nDefault behavior is to discard the output and just print the results.\nTo keep the results instead, use --save_output\n\"\"\"\n\nimport io\nimport logging\nimport math\nimport os\n\nfrom stanza.models import tokenizer\nfrom stanza.models.common.doc import Document\nfrom stanza.utils.avg_sent_len import avg_sent_len\nfrom stanza.utils.training import common\nfrom stanza.utils.training.common import Mode, add_charlm_args, build_tokenizer_charlm_args\n\nlogger = logging.getLogger('stanza')\n\ndef add_tokenizer_args(parser):\n    add_charlm_args(parser)\n\n\ndef build_model_filename(paths, short_name, command_args, extra_args):\n    short_language, dataset = short_name.split(\"_\", 1)\n\n    # TODO: can avoid downloading the charlm at this point, since we\n    # might not even be training\n    charlm_args = build_tokenizer_charlm_args(short_language, dataset, command_args.charlm)\n\n    train_args = [\"--shorthand\", short_name,\n                  \"--mode\", \"train\"]\n    train_args = train_args + charlm_args + extra_args\n    if command_args.save_name is not None:\n        train_args.extend([\"--save_name\", command_args.save_name])\n    if command_args.save_dir is not None:\n        train_args.extend([\"--save_dir\", command_args.save_dir])\n    args = tokenizer.parse_args(train_args)\n    save_name = tokenizer.model_file_name(args)\n    return save_name\n\n\n\ndef uses_dictionary(short_language):\n    \"\"\"\n    Some of the languages (as shown here) have external dictionaries\n\n    We found this helped the overall tokenizer performance\n    If these can't be found, they can be extracted from the previous iteration of models\n    \"\"\"\n    if short_language in ('ja', 'th', 'zh', 'zh-hans', 'zh-hant'):\n        return True\n    return False\n\ndef run_treebank(mode, paths, treebank, short_name, command_args, extra_args):\n    tokenize_dir = paths[\"TOKENIZE_DATA_DIR\"]\n\n    short_language, dataset = short_name.split(\"_\", 1)\n    label_type = \"--label_file\"\n    label_file = f\"{tokenize_dir}/{short_name}-ud-train.toklabels\"\n    dev_type = \"--txt_file\"\n    dev_file = f\"{tokenize_dir}/{short_name}.dev.txt\"\n    test_type = \"--txt_file\"\n    test_file = f\"{tokenize_dir}/{short_name}.test.txt\"\n    train_type = \"--txt_file\"\n    train_file = f\"{tokenize_dir}/{short_name}.train.txt\"\n    train_dev_args = [\"--dev_txt_file\", dev_file, \"--dev_label_file\", f\"{tokenize_dir}/{short_name}-ud-dev.toklabels\"]\n    \n    if short_language == \"zh\" or short_language.startswith(\"zh-\"):\n        extra_args = [\"--skip_newline\"] + extra_args\n\n    train_gold = f\"{tokenize_dir}/{short_name}.train.gold.conllu\"\n    dev_gold = f\"{tokenize_dir}/{short_name}.dev.gold.conllu\"\n    test_gold = f\"{tokenize_dir}/{short_name}.test.gold.conllu\"\n\n    train_mwt = f\"{tokenize_dir}/{short_name}-ud-train-mwt.json\"\n    dev_mwt = f\"{tokenize_dir}/{short_name}-ud-dev-mwt.json\"\n    test_mwt = f\"{tokenize_dir}/{short_name}-ud-test-mwt.json\"\n\n    train_pred = f\"{tokenize_dir}/{short_name}.train.pred.conllu\"\n    dev_pred = f\"{tokenize_dir}/{short_name}.dev.pred.conllu\"\n    test_pred = f\"{tokenize_dir}/{short_name}.test.pred.conllu\"\n\n    charlm_args = build_tokenizer_charlm_args(short_language, dataset, command_args.charlm)\n\n    if mode == Mode.TRAIN:\n        seqlen = str(math.ceil(avg_sent_len(label_file) * 3 / 100) * 100)\n        train_args = ([label_type, label_file, train_type, train_file, \"--lang\", short_language,\n                       \"--max_seqlen\", seqlen, \"--mwt_json_file\", dev_mwt] +\n                      train_dev_args +\n                      [\"--dev_conll_gold\", dev_gold, \"--shorthand\", short_name])\n        if uses_dictionary(short_language):\n            train_args = train_args + [\"--use_dictionary\"]\n        train_args = train_args + charlm_args + extra_args\n        logger.info(\"Running train step with args: {}\".format(train_args))\n        tokenizer.main(train_args)\n    \n    if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:\n        dev_args = [\"--mode\", \"predict\", dev_type, dev_file, \"--lang\", short_language,\n                    \"--shorthand\", short_name, \"--mwt_json_file\", dev_mwt]\n        if command_args.save_output:\n            dev_args.extend([\"--conll_file\", dev_pred])\n        dev_args = dev_args + charlm_args + extra_args\n        logger.info(\"Running dev step with args: {}\".format(dev_args))\n        _, dev_doc = tokenizer.main(dev_args)\n\n        # TODO: log these results?  The original script logged them to\n        # echo $results $args >> ${TOKENIZE_DATA_DIR}/${short}.results\n\n        if not command_args.save_output:\n            dev_pred = \"{:C}\\n\\n\".format(Document(dev_doc))\n            dev_pred = io.StringIO(dev_pred)\n        results = common.run_eval_script_tokens(dev_gold, dev_pred)\n        logger.info(\"Finished running dev set on\\n{}\\n{}\".format(treebank, results))\n\n    if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:\n        test_args = [\"--mode\", \"predict\", test_type, test_file, \"--lang\", short_language,\n                     \"--shorthand\", short_name, \"--mwt_json_file\", test_mwt]\n        if command_args.save_output:\n            test_args.extend([\"--conll_file\", test_pred])\n        test_args = test_args + charlm_args + extra_args\n        logger.info(\"Running test step with args: {}\".format(test_args))\n        _, test_doc = tokenizer.main(test_args)\n\n        if not command_args.save_output:\n            test_pred = \"{:C}\\n\\n\".format(Document(test_doc))\n            test_pred = io.StringIO(test_pred)\n        results = common.run_eval_script_tokens(test_gold, test_pred)\n        logger.info(\"Finished running test set on\\n{}\\n{}\".format(treebank, results))\n\n    if mode == Mode.SCORE_TRAIN:\n        test_args = [\"--mode\", \"predict\", test_type, train_file, \"--lang\", short_language,\n                     \"--shorthand\", short_name, \"--mwt_json_file\", train_mwt]\n        if command_args.save_output:\n            test_args.extend([\"--conll_file\", train_pred])\n        test_args = test_args + charlm_args + extra_args\n        logger.info(\"Running test step with args: {}\".format(test_args))\n        _, train_doc = tokenizer.main(test_args)\n\n        if not command_args.save_output:\n            train_pred = \"{:C}\\n\\n\".format(Document(train_doc))\n            train_pred = io.StringIO(train_pred)\n        results = common.run_eval_script_tokens(train_gold, train_pred)\n        logger.info(\"Finished running train set as a test on\\n{}\\n{}\".format(treebank, results))\n\n\n\ndef main():\n    common.main(run_treebank, \"tokenize\", \"tokenizer\", add_tokenizer_args, sub_argparse=tokenizer.build_argparse(), build_model_filename=build_model_filename)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/training/separate_ner_pretrain.py",
    "content": "\"\"\"\nLoads NER models & separates out the word vectors to base & delta\n\nThe model will then be resaved without the base word vector,\ngreatly reducing the size of the model\n\nThis may be useful for any external users of stanza who have an NER\nmodel they wish to reuse without retraining\n\nIf you know which pretrain was used to build an NER model, you can\nprovide that pretrain.  Otherwise, you can give a directory of\npretrains and the script will test each one.  In the latter case,\nthe name of the pretrain needs to look like lang_dataset_pretrain.pt\n\"\"\"\n\nimport argparse\nfrom collections import defaultdict\nimport logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom stanza import Pipeline\nfrom stanza.models.common.constant import lang_to_langcode\nfrom stanza.models.common.pretrain import Pretrain, PretrainedWordVocab\nfrom stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX\nfrom stanza.models.ner.trainer import Trainer\n\nlogger = logging.getLogger('stanza')\nlogger.setLevel(logging.ERROR)\n\nDEBUG = False\nEPS = 0.0001\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input_path', type=str, default='saved_models/ner', help='Where to find NER models (dir or filename)')\n    parser.add_argument('--output_path', type=str, default='saved_models/shrunk', help='Where to write shrunk NER models (dir)')\n    parser.add_argument('--pretrain_path', type=str, default='saved_models/pretrain', help='Where to find pretrains (dir or filename)')\n    args = parser.parse_args()\n\n    # get list of NER models to shrink\n    if os.path.isdir(args.input_path):\n        ner_model_dir = args.input_path\n        ners = os.listdir(ner_model_dir)\n        if len(ners) == 0:\n            raise FileNotFoundError(\"No ner models found in {}\".format(args.input_path))\n    else:\n        if not os.path.isfile(args.input_path):\n            raise FileNotFoundError(\"No ner model found at path {}\".format(args.input_path))\n        ner_model_dir, ners = os.path.split(args.input_path)\n        ners = [ners]\n\n    # get map from language to candidate pretrains\n    if os.path.isdir(args.pretrain_path):\n        pt_model_dir = args.pretrain_path\n        pretrains = os.listdir(pt_model_dir)\n        lang_to_pretrain = defaultdict(list)\n        for pt in pretrains:\n            lang_to_pretrain[pt.split(\"_\")[0]].append(pt)\n    else:\n        pt_model_dir, pretrains = os.path.split(pt_model_dir)\n        pretrains = [pretrains]\n        lang_to_pretrain = defaultdict(lambda: pretrains)\n\n    # shrunk models will all go in this directory\n    new_dir = args.output_path\n    os.makedirs(new_dir, exist_ok=True)\n\n    final_pretrains = []\n    missing_pretrains = []\n    no_finetune = []\n\n    # for each model, go through the various pretrains\n    # until we find one that works or none of them work\n    for ner_model in ners:\n        ner_path = os.path.join(ner_model_dir, ner_model)\n\n        expected_ending = \"_nertagger.pt\"\n        if not ner_model.endswith(expected_ending):\n            raise ValueError(\"Unexpected name: {}\".format(ner_model))\n        short_name = ner_model[:-len(expected_ending)]\n        lang, package = short_name.split(\"_\", maxsplit=1)\n        print(\"===============================================\")\n        print(\"Processing lang %s package %s\" % (lang, package))\n\n        # this may look funny - basically, the pipeline has machinery\n        # to make sure the model has everything it needs to load,\n        # including downloading other pieces if needed\n        pipe = Pipeline(lang, processors=\"tokenize,ner\", tokenize_pretokenized=True, package={\"ner\": package}, ner_model_path=ner_path)\n        ner_processor = pipe.processors['ner']\n        print(\"Loaded NER processor: {}\".format(ner_processor))\n        trainer = ner_processor.trainers[0]\n        vocab = trainer.model.vocab\n        word_vocab = vocab['word']\n        num_vectors = trainer.model.word_emb.weight.shape[0]\n\n        # sanity check, make sure the model loaded matches the\n        # language from the model's filename\n        lcode = lang_to_langcode(trainer.args['lang'])\n        if lang != lcode and not (lcode == 'zh' and lang == 'zh-hans'):\n            raise ValueError(\"lang not as expected: {} vs {} ({})\".format(lang, trainer.args['lang'], lcode))\n\n        ner_pretrains = sorted(set(lang_to_pretrain[lang] + lang_to_pretrain[lcode]))\n        for pt_model in ner_pretrains:\n            pt_path = os.path.join(pt_model_dir, pt_model)\n            print(\"Attempting pretrain: {}\".format(pt_path))\n            pt = Pretrain(filename=pt_path)\n            print(\"  pretrain shape:               {}\".format(pt.emb.shape))\n            print(\"  embedding in ner model shape: {}\".format(trainer.model.word_emb.weight.shape))\n            if pt.emb.shape[1] != trainer.model.word_emb.weight.shape[1]:\n                print(\"  DIMENSION DOES NOT MATCH.  SKIPPING\")\n                continue\n            N = min(pt.emb.shape[0], trainer.model.word_emb.weight.shape[0])\n            if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]:\n                # If the vocab was exactly the same, that's a good\n                # sign this pretrain was used, just with a different size\n                # In such a case, we can reuse the rest of the pretrain\n                # Minor issue: some vectors which were trained will be\n                # lost in the case of |pt| < |model.word_emb|\n                if all(word_vocab.id2unit(x) == word_vocab.id2unit(x) for x in range(N)):\n                    print(\"  Attempting to use pt vectors to replace ner model's vectors\")\n                else:\n                    print(\"  NUM VECTORS DO NOT MATCH.  WORDS DO NOT MATCH.  SKIPPING\")\n                    continue\n                if pt.emb.shape[0] < trainer.model.word_emb.weight.shape[0]:\n                    print(\"  WARNING: if any vectors beyond {} were fine tuned, that fine tuning will be lost\".format(N))\n            device = next(trainer.model.parameters()).device\n            delta = trainer.model.word_emb.weight[:N, :] - pt.emb.to(device)[:N, :]\n            delta = delta.detach()\n            delta_norms = torch.linalg.norm(delta, dim=1).cpu().numpy()\n            if np.sum(delta_norms < 0) > 0:\n                raise ValueError(\"This should not be - a norm was less than 0!\")\n            num_matching = np.sum(delta_norms < EPS)\n            if num_matching > N / 2:\n                print(\"  Accepted!  %d of %d vectors match for %s\" % (num_matching, N, pt_path))\n                if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]:\n                    print(\"  Setting model vocab to match the pretrain\")\n                    word_vocab = pt.vocab\n                    vocab['word'] = word_vocab\n                    trainer.args['word_emb_dim'] = pt.emb.shape[1]\n                break\n            else:\n                print(\"  %d of %d vectors matched for %s - SKIPPING\" % (num_matching, N, pt_path))\n                vocab_same = sum(x in pt.vocab for x in word_vocab)\n                print(\"  %d words were in both vocabs\" % vocab_same)\n                # this is expensive, and in practice doesn't happen,\n                # but theoretically we might have missed a mostly matching pt\n                # if the vocab had been scrambled\n                if DEBUG:\n                    rearranged_count = 0\n                    for x in word_vocab:\n                        if x not in pt.vocab:\n                            continue\n                        x_id = word_vocab.unit2id(x)\n                        x_vec = trainer.model.word_emb.weight[x_id, :]\n                        pt_id = pt.vocab.unit2id(x)\n                        pt_vec = pt.emb[pt_id, :]\n                        if (x_vec.detach().cpu() - pt_vec).norm() < EPS:\n                            rearranged_count += 1\n                    print(\"  %d vectors were close when ignoring id ordering\" % rearranged_count)\n        else:\n            print(\"COULD NOT FIND A MATCHING PT: {}\".format(ner_processor))\n            missing_pretrains.append(ner_model)\n            continue\n\n        # build a delta vector & embedding\n        assert 'delta' not in vocab.keys()\n        delta_vectors = [delta[i].cpu() for i in range(4)]\n        delta_vocab = []\n        for i in range(4, len(delta_norms)):\n            if delta_norms[i] > 0.0:\n                delta_vocab.append(word_vocab.id2unit(i))\n                delta_vectors.append(delta[i].cpu())\n\n        trainer.model.unsaved_modules.append(\"word_emb\")\n        if len(delta_vocab) == 0:\n            print(\"No vectors were changed!  Perhaps this model was trained without finetune.\")\n            no_finetune.append(ner_model)\n        else:\n            print(\"%d delta vocab\" % len(delta_vocab))\n            print(\"%d vectors in the delta set\" % len(delta_vectors))\n            delta_vectors = np.stack(delta_vectors)\n            delta_vectors = torch.from_numpy(delta_vectors)\n            assert delta_vectors.shape[0] == len(delta_vocab) + len(VOCAB_PREFIX)\n            print(delta_vectors.shape)\n\n            delta_vocab = PretrainedWordVocab(delta_vocab, lang=word_vocab.lang, lower=word_vocab.lower)\n            vocab['delta'] = delta_vocab\n            trainer.model.delta_emb = nn.Embedding(delta_vectors.shape[0], delta_vectors.shape[1], PAD_ID)\n            trainer.model.delta_emb.weight.data.copy_(delta_vectors)\n\n        new_path = os.path.join(new_dir, ner_model)\n        trainer.save(new_path)\n\n        final_pretrains.append((ner_model, pt_model))\n\n    print()\n    if len(final_pretrains) > 0:\n        print(\"Final pretrain mappings:\")\n        for i in final_pretrains:\n            print(i)\n    if len(missing_pretrains) > 0:\n        print(\"MISSING EMBEDDINGS:\")\n        for i in missing_pretrains:\n            print(i)\n    if len(no_finetune) > 0:\n        print(\"NOT FINE TUNED:\")\n        for i in no_finetune:\n            print(i)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/visualization/README",
    "content": "# Overview\n\nThe code in this directory contains tooling required for Semgrex and Ssurgeon visualization.\nSearching dependency graphs and manipulating them can be a time consuming and challenging task to get right.\nSemgrex is a system for searching dependency graphs and Ssurgeon is a system for manipulating the output of Semgrex.\nThe compact language used by these systems allows for easy command line or API processing of dependencies.\n\nWe now offer Semgrex and Ssurgeon through a web interface, now accessible via Streamlit with visualizations.\n\n## How to run visualizations through Streamlit\n\nStreamlit can be used to visualize Semgrex and Ssurgeon results and process files.\nHere are instructions for setting up a Streamlit webpage:\n\n1. install Streamlit. `pip install streamlit`\n2. install Stanford CoreNLP if you have not. You can find an installation here: https://stanfordnlp.github.io/CoreNLP/download.html\n3. set the $CLASSPATH environment variable to your local installation of CoreNLP.\n4. install streamlit, spacy, and ipython.  You can use the \"visualization\" stanza setup option for that\n5. Run `streamlit run stanza/utils/visualization/semgrex_app.py --theme.backgroundColor \"#FFFFFF\"`\n\nThis should begin a Streamlit runtime application on your local machine that can be interacted with.\n\nFor instructions on how to use Ssurgeon and Semgrex, refer to these helpful pages:\nhttps://aclanthology.org/2023.tlt-1.7.pdf\nhttps://nlp.stanford.edu/nlp/javadoc/javanlp-3.5.0/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html\nhttps://stanfordnlp.github.io/stanza/client_regex.html\nhttps://stanfordnlp.github.io/CoreNLP/corenlp-server.html#query-tokensregex-tokensregex\n"
  },
  {
    "path": "stanza/utils/visualization/__init__.py",
    "content": ""
  },
  {
    "path": "stanza/utils/visualization/conll_deprel_visualization.py",
    "content": "from stanza.models.common.constant import is_right_to_left\nimport spacy\nimport argparse\nfrom spacy import displacy\nfrom spacy.tokens import Doc\nfrom stanza.utils import conll\nfrom stanza.utils.visualization import dependency_visualization as viz\n\n\ndef conll_to_visual(conll_file, pipeline, sent_count=10, display_all=False):\n    \"\"\"\n    Takes in a conll file and visualizes it by converting the conll file to a Stanza Document object\n    and visualizing it with the visualize_doc method.\n\n    Input should be a proper conll file.\n\n    The pipeline for the conll file to be processed in must be provided as well.\n\n    Optionally, the sent_count argument can be tweaked to display a different amount of sentences.\n\n    To display all of the sentences in a conll file, the display_all argument can optionally be set to True.\n    BEWARE: setting this argument for a large conll file may result in too many renderings, resulting in a crash.\n    \"\"\"\n    # convert conll file to doc\n    doc = conll.CoNLL.conll2doc(conll_file)\n\n    if display_all:\n        viz.visualize_doc(conll.CoNLL.conll2doc(conll_file), pipeline)\n    else:  # visualize a given number of sentences\n        visualization_options = {\"compact\": True, \"bg\": \"#09a3d5\", \"color\": \"white\", \"distance\": 100,\n                                 \"font\": \"Source Sans Pro\", \"offset_x\": 30,\n                                 \"arrow_spacing\": 20}  # see spaCy visualization settings doc for more options\n        nlp = spacy.blank(\"en\")\n        sentences_to_visualize, rtl, num_sentences = [], is_right_to_left(pipeline), len(doc.sentences)\n\n        for i in range(sent_count):\n            if i >= num_sentences:  # case where there are less sentences than amount requested\n                break\n            sentence = doc.sentences[i]\n            words, lemmas, heads, deps, tags = [], [], [], [], []\n            sentence_words = sentence.words\n            if rtl:  # rtl languages will be visually rendered from right to left as well\n                sentence_words = reversed(sentence.words)\n                sent_len = len(sentence.words)\n            for word in sentence_words:\n                words.append(word.text)\n                lemmas.append(word.lemma)\n                deps.append(word.deprel)\n                tags.append(word.upos)\n                if rtl and word.head == 0:  # word heads are off-by-1 in spaCy doc inits compared to Stanza\n                    heads.append(sent_len - word.id)\n                elif rtl and word.head != 0:\n                    heads.append(sent_len - word.head)\n                elif not rtl and word.head == 0:\n                    heads.append(word.id - 1)\n                elif not rtl and word.head != 0:\n                    heads.append(word.head - 1)\n\n            document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)\n            sentences_to_visualize.append(document_result)\n\n        print(sentences_to_visualize)\n        for line in sentences_to_visualize:  # render all sentences through displaCy\n            displacy.render(line, style=\"dep\", options=visualization_options)\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--conll_file', type=str,\n                        default=\"C:\\\\Users\\\\Alex\\\\stanza\\\\demo\\\\en_test.conllu.txt\",\n                        help=\"File path of the CoNLL file to visualize dependencies of\")\n    parser.add_argument('--pipeline', type=str, default=\"en\",\n                        help=\"Language code of the language pipeline to use (ex: 'en' for English)\")\n    parser.add_argument('--sent_count', type=int, default=10, help=\"Number of sentences to visualize from CoNLL file\")\n    parser.add_argument('--display_all', type=bool, default=False,\n                        help=\"Whether or not to visualize all of the sentences from the file. Overrides sent_count if set to True\")\n    args = parser.parse_args()\n    conll_to_visual(args.conll_file, args.pipeline, args.sent_count, args.display_all)\n    return\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/visualization/constants.py",
    "content": "\"\"\"\nConstants used for visualization tooling\n\"\"\"\n\n# Ssurgeon constants\nSAMPLE_SSURGEON_DOC = \"\"\"\n    # sent_id = 271\n    # text = Hers is easy to clean.\n    # previous = What did the dealer like about Alex's car?\n    # comment = extraction/raising via \"tough extraction\" and clausal subject\n    1\tHers\thers\tPRON\tPRP\tGender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs\t3\tnsubj\t_\t_\n    2\tis\tbe\tAUX\tVBZ\tMood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin\t3\tcop\t_\t_\n    3\teasy\teasy\tADJ\tJJ\tDegree=Pos\t0\troot\t_\t_\n    4\tto\tto\tPART\tTO\t_\t5\tmark\t_\t_\n    5\tclean\tclean\tVERB\tVB\tVerbForm=Inf\t3\tcsubj\t_\tSpaceAfter=No\n    6\t.\t.\tPUNCT\t.\t_\t5\tpunct\t_\t_\n    \"\"\"\n\n# Semgrex constants\nDEFAULT_SAMPLE_TEXT = \"Banning opal removed artifact decks from the meta.\"\nDEFAULT_SEMGREX_QUERY = \"{pos:NN}=object <obl {}=action, {cpos:NOUN}=thing <obj {cpos:VERB}=action\"\n\n\n"
  },
  {
    "path": "stanza/utils/visualization/dependency_visualization.py",
    "content": "\"\"\"\nFunctions to visualize dependency relations in texts and Stanza documents \n\"\"\"\n\nfrom stanza.models.common.constant import is_right_to_left\nimport stanza\nimport spacy\nfrom spacy import displacy\nfrom spacy.tokens import Doc\n\n\ndef visualize_doc(doc, language):\n    \"\"\"\n    Takes in a Document and visualizes it using displacy.\n\n    The document to visualize must be from the stanza pipeline.\n\n    right-to-left languages such as Arabic are displayed right-to-left based on the language code\n    \"\"\"\n    visualization_options = {\"compact\": True, \"bg\": \"#09a3d5\", \"color\": \"white\", \"distance\": 90,\n                             \"font\": \"Source Sans Pro\", \"arrow_spacing\": 25}\n    # blank model - we don't use any of the model features, just the viz\n    nlp = spacy.blank(\"en\")\n    sentences_to_visualize = []\n    for sentence in doc.sentences:\n        words, lemmas, heads, deps, tags = [], [], [], [], []\n        if is_right_to_left(language):  # order of words displayed is reversed, dependency arcs remain intact\n            sent_len = len(sentence.words)\n            for word in reversed(sentence.words):\n                words.append(word.text)\n                lemmas.append(word.lemma)\n                deps.append(word.deprel)\n                tags.append(word.upos)\n                if word.head == 0:  # spaCy head indexes are formatted differently than that of Stanza\n                    heads.append(sent_len - word.id)\n                else:\n                    heads.append(sent_len - word.head)\n        else:   # left to right rendering\n            for word in sentence.words:\n                words.append(word.text)\n                lemmas.append(word.lemma)\n                deps.append(word.deprel)\n                tags.append(word.upos)\n                if word.head == 0:\n                    heads.append(word.id - 1)\n                else:\n                    heads.append(word.head - 1)\n        document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)\n        sentences_to_visualize.append(document_result)\n\n    for line in sentences_to_visualize:  # render all sentences through displaCy\n        # If this program is NOT being run in a Jupyter notebook, replace displacy.render with displacy.serve\n        # and the visualization will be hosted locally, link being provided in the program output.\n        displacy.render(line, style=\"dep\", options=visualization_options)\n\n\ndef visualize_str(text, pipeline_code, pipe):\n    \"\"\"\n    Takes a string and visualizes it using displacy.\n\n    The string is processed using the stanza pipeline and its\n    dependencies are formatted into a spaCy doc object for easy\n    visualization. Accepts valid stanza (UD) pipelines as the pipeline\n    argument. Must supply the stanza pipeline code (the two-letter\n    abbreviation of the language, such as 'en' for English. Must also\n    supply the stanza pipeline object as the third argument.\n    \"\"\"\n    doc = pipe(text)\n    visualize_doc(doc, pipeline_code)\n\n\ndef visualize_docs(docs, lang_code):\n    \"\"\"\n    Takes in a list of Stanza document objects and a language code (ex: 'en' for English) and visualizes the\n    dependency relationships within each document.\n\n    This function uses spaCy visualizations. See the visualize_doc function for more details.\n    \"\"\"\n    for doc in docs:\n        visualize_doc(doc, lang_code)\n\n\ndef visualize_strings(texts, lang_code):\n    \"\"\"\n    Takes a language code (ex: 'en' for English) and a list of strings to process and visualizes the\n    dependency relationships in each text.\n\n    This function loads the Stanza pipeline for the given language and uses it to visualize all of the strings provided.\n    \"\"\"\n    pipe = stanza.Pipeline(lang_code, processors=\"tokenize,pos,lemma,depparse\")\n    for text in texts:\n        visualize_str(text, lang_code, pipe)\n\n\ndef main():\n    ar_strings = ['برلين ترفض حصول شركة اميركية على رخصة تصنيع دبابة \"ليوبارد\" الالمانية', \"هل بإمكاني مساعدتك؟\",\n               \"أراك في مابعد\", \"لحظة من فضلك\"]\n    en_strings = [\"This is a sentence.\",\n                  \"Barack Obama was born in Hawaii. He was elected President of the United States in 2008.\"]\n    zh_strings = [\"中国是一个很有意思的国家。\"]\n    # Testing with right to left language\n    visualize_strings(ar_strings, \"ar\")\n    # Testing with left to right languages\n    visualize_strings(en_strings, \"en\")\n    visualize_strings(zh_strings, \"zh\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/visualization/ner_visualization.py",
    "content": "\"\"\"\nVisualize named entities from different texts and Stanza documents (+ CoNLL files)\n\"\"\"\n\nfrom spacy import displacy\nfrom spacy.tokens import Doc\nfrom spacy.tokens import Span\nfrom stanza.models.common.constant import is_right_to_left\nimport stanza\nimport spacy\nimport copy\n\n\ndef visualize_ner_doc(doc, language, select=None, colors=None):\n    \"\"\"\n    Takes a stanza doc object and language pipeline and visualizes the named entities within it.\n\n    Stanza currently supports a limited amount of languages for NER, which you can view here:\n    https://stanfordnlp.github.io/stanza/ner_models.html\n\n    To view only a specific type(s) of named entities, set the optional 'select' argument to\n    a list of the named entity types. Ex: select=[\"PER\", \"ORG\", \"GPE\"] to only see entities tagged as Person(s),\n    Organizations, and Geo-political entities. A full list of the available types can be found here:\n    https://stanfordnlp.github.io/stanza/ner_models.html (ctrl + F \"The following table\").\n\n    The colors argument is formatted as a dictionary of NER tags with their corresponding colors, which can be\n    represented as a string (ex: \"blue\"), a color hex value (ex: #aa9cfc), or as a linear gradient of color\n    values (ex: \"linear-gradient(90deg, #aa9cfc, #fc9ce7)\").\n\n    Do not change the 'rtl_clr_adjusted' argument; it is used for ensuring that the visualize_strings function\n    works properly on rtl languages.\n    \"\"\"\n    model, documents, visualization_colors = spacy.blank('en'), [], copy.deepcopy(colors)  # blank model, spacy is only used for visualization purposes\n    sentences, rtl, RTL_OVERRIDE = doc.sentences, is_right_to_left(language), \"‮\"\n    if rtl:  # need to flip order of all the sentences in rendered display\n        sentences = reversed(doc.sentences)\n        # adjust colors to be in LTR flipped format due to the RLO unicode char flipping words\n        if colors:\n            for color in visualization_colors:\n                if RTL_OVERRIDE not in color:\n                    clr_val = visualization_colors[color]\n                    visualization_colors.pop(color)\n                    visualization_colors[RTL_OVERRIDE + color[::-1]] = clr_val\n    for sentence in sentences:\n        words, display_ents, already_found = [], [], False\n        # initialize doc object with words first\n        for i, word in enumerate(sentence.words):\n            if rtl and word.text.isascii() and not already_found:\n                to_append = [word.text[::-1]]\n                next_word_index = i + 1\n                # account for flipping non Arabic words back to original form and order. two flips -> original order\n                while next_word_index <= len(sentence.words) - 1 and sentence.words[next_word_index].text.isascii():\n                    to_append.append(sentence.words[next_word_index].text[::-1])\n                    next_word_index += 1\n                to_append = reversed(to_append)\n                for token in to_append:\n                    words.append(token)\n                already_found = True\n            elif rtl and word.text.isascii() and already_found:  # skip over already collected words\n                continue\n            else:  # arabic chars\n                words.append(word.text)\n                already_found = False\n\n        document = Doc(model.vocab, words=words)\n\n        # tag all NER tokens found\n        for ent in sentence.ents:\n            if select and ent.type not in select:\n                continue\n            found_indexes = []\n            for token in ent.tokens:\n                found_indexes.append(token.id[0] - 1)\n            if not rtl:\n                to_add = Span(document, found_indexes[0], found_indexes[-1] + 1, ent.type)\n            else:  # RTL languages need the override char to flip order\n                to_add = Span(document, found_indexes[0], found_indexes[-1] + 1, RTL_OVERRIDE + ent.type[::-1])\n            display_ents.append(to_add)\n        document.set_ents(display_ents)\n        documents.append(document)\n\n    # Visualize doc objects\n    visualization_options = {\"ents\": select}\n    if colors:\n        visualization_options[\"colors\"] = visualization_colors\n    for document in documents:\n        displacy.render(document, style='ent', options=visualization_options)\n\n\ndef visualize_ner_str(text, pipe, select=None, colors=None):\n    \"\"\"\n    Takes in a text string and visualizes the named entities within the text.\n\n    Required args also include a pipeline code, the two-letter code for a language defined by Universal Dependencies (ex: \"en\" for English).\n\n    Lastly, the user must provide an NLP pipeline - we recommend Stanza (ex: pipe = stanza.Pipeline('en')).\n\n    Optionally, the 'select' argument allows for specific NER tags to be highlighted; the 'color' argument allows\n    for specific NER tags to have certain color(s).\n    \"\"\"\n    doc = pipe(text)\n    visualize_ner_doc(doc, pipe.lang, select, colors)\n\n\ndef visualize_strings(texts, language_code, select=None, colors=None):\n    \"\"\"\n    Takes in a list of strings and a language code (Stanza defines these, ex: 'en' for English) to visualize all\n    of the strings' named entities.\n\n    The strings are processed by the Stanza pipeline and the named entities are displayed. Each text is separated by a delimiting line.\n\n    Optionally, the 'select' argument may be configured to only visualize given named entities (ex: select=['ORG', 'PERSON']).\n\n    The optional colors argument is formatted as a dictionary of NER tags with their corresponding colors, which can be\n    represented as a string (ex: \"blue\"), a color hex value (ex: #aa9cfc), or as a linear gradient of color\n    values (ex: \"linear-gradient(90deg, #aa9cfc, #fc9ce7)\").\n    \"\"\"\n    lang_pipe = stanza.Pipeline(language_code, processors=\"tokenize,ner\")\n\n    for text in texts:\n        visualize_ner_str(text, lang_pipe, select=select, colors=colors)\n\n\ndef visualize_docs(docs, language_code, select=None, colors=None):\n    \"\"\"\n    Takes in a list of doc and a language code (Stanza defines these, ex: 'en' for English) to visualize all\n    of the strings' named entities.\n\n    Each text is separated by a delimiting line.\n\n    Optionally, the 'select' argument may be configured to only visualize given named entities (ex: select=['ORG', 'PERSON']).\n\n    The optional colors argument is formatted as a dictionary of NER tags with their corresponding colors, which can be\n    represented as a string (ex: \"blue\"), a color hex value (ex: #aa9cfc), or as a linear gradient of color\n    values (ex: \"linear-gradient(90deg, #aa9cfc, #fc9ce7)\").\n    \"\"\"\n    for doc in docs:\n        visualize_ner_doc(doc, language_code, select=select, colors=colors)\n\n\ndef main():\n    en_strings = ['''Samuel Jackson, a Christian man from Utah, went to the JFK Airport for a flight to New York.\n                               He was thinking of attending the US Open, his favorite tennis tournament besides Wimbledon.\n                               That would be a dream trip, certainly not possible since it is $5000 attendance and 5000 miles away.\n                               On the way there, he watched the Super Bowl for 2 hours and read War and Piece by Tolstoy for 1 hour.\n                               In New York, he crossed the Brooklyn Bridge and listened to the 5th symphony of Beethoven as well as\n                               \"All I want for Christmas is You\" by Mariah Carey.''',\n                  \"Barack Obama was born in Hawaii. He was elected President of the United States in 2008\"]\n    zh_strings = ['''来自犹他州的基督徒塞缪尔杰克逊前往肯尼迪机场搭乘航班飞往纽约。\n                             他正在考虑参加美国公开赛，这是除了温布尔登之外他最喜欢的网球赛事。\n                             那将是一次梦想之旅，当然不可能，因为它的出勤费为 5000 美元，距离 5000 英里。\n                             在去的路上，他看了 2 个小时的超级碗比赛，看了 1 个小时的托尔斯泰的《战争与碎片》。\n                               在纽约，他穿过布鲁克林大桥，聆听了贝多芬的第五交响曲以及 玛丽亚凯莉的“圣诞节我想要的就是你”。''',\n                  \"我觉得罗家费德勒住在加州, 在美国里面。\"]\n    ar_strings = [\n        \".أعيش في سان فرانسيسكو ، كاليفورنيا. اسمي أليكس وأنا ألتحق بجامعة ستانفورد. أنا أدرس علوم الكمبيوتر وأستاذي هو كريس مانينغ\"\n        , \"اسمي أليكس ، أنا من الولايات المتحدة.\",\n        '''صامويل جاكسون ، رجل مسيحي من ولاية يوتا ، ذهب إلى مطار جون كنيدي في رحلة إلى نيويورك. كان يفكر في حضور بطولة الولايات المتحدة المفتوحة للتنس ، بطولة التنس المفضلة لديه إلى جانب بطولة ويمبلدون. ستكون هذه رحلة الأحلام ، وبالتأكيد ليست ممكنة لأنها تبلغ 5000 دولار للحضور و 5000 ميل. في الطريق إلى هناك ، شاهد Super Bowl لمدة ساعتين وقرأ War and Piece by Tolstoy لمدة ساعة واحدة. في نيويورك ، عبر جسر بروكلين واستمع إلى السيمفونية الخامسة لبيتهوفن وكذلك \"كل ما أريده في عيد الميلاد هو أنت\" لماريا كاري.''']\n\n    visualize_strings(en_strings, \"en\")\n    visualize_strings(zh_strings, \"zh\", colors={\"PERSON\": \"yellow\", \"DATE\": \"red\", \"GPE\": \"blue\"})\n    visualize_strings(zh_strings, \"zh\", select=['PERSON', 'DATE'])\n    visualize_strings(ar_strings, \"ar\",\n                      colors={\"PER\": \"pink\", \"LOC\": \"linear-gradient(90deg, #aa9cfc, #fc9ce7)\", \"ORG\": \"yellow\"})\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/visualization/semgrex_app.py",
    "content": "import os \nimport sys \nimport streamlit as st\nimport streamlit.components.v1 as components\nimport stanza.utils.visualization.ssurgeon_visualizer as ssv\nimport logging\n\nfrom stanza.utils.visualization.semgrex_visualizer import visualize_search_str\nfrom stanza.utils.visualization.semgrex_visualizer import edit_html_overflow\nfrom stanza.utils.visualization.constants import *\nfrom stanza.utils.conll import CoNLL\nfrom stanza.server.ssurgeon import *\nfrom stanza.pipeline.core import Pipeline\n\nfrom io import StringIO\nimport os\nfrom typing import List, Tuple, Any\nimport argparse\n\n\ndef get_semgrex_text_and_query() -> Tuple[str, str]:\n    \"\"\"\n    Gets user input for the Semgrex text and queries to process.\n\n    @return: A tuple containing the user's input text and their input queries\n    \"\"\"\n    input_txt = st.text_area(\n        \"Text to analyze\",\n        DEFAULT_SAMPLE_TEXT,\n        placeholder=DEFAULT_SAMPLE_TEXT,\n    )\n    input_queries = st.text_area(\n        \"Semgrex search queries (separate each query with a comma)\",\n        DEFAULT_SEMGREX_QUERY,\n        placeholder=DEFAULT_SEMGREX_QUERY,\n    )\n    return input_txt, input_queries\n\n\ndef get_file_input() -> List[str]:\n    \"\"\"\n    Allows user to submit files for analysis.\n\n    @return: List of strings containing the file contents of each submitted file. The i-th element of res is the\n    string representing the i-th file uploaded.\n    \"\"\"\n    st.markdown(\"\"\"**Alternatively, upload file(s) to analyze.**\"\"\")\n    uploaded_files = st.file_uploader(\n        \"button_label\", accept_multiple_files=True, label_visibility=\"collapsed\"\n    )\n    res = []\n    for file in uploaded_files:\n        stringio = StringIO(file.getvalue().decode(\"utf-8\"))\n        string_data = stringio.read()\n        res.append(string_data)\n    return res\n\n\ndef get_semgrex_window_input() -> Tuple[bool, int, int]:\n    \"\"\"\n    Allows user to specify a specific window of Semgrex hits to visualize. Works similar to Python splicing.\n\n    @return: A tuple containing a bool representing whether or not the user wants to visualize a splice of\n    the visualizations, and two ints representing the start and end indices of the splice.\n    \"\"\"\n    show_window = st.checkbox(\n        \"Visualize a specific window of Semgrex search hits?\",\n        help=\"\"\"If you want to visualize all search results, leave this unmarked.\"\"\",\n    )\n    start_window, end_window = None, None\n    if show_window:\n        start_window = st.number_input(\n            \"Which search hit should visualizations start from?\",\n            help=\"\"\"If you want to visualize the first 10 search results, set this to 0.\"\"\",\n            min_value=0,\n        )\n        end_window = st.number_input(\n            \"Which search hit should visualizations stop on?\",\n            help=\"\"\"If you want to visualize the first 10 search results, set this to 11.\n                                     The 11th result will NOT be displayed.\"\"\",\n            value=11,\n            min_value=start_window + 1,\n        )\n    return show_window, start_window, end_window\n\n\ndef get_pos_input() -> bool:\n    \"\"\"\n    Prompts client for whether they want to see xpos tags instead of upos.\n    \"\"\"\n    use_xpos = st.checkbox(\"Would you like to visualize xpos tags?\",\n                           help=\"The default visualization options use upos tags for part-of-speech labeling. If xpos tags aren't available for the sentence, displays upos.\")\n    return use_xpos\n\n\ndef get_input() -> Tuple[str, str, List[str], Tuple[bool, int, int, bool]]:\n    \"\"\"\n    Tie together all inputs to query user for all possible inputs.\n    \"\"\"\n    input_txt, input_queries = get_semgrex_text_and_query()\n    client_files = get_file_input()  # this is already converted to string format\n    window_input = get_semgrex_window_input()\n    visualize_xpos = get_pos_input()\n    return input_txt, input_queries, client_files, window_input, visualize_xpos\n\n\ndef run_semgrex_process(\n    input_txt: str,\n    input_queries: str,\n    client_files: List[str],\n    show_window: bool,\n    clicked: bool,\n    pipe: Any,\n    start_window: int,\n    end_window: int,\n    visualize_xpos: bool,\n    show_success: bool = True\n) -> None:\n    \"\"\"\n    Run Semgrex search on the input text/files with input query and serve the HTML on the app.\n\n    @param input_txt: Text to analyze and draw sentences from.\n    @param input_queries: Semgrex queries to parse the input with.\n    @param client_files: Alternative to input text, we can parse the content of files for scaled analysis.\n    @param show_window: Whether or not the user wants a splice of the visualizations\n    @param clicked: Whether or not the button has been clicked to run Semgrex search\n    @param pipe: NLP pipeline to process input with\n    @param start_window: If displaying a splice of visualizations, this is the start idx\n    @param end_window: If displaying a splice of visualizations, this is the end idx\n    @param visualize_xpos: Set to true if using xpos tags for part of speech labels, otherwise use upos tags\n\n    \"\"\"\n\n    if clicked:\n\n        # process inputs, reject bad ones\n        if not input_txt and not client_files:\n            st.error(\"Please provide a text input or upload files for analysis.\")\n        elif input_txt and client_files:\n            st.error(\n                \"Please only choose to visualize your input text or your uploaded files, not both.\"\n            )\n        elif not input_queries:\n            st.error(\"Please provide a set of Semgrex queries.\")\n        else:  # no input errors\n            try:\n                with st.spinner(\"Processing...\"):\n                    queries = [\n                        query.strip() for query in input_queries.split(\",\")\n                    ]  # separate queries into individual parts\n                    if client_files:\n                        html_strings, begin_viz_idx, end_viz_idx = [], 0, float(\"inf\")\n                        if show_window:\n                            begin_viz_idx, end_viz_idx = (\n                                start_window - 1,\n                                end_window - 1,\n                            )\n                        for client_file in client_files:\n                            client_file_html_strings = visualize_search_str(\n                                client_file,\n                                queries,\n                                \"en\",\n                                start_match=begin_viz_idx,\n                                end_match=end_viz_idx,\n                                pipe=pipe,\n                                visualize_xpos=visualize_xpos\n                            )\n                            html_strings += client_file_html_strings\n                    else:  # just input text, no files\n                        if show_window:\n                            html_strings = visualize_search_str(\n                                input_txt,\n                                queries,\n                                \"en\",\n                                start_match=start_window - 1,\n                                end_match=end_window - 1,\n                                pipe=pipe,\n                                visualize_xpos=visualize_xpos\n                            )\n                        else:\n                            html_strings = visualize_search_str(\n                                input_txt,\n                                queries,\n                                \"en\",\n                                end_match=float(\"inf\"),\n                                pipe=pipe,\n                                visualize_xpos=visualize_xpos\n                            )\n\n\n                    if len(html_strings) == 0:\n                        st.write(\"No Semgrex match hits!\")\n\n                    # Render successful Semgrex results\n                    for s in html_strings:\n                        s_no_overflow = edit_html_overflow(s)\n                        components.html(\n                            s_no_overflow, height=200, width=1000, scrolling=True\n                        )\n                    if show_success:\n                        if len(html_strings) == 1:\n                            st.success(\n                                f\"Completed! Visualized {len(html_strings)} Semgrex search hit.\"\n                            )\n                        else:\n                            st.success(\n                                f\"Completed! Visualized {len(html_strings)} Semgrex search hits.\"\n                            )\n            except OSError:\n                st.error(\n                    \"Your text input or your provided Semgrex queries are incorrect. Please try again.\"\n                )\n\n\ndef semgrex_state():\n    \"\"\"\n    Contains the Semgrex portion of the webpage.\n\n    This contains the markdown and calls to the processes which run when a query is made.\n\n    When the `Load Semgrex search visualization` button is pressed, the function `run_semgrex_process`\n    is called inside this function and the rendered visual is placed onto the webpage.\n    \"\"\"\n\n    # Title Markdown for page header\n    st.title(\"Displaying Semgrex Queries\")\n\n    html_string = (\n        \"<h3>Enter a text below, along with your Semgrex query of choice.</h3>\"\n    )\n    st.markdown(html_string, unsafe_allow_html=True)\n    input_txt, input_queries, client_files, window_input, visualize_xpos = get_input()\n\n    show_window, start_window, end_window = window_input\n\n    clicked = st.button(\n        \"Load Semgrex search visualization\",\n        help=\"\"\"Semgrex search visualizations only display \n    sentences with a query match. Non-matching sentences are not shown.\"\"\",\n    )  # use the on_click param\n\n    run_semgrex_process(\n        input_txt=input_txt,\n        input_queries=input_queries,\n        client_files=client_files,\n        show_window=show_window,\n        clicked=clicked,\n        pipe=st.session_state[\"pipeline\"],\n        start_window=start_window,\n        end_window=end_window,\n        visualize_xpos=visualize_xpos\n    )\n\n\ndef ssurgeon_state():\n    \"\"\"\n    Contains the ssurgeon state for the webpage.\n\n    This contains the markdown and calls the processes that run Ssurgeon operations.\n\n    When the text boxes, buttons, or other interactable features are edited by the user, this function\n    runs with the updated page state and conducts operations (e.g. runs a Ssurgeon operation on a submitted file)\n    \"\"\"\n\n    st.title(\"Displaying Ssurgeon Results\")\n\n    # Textbox for input to SSurgeon (text)\n    input_txt = st.text_area(\n        \"Text to analyze\",\n        SAMPLE_SSURGEON_DOC,\n        placeholder=SAMPLE_SSURGEON_DOC,\n    )\n\n    # Textbox for input queries to SSurgeon (commands + queries)\n    semgrex_input_queries = st.text_area(\n        \"Semgrex search queries (separate each query with a comma)\",\n        \"{}=source >nsubj {} >csubj=bad {}\",\n        placeholder=\"\"\"{}=source >nsubj {} >csubj=bad {}\"\"\",\n    )\n    ssurgeon_input_queries = st.text_area(\n        \"Ssurgeon commands\",\n        \"relabelNamedEdge -edge bad -reln advcl\",\n        placeholder=\"relabelNamedEdge -edge bad -reln advcl\"\n    )\n\n    # File uploading box\n    st.markdown(\"\"\"**Alternatively, upload file(s) to edit.**\"\"\")\n    uploaded_files = st.file_uploader(\n        \"\", accept_multiple_files=True, label_visibility=\"collapsed\"\n    )\n    res = []\n    # Convert uploaded files to strings for processing\n    for file in uploaded_files: \n        stringio = StringIO(file.getvalue().decode(\"utf-8\"))\n        string_data = stringio.read()\n        res.append(string_data)\n\n    # Input button to trigger processing phase \n    clicked = st.button(\n        \"Load Ssurgeon visualization\",\n    )\n    clicked_for_file_edit = st.button(\n        \"Edit File\"\n    )\n    # Once the user requests the Ssurgeon operation, run this block:\n    if clicked:\n        try:\n            with st.spinner(\"Processing...\"):\n                semgrex_queries = semgrex_input_queries # separate queries into individual parts\n                ssurgeon_queries = [ssurgeon_input_queries]\n\n                # use SSurgeon to edit the deprel, get the HTML for new relations\n                html_strings = ssv.visualize_ssurgeon_deprel_adjusted_str_input(input_txt, semgrex_queries, ssurgeon_queries)\n                doc = CoNLL.conll2doc(input_str=input_txt)\n                string_txt = \" \".join([word.text for sentence in doc.sentences for word in sentence.words])\n\n                # Render pre-edited input\n                html_string = (\n                    \"<h3>Previous deprel visualization:</h3>\"\n                )\n                st.markdown(html_string, unsafe_allow_html=True)\n                components.html(\n                    run_semgrex_process(input_txt=string_txt, input_queries=semgrex_queries, clicked=clicked,\n                                        show_window=False, client_files=[], pipe=st.session_state[\"pipeline\"],\n                                        start_window=1, end_window=11, visualize_xpos=False, show_success=False)\n                )\n\n                if len(html_strings) == 0:\n                    st.write(\"No Semgrex match hits!\")\n\n                # Render edited outputs\n                for s in html_strings:\n                    html_string = (\n                        \"<h3>Edited deprel visualization:</h3>\"\n                    )\n                    st.markdown(html_string, unsafe_allow_html=True)\n                    s_no_overflow = edit_html_overflow(s)\n                    components.html(\n                        s_no_overflow, height=200, width=1000, scrolling=True\n                    )\n        except OSError:\n            st.error(\n                \"Your text input or your provided Semgrex/Ssurgeon queries are incorrect. Please try again.\"\n            )\n    # If the input is a file instead of raw text, process the file with Ssurgeon and give an output\n    # that can be downloaded by the client\n    if clicked_for_file_edit:\n        # files are in res\n        if len(res) == 0:\n            st.error(\"You must provide files for analysis.\")\n        with st.spinner(\"Editing...\"):\n            single_file = res[0]\n            doc = CoNLL.conll2doc(input_str=single_file)\n            ssurgeon_response = process_doc_one_operation(doc, semgrex_input_queries, [ssurgeon_input_queries])\n            updated_doc = convert_response_to_doc(doc, ssurgeon_response)\n            output = CoNLL.doc2conll(updated_doc)[0]\n            output_str = \"\\n\".join(output)\n            st.download_button(\"Download your edited file\", data=output_str, file_name=\"SSurgeon.conll\")\n\n\ndef main():\n    \n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--CLASSPATH\",\n        type=str,\n        default=os.environ.get(\"CLASSPATH\", None),\n        help=\"Path to your CoreNLP directory.\",\n    )  # for example, set $CLASSPATH to \"C:\\\\stanford-corenlp-4.5.2\\\\stanford-corenlp-4.5.2\\\\*\"\n    args = parser.parse_args()\n    \n    CLASSPATH = args.CLASSPATH\n    os.environ[\"CLASSPATH\"] = CLASSPATH\n\n    if os.environ.get(\"CLASSPATH\") is None:\n        logging.error(\"Provide a valid $CLASSPATH value (path to your CoreNLP installation).\")\n        raise ValueError(\"Provide a valid $CLASSPATH value (path to your CoreNLP installation).\")\n\n    # run pipeline once per user session\n    if \"pipeline\" not in st.session_state:  \n        en_nlp_stanza = Pipeline(\n            \"en\", processors=\"tokenize, pos, lemma, depparse\"\n        )\n        st.session_state[\"pipeline\"] = en_nlp_stanza\n\n    #### Below is the webpage states that run. Streamlit operates by having the rendered HTML and when the user interacts with\n    # the page, these states are run once more with their internal states possibly altered (e.g. user clicks a button). \n\n    semgrex_state()\n    ssurgeon_state()\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/visualization/semgrex_visualizer.py",
    "content": "import os\nimport argparse\nimport sys\n\nroot_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))\nsys.path.append(root_dir)\n\nfrom stanza.pipeline.core import Pipeline\nfrom stanza.server.semgrex import Semgrex\nfrom stanza.models.common.constant import is_right_to_left\nimport spacy\nfrom spacy import displacy\nfrom spacy.tokens import Doc\nfrom IPython.display import display, HTML\nimport typing\nfrom typing import List, Tuple, Any\n\nfrom stanza.utils.visualization.utils import find_nth, round_base\n\n\ndef get_sentences_html(doc: Any, language: str, visualize_xpos: bool = False) -> List[str]:\n    \"\"\"\n    Returns a list of HTML strings representing the dependency visualizations of a given stanza document.\n    One HTML string is generated per sentence of the document object. Converts the stanza document object\n    to a spaCy doc object and generates HTML with displaCy.\n\n    @param doc: a stanza document object which can be generated with an NLP pipeline.\n    @param language: the two letter language code for the document e.g. \"en\" for English.\n    @param visualize_xpos: A toggled option to use xpos tags for part-of-speech labels instead of upos.\n\n    @return: a list of HTML strings which visualize the dependencies of the doc object.\n    \"\"\"\n    USE_FINE_GRAINED = False if not visualize_xpos else True\n    html_strings, sentences_to_visualize = [], []\n    nlp = spacy.blank(\n        \"en\"\n    )  # blank model - we don't use any of the model features, just the visualization\n    for sentence in doc.sentences:\n        words, lemmas, heads, deps, tags = [], [], [], [], []\n        if is_right_to_left(\n            language\n        ):  # order of words displayed is reversed, dependency arcs remain intact\n            sentence_len = len(sentence.words)\n            for word in reversed(sentence.words):\n                words.append(word.text)\n                lemmas.append(word.lemma)\n                deps.append(word.deprel)\n                if visualize_xpos and word.xpos:\n                    tags.append(word.xpos)\n                else:\n                    tags.append(word.upos)\n                if word.head == 0:  # spaCy head indexes are one-off from Stanza's\n                    heads.append(sentence_len - word.id)\n                else:\n                    heads.append(sentence_len - word.head)\n        else:  # left to right rendering\n            for word in sentence.words:\n                words.append(word.text)\n                lemmas.append(word.lemma)\n                deps.append(word.deprel)\n                if visualize_xpos and word.xpos:\n                    tags.append(word.xpos)\n                else:\n                    tags.append(word.upos)\n                if word.head == 0:\n                    heads.append(word.id - 1)\n                else:\n                    heads.append(word.head - 1)\n        if USE_FINE_GRAINED:\n            stanza_to_spacy_doc = Doc(\n                nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, tags=tags\n            )\n        else:\n            stanza_to_spacy_doc = Doc(\n                nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags\n            )\n        sentences_to_visualize.append(stanza_to_spacy_doc)\n\n    for line in sentences_to_visualize:  # render all sentences through displaCy\n        html_strings.append(\n            displacy.render(\n                line,\n                style=\"dep\",\n                options={\n                    \"compact\": True,\n                    \"word_spacing\": 30,\n                    \"distance\": 100,\n                    \"arrow_spacing\": 20,\n                    \"fine_grained\": USE_FINE_GRAINED\n                },\n                jupyter=False,\n            )\n        )\n    return html_strings\n\n\ndef semgrexify_html(orig_html: str, semgrex_sentence) -> str:\n    \"\"\"\n    Modifies the HTML of a sentence's dependency visualization, highlighting words involved in the\n    semgrex_sentence search queries and adding the label of the word inside of the match.\n\n\n    @param orig_html: unedited HTML of a sentence's dependency visualization.\n    @param semgrex_sentence: a Semgrex result object containing the matches to a provided query.\n    @return: edited HTML containing the visual changes described above.\n    \"\"\"\n    tracker = {}  # keep track of which words have multiple labels\n    DEFAULT_TSPAN_COUNT = (\n        2  # the original displacy html assigns two <tspan> objects per <text> object\n    )\n    CLOSING_TSPAN_LEN = 8  # </tspan> is 8 chars long\n    colors = [\n        \"#4477AA\",\n        \"#66CCEE\",\n        \"#228833\",\n        \"#CCBB44\",\n        \"#EE6677\",\n        \"#AA3377\",\n        \"#BBBBBB\",\n    ]  # colorblind-friendly scheme\n    css_bolded_class = \"<style> .bolded{font-weight: bold;} </style>\\n\"\n    opening_svg_end_idx = orig_html.find(\"\\n\")\n    # insert the new style class\n    orig_html = (\n        orig_html[: opening_svg_end_idx + 1]\n        + css_bolded_class\n        + orig_html[opening_svg_end_idx + 1 :]\n    )\n\n    # Color and bold words involved in each Semgrex match\n    for query in semgrex_sentence.result:\n        for i, match in enumerate(query.match):\n            color = colors[i]\n            paired_dy = 2\n            for node in match.node:\n                name, match_index = node.name, node.matchIndex\n                # edit existing <tspan> to change color and bold the text\n                start = find_nth(\n                    orig_html, \"<text\", match_index\n                )  # finds start of svg <text> of interest\n                if (\n                    match_index not in tracker\n                ):  # if we've already bolded and colored, keep the first color\n                    tspan_start = orig_html.find(\n                        \"<tspan\", start\n                    )  # finds start of the first svg <tspan> inside of the <text>\n                    tspan_end = orig_html.find(\n                        \"</tspan>\", start\n                    )  # finds start of the end of the above <tspan>\n                    tspan_substr = (\n                        orig_html[tspan_start : tspan_end + CLOSING_TSPAN_LEN + 1]\n                        + \"\\n\"\n                    )\n                    # color and bold words in the search hit\n                    edited_tspan = tspan_substr.replace(\n                        'class=\"displacy-word\"', 'class=\"bolded\"'\n                    ).replace('fill=\"currentColor\"', f'fill=\"{color}\"')\n                    # insert edited <tspan> object into html string\n\n                    # TODO: DEBUG. This code has a bug in it that causes the svg to not end on an input like\n                    # \"The Wimbledon grass-court tennis tournament banned players, resulting in players hating others.\"\n                    # to malfunction and add another <svg> copy to the tail-end of the first svg rendering.\n                    # This bug has been patched in the end of this function, but need to find out what is going on.\n                    orig_html = (\n                        orig_html[:tspan_start]\n                        + edited_tspan\n                        + orig_html[tspan_end + CLOSING_TSPAN_LEN + 2 :]\n                    )\n\n                    tracker[match_index] = DEFAULT_TSPAN_COUNT\n\n                # next, we have to insert the new <tspan> object for the label\n                # Copy old <tspan> to copy formatting when creating new <tspan> later\n                prev_tspan_start = (\n                    find_nth(orig_html[start:], \"<tspan\", tracker[match_index] - 1)\n                    + start\n                )  # find the previous <tspan> start index\n                prev_tspan_end = (\n                    find_nth(orig_html[start:], \"</tspan>\", tracker[match_index] - 1)\n                    + start\n                )  # find the prev </tspan> start index\n                prev_tspan = orig_html[\n                    prev_tspan_start : prev_tspan_end + CLOSING_TSPAN_LEN + 1\n                ]\n\n                # Find spot to insert new tspan\n                closing_tspan_start = (\n                    find_nth(orig_html[start:], \"</tspan>\", tracker[match_index])\n                    + start\n                )\n                up_to_new_tspan = orig_html[\n                    : closing_tspan_start + CLOSING_TSPAN_LEN + 1\n                ]\n                rest = orig_html[closing_tspan_start + CLOSING_TSPAN_LEN + 1 :]\n\n                # Calculate proper x value in svg\n                x_value_start = prev_tspan.find('x=\"')\n                x_value_end = (\n                    prev_tspan[x_value_start + 3 :].find('\"') + 3\n                )  # 3 is the length of the 'x=\"' substring\n                x_value = prev_tspan[x_value_start + 3 : x_value_end + x_value_start]\n\n                # Calculate proper y value in svg\n                DEFAULT_DY_VAL, dy = 2, 2\n                if (\n                    paired_dy != DEFAULT_DY_VAL and node == match.node[1]\n                ):  # we're on the second node and need to adjust height to match the paired node\n                    dy = paired_dy\n                if node == match.node[0] and len(match.node) > 1:\n                    paired_node_level = 2\n                    if (\n                        match.node[1].matchIndex in tracker\n                    ):  # check if we need to adjust heights of labels\n                        paired_node_level = tracker[match.node[1].matchIndex]\n                        dif = tracker[match_index] - paired_node_level\n                        if dif > 0:  # current node has more labels\n                            paired_dy = DEFAULT_DY_VAL * dif + 1\n                            dy = DEFAULT_DY_VAL\n                        else:  # paired node has more labels, adjust this label down\n                            dy = DEFAULT_DY_VAL * (abs(dif) + 1)\n                            paired_dy = DEFAULT_DY_VAL\n\n                # Insert new <tspan> object\n                new_tspan = f'  <tspan class=\"displacy-word\" dy=\"{dy}em\" fill=\"{color}\" x={x_value}>{name[: 3].title()}.</tspan>\\n'  # abbreviate label names to 3 chars\n                orig_html = up_to_new_tspan + new_tspan + rest\n                tracker[match_index] += 1\n\n        # process out extra term if present -- TODO: Figure out why the semgrexify_html function lines 164-168 cause a duplication bug\n        end = find_nth(haystack=orig_html, needle=\"</svg\", n=1)\n        LENGTH_OF_END_SVG = 7  # </svg> has length 6 so add 1 to the end too\n        if len(orig_html) > end + LENGTH_OF_END_SVG:\n            orig_html = orig_html[: end + LENGTH_OF_END_SVG]\n\n    return orig_html\n\n\ndef render_html_strings(edited_html_strings: List[str]) -> None:\n    \"\"\"\n    Renders the HTML of each HTML string.\n    \"\"\"\n    for html_string in edited_html_strings:\n        display(HTML(html_string))\n\n\ndef visualize_search_doc(\n    doc: Any,\n    semgrex_queries: List[str],\n    lang_code: str,\n    start_match: int = 0,\n    end_match: int = 11,\n    render: bool = True,\n    visualize_xpos: bool = False\n) -> List[str]:\n    \"\"\"\n    Visualizes the result of running Semgrex search on a document. The i-th element of\n    the returned list is the HTML representation of the i-th sentence's dependency\n    relationships. Only shows sentences that have a match on the Semgrex search.\n\n    @param doc: A Stanza document object that contains dependency relationships .\n    @param semgrex_queries: A list of Semgrex queries to search for in the document.\n    @param lang_code: A two letter language abbreviation for the language that the Stanza document is written in.\n    @param start_match: Beginning of the splice for which to display elements with.\n    @param end_match: End of the splice for which to display elements with.\n    @param render: A toggled option to render the HTML strings within the returned list\n    @param visualize_xpos: A toggled option to use xpos tags in part-of-speech labels, defaulting to upos tags.\n\n    @return: A list of HTML strings representing the dependency relations of the doc object.\n    \"\"\"\n\n    matches_count = 0  # Limits number of visualizations\n    with Semgrex(classpath=\"$CLASSPATH\") as sem:\n        edited_html_strings = []\n        semgrex_results = sem.process(doc, *semgrex_queries)\n        # one html string for each sentence\n        unedited_html_strings = get_sentences_html(doc, lang_code, visualize_xpos=visualize_xpos)\n\n        for semgrex_result in semgrex_results.result:\n            if matches_count >= end_match:  # we've collected enough matches\n                break\n\n            # read the sentence_idx off the matches,\n            # in case they came back in an unexpected order\n            sentence_idx = None\n            for sentence_result in semgrex_result.result:\n                for match in sentence_result.match:\n                    sentence_idx = match.sentenceIndex\n                    break\n            # don't count empty match objects as having matched\n            if sentence_idx is None:\n                continue\n            if start_match <= matches_count < end_match:\n                unedited_html_string = unedited_html_strings[sentence_idx]\n                edited_string = semgrexify_html(\n                    unedited_html_string, semgrex_result\n                )\n                edited_string = adjust_dep_arrows(edited_string)\n                edited_html_strings.append(edited_string)\n            matches_count += 1\n        if render:\n            render_html_strings(edited_html_strings)\n    return edited_html_strings\n\n\ndef visualize_search_str(\n    text: str,\n    semgrex_queries: List[str],\n    lang_code: str,\n    start_match: int = 0,\n    end_match: int = 11,\n    pipe=None,\n    render: bool = True,\n    visualize_xpos: bool = False\n):\n    \"\"\"\n    Visualizes the result of running Semgrex search on a string. The i-th element of\n    the returned list is the HTML representation of the i-th sentence's dependency\n    relationships. Only shows sentences that have a match on the Semgrex search.\n\n    @param text: The string for which Semgrex search will be run on.\n    @param semgrex_queries: A list of Semgrex queries to search for in the document.\n    @param lang_code: A two letter language abbreviation for the language that the Stanza document is written in.\n    @param start_match: Beginning of the splice for which to display elements with.\n    @param end_match: End of the splice for which to display elements with.\n    @param pipe: An NLP pipeline through which the text will be processed.\n    @param render: A toggled option to render the HTML strings within the returned list.\n    @param visualize_xpos: A toggled option to use xpos tags for part-of-speech labeling, defaulting to upos tags\n\n    @return: A list of HTML strings representing the dependency relations of the doc object.\n    \"\"\"\n    if pipe is None:\n        nlp = Pipeline(lang_code, processors=\"tokenize, pos, lemma, depparse\")\n    else:\n        nlp = pipe\n    doc = nlp(text)\n    return visualize_search_doc(\n        doc,\n        semgrex_queries,\n        lang_code,\n        start_match=start_match,\n        end_match=end_match,\n        render=render,\n        visualize_xpos=visualize_xpos\n    )\n\n\ndef adjust_dep_arrows(raw_html: str) -> str:\n    \"\"\"\n    Default spaCy dependency visualizations have misaligned arrows. Fix arrows by aligning arrow ends and bodies\n    to the word that they are directed to.\n\n    @param raw_html: Dependency relation visualization generated HTML from displaCy\n    @return: Edited HTML string with fixed arrow placements\n    \"\"\"\n\n    HTML_ARROW_BEGINNING = '<g class=\"displacy-arrow\">'\n    HTML_ARROW_ENDING = \"</g>\"\n    HTML_ARROW_ENDING_LEN = 6  # there are 2 newline chars after the arrow ending\n    arrows_start_idx = find_nth(\n        haystack=raw_html, needle='<g class=\"displacy-arrow\">', n=1\n    )\n    words_html, arrows_html = (\n        raw_html[:arrows_start_idx],\n        raw_html[arrows_start_idx:],\n    )  # separate html for words and arrows\n    final_html = (\n        words_html  # continually concatenate to this after processing each arrow\n    )\n    arrow_number = 1  # which arrow we're currently editing (1-indexed)\n    start_idx, end_of_class_idx = (\n        find_nth(haystack=arrows_html, needle=HTML_ARROW_BEGINNING, n=arrow_number),\n        find_nth(haystack=arrows_html, needle=HTML_ARROW_ENDING, n=arrow_number),\n    )\n    while start_idx != -1:  # edit every arrow\n        arrow_section = arrows_html[\n            start_idx : end_of_class_idx + HTML_ARROW_ENDING_LEN\n        ]  # slice a single svg arrow object\n        if (\n            arrow_section[-1] == \"<\"\n        ):  # this is the last arrow in the HTML, don't cut the splice early\n            arrow_section = arrows_html[start_idx:]\n        edited_arrow_section = edit_dep_arrow(arrow_section)\n\n        final_html = (\n            final_html + edited_arrow_section\n        )  # continually update html with new arrow html until done\n\n        # Prepare for next iteration\n        arrow_number += 1\n        start_idx = find_nth(arrows_html, '<g class=\"displacy-arrow\">', arrow_number)\n        end_of_class_idx = find_nth(arrows_html, \"</g>\", arrow_number)\n    return final_html\n\n\ndef edit_dep_arrow(arrow_html: str) -> str:\n    \"\"\"\n    The formatting of a single displacy arrow in svg is the following:\n    <g class=\"displacy-arrow\">\n        <path class=\"displacy-arc\" id=\"arrow-c628889ffbf343e3848193a08606f10a-0-0\" stroke-width=\"2px\" d=\"M70,352.0 C70,177.0 390.0,177.0 390.0,352.0\" fill=\"none\" stroke=\"currentColor\"/>\n        <text dy=\"1.25em\" style=\"font-size: 0.8em; letter-spacing: 1px\">\n            <textPath xlink:href=\"#arrow-c628889ffbf343e3848193a08606f10a-0-0\" class=\"displacy-label\" startOffset=\"50%\" side=\"left\" fill=\"currentColor\" text-anchor=\"middle\">csubj</textPath>\n        </text>\n        <path class=\"displacy-arrowhead\" d=\"M70,354.0 L62,342.0 78,342.0\" fill=\"currentColor\"/>\n    </g>\n\n    We edit the 'd = ...' parts of the <path class ...> section to fix the arrow direction and length to round to\n    the nearest 50 units, centering on each word's center. This is because the words start at x=50 and have spacing\n    of 100, so each word is at an x-value that is a multiple of 50.\n\n    @param arrow_html: Original SVG for a single displaCy arrow.\n    @return: Edited SVG for the displaCy arrow, adjusting its placement\n    \"\"\"\n\n    WORD_SPACING = 50  # words start at x=50 and are separated by 100s so their x values are multiples of 50\n    M_OFFSET = 4  # length of 'd=\"M' that we search for to extract the number from d=\"M70, for instance\n    ARROW_PIXEL_SIZE = 4\n    first_d_idx, second_d_idx = (\n        find_nth(arrow_html, 'd=\"M', 1),\n        find_nth(arrow_html, 'd=\"M', 2),\n    )  # find where d=\"M starts\n    first_d_cutoff, second_d_cutoff = (\n        arrow_html.find(\",\", first_d_idx),\n        arrow_html.find(\",\", second_d_idx),\n    )  # isolate the number after 'M' e.g. 'M70'\n    # gives svg x values of arrow body starting position and arrowhead position\n    arrow_position, arrowhead_position = (\n        float(arrow_html[first_d_idx + M_OFFSET : first_d_cutoff]),\n        float(arrow_html[second_d_idx + M_OFFSET : second_d_cutoff]),\n    )\n    # gives starting index of where 'fill=\"none\"' or 'fill=\"currentColor\"' begin, reference points to end the d= section\n    first_fill_start_idx, second_fill_start_idx = (\n        find_nth(arrow_html, \"fill\", n=1),\n        find_nth(arrow_html, \"fill\", n=3),\n    )\n\n    # isolate the d= ... section to edit\n    first_d, second_d = (\n        arrow_html[first_d_idx:first_fill_start_idx],\n        arrow_html[second_d_idx:second_fill_start_idx],\n    )\n    first_d_split, second_d_split = first_d.split(\",\"), second_d.split(\",\")\n\n    if (\n        arrow_position == arrowhead_position\n    ):  # This arrow is incoming onto the word, center the arrow/head to word center\n        corrected_arrow_pos = corrected_arrowhead_pos = round_base(\n            arrow_position, base=WORD_SPACING\n        )\n\n        # edit first_d  -- arrow body\n        second_term = first_d_split[1].split(\" \")[0] + \" \" + str(corrected_arrow_pos)\n        first_d = (\n            'd=\"M'\n            + str(corrected_arrow_pos)\n            + \",\"\n            + second_term\n            + \",\"\n            + \",\".join(first_d_split[2:])\n        )\n\n        # edit second_d  -- arrowhead\n        second_term = (\n            second_d_split[1].split(\" \")[0]\n            + \" L\"\n            + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n        )\n        third_term = (\n            second_d_split[2].split(\" \")[0]\n            + \" \"\n            + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n        )\n        second_d = (\n            'd=\"M'\n            + str(corrected_arrowhead_pos)\n            + \",\"\n            + second_term\n            + \",\"\n            + third_term\n            + \",\"\n            + \",\".join(second_d_split[3:])\n        )\n    else:  # This arrow is outgoing to another word, center the arrow/head to that word's center\n        corrected_arrowhead_pos = round_base(arrowhead_position, base=WORD_SPACING)\n\n        # edit first_d -- arrow body\n        third_term = first_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n        fourth_term = (\n            first_d_split[3].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n        )\n        terms = [\n            first_d_split[0],\n            first_d_split[1],\n            third_term,\n            fourth_term,\n        ] + first_d_split[4:]\n        first_d = \",\".join(terms)\n\n        # edit second_d -- arrow head\n        first_term = f'd=\"M{corrected_arrowhead_pos}'\n        second_term = (\n            second_d_split[1].split(\" \")[0]\n            + \" L\"\n            + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n        )\n        third_term = (\n            second_d_split[2].split(\" \")[0]\n            + \" \"\n            + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n        )\n        terms = [first_term, second_term, third_term] + second_d_split[3:]\n        second_d = \",\".join(terms)\n    # rebuild and return html from its individual sections\n    return (\n        arrow_html[:first_d_idx]\n        + first_d\n        + \" \"\n        + arrow_html[first_fill_start_idx:second_d_idx]\n        + second_d\n        + \" \"\n        + arrow_html[second_fill_start_idx:]\n    )\n\n\ndef edit_html_overflow(html_string: str) -> str:\n    \"\"\"\n    Adds to overflow and display settings to the SVG header to visualize overflowing HTML renderings in the\n    Semgrex streamlit app. Prevents Semgrex search tags from being cut off at the bottom of visualizations.\n\n    The opening of each HTML string looks similar to this; we add to the end of the SVG header.\n\n    <svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" xml:lang=\"en\" id=\"fa9446a525de4862b233007f26dbbecb-0\" class=\"displacy\" width=\"850\" height=\"242.0\" direction=\"ltr\" style=\"max-width: none; height: 242.0px; color: #000000; background: #ffffff; font-family: Arial; direction: ltr\">\n    <style> .bolded{font-weight: bold;} </style>\n    <text class=\"displacy-token\" fill=\"currentColor\" text-anchor=\"middle\" y=\"182.0\">\n        <tspan class=\"bolded\" fill=\"#66CCEE\" x=\"50\">Banning</tspan>\n\n       <tspan class=\"displacy-tag\" dy=\"2em\" fill=\"currentColor\" x=\"50\">VERB</tspan>\n      <tspan class=\"displacy-word\" dy=\"2em\" fill=\"#66CCEE\" x=50>Act.</tspan>\n    </text>\n\n    @param html_string: HTML of the result of running Semgrex search on a text\n    @return: Edited HTML to visualize the dependencies even in the case of overflow.\n    \"\"\"\n\n    BUFFER_LEN = 14  # length of 'direction: ltr\"'\n    editing_start_idx = find_nth(html_string, \"direction: ltr\", n=1)\n    SVG_HEADER_ADDITION = \"overflow: visible; display: block\"\n    return (\n        html_string[:editing_start_idx]\n        + \"; \"\n        + SVG_HEADER_ADDITION\n        + html_string[editing_start_idx + BUFFER_LEN :]\n    )\n\n\ndef main():\n    \"\"\"\n    IMPORTANT: For the code in this module to run, you must have corenlp and Java installed on your machine. Additionally,\n    set an environment variable CLASSPATH equal to the path of your corenlp directory.\n\n    Example: CLASSPATH=C:\\\\Users\\\\Alex\\\\PycharmProjects\\\\pythonProject\\\\stanford-corenlp-4.5.0\\\\stanford-corenlp-4.5.0\\\\*\n    \"\"\"\n    nlp = Pipeline(\"en\", processors=\"tokenize,pos,lemma,depparse\")\n    doc = nlp(\n        \"Banning opal removed artifact decks from the meta. Banning tennis resulted in players banning people.\"\n    )\n    queries = [\n        \"{pos:NN}=object <obl {}=action\",\n        \"{cpos:NOUN}=thing <obj {cpos:VERB}=action\",\n    ]\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--doc\", type=Any, default=doc, help=\"Stanza document to process.\"\n    )\n    parser.add_argument(\n        \"--queries\",\n        type=List[str],\n        default=queries,\n        help=\"Semgrex queries to search for\",\n    )\n    parser.add_argument(\n        \"--lang_code\",\n        type=str,\n        default=\"en\",\n        help=\"Two letter abbreviation the document's language e.g. 'en' for English\",\n    )\n    parser.add_argument(\n        \"--CLASSPATH\",\n        type=str,\n        default=\"C:\\\\stanford-corenlp-4.5.2\\\\stanford-corenlp-4.5.2\\\\*\",\n        help=\"Path to your coreNLP directory\",\n    )\n    args = parser.parse_args()\n    os.environ[\"CLASSPATH\"] = args.CLASSPATH\n    try:\n        res = visualize_search_doc(doc, queries, \"en\")\n        print(res[0])  # first sentence's HTML\n    except TypeError:\n        raise TypeError(\n            \"\"\"For the code in this module to run, you must have corenlp and Java installed on your machine. \n            Once installed, you can pass in the path to your corenlp directory as a command-line argument named \n            \"CLASSPATH\". Alternatively, set an environment variable CLASSPATH equal to the path of your corenlp \n            directory.\"\"\"\n        )\n    return\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "stanza/utils/visualization/ssurgeon_visualizer.py",
    "content": "\"\"\"\nVisualization tooling for Ssurgeon\n\"\"\"\nimport os\nimport sys\nimport stanza.utils.visualization.semgrex_visualizer as sv\nimport stanza.server.ssurgeon\nfrom stanza.server.ssurgeon import process_doc_one_operation, convert_response_to_doc\nfrom stanza.utils.conll import CoNLL\nfrom stanza.utils.visualization.constants import *\nimport logging\n\n\ndef generate_edited_deprel_unadjusted(edited_doc, lang_code, visualize_xpos):\n    \"\"\"\n    Submit edited doc from ssurgeon to generate HTML for sentences output\n    :param edited_doc:\n    :param lang_code:\n    :param visualize_xpos:\n    :return:\n    \"\"\"\n    return sv.get_sentences_html(doc=edited_doc, language=lang_code, visualize_xpos=visualize_xpos)\n\n\ndef visualize_ssurgeon_deprel_adjusted_str_input(input_str, semgrex_query, ssurgeon_query, lang_code=\"en\", visualize_xpos=False, render=False):\n    \"\"\"\n    Visualizes the edited side of the ssurgeon edit\n    :param unedited_doc:\n    :param semgrex_query:\n    :param ssurgeon_query:\n    :return:\n    \"\"\"\n    doc = CoNLL.conll2doc(input_str=input_str)\n    ssurgeon_response = process_doc_one_operation(doc, semgrex_query, ssurgeon_query)\n    updated_doc = convert_response_to_doc(doc, ssurgeon_response)\n    html_strings = generate_edited_deprel_unadjusted(updated_doc, lang_code, visualize_xpos=visualize_xpos)\n    edited_html_strings = []\n    for i in range(len(html_strings)):\n        edited_html = sv.adjust_dep_arrows(html_strings[i])\n        edited_html_strings.append(edited_html)\n\n    if render:\n        sv.render_html_strings(edited_html_strings)\n\n    return edited_html_strings\n\n\ndef main():\n    # Load classpath if not already existing\n    if not os.environ.get('CLASSPATH'):\n        logging.info(\"Load the path to wherever CoreNLP is installed on your machine to $CLASSPATH.\")\n\n    # The default semgrex detects sentences in the UD_English-Pronouns dataset which have both nsubj and csubj on the same word.\n    # The default ssurgeon transforms the unwanted csubj to advcl\n    # See https://github.com/UniversalDependencies/docs/issues/923\n    ssurgeon = [\"relabelNamedEdge -edge bad -reln advcl\"]  # example\n    semgrex = \"{}=source >nsubj {} >csubj=bad {}\"  # example\n    SSURGEON_JAVA = \"edu.stanford.nlp.semgraph.semgrex.ssurgeon.ProcessSsurgeonRequest\"\n    doc = CoNLL.conll2doc(input_str=SAMPLE_SSURGEON_DOC)\n\n    print(\"{:C}\".format(doc))\n    ssurgeon_response = process_doc_one_operation(doc, semgrex, ssurgeon)\n    updated_doc = convert_response_to_doc(doc, ssurgeon_response)\n    print(\"{:C}\".format(updated_doc))\n    print(generate_edited_deprel_unadjusted(updated_doc, lang_code='en', visualize_xpos=False))\n    visualize_ssurgeon_deprel_adjusted_str_input(SAMPLE_SSURGEON_DOC, semgrex, ssurgeon)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "stanza/utils/visualization/utils.py",
    "content": "def find_nth(haystack, needle, n):\n    \"\"\"\n    Returns the starting index of the nth occurrence of the substring 'needle' in the string 'haystack'.\n    \"\"\"\n    start = haystack.find(needle)\n    while start >= 0 and n > 1:\n        start = haystack.find(needle, start + len(needle))\n        n -= 1\n    return start\n\n\ndef round_base(num, base=10):\n    \"\"\"\n    Rounding a number to its nearest multiple of the base. round_base(49.2, base=50) = 50.\n    \"\"\"\n    return base * round(num / base)"
  }
]