Repository: xhlulu/dl-translate Branch: main Commit: d9f1fcd6d7d9 Files: 40 Total size: 180.1 KB Directory structure: gitextract_lc7n6m8w/ ├── .github/ │ └── workflows/ │ ├── generate-docs.yml │ ├── main.yaml │ └── python-publish.yml ├── .gitignore ├── .readthedocs.yaml ├── CITATION.cff ├── LICENSE ├── MANIFEST.in ├── README.md ├── demos/ │ ├── colab_demo.ipynb │ └── nllb_demo.ipynb ├── dl_translate/ │ ├── __init__.py │ ├── _pairs.py │ ├── _translation_model.py │ ├── lang/ │ │ ├── __init__.py │ │ ├── m2m100.py │ │ ├── mbart50.py │ │ └── nllb200.py │ └── utils.py ├── docs/ │ ├── available_languages.md │ ├── contributing.md │ ├── index.md │ ├── references.md │ └── requirements.txt ├── mkdocs-rtd.yml ├── mkdocs.yml ├── scripts/ │ ├── generate_langs.py │ ├── langs_coverage/ │ │ ├── m2m100.json │ │ ├── mbart50.json │ │ └── nllb200.json │ ├── render_available_langs.py │ ├── render_references.py │ └── templates/ │ ├── available_languages.md.jinja2 │ └── references.md.jinja2 ├── setup.py └── tests/ ├── long/ │ ├── test_save_load.py │ └── test_translate.py └── quick/ ├── test_lang.py ├── test_translation_model.py └── test_utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/generate-docs.yml ================================================ name: Publish Docs on: # Trigger the workflow on push or pull request, # but only for the main branch push: branches: - main jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: python-version: "3.7" - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r docs/requirements.txt - name: Publish Docs run: | mkdocs gh-deploy -t material --force ================================================ FILE: .github/workflows/main.yaml ================================================ name: Lint and run tests on: push: branches: - main pull_request: branches: - main jobs: build: runs-on: ubuntu-latest strategy: matrix: python-version: [3.8, 3.9] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - uses: actions/cache@v2 # source: https://medium.com/ai2-blog/e9452698e98d with: path: ${{ env.pythonLocation }} key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }} - name: Install requirements run: | python -m pip install --upgrade pip pip install --upgrade --upgrade-strategy eager -e .[dev] - name: Show setup install requires versions run: | pip show transformers torch tqdm protobuf tqdm - name: Lint code with black run: | black . --check - name: Run quick tests with pytest run: | pytest tests/quick - name: Run long tests with pytest if: ${{ matrix.python-version == '3.7' }} run: | pytest tests/long ================================================ FILE: .github/workflows/python-publish.yml ================================================ # This workflow will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries name: Publish Python Package on: release: types: [published] jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: python-version: '3.8' - name: Install dependencies run: | python -m pip install --upgrade pip pip install setuptools wheel twine - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: | python setup.py sdist bdist_wheel twine upload dist/* ================================================ FILE: .gitignore ================================================ # Custom .vscode # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ ================================================ FILE: .readthedocs.yaml ================================================ # .readthedocs.yaml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 mkdocs: configuration: mkdocs-rtd.yml # Optionally build your docs in additional formats such as PDF formats: - pdf # Optionally set the version of Python and requirements required to build your docs python: version: 3.7 ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 message: "If you use this software, please cite it as below." authors: - family-names: "Lu" given-names: "Xing Han" orcid: "https://orcid.org/0000-0001-9027-8425" title: "DL Translate: A deep learning-based translation library built on Huggingface transformers" version: 0.3.0 doi: 10.5281/zenodo.5230676 date-released: 2021-08-21 url: "https://github.com/xhlulu/dl-translate" ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 Xing Han Lu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ include CITATION.cff ================================================ FILE: README.md ================================================ # DL Translate [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5230676.svg)](https://doi.org/10.5281/zenodo.5230676) [![Downloads](https://static.pepy.tech/personalized-badge/dl-translate?period=total&units=abbreviation&left_color=grey&right_color=orange&left_text=Downloads)](https://pepy.tech/project/dl-translate) [![License](https://img.shields.io/badge/license-MIT-green)](https://github.com/xhluca/dl-translate/blob/main/LICENSE) *A translation library for 200 languages built on Huggingface `transformers`* 💻 [GitHub Repository](https://github.com/xhluca/dl-translate)
📚 [Documentation](https://xhluca.github.io/dl-translate)
🐍 [PyPi project](https://pypi.org/project/dl-translate/)
🧪 [Colab Demo](https://colab.research.google.com/github/xhluca/dl-translate/blob/main/demos/colab_demo.ipynb) / [Kaggle Demo](https://www.kaggle.com/xhlulu/dl-translate-demo/) ## Quickstart Install the library with pip: ``` pip install dl-translate ``` To translate some text: ```python import dl_translate as dlt mt = dlt.TranslationModel() # Slow when you load it for the first time text_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है" mt.translate(text_hi, source=dlt.lang.HINDI, target=dlt.lang.ENGLISH) ``` Above, you can see that `dlt.lang` contains variables representing each of the 200 available languages with auto-complete support. Alternatively, you can specify the language (e.g. "Arabic") or the language code (e.g. "fr" for French): ```python text_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا." mt.translate(text_ar, source="Arabic", target="fr") ``` If you want to verify whether a language is available, you can check it: ```python print(mt.available_languages()) # All languages that you can use print(mt.available_codes()) # Code corresponding to each language accepted print(mt.get_lang_code_map()) # Dictionary of lang -> code ``` ## Usage ### Selecting a device When you load the model, you can specify the device: ```python mt = dlt.TranslationModel(device="auto") ``` By default, the value will be `device="auto"`, which means it will use a GPU if possible. You can also explicitly set `device="cpu"` or `device="gpu"`, or some other strings accepted by [`torch.device()`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device). __In general, it is recommend to use a GPU if you want a reasonable processing time.__ ### Choosing a different model By default, the `m2m100` model will be used. However, there are a few options: * [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html): Allows translations across 50 languages. * [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html): Allows translations across 100 languages. * [nllb-200](https://huggingface.co/docs/transformers/model_doc/nllb) (New in v0.3): Allows translations across 200 languages, and is faster than m2m100 (On RTX A6000, we can see speed up of 3x). Here's an example: ```python # The default approval mt = dlt.TranslationModel("m2m100") # Shorthand mt = dlt.TranslationModel("facebook/m2m100_418M") # Huggingface repo # If you want to use mBART-50 Large mt = dlt.TranslationModel("mbart50") mt = dlt.TranslationModel("facebook/mbart-large-50-many-to-many-mmt") # Or NLLB-200 (faster and has 200 languages) mt = dlt.TranslationModel("nllb200") mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M") ``` Note that the language code will change depending on the model family. To find out the correct language codes, please read the doc page on available languages or run `mt.available_codes()`. By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt), [m2m100](https://huggingface.co/facebook/m2m100_418M), or [nllb200](https://huggingface.co/facebook/nllb-200-distilled-600M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`: ```python mt = dlt.TranslationModel("/path/to/model/directory/", model_family="mbart50") mt = dlt.TranslationModel("facebook/m2m100_1.2B", model_family="m2m100") mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M", model_family="nllb200") ``` Notes: * Make sure your tokenizer is also stored in the same directory if you load from a file. * The available languages will change if you select a different model, so you will not be able to leverage `dlt.lang` or `dlt.utils`. ### Splitting into sentences It is not recommended to use extremely long texts as it takes more time to process. Instead, you can try to break them down into sentences with the help of `nltk`. First install the library with `pip install nltk`, then run: ```python import nltk nltk.download("punkt") text = "Mr. Smith went to his favorite cafe. There, he met his friend Dr. Doe." sents = nltk.tokenize.sent_tokenize(text, "english") # don't use dlt.lang.ENGLISH " ".join(mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH)) ``` ### Batch size during translation It's possible to set a batch size (i.e. the number of elements processed at once) for `mt.translate` and whether you want to see the progress bar or not: ```python # ... mt = dlt.TranslationModel() mt.translate(text, source, target, batch_size=32, verbose=True) ``` If you set `batch_size=None`, it will compute the entire `text` at once rather than splitting into "chunks". We recommend lowering `batch_size` if you do not have a lot of RAM or VRAM and run into CUDA memory error. Set a higher value if you are using a high-end GPU and the VRAM is not fully utilized. ### `dlt.utils` module An alternative to `mt.available_languages()` is the `dlt.utils` module. You can use it to find out which languages and codes are available: ```python print(dlt.utils.available_languages('mbart50')) # All languages that you can use print(dlt.utils.available_codes('m2m100')) # Code corresponding to each language accepted print(dlt.utils.get_lang_code_map('nllb200')) # Dictionary of lang -> code ``` ### Offline usage Unlike the Google translate or MSFT Translator APIs, this library can be fully used offline. However, you will need to first download the packages and models, and move them to your offline environment to be installed and loaded inside a venv. First, run in your terminal: ```bash mkdir dlt cd dlt mkdir libraries pip download -d libraries/ dl-translate ``` Once all the required packages are downloaded, you will need to use huggingface hub to download the files. Install it with `pip install huggingface-hub`. Then, run inside Python: ```python import shutil import huggingface_hub as hub dirname = hub.snapshot_download("facebook/m2m100_418M") shutil.copytree(dirname, "cached_model_m2m100") # Copy to a permanent folder ``` Now, move everything in the `dlt` directory to your offline environment. Create a virtual environment and run the following in terminal: ```bash pip install --no-index --find-links libraries/ dl-translate ``` Now, run inside Python: ```python import dl_translate as dlt mt = dlt.TranslationModel("cached_model_m2m100", model_family="m2m100") ``` ## Advanced If you have knowledge of PyTorch and Huggingface Transformers, you can access advanced aspects of the library for more customization: * **Saving and loading**: If you wish to accelerate the loading time the translation model, you can use `save_obj` and reload it later with `load_obj`. This method is only recommended if you are familiar with `huggingface` and `torch`; please read the docs for more information. * **Interacting with underlying model and tokenizer**: When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer with `model_options` and `tokenizer_options` respectively. You can also access the underlying `transformers` with `mt.get_transformers_model()`. * **Keyword arguments for the `generate()` method**: When running `mt.translate`, you can also give `generation_options` that is passed to the `generate()` method of the underlying transformer model. For more information, please visit the [advanced section of the user guide](https://xhluca.github.io/dl-translate/#advanced). ## Acknowledgement `dl-translate` is built on top of Huggingface's implementation of two models created by Facebook AI Research. 1. The multilingual BART finetuned on many-to-many translation of over 50 languages, which is [documented here](https://huggingface.co/transformers/master/model_doc/mbart.html) The original paper was written by Tang et. al from Facebook AI Research; you can [find it here](https://arxiv.org/pdf/2008.00401.pdf) and cite it using the following: ``` @article{tang2020multilingual, title={Multilingual translation with extensible multilingual pretraining and finetuning}, author={Tang, Yuqing and Tran, Chau and Li, Xian and Chen, Peng-Jen and Goyal, Naman and Chaudhary, Vishrav and Gu, Jiatao and Fan, Angela}, journal={arXiv preprint arXiv:2008.00401}, year={2020} } ``` 2. The transformer model published in [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Fan et. al, which supports over 100 languages. You can cite it here: ``` @misc{fan2020englishcentric, title={Beyond English-Centric Multilingual Machine Translation}, author={Angela Fan and Shruti Bhosale and Holger Schwenk and Zhiyi Ma and Ahmed El-Kishky and Siddharth Goyal and Mandeep Baines and Onur Celebi and Guillaume Wenzek and Vishrav Chaudhary and Naman Goyal and Tom Birch and Vitaliy Liptchinsky and Sergey Edunov and Edouard Grave and Michael Auli and Armand Joulin}, year={2020}, eprint={2010.11125}, archivePrefix={arXiv}, primaryClass={cs.CL} } ``` 3. The [no language left behind](https://arxiv.org/abs/2207.04672) model, which extends NMT to 200+ languages. You can cite it here: ``` @misc{nllbteam2022language, title={No Language Left Behind: Scaling Human-Centered Machine Translation}, author={NLLB Team and Marta R. Costa-jussà and James Cross and Onur Çelebi and Maha Elbayad and Kenneth Heafield and Kevin Heffernan and Elahe Kalbassi and Janice Lam and Daniel Licht and Jean Maillard and Anna Sun and Skyler Wang and Guillaume Wenzek and Al Youngblood and Bapi Akula and Loic Barrault and Gabriel Mejia Gonzalez and Prangthip Hansanti and John Hoffman and Semarley Jarrett and Kaushik Ram Sadagopan and Dirk Rowe and Shannon Spruit and Chau Tran and Pierre Andrews and Necip Fazil Ayan and Shruti Bhosale and Sergey Edunov and Angela Fan and Cynthia Gao and Vedanuj Goswami and Francisco Guzmán and Philipp Koehn and Alexandre Mourachko and Christophe Ropers and Safiyyah Saleem and Holger Schwenk and Jeff Wang}, year={2022}, eprint={2207.04672}, archivePrefix={arXiv}, primaryClass={cs.CL} } ``` `dlt` is a wrapper with useful `utils` to save you time. For huggingface's `transformers`, the following snippet is shown as an example: ```python from transformers import MBartForConditionalGeneration, MBart50TokenizerFast article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है" article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا." model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") # translate Hindi to French tokenizer.src_lang = "hi_IN" encoded_hi = tokenizer(article_hi, return_tensors="pt") generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"]) tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) # => "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria." # translate Arabic to English tokenizer.src_lang = "ar_AR" encoded_ar = tokenizer(article_ar, return_tensors="pt") generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]) tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) # => "The Secretary-General of the United Nations says there is no military solution in Syria." ``` With `dlt`, you can run: ```python import dl_translate as dlt article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है" article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا." mt = dlt.TranslationModel() translated_fr = mt.translate(article_hi, source=dlt.lang.HINDI, target=dlt.lang.FRENCH) translated_en = mt.translate(article_ar, source=dlt.lang.ARABIC, target=dlt.lang.ENGLISH) ``` Notice you don't have to think about tokenizers, condition generation, pretrained models, and regional codes; you can just tell the model what to translate! If you are experienced with `huggingface`'s ecosystem, then you should be familiar enough with the example above that you wouldn't need this library. However, if you've never heard of huggingface or mBART, then I hope using this library will give you enough motivation to [learn more about them](https://github.com/huggingface/transformers) :) ================================================ FILE: demos/colab_demo.ipynb ================================================ { "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "name": "python3", "display_name": "Python 3", "language": "python" }, "language_info": { "name": "python", "version": "3.7.9", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "colab": { "name": "dl-translate demo.ipynb", "provenance": [], "collapsed_sections": [] }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "9695e0e8562c4104b8e28a25bf05991e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_7e7d388cb3ea475098dcca168cba2635", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_7424952a9ad34bc9a0094ed6df2881ab", "IPY_MODEL_bd39d9cea7f949babdb9048b41f4055f" ] } }, "7e7d388cb3ea475098dcca168cba2635": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "7424952a9ad34bc9a0094ed6df2881ab": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_fee217673faf4424883be8bf33574584", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 3708092, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 3708092, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_2ad1b48139f14387bac9b6046b0d8d60" } }, "bd39d9cea7f949babdb9048b41f4055f": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_79733c788f914b66a5803f1b71554c26", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 3.71M/3.71M [00:03<00:00, 1.22MB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_cdc4f3fc6bf34edfa6ed14d1edc1c0da" } }, "fee217673faf4424883be8bf33574584": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "2ad1b48139f14387bac9b6046b0d8d60": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "79733c788f914b66a5803f1b71554c26": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "cdc4f3fc6bf34edfa6ed14d1edc1c0da": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "451f6d64f9d14cfe821ef82af4313223": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_d213c156b7f941de86f89704200898cc", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_d2ab70c5a6b24d3582e13bd7300e4b48", "IPY_MODEL_d596d132f81a4577ab6eeff01c858f4a" ] } }, "d213c156b7f941de86f89704200898cc": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "d2ab70c5a6b24d3582e13bd7300e4b48": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_e4ef14da423a4e37b23815228dff047c", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 2423393, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 2423393, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_f489df53fce4489c85248d0a0fc0347b" } }, "d596d132f81a4577ab6eeff01c858f4a": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_347117e2b4084588af0120de4cba9992", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 2.42M/2.42M [00:00<00:00, 5.19MB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_f7dde0ddeb954feeb6334f91957b3b4f" } }, "e4ef14da423a4e37b23815228dff047c": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "f489df53fce4489c85248d0a0fc0347b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "347117e2b4084588af0120de4cba9992": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "f7dde0ddeb954feeb6334f91957b3b4f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "d85f7d324e534e7fb7665440581927af": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_36a4ac5796544fe5b50e802cd0b4b599", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_65782adbd1e943f39a3d282e4f217daf", "IPY_MODEL_d78f6159d9fe4b0e840b2d1912a253d8" ] } }, "36a4ac5796544fe5b50e802cd0b4b599": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "65782adbd1e943f39a3d282e4f217daf": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_db2021d70d584ef987b1580962ad919c", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 272, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 272, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_2cc9fc8c66aa4381ab3a6e742d0547dc" } }, "d78f6159d9fe4b0e840b2d1912a253d8": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_256273b11d094044a36d0056498cc487", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 272/272 [00:01<00:00, 236B/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_361e88b5998d4df996113c7b6c55bd1e" } }, "db2021d70d584ef987b1580962ad919c": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "2cc9fc8c66aa4381ab3a6e742d0547dc": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "256273b11d094044a36d0056498cc487": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "361e88b5998d4df996113c7b6c55bd1e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "0236fdbddbcc4049b48713bb17ba4b74": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_878fc2e400354baf84a8a134d9038174", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_e2a4fd0970f94aeb8b841870d95662a8", "IPY_MODEL_bc263f4b7ac94cfeaa7609341d5394e2" ] } }, "878fc2e400354baf84a8a134d9038174": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "e2a4fd0970f94aeb8b841870d95662a8": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_266b8b8d1dda42afa1b5fdafb66af4ac", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 1140, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 1140, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_dc8604c9207b4a8ca364a817c6516378" } }, "bc263f4b7ac94cfeaa7609341d5394e2": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_f1958b44d65243a0b4ca5ba2803d3598", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 1.14k/1.14k [00:00<00:00, 2.38kB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_d26580979feb41cdb423d151bd5ebd15" } }, "266b8b8d1dda42afa1b5fdafb66af4ac": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "dc8604c9207b4a8ca364a817c6516378": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "f1958b44d65243a0b4ca5ba2803d3598": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "d26580979feb41cdb423d151bd5ebd15": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "e7a58c85b79041eaaeeb9a9c3fa102aa": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_57b95194b9c0404d84ee4a73208e10b0", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_477c9d28d9624f448b75b8c38207fed5", "IPY_MODEL_e2e7b8f3477c4fd486a0602377992efe" ] } }, "57b95194b9c0404d84ee4a73208e10b0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "477c9d28d9624f448b75b8c38207fed5": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_72755dae20114e74a26fa5f556eed8c8", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 908, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 908, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_c961071096794f7580d0d37db02e8ff0" } }, "e2e7b8f3477c4fd486a0602377992efe": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_cf26e32a7c6b48caa2d7727365d33da3", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 908/908 [00:00<00:00, 25.8kB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_2edbb013778f4cb2a05510875c8b1d66" } }, "72755dae20114e74a26fa5f556eed8c8": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "c961071096794f7580d0d37db02e8ff0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "cf26e32a7c6b48caa2d7727365d33da3": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "2edbb013778f4cb2a05510875c8b1d66": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "b4bb0987873c46f9865e5308e2cd6cab": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_8fa51cc1231146d9be01452d4baedcab", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_5123c4c6e7524b4983ca0701394af3a0", "IPY_MODEL_4d97b210440f431bb5440d5b3b49032b" ] } }, "8fa51cc1231146d9be01452d4baedcab": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "5123c4c6e7524b4983ca0701394af3a0": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_d5a94e8542ac48babbc228b80e8c5c91", "_dom_classes": [], "description": "Downloading: 63%", "_model_name": "FloatProgressModel", "bar_style": "", "max": 1935796948, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 1223096320, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_47f2a197eb0b47d583ec331a89a7657f" } }, "4d97b210440f431bb5440d5b3b49032b": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_c1998ef598ff4359b0de7961665dba6d", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 1.22G/1.94G [00:22<00:13, 52.5MB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_ba0a677b24604277981595701dfc4574" } }, "d5a94e8542ac48babbc228b80e8c5c91": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "47f2a197eb0b47d583ec331a89a7657f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "c1998ef598ff4359b0de7961665dba6d": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "ba0a677b24604277981595701dfc4574": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } } } } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "tx6xJha5YIiA" }, "source": [ "# DL Translate\n", "\n", "*A deep learning-based translation library built on Huggingface `transformers` and Facebook's `mBART-Large`*\n", "\n", "💻 [GitHub Repository](https://github.com/xhlulu/dl-translate)\\\n", "📚 [Documentation](https://git.io/dlt-docs) / [readthedocs](https://dl-translate.readthedocs.io)\\\n", "🐍 [PyPi project](https://pypi.org/project/dl-translate/)" ] }, { "cell_type": "markdown", "metadata": { "id": "bCjxVhyxYIiD" }, "source": [ "## Quickstart\n", "\n", "Install the library with pip:" ] }, { "cell_type": "code", "metadata": { "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "trusted": true, "_kg_hide-input": false, "_kg_hide-output": true, "id": "BI3mAoRnYIiF", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "c5674616-a545-4a8a-9813-b127c1777efa" }, "source": [ "!pip install -q dl-translate" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "\u001b[K |████████████████████████████████| 1.2MB 9.2MB/s \n", "\u001b[K |████████████████████████████████| 2.2MB 29.2MB/s \n", "\u001b[K |████████████████████████████████| 870kB 50.4MB/s \n", "\u001b[K |████████████████████████████████| 3.3MB 51.1MB/s \n", "\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "p1oeb4czYIiG" }, "source": [ "To translate some text:" ] }, { "cell_type": "code", "metadata": { "trusted": true, "colab": { "base_uri": "https://localhost:8080/", "height": 300, "referenced_widgets": [ "9695e0e8562c4104b8e28a25bf05991e", "7e7d388cb3ea475098dcca168cba2635", "7424952a9ad34bc9a0094ed6df2881ab", "bd39d9cea7f949babdb9048b41f4055f", "fee217673faf4424883be8bf33574584", "2ad1b48139f14387bac9b6046b0d8d60", "79733c788f914b66a5803f1b71554c26", "cdc4f3fc6bf34edfa6ed14d1edc1c0da", "451f6d64f9d14cfe821ef82af4313223", "d213c156b7f941de86f89704200898cc", "d2ab70c5a6b24d3582e13bd7300e4b48", "d596d132f81a4577ab6eeff01c858f4a", "e4ef14da423a4e37b23815228dff047c", "f489df53fce4489c85248d0a0fc0347b", "347117e2b4084588af0120de4cba9992", "f7dde0ddeb954feeb6334f91957b3b4f", "d85f7d324e534e7fb7665440581927af", "36a4ac5796544fe5b50e802cd0b4b599", "65782adbd1e943f39a3d282e4f217daf", "d78f6159d9fe4b0e840b2d1912a253d8", "db2021d70d584ef987b1580962ad919c", "2cc9fc8c66aa4381ab3a6e742d0547dc", "256273b11d094044a36d0056498cc487", "361e88b5998d4df996113c7b6c55bd1e", "0236fdbddbcc4049b48713bb17ba4b74", "878fc2e400354baf84a8a134d9038174", "e2a4fd0970f94aeb8b841870d95662a8", "bc263f4b7ac94cfeaa7609341d5394e2", "266b8b8d1dda42afa1b5fdafb66af4ac", "dc8604c9207b4a8ca364a817c6516378", "f1958b44d65243a0b4ca5ba2803d3598", "d26580979feb41cdb423d151bd5ebd15", "e7a58c85b79041eaaeeb9a9c3fa102aa", "57b95194b9c0404d84ee4a73208e10b0", "477c9d28d9624f448b75b8c38207fed5", "e2e7b8f3477c4fd486a0602377992efe", "72755dae20114e74a26fa5f556eed8c8", "c961071096794f7580d0d37db02e8ff0", "cf26e32a7c6b48caa2d7727365d33da3", "2edbb013778f4cb2a05510875c8b1d66", "b4bb0987873c46f9865e5308e2cd6cab", "8fa51cc1231146d9be01452d4baedcab", "5123c4c6e7524b4983ca0701394af3a0", "4d97b210440f431bb5440d5b3b49032b", "d5a94e8542ac48babbc228b80e8c5c91", "47f2a197eb0b47d583ec331a89a7657f", "c1998ef598ff4359b0de7961665dba6d", "ba0a677b24604277981595701dfc4574" ] }, "id": "qdefSjR_YIiG", "outputId": "a1002eb7-cceb-45ee-dbb9-6a7329860af1" }, "source": [ "import dl_translate as dlt\n", "\n", "mt = dlt.TranslationModel()\n", "\n", "text_hi = \"संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है\"\n", "mt.translate(text_hi, source=dlt.lang.HINDI, target=dlt.lang.ENGLISH)" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9695e0e8562c4104b8e28a25bf05991e", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3708092.0, style=ProgressStyle(descript…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "451f6d64f9d14cfe821ef82af4313223", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2423393.0, style=ProgressStyle(descript…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d85f7d324e534e7fb7665440581927af", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=272.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0236fdbddbcc4049b48713bb17ba4b74", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1140.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e7a58c85b79041eaaeeb9a9c3fa102aa", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=908.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b4bb0987873c46f9865e5308e2cd6cab", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1935796948.0, style=ProgressStyle(descr…" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "DDQGpznwYIiH" }, "source": [ "Above, you can see that `dlt.lang` contains variables representing each of the 50 available languages with auto-complete support. Alternatively, you can specify the language (e.g. \"Arabic\") or the language code (e.g. \"fr\" for French):" ] }, { "cell_type": "code", "metadata": { "trusted": true, "id": "yC3LMjmNYIiI" }, "source": [ "text_ar = \"الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.\"\n", "mt.translate(text_ar, source=\"Arabic\", target=\"fr\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "erptbvbiYIiI" }, "source": [ "If you want to verify whether a language is available, you can check it:" ] }, { "cell_type": "code", "metadata": { "trusted": true, "_kg_hide-output": false, "id": "saHalYvsYIiJ" }, "source": [ "print(mt.available_languages()) # All languages that you can use\n", "print(mt.available_codes()) # Code corresponding to each language accepted\n", "print(mt.get_lang_code_map()) # Dictionary of lang -> code" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "trusted": true, "id": "lz3Rq5t0YIiJ" }, "source": [ "## Usage\n", "\n", "### Selecting a device\n", "\n", "When you load the model, you can specify the device:\n", "```python\n", "mt = dlt.TranslationModel(device=\"auto\")\n", "```\n", "\n", "By default, the value will be `device=\"auto\"`, which means it will use a GPU if possible. You can also explicitly set `device=\"cpu\"` or `device=\"gpu\"`, or some other strings accepted by [`torch.device()`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device). __In general, it is recommend to use a GPU if you want a reasonable processing time.__\n", "\n", "Let's check what we originally loaded:" ] }, { "cell_type": "code", "metadata": { "trusted": true, "id": "fKAlEmzbYIiJ" }, "source": [ "mt.device" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Nm7hVlSZYIiL" }, "source": [ "### Loading from a path\n", "\n", "By default, `dlt.TranslationModel` will download the model from the [huggingface repo](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt) and cache it. However, you are free to load from a path:\n", "```python\n", "mt = dlt.TranslationModel(\"/path/to/your/model/directory/\", model_family=\"mbart50\")\n", "```\n", "Make sure that your tokenizer is also stored in the same directory if you use this approach.\n", "\n", "\n", "### Using a different model\n", "\n", "You can also choose another model that has [a similar format](https://huggingface.co/models?filter=mbart-50). In those cases, it's preferable to specify the model family:\n", "```python\n", "mt = dlt.TranslationModel(\"facebook/mbart-large-50-one-to-many-mmt\")\n", "mt = dlt.TranslationModel(\"facebook/m2m100_1.2B\", model_family=\"m2m100\")\n", "```\n", "Note that the available languages will change if you do this, so you will not be able to leverage `dlt.lang` or `dlt.utils`.\n", "\n", "\n", "### Breaking down into sentences\n", "\n", "It is not recommended to use extremely long texts as it takes more time to process. Instead, you can try to break them down into sentences with the help of `nltk`. First install the library with `pip install nltk`, then run:" ] }, { "cell_type": "code", "metadata": { "_kg_hide-output": true, "trusted": true, "id": "XkjrydidYIiL" }, "source": [ "import nltk\n", "nltk.download(\"punkt\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "trusted": true, "id": "j-cyjxQCYIiL" }, "source": [ "text = \"Mr. Smith went to his favorite cafe. There, he met his friend Dr. Doe.\"\n", "sents = nltk.tokenize.sent_tokenize(text, \"english\") # don't use dlt.lang.ENGLISH\n", "\" \".join(mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "u6y8-SthYIiM" }, "source": [ "### Setting a `batch_size` and verbosity when calling `dlt.TranslationModel.translate`\n", "\n", "It's possible to set a batch size (i.e. the number of elements processed at once) for `mt.translate` and whether you want to see the progress bar or not:" ] }, { "cell_type": "code", "metadata": { "trusted": true, "id": "fcxUFmjAYIiM" }, "source": [ "mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH, batch_size=32, verbose=True)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "iS0bfSh_YIiM" }, "source": [ "If you set `batch_size=None`, it will compute the entire `text` at once rather than splitting into \"chunks\". We recommend lowering `batch_size` if you do not have a lot of RAM or VRAM and run into CUDA memory error. Set a higher value if you are using a high-end GPU and the VRAM is not fully utilized.\n", "\n", "\n", "### `dlt.utils` module\n", "\n", "An alternative to `mt.available_languages()` is the `dlt.utils` module. You can use it to find out which languages and codes are available:\n" ] }, { "cell_type": "code", "metadata": { "trusted": true, "id": "U7iS_wKTYIiM" }, "source": [ "print(dlt.utils.available_languages('mbart50')) # All languages that you can use\n", "print(dlt.utils.available_codes('mbart50')) # Code corresponding to each language accepted\n", "print(dlt.utils.get_lang_code_map('mbart50')) # Dictionary of lang -> code" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "RFN5AplfYIiM" }, "source": [ "## Advanced\n", "\n", "The following section assumes you have knowledge of PyTorch and Huggingface Transformers.\n", "\n", "### Saving and loading\n", "\n", "If you wish to accelerate the loading time the translation model, you can use `save_obj`:\n" ] }, { "cell_type": "code", "metadata": { "trusted": true, "id": "UaDQVFlzYIiN" }, "source": [ "mt.save_obj(\"saved_model\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "c_uMZ9exYIiN" }, "source": [ "\n", "Then later you can reload it with `load_obj`:" ] }, { "cell_type": "code", "metadata": { "trusted": true, "id": "SEkmXUDaYIiN" }, "source": [ "%%time\n", "mt = dlt.TranslationModel.load_obj('saved_model')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "j4G9tRL8YIiO" }, "source": [ "\n", "**Warning:** Only use this if you are certain the torch module saved in `saved_model/weights.pt` can be correctly loaded. Indeed, it is possible that the `huggingface`, `torch` or some other dependencies change between when you called `save_obj` and `load_obj`, and that might break your code. Thus, it is recommend to only run `load_obj` in the same environment/session as `save_obj`. **Note this method might be deprecated in the future once there's no speed benefit in loading this way.**\n", "\n", "\n", "### Interacting with underlying model and tokenizer\n", "\n", "When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer (which will respectively be passed to `MBartForConditionalGeneration.from_pretrained` and `MBart50TokenizerFast.from_pretrained`):\n", "\n", "```python\n", "mt = dlt.TranslationModel(\n", " model_options=dict(\n", " state_dict=...,\n", " cache_dir=...,\n", " ...\n", " ),\n", " tokenizer_options=dict(\n", " tokenizer_file=...,\n", " eos_token=...,\n", " ...\n", " )\n", ")\n", "```\n", "\n", "You can also access the underlying `transformers` model and `tokenizer`:" ] }, { "cell_type": "code", "metadata": { "_kg_hide-output": true, "trusted": true, "id": "UX1lyl_uYIiO" }, "source": [ "bart = mt.get_transformers_model()\n", "tokenizer = mt.get_tokenizer()\n", "\n", "print(tokenizer)\n", "print(bart)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "QQPqtEeuYIiO" }, "source": [ "See the [huggingface docs](https://huggingface.co/transformers/master/model_doc/mbart.html) for more information.\n", "\n", "\n", "### `bart_model.generate()` keyword arguments\n", "\n", "When running `mt.translate`, you can also give a `generation_options` dictionary that is passed as keyword arguments to the underlying `bart_model.generate()` method:" ] }, { "cell_type": "code", "metadata": { "trusted": true, "id": "XuTbvJBWYIiP" }, "source": [ "mt.translate(\n", " sents,\n", " source=dlt.lang.ENGLISH,\n", " target=dlt.lang.SPANISH,\n", " generation_options=dict(num_beams=5, max_length=128)\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "QvwPa2b1YIiP" }, "source": [ "Learn more in the [huggingface docs](https://huggingface.co/transformers/main_classes/model.html#transformers.generation_utils.GenerationMixin.generate).\n", "\n", "\n", "## Acknowledgement\n", "\n", "`dl-translate` is built on top of Huggingface's implementation of two models created by Facebook AI Research.\n", "\n", "1. The multilingual BART finetuned on many-to-many translation of over 50 languages, which is [documented here](https://huggingface.co/transformers/master/model_doc/mbart.html) The original paper was written by Tang et. al from Facebook AI Research; you can [find it here](https://arxiv.org/pdf/2008.00401.pdf) and cite it using the following:\n", " ```\n", " @article{tang2020multilingual,\n", " title={Multilingual translation with extensible multilingual pretraining and finetuning},\n", " author={Tang, Yuqing and Tran, Chau and Li, Xian and Chen, Peng-Jen and Goyal, Naman and Chaudhary, Vishrav and Gu, Jiatao and Fan, Angela},\n", " journal={arXiv preprint arXiv:2008.00401},\n", " year={2020}\n", " }\n", " ```\n", "2. The transformer model published in [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Fan et. al, which supports over 100 languages. You can cite it here:\n", " ```\n", " @misc{fan2020englishcentric,\n", " title={Beyond English-Centric Multilingual Machine Translation}, \n", " author={Angela Fan and Shruti Bhosale and Holger Schwenk and Zhiyi Ma and Ahmed El-Kishky and Siddharth Goyal and Mandeep Baines and Onur Celebi and Guillaume Wenzek and Vishrav Chaudhary and Naman Goyal and Tom Birch and Vitaliy Liptchinsky and Sergey Edunov and Edouard Grave and Michael Auli and Armand Joulin},\n", " year={2020},\n", " eprint={2010.11125},\n", " archivePrefix={arXiv},\n", " primaryClass={cs.CL}\n", " }\n", " ```\n", "\n", "`dlt` is a wrapper with useful `utils` to save you time. For huggingface's `transformers`, the following snippet is shown as an example:\n", "```python\n", "from transformers import MBartForConditionalGeneration, MBart50TokenizerFast\n", "\n", "article_hi = \"संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है\"\n", "article_ar = \"الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.\"\n", "\n", "model = MBartForConditionalGeneration.from_pretrained(\"facebook/mbart-large-50-many-to-many-mmt\")\n", "tokenizer = MBart50TokenizerFast.from_pretrained(\"facebook/mbart-large-50-many-to-many-mmt\")\n", "\n", "# translate Hindi to French\n", "tokenizer.src_lang = \"hi_IN\"\n", "encoded_hi = tokenizer(article_hi, return_tensors=\"pt\")\n", "generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id[\"fr_XX\"])\n", "tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n", "# => \"Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria.\"\n", "\n", "# translate Arabic to English\n", "tokenizer.src_lang = \"ar_AR\"\n", "encoded_ar = tokenizer(article_ar, return_tensors=\"pt\")\n", "generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id[\"en_XX\"])\n", "tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n", "# => \"The Secretary-General of the United Nations says there is no military solution in Syria.\"\n", "```\n", "\n", "With `dlt`, you can run:\n", "```python\n", "import dl_translate as dlt\n", "\n", "article_hi = \"संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है\"\n", "article_ar = \"الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.\"\n", "\n", "mt = dlt.TranslationModel()\n", "translated_fr = mt.translate(article_hi, source=dlt.lang.HINDI, target=dlt.lang.FRENCH)\n", "translated_en = mt.translate(article_ar, source=dlt.lang.ARABIC, target=dlt.lang.ENGLISH)\n", "```\n", "\n", "Notice you don't have to think about tokenizers, condition generation, pretrained models, and regional codes; you can just tell the model what to translate!\n", "\n", "If you are experienced with `huggingface`'s ecosystem, then you should be familiar enough with the example above that you wouldn't need this library. However, if you've never heard of huggingface or mBART, then I hope using this library will give you enough motivation to [learn more about them](https://github.com/huggingface/transformers) :)" ] } ] } ================================================ FILE: demos/nllb_demo.ipynb ================================================ {"cells":[{"cell_type":"code","execution_count":1,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2023-07-18T05:15:13.999614Z","iopub.status.busy":"2023-07-18T05:15:13.999228Z","iopub.status.idle":"2023-07-18T05:15:31.978108Z","shell.execute_reply":"2023-07-18T05:15:31.976681Z","shell.execute_reply.started":"2023-07-18T05:15:13.999573Z"},"trusted":true},"outputs":[],"source":["!pip install dl-translate==0.3.* -q"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:15:31.982361Z","iopub.status.busy":"2023-07-18T05:15:31.981992Z","iopub.status.idle":"2023-07-18T05:16:23.731908Z","shell.execute_reply":"2023-07-18T05:16:23.730776Z","shell.execute_reply.started":"2023-07-18T05:15:31.982327Z"},"trusted":true},"outputs":[],"source":["import dl_translate as dlt\n","\n","mt = dlt.TranslationModel(\"nllb200\")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:23.734336Z","iopub.status.busy":"2023-07-18T05:16:23.733295Z","iopub.status.idle":"2023-07-18T05:16:28.025038Z","shell.execute_reply":"2023-07-18T05:16:28.023933Z","shell.execute_reply.started":"2023-07-18T05:16:23.734293Z"},"trusted":true},"outputs":[],"source":["text = \"Meta AI has built a single AI model capable of translating across 200 different languages with state-of-the-art quality.\"\n","\n","# The new translation is much faster than before\n","%time print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.FRENCH))"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:28.028919Z","iopub.status.busy":"2023-07-18T05:16:28.028286Z","iopub.status.idle":"2023-07-18T05:16:28.717521Z","shell.execute_reply":"2023-07-18T05:16:28.716343Z","shell.execute_reply.started":"2023-07-18T05:16:28.028882Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["मेटाएआई एकमेव एआई मॉडलं निर्मितवान्, यत् 200 भिन्नभाषायां अवधीतमतमतमगुणैः अनुवादं कर्तुं समर्थः अस्ति।\n"]}],"source":["# Sanskrit is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.SANSKRIT))"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:28.719596Z","iopub.status.busy":"2023-07-18T05:16:28.719227Z","iopub.status.idle":"2023-07-18T05:16:29.443696Z","shell.execute_reply":"2023-07-18T05:16:29.442668Z","shell.execute_reply.started":"2023-07-18T05:16:28.719560Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["Meta AI hà custruitu un solu mudellu d'AI capace di tradurisce in 200 lingue sfarenti cù qualità di u statu di l'arte.\n"]}],"source":["# Sicilian is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.SICILIAN))"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:29.447147Z","iopub.status.busy":"2023-07-18T05:16:29.445331Z","iopub.status.idle":"2023-07-18T05:16:30.145637Z","shell.execute_reply":"2023-07-18T05:16:30.144623Z","shell.execute_reply.started":"2023-07-18T05:16:29.447108Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["基於Meta AI 建立咗一個 AI 模型 可以用最先端嘅質量翻譯到 200 個唔同語言\n"]}],"source":["# Yue Chinese is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.YUE_CHINESE))"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4} ================================================ FILE: dl_translate/__init__.py ================================================ from . import lang from . import utils from ._translation_model import TranslationModel ================================================ FILE: dl_translate/_pairs.py ================================================ # Auto-generated. Do not modify, use scripts/generate_langs.py instead. _PAIRS_M2M100 = ( ("Afrikaans", "af"), ("Amharic", "am"), ("Arabic", "ar"), ("Asturian", "ast"), ("Azerbaijani", "az"), ("Bashkir", "ba"), ("Belarusian", "be"), ("Bulgarian", "bg"), ("Bengali", "bn"), ("Breton", "br"), ("Bosnian", "bs"), ("Catalan", "ca"), ("Valencian", "ca"), ("Cebuano", "ceb"), ("Czech", "cs"), ("Welsh", "cy"), ("Danish", "da"), ("German", "de"), ("Greek", "el"), ("English", "en"), ("Spanish", "es"), ("Estonian", "et"), ("Persian", "fa"), ("Fulah", "ff"), ("Finnish", "fi"), ("French", "fr"), ("Western Frisian", "fy"), ("Irish", "ga"), ("Gaelic", "gd"), ("Scottish Gaelic", "gd"), ("Galician", "gl"), ("Gujarati", "gu"), ("Hausa", "ha"), ("Hebrew", "he"), ("Hindi", "hi"), ("Croatian", "hr"), ("Haitian", "ht"), ("Haitian Creole", "ht"), ("Hungarian", "hu"), ("Armenian", "hy"), ("Indonesian", "id"), ("Igbo", "ig"), ("Iloko", "ilo"), ("Icelandic", "is"), ("Italian", "it"), ("Japanese", "ja"), ("Javanese", "jv"), ("Georgian", "ka"), ("Kazakh", "kk"), ("Khmer", "km"), ("Central Khmer", "km"), ("Kannada", "kn"), ("Korean", "ko"), ("Luxembourgish", "lb"), ("Letzeburgesch", "lb"), ("Ganda", "lg"), ("Lingala", "ln"), ("Lao", "lo"), ("Lithuanian", "lt"), ("Latvian", "lv"), ("Malagasy", "mg"), ("Macedonian", "mk"), ("Malayalam", "ml"), ("Mongolian", "mn"), ("Marathi", "mr"), ("Malay", "ms"), ("Burmese", "my"), ("Nepali", "ne"), ("Dutch", "nl"), ("Flemish", "nl"), ("Norwegian", "no"), ("Northern Sotho", "ns"), ("Occitan", "oc"), ("Oriya", "or"), ("Panjabi", "pa"), ("Punjabi", "pa"), ("Polish", "pl"), ("Pushto", "ps"), ("Pashto", "ps"), ("Portuguese", "pt"), ("Romanian", "ro"), ("Moldavian", "ro"), ("Moldovan", "ro"), ("Russian", "ru"), ("Sindhi", "sd"), ("Sinhala", "si"), ("Sinhalese", "si"), ("Slovak", "sk"), ("Slovenian", "sl"), ("Somali", "so"), ("Albanian", "sq"), ("Serbian", "sr"), ("Swati", "ss"), ("Sundanese", "su"), ("Swedish", "sv"), ("Swahili", "sw"), ("Tamil", "ta"), ("Thai", "th"), ("Tagalog", "tl"), ("Tswana", "tn"), ("Turkish", "tr"), ("Ukrainian", "uk"), ("Urdu", "ur"), ("Uzbek", "uz"), ("Vietnamese", "vi"), ("Wolof", "wo"), ("Xhosa", "xh"), ("Yiddish", "yi"), ("Yoruba", "yo"), ("Chinese", "zh"), ("Zulu", "zu"), ) _PAIRS_MBART50 = ( ("Arabic", "ar_AR"), ("Czech", "cs_CZ"), ("German", "de_DE"), ("English", "en_XX"), ("Spanish", "es_XX"), ("Estonian", "et_EE"), ("Finnish", "fi_FI"), ("French", "fr_XX"), ("Gujarati", "gu_IN"), ("Hindi", "hi_IN"), ("Italian", "it_IT"), ("Japanese", "ja_XX"), ("Kazakh", "kk_KZ"), ("Korean", "ko_KR"), ("Lithuanian", "lt_LT"), ("Latvian", "lv_LV"), ("Burmese", "my_MM"), ("Nepali", "ne_NP"), ("Dutch", "nl_XX"), ("Romanian", "ro_RO"), ("Russian", "ru_RU"), ("Sinhala", "si_LK"), ("Turkish", "tr_TR"), ("Vietnamese", "vi_VN"), ("Chinese", "zh_CN"), ("Afrikaans", "af_ZA"), ("Azerbaijani", "az_AZ"), ("Bengali", "bn_IN"), ("Persian", "fa_IR"), ("Hebrew", "he_IL"), ("Croatian", "hr_HR"), ("Indonesian", "id_ID"), ("Georgian", "ka_GE"), ("Khmer", "km_KH"), ("Macedonian", "mk_MK"), ("Malayalam", "ml_IN"), ("Mongolian", "mn_MN"), ("Marathi", "mr_IN"), ("Polish", "pl_PL"), ("Pashto", "ps_AF"), ("Portuguese", "pt_XX"), ("Swedish", "sv_SE"), ("Swahili", "sw_KE"), ("Tamil", "ta_IN"), ("Telugu", "te_IN"), ("Thai", "th_TH"), ("Tagalog", "tl_XX"), ("Ukrainian", "uk_UA"), ("Urdu", "ur_PK"), ("Xhosa", "xh_ZA"), ("Galician", "gl_ES"), ("Slovene", "sl_SI"), ) _PAIRS_NLLB200 = ( ("Acehnese (Arabic script)", "ace_Arab"), ("Acehnese (Latin script)", "ace_Latn"), ("Mesopotamian Arabic", "acm_Arab"), ("Ta'izzi-Adeni Arabic", "acq_Arab"), ("Tunisian Arabic", "aeb_Arab"), ("Afrikaans", "afr_Latn"), ("South Levantine Arabic", "ajp_Arab"), ("Akan", "aka_Latn"), ("Amharic", "amh_Ethi"), ("North Levantine Arabic", "apc_Arab"), ("Modern Standard Arabic", "arb_Arab"), ("Modern Standard Arabic (Romanized)", "arb_Latn"), ("Najdi Arabic", "ars_Arab"), ("Moroccan Arabic", "ary_Arab"), ("Egyptian Arabic", "arz_Arab"), ("Assamese", "asm_Beng"), ("Asturian", "ast_Latn"), ("Awadhi", "awa_Deva"), ("Central Aymara", "ayr_Latn"), ("South Azerbaijani", "azb_Arab"), ("North Azerbaijani", "azj_Latn"), ("Bashkir", "bak_Cyrl"), ("Bambara", "bam_Latn"), ("Balinese", "ban_Latn"), ("Belarusian", "bel_Cyrl"), ("Bemba", "bem_Latn"), ("Bengali", "ben_Beng"), ("Bhojpuri", "bho_Deva"), ("Banjar (Arabic script)", "bjn_Arab"), ("Banjar (Latin script)", "bjn_Latn"), ("Standard Tibetan", "bod_Tibt"), ("Bosnian", "bos_Latn"), ("Buginese", "bug_Latn"), ("Bulgarian", "bul_Cyrl"), ("Catalan", "cat_Latn"), ("Cebuano", "ceb_Latn"), ("Czech", "ces_Latn"), ("Chokwe", "cjk_Latn"), ("Central Kurdish", "ckb_Arab"), ("Crimean Tatar", "crh_Latn"), ("Welsh", "cym_Latn"), ("Danish", "dan_Latn"), ("German", "deu_Latn"), ("Southwestern Dinka", "dik_Latn"), ("Dyula", "dyu_Latn"), ("Dzongkha", "dzo_Tibt"), ("Greek", "ell_Grek"), ("English", "eng_Latn"), ("Esperanto", "epo_Latn"), ("Estonian", "est_Latn"), ("Basque", "eus_Latn"), ("Ewe", "ewe_Latn"), ("Faroese", "fao_Latn"), ("Fijian", "fij_Latn"), ("Finnish", "fin_Latn"), ("Fon", "fon_Latn"), ("French", "fra_Latn"), ("Friulian", "fur_Latn"), ("Nigerian Fulfulde", "fuv_Latn"), ("Scottish Gaelic", "gla_Latn"), ("Irish", "gle_Latn"), ("Galician", "glg_Latn"), ("Guarani", "grn_Latn"), ("Gujarati", "guj_Gujr"), ("Haitian Creole", "hat_Latn"), ("Hausa", "hau_Latn"), ("Hebrew", "heb_Hebr"), ("Hindi", "hin_Deva"), ("Chhattisgarhi", "hne_Deva"), ("Croatian", "hrv_Latn"), ("Hungarian", "hun_Latn"), ("Armenian", "hye_Armn"), ("Igbo", "ibo_Latn"), ("Ilocano", "ilo_Latn"), ("Indonesian", "ind_Latn"), ("Icelandic", "isl_Latn"), ("Italian", "ita_Latn"), ("Javanese", "jav_Latn"), ("Japanese", "jpn_Jpan"), ("Kabyle", "kab_Latn"), ("Jingpho", "kac_Latn"), ("Kamba", "kam_Latn"), ("Kannada", "kan_Knda"), ("Kashmiri (Arabic script)", "kas_Arab"), ("Kashmiri (Devanagari script)", "kas_Deva"), ("Georgian", "kat_Geor"), ("Central Kanuri (Arabic script)", "knc_Arab"), ("Central Kanuri (Latin script)", "knc_Latn"), ("Kazakh", "kaz_Cyrl"), ("Kabiyè", "kbp_Latn"), ("Kabuverdianu", "kea_Latn"), ("Khmer", "khm_Khmr"), ("Kikuyu", "kik_Latn"), ("Kinyarwanda", "kin_Latn"), ("Kyrgyz", "kir_Cyrl"), ("Kimbundu", "kmb_Latn"), ("Northern Kurdish", "kmr_Latn"), ("Kikongo", "kon_Latn"), ("Korean", "kor_Hang"), ("Lao", "lao_Laoo"), ("Ligurian", "lij_Latn"), ("Limburgish", "lim_Latn"), ("Lingala", "lin_Latn"), ("Lithuanian", "lit_Latn"), ("Lombard", "lmo_Latn"), ("Latgalian", "ltg_Latn"), ("Luxembourgish", "ltz_Latn"), ("Luba-Kasai", "lua_Latn"), ("Ganda", "lug_Latn"), ("Luo", "luo_Latn"), ("Mizo", "lus_Latn"), ("Standard Latvian", "lvs_Latn"), ("Magahi", "mag_Deva"), ("Maithili", "mai_Deva"), ("Malayalam", "mal_Mlym"), ("Marathi", "mar_Deva"), ("Minangkabau (Arabic script)", "min_Arab"), ("Minangkabau (Latin script)", "min_Latn"), ("Macedonian", "mkd_Cyrl"), ("Plateau Malagasy", "plt_Latn"), ("Maltese", "mlt_Latn"), ("Meitei (Bengali script)", "mni_Beng"), ("Halh Mongolian", "khk_Cyrl"), ("Mossi", "mos_Latn"), ("Maori", "mri_Latn"), ("Burmese", "mya_Mymr"), ("Dutch", "nld_Latn"), ("Norwegian Nynorsk", "nno_Latn"), ("Norwegian Bokmål", "nob_Latn"), ("Nepali", "npi_Deva"), ("Northern Sotho", "nso_Latn"), ("Nuer", "nus_Latn"), ("Nyanja", "nya_Latn"), ("Occitan", "oci_Latn"), ("West Central Oromo", "gaz_Latn"), ("Odia", "ory_Orya"), ("Pangasinan", "pag_Latn"), ("Eastern Panjabi", "pan_Guru"), ("Papiamento", "pap_Latn"), ("Western Persian", "pes_Arab"), ("Polish", "pol_Latn"), ("Portuguese", "por_Latn"), ("Dari", "prs_Arab"), ("Southern Pashto", "pbt_Arab"), ("Ayacucho Quechua", "quy_Latn"), ("Romanian", "ron_Latn"), ("Rundi", "run_Latn"), ("Russian", "rus_Cyrl"), ("Sango", "sag_Latn"), ("Sanskrit", "san_Deva"), ("Santali", "sat_Olck"), ("Sicilian", "scn_Latn"), ("Shan", "shn_Mymr"), ("Sinhala", "sin_Sinh"), ("Slovak", "slk_Latn"), ("Slovenian", "slv_Latn"), ("Samoan", "smo_Latn"), ("Shona", "sna_Latn"), ("Sindhi", "snd_Arab"), ("Somali", "som_Latn"), ("Southern Sotho", "sot_Latn"), ("Spanish", "spa_Latn"), ("Tosk Albanian", "als_Latn"), ("Sardinian", "srd_Latn"), ("Serbian", "srp_Cyrl"), ("Swati", "ssw_Latn"), ("Sundanese", "sun_Latn"), ("Swedish", "swe_Latn"), ("Swahili", "swh_Latn"), ("Silesian", "szl_Latn"), ("Tamil", "tam_Taml"), ("Tatar", "tat_Cyrl"), ("Telugu", "tel_Telu"), ("Tajik", "tgk_Cyrl"), ("Tagalog", "tgl_Latn"), ("Thai", "tha_Thai"), ("Tigrinya", "tir_Ethi"), ("Tamasheq (Latin script)", "taq_Latn"), ("Tamasheq (Tifinagh script)", "taq_Tfng"), ("Tok Pisin", "tpi_Latn"), ("Tswana", "tsn_Latn"), ("Tsonga", "tso_Latn"), ("Turkmen", "tuk_Latn"), ("Tumbuka", "tum_Latn"), ("Turkish", "tur_Latn"), ("Twi", "twi_Latn"), ("Central Atlas Tamazight", "tzm_Tfng"), ("Uyghur", "uig_Arab"), ("Ukrainian", "ukr_Cyrl"), ("Umbundu", "umb_Latn"), ("Urdu", "urd_Arab"), ("Northern Uzbek", "uzn_Latn"), ("Venetian", "vec_Latn"), ("Vietnamese", "vie_Latn"), ("Waray", "war_Latn"), ("Wolof", "wol_Latn"), ("Xhosa", "xho_Latn"), ("Eastern Yiddish", "ydd_Hebr"), ("Yoruba", "yor_Latn"), ("Yue Chinese", "yue_Hant"), ("Chinese (Simplified)", "zho_Hans"), ("Chinese (Traditional)", "zho_Hant"), ("Standard Malay", "zsm_Latn"), ("Zulu", "zul_Latn"), ) ================================================ FILE: dl_translate/_translation_model.py ================================================ import os import json from typing import Union, List, Dict import transformers import torch from tqdm.auto import tqdm from . import utils from .utils import _infer_model_family, _infer_model_or_path def _select_device(device_selection): selected = device_selection.lower() if selected == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif selected == "cpu": device = torch.device("cpu") elif selected == "gpu": device = torch.device("cuda") else: device = torch.device(selected) return device def _resolve_lang_codes(lang: str, name: str, model_family: str): def error_message(variable, value): return f'Your {variable}="{value}" is not valid. Please run `print(mt.available_languages())` to see which languages are available. Make sure you are using the correct capital letters.' # If can't find in the lang -> code mapping, assumes it's already a code. lang_code_map = utils.get_lang_code_map(model_family) if lang in lang_code_map: code = lang_code_map[lang] elif lang.capitalize() in lang_code_map: code = lang_code_map[lang.capitalize()] else: lang_upper = lang.upper() lang_code_map_upper = {k.upper(): v for k, v in lang_code_map.items()} if lang_upper in lang_code_map_upper: code = lang_code_map_upper[lang_upper] else: code = lang # If the code is not valid, raises an error if code not in utils.available_codes(model_family): raise ValueError(error_message(name, code)) return code def _resolve_tokenizer(model_family): di = { "mbart50": transformers.MBart50TokenizerFast, "m2m100": transformers.M2M100Tokenizer, "nllb200": transformers.AutoTokenizer, } if model_family in di: return di[model_family] else: error_msg = f"{model_family} is not a valid value for model_family. Please choose model_family to be equal to one of the following values: {list(di.keys())}" raise ValueError(error_msg) def _resolve_transformers_model(model_family): di = { "mbart50": transformers.MBartForConditionalGeneration, "m2m100": transformers.M2M100ForConditionalGeneration, "nllb200": transformers.AutoModelForSeq2SeqLM, } if model_family in di: return di[model_family] else: error_msg = f"{model_family} is not a valid value for model_family. Please choose model_family to be equal to one of the following values: {list(di.keys())}" raise ValueError(error_msg) class TranslationModel: def __init__( self, model_or_path: str = "m2m100", tokenizer_path: str = None, device: str = "auto", model_family: str = None, model_options: dict = None, tokenizer_options: dict = None, ): """ *Instantiates a multilingual transformer model for translation.* {{params}} {{model_or_path}} The path or the name of the model. Equivalent to the first argument of `AutoModel.from_pretrained()`. You can also specify shorthands ("mbart50" and "m2m100"). {{tokenizer_path}} The path to the tokenizer. By default, it will be set to `model_or_path`. {{device}} "cpu", "gpu" or "auto". If it's set to "auto", will try to select a GPU when available or else fall back to CPU. {{model_family}} Either "mbart50" or "m2m100". By default, it will be inferred based on `model_or_path`. Needs to be explicitly set if `model_or_path` is a path. {{model_options}} The keyword arguments passed to the model, which is a transformer for conditional generation. {{tokenizer_options}} The keyword arguments passed to the model's tokenizer. """ model_or_path = _infer_model_or_path(model_or_path) self.model_or_path = model_or_path self.device = _select_device(device) # Resolve default values tokenizer_path = tokenizer_path or self.model_or_path model_options = model_options or {} tokenizer_options = tokenizer_options or {} self.model_family = model_family or _infer_model_family(model_or_path) # Load the tokenizer TokenizerFast = _resolve_tokenizer(self.model_family) self._tokenizer = TokenizerFast.from_pretrained( tokenizer_path, **tokenizer_options ) # Load the model either from a saved torch model or from transformers. if model_or_path.endswith(".pt"): self._transformers_model = torch.load( model_or_path, map_location=self.device ).eval() else: ModelForConditionalGeneration = _resolve_transformers_model( self.model_family ) self._transformers_model = ( ModelForConditionalGeneration.from_pretrained( self.model_or_path, **model_options ) .to(self.device) .eval() ) def translate( self, text: Union[str, List[str]], source: str, target: str, batch_size: int = 32, verbose: bool = False, generation_options: dict = None, ) -> Union[str, List[str]]: """ *Translates a string or a list of strings from a source to a target language.* {{params}} {{text}} The content you want to translate. {{source}} The language of the original text. {{target}} The language of the translated text. {{batch_size}} The number of samples to load at once. If set to `None`, it will process everything at once. {{verbose}} Whether to display the progress bar for every batch processed. {{generation_options}} The keyword arguments passed to `model.generate()`, where `model` is the underlying transformers model. Note: - Run `print(dlt.utils.available_languages())` to see what's available. - A smaller value is preferred for `batch_size` if your (video) RAM is limited. """ if generation_options is None: generation_options = {} source = _resolve_lang_codes(source, "source", self.model_family) target = _resolve_lang_codes(target, "target", self.model_family) self._tokenizer.src_lang = source original_text_type = type(text) if original_text_type is str: text = [text] if batch_size is None: batch_size = len(text) generation_options.setdefault( "forced_bos_token_id", self._tokenizer.convert_tokens_to_ids(target) ) generation_options.setdefault("max_new_tokens", 512) data_loader = torch.utils.data.DataLoader(text, batch_size=batch_size) output_text = [] tqdm_iterator = data_loader if verbose is True: tqdm_iterator = tqdm(data_loader) with torch.no_grad(): for batch in tqdm_iterator: encoded = self._tokenizer(batch, return_tensors="pt", padding=True) encoded.to(self.device) generated_tokens = self._transformers_model.generate( **encoded, **generation_options ).cpu() decoded = self._tokenizer.batch_decode( generated_tokens, skip_special_tokens=True ) output_text.extend(decoded) # If text: str and output_text: List[str], then we should convert output_text to str if original_text_type is str and len(output_text) == 1: output_text = output_text[0] return output_text def get_transformers_model(self): """ *Retrieve the underlying mBART transformer model.* """ return self._transformers_model def get_tokenizer(self): """ *Retrieve the mBART huggingface tokenizer.* """ return self._tokenizer def available_languages(self) -> List[str]: """ *Returns all the available languages for a given `dlt.TranslationModel` instance.* """ return utils.available_languages(self.model_family) def available_codes(self) -> List[str]: """ *Returns all the available codes for a given `dlt.TranslationModel` instance.* """ return utils.available_codes(self.model_family) def get_lang_code_map(self) -> Dict[str, str]: """ *Returns the language -> codes dictionary for a given `dlt.TranslationModel` instance.* """ return utils.get_lang_code_map(self.model_family) def save_obj(self, path: str = "saved_model") -> None: """ *Saves your model as a torch object and save your tokenizer.* {{params}} {{path}} The directory where you want to save your model and tokenizer """ os.makedirs(path, exist_ok=True) torch.save(self._transformers_model, os.path.join(path, "weights.pt")) self._tokenizer.save_pretrained(path) dlt_config = dict(model_family=self.model_family) json.dump(dlt_config, open(os.path.join(path, "dlt_config.json"), "w")) @classmethod def load_obj(cls, path: str = "saved_model", **kwargs): """ *Initialize `dlt.TranslationModel` from the torch object and tokenizer saved with `dlt.TranslationModel.save_obj`* {{params}} {{path}} The directory where your torch model and tokenizer are stored """ config_prev = json.load(open(os.path.join(path, "dlt_config.json"), "rb")) config_prev.update(kwargs) return cls( model_or_path=os.path.join(path, "weights.pt"), tokenizer_path=path, **config_prev, ) ================================================ FILE: dl_translate/lang/__init__.py ================================================ from .m2m100 import * from . import m2m100, mbart50, nllb200 ================================================ FILE: dl_translate/lang/m2m100.py ================================================ # Auto-generated. Do not modify, use scripts/generate_langs.py instead. AFRIKAANS = "Afrikaans" AMHARIC = "Amharic" ARABIC = "Arabic" ASTURIAN = "Asturian" AZERBAIJANI = "Azerbaijani" BASHKIR = "Bashkir" BELARUSIAN = "Belarusian" BULGARIAN = "Bulgarian" BENGALI = "Bengali" BRETON = "Breton" BOSNIAN = "Bosnian" CATALAN = "Catalan" VALENCIAN = "Valencian" CEBUANO = "Cebuano" CZECH = "Czech" WELSH = "Welsh" DANISH = "Danish" GERMAN = "German" GREEK = "Greek" ENGLISH = "English" SPANISH = "Spanish" ESTONIAN = "Estonian" PERSIAN = "Persian" FULAH = "Fulah" FINNISH = "Finnish" FRENCH = "French" WESTERN_FRISIAN = "Western Frisian" IRISH = "Irish" GAELIC = "Gaelic" SCOTTISH_GAELIC = "Scottish Gaelic" GALICIAN = "Galician" GUJARATI = "Gujarati" HAUSA = "Hausa" HEBREW = "Hebrew" HINDI = "Hindi" CROATIAN = "Croatian" HAITIAN = "Haitian" HAITIAN_CREOLE = "Haitian Creole" HUNGARIAN = "Hungarian" ARMENIAN = "Armenian" INDONESIAN = "Indonesian" IGBO = "Igbo" ILOKO = "Iloko" ICELANDIC = "Icelandic" ITALIAN = "Italian" JAPANESE = "Japanese" JAVANESE = "Javanese" GEORGIAN = "Georgian" KAZAKH = "Kazakh" KHMER = "Khmer" CENTRAL_KHMER = "Central Khmer" KANNADA = "Kannada" KOREAN = "Korean" LUXEMBOURGISH = "Luxembourgish" LETZEBURGESCH = "Letzeburgesch" GANDA = "Ganda" LINGALA = "Lingala" LAO = "Lao" LITHUANIAN = "Lithuanian" LATVIAN = "Latvian" MALAGASY = "Malagasy" MACEDONIAN = "Macedonian" MALAYALAM = "Malayalam" MONGOLIAN = "Mongolian" MARATHI = "Marathi" MALAY = "Malay" BURMESE = "Burmese" NEPALI = "Nepali" DUTCH = "Dutch" FLEMISH = "Flemish" NORWEGIAN = "Norwegian" NORTHERN_SOTHO = "Northern Sotho" OCCITAN = "Occitan" ORIYA = "Oriya" PANJABI = "Panjabi" PUNJABI = "Punjabi" POLISH = "Polish" PUSHTO = "Pushto" PASHTO = "Pashto" PORTUGUESE = "Portuguese" ROMANIAN = "Romanian" MOLDAVIAN = "Moldavian" MOLDOVAN = "Moldovan" RUSSIAN = "Russian" SINDHI = "Sindhi" SINHALA = "Sinhala" SINHALESE = "Sinhalese" SLOVAK = "Slovak" SLOVENIAN = "Slovenian" SOMALI = "Somali" ALBANIAN = "Albanian" SERBIAN = "Serbian" SWATI = "Swati" SUNDANESE = "Sundanese" SWEDISH = "Swedish" SWAHILI = "Swahili" TAMIL = "Tamil" THAI = "Thai" TAGALOG = "Tagalog" TSWANA = "Tswana" TURKISH = "Turkish" UKRAINIAN = "Ukrainian" URDU = "Urdu" UZBEK = "Uzbek" VIETNAMESE = "Vietnamese" WOLOF = "Wolof" XHOSA = "Xhosa" YIDDISH = "Yiddish" YORUBA = "Yoruba" CHINESE = "Chinese" ZULU = "Zulu" ================================================ FILE: dl_translate/lang/mbart50.py ================================================ # Auto-generated. Do not modify, use scripts/generate_langs.py instead. ARABIC = "Arabic" CZECH = "Czech" GERMAN = "German" ENGLISH = "English" SPANISH = "Spanish" ESTONIAN = "Estonian" FINNISH = "Finnish" FRENCH = "French" GUJARATI = "Gujarati" HINDI = "Hindi" ITALIAN = "Italian" JAPANESE = "Japanese" KAZAKH = "Kazakh" KOREAN = "Korean" LITHUANIAN = "Lithuanian" LATVIAN = "Latvian" BURMESE = "Burmese" NEPALI = "Nepali" DUTCH = "Dutch" ROMANIAN = "Romanian" RUSSIAN = "Russian" SINHALA = "Sinhala" TURKISH = "Turkish" VIETNAMESE = "Vietnamese" CHINESE = "Chinese" AFRIKAANS = "Afrikaans" AZERBAIJANI = "Azerbaijani" BENGALI = "Bengali" PERSIAN = "Persian" HEBREW = "Hebrew" CROATIAN = "Croatian" INDONESIAN = "Indonesian" GEORGIAN = "Georgian" KHMER = "Khmer" MACEDONIAN = "Macedonian" MALAYALAM = "Malayalam" MONGOLIAN = "Mongolian" MARATHI = "Marathi" POLISH = "Polish" PASHTO = "Pashto" PORTUGUESE = "Portuguese" SWEDISH = "Swedish" SWAHILI = "Swahili" TAMIL = "Tamil" TELUGU = "Telugu" THAI = "Thai" TAGALOG = "Tagalog" UKRAINIAN = "Ukrainian" URDU = "Urdu" XHOSA = "Xhosa" GALICIAN = "Galician" SLOVENE = "Slovene" ================================================ FILE: dl_translate/lang/nllb200.py ================================================ # Auto-generated. Do not modify, use scripts/generate_langs.py instead. ACEHNESE_ARABIC_SCRIPT = "Acehnese (Arabic script)" ACEHNESE_LATIN_SCRIPT = "Acehnese (Latin script)" MESOPOTAMIAN_ARABIC = "Mesopotamian Arabic" TAIZZI_ADENI_ARABIC = "Ta'izzi-Adeni Arabic" TUNISIAN_ARABIC = "Tunisian Arabic" AFRIKAANS = "Afrikaans" SOUTH_LEVANTINE_ARABIC = "South Levantine Arabic" AKAN = "Akan" AMHARIC = "Amharic" NORTH_LEVANTINE_ARABIC = "North Levantine Arabic" MODERN_STANDARD_ARABIC = "Modern Standard Arabic" MODERN_STANDARD_ARABIC_ROMANIZED = "Modern Standard Arabic (Romanized)" NAJDI_ARABIC = "Najdi Arabic" MOROCCAN_ARABIC = "Moroccan Arabic" EGYPTIAN_ARABIC = "Egyptian Arabic" ASSAMESE = "Assamese" ASTURIAN = "Asturian" AWADHI = "Awadhi" CENTRAL_AYMARA = "Central Aymara" SOUTH_AZERBAIJANI = "South Azerbaijani" NORTH_AZERBAIJANI = "North Azerbaijani" BASHKIR = "Bashkir" BAMBARA = "Bambara" BALINESE = "Balinese" BELARUSIAN = "Belarusian" BEMBA = "Bemba" BENGALI = "Bengali" BHOJPURI = "Bhojpuri" BANJAR_ARABIC_SCRIPT = "Banjar (Arabic script)" BANJAR_LATIN_SCRIPT = "Banjar (Latin script)" STANDARD_TIBETAN = "Standard Tibetan" BOSNIAN = "Bosnian" BUGINESE = "Buginese" BULGARIAN = "Bulgarian" CATALAN = "Catalan" CEBUANO = "Cebuano" CZECH = "Czech" CHOKWE = "Chokwe" CENTRAL_KURDISH = "Central Kurdish" CRIMEAN_TATAR = "Crimean Tatar" WELSH = "Welsh" DANISH = "Danish" GERMAN = "German" SOUTHWESTERN_DINKA = "Southwestern Dinka" DYULA = "Dyula" DZONGKHA = "Dzongkha" GREEK = "Greek" ENGLISH = "English" ESPERANTO = "Esperanto" ESTONIAN = "Estonian" BASQUE = "Basque" EWE = "Ewe" FAROESE = "Faroese" FIJIAN = "Fijian" FINNISH = "Finnish" FON = "Fon" FRENCH = "French" FRIULIAN = "Friulian" NIGERIAN_FULFULDE = "Nigerian Fulfulde" SCOTTISH_GAELIC = "Scottish Gaelic" IRISH = "Irish" GALICIAN = "Galician" GUARANI = "Guarani" GUJARATI = "Gujarati" HAITIAN_CREOLE = "Haitian Creole" HAUSA = "Hausa" HEBREW = "Hebrew" HINDI = "Hindi" CHHATTISGARHI = "Chhattisgarhi" CROATIAN = "Croatian" HUNGARIAN = "Hungarian" ARMENIAN = "Armenian" IGBO = "Igbo" ILOCANO = "Ilocano" INDONESIAN = "Indonesian" ICELANDIC = "Icelandic" ITALIAN = "Italian" JAVANESE = "Javanese" JAPANESE = "Japanese" KABYLE = "Kabyle" JINGPHO = "Jingpho" KAMBA = "Kamba" KANNADA = "Kannada" KASHMIRI_ARABIC_SCRIPT = "Kashmiri (Arabic script)" KASHMIRI_DEVANAGARI_SCRIPT = "Kashmiri (Devanagari script)" GEORGIAN = "Georgian" CENTRAL_KANURI_ARABIC_SCRIPT = "Central Kanuri (Arabic script)" CENTRAL_KANURI_LATIN_SCRIPT = "Central Kanuri (Latin script)" KAZAKH = "Kazakh" KABIYÈ = "Kabiyè" KABUVERDIANU = "Kabuverdianu" KHMER = "Khmer" KIKUYU = "Kikuyu" KINYARWANDA = "Kinyarwanda" KYRGYZ = "Kyrgyz" KIMBUNDU = "Kimbundu" NORTHERN_KURDISH = "Northern Kurdish" KIKONGO = "Kikongo" KOREAN = "Korean" LAO = "Lao" LIGURIAN = "Ligurian" LIMBURGISH = "Limburgish" LINGALA = "Lingala" LITHUANIAN = "Lithuanian" LOMBARD = "Lombard" LATGALIAN = "Latgalian" LUXEMBOURGISH = "Luxembourgish" LUBA_KASAI = "Luba-Kasai" GANDA = "Ganda" LUO = "Luo" MIZO = "Mizo" STANDARD_LATVIAN = "Standard Latvian" MAGAHI = "Magahi" MAITHILI = "Maithili" MALAYALAM = "Malayalam" MARATHI = "Marathi" MINANGKABAU_ARABIC_SCRIPT = "Minangkabau (Arabic script)" MINANGKABAU_LATIN_SCRIPT = "Minangkabau (Latin script)" MACEDONIAN = "Macedonian" PLATEAU_MALAGASY = "Plateau Malagasy" MALTESE = "Maltese" MEITEI_BENGALI_SCRIPT = "Meitei (Bengali script)" HALH_MONGOLIAN = "Halh Mongolian" MOSSI = "Mossi" MAORI = "Maori" BURMESE = "Burmese" DUTCH = "Dutch" NORWEGIAN_NYNORSK = "Norwegian Nynorsk" NORWEGIAN_BOKMÅL = "Norwegian Bokmål" NEPALI = "Nepali" NORTHERN_SOTHO = "Northern Sotho" NUER = "Nuer" NYANJA = "Nyanja" OCCITAN = "Occitan" WEST_CENTRAL_OROMO = "West Central Oromo" ODIA = "Odia" PANGASINAN = "Pangasinan" EASTERN_PANJABI = "Eastern Panjabi" PAPIAMENTO = "Papiamento" WESTERN_PERSIAN = "Western Persian" POLISH = "Polish" PORTUGUESE = "Portuguese" DARI = "Dari" SOUTHERN_PASHTO = "Southern Pashto" AYACUCHO_QUECHUA = "Ayacucho Quechua" ROMANIAN = "Romanian" RUNDI = "Rundi" RUSSIAN = "Russian" SANGO = "Sango" SANSKRIT = "Sanskrit" SANTALI = "Santali" SICILIAN = "Sicilian" SHAN = "Shan" SINHALA = "Sinhala" SLOVAK = "Slovak" SLOVENIAN = "Slovenian" SAMOAN = "Samoan" SHONA = "Shona" SINDHI = "Sindhi" SOMALI = "Somali" SOUTHERN_SOTHO = "Southern Sotho" SPANISH = "Spanish" TOSK_ALBANIAN = "Tosk Albanian" SARDINIAN = "Sardinian" SERBIAN = "Serbian" SWATI = "Swati" SUNDANESE = "Sundanese" SWEDISH = "Swedish" SWAHILI = "Swahili" SILESIAN = "Silesian" TAMIL = "Tamil" TATAR = "Tatar" TELUGU = "Telugu" TAJIK = "Tajik" TAGALOG = "Tagalog" THAI = "Thai" TIGRINYA = "Tigrinya" TAMASHEQ_LATIN_SCRIPT = "Tamasheq (Latin script)" TAMASHEQ_TIFINAGH_SCRIPT = "Tamasheq (Tifinagh script)" TOK_PISIN = "Tok Pisin" TSWANA = "Tswana" TSONGA = "Tsonga" TURKMEN = "Turkmen" TUMBUKA = "Tumbuka" TURKISH = "Turkish" TWI = "Twi" CENTRAL_ATLAS_TAMAZIGHT = "Central Atlas Tamazight" UYGHUR = "Uyghur" UKRAINIAN = "Ukrainian" UMBUNDU = "Umbundu" URDU = "Urdu" NORTHERN_UZBEK = "Northern Uzbek" VENETIAN = "Venetian" VIETNAMESE = "Vietnamese" WARAY = "Waray" WOLOF = "Wolof" XHOSA = "Xhosa" EASTERN_YIDDISH = "Eastern Yiddish" YORUBA = "Yoruba" YUE_CHINESE = "Yue Chinese" CHINESE_SIMPLIFIED = "Chinese (Simplified)" CHINESE_TRADITIONAL = "Chinese (Traditional)" STANDARD_MALAY = "Standard Malay" ZULU = "Zulu" ================================================ FILE: dl_translate/utils.py ================================================ from typing import Dict, List from ._pairs import _PAIRS_MBART50, _PAIRS_M2M100, _PAIRS_NLLB200 def _infer_model_family(model_or_path): di = { "facebook/mbart-large-50-many-to-many-mmt": "mbart50", "facebook/m2m100_418M": "m2m100", "facebook/m2m100_1.2B": "m2m100", "facebook/nllb-200-distilled-600M": "nllb200", "facebook/nllb-200-distilled-1.3B": "nllb200", "facebook/nllb-200-1.3B": "nllb200", "facebook/nllb-200-3.3B": "nllb200", } if model_or_path in di: return di[model_or_path] else: error_msg = f'Unable to infer the model_family from "{model_or_path}". Try explicitly setting the value of model_family to "mbart50" or "m2m100".' raise ValueError(error_msg) def _infer_model_or_path(model_or_path): di = { "mbart50": "facebook/mbart-large-50-many-to-many-mmt", "m2m100": "facebook/m2m100_418M", "m2m100-small": "facebook/m2m100_418M", "m2m100-medium": "facebook/m2m100_1.2B", "nllb200": "facebook/nllb-200-distilled-600M", "nllb200-small": "facebook/nllb-200-distilled-600M", "nllb200-medium": "facebook/nllb-200-distilled-1.3B", "nllb200-medium-regular": "facebook/nllb-200-1.3B", "nllb200-large": "facebook/nllb-200-3.3B", } return di.get(model_or_path, model_or_path) def _weights2pairs(): return { "mbart50": _PAIRS_MBART50, "mbart-large-50-many-to-many-mmt": _PAIRS_MBART50, "facebook/mbart-large-50-many-to-many-mmt": _PAIRS_MBART50, "m2m100": _PAIRS_M2M100, "m2m100_418M": _PAIRS_M2M100, "m2m100_1.2B": _PAIRS_M2M100, "facebook/m2m100_418M": _PAIRS_M2M100, "facebook/m2m100_1.2B": _PAIRS_M2M100, "nllb200": _PAIRS_NLLB200, "nllb-200-distilled": _PAIRS_NLLB200, "nllb-200-distilled-600M": _PAIRS_NLLB200, "nllb-200-distilled-1.3B": _PAIRS_NLLB200, "nllb-200-1.3B": _PAIRS_NLLB200, "nllb-200-3.3B": _PAIRS_NLLB200, "facebook/nllb-200-distilled-600M": _PAIRS_NLLB200, "facebook/nllb-200-distilled-1.3B": _PAIRS_NLLB200, "facebook/nllb-200-1.3B": _PAIRS_NLLB200, "facebook/nllb-200-3.3B": _PAIRS_NLLB200, } def _dict_from_weights(weights: str) -> dict: """Returns a dictionary of lang, codes, pairs if the provided weights is supported.""" if weights in _weights2pairs(): pairs = _weights2pairs()[weights] return { "langs": tuple(pair[0] for pair in pairs), "codes": tuple(pair[1] for pair in pairs), "pairs": dict(pairs), } elif weights.lower() in _weights2pairs(): pairs = _weights2pairs()[weights.lower()] return { "langs": tuple(pair[0] for pair in pairs), "codes": tuple(pair[1] for pair in pairs), "pairs": dict(pairs), } else: error_message = f"Incorrect argument '{weights}' for parameter weights. Please choose from: {list(_weights2pairs().keys())}" raise ValueError(error_message) def get_lang_code_map(weights: str = "m2m100") -> Dict[str, str]: """ *Get a dictionary mapping a language -> code for a given model. The code will depend on the model you choose.* {{params}} {{weights}} The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 languages available to use. """ return _dict_from_weights(weights)["pairs"] def available_languages(weights: str = "m2m100") -> List[str]: """ *Get all the languages available for a given model.* {{params}} {{weights}} The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 languages available to use. """ return _dict_from_weights(weights)["langs"] def available_codes(weights: str = "m2m100") -> List[str]: """ *Get all the codes available for a given model. The code format will depend on the model you select.* {{params}} {{weights}} The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 codes available to use. """ return _dict_from_weights(weights)["codes"] ================================================ FILE: docs/available_languages.md ================================================ # Languages Available This page gives all the languages available for each model family. ## MBart 50 | Language Name | Code | | --- | --- | | Arabic | ar_AR | | Czech | cs_CZ | | German | de_DE | | English | en_XX | | Spanish | es_XX | | Estonian | et_EE | | Finnish | fi_FI | | French | fr_XX | | Gujarati | gu_IN | | Hindi | hi_IN | | Italian | it_IT | | Japanese | ja_XX | | Kazakh | kk_KZ | | Korean | ko_KR | | Lithuanian | lt_LT | | Latvian | lv_LV | | Burmese | my_MM | | Nepali | ne_NP | | Dutch | nl_XX | | Romanian | ro_RO | | Russian | ru_RU | | Sinhala | si_LK | | Turkish | tr_TR | | Vietnamese | vi_VN | | Chinese | zh_CN | | Afrikaans | af_ZA | | Azerbaijani | az_AZ | | Bengali | bn_IN | | Persian | fa_IR | | Hebrew | he_IL | | Croatian | hr_HR | | Indonesian | id_ID | | Georgian | ka_GE | | Khmer | km_KH | | Macedonian | mk_MK | | Malayalam | ml_IN | | Mongolian | mn_MN | | Marathi | mr_IN | | Polish | pl_PL | | Pashto | ps_AF | | Portuguese | pt_XX | | Swedish | sv_SE | | Swahili | sw_KE | | Tamil | ta_IN | | Telugu | te_IN | | Thai | th_TH | | Tagalog | tl_XX | | Ukrainian | uk_UA | | Urdu | ur_PK | | Xhosa | xh_ZA | | Galician | gl_ES | | Slovene | sl_SI | ## M2M-100 | Language Name | Code | | --- | --- | | Afrikaans | af | | Amharic | am | | Arabic | ar | | Asturian | ast | | Azerbaijani | az | | Bashkir | ba | | Belarusian | be | | Bulgarian | bg | | Bengali | bn | | Breton | br | | Bosnian | bs | | Catalan | ca | | Valencian | ca | | Cebuano | ceb | | Czech | cs | | Welsh | cy | | Danish | da | | German | de | | Greek | el | | English | en | | Spanish | es | | Estonian | et | | Persian | fa | | Fulah | ff | | Finnish | fi | | French | fr | | Western Frisian | fy | | Irish | ga | | Gaelic | gd | | Scottish Gaelic | gd | | Galician | gl | | Gujarati | gu | | Hausa | ha | | Hebrew | he | | Hindi | hi | | Croatian | hr | | Haitian | ht | | Haitian Creole | ht | | Hungarian | hu | | Armenian | hy | | Indonesian | id | | Igbo | ig | | Iloko | ilo | | Icelandic | is | | Italian | it | | Japanese | ja | | Javanese | jv | | Georgian | ka | | Kazakh | kk | | Khmer | km | | Central Khmer | km | | Kannada | kn | | Korean | ko | | Luxembourgish | lb | | Letzeburgesch | lb | | Ganda | lg | | Lingala | ln | | Lao | lo | | Lithuanian | lt | | Latvian | lv | | Malagasy | mg | | Macedonian | mk | | Malayalam | ml | | Mongolian | mn | | Marathi | mr | | Malay | ms | | Burmese | my | | Nepali | ne | | Dutch | nl | | Flemish | nl | | Norwegian | no | | Northern Sotho | ns | | Occitan | oc | | Oriya | or | | Panjabi | pa | | Punjabi | pa | | Polish | pl | | Pushto | ps | | Pashto | ps | | Portuguese | pt | | Romanian | ro | | Moldavian | ro | | Moldovan | ro | | Russian | ru | | Sindhi | sd | | Sinhala | si | | Sinhalese | si | | Slovak | sk | | Slovenian | sl | | Somali | so | | Albanian | sq | | Serbian | sr | | Swati | ss | | Sundanese | su | | Swedish | sv | | Swahili | sw | | Tamil | ta | | Thai | th | | Tagalog | tl | | Tswana | tn | | Turkish | tr | | Ukrainian | uk | | Urdu | ur | | Uzbek | uz | | Vietnamese | vi | | Wolof | wo | | Xhosa | xh | | Yiddish | yi | | Yoruba | yo | | Chinese | zh | | Zulu | zu | ## NLLB-200 | Language Name | Code | | --- | --- | | Acehnese (Arabic script) | ace_Arab | | Acehnese (Latin script) | ace_Latn | | Mesopotamian Arabic | acm_Arab | | Ta'izzi-Adeni Arabic | acq_Arab | | Tunisian Arabic | aeb_Arab | | Afrikaans | afr_Latn | | South Levantine Arabic | ajp_Arab | | Akan | aka_Latn | | Amharic | amh_Ethi | | North Levantine Arabic | apc_Arab | | Modern Standard Arabic | arb_Arab | | Modern Standard Arabic (Romanized) | arb_Latn | | Najdi Arabic | ars_Arab | | Moroccan Arabic | ary_Arab | | Egyptian Arabic | arz_Arab | | Assamese | asm_Beng | | Asturian | ast_Latn | | Awadhi | awa_Deva | | Central Aymara | ayr_Latn | | South Azerbaijani | azb_Arab | | North Azerbaijani | azj_Latn | | Bashkir | bak_Cyrl | | Bambara | bam_Latn | | Balinese | ban_Latn | | Belarusian | bel_Cyrl | | Bemba | bem_Latn | | Bengali | ben_Beng | | Bhojpuri | bho_Deva | | Banjar (Arabic script) | bjn_Arab | | Banjar (Latin script) | bjn_Latn | | Standard Tibetan | bod_Tibt | | Bosnian | bos_Latn | | Buginese | bug_Latn | | Bulgarian | bul_Cyrl | | Catalan | cat_Latn | | Cebuano | ceb_Latn | | Czech | ces_Latn | | Chokwe | cjk_Latn | | Central Kurdish | ckb_Arab | | Crimean Tatar | crh_Latn | | Welsh | cym_Latn | | Danish | dan_Latn | | German | deu_Latn | | Southwestern Dinka | dik_Latn | | Dyula | dyu_Latn | | Dzongkha | dzo_Tibt | | Greek | ell_Grek | | English | eng_Latn | | Esperanto | epo_Latn | | Estonian | est_Latn | | Basque | eus_Latn | | Ewe | ewe_Latn | | Faroese | fao_Latn | | Fijian | fij_Latn | | Finnish | fin_Latn | | Fon | fon_Latn | | French | fra_Latn | | Friulian | fur_Latn | | Nigerian Fulfulde | fuv_Latn | | Scottish Gaelic | gla_Latn | | Irish | gle_Latn | | Galician | glg_Latn | | Guarani | grn_Latn | | Gujarati | guj_Gujr | | Haitian Creole | hat_Latn | | Hausa | hau_Latn | | Hebrew | heb_Hebr | | Hindi | hin_Deva | | Chhattisgarhi | hne_Deva | | Croatian | hrv_Latn | | Hungarian | hun_Latn | | Armenian | hye_Armn | | Igbo | ibo_Latn | | Ilocano | ilo_Latn | | Indonesian | ind_Latn | | Icelandic | isl_Latn | | Italian | ita_Latn | | Javanese | jav_Latn | | Japanese | jpn_Jpan | | Kabyle | kab_Latn | | Jingpho | kac_Latn | | Kamba | kam_Latn | | Kannada | kan_Knda | | Kashmiri (Arabic script) | kas_Arab | | Kashmiri (Devanagari script) | kas_Deva | | Georgian | kat_Geor | | Central Kanuri (Arabic script) | knc_Arab | | Central Kanuri (Latin script) | knc_Latn | | Kazakh | kaz_Cyrl | | Kabiyè | kbp_Latn | | Kabuverdianu | kea_Latn | | Khmer | khm_Khmr | | Kikuyu | kik_Latn | | Kinyarwanda | kin_Latn | | Kyrgyz | kir_Cyrl | | Kimbundu | kmb_Latn | | Northern Kurdish | kmr_Latn | | Kikongo | kon_Latn | | Korean | kor_Hang | | Lao | lao_Laoo | | Ligurian | lij_Latn | | Limburgish | lim_Latn | | Lingala | lin_Latn | | Lithuanian | lit_Latn | | Lombard | lmo_Latn | | Latgalian | ltg_Latn | | Luxembourgish | ltz_Latn | | Luba-Kasai | lua_Latn | | Ganda | lug_Latn | | Luo | luo_Latn | | Mizo | lus_Latn | | Standard Latvian | lvs_Latn | | Magahi | mag_Deva | | Maithili | mai_Deva | | Malayalam | mal_Mlym | | Marathi | mar_Deva | | Minangkabau (Arabic script) | min_Arab | | Minangkabau (Latin script) | min_Latn | | Macedonian | mkd_Cyrl | | Plateau Malagasy | plt_Latn | | Maltese | mlt_Latn | | Meitei (Bengali script) | mni_Beng | | Halh Mongolian | khk_Cyrl | | Mossi | mos_Latn | | Maori | mri_Latn | | Burmese | mya_Mymr | | Dutch | nld_Latn | | Norwegian Nynorsk | nno_Latn | | Norwegian Bokmål | nob_Latn | | Nepali | npi_Deva | | Northern Sotho | nso_Latn | | Nuer | nus_Latn | | Nyanja | nya_Latn | | Occitan | oci_Latn | | West Central Oromo | gaz_Latn | | Odia | ory_Orya | | Pangasinan | pag_Latn | | Eastern Panjabi | pan_Guru | | Papiamento | pap_Latn | | Western Persian | pes_Arab | | Polish | pol_Latn | | Portuguese | por_Latn | | Dari | prs_Arab | | Southern Pashto | pbt_Arab | | Ayacucho Quechua | quy_Latn | | Romanian | ron_Latn | | Rundi | run_Latn | | Russian | rus_Cyrl | | Sango | sag_Latn | | Sanskrit | san_Deva | | Santali | sat_Olck | | Sicilian | scn_Latn | | Shan | shn_Mymr | | Sinhala | sin_Sinh | | Slovak | slk_Latn | | Slovenian | slv_Latn | | Samoan | smo_Latn | | Shona | sna_Latn | | Sindhi | snd_Arab | | Somali | som_Latn | | Southern Sotho | sot_Latn | | Spanish | spa_Latn | | Tosk Albanian | als_Latn | | Sardinian | srd_Latn | | Serbian | srp_Cyrl | | Swati | ssw_Latn | | Sundanese | sun_Latn | | Swedish | swe_Latn | | Swahili | swh_Latn | | Silesian | szl_Latn | | Tamil | tam_Taml | | Tatar | tat_Cyrl | | Telugu | tel_Telu | | Tajik | tgk_Cyrl | | Tagalog | tgl_Latn | | Thai | tha_Thai | | Tigrinya | tir_Ethi | | Tamasheq (Latin script) | taq_Latn | | Tamasheq (Tifinagh script) | taq_Tfng | | Tok Pisin | tpi_Latn | | Tswana | tsn_Latn | | Tsonga | tso_Latn | | Turkmen | tuk_Latn | | Tumbuka | tum_Latn | | Turkish | tur_Latn | | Twi | twi_Latn | | Central Atlas Tamazight | tzm_Tfng | | Uyghur | uig_Arab | | Ukrainian | ukr_Cyrl | | Umbundu | umb_Latn | | Urdu | urd_Arab | | Northern Uzbek | uzn_Latn | | Venetian | vec_Latn | | Vietnamese | vie_Latn | | Waray | war_Latn | | Wolof | wol_Latn | | Xhosa | xho_Latn | | Eastern Yiddish | ydd_Hebr | | Yoruba | yor_Latn | | Yue Chinese | yue_Hant | | Chinese (Simplified) | zho_Hans | | Chinese (Traditional) | zho_Hant | | Standard Malay | zsm_Latn | | Zulu | zul_Latn | ================================================ FILE: docs/contributing.md ================================================ # Contributions If you wish to contribute to the project, please do the following: 1. Verify if there's an existing similar issue. 2. If no issue exists, create it. 3. Once the contribution has been discussed inside the issue, fork this repo. 4. Before modifying any code, make sure to read the sections below. 5. Once you are done with your contribution, start a PR and tag a codeowner. ## Setup To set up the development environment, clone the repo: ```bash git clone https://github.com/xhlulu/dl-translate cd dl-translate ``` Create a new venv and install the dev dependencies ```bash python -m venv venv source venv/bin/activate pip install -e .[dev] ``` ## Code linting To ensure consistent and readable code, we use `black`. To run: ```bash python black . ``` ## Running tests To run **all** the tests: ```bash python -m pytest tests ``` For quick tests, run: ```bash python -m pytest tests/fast ``` ## Documentation To re-generate the documentation after the source code was modified: ```bash python scripts/render_references.py ``` To run the docs locally, run: ``` mkdocs serve -t material ``` Once ready, you can build it: ``` mkdocs build -t material ``` Or release it on GitHub Pages: ``` mkdocs gh-deploy -t material ``` ================================================ FILE: docs/index.md ================================================ # User Guide Quick links: 💻 [GitHub Repository](https://github.com/xhlulu/dl-translate)
📚 [Documentation](https://xhluca.github.io/dl-translate)
🐍 [PyPi project](https://pypi.org/project/dl-translate/)
🧪 [Colab Demo](https://colab.research.google.com/github/xhlulu/dl-translate/blob/main/demos/colab_demo.ipynb) / [Kaggle Demo](https://www.kaggle.com/xhlulu/dl-translate-demo/) ## Quickstart Install the library with pip: ``` pip install dl-translate ``` To translate some text: ```python import dl_translate as dlt mt = dlt.TranslationModel() # Slow when you load it for the first time text_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है" mt.translate(text_hi, source=dlt.lang.HINDI, target=dlt.lang.ENGLISH) ``` Above, you can see that `dlt.lang` contains variables representing each of the 50 available languages with auto-complete support. Alternatively, you can specify the language (e.g. "Arabic") or the language code (e.g. "fr" for French): ```python text_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا." mt.translate(text_ar, source="Arabic", target="fr") ``` If you want to verify whether a language is available, you can check it: ```python print(mt.available_languages()) # All languages that you can use print(mt.available_codes()) # Code corresponding to each language accepted print(mt.get_lang_code_map()) # Dictionary of lang -> code ``` ## Usage ### Selecting a device When you load the model, you can specify the device using the `device` argument. By default, the value will be `device="auto"`, which means it will use a GPU if possible. You can also explicitly set `device="cpu"` or `device="gpu"`, or some other strings accepted by [`torch.device()`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device). __In general, it is recommend to use a GPU if you want a reasonable processing time.__ ```python mt = dlt.TranslationModel(device="auto") # Automatically select device mt = dlt.TranslationModel(device="cpu") # Force you to use a CPU mt = dlt.TranslationModel(device="gpu") # Force you to use a GPU mt = dlt.TranslationModel(device="cuda:2") # Use the 3rd GPU available ``` ### Choosing a different model By default, the `m2m100` model will be used. However, there are a few options: * [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html): Allows translations across 50 languages. * [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html): Allows translations across 100 languages. * [nllb-200](https://huggingface.co/docs/transformers/model_doc/nllb) (New in v0.3): Allows translations across 200 languages, and is faster than m2m100 (On RTX A6000, we can see speed up of 3x). Here's an example: ```python # The default approval mt = dlt.TranslationModel("m2m100") # Shorthand mt = dlt.TranslationModel("facebook/m2m100_418M") # Huggingface repo # If you want to use mBART-50 Large mt = dlt.TranslationModel("mbart50") mt = dlt.TranslationModel("facebook/mbart-large-50-many-to-many-mmt") # Or NLLB-200 (faster and has 200 languages) mt = dlt.TranslationModel("nllb200") mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M") ``` Note that the language code will change depending on the model family. To find out the correct language codes, please read the doc page on available languages or run `mt.available_codes()`. By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt), [m2m100](https://huggingface.co/facebook/m2m100_418M), or [nllb200](https://huggingface.co/facebook/nllb-200-distilled-600M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`: ```python mt = dlt.TranslationModel("/path/to/model/directory/", model_family="mbart50") mt = dlt.TranslationModel("facebook/m2m100_1.2B", model_family="m2m100") mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M", model_family="nllb200") ``` Notes: * Make sure your tokenizer is also stored in the same directory if you load from a file. * The available languages will change if you select a different model, so you will not be able to leverage `dlt.lang` or `dlt.utils`. ### Breaking down into sentences It is not recommended to use extremely long texts as it takes more time to process. Instead, you can try to break them down into sentences. Multiple solutions exists for that, including doing it manually and using the `nltk` library. A quick approach would be to split them by period. However, you have to ensure that there are no periods used for abbreviations (such as `Mr.` or `Dr.`). For example, it will work in the following case: ```python text = "Mr Smith went to his favorite cafe. There, he met his friend Dr Doe." sents = text.split(".") ".".join(mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH)) ``` For more complex cases (e.g. where you use periods for abbreviations), you can use `nltk`. First install the library with `pip install nltk`, then run: ```python import nltk nltk.download("punkt") text = "Mr. Smith went to his favorite cafe. There, he met his friend Dr. Doe." sents = nltk.tokenize.sent_tokenize(text, "english") # don't use dlt.lang.ENGLISH " ".join(mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH)) ``` ### Batch size and verbosity when using `translate` It's possible to set a batch size (i.e. the number of elements processed at once) for `mt.translate` and whether you want to see the progress bar or not: ```python ... mt = dlt.TranslationModel() mt.translate(text, source, target, batch_size=32, verbose=True) ``` If you set `batch_size=None`, it will compute the entire `text` at once rather than splitting into "chunks". We recommend lowering `batch_size` if you do not have a lot of RAM or VRAM and run into CUDA memory error. Set a higher value if you are using a high-end GPU and the VRAM is not fully utilized. ### `dlt.utils` module An alternative to `mt.available_languages()` is the `dlt.utils` module. You can use it to find out which languages and codes are available: ```python print(dlt.utils.available_languages('mbart50')) # All languages that you can use print(dlt.utils.available_codes('mbart50')) # Code corresponding to each language accepted print(dlt.utils.get_lang_code_map('mbart50')) # Dictionary of lang -> code print(dlt.utils.available_languages('m2m100')) # write the name of the model family ``` At the moment, the following models are accepted: - `"mbart50"` - `"m2m100"` - `"nllb200"` ### Offline usage Unlike the Google translate or MSFT Translator APIs, this library can be fully used offline. However, you will need to first download the packages and models, and move them to your offline environment to be installed and loaded inside a venv. First, run in your terminal: ```bash mkdir dlt cd dlt mkdir libraries pip download -d libraries/ dl-translate ``` Once all the required packages are downloaded, you will need to use huggingface hub to download the files. Install it with `pip install huggingface-hub`. Then, run inside Python: ```python import shutil import huggingface_hub as hub dirname = hub.snapshot_download("facebook/m2m100_418M") shutil.copytree(dirname, "cached_model_m2m100") # Copy to a permanent folder ``` Now, move everything in the `dlt` directory to your offline environment. Create a virtual environment and run the following in terminal: ```bash pip install --no-index --find-links libraries/ dl-translate ``` Now, run inside Python: ```python import dl_translate as dlt mt = dlt.TranslationModel("cached_model_m2m100", model_family="m2m100") ``` ## Advanced The following section assumes you have knowledge of PyTorch and Huggingface Transformers. ### Saving and loading If you wish to accelerate the loading time the translation model, you can use `save_obj`. Later you can reload it with `load_obj` by specifying the same directory that you are using to save. ```python mt = dlt.TranslationModel() # ... mt.save_obj('saved_model') # ... mt = dlt.TranslationModel.load_obj('saved_model') ``` **Warning:** Only use this if you are certain the torch module saved in `saved_model/weights.pt` can be correctly loaded. Indeed, it is possible that the `huggingface`, `torch` or some other dependencies change between when you called `save_obj` and `load_obj`, and that might break your code. Thus, it is recommend to only run `load_obj` in the same environment/session as `save_obj`. **Note this method might be deprecated in the future once there's no speed benefit in loading this way.** ### Interacting with underlying model and tokenizer When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer (which will respectively be passed to `ModelForConditionalGeneration.from_pretrained` and `TokenizerFast.from_pretrained`): ```python mt = dlt.TranslationModel( model_options=dict( state_dict=..., cache_dir=..., ... ), tokenizer_options=dict( tokenizer_file=..., eos_token=..., ... ) ) ``` You can also access the underlying `transformers` model and `tokenizer`: ```python transformers_model = mt.get_transformers_model() tokenizer = mt.get_tokenizer() ``` For more information about the models themselves, please read the docs on [mBART](https://huggingface.co/transformers/master/model_doc/mbart.html) and [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html). ### Keyword arguments for the `generate()` method of the underlying model When running `mt.translate`, you can also give a `generation_options` dictionary that is passed as keyword arguments to the underlying `mt.get_transformers_model().generate()` method: ```python mt.translate( text, source=dlt.lang.GERMAN, target=dlt.lang.SPANISH, generation_options=dict(num_beams=5, max_length=...) ) ``` Learn more in the [huggingface docs](https://huggingface.co/transformers/main_classes/model.html#transformers.generation_utils.GenerationMixin.generate). ================================================ FILE: docs/references.md ================================================ # API Reference ## dlt.TranslationModel ### __init__ ```python dlt.TranslationModel.__init__(self, model_or_path: str = 'm2m100', tokenizer_path: str = None, device: str = 'auto', model_family: str = None, model_options: dict = None, tokenizer_options: dict = None) ``` *Instantiates a multilingual transformer model for translation.* | Parameter | Type | Default | Description | |-|-|-|-| | **model_or_path** | *str* | `m2m100` | The path or the name of the model. Equivalent to the first argument of `AutoModel.from_pretrained()`. You can also specify shorthands ("mbart50" and "m2m100"). | **tokenizer_path** | *str* | *optional* | The path to the tokenizer. By default, it will be set to `model_or_path`. | **device** | *str* | `auto` | "cpu", "gpu" or "auto". If it's set to "auto", will try to select a GPU when available or else fall back to CPU. | **model_family** | *str* | *optional* | Either "mbart50" or "m2m100". By default, it will be inferred based on `model_or_path`. Needs to be explicitly set if `model_or_path` is a path. | **model_options** | *dict* | *optional* | The keyword arguments passed to the model, which is a transformer for conditional generation. | **tokenizer_options** | *dict* | *optional* | The keyword arguments passed to the model's tokenizer.
### translate ```python dlt.TranslationModel.translate(self, text: Union[str, List[str]], source: str, target: str, batch_size: int = 32, verbose: bool = False, generation_options: dict = None) -> Union[str, List[str]] ``` *Translates a string or a list of strings from a source to a target language.* | Parameter | Type | Default | Description | |-|-|-|-| | **text** | *Union[str, List[str]]* | *required* | The content you want to translate. | **source** | *str* | *required* | The language of the original text. | **target** | *str* | *required* | The language of the translated text. | **batch_size** | *int* | `32` | The number of samples to load at once. If set to `None`, it will process everything at once. | **verbose** | *bool* | `False` | Whether to display the progress bar for every batch processed. | **generation_options** | *dict* | *optional* | The keyword arguments passed to `model.generate()`, where `model` is the underlying transformers model. Note: - Run `print(dlt.utils.available_languages())` to see what's available. - A smaller value is preferred for `batch_size` if your (video) RAM is limited.
### get_transformers_model ```python dlt.TranslationModel.get_transformers_model(self) ``` *Retrieve the underlying mBART transformer model.*
### get_tokenizer ```python dlt.TranslationModel.get_tokenizer(self) ``` *Retrieve the mBART huggingface tokenizer.*
### available_codes ```python dlt.TranslationModel.available_codes(self) -> List[str] ``` *Returns all the available codes for a given `dlt.TranslationModel` instance.*
### available_languages ```python dlt.TranslationModel.available_languages(self) -> List[str] ``` *Returns all the available languages for a given `dlt.TranslationModel` instance.*
### get_lang_code_map ```python dlt.TranslationModel.get_lang_code_map(self) -> Dict[str, str] ``` *Returns the language -> codes dictionary for a given `dlt.TranslationModel` instance.*
### save_obj ```python dlt.TranslationModel.save_obj(self, path: str = 'saved_model') -> None ``` *Saves your model as a torch object and save your tokenizer.* | Parameter | Type | Default | Description | |-|-|-|-| | **path** | *str* | `saved_model` | The directory where you want to save your model and tokenizer
### load_obj ```python dlt.TranslationModel.load_obj(path: str = 'saved_model', **kwargs) ``` *Initialize `dlt.TranslationModel` from the torch object and tokenizer saved with `dlt.TranslationModel.save_obj`* | Parameter | Type | Default | Description | |-|-|-|-| | **path** | *str* | `saved_model` | The directory where your torch model and tokenizer are stored

## dlt.utils ### get_lang_code_map ```python dlt.utils.get_lang_code_map(weights: str = 'mbart50') -> Dict[str, str] ``` *Get a dictionary mapping a language -> code for a given model. The code will depend on the model you choose.* | Parameter | Type | Default | Description | |-|-|-|-| | **weights** | *str* | `mbart50` | The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 languages available to use.
### available_codes ```python dlt.utils.available_codes(weights: str = 'mbart50') -> List[str] ``` *Get all the codes available for a given model. The code format will depend on the model you select.* | Parameter | Type | Default | Description | |-|-|-|-| | **weights** | *str* | `mbart50` | The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 codes available to use.
### available_languages ```python dlt.utils.available_languages(weights: str = 'mbart50') -> List[str] ``` *Get all the languages available for a given model.* | Parameter | Type | Default | Description | |-|-|-|-| | **weights** | *str* | `mbart50` | The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 languages available to use.

================================================ FILE: docs/requirements.txt ================================================ mkdocs mkdocs-material jinja2<3.1.0 ================================================ FILE: mkdocs-rtd.yml ================================================ site_name: DL Translate repo_url: https://github.com/xhluca/dl-translate edit_uri: blob/main/docs/ nav: - index.md - references.md - contributing.md - available_languages.md theme: readthedocs ================================================ FILE: mkdocs.yml ================================================ site_name: DL Translate repo_url: https://github.com/xhluca/dl-translate nav: - index.md - references.md - contributing.md - available_languages.md theme: material markdown_extensions: - pymdownx.highlight - pymdownx.superfences ================================================ FILE: scripts/generate_langs.py ================================================ import json import os def name_to_var(lang_name): return ( lang_name.upper() .replace(" ", "_") .replace("(", "") .replace(")", "") .replace("-", "_") .replace("'", "") ) def load_json(name): filepath = os.path.join(os.path.dirname(__file__), "langs_coverage", f"{name}.json") return json.loads(open(filepath).read()) auto_gen_comment = f"# Auto-generated. Do not modify, use {__file__} instead.\n" name2json = {} for name in ["m2m100", "mbart50", "nllb200"]: name2json[name] = lang2code = load_json(name) with open(f"./dl_translate/lang/{name}.py", "w") as f: f.write(auto_gen_comment) for lang, code in lang2code.items(): f.write(f'{name_to_var(lang)} = "{lang}"\n') with open("./dl_translate/_pairs.py", "w") as f: f.write(auto_gen_comment) for name, lang2code in name2json.items(): f.write(f"_PAIRS_{name.upper()} = {tuple(lang2code.items())}\n") ================================================ FILE: scripts/langs_coverage/m2m100.json ================================================ { "Afrikaans": "af", "Amharic": "am", "Arabic": "ar", "Asturian": "ast", "Azerbaijani": "az", "Bashkir": "ba", "Belarusian": "be", "Bulgarian": "bg", "Bengali": "bn", "Breton": "br", "Bosnian": "bs", "Catalan": "ca", "Valencian": "ca", "Cebuano": "ceb", "Czech": "cs", "Welsh": "cy", "Danish": "da", "German": "de", "Greek": "el", "English": "en", "Spanish": "es", "Estonian": "et", "Persian": "fa", "Fulah": "ff", "Finnish": "fi", "French": "fr", "Western Frisian": "fy", "Irish": "ga", "Gaelic": "gd", "Scottish Gaelic": "gd", "Galician": "gl", "Gujarati": "gu", "Hausa": "ha", "Hebrew": "he", "Hindi": "hi", "Croatian": "hr", "Haitian": "ht", "Haitian Creole": "ht", "Hungarian": "hu", "Armenian": "hy", "Indonesian": "id", "Igbo": "ig", "Iloko": "ilo", "Icelandic": "is", "Italian": "it", "Japanese": "ja", "Javanese": "jv", "Georgian": "ka", "Kazakh": "kk", "Khmer": "km", "Central Khmer": "km", "Kannada": "kn", "Korean": "ko", "Luxembourgish": "lb", "Letzeburgesch": "lb", "Ganda": "lg", "Lingala": "ln", "Lao": "lo", "Lithuanian": "lt", "Latvian": "lv", "Malagasy": "mg", "Macedonian": "mk", "Malayalam": "ml", "Mongolian": "mn", "Marathi": "mr", "Malay": "ms", "Burmese": "my", "Nepali": "ne", "Dutch": "nl", "Flemish": "nl", "Norwegian": "no", "Northern Sotho": "ns", "Occitan": "oc", "Oriya": "or", "Panjabi": "pa", "Punjabi": "pa", "Polish": "pl", "Pushto": "ps", "Pashto": "ps", "Portuguese": "pt", "Romanian": "ro", "Moldavian": "ro", "Moldovan": "ro", "Russian": "ru", "Sindhi": "sd", "Sinhala": "si", "Sinhalese": "si", "Slovak": "sk", "Slovenian": "sl", "Somali": "so", "Albanian": "sq", "Serbian": "sr", "Swati": "ss", "Sundanese": "su", "Swedish": "sv", "Swahili": "sw", "Tamil": "ta", "Thai": "th", "Tagalog": "tl", "Tswana": "tn", "Turkish": "tr", "Ukrainian": "uk", "Urdu": "ur", "Uzbek": "uz", "Vietnamese": "vi", "Wolof": "wo", "Xhosa": "xh", "Yiddish": "yi", "Yoruba": "yo", "Chinese": "zh", "Zulu": "zu" } ================================================ FILE: scripts/langs_coverage/mbart50.json ================================================ { "Arabic": "ar_AR", "Czech": "cs_CZ", "German": "de_DE", "English": "en_XX", "Spanish": "es_XX", "Estonian": "et_EE", "Finnish": "fi_FI", "French": "fr_XX", "Gujarati": "gu_IN", "Hindi": "hi_IN", "Italian": "it_IT", "Japanese": "ja_XX", "Kazakh": "kk_KZ", "Korean": "ko_KR", "Lithuanian": "lt_LT", "Latvian": "lv_LV", "Burmese": "my_MM", "Nepali": "ne_NP", "Dutch": "nl_XX", "Romanian": "ro_RO", "Russian": "ru_RU", "Sinhala": "si_LK", "Turkish": "tr_TR", "Vietnamese": "vi_VN", "Chinese": "zh_CN", "Afrikaans": "af_ZA", "Azerbaijani": "az_AZ", "Bengali": "bn_IN", "Persian": "fa_IR", "Hebrew": "he_IL", "Croatian": "hr_HR", "Indonesian": "id_ID", "Georgian": "ka_GE", "Khmer": "km_KH", "Macedonian": "mk_MK", "Malayalam": "ml_IN", "Mongolian": "mn_MN", "Marathi": "mr_IN", "Polish": "pl_PL", "Pashto": "ps_AF", "Portuguese": "pt_XX", "Swedish": "sv_SE", "Swahili": "sw_KE", "Tamil": "ta_IN", "Telugu": "te_IN", "Thai": "th_TH", "Tagalog": "tl_XX", "Ukrainian": "uk_UA", "Urdu": "ur_PK", "Xhosa": "xh_ZA", "Galician": "gl_ES", "Slovene": "sl_SI" } ================================================ FILE: scripts/langs_coverage/nllb200.json ================================================ { "Acehnese (Arabic script)": "ace_Arab", "Acehnese (Latin script)": "ace_Latn", "Mesopotamian Arabic": "acm_Arab", "Ta'izzi-Adeni Arabic": "acq_Arab", "Tunisian Arabic": "aeb_Arab", "Afrikaans": "afr_Latn", "South Levantine Arabic": "ajp_Arab", "Akan": "aka_Latn", "Amharic": "amh_Ethi", "North Levantine Arabic": "apc_Arab", "Modern Standard Arabic": "arb_Arab", "Modern Standard Arabic (Romanized)": "arb_Latn", "Najdi Arabic": "ars_Arab", "Moroccan Arabic": "ary_Arab", "Egyptian Arabic": "arz_Arab", "Assamese": "asm_Beng", "Asturian": "ast_Latn", "Awadhi": "awa_Deva", "Central Aymara": "ayr_Latn", "South Azerbaijani": "azb_Arab", "North Azerbaijani": "azj_Latn", "Bashkir": "bak_Cyrl", "Bambara": "bam_Latn", "Balinese": "ban_Latn", "Belarusian": "bel_Cyrl", "Bemba": "bem_Latn", "Bengali": "ben_Beng", "Bhojpuri": "bho_Deva", "Banjar (Arabic script)": "bjn_Arab", "Banjar (Latin script)": "bjn_Latn", "Standard Tibetan": "bod_Tibt", "Bosnian": "bos_Latn", "Buginese": "bug_Latn", "Bulgarian": "bul_Cyrl", "Catalan": "cat_Latn", "Cebuano": "ceb_Latn", "Czech": "ces_Latn", "Chokwe": "cjk_Latn", "Central Kurdish": "ckb_Arab", "Crimean Tatar": "crh_Latn", "Welsh": "cym_Latn", "Danish": "dan_Latn", "German": "deu_Latn", "Southwestern Dinka": "dik_Latn", "Dyula": "dyu_Latn", "Dzongkha": "dzo_Tibt", "Greek": "ell_Grek", "English": "eng_Latn", "Esperanto": "epo_Latn", "Estonian": "est_Latn", "Basque": "eus_Latn", "Ewe": "ewe_Latn", "Faroese": "fao_Latn", "Fijian": "fij_Latn", "Finnish": "fin_Latn", "Fon": "fon_Latn", "French": "fra_Latn", "Friulian": "fur_Latn", "Nigerian Fulfulde": "fuv_Latn", "Scottish Gaelic": "gla_Latn", "Irish": "gle_Latn", "Galician": "glg_Latn", "Guarani": "grn_Latn", "Gujarati": "guj_Gujr", "Haitian Creole": "hat_Latn", "Hausa": "hau_Latn", "Hebrew": "heb_Hebr", "Hindi": "hin_Deva", "Chhattisgarhi": "hne_Deva", "Croatian": "hrv_Latn", "Hungarian": "hun_Latn", "Armenian": "hye_Armn", "Igbo": "ibo_Latn", "Ilocano": "ilo_Latn", "Indonesian": "ind_Latn", "Icelandic": "isl_Latn", "Italian": "ita_Latn", "Javanese": "jav_Latn", "Japanese": "jpn_Jpan", "Kabyle": "kab_Latn", "Jingpho": "kac_Latn", "Kamba": "kam_Latn", "Kannada": "kan_Knda", "Kashmiri (Arabic script)": "kas_Arab", "Kashmiri (Devanagari script)": "kas_Deva", "Georgian": "kat_Geor", "Central Kanuri (Arabic script)": "knc_Arab", "Central Kanuri (Latin script)": "knc_Latn", "Kazakh": "kaz_Cyrl", "Kabiyè": "kbp_Latn", "Kabuverdianu": "kea_Latn", "Khmer": "khm_Khmr", "Kikuyu": "kik_Latn", "Kinyarwanda": "kin_Latn", "Kyrgyz": "kir_Cyrl", "Kimbundu": "kmb_Latn", "Northern Kurdish": "kmr_Latn", "Kikongo": "kon_Latn", "Korean": "kor_Hang", "Lao": "lao_Laoo", "Ligurian": "lij_Latn", "Limburgish": "lim_Latn", "Lingala": "lin_Latn", "Lithuanian": "lit_Latn", "Lombard": "lmo_Latn", "Latgalian": "ltg_Latn", "Luxembourgish": "ltz_Latn", "Luba-Kasai": "lua_Latn", "Ganda": "lug_Latn", "Luo": "luo_Latn", "Mizo": "lus_Latn", "Standard Latvian": "lvs_Latn", "Magahi": "mag_Deva", "Maithili": "mai_Deva", "Malayalam": "mal_Mlym", "Marathi": "mar_Deva", "Minangkabau (Arabic script)": "min_Arab", "Minangkabau (Latin script)": "min_Latn", "Macedonian": "mkd_Cyrl", "Plateau Malagasy": "plt_Latn", "Maltese": "mlt_Latn", "Meitei (Bengali script)": "mni_Beng", "Halh Mongolian": "khk_Cyrl", "Mossi": "mos_Latn", "Maori": "mri_Latn", "Burmese": "mya_Mymr", "Dutch": "nld_Latn", "Norwegian Nynorsk": "nno_Latn", "Norwegian Bokmål": "nob_Latn", "Nepali": "npi_Deva", "Northern Sotho": "nso_Latn", "Nuer": "nus_Latn", "Nyanja": "nya_Latn", "Occitan": "oci_Latn", "West Central Oromo": "gaz_Latn", "Odia": "ory_Orya", "Pangasinan": "pag_Latn", "Eastern Panjabi": "pan_Guru", "Papiamento": "pap_Latn", "Western Persian": "pes_Arab", "Polish": "pol_Latn", "Portuguese": "por_Latn", "Dari": "prs_Arab", "Southern Pashto": "pbt_Arab", "Ayacucho Quechua": "quy_Latn", "Romanian": "ron_Latn", "Rundi": "run_Latn", "Russian": "rus_Cyrl", "Sango": "sag_Latn", "Sanskrit": "san_Deva", "Santali": "sat_Olck", "Sicilian": "scn_Latn", "Shan": "shn_Mymr", "Sinhala": "sin_Sinh", "Slovak": "slk_Latn", "Slovenian": "slv_Latn", "Samoan": "smo_Latn", "Shona": "sna_Latn", "Sindhi": "snd_Arab", "Somali": "som_Latn", "Southern Sotho": "sot_Latn", "Spanish": "spa_Latn", "Tosk Albanian": "als_Latn", "Sardinian": "srd_Latn", "Serbian": "srp_Cyrl", "Swati": "ssw_Latn", "Sundanese": "sun_Latn", "Swedish": "swe_Latn", "Swahili": "swh_Latn", "Silesian": "szl_Latn", "Tamil": "tam_Taml", "Tatar": "tat_Cyrl", "Telugu": "tel_Telu", "Tajik": "tgk_Cyrl", "Tagalog": "tgl_Latn", "Thai": "tha_Thai", "Tigrinya": "tir_Ethi", "Tamasheq (Latin script)": "taq_Latn", "Tamasheq (Tifinagh script)": "taq_Tfng", "Tok Pisin": "tpi_Latn", "Tswana": "tsn_Latn", "Tsonga": "tso_Latn", "Turkmen": "tuk_Latn", "Tumbuka": "tum_Latn", "Turkish": "tur_Latn", "Twi": "twi_Latn", "Central Atlas Tamazight": "tzm_Tfng", "Uyghur": "uig_Arab", "Ukrainian": "ukr_Cyrl", "Umbundu": "umb_Latn", "Urdu": "urd_Arab", "Northern Uzbek": "uzn_Latn", "Venetian": "vec_Latn", "Vietnamese": "vie_Latn", "Waray": "war_Latn", "Wolof": "wol_Latn", "Xhosa": "xho_Latn", "Eastern Yiddish": "ydd_Hebr", "Yoruba": "yor_Latn", "Yue Chinese": "yue_Hant", "Chinese (Simplified)": "zho_Hans", "Chinese (Traditional)": "zho_Hant", "Standard Malay": "zsm_Latn", "Zulu": "zul_Latn" } ================================================ FILE: scripts/render_available_langs.py ================================================ import os import json from jinja2 import Template def load_json(name): filepath = os.path.join(os.path.dirname(__file__), "langs_coverage", f"{name}.json") return json.loads(open(filepath).read()) template_values = {} for name in ["m2m100", "mbart50", "nllb200"]: content = "" di = load_json(name) content += "| Language Name | Code |\n" content += "| --- | --- |\n" for key, val in di.items(): content += f"| {key} | {val} |\n" template_values[name] = content template_path = os.path.join( os.path.dirname(__file__), "templates", "available_languages.md.jinja2" ) save_path = os.path.join( os.path.dirname(__file__), "..", "docs", "available_languages.md" ) with open(template_path) as f: template = Template(f.read()) rendered = template.render(template_values) with open(save_path, "w") as f: f.write(rendered) ================================================ FILE: scripts/render_references.py ================================================ import os from typing import NamedTuple, List, Optional, Any, NamedTuple import inspect from jinja2 import Template import dl_translate as dlt type2str = { int: "int", float: "float", str: "str", bool: "bool", dict: "dict", inspect._empty: "unspecified", Optional[Any]: "optional", } default2str = {inspect._empty: "*required*", None: "*optional*"} def preprocess_annot(annotation): annotation = type2str.get(annotation, str(annotation)) return annotation.replace("typing.", "") def preprocess_default(default): default = default2str.get(default, f"`{default}`") return default class FunctionReference: def __init__(self, function, modname=None): self.func = function if modname is None: self.modname = inspect.getmodule(self.func).__name__.replace( "dl_translate", "dlt" ) else: self.modname = modname @property def name(self): return self.func.__name__ @property def signature(self): return inspect.signature(self.func) @property def sig_desc(self): return self.modname + "." + self.name + str(self.signature) @property def doc(self): doc_template = Template(inspect.getdoc(self.func)) kwargs = {"params": "| Parameter | Type | Default | Description |\n|-|-|-|-|"} for arg_name, param in self.signature.parameters.items(): if arg_name == "self": continue annot = preprocess_annot(param.annotation) default = preprocess_default(param.default) kwargs[arg_name] = f"| **{arg_name}** | *{annot}* | {default} |" return doc_template.render(**kwargs) class ModuleReferences(NamedTuple): name: str funcs: List[FunctionReference] template_path = os.path.join( os.path.dirname(__file__), "templates", "references.md.jinja2" ) save_path = os.path.join(os.path.dirname(__file__), "..", "docs", "references.md") with open(template_path) as f: template = Template(f.read()) rendered = template.render( modules=[ ModuleReferences( "dlt.TranslationModel", [ FunctionReference( dlt.TranslationModel.__init__, "dlt.TranslationModel" ), FunctionReference( dlt.TranslationModel.translate, "dlt.TranslationModel" ), FunctionReference( dlt.TranslationModel.get_transformers_model, "dlt.TranslationModel" ), FunctionReference( dlt.TranslationModel.get_tokenizer, "dlt.TranslationModel" ), FunctionReference( dlt.TranslationModel.available_codes, "dlt.TranslationModel" ), FunctionReference( dlt.TranslationModel.available_languages, "dlt.TranslationModel" ), FunctionReference( dlt.TranslationModel.get_lang_code_map, "dlt.TranslationModel" ), FunctionReference( dlt.TranslationModel.save_obj, "dlt.TranslationModel" ), FunctionReference( dlt.TranslationModel.load_obj, "dlt.TranslationModel" ), ], ), ModuleReferences( "dlt.utils", [ FunctionReference(dlt.utils.get_lang_code_map), FunctionReference(dlt.utils.available_codes), FunctionReference(dlt.utils.available_languages), ], ), ] ) with open(save_path, "w") as f: f.write(rendered) ================================================ FILE: scripts/templates/available_languages.md.jinja2 ================================================ # Languages Available This page gives all the languages available for each model family. ## MBart 50 {{mbart50}} ## M2M-100 {{m2m100}} ## NLLB-200 {{nllb200}} ================================================ FILE: scripts/templates/references.md.jinja2 ================================================ # API Reference {% for module in modules %} ## {{module.name}} {% for func in module.funcs %} ### {{ func.name }} ```python {{ func.sig_desc }} ``` {{ func.doc }}
{% endfor %}
{% endfor %} ================================================ FILE: setup.py ================================================ import setuptools with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setuptools.setup( name="dl-translate", version="0.3.1", author="Xing Han Lu", author_email="github@xinghanlu.com", description="A deep learning-based translation library built on Huggingface transformers", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/xhlulu/dl-translate", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], packages=setuptools.find_packages(), python_requires=">=3.7", install_requires=[ "transformers>=4.30.2", "torch>=2.0.0", "sentencepiece", "protobuf", "tqdm", ], extras_require={"dev": ["pytest", "black", "jinja2", "mkdocs", "mkdocs-material"]}, ) ================================================ FILE: tests/long/test_save_load.py ================================================ import os import dl_translate as dlt def test_save(): mt = dlt.TranslationModel() mt.save_obj("saved_model") assert os.path.exists("saved_model/weights.pt") assert os.path.exists("saved_model/tokenizer_config.json") def test_load(): mt = dlt.TranslationModel.load_obj("saved_model") assert isinstance(mt, dlt.TranslationModel) ================================================ FILE: tests/long/test_translate.py ================================================ import dl_translate as dlt def test_translate(): mt = dlt.TranslationModel() msg_en = "Hello everyone, how are you?" assert ( mt.translate(msg_en, source="English", target="Spanish") == "Hola a todos, ¿cómo estás?" ) fr_1 = mt.translate(msg_en, source="English", target="French") ch = mt.translate(msg_en, source="English", target="Chinese") fr_2 = mt.translate([msg_en, msg_en + msg_en], source="English", target="French") assert fr_1 == fr_2[0] assert ch != fr_1 def test_mbart50(): mt = dlt.TranslationModel("mbart50") msg_en = "Hello everyone, how are you?" fr_1 = mt.translate(msg_en, source="English", target="French") ch = mt.translate(msg_en, source="English", target="Chinese") fr_2 = mt.translate([msg_en, msg_en + msg_en], source="English", target="French") assert fr_1 == fr_2[0] assert ch != fr_1 ================================================ FILE: tests/quick/test_lang.py ================================================ import dl_translate as dlt from dl_translate._pairs import _PAIRS_MBART50, _PAIRS_M2M100 def test_lang(): for l, _ in _PAIRS_M2M100: assert getattr(dlt.lang, l.upper().replace(" ", "_")) == l def test_lang_m2m100(): for l, _ in _PAIRS_M2M100: assert getattr(dlt.lang.m2m100, l.upper().replace(" ", "_")) == l def test_lang_mbart50(): for l, _ in _PAIRS_MBART50: assert getattr(dlt.lang.mbart50, l.upper().replace(" ", "_")) == l ================================================ FILE: tests/quick/test_translation_model.py ================================================ import pytest import torch import dl_translate as dlt from dl_translate._translation_model import ( _resolve_lang_codes, _select_device, _infer_model_or_path, _infer_model_family, ) def test_resolve_lang_codes_mbart50(): sources = [dlt.lang.FRENCH, "fr_XX", "French"] targets = [dlt.lang.ENGLISH, "en_XX", "English"] for source, target in zip(sources, targets): s = _resolve_lang_codes(source, "source", "mbart50") t = _resolve_lang_codes(target, "target", "mbart50") assert s == "fr_XX" assert t == "en_XX" def test_resolve_lang_codes_m2m100(): sources = [dlt.lang.m2m100.FRENCH, "fr", "French"] targets = [dlt.lang.m2m100.ENGLISH, "en", "English"] for source, target in zip(sources, targets): s = _resolve_lang_codes(source, "source", "m2m100") t = _resolve_lang_codes(target, "target", "m2m100") assert s == "fr" assert t == "en" def test_resolve_lang_codes_m2m100(): sources = [dlt.lang.nllb200.FRENCH, "fra_Latn", "French"] targets = [dlt.lang.nllb200.ENGLISH, "eng_Latn", "English"] for source, target in zip(sources, targets): s = _resolve_lang_codes(source, "source", "nllb200") t = _resolve_lang_codes(target, "target", "nllb200") assert s == "fra_Latn" assert t == "eng_Latn" sources = ["Central Kanuri (Latin script)"] targets = ["Ta'izzi-Adeni Arabic"] for source, target in zip(sources, targets): s = _resolve_lang_codes(source, "source", "nllb200") t = _resolve_lang_codes(target, "target", "nllb200") assert s == "knc_Latn" assert t == "acq_Arab" def test_select_device(): assert _select_device("cpu") == torch.device("cpu") assert _select_device("gpu") == torch.device("cuda") assert _select_device("cuda:0") == torch.device("cuda", index=0) if torch.cuda.is_available(): assert _select_device("auto") == torch.device("cuda") else: assert _select_device("auto") == torch.device("cpu") def test_infer_model_or_path(): assert _infer_model_or_path("mbart50") == "facebook/mbart-large-50-many-to-many-mmt" assert _infer_model_or_path("m2m100") == "facebook/m2m100_418M" assert _infer_model_or_path("m2m100-small") == "facebook/m2m100_418M" assert _infer_model_or_path("m2m100-medium") == "facebook/m2m100_1.2B" assert _infer_model_or_path("non-existing-value") == "non-existing-value" def test_infer_model_family(): assert _infer_model_family("facebook/mbart-large-50-many-to-many-mmt") == "mbart50" assert _infer_model_family("facebook/m2m100_418M") == "m2m100" assert _infer_model_family("facebook/m2m100_1.2B") == "m2m100" with pytest.raises(ValueError): _infer_model_family("non-existing-value") ================================================ FILE: tests/quick/test_utils.py ================================================ import pytest from dl_translate import utils from dl_translate._pairs import _PAIRS_MBART50, _PAIRS_M2M100, _PAIRS_NLLB200 def test_dict_from_weights(): weights = [ "mbart50", "mbart-large-50-many-to-many-mmt", "facebook/mbart-large-50-many-to-many-mmt", "m2m100", "m2m100_418M", "m2m100_1.2B", "facebook/m2m100_418M", "facebook/m2m100_1.2B", ] valid_keys = ["langs", "codes", "pairs"] for w in weights: assert type(utils._dict_from_weights(w)) is dict keys = utils._dict_from_weights(w).keys() for key in valid_keys: assert key in keys def test_dict_from_weights_exception(): with pytest.raises(ValueError): utils._dict_from_weights("incorrect") def test_available_languages(): assert utils.available_languages() == utils.available_languages() langs = utils.available_languages() for lang, _ in _PAIRS_M2M100: assert lang in langs langs = utils.available_languages("mbart50") for lang, _ in _PAIRS_MBART50: assert lang in langs langs = utils.available_languages("nllb200") for lang, _ in _PAIRS_NLLB200: assert lang in langs def test_available_codes(): assert utils.available_codes() == utils.available_codes("m2m100") codes = utils.available_codes() for _, code in _PAIRS_M2M100: assert code in codes codes = utils.available_codes("mbart50") for _, code in _PAIRS_MBART50: assert code in codes codes = utils.available_codes("nllb200") for _, code in _PAIRS_NLLB200: assert code in codes