Repository: xhlulu/dl-translate
Branch: main
Commit: d9f1fcd6d7d9
Files: 40
Total size: 180.1 KB
Directory structure:
gitextract_lc7n6m8w/
├── .github/
│ └── workflows/
│ ├── generate-docs.yml
│ ├── main.yaml
│ └── python-publish.yml
├── .gitignore
├── .readthedocs.yaml
├── CITATION.cff
├── LICENSE
├── MANIFEST.in
├── README.md
├── demos/
│ ├── colab_demo.ipynb
│ └── nllb_demo.ipynb
├── dl_translate/
│ ├── __init__.py
│ ├── _pairs.py
│ ├── _translation_model.py
│ ├── lang/
│ │ ├── __init__.py
│ │ ├── m2m100.py
│ │ ├── mbart50.py
│ │ └── nllb200.py
│ └── utils.py
├── docs/
│ ├── available_languages.md
│ ├── contributing.md
│ ├── index.md
│ ├── references.md
│ └── requirements.txt
├── mkdocs-rtd.yml
├── mkdocs.yml
├── scripts/
│ ├── generate_langs.py
│ ├── langs_coverage/
│ │ ├── m2m100.json
│ │ ├── mbart50.json
│ │ └── nllb200.json
│ ├── render_available_langs.py
│ ├── render_references.py
│ └── templates/
│ ├── available_languages.md.jinja2
│ └── references.md.jinja2
├── setup.py
└── tests/
├── long/
│ ├── test_save_load.py
│ └── test_translate.py
└── quick/
├── test_lang.py
├── test_translation_model.py
└── test_utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/generate-docs.yml
================================================
name: Publish Docs
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.7"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r docs/requirements.txt
- name: Publish Docs
run: |
mkdocs gh-deploy -t material --force
================================================
FILE: .github/workflows/main.yaml
================================================
name: Lint and run tests
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
# source: https://medium.com/ai2-blog/e9452698e98d
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install --upgrade --upgrade-strategy eager -e .[dev]
- name: Show setup install requires versions
run: |
pip show transformers torch tqdm protobuf tqdm
- name: Lint code with black
run: |
black . --check
- name: Run quick tests with pytest
run: |
pytest tests/quick
- name: Run long tests with pytest
if: ${{ matrix.python-version == '3.7' }}
run: |
pytest tests/long
================================================
FILE: .github/workflows/python-publish.yml
================================================
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Publish Python Package
on:
release:
types: [published]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
================================================
FILE: .gitignore
================================================
# Custom
.vscode
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
================================================
FILE: .readthedocs.yaml
================================================
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
mkdocs:
configuration: mkdocs-rtd.yml
# Optionally build your docs in additional formats such as PDF
formats:
- pdf
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Lu"
given-names: "Xing Han"
orcid: "https://orcid.org/0000-0001-9027-8425"
title: "DL Translate: A deep learning-based translation library built on Huggingface transformers"
version: 0.3.0
doi: 10.5281/zenodo.5230676
date-released: 2021-08-21
url: "https://github.com/xhlulu/dl-translate"
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 Xing Han Lu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: MANIFEST.in
================================================
include CITATION.cff
================================================
FILE: README.md
================================================
# DL Translate
[](https://doi.org/10.5281/zenodo.5230676)
[](https://pepy.tech/project/dl-translate)
[](https://github.com/xhluca/dl-translate/blob/main/LICENSE)
*A translation library for 200 languages built on Huggingface `transformers`*
💻 [GitHub Repository](https://github.com/xhluca/dl-translate)
📚 [Documentation](https://xhluca.github.io/dl-translate)
🐍 [PyPi project](https://pypi.org/project/dl-translate/)
🧪 [Colab Demo](https://colab.research.google.com/github/xhluca/dl-translate/blob/main/demos/colab_demo.ipynb) / [Kaggle Demo](https://www.kaggle.com/xhlulu/dl-translate-demo/)
## Quickstart
Install the library with pip:
```
pip install dl-translate
```
To translate some text:
```python
import dl_translate as dlt
mt = dlt.TranslationModel() # Slow when you load it for the first time
text_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
mt.translate(text_hi, source=dlt.lang.HINDI, target=dlt.lang.ENGLISH)
```
Above, you can see that `dlt.lang` contains variables representing each of the 200 available languages with auto-complete support. Alternatively, you can specify the language (e.g. "Arabic") or the language code (e.g. "fr" for French):
```python
text_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."
mt.translate(text_ar, source="Arabic", target="fr")
```
If you want to verify whether a language is available, you can check it:
```python
print(mt.available_languages()) # All languages that you can use
print(mt.available_codes()) # Code corresponding to each language accepted
print(mt.get_lang_code_map()) # Dictionary of lang -> code
```
## Usage
### Selecting a device
When you load the model, you can specify the device:
```python
mt = dlt.TranslationModel(device="auto")
```
By default, the value will be `device="auto"`, which means it will use a GPU if possible. You can also explicitly set `device="cpu"` or `device="gpu"`, or some other strings accepted by [`torch.device()`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device). __In general, it is recommend to use a GPU if you want a reasonable processing time.__
### Choosing a different model
By default, the `m2m100` model will be used. However, there are a few options:
* [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html): Allows translations across 50 languages.
* [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html): Allows translations across 100 languages.
* [nllb-200](https://huggingface.co/docs/transformers/model_doc/nllb) (New in v0.3): Allows translations across 200 languages, and is faster than m2m100 (On RTX A6000, we can see speed up of 3x).
Here's an example:
```python
# The default approval
mt = dlt.TranslationModel("m2m100") # Shorthand
mt = dlt.TranslationModel("facebook/m2m100_418M") # Huggingface repo
# If you want to use mBART-50 Large
mt = dlt.TranslationModel("mbart50")
mt = dlt.TranslationModel("facebook/mbart-large-50-many-to-many-mmt")
# Or NLLB-200 (faster and has 200 languages)
mt = dlt.TranslationModel("nllb200")
mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M")
```
Note that the language code will change depending on the model family. To find out the correct language codes, please read the doc page on available languages or run `mt.available_codes()`.
By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt), [m2m100](https://huggingface.co/facebook/m2m100_418M), or [nllb200](https://huggingface.co/facebook/nllb-200-distilled-600M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`:
```python
mt = dlt.TranslationModel("/path/to/model/directory/", model_family="mbart50")
mt = dlt.TranslationModel("facebook/m2m100_1.2B", model_family="m2m100")
mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M", model_family="nllb200")
```
Notes:
* Make sure your tokenizer is also stored in the same directory if you load from a file.
* The available languages will change if you select a different model, so you will not be able to leverage `dlt.lang` or `dlt.utils`.
### Splitting into sentences
It is not recommended to use extremely long texts as it takes more time to process. Instead, you can try to break them down into sentences with the help of `nltk`. First install the library with `pip install nltk`, then run:
```python
import nltk
nltk.download("punkt")
text = "Mr. Smith went to his favorite cafe. There, he met his friend Dr. Doe."
sents = nltk.tokenize.sent_tokenize(text, "english") # don't use dlt.lang.ENGLISH
" ".join(mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH))
```
### Batch size during translation
It's possible to set a batch size (i.e. the number of elements processed at once) for `mt.translate` and whether you want to see the progress bar or not:
```python
# ...
mt = dlt.TranslationModel()
mt.translate(text, source, target, batch_size=32, verbose=True)
```
If you set `batch_size=None`, it will compute the entire `text` at once rather than splitting into "chunks". We recommend lowering `batch_size` if you do not have a lot of RAM or VRAM and run into CUDA memory error. Set a higher value if you are using a high-end GPU and the VRAM is not fully utilized.
### `dlt.utils` module
An alternative to `mt.available_languages()` is the `dlt.utils` module. You can use it to find out which languages and codes are available:
```python
print(dlt.utils.available_languages('mbart50')) # All languages that you can use
print(dlt.utils.available_codes('m2m100')) # Code corresponding to each language accepted
print(dlt.utils.get_lang_code_map('nllb200')) # Dictionary of lang -> code
```
### Offline usage
Unlike the Google translate or MSFT Translator APIs, this library can be fully used offline. However, you will need to first download the packages and models, and move them to your offline environment to be installed and loaded inside a venv.
First, run in your terminal:
```bash
mkdir dlt
cd dlt
mkdir libraries
pip download -d libraries/ dl-translate
```
Once all the required packages are downloaded, you will need to use huggingface hub to download the files. Install it with `pip install huggingface-hub`. Then, run inside Python:
```python
import shutil
import huggingface_hub as hub
dirname = hub.snapshot_download("facebook/m2m100_418M")
shutil.copytree(dirname, "cached_model_m2m100") # Copy to a permanent folder
```
Now, move everything in the `dlt` directory to your offline environment. Create a virtual environment and run the following in terminal:
```bash
pip install --no-index --find-links libraries/ dl-translate
```
Now, run inside Python:
```python
import dl_translate as dlt
mt = dlt.TranslationModel("cached_model_m2m100", model_family="m2m100")
```
## Advanced
If you have knowledge of PyTorch and Huggingface Transformers, you can access advanced aspects of the library for more customization:
* **Saving and loading**: If you wish to accelerate the loading time the translation model, you can use `save_obj` and reload it later with `load_obj`. This method is only recommended if you are familiar with `huggingface` and `torch`; please read the docs for more information.
* **Interacting with underlying model and tokenizer**: When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer with `model_options` and `tokenizer_options` respectively. You can also access the underlying `transformers` with `mt.get_transformers_model()`.
* **Keyword arguments for the `generate()` method**: When running `mt.translate`, you can also give `generation_options` that is passed to the `generate()` method of the underlying transformer model.
For more information, please visit the [advanced section of the user guide](https://xhluca.github.io/dl-translate/#advanced).
## Acknowledgement
`dl-translate` is built on top of Huggingface's implementation of two models created by Facebook AI Research.
1. The multilingual BART finetuned on many-to-many translation of over 50 languages, which is [documented here](https://huggingface.co/transformers/master/model_doc/mbart.html) The original paper was written by Tang et. al from Facebook AI Research; you can [find it here](https://arxiv.org/pdf/2008.00401.pdf) and cite it using the following:
```
@article{tang2020multilingual,
title={Multilingual translation with extensible multilingual pretraining and finetuning},
author={Tang, Yuqing and Tran, Chau and Li, Xian and Chen, Peng-Jen and Goyal, Naman and Chaudhary, Vishrav and Gu, Jiatao and Fan, Angela},
journal={arXiv preprint arXiv:2008.00401},
year={2020}
}
```
2. The transformer model published in [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Fan et. al, which supports over 100 languages. You can cite it here:
```
@misc{fan2020englishcentric,
title={Beyond English-Centric Multilingual Machine Translation},
author={Angela Fan and Shruti Bhosale and Holger Schwenk and Zhiyi Ma and Ahmed El-Kishky and Siddharth Goyal and Mandeep Baines and Onur Celebi and Guillaume Wenzek and Vishrav Chaudhary and Naman Goyal and Tom Birch and Vitaliy Liptchinsky and Sergey Edunov and Edouard Grave and Michael Auli and Armand Joulin},
year={2020},
eprint={2010.11125},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
3. The [no language left behind](https://arxiv.org/abs/2207.04672) model, which extends NMT to 200+ languages. You can cite it here:
```
@misc{nllbteam2022language,
title={No Language Left Behind: Scaling Human-Centered Machine Translation},
author={NLLB Team and Marta R. Costa-jussà and James Cross and Onur Çelebi and Maha Elbayad and Kenneth Heafield and Kevin Heffernan and Elahe Kalbassi and Janice Lam and Daniel Licht and Jean Maillard and Anna Sun and Skyler Wang and Guillaume Wenzek and Al Youngblood and Bapi Akula and Loic Barrault and Gabriel Mejia Gonzalez and Prangthip Hansanti and John Hoffman and Semarley Jarrett and Kaushik Ram Sadagopan and Dirk Rowe and Shannon Spruit and Chau Tran and Pierre Andrews and Necip Fazil Ayan and Shruti Bhosale and Sergey Edunov and Angela Fan and Cynthia Gao and Vedanuj Goswami and Francisco Guzmán and Philipp Koehn and Alexandre Mourachko and Christophe Ropers and Safiyyah Saleem and Holger Schwenk and Jeff Wang},
year={2022},
eprint={2207.04672},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
`dlt` is a wrapper with useful `utils` to save you time. For huggingface's `transformers`, the following snippet is shown as an example:
```python
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
# translate Hindi to French
tokenizer.src_lang = "hi_IN"
encoded_hi = tokenizer(article_hi, return_tensors="pt")
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria."
# translate Arabic to English
tokenizer.src_lang = "ar_AR"
encoded_ar = tokenizer(article_ar, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "The Secretary-General of the United Nations says there is no military solution in Syria."
```
With `dlt`, you can run:
```python
import dl_translate as dlt
article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."
mt = dlt.TranslationModel()
translated_fr = mt.translate(article_hi, source=dlt.lang.HINDI, target=dlt.lang.FRENCH)
translated_en = mt.translate(article_ar, source=dlt.lang.ARABIC, target=dlt.lang.ENGLISH)
```
Notice you don't have to think about tokenizers, condition generation, pretrained models, and regional codes; you can just tell the model what to translate!
If you are experienced with `huggingface`'s ecosystem, then you should be familiar enough with the example above that you wouldn't need this library. However, if you've never heard of huggingface or mBART, then I hope using this library will give you enough motivation to [learn more about them](https://github.com/huggingface/transformers) :)
================================================
FILE: demos/colab_demo.ipynb
================================================
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.9",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"colab": {
"name": "dl-translate demo.ipynb",
"provenance": [],
"collapsed_sections": []
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"9695e0e8562c4104b8e28a25bf05991e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_7e7d388cb3ea475098dcca168cba2635",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_7424952a9ad34bc9a0094ed6df2881ab",
"IPY_MODEL_bd39d9cea7f949babdb9048b41f4055f"
]
}
},
"7e7d388cb3ea475098dcca168cba2635": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"7424952a9ad34bc9a0094ed6df2881ab": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_fee217673faf4424883be8bf33574584",
"_dom_classes": [],
"description": "Downloading: 100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 3708092,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 3708092,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_2ad1b48139f14387bac9b6046b0d8d60"
}
},
"bd39d9cea7f949babdb9048b41f4055f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_79733c788f914b66a5803f1b71554c26",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 3.71M/3.71M [00:03<00:00, 1.22MB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_cdc4f3fc6bf34edfa6ed14d1edc1c0da"
}
},
"fee217673faf4424883be8bf33574584": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"2ad1b48139f14387bac9b6046b0d8d60": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"79733c788f914b66a5803f1b71554c26": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"cdc4f3fc6bf34edfa6ed14d1edc1c0da": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"451f6d64f9d14cfe821ef82af4313223": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_d213c156b7f941de86f89704200898cc",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_d2ab70c5a6b24d3582e13bd7300e4b48",
"IPY_MODEL_d596d132f81a4577ab6eeff01c858f4a"
]
}
},
"d213c156b7f941de86f89704200898cc": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"d2ab70c5a6b24d3582e13bd7300e4b48": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_e4ef14da423a4e37b23815228dff047c",
"_dom_classes": [],
"description": "Downloading: 100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 2423393,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 2423393,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_f489df53fce4489c85248d0a0fc0347b"
}
},
"d596d132f81a4577ab6eeff01c858f4a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_347117e2b4084588af0120de4cba9992",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 2.42M/2.42M [00:00<00:00, 5.19MB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_f7dde0ddeb954feeb6334f91957b3b4f"
}
},
"e4ef14da423a4e37b23815228dff047c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"f489df53fce4489c85248d0a0fc0347b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"347117e2b4084588af0120de4cba9992": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"f7dde0ddeb954feeb6334f91957b3b4f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"d85f7d324e534e7fb7665440581927af": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_36a4ac5796544fe5b50e802cd0b4b599",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_65782adbd1e943f39a3d282e4f217daf",
"IPY_MODEL_d78f6159d9fe4b0e840b2d1912a253d8"
]
}
},
"36a4ac5796544fe5b50e802cd0b4b599": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"65782adbd1e943f39a3d282e4f217daf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_db2021d70d584ef987b1580962ad919c",
"_dom_classes": [],
"description": "Downloading: 100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 272,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 272,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_2cc9fc8c66aa4381ab3a6e742d0547dc"
}
},
"d78f6159d9fe4b0e840b2d1912a253d8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_256273b11d094044a36d0056498cc487",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 272/272 [00:01<00:00, 236B/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_361e88b5998d4df996113c7b6c55bd1e"
}
},
"db2021d70d584ef987b1580962ad919c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"2cc9fc8c66aa4381ab3a6e742d0547dc": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"256273b11d094044a36d0056498cc487": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"361e88b5998d4df996113c7b6c55bd1e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"0236fdbddbcc4049b48713bb17ba4b74": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_878fc2e400354baf84a8a134d9038174",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_e2a4fd0970f94aeb8b841870d95662a8",
"IPY_MODEL_bc263f4b7ac94cfeaa7609341d5394e2"
]
}
},
"878fc2e400354baf84a8a134d9038174": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"e2a4fd0970f94aeb8b841870d95662a8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_266b8b8d1dda42afa1b5fdafb66af4ac",
"_dom_classes": [],
"description": "Downloading: 100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 1140,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 1140,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_dc8604c9207b4a8ca364a817c6516378"
}
},
"bc263f4b7ac94cfeaa7609341d5394e2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_f1958b44d65243a0b4ca5ba2803d3598",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 1.14k/1.14k [00:00<00:00, 2.38kB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_d26580979feb41cdb423d151bd5ebd15"
}
},
"266b8b8d1dda42afa1b5fdafb66af4ac": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"dc8604c9207b4a8ca364a817c6516378": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"f1958b44d65243a0b4ca5ba2803d3598": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"d26580979feb41cdb423d151bd5ebd15": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"e7a58c85b79041eaaeeb9a9c3fa102aa": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_57b95194b9c0404d84ee4a73208e10b0",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_477c9d28d9624f448b75b8c38207fed5",
"IPY_MODEL_e2e7b8f3477c4fd486a0602377992efe"
]
}
},
"57b95194b9c0404d84ee4a73208e10b0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"477c9d28d9624f448b75b8c38207fed5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_72755dae20114e74a26fa5f556eed8c8",
"_dom_classes": [],
"description": "Downloading: 100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 908,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 908,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_c961071096794f7580d0d37db02e8ff0"
}
},
"e2e7b8f3477c4fd486a0602377992efe": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_cf26e32a7c6b48caa2d7727365d33da3",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 908/908 [00:00<00:00, 25.8kB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_2edbb013778f4cb2a05510875c8b1d66"
}
},
"72755dae20114e74a26fa5f556eed8c8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"c961071096794f7580d0d37db02e8ff0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"cf26e32a7c6b48caa2d7727365d33da3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"2edbb013778f4cb2a05510875c8b1d66": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"b4bb0987873c46f9865e5308e2cd6cab": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_8fa51cc1231146d9be01452d4baedcab",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_5123c4c6e7524b4983ca0701394af3a0",
"IPY_MODEL_4d97b210440f431bb5440d5b3b49032b"
]
}
},
"8fa51cc1231146d9be01452d4baedcab": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"5123c4c6e7524b4983ca0701394af3a0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_d5a94e8542ac48babbc228b80e8c5c91",
"_dom_classes": [],
"description": "Downloading: 63%",
"_model_name": "FloatProgressModel",
"bar_style": "",
"max": 1935796948,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 1223096320,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_47f2a197eb0b47d583ec331a89a7657f"
}
},
"4d97b210440f431bb5440d5b3b49032b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_c1998ef598ff4359b0de7961665dba6d",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 1.22G/1.94G [00:22<00:13, 52.5MB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_ba0a677b24604277981595701dfc4574"
}
},
"d5a94e8542ac48babbc228b80e8c5c91": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"47f2a197eb0b47d583ec331a89a7657f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"c1998ef598ff4359b0de7961665dba6d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"ba0a677b24604277981595701dfc4574": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "tx6xJha5YIiA"
},
"source": [
"# DL Translate\n",
"\n",
"*A deep learning-based translation library built on Huggingface `transformers` and Facebook's `mBART-Large`*\n",
"\n",
"💻 [GitHub Repository](https://github.com/xhlulu/dl-translate)\\\n",
"📚 [Documentation](https://git.io/dlt-docs) / [readthedocs](https://dl-translate.readthedocs.io)\\\n",
"🐍 [PyPi project](https://pypi.org/project/dl-translate/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bCjxVhyxYIiD"
},
"source": [
"## Quickstart\n",
"\n",
"Install the library with pip:"
]
},
{
"cell_type": "code",
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"trusted": true,
"_kg_hide-input": false,
"_kg_hide-output": true,
"id": "BI3mAoRnYIiF",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c5674616-a545-4a8a-9813-b127c1777efa"
},
"source": [
"!pip install -q dl-translate"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 1.2MB 9.2MB/s \n",
"\u001b[K |████████████████████████████████| 2.2MB 29.2MB/s \n",
"\u001b[K |████████████████████████████████| 870kB 50.4MB/s \n",
"\u001b[K |████████████████████████████████| 3.3MB 51.1MB/s \n",
"\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p1oeb4czYIiG"
},
"source": [
"To translate some text:"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300,
"referenced_widgets": [
"9695e0e8562c4104b8e28a25bf05991e",
"7e7d388cb3ea475098dcca168cba2635",
"7424952a9ad34bc9a0094ed6df2881ab",
"bd39d9cea7f949babdb9048b41f4055f",
"fee217673faf4424883be8bf33574584",
"2ad1b48139f14387bac9b6046b0d8d60",
"79733c788f914b66a5803f1b71554c26",
"cdc4f3fc6bf34edfa6ed14d1edc1c0da",
"451f6d64f9d14cfe821ef82af4313223",
"d213c156b7f941de86f89704200898cc",
"d2ab70c5a6b24d3582e13bd7300e4b48",
"d596d132f81a4577ab6eeff01c858f4a",
"e4ef14da423a4e37b23815228dff047c",
"f489df53fce4489c85248d0a0fc0347b",
"347117e2b4084588af0120de4cba9992",
"f7dde0ddeb954feeb6334f91957b3b4f",
"d85f7d324e534e7fb7665440581927af",
"36a4ac5796544fe5b50e802cd0b4b599",
"65782adbd1e943f39a3d282e4f217daf",
"d78f6159d9fe4b0e840b2d1912a253d8",
"db2021d70d584ef987b1580962ad919c",
"2cc9fc8c66aa4381ab3a6e742d0547dc",
"256273b11d094044a36d0056498cc487",
"361e88b5998d4df996113c7b6c55bd1e",
"0236fdbddbcc4049b48713bb17ba4b74",
"878fc2e400354baf84a8a134d9038174",
"e2a4fd0970f94aeb8b841870d95662a8",
"bc263f4b7ac94cfeaa7609341d5394e2",
"266b8b8d1dda42afa1b5fdafb66af4ac",
"dc8604c9207b4a8ca364a817c6516378",
"f1958b44d65243a0b4ca5ba2803d3598",
"d26580979feb41cdb423d151bd5ebd15",
"e7a58c85b79041eaaeeb9a9c3fa102aa",
"57b95194b9c0404d84ee4a73208e10b0",
"477c9d28d9624f448b75b8c38207fed5",
"e2e7b8f3477c4fd486a0602377992efe",
"72755dae20114e74a26fa5f556eed8c8",
"c961071096794f7580d0d37db02e8ff0",
"cf26e32a7c6b48caa2d7727365d33da3",
"2edbb013778f4cb2a05510875c8b1d66",
"b4bb0987873c46f9865e5308e2cd6cab",
"8fa51cc1231146d9be01452d4baedcab",
"5123c4c6e7524b4983ca0701394af3a0",
"4d97b210440f431bb5440d5b3b49032b",
"d5a94e8542ac48babbc228b80e8c5c91",
"47f2a197eb0b47d583ec331a89a7657f",
"c1998ef598ff4359b0de7961665dba6d",
"ba0a677b24604277981595701dfc4574"
]
},
"id": "qdefSjR_YIiG",
"outputId": "a1002eb7-cceb-45ee-dbb9-6a7329860af1"
},
"source": [
"import dl_translate as dlt\n",
"\n",
"mt = dlt.TranslationModel()\n",
"\n",
"text_hi = \"संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है\"\n",
"mt.translate(text_hi, source=dlt.lang.HINDI, target=dlt.lang.ENGLISH)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9695e0e8562c4104b8e28a25bf05991e",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3708092.0, style=ProgressStyle(descript…"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "451f6d64f9d14cfe821ef82af4313223",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2423393.0, style=ProgressStyle(descript…"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d85f7d324e534e7fb7665440581927af",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=272.0, style=ProgressStyle(description_…"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0236fdbddbcc4049b48713bb17ba4b74",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1140.0, style=ProgressStyle(description…"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e7a58c85b79041eaaeeb9a9c3fa102aa",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=908.0, style=ProgressStyle(description_…"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b4bb0987873c46f9865e5308e2cd6cab",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1935796948.0, style=ProgressStyle(descr…"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DDQGpznwYIiH"
},
"source": [
"Above, you can see that `dlt.lang` contains variables representing each of the 50 available languages with auto-complete support. Alternatively, you can specify the language (e.g. \"Arabic\") or the language code (e.g. \"fr\" for French):"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"id": "yC3LMjmNYIiI"
},
"source": [
"text_ar = \"الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.\"\n",
"mt.translate(text_ar, source=\"Arabic\", target=\"fr\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "erptbvbiYIiI"
},
"source": [
"If you want to verify whether a language is available, you can check it:"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"_kg_hide-output": false,
"id": "saHalYvsYIiJ"
},
"source": [
"print(mt.available_languages()) # All languages that you can use\n",
"print(mt.available_codes()) # Code corresponding to each language accepted\n",
"print(mt.get_lang_code_map()) # Dictionary of lang -> code"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"trusted": true,
"id": "lz3Rq5t0YIiJ"
},
"source": [
"## Usage\n",
"\n",
"### Selecting a device\n",
"\n",
"When you load the model, you can specify the device:\n",
"```python\n",
"mt = dlt.TranslationModel(device=\"auto\")\n",
"```\n",
"\n",
"By default, the value will be `device=\"auto\"`, which means it will use a GPU if possible. You can also explicitly set `device=\"cpu\"` or `device=\"gpu\"`, or some other strings accepted by [`torch.device()`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device). __In general, it is recommend to use a GPU if you want a reasonable processing time.__\n",
"\n",
"Let's check what we originally loaded:"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"id": "fKAlEmzbYIiJ"
},
"source": [
"mt.device"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nm7hVlSZYIiL"
},
"source": [
"### Loading from a path\n",
"\n",
"By default, `dlt.TranslationModel` will download the model from the [huggingface repo](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt) and cache it. However, you are free to load from a path:\n",
"```python\n",
"mt = dlt.TranslationModel(\"/path/to/your/model/directory/\", model_family=\"mbart50\")\n",
"```\n",
"Make sure that your tokenizer is also stored in the same directory if you use this approach.\n",
"\n",
"\n",
"### Using a different model\n",
"\n",
"You can also choose another model that has [a similar format](https://huggingface.co/models?filter=mbart-50). In those cases, it's preferable to specify the model family:\n",
"```python\n",
"mt = dlt.TranslationModel(\"facebook/mbart-large-50-one-to-many-mmt\")\n",
"mt = dlt.TranslationModel(\"facebook/m2m100_1.2B\", model_family=\"m2m100\")\n",
"```\n",
"Note that the available languages will change if you do this, so you will not be able to leverage `dlt.lang` or `dlt.utils`.\n",
"\n",
"\n",
"### Breaking down into sentences\n",
"\n",
"It is not recommended to use extremely long texts as it takes more time to process. Instead, you can try to break them down into sentences with the help of `nltk`. First install the library with `pip install nltk`, then run:"
]
},
{
"cell_type": "code",
"metadata": {
"_kg_hide-output": true,
"trusted": true,
"id": "XkjrydidYIiL"
},
"source": [
"import nltk\n",
"nltk.download(\"punkt\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"id": "j-cyjxQCYIiL"
},
"source": [
"text = \"Mr. Smith went to his favorite cafe. There, he met his friend Dr. Doe.\"\n",
"sents = nltk.tokenize.sent_tokenize(text, \"english\") # don't use dlt.lang.ENGLISH\n",
"\" \".join(mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "u6y8-SthYIiM"
},
"source": [
"### Setting a `batch_size` and verbosity when calling `dlt.TranslationModel.translate`\n",
"\n",
"It's possible to set a batch size (i.e. the number of elements processed at once) for `mt.translate` and whether you want to see the progress bar or not:"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"id": "fcxUFmjAYIiM"
},
"source": [
"mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH, batch_size=32, verbose=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "iS0bfSh_YIiM"
},
"source": [
"If you set `batch_size=None`, it will compute the entire `text` at once rather than splitting into \"chunks\". We recommend lowering `batch_size` if you do not have a lot of RAM or VRAM and run into CUDA memory error. Set a higher value if you are using a high-end GPU and the VRAM is not fully utilized.\n",
"\n",
"\n",
"### `dlt.utils` module\n",
"\n",
"An alternative to `mt.available_languages()` is the `dlt.utils` module. You can use it to find out which languages and codes are available:\n"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"id": "U7iS_wKTYIiM"
},
"source": [
"print(dlt.utils.available_languages('mbart50')) # All languages that you can use\n",
"print(dlt.utils.available_codes('mbart50')) # Code corresponding to each language accepted\n",
"print(dlt.utils.get_lang_code_map('mbart50')) # Dictionary of lang -> code"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFN5AplfYIiM"
},
"source": [
"## Advanced\n",
"\n",
"The following section assumes you have knowledge of PyTorch and Huggingface Transformers.\n",
"\n",
"### Saving and loading\n",
"\n",
"If you wish to accelerate the loading time the translation model, you can use `save_obj`:\n"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"id": "UaDQVFlzYIiN"
},
"source": [
"mt.save_obj(\"saved_model\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "c_uMZ9exYIiN"
},
"source": [
"\n",
"Then later you can reload it with `load_obj`:"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"id": "SEkmXUDaYIiN"
},
"source": [
"%%time\n",
"mt = dlt.TranslationModel.load_obj('saved_model')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "j4G9tRL8YIiO"
},
"source": [
"\n",
"**Warning:** Only use this if you are certain the torch module saved in `saved_model/weights.pt` can be correctly loaded. Indeed, it is possible that the `huggingface`, `torch` or some other dependencies change between when you called `save_obj` and `load_obj`, and that might break your code. Thus, it is recommend to only run `load_obj` in the same environment/session as `save_obj`. **Note this method might be deprecated in the future once there's no speed benefit in loading this way.**\n",
"\n",
"\n",
"### Interacting with underlying model and tokenizer\n",
"\n",
"When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer (which will respectively be passed to `MBartForConditionalGeneration.from_pretrained` and `MBart50TokenizerFast.from_pretrained`):\n",
"\n",
"```python\n",
"mt = dlt.TranslationModel(\n",
" model_options=dict(\n",
" state_dict=...,\n",
" cache_dir=...,\n",
" ...\n",
" ),\n",
" tokenizer_options=dict(\n",
" tokenizer_file=...,\n",
" eos_token=...,\n",
" ...\n",
" )\n",
")\n",
"```\n",
"\n",
"You can also access the underlying `transformers` model and `tokenizer`:"
]
},
{
"cell_type": "code",
"metadata": {
"_kg_hide-output": true,
"trusted": true,
"id": "UX1lyl_uYIiO"
},
"source": [
"bart = mt.get_transformers_model()\n",
"tokenizer = mt.get_tokenizer()\n",
"\n",
"print(tokenizer)\n",
"print(bart)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "QQPqtEeuYIiO"
},
"source": [
"See the [huggingface docs](https://huggingface.co/transformers/master/model_doc/mbart.html) for more information.\n",
"\n",
"\n",
"### `bart_model.generate()` keyword arguments\n",
"\n",
"When running `mt.translate`, you can also give a `generation_options` dictionary that is passed as keyword arguments to the underlying `bart_model.generate()` method:"
]
},
{
"cell_type": "code",
"metadata": {
"trusted": true,
"id": "XuTbvJBWYIiP"
},
"source": [
"mt.translate(\n",
" sents,\n",
" source=dlt.lang.ENGLISH,\n",
" target=dlt.lang.SPANISH,\n",
" generation_options=dict(num_beams=5, max_length=128)\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "QvwPa2b1YIiP"
},
"source": [
"Learn more in the [huggingface docs](https://huggingface.co/transformers/main_classes/model.html#transformers.generation_utils.GenerationMixin.generate).\n",
"\n",
"\n",
"## Acknowledgement\n",
"\n",
"`dl-translate` is built on top of Huggingface's implementation of two models created by Facebook AI Research.\n",
"\n",
"1. The multilingual BART finetuned on many-to-many translation of over 50 languages, which is [documented here](https://huggingface.co/transformers/master/model_doc/mbart.html) The original paper was written by Tang et. al from Facebook AI Research; you can [find it here](https://arxiv.org/pdf/2008.00401.pdf) and cite it using the following:\n",
" ```\n",
" @article{tang2020multilingual,\n",
" title={Multilingual translation with extensible multilingual pretraining and finetuning},\n",
" author={Tang, Yuqing and Tran, Chau and Li, Xian and Chen, Peng-Jen and Goyal, Naman and Chaudhary, Vishrav and Gu, Jiatao and Fan, Angela},\n",
" journal={arXiv preprint arXiv:2008.00401},\n",
" year={2020}\n",
" }\n",
" ```\n",
"2. The transformer model published in [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Fan et. al, which supports over 100 languages. You can cite it here:\n",
" ```\n",
" @misc{fan2020englishcentric,\n",
" title={Beyond English-Centric Multilingual Machine Translation}, \n",
" author={Angela Fan and Shruti Bhosale and Holger Schwenk and Zhiyi Ma and Ahmed El-Kishky and Siddharth Goyal and Mandeep Baines and Onur Celebi and Guillaume Wenzek and Vishrav Chaudhary and Naman Goyal and Tom Birch and Vitaliy Liptchinsky and Sergey Edunov and Edouard Grave and Michael Auli and Armand Joulin},\n",
" year={2020},\n",
" eprint={2010.11125},\n",
" archivePrefix={arXiv},\n",
" primaryClass={cs.CL}\n",
" }\n",
" ```\n",
"\n",
"`dlt` is a wrapper with useful `utils` to save you time. For huggingface's `transformers`, the following snippet is shown as an example:\n",
"```python\n",
"from transformers import MBartForConditionalGeneration, MBart50TokenizerFast\n",
"\n",
"article_hi = \"संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है\"\n",
"article_ar = \"الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.\"\n",
"\n",
"model = MBartForConditionalGeneration.from_pretrained(\"facebook/mbart-large-50-many-to-many-mmt\")\n",
"tokenizer = MBart50TokenizerFast.from_pretrained(\"facebook/mbart-large-50-many-to-many-mmt\")\n",
"\n",
"# translate Hindi to French\n",
"tokenizer.src_lang = \"hi_IN\"\n",
"encoded_hi = tokenizer(article_hi, return_tensors=\"pt\")\n",
"generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id[\"fr_XX\"])\n",
"tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
"# => \"Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria.\"\n",
"\n",
"# translate Arabic to English\n",
"tokenizer.src_lang = \"ar_AR\"\n",
"encoded_ar = tokenizer(article_ar, return_tensors=\"pt\")\n",
"generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id[\"en_XX\"])\n",
"tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
"# => \"The Secretary-General of the United Nations says there is no military solution in Syria.\"\n",
"```\n",
"\n",
"With `dlt`, you can run:\n",
"```python\n",
"import dl_translate as dlt\n",
"\n",
"article_hi = \"संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है\"\n",
"article_ar = \"الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.\"\n",
"\n",
"mt = dlt.TranslationModel()\n",
"translated_fr = mt.translate(article_hi, source=dlt.lang.HINDI, target=dlt.lang.FRENCH)\n",
"translated_en = mt.translate(article_ar, source=dlt.lang.ARABIC, target=dlt.lang.ENGLISH)\n",
"```\n",
"\n",
"Notice you don't have to think about tokenizers, condition generation, pretrained models, and regional codes; you can just tell the model what to translate!\n",
"\n",
"If you are experienced with `huggingface`'s ecosystem, then you should be familiar enough with the example above that you wouldn't need this library. However, if you've never heard of huggingface or mBART, then I hope using this library will give you enough motivation to [learn more about them](https://github.com/huggingface/transformers) :)"
]
}
]
}
================================================
FILE: demos/nllb_demo.ipynb
================================================
{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2023-07-18T05:15:13.999614Z","iopub.status.busy":"2023-07-18T05:15:13.999228Z","iopub.status.idle":"2023-07-18T05:15:31.978108Z","shell.execute_reply":"2023-07-18T05:15:31.976681Z","shell.execute_reply.started":"2023-07-18T05:15:13.999573Z"},"trusted":true},"outputs":[],"source":["!pip install dl-translate==0.3.* -q"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:15:31.982361Z","iopub.status.busy":"2023-07-18T05:15:31.981992Z","iopub.status.idle":"2023-07-18T05:16:23.731908Z","shell.execute_reply":"2023-07-18T05:16:23.730776Z","shell.execute_reply.started":"2023-07-18T05:15:31.982327Z"},"trusted":true},"outputs":[],"source":["import dl_translate as dlt\n","\n","mt = dlt.TranslationModel(\"nllb200\")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:23.734336Z","iopub.status.busy":"2023-07-18T05:16:23.733295Z","iopub.status.idle":"2023-07-18T05:16:28.025038Z","shell.execute_reply":"2023-07-18T05:16:28.023933Z","shell.execute_reply.started":"2023-07-18T05:16:23.734293Z"},"trusted":true},"outputs":[],"source":["text = \"Meta AI has built a single AI model capable of translating across 200 different languages with state-of-the-art quality.\"\n","\n","# The new translation is much faster than before\n","%time print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.FRENCH))"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:28.028919Z","iopub.status.busy":"2023-07-18T05:16:28.028286Z","iopub.status.idle":"2023-07-18T05:16:28.717521Z","shell.execute_reply":"2023-07-18T05:16:28.716343Z","shell.execute_reply.started":"2023-07-18T05:16:28.028882Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["मेटाएआई एकमेव एआई मॉडलं निर्मितवान्, यत् 200 भिन्नभाषायां अवधीतमतमतमगुणैः अनुवादं कर्तुं समर्थः अस्ति।\n"]}],"source":["# Sanskrit is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.SANSKRIT))"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:28.719596Z","iopub.status.busy":"2023-07-18T05:16:28.719227Z","iopub.status.idle":"2023-07-18T05:16:29.443696Z","shell.execute_reply":"2023-07-18T05:16:29.442668Z","shell.execute_reply.started":"2023-07-18T05:16:28.719560Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["Meta AI hà custruitu un solu mudellu d'AI capace di tradurisce in 200 lingue sfarenti cù qualità di u statu di l'arte.\n"]}],"source":["# Sicilian is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.SICILIAN))"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:29.447147Z","iopub.status.busy":"2023-07-18T05:16:29.445331Z","iopub.status.idle":"2023-07-18T05:16:30.145637Z","shell.execute_reply":"2023-07-18T05:16:30.144623Z","shell.execute_reply.started":"2023-07-18T05:16:29.447108Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["基於Meta AI 建立咗一個 AI 模型 可以用最先端嘅質量翻譯到 200 個唔同語言\n"]}],"source":["# Yue Chinese is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.YUE_CHINESE))"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4}
================================================
FILE: dl_translate/__init__.py
================================================
from . import lang
from . import utils
from ._translation_model import TranslationModel
================================================
FILE: dl_translate/_pairs.py
================================================
# Auto-generated. Do not modify, use scripts/generate_langs.py instead.
_PAIRS_M2M100 = (
("Afrikaans", "af"),
("Amharic", "am"),
("Arabic", "ar"),
("Asturian", "ast"),
("Azerbaijani", "az"),
("Bashkir", "ba"),
("Belarusian", "be"),
("Bulgarian", "bg"),
("Bengali", "bn"),
("Breton", "br"),
("Bosnian", "bs"),
("Catalan", "ca"),
("Valencian", "ca"),
("Cebuano", "ceb"),
("Czech", "cs"),
("Welsh", "cy"),
("Danish", "da"),
("German", "de"),
("Greek", "el"),
("English", "en"),
("Spanish", "es"),
("Estonian", "et"),
("Persian", "fa"),
("Fulah", "ff"),
("Finnish", "fi"),
("French", "fr"),
("Western Frisian", "fy"),
("Irish", "ga"),
("Gaelic", "gd"),
("Scottish Gaelic", "gd"),
("Galician", "gl"),
("Gujarati", "gu"),
("Hausa", "ha"),
("Hebrew", "he"),
("Hindi", "hi"),
("Croatian", "hr"),
("Haitian", "ht"),
("Haitian Creole", "ht"),
("Hungarian", "hu"),
("Armenian", "hy"),
("Indonesian", "id"),
("Igbo", "ig"),
("Iloko", "ilo"),
("Icelandic", "is"),
("Italian", "it"),
("Japanese", "ja"),
("Javanese", "jv"),
("Georgian", "ka"),
("Kazakh", "kk"),
("Khmer", "km"),
("Central Khmer", "km"),
("Kannada", "kn"),
("Korean", "ko"),
("Luxembourgish", "lb"),
("Letzeburgesch", "lb"),
("Ganda", "lg"),
("Lingala", "ln"),
("Lao", "lo"),
("Lithuanian", "lt"),
("Latvian", "lv"),
("Malagasy", "mg"),
("Macedonian", "mk"),
("Malayalam", "ml"),
("Mongolian", "mn"),
("Marathi", "mr"),
("Malay", "ms"),
("Burmese", "my"),
("Nepali", "ne"),
("Dutch", "nl"),
("Flemish", "nl"),
("Norwegian", "no"),
("Northern Sotho", "ns"),
("Occitan", "oc"),
("Oriya", "or"),
("Panjabi", "pa"),
("Punjabi", "pa"),
("Polish", "pl"),
("Pushto", "ps"),
("Pashto", "ps"),
("Portuguese", "pt"),
("Romanian", "ro"),
("Moldavian", "ro"),
("Moldovan", "ro"),
("Russian", "ru"),
("Sindhi", "sd"),
("Sinhala", "si"),
("Sinhalese", "si"),
("Slovak", "sk"),
("Slovenian", "sl"),
("Somali", "so"),
("Albanian", "sq"),
("Serbian", "sr"),
("Swati", "ss"),
("Sundanese", "su"),
("Swedish", "sv"),
("Swahili", "sw"),
("Tamil", "ta"),
("Thai", "th"),
("Tagalog", "tl"),
("Tswana", "tn"),
("Turkish", "tr"),
("Ukrainian", "uk"),
("Urdu", "ur"),
("Uzbek", "uz"),
("Vietnamese", "vi"),
("Wolof", "wo"),
("Xhosa", "xh"),
("Yiddish", "yi"),
("Yoruba", "yo"),
("Chinese", "zh"),
("Zulu", "zu"),
)
_PAIRS_MBART50 = (
("Arabic", "ar_AR"),
("Czech", "cs_CZ"),
("German", "de_DE"),
("English", "en_XX"),
("Spanish", "es_XX"),
("Estonian", "et_EE"),
("Finnish", "fi_FI"),
("French", "fr_XX"),
("Gujarati", "gu_IN"),
("Hindi", "hi_IN"),
("Italian", "it_IT"),
("Japanese", "ja_XX"),
("Kazakh", "kk_KZ"),
("Korean", "ko_KR"),
("Lithuanian", "lt_LT"),
("Latvian", "lv_LV"),
("Burmese", "my_MM"),
("Nepali", "ne_NP"),
("Dutch", "nl_XX"),
("Romanian", "ro_RO"),
("Russian", "ru_RU"),
("Sinhala", "si_LK"),
("Turkish", "tr_TR"),
("Vietnamese", "vi_VN"),
("Chinese", "zh_CN"),
("Afrikaans", "af_ZA"),
("Azerbaijani", "az_AZ"),
("Bengali", "bn_IN"),
("Persian", "fa_IR"),
("Hebrew", "he_IL"),
("Croatian", "hr_HR"),
("Indonesian", "id_ID"),
("Georgian", "ka_GE"),
("Khmer", "km_KH"),
("Macedonian", "mk_MK"),
("Malayalam", "ml_IN"),
("Mongolian", "mn_MN"),
("Marathi", "mr_IN"),
("Polish", "pl_PL"),
("Pashto", "ps_AF"),
("Portuguese", "pt_XX"),
("Swedish", "sv_SE"),
("Swahili", "sw_KE"),
("Tamil", "ta_IN"),
("Telugu", "te_IN"),
("Thai", "th_TH"),
("Tagalog", "tl_XX"),
("Ukrainian", "uk_UA"),
("Urdu", "ur_PK"),
("Xhosa", "xh_ZA"),
("Galician", "gl_ES"),
("Slovene", "sl_SI"),
)
_PAIRS_NLLB200 = (
("Acehnese (Arabic script)", "ace_Arab"),
("Acehnese (Latin script)", "ace_Latn"),
("Mesopotamian Arabic", "acm_Arab"),
("Ta'izzi-Adeni Arabic", "acq_Arab"),
("Tunisian Arabic", "aeb_Arab"),
("Afrikaans", "afr_Latn"),
("South Levantine Arabic", "ajp_Arab"),
("Akan", "aka_Latn"),
("Amharic", "amh_Ethi"),
("North Levantine Arabic", "apc_Arab"),
("Modern Standard Arabic", "arb_Arab"),
("Modern Standard Arabic (Romanized)", "arb_Latn"),
("Najdi Arabic", "ars_Arab"),
("Moroccan Arabic", "ary_Arab"),
("Egyptian Arabic", "arz_Arab"),
("Assamese", "asm_Beng"),
("Asturian", "ast_Latn"),
("Awadhi", "awa_Deva"),
("Central Aymara", "ayr_Latn"),
("South Azerbaijani", "azb_Arab"),
("North Azerbaijani", "azj_Latn"),
("Bashkir", "bak_Cyrl"),
("Bambara", "bam_Latn"),
("Balinese", "ban_Latn"),
("Belarusian", "bel_Cyrl"),
("Bemba", "bem_Latn"),
("Bengali", "ben_Beng"),
("Bhojpuri", "bho_Deva"),
("Banjar (Arabic script)", "bjn_Arab"),
("Banjar (Latin script)", "bjn_Latn"),
("Standard Tibetan", "bod_Tibt"),
("Bosnian", "bos_Latn"),
("Buginese", "bug_Latn"),
("Bulgarian", "bul_Cyrl"),
("Catalan", "cat_Latn"),
("Cebuano", "ceb_Latn"),
("Czech", "ces_Latn"),
("Chokwe", "cjk_Latn"),
("Central Kurdish", "ckb_Arab"),
("Crimean Tatar", "crh_Latn"),
("Welsh", "cym_Latn"),
("Danish", "dan_Latn"),
("German", "deu_Latn"),
("Southwestern Dinka", "dik_Latn"),
("Dyula", "dyu_Latn"),
("Dzongkha", "dzo_Tibt"),
("Greek", "ell_Grek"),
("English", "eng_Latn"),
("Esperanto", "epo_Latn"),
("Estonian", "est_Latn"),
("Basque", "eus_Latn"),
("Ewe", "ewe_Latn"),
("Faroese", "fao_Latn"),
("Fijian", "fij_Latn"),
("Finnish", "fin_Latn"),
("Fon", "fon_Latn"),
("French", "fra_Latn"),
("Friulian", "fur_Latn"),
("Nigerian Fulfulde", "fuv_Latn"),
("Scottish Gaelic", "gla_Latn"),
("Irish", "gle_Latn"),
("Galician", "glg_Latn"),
("Guarani", "grn_Latn"),
("Gujarati", "guj_Gujr"),
("Haitian Creole", "hat_Latn"),
("Hausa", "hau_Latn"),
("Hebrew", "heb_Hebr"),
("Hindi", "hin_Deva"),
("Chhattisgarhi", "hne_Deva"),
("Croatian", "hrv_Latn"),
("Hungarian", "hun_Latn"),
("Armenian", "hye_Armn"),
("Igbo", "ibo_Latn"),
("Ilocano", "ilo_Latn"),
("Indonesian", "ind_Latn"),
("Icelandic", "isl_Latn"),
("Italian", "ita_Latn"),
("Javanese", "jav_Latn"),
("Japanese", "jpn_Jpan"),
("Kabyle", "kab_Latn"),
("Jingpho", "kac_Latn"),
("Kamba", "kam_Latn"),
("Kannada", "kan_Knda"),
("Kashmiri (Arabic script)", "kas_Arab"),
("Kashmiri (Devanagari script)", "kas_Deva"),
("Georgian", "kat_Geor"),
("Central Kanuri (Arabic script)", "knc_Arab"),
("Central Kanuri (Latin script)", "knc_Latn"),
("Kazakh", "kaz_Cyrl"),
("Kabiyè", "kbp_Latn"),
("Kabuverdianu", "kea_Latn"),
("Khmer", "khm_Khmr"),
("Kikuyu", "kik_Latn"),
("Kinyarwanda", "kin_Latn"),
("Kyrgyz", "kir_Cyrl"),
("Kimbundu", "kmb_Latn"),
("Northern Kurdish", "kmr_Latn"),
("Kikongo", "kon_Latn"),
("Korean", "kor_Hang"),
("Lao", "lao_Laoo"),
("Ligurian", "lij_Latn"),
("Limburgish", "lim_Latn"),
("Lingala", "lin_Latn"),
("Lithuanian", "lit_Latn"),
("Lombard", "lmo_Latn"),
("Latgalian", "ltg_Latn"),
("Luxembourgish", "ltz_Latn"),
("Luba-Kasai", "lua_Latn"),
("Ganda", "lug_Latn"),
("Luo", "luo_Latn"),
("Mizo", "lus_Latn"),
("Standard Latvian", "lvs_Latn"),
("Magahi", "mag_Deva"),
("Maithili", "mai_Deva"),
("Malayalam", "mal_Mlym"),
("Marathi", "mar_Deva"),
("Minangkabau (Arabic script)", "min_Arab"),
("Minangkabau (Latin script)", "min_Latn"),
("Macedonian", "mkd_Cyrl"),
("Plateau Malagasy", "plt_Latn"),
("Maltese", "mlt_Latn"),
("Meitei (Bengali script)", "mni_Beng"),
("Halh Mongolian", "khk_Cyrl"),
("Mossi", "mos_Latn"),
("Maori", "mri_Latn"),
("Burmese", "mya_Mymr"),
("Dutch", "nld_Latn"),
("Norwegian Nynorsk", "nno_Latn"),
("Norwegian Bokmål", "nob_Latn"),
("Nepali", "npi_Deva"),
("Northern Sotho", "nso_Latn"),
("Nuer", "nus_Latn"),
("Nyanja", "nya_Latn"),
("Occitan", "oci_Latn"),
("West Central Oromo", "gaz_Latn"),
("Odia", "ory_Orya"),
("Pangasinan", "pag_Latn"),
("Eastern Panjabi", "pan_Guru"),
("Papiamento", "pap_Latn"),
("Western Persian", "pes_Arab"),
("Polish", "pol_Latn"),
("Portuguese", "por_Latn"),
("Dari", "prs_Arab"),
("Southern Pashto", "pbt_Arab"),
("Ayacucho Quechua", "quy_Latn"),
("Romanian", "ron_Latn"),
("Rundi", "run_Latn"),
("Russian", "rus_Cyrl"),
("Sango", "sag_Latn"),
("Sanskrit", "san_Deva"),
("Santali", "sat_Olck"),
("Sicilian", "scn_Latn"),
("Shan", "shn_Mymr"),
("Sinhala", "sin_Sinh"),
("Slovak", "slk_Latn"),
("Slovenian", "slv_Latn"),
("Samoan", "smo_Latn"),
("Shona", "sna_Latn"),
("Sindhi", "snd_Arab"),
("Somali", "som_Latn"),
("Southern Sotho", "sot_Latn"),
("Spanish", "spa_Latn"),
("Tosk Albanian", "als_Latn"),
("Sardinian", "srd_Latn"),
("Serbian", "srp_Cyrl"),
("Swati", "ssw_Latn"),
("Sundanese", "sun_Latn"),
("Swedish", "swe_Latn"),
("Swahili", "swh_Latn"),
("Silesian", "szl_Latn"),
("Tamil", "tam_Taml"),
("Tatar", "tat_Cyrl"),
("Telugu", "tel_Telu"),
("Tajik", "tgk_Cyrl"),
("Tagalog", "tgl_Latn"),
("Thai", "tha_Thai"),
("Tigrinya", "tir_Ethi"),
("Tamasheq (Latin script)", "taq_Latn"),
("Tamasheq (Tifinagh script)", "taq_Tfng"),
("Tok Pisin", "tpi_Latn"),
("Tswana", "tsn_Latn"),
("Tsonga", "tso_Latn"),
("Turkmen", "tuk_Latn"),
("Tumbuka", "tum_Latn"),
("Turkish", "tur_Latn"),
("Twi", "twi_Latn"),
("Central Atlas Tamazight", "tzm_Tfng"),
("Uyghur", "uig_Arab"),
("Ukrainian", "ukr_Cyrl"),
("Umbundu", "umb_Latn"),
("Urdu", "urd_Arab"),
("Northern Uzbek", "uzn_Latn"),
("Venetian", "vec_Latn"),
("Vietnamese", "vie_Latn"),
("Waray", "war_Latn"),
("Wolof", "wol_Latn"),
("Xhosa", "xho_Latn"),
("Eastern Yiddish", "ydd_Hebr"),
("Yoruba", "yor_Latn"),
("Yue Chinese", "yue_Hant"),
("Chinese (Simplified)", "zho_Hans"),
("Chinese (Traditional)", "zho_Hant"),
("Standard Malay", "zsm_Latn"),
("Zulu", "zul_Latn"),
)
================================================
FILE: dl_translate/_translation_model.py
================================================
import os
import json
from typing import Union, List, Dict
import transformers
import torch
from tqdm.auto import tqdm
from . import utils
from .utils import _infer_model_family, _infer_model_or_path
def _select_device(device_selection):
selected = device_selection.lower()
if selected == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif selected == "cpu":
device = torch.device("cpu")
elif selected == "gpu":
device = torch.device("cuda")
else:
device = torch.device(selected)
return device
def _resolve_lang_codes(lang: str, name: str, model_family: str):
def error_message(variable, value):
return f'Your {variable}="{value}" is not valid. Please run `print(mt.available_languages())` to see which languages are available. Make sure you are using the correct capital letters.'
# If can't find in the lang -> code mapping, assumes it's already a code.
lang_code_map = utils.get_lang_code_map(model_family)
if lang in lang_code_map:
code = lang_code_map[lang]
elif lang.capitalize() in lang_code_map:
code = lang_code_map[lang.capitalize()]
else:
lang_upper = lang.upper()
lang_code_map_upper = {k.upper(): v for k, v in lang_code_map.items()}
if lang_upper in lang_code_map_upper:
code = lang_code_map_upper[lang_upper]
else:
code = lang
# If the code is not valid, raises an error
if code not in utils.available_codes(model_family):
raise ValueError(error_message(name, code))
return code
def _resolve_tokenizer(model_family):
di = {
"mbart50": transformers.MBart50TokenizerFast,
"m2m100": transformers.M2M100Tokenizer,
"nllb200": transformers.AutoTokenizer,
}
if model_family in di:
return di[model_family]
else:
error_msg = f"{model_family} is not a valid value for model_family. Please choose model_family to be equal to one of the following values: {list(di.keys())}"
raise ValueError(error_msg)
def _resolve_transformers_model(model_family):
di = {
"mbart50": transformers.MBartForConditionalGeneration,
"m2m100": transformers.M2M100ForConditionalGeneration,
"nllb200": transformers.AutoModelForSeq2SeqLM,
}
if model_family in di:
return di[model_family]
else:
error_msg = f"{model_family} is not a valid value for model_family. Please choose model_family to be equal to one of the following values: {list(di.keys())}"
raise ValueError(error_msg)
class TranslationModel:
def __init__(
self,
model_or_path: str = "m2m100",
tokenizer_path: str = None,
device: str = "auto",
model_family: str = None,
model_options: dict = None,
tokenizer_options: dict = None,
):
"""
*Instantiates a multilingual transformer model for translation.*
{{params}}
{{model_or_path}} The path or the name of the model. Equivalent to the first argument of `AutoModel.from_pretrained()`. You can also specify shorthands ("mbart50" and "m2m100").
{{tokenizer_path}} The path to the tokenizer. By default, it will be set to `model_or_path`.
{{device}} "cpu", "gpu" or "auto". If it's set to "auto", will try to select a GPU when available or else fall back to CPU.
{{model_family}} Either "mbart50" or "m2m100". By default, it will be inferred based on `model_or_path`. Needs to be explicitly set if `model_or_path` is a path.
{{model_options}} The keyword arguments passed to the model, which is a transformer for conditional generation.
{{tokenizer_options}} The keyword arguments passed to the model's tokenizer.
"""
model_or_path = _infer_model_or_path(model_or_path)
self.model_or_path = model_or_path
self.device = _select_device(device)
# Resolve default values
tokenizer_path = tokenizer_path or self.model_or_path
model_options = model_options or {}
tokenizer_options = tokenizer_options or {}
self.model_family = model_family or _infer_model_family(model_or_path)
# Load the tokenizer
TokenizerFast = _resolve_tokenizer(self.model_family)
self._tokenizer = TokenizerFast.from_pretrained(
tokenizer_path, **tokenizer_options
)
# Load the model either from a saved torch model or from transformers.
if model_or_path.endswith(".pt"):
self._transformers_model = torch.load(
model_or_path, map_location=self.device
).eval()
else:
ModelForConditionalGeneration = _resolve_transformers_model(
self.model_family
)
self._transformers_model = (
ModelForConditionalGeneration.from_pretrained(
self.model_or_path, **model_options
)
.to(self.device)
.eval()
)
def translate(
self,
text: Union[str, List[str]],
source: str,
target: str,
batch_size: int = 32,
verbose: bool = False,
generation_options: dict = None,
) -> Union[str, List[str]]:
"""
*Translates a string or a list of strings from a source to a target language.*
{{params}}
{{text}} The content you want to translate.
{{source}} The language of the original text.
{{target}} The language of the translated text.
{{batch_size}} The number of samples to load at once. If set to `None`, it will process everything at once.
{{verbose}} Whether to display the progress bar for every batch processed.
{{generation_options}} The keyword arguments passed to `model.generate()`, where `model` is the underlying transformers model.
Note:
- Run `print(dlt.utils.available_languages())` to see what's available.
- A smaller value is preferred for `batch_size` if your (video) RAM is limited.
"""
if generation_options is None:
generation_options = {}
source = _resolve_lang_codes(source, "source", self.model_family)
target = _resolve_lang_codes(target, "target", self.model_family)
self._tokenizer.src_lang = source
original_text_type = type(text)
if original_text_type is str:
text = [text]
if batch_size is None:
batch_size = len(text)
generation_options.setdefault(
"forced_bos_token_id", self._tokenizer.convert_tokens_to_ids(target)
)
generation_options.setdefault("max_new_tokens", 512)
data_loader = torch.utils.data.DataLoader(text, batch_size=batch_size)
output_text = []
tqdm_iterator = data_loader
if verbose is True:
tqdm_iterator = tqdm(data_loader)
with torch.no_grad():
for batch in tqdm_iterator:
encoded = self._tokenizer(batch, return_tensors="pt", padding=True)
encoded.to(self.device)
generated_tokens = self._transformers_model.generate(
**encoded, **generation_options
).cpu()
decoded = self._tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
output_text.extend(decoded)
# If text: str and output_text: List[str], then we should convert output_text to str
if original_text_type is str and len(output_text) == 1:
output_text = output_text[0]
return output_text
def get_transformers_model(self):
"""
*Retrieve the underlying mBART transformer model.*
"""
return self._transformers_model
def get_tokenizer(self):
"""
*Retrieve the mBART huggingface tokenizer.*
"""
return self._tokenizer
def available_languages(self) -> List[str]:
"""
*Returns all the available languages for a given `dlt.TranslationModel`
instance.*
"""
return utils.available_languages(self.model_family)
def available_codes(self) -> List[str]:
"""
*Returns all the available codes for a given `dlt.TranslationModel`
instance.*
"""
return utils.available_codes(self.model_family)
def get_lang_code_map(self) -> Dict[str, str]:
"""
*Returns the language -> codes dictionary for a given `dlt.TranslationModel`
instance.*
"""
return utils.get_lang_code_map(self.model_family)
def save_obj(self, path: str = "saved_model") -> None:
"""
*Saves your model as a torch object and save your tokenizer.*
{{params}}
{{path}} The directory where you want to save your model and tokenizer
"""
os.makedirs(path, exist_ok=True)
torch.save(self._transformers_model, os.path.join(path, "weights.pt"))
self._tokenizer.save_pretrained(path)
dlt_config = dict(model_family=self.model_family)
json.dump(dlt_config, open(os.path.join(path, "dlt_config.json"), "w"))
@classmethod
def load_obj(cls, path: str = "saved_model", **kwargs):
"""
*Initialize `dlt.TranslationModel` from the torch object and tokenizer
saved with `dlt.TranslationModel.save_obj`*
{{params}}
{{path}} The directory where your torch model and tokenizer are stored
"""
config_prev = json.load(open(os.path.join(path, "dlt_config.json"), "rb"))
config_prev.update(kwargs)
return cls(
model_or_path=os.path.join(path, "weights.pt"),
tokenizer_path=path,
**config_prev,
)
================================================
FILE: dl_translate/lang/__init__.py
================================================
from .m2m100 import *
from . import m2m100, mbart50, nllb200
================================================
FILE: dl_translate/lang/m2m100.py
================================================
# Auto-generated. Do not modify, use scripts/generate_langs.py instead.
AFRIKAANS = "Afrikaans"
AMHARIC = "Amharic"
ARABIC = "Arabic"
ASTURIAN = "Asturian"
AZERBAIJANI = "Azerbaijani"
BASHKIR = "Bashkir"
BELARUSIAN = "Belarusian"
BULGARIAN = "Bulgarian"
BENGALI = "Bengali"
BRETON = "Breton"
BOSNIAN = "Bosnian"
CATALAN = "Catalan"
VALENCIAN = "Valencian"
CEBUANO = "Cebuano"
CZECH = "Czech"
WELSH = "Welsh"
DANISH = "Danish"
GERMAN = "German"
GREEK = "Greek"
ENGLISH = "English"
SPANISH = "Spanish"
ESTONIAN = "Estonian"
PERSIAN = "Persian"
FULAH = "Fulah"
FINNISH = "Finnish"
FRENCH = "French"
WESTERN_FRISIAN = "Western Frisian"
IRISH = "Irish"
GAELIC = "Gaelic"
SCOTTISH_GAELIC = "Scottish Gaelic"
GALICIAN = "Galician"
GUJARATI = "Gujarati"
HAUSA = "Hausa"
HEBREW = "Hebrew"
HINDI = "Hindi"
CROATIAN = "Croatian"
HAITIAN = "Haitian"
HAITIAN_CREOLE = "Haitian Creole"
HUNGARIAN = "Hungarian"
ARMENIAN = "Armenian"
INDONESIAN = "Indonesian"
IGBO = "Igbo"
ILOKO = "Iloko"
ICELANDIC = "Icelandic"
ITALIAN = "Italian"
JAPANESE = "Japanese"
JAVANESE = "Javanese"
GEORGIAN = "Georgian"
KAZAKH = "Kazakh"
KHMER = "Khmer"
CENTRAL_KHMER = "Central Khmer"
KANNADA = "Kannada"
KOREAN = "Korean"
LUXEMBOURGISH = "Luxembourgish"
LETZEBURGESCH = "Letzeburgesch"
GANDA = "Ganda"
LINGALA = "Lingala"
LAO = "Lao"
LITHUANIAN = "Lithuanian"
LATVIAN = "Latvian"
MALAGASY = "Malagasy"
MACEDONIAN = "Macedonian"
MALAYALAM = "Malayalam"
MONGOLIAN = "Mongolian"
MARATHI = "Marathi"
MALAY = "Malay"
BURMESE = "Burmese"
NEPALI = "Nepali"
DUTCH = "Dutch"
FLEMISH = "Flemish"
NORWEGIAN = "Norwegian"
NORTHERN_SOTHO = "Northern Sotho"
OCCITAN = "Occitan"
ORIYA = "Oriya"
PANJABI = "Panjabi"
PUNJABI = "Punjabi"
POLISH = "Polish"
PUSHTO = "Pushto"
PASHTO = "Pashto"
PORTUGUESE = "Portuguese"
ROMANIAN = "Romanian"
MOLDAVIAN = "Moldavian"
MOLDOVAN = "Moldovan"
RUSSIAN = "Russian"
SINDHI = "Sindhi"
SINHALA = "Sinhala"
SINHALESE = "Sinhalese"
SLOVAK = "Slovak"
SLOVENIAN = "Slovenian"
SOMALI = "Somali"
ALBANIAN = "Albanian"
SERBIAN = "Serbian"
SWATI = "Swati"
SUNDANESE = "Sundanese"
SWEDISH = "Swedish"
SWAHILI = "Swahili"
TAMIL = "Tamil"
THAI = "Thai"
TAGALOG = "Tagalog"
TSWANA = "Tswana"
TURKISH = "Turkish"
UKRAINIAN = "Ukrainian"
URDU = "Urdu"
UZBEK = "Uzbek"
VIETNAMESE = "Vietnamese"
WOLOF = "Wolof"
XHOSA = "Xhosa"
YIDDISH = "Yiddish"
YORUBA = "Yoruba"
CHINESE = "Chinese"
ZULU = "Zulu"
================================================
FILE: dl_translate/lang/mbart50.py
================================================
# Auto-generated. Do not modify, use scripts/generate_langs.py instead.
ARABIC = "Arabic"
CZECH = "Czech"
GERMAN = "German"
ENGLISH = "English"
SPANISH = "Spanish"
ESTONIAN = "Estonian"
FINNISH = "Finnish"
FRENCH = "French"
GUJARATI = "Gujarati"
HINDI = "Hindi"
ITALIAN = "Italian"
JAPANESE = "Japanese"
KAZAKH = "Kazakh"
KOREAN = "Korean"
LITHUANIAN = "Lithuanian"
LATVIAN = "Latvian"
BURMESE = "Burmese"
NEPALI = "Nepali"
DUTCH = "Dutch"
ROMANIAN = "Romanian"
RUSSIAN = "Russian"
SINHALA = "Sinhala"
TURKISH = "Turkish"
VIETNAMESE = "Vietnamese"
CHINESE = "Chinese"
AFRIKAANS = "Afrikaans"
AZERBAIJANI = "Azerbaijani"
BENGALI = "Bengali"
PERSIAN = "Persian"
HEBREW = "Hebrew"
CROATIAN = "Croatian"
INDONESIAN = "Indonesian"
GEORGIAN = "Georgian"
KHMER = "Khmer"
MACEDONIAN = "Macedonian"
MALAYALAM = "Malayalam"
MONGOLIAN = "Mongolian"
MARATHI = "Marathi"
POLISH = "Polish"
PASHTO = "Pashto"
PORTUGUESE = "Portuguese"
SWEDISH = "Swedish"
SWAHILI = "Swahili"
TAMIL = "Tamil"
TELUGU = "Telugu"
THAI = "Thai"
TAGALOG = "Tagalog"
UKRAINIAN = "Ukrainian"
URDU = "Urdu"
XHOSA = "Xhosa"
GALICIAN = "Galician"
SLOVENE = "Slovene"
================================================
FILE: dl_translate/lang/nllb200.py
================================================
# Auto-generated. Do not modify, use scripts/generate_langs.py instead.
ACEHNESE_ARABIC_SCRIPT = "Acehnese (Arabic script)"
ACEHNESE_LATIN_SCRIPT = "Acehnese (Latin script)"
MESOPOTAMIAN_ARABIC = "Mesopotamian Arabic"
TAIZZI_ADENI_ARABIC = "Ta'izzi-Adeni Arabic"
TUNISIAN_ARABIC = "Tunisian Arabic"
AFRIKAANS = "Afrikaans"
SOUTH_LEVANTINE_ARABIC = "South Levantine Arabic"
AKAN = "Akan"
AMHARIC = "Amharic"
NORTH_LEVANTINE_ARABIC = "North Levantine Arabic"
MODERN_STANDARD_ARABIC = "Modern Standard Arabic"
MODERN_STANDARD_ARABIC_ROMANIZED = "Modern Standard Arabic (Romanized)"
NAJDI_ARABIC = "Najdi Arabic"
MOROCCAN_ARABIC = "Moroccan Arabic"
EGYPTIAN_ARABIC = "Egyptian Arabic"
ASSAMESE = "Assamese"
ASTURIAN = "Asturian"
AWADHI = "Awadhi"
CENTRAL_AYMARA = "Central Aymara"
SOUTH_AZERBAIJANI = "South Azerbaijani"
NORTH_AZERBAIJANI = "North Azerbaijani"
BASHKIR = "Bashkir"
BAMBARA = "Bambara"
BALINESE = "Balinese"
BELARUSIAN = "Belarusian"
BEMBA = "Bemba"
BENGALI = "Bengali"
BHOJPURI = "Bhojpuri"
BANJAR_ARABIC_SCRIPT = "Banjar (Arabic script)"
BANJAR_LATIN_SCRIPT = "Banjar (Latin script)"
STANDARD_TIBETAN = "Standard Tibetan"
BOSNIAN = "Bosnian"
BUGINESE = "Buginese"
BULGARIAN = "Bulgarian"
CATALAN = "Catalan"
CEBUANO = "Cebuano"
CZECH = "Czech"
CHOKWE = "Chokwe"
CENTRAL_KURDISH = "Central Kurdish"
CRIMEAN_TATAR = "Crimean Tatar"
WELSH = "Welsh"
DANISH = "Danish"
GERMAN = "German"
SOUTHWESTERN_DINKA = "Southwestern Dinka"
DYULA = "Dyula"
DZONGKHA = "Dzongkha"
GREEK = "Greek"
ENGLISH = "English"
ESPERANTO = "Esperanto"
ESTONIAN = "Estonian"
BASQUE = "Basque"
EWE = "Ewe"
FAROESE = "Faroese"
FIJIAN = "Fijian"
FINNISH = "Finnish"
FON = "Fon"
FRENCH = "French"
FRIULIAN = "Friulian"
NIGERIAN_FULFULDE = "Nigerian Fulfulde"
SCOTTISH_GAELIC = "Scottish Gaelic"
IRISH = "Irish"
GALICIAN = "Galician"
GUARANI = "Guarani"
GUJARATI = "Gujarati"
HAITIAN_CREOLE = "Haitian Creole"
HAUSA = "Hausa"
HEBREW = "Hebrew"
HINDI = "Hindi"
CHHATTISGARHI = "Chhattisgarhi"
CROATIAN = "Croatian"
HUNGARIAN = "Hungarian"
ARMENIAN = "Armenian"
IGBO = "Igbo"
ILOCANO = "Ilocano"
INDONESIAN = "Indonesian"
ICELANDIC = "Icelandic"
ITALIAN = "Italian"
JAVANESE = "Javanese"
JAPANESE = "Japanese"
KABYLE = "Kabyle"
JINGPHO = "Jingpho"
KAMBA = "Kamba"
KANNADA = "Kannada"
KASHMIRI_ARABIC_SCRIPT = "Kashmiri (Arabic script)"
KASHMIRI_DEVANAGARI_SCRIPT = "Kashmiri (Devanagari script)"
GEORGIAN = "Georgian"
CENTRAL_KANURI_ARABIC_SCRIPT = "Central Kanuri (Arabic script)"
CENTRAL_KANURI_LATIN_SCRIPT = "Central Kanuri (Latin script)"
KAZAKH = "Kazakh"
KABIYÈ = "Kabiyè"
KABUVERDIANU = "Kabuverdianu"
KHMER = "Khmer"
KIKUYU = "Kikuyu"
KINYARWANDA = "Kinyarwanda"
KYRGYZ = "Kyrgyz"
KIMBUNDU = "Kimbundu"
NORTHERN_KURDISH = "Northern Kurdish"
KIKONGO = "Kikongo"
KOREAN = "Korean"
LAO = "Lao"
LIGURIAN = "Ligurian"
LIMBURGISH = "Limburgish"
LINGALA = "Lingala"
LITHUANIAN = "Lithuanian"
LOMBARD = "Lombard"
LATGALIAN = "Latgalian"
LUXEMBOURGISH = "Luxembourgish"
LUBA_KASAI = "Luba-Kasai"
GANDA = "Ganda"
LUO = "Luo"
MIZO = "Mizo"
STANDARD_LATVIAN = "Standard Latvian"
MAGAHI = "Magahi"
MAITHILI = "Maithili"
MALAYALAM = "Malayalam"
MARATHI = "Marathi"
MINANGKABAU_ARABIC_SCRIPT = "Minangkabau (Arabic script)"
MINANGKABAU_LATIN_SCRIPT = "Minangkabau (Latin script)"
MACEDONIAN = "Macedonian"
PLATEAU_MALAGASY = "Plateau Malagasy"
MALTESE = "Maltese"
MEITEI_BENGALI_SCRIPT = "Meitei (Bengali script)"
HALH_MONGOLIAN = "Halh Mongolian"
MOSSI = "Mossi"
MAORI = "Maori"
BURMESE = "Burmese"
DUTCH = "Dutch"
NORWEGIAN_NYNORSK = "Norwegian Nynorsk"
NORWEGIAN_BOKMÅL = "Norwegian Bokmål"
NEPALI = "Nepali"
NORTHERN_SOTHO = "Northern Sotho"
NUER = "Nuer"
NYANJA = "Nyanja"
OCCITAN = "Occitan"
WEST_CENTRAL_OROMO = "West Central Oromo"
ODIA = "Odia"
PANGASINAN = "Pangasinan"
EASTERN_PANJABI = "Eastern Panjabi"
PAPIAMENTO = "Papiamento"
WESTERN_PERSIAN = "Western Persian"
POLISH = "Polish"
PORTUGUESE = "Portuguese"
DARI = "Dari"
SOUTHERN_PASHTO = "Southern Pashto"
AYACUCHO_QUECHUA = "Ayacucho Quechua"
ROMANIAN = "Romanian"
RUNDI = "Rundi"
RUSSIAN = "Russian"
SANGO = "Sango"
SANSKRIT = "Sanskrit"
SANTALI = "Santali"
SICILIAN = "Sicilian"
SHAN = "Shan"
SINHALA = "Sinhala"
SLOVAK = "Slovak"
SLOVENIAN = "Slovenian"
SAMOAN = "Samoan"
SHONA = "Shona"
SINDHI = "Sindhi"
SOMALI = "Somali"
SOUTHERN_SOTHO = "Southern Sotho"
SPANISH = "Spanish"
TOSK_ALBANIAN = "Tosk Albanian"
SARDINIAN = "Sardinian"
SERBIAN = "Serbian"
SWATI = "Swati"
SUNDANESE = "Sundanese"
SWEDISH = "Swedish"
SWAHILI = "Swahili"
SILESIAN = "Silesian"
TAMIL = "Tamil"
TATAR = "Tatar"
TELUGU = "Telugu"
TAJIK = "Tajik"
TAGALOG = "Tagalog"
THAI = "Thai"
TIGRINYA = "Tigrinya"
TAMASHEQ_LATIN_SCRIPT = "Tamasheq (Latin script)"
TAMASHEQ_TIFINAGH_SCRIPT = "Tamasheq (Tifinagh script)"
TOK_PISIN = "Tok Pisin"
TSWANA = "Tswana"
TSONGA = "Tsonga"
TURKMEN = "Turkmen"
TUMBUKA = "Tumbuka"
TURKISH = "Turkish"
TWI = "Twi"
CENTRAL_ATLAS_TAMAZIGHT = "Central Atlas Tamazight"
UYGHUR = "Uyghur"
UKRAINIAN = "Ukrainian"
UMBUNDU = "Umbundu"
URDU = "Urdu"
NORTHERN_UZBEK = "Northern Uzbek"
VENETIAN = "Venetian"
VIETNAMESE = "Vietnamese"
WARAY = "Waray"
WOLOF = "Wolof"
XHOSA = "Xhosa"
EASTERN_YIDDISH = "Eastern Yiddish"
YORUBA = "Yoruba"
YUE_CHINESE = "Yue Chinese"
CHINESE_SIMPLIFIED = "Chinese (Simplified)"
CHINESE_TRADITIONAL = "Chinese (Traditional)"
STANDARD_MALAY = "Standard Malay"
ZULU = "Zulu"
================================================
FILE: dl_translate/utils.py
================================================
from typing import Dict, List
from ._pairs import _PAIRS_MBART50, _PAIRS_M2M100, _PAIRS_NLLB200
def _infer_model_family(model_or_path):
di = {
"facebook/mbart-large-50-many-to-many-mmt": "mbart50",
"facebook/m2m100_418M": "m2m100",
"facebook/m2m100_1.2B": "m2m100",
"facebook/nllb-200-distilled-600M": "nllb200",
"facebook/nllb-200-distilled-1.3B": "nllb200",
"facebook/nllb-200-1.3B": "nllb200",
"facebook/nllb-200-3.3B": "nllb200",
}
if model_or_path in di:
return di[model_or_path]
else:
error_msg = f'Unable to infer the model_family from "{model_or_path}". Try explicitly setting the value of model_family to "mbart50" or "m2m100".'
raise ValueError(error_msg)
def _infer_model_or_path(model_or_path):
di = {
"mbart50": "facebook/mbart-large-50-many-to-many-mmt",
"m2m100": "facebook/m2m100_418M",
"m2m100-small": "facebook/m2m100_418M",
"m2m100-medium": "facebook/m2m100_1.2B",
"nllb200": "facebook/nllb-200-distilled-600M",
"nllb200-small": "facebook/nllb-200-distilled-600M",
"nllb200-medium": "facebook/nllb-200-distilled-1.3B",
"nllb200-medium-regular": "facebook/nllb-200-1.3B",
"nllb200-large": "facebook/nllb-200-3.3B",
}
return di.get(model_or_path, model_or_path)
def _weights2pairs():
return {
"mbart50": _PAIRS_MBART50,
"mbart-large-50-many-to-many-mmt": _PAIRS_MBART50,
"facebook/mbart-large-50-many-to-many-mmt": _PAIRS_MBART50,
"m2m100": _PAIRS_M2M100,
"m2m100_418M": _PAIRS_M2M100,
"m2m100_1.2B": _PAIRS_M2M100,
"facebook/m2m100_418M": _PAIRS_M2M100,
"facebook/m2m100_1.2B": _PAIRS_M2M100,
"nllb200": _PAIRS_NLLB200,
"nllb-200-distilled": _PAIRS_NLLB200,
"nllb-200-distilled-600M": _PAIRS_NLLB200,
"nllb-200-distilled-1.3B": _PAIRS_NLLB200,
"nllb-200-1.3B": _PAIRS_NLLB200,
"nllb-200-3.3B": _PAIRS_NLLB200,
"facebook/nllb-200-distilled-600M": _PAIRS_NLLB200,
"facebook/nllb-200-distilled-1.3B": _PAIRS_NLLB200,
"facebook/nllb-200-1.3B": _PAIRS_NLLB200,
"facebook/nllb-200-3.3B": _PAIRS_NLLB200,
}
def _dict_from_weights(weights: str) -> dict:
"""Returns a dictionary of lang, codes, pairs if the provided weights is supported."""
if weights in _weights2pairs():
pairs = _weights2pairs()[weights]
return {
"langs": tuple(pair[0] for pair in pairs),
"codes": tuple(pair[1] for pair in pairs),
"pairs": dict(pairs),
}
elif weights.lower() in _weights2pairs():
pairs = _weights2pairs()[weights.lower()]
return {
"langs": tuple(pair[0] for pair in pairs),
"codes": tuple(pair[1] for pair in pairs),
"pairs": dict(pairs),
}
else:
error_message = f"Incorrect argument '{weights}' for parameter weights. Please choose from: {list(_weights2pairs().keys())}"
raise ValueError(error_message)
def get_lang_code_map(weights: str = "m2m100") -> Dict[str, str]:
"""
*Get a dictionary mapping a language -> code for a given model. The code will depend on the model you choose.*
{{params}}
{{weights}} The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 languages available to use.
"""
return _dict_from_weights(weights)["pairs"]
def available_languages(weights: str = "m2m100") -> List[str]:
"""
*Get all the languages available for a given model.*
{{params}}
{{weights}} The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 languages available to use.
"""
return _dict_from_weights(weights)["langs"]
def available_codes(weights: str = "m2m100") -> List[str]:
"""
*Get all the codes available for a given model. The code format will depend on the model you select.*
{{params}}
{{weights}} The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 codes available to use.
"""
return _dict_from_weights(weights)["codes"]
================================================
FILE: docs/available_languages.md
================================================
# Languages Available
This page gives all the languages available for each model family.
## MBart 50
| Language Name | Code |
| --- | --- |
| Arabic | ar_AR |
| Czech | cs_CZ |
| German | de_DE |
| English | en_XX |
| Spanish | es_XX |
| Estonian | et_EE |
| Finnish | fi_FI |
| French | fr_XX |
| Gujarati | gu_IN |
| Hindi | hi_IN |
| Italian | it_IT |
| Japanese | ja_XX |
| Kazakh | kk_KZ |
| Korean | ko_KR |
| Lithuanian | lt_LT |
| Latvian | lv_LV |
| Burmese | my_MM |
| Nepali | ne_NP |
| Dutch | nl_XX |
| Romanian | ro_RO |
| Russian | ru_RU |
| Sinhala | si_LK |
| Turkish | tr_TR |
| Vietnamese | vi_VN |
| Chinese | zh_CN |
| Afrikaans | af_ZA |
| Azerbaijani | az_AZ |
| Bengali | bn_IN |
| Persian | fa_IR |
| Hebrew | he_IL |
| Croatian | hr_HR |
| Indonesian | id_ID |
| Georgian | ka_GE |
| Khmer | km_KH |
| Macedonian | mk_MK |
| Malayalam | ml_IN |
| Mongolian | mn_MN |
| Marathi | mr_IN |
| Polish | pl_PL |
| Pashto | ps_AF |
| Portuguese | pt_XX |
| Swedish | sv_SE |
| Swahili | sw_KE |
| Tamil | ta_IN |
| Telugu | te_IN |
| Thai | th_TH |
| Tagalog | tl_XX |
| Ukrainian | uk_UA |
| Urdu | ur_PK |
| Xhosa | xh_ZA |
| Galician | gl_ES |
| Slovene | sl_SI |
## M2M-100
| Language Name | Code |
| --- | --- |
| Afrikaans | af |
| Amharic | am |
| Arabic | ar |
| Asturian | ast |
| Azerbaijani | az |
| Bashkir | ba |
| Belarusian | be |
| Bulgarian | bg |
| Bengali | bn |
| Breton | br |
| Bosnian | bs |
| Catalan | ca |
| Valencian | ca |
| Cebuano | ceb |
| Czech | cs |
| Welsh | cy |
| Danish | da |
| German | de |
| Greek | el |
| English | en |
| Spanish | es |
| Estonian | et |
| Persian | fa |
| Fulah | ff |
| Finnish | fi |
| French | fr |
| Western Frisian | fy |
| Irish | ga |
| Gaelic | gd |
| Scottish Gaelic | gd |
| Galician | gl |
| Gujarati | gu |
| Hausa | ha |
| Hebrew | he |
| Hindi | hi |
| Croatian | hr |
| Haitian | ht |
| Haitian Creole | ht |
| Hungarian | hu |
| Armenian | hy |
| Indonesian | id |
| Igbo | ig |
| Iloko | ilo |
| Icelandic | is |
| Italian | it |
| Japanese | ja |
| Javanese | jv |
| Georgian | ka |
| Kazakh | kk |
| Khmer | km |
| Central Khmer | km |
| Kannada | kn |
| Korean | ko |
| Luxembourgish | lb |
| Letzeburgesch | lb |
| Ganda | lg |
| Lingala | ln |
| Lao | lo |
| Lithuanian | lt |
| Latvian | lv |
| Malagasy | mg |
| Macedonian | mk |
| Malayalam | ml |
| Mongolian | mn |
| Marathi | mr |
| Malay | ms |
| Burmese | my |
| Nepali | ne |
| Dutch | nl |
| Flemish | nl |
| Norwegian | no |
| Northern Sotho | ns |
| Occitan | oc |
| Oriya | or |
| Panjabi | pa |
| Punjabi | pa |
| Polish | pl |
| Pushto | ps |
| Pashto | ps |
| Portuguese | pt |
| Romanian | ro |
| Moldavian | ro |
| Moldovan | ro |
| Russian | ru |
| Sindhi | sd |
| Sinhala | si |
| Sinhalese | si |
| Slovak | sk |
| Slovenian | sl |
| Somali | so |
| Albanian | sq |
| Serbian | sr |
| Swati | ss |
| Sundanese | su |
| Swedish | sv |
| Swahili | sw |
| Tamil | ta |
| Thai | th |
| Tagalog | tl |
| Tswana | tn |
| Turkish | tr |
| Ukrainian | uk |
| Urdu | ur |
| Uzbek | uz |
| Vietnamese | vi |
| Wolof | wo |
| Xhosa | xh |
| Yiddish | yi |
| Yoruba | yo |
| Chinese | zh |
| Zulu | zu |
## NLLB-200
| Language Name | Code |
| --- | --- |
| Acehnese (Arabic script) | ace_Arab |
| Acehnese (Latin script) | ace_Latn |
| Mesopotamian Arabic | acm_Arab |
| Ta'izzi-Adeni Arabic | acq_Arab |
| Tunisian Arabic | aeb_Arab |
| Afrikaans | afr_Latn |
| South Levantine Arabic | ajp_Arab |
| Akan | aka_Latn |
| Amharic | amh_Ethi |
| North Levantine Arabic | apc_Arab |
| Modern Standard Arabic | arb_Arab |
| Modern Standard Arabic (Romanized) | arb_Latn |
| Najdi Arabic | ars_Arab |
| Moroccan Arabic | ary_Arab |
| Egyptian Arabic | arz_Arab |
| Assamese | asm_Beng |
| Asturian | ast_Latn |
| Awadhi | awa_Deva |
| Central Aymara | ayr_Latn |
| South Azerbaijani | azb_Arab |
| North Azerbaijani | azj_Latn |
| Bashkir | bak_Cyrl |
| Bambara | bam_Latn |
| Balinese | ban_Latn |
| Belarusian | bel_Cyrl |
| Bemba | bem_Latn |
| Bengali | ben_Beng |
| Bhojpuri | bho_Deva |
| Banjar (Arabic script) | bjn_Arab |
| Banjar (Latin script) | bjn_Latn |
| Standard Tibetan | bod_Tibt |
| Bosnian | bos_Latn |
| Buginese | bug_Latn |
| Bulgarian | bul_Cyrl |
| Catalan | cat_Latn |
| Cebuano | ceb_Latn |
| Czech | ces_Latn |
| Chokwe | cjk_Latn |
| Central Kurdish | ckb_Arab |
| Crimean Tatar | crh_Latn |
| Welsh | cym_Latn |
| Danish | dan_Latn |
| German | deu_Latn |
| Southwestern Dinka | dik_Latn |
| Dyula | dyu_Latn |
| Dzongkha | dzo_Tibt |
| Greek | ell_Grek |
| English | eng_Latn |
| Esperanto | epo_Latn |
| Estonian | est_Latn |
| Basque | eus_Latn |
| Ewe | ewe_Latn |
| Faroese | fao_Latn |
| Fijian | fij_Latn |
| Finnish | fin_Latn |
| Fon | fon_Latn |
| French | fra_Latn |
| Friulian | fur_Latn |
| Nigerian Fulfulde | fuv_Latn |
| Scottish Gaelic | gla_Latn |
| Irish | gle_Latn |
| Galician | glg_Latn |
| Guarani | grn_Latn |
| Gujarati | guj_Gujr |
| Haitian Creole | hat_Latn |
| Hausa | hau_Latn |
| Hebrew | heb_Hebr |
| Hindi | hin_Deva |
| Chhattisgarhi | hne_Deva |
| Croatian | hrv_Latn |
| Hungarian | hun_Latn |
| Armenian | hye_Armn |
| Igbo | ibo_Latn |
| Ilocano | ilo_Latn |
| Indonesian | ind_Latn |
| Icelandic | isl_Latn |
| Italian | ita_Latn |
| Javanese | jav_Latn |
| Japanese | jpn_Jpan |
| Kabyle | kab_Latn |
| Jingpho | kac_Latn |
| Kamba | kam_Latn |
| Kannada | kan_Knda |
| Kashmiri (Arabic script) | kas_Arab |
| Kashmiri (Devanagari script) | kas_Deva |
| Georgian | kat_Geor |
| Central Kanuri (Arabic script) | knc_Arab |
| Central Kanuri (Latin script) | knc_Latn |
| Kazakh | kaz_Cyrl |
| Kabiyè | kbp_Latn |
| Kabuverdianu | kea_Latn |
| Khmer | khm_Khmr |
| Kikuyu | kik_Latn |
| Kinyarwanda | kin_Latn |
| Kyrgyz | kir_Cyrl |
| Kimbundu | kmb_Latn |
| Northern Kurdish | kmr_Latn |
| Kikongo | kon_Latn |
| Korean | kor_Hang |
| Lao | lao_Laoo |
| Ligurian | lij_Latn |
| Limburgish | lim_Latn |
| Lingala | lin_Latn |
| Lithuanian | lit_Latn |
| Lombard | lmo_Latn |
| Latgalian | ltg_Latn |
| Luxembourgish | ltz_Latn |
| Luba-Kasai | lua_Latn |
| Ganda | lug_Latn |
| Luo | luo_Latn |
| Mizo | lus_Latn |
| Standard Latvian | lvs_Latn |
| Magahi | mag_Deva |
| Maithili | mai_Deva |
| Malayalam | mal_Mlym |
| Marathi | mar_Deva |
| Minangkabau (Arabic script) | min_Arab |
| Minangkabau (Latin script) | min_Latn |
| Macedonian | mkd_Cyrl |
| Plateau Malagasy | plt_Latn |
| Maltese | mlt_Latn |
| Meitei (Bengali script) | mni_Beng |
| Halh Mongolian | khk_Cyrl |
| Mossi | mos_Latn |
| Maori | mri_Latn |
| Burmese | mya_Mymr |
| Dutch | nld_Latn |
| Norwegian Nynorsk | nno_Latn |
| Norwegian Bokmål | nob_Latn |
| Nepali | npi_Deva |
| Northern Sotho | nso_Latn |
| Nuer | nus_Latn |
| Nyanja | nya_Latn |
| Occitan | oci_Latn |
| West Central Oromo | gaz_Latn |
| Odia | ory_Orya |
| Pangasinan | pag_Latn |
| Eastern Panjabi | pan_Guru |
| Papiamento | pap_Latn |
| Western Persian | pes_Arab |
| Polish | pol_Latn |
| Portuguese | por_Latn |
| Dari | prs_Arab |
| Southern Pashto | pbt_Arab |
| Ayacucho Quechua | quy_Latn |
| Romanian | ron_Latn |
| Rundi | run_Latn |
| Russian | rus_Cyrl |
| Sango | sag_Latn |
| Sanskrit | san_Deva |
| Santali | sat_Olck |
| Sicilian | scn_Latn |
| Shan | shn_Mymr |
| Sinhala | sin_Sinh |
| Slovak | slk_Latn |
| Slovenian | slv_Latn |
| Samoan | smo_Latn |
| Shona | sna_Latn |
| Sindhi | snd_Arab |
| Somali | som_Latn |
| Southern Sotho | sot_Latn |
| Spanish | spa_Latn |
| Tosk Albanian | als_Latn |
| Sardinian | srd_Latn |
| Serbian | srp_Cyrl |
| Swati | ssw_Latn |
| Sundanese | sun_Latn |
| Swedish | swe_Latn |
| Swahili | swh_Latn |
| Silesian | szl_Latn |
| Tamil | tam_Taml |
| Tatar | tat_Cyrl |
| Telugu | tel_Telu |
| Tajik | tgk_Cyrl |
| Tagalog | tgl_Latn |
| Thai | tha_Thai |
| Tigrinya | tir_Ethi |
| Tamasheq (Latin script) | taq_Latn |
| Tamasheq (Tifinagh script) | taq_Tfng |
| Tok Pisin | tpi_Latn |
| Tswana | tsn_Latn |
| Tsonga | tso_Latn |
| Turkmen | tuk_Latn |
| Tumbuka | tum_Latn |
| Turkish | tur_Latn |
| Twi | twi_Latn |
| Central Atlas Tamazight | tzm_Tfng |
| Uyghur | uig_Arab |
| Ukrainian | ukr_Cyrl |
| Umbundu | umb_Latn |
| Urdu | urd_Arab |
| Northern Uzbek | uzn_Latn |
| Venetian | vec_Latn |
| Vietnamese | vie_Latn |
| Waray | war_Latn |
| Wolof | wol_Latn |
| Xhosa | xho_Latn |
| Eastern Yiddish | ydd_Hebr |
| Yoruba | yor_Latn |
| Yue Chinese | yue_Hant |
| Chinese (Simplified) | zho_Hans |
| Chinese (Traditional) | zho_Hant |
| Standard Malay | zsm_Latn |
| Zulu | zul_Latn |
================================================
FILE: docs/contributing.md
================================================
# Contributions
If you wish to contribute to the project, please do the following:
1. Verify if there's an existing similar issue.
2. If no issue exists, create it.
3. Once the contribution has been discussed inside the issue, fork this repo.
4. Before modifying any code, make sure to read the sections below.
5. Once you are done with your contribution, start a PR and tag a codeowner.
## Setup
To set up the development environment, clone the repo:
```bash
git clone https://github.com/xhlulu/dl-translate
cd dl-translate
```
Create a new venv and install the dev dependencies
```bash
python -m venv venv
source venv/bin/activate
pip install -e .[dev]
```
## Code linting
To ensure consistent and readable code, we use `black`. To run:
```bash
python black .
```
## Running tests
To run **all** the tests:
```bash
python -m pytest tests
```
For quick tests, run:
```bash
python -m pytest tests/fast
```
## Documentation
To re-generate the documentation after the source code was modified:
```bash
python scripts/render_references.py
```
To run the docs locally, run:
```
mkdocs serve -t material
```
Once ready, you can build it:
```
mkdocs build -t material
```
Or release it on GitHub Pages:
```
mkdocs gh-deploy -t material
```
================================================
FILE: docs/index.md
================================================
# User Guide
Quick links:
💻 [GitHub Repository](https://github.com/xhlulu/dl-translate)
📚 [Documentation](https://xhluca.github.io/dl-translate)
🐍 [PyPi project](https://pypi.org/project/dl-translate/)
🧪 [Colab Demo](https://colab.research.google.com/github/xhlulu/dl-translate/blob/main/demos/colab_demo.ipynb) / [Kaggle Demo](https://www.kaggle.com/xhlulu/dl-translate-demo/)
## Quickstart
Install the library with pip:
```
pip install dl-translate
```
To translate some text:
```python
import dl_translate as dlt
mt = dlt.TranslationModel() # Slow when you load it for the first time
text_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
mt.translate(text_hi, source=dlt.lang.HINDI, target=dlt.lang.ENGLISH)
```
Above, you can see that `dlt.lang` contains variables representing each of the 50 available languages with auto-complete support. Alternatively, you can specify the language (e.g. "Arabic") or the language code (e.g. "fr" for French):
```python
text_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."
mt.translate(text_ar, source="Arabic", target="fr")
```
If you want to verify whether a language is available, you can check it:
```python
print(mt.available_languages()) # All languages that you can use
print(mt.available_codes()) # Code corresponding to each language accepted
print(mt.get_lang_code_map()) # Dictionary of lang -> code
```
## Usage
### Selecting a device
When you load the model, you can specify the device using the `device` argument. By default, the value will be `device="auto"`, which means it will use a GPU if possible. You can also explicitly set `device="cpu"` or `device="gpu"`, or some other strings accepted by [`torch.device()`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device). __In general, it is recommend to use a GPU if you want a reasonable processing time.__
```python
mt = dlt.TranslationModel(device="auto") # Automatically select device
mt = dlt.TranslationModel(device="cpu") # Force you to use a CPU
mt = dlt.TranslationModel(device="gpu") # Force you to use a GPU
mt = dlt.TranslationModel(device="cuda:2") # Use the 3rd GPU available
```
### Choosing a different model
By default, the `m2m100` model will be used. However, there are a few options:
* [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html): Allows translations across 50 languages.
* [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html): Allows translations across 100 languages.
* [nllb-200](https://huggingface.co/docs/transformers/model_doc/nllb) (New in v0.3): Allows translations across 200 languages, and is faster than m2m100 (On RTX A6000, we can see speed up of 3x).
Here's an example:
```python
# The default approval
mt = dlt.TranslationModel("m2m100") # Shorthand
mt = dlt.TranslationModel("facebook/m2m100_418M") # Huggingface repo
# If you want to use mBART-50 Large
mt = dlt.TranslationModel("mbart50")
mt = dlt.TranslationModel("facebook/mbart-large-50-many-to-many-mmt")
# Or NLLB-200 (faster and has 200 languages)
mt = dlt.TranslationModel("nllb200")
mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M")
```
Note that the language code will change depending on the model family. To find out the correct language codes, please read the doc page on available languages or run `mt.available_codes()`.
By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt), [m2m100](https://huggingface.co/facebook/m2m100_418M), or [nllb200](https://huggingface.co/facebook/nllb-200-distilled-600M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`:
```python
mt = dlt.TranslationModel("/path/to/model/directory/", model_family="mbart50")
mt = dlt.TranslationModel("facebook/m2m100_1.2B", model_family="m2m100")
mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M", model_family="nllb200")
```
Notes:
* Make sure your tokenizer is also stored in the same directory if you load from a file.
* The available languages will change if you select a different model, so you will not be able to leverage `dlt.lang` or `dlt.utils`.
### Breaking down into sentences
It is not recommended to use extremely long texts as it takes more time to process. Instead, you can try to break them down into sentences. Multiple solutions exists for that, including doing it manually and using the `nltk` library.
A quick approach would be to split them by period. However, you have to ensure that there are no periods used for abbreviations (such as `Mr.` or `Dr.`). For example, it will work in the following case:
```python
text = "Mr Smith went to his favorite cafe. There, he met his friend Dr Doe."
sents = text.split(".")
".".join(mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH))
```
For more complex cases (e.g. where you use periods for abbreviations), you can use `nltk`. First install the library with `pip install nltk`, then run:
```python
import nltk
nltk.download("punkt")
text = "Mr. Smith went to his favorite cafe. There, he met his friend Dr. Doe."
sents = nltk.tokenize.sent_tokenize(text, "english") # don't use dlt.lang.ENGLISH
" ".join(mt.translate(sents, source=dlt.lang.ENGLISH, target=dlt.lang.FRENCH))
```
### Batch size and verbosity when using `translate`
It's possible to set a batch size (i.e. the number of elements processed at once) for `mt.translate` and whether you want to see the progress bar or not:
```python
...
mt = dlt.TranslationModel()
mt.translate(text, source, target, batch_size=32, verbose=True)
```
If you set `batch_size=None`, it will compute the entire `text` at once rather than splitting into "chunks". We recommend lowering `batch_size` if you do not have a lot of RAM or VRAM and run into CUDA memory error. Set a higher value if you are using a high-end GPU and the VRAM is not fully utilized.
### `dlt.utils` module
An alternative to `mt.available_languages()` is the `dlt.utils` module. You can use it to find out which languages and codes are available:
```python
print(dlt.utils.available_languages('mbart50')) # All languages that you can use
print(dlt.utils.available_codes('mbart50')) # Code corresponding to each language accepted
print(dlt.utils.get_lang_code_map('mbart50')) # Dictionary of lang -> code
print(dlt.utils.available_languages('m2m100')) # write the name of the model family
```
At the moment, the following models are accepted:
- `"mbart50"`
- `"m2m100"`
- `"nllb200"`
### Offline usage
Unlike the Google translate or MSFT Translator APIs, this library can be fully used offline. However, you will need to first download the packages and models, and move them to your offline environment to be installed and loaded inside a venv.
First, run in your terminal:
```bash
mkdir dlt
cd dlt
mkdir libraries
pip download -d libraries/ dl-translate
```
Once all the required packages are downloaded, you will need to use huggingface hub to download the files. Install it with `pip install huggingface-hub`. Then, run inside Python:
```python
import shutil
import huggingface_hub as hub
dirname = hub.snapshot_download("facebook/m2m100_418M")
shutil.copytree(dirname, "cached_model_m2m100") # Copy to a permanent folder
```
Now, move everything in the `dlt` directory to your offline environment. Create a virtual environment and run the following in terminal:
```bash
pip install --no-index --find-links libraries/ dl-translate
```
Now, run inside Python:
```python
import dl_translate as dlt
mt = dlt.TranslationModel("cached_model_m2m100", model_family="m2m100")
```
## Advanced
The following section assumes you have knowledge of PyTorch and Huggingface Transformers.
### Saving and loading
If you wish to accelerate the loading time the translation model, you can use `save_obj`. Later you can reload it with `load_obj` by specifying the same directory that you are using to save.
```python
mt = dlt.TranslationModel()
# ...
mt.save_obj('saved_model')
# ...
mt = dlt.TranslationModel.load_obj('saved_model')
```
**Warning:** Only use this if you are certain the torch module saved in `saved_model/weights.pt` can be correctly loaded. Indeed, it is possible that the `huggingface`, `torch` or some other dependencies change between when you called `save_obj` and `load_obj`, and that might break your code. Thus, it is recommend to only run `load_obj` in the same environment/session as `save_obj`. **Note this method might be deprecated in the future once there's no speed benefit in loading this way.**
### Interacting with underlying model and tokenizer
When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer (which will respectively be passed to `ModelForConditionalGeneration.from_pretrained` and `TokenizerFast.from_pretrained`):
```python
mt = dlt.TranslationModel(
model_options=dict(
state_dict=...,
cache_dir=...,
...
),
tokenizer_options=dict(
tokenizer_file=...,
eos_token=...,
...
)
)
```
You can also access the underlying `transformers` model and `tokenizer`:
```python
transformers_model = mt.get_transformers_model()
tokenizer = mt.get_tokenizer()
```
For more information about the models themselves, please read the docs on [mBART](https://huggingface.co/transformers/master/model_doc/mbart.html) and [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html).
### Keyword arguments for the `generate()` method of the underlying model
When running `mt.translate`, you can also give a `generation_options` dictionary that is passed as keyword arguments to the underlying `mt.get_transformers_model().generate()` method:
```python
mt.translate(
text,
source=dlt.lang.GERMAN,
target=dlt.lang.SPANISH,
generation_options=dict(num_beams=5, max_length=...)
)
```
Learn more in the [huggingface docs](https://huggingface.co/transformers/main_classes/model.html#transformers.generation_utils.GenerationMixin.generate).
================================================
FILE: docs/references.md
================================================
# API Reference
## dlt.TranslationModel
### __init__
```python
dlt.TranslationModel.__init__(self, model_or_path: str = 'm2m100', tokenizer_path: str = None, device: str = 'auto', model_family: str = None, model_options: dict = None, tokenizer_options: dict = None)
```
*Instantiates a multilingual transformer model for translation.*
| Parameter | Type | Default | Description |
|-|-|-|-|
| **model_or_path** | *str* | `m2m100` | The path or the name of the model. Equivalent to the first argument of `AutoModel.from_pretrained()`. You can also specify shorthands ("mbart50" and "m2m100").
| **tokenizer_path** | *str* | *optional* | The path to the tokenizer. By default, it will be set to `model_or_path`.
| **device** | *str* | `auto` | "cpu", "gpu" or "auto". If it's set to "auto", will try to select a GPU when available or else fall back to CPU.
| **model_family** | *str* | *optional* | Either "mbart50" or "m2m100". By default, it will be inferred based on `model_or_path`. Needs to be explicitly set if `model_or_path` is a path.
| **model_options** | *dict* | *optional* | The keyword arguments passed to the model, which is a transformer for conditional generation.
| **tokenizer_options** | *dict* | *optional* | The keyword arguments passed to the model's tokenizer.
### translate
```python
dlt.TranslationModel.translate(self, text: Union[str, List[str]], source: str, target: str, batch_size: int = 32, verbose: bool = False, generation_options: dict = None) -> Union[str, List[str]]
```
*Translates a string or a list of strings from a source to a target language.*
| Parameter | Type | Default | Description |
|-|-|-|-|
| **text** | *Union[str, List[str]]* | *required* | The content you want to translate.
| **source** | *str* | *required* | The language of the original text.
| **target** | *str* | *required* | The language of the translated text.
| **batch_size** | *int* | `32` | The number of samples to load at once. If set to `None`, it will process everything at once.
| **verbose** | *bool* | `False` | Whether to display the progress bar for every batch processed.
| **generation_options** | *dict* | *optional* | The keyword arguments passed to `model.generate()`, where `model` is the underlying transformers model.
Note:
- Run `print(dlt.utils.available_languages())` to see what's available.
- A smaller value is preferred for `batch_size` if your (video) RAM is limited.
### get_transformers_model
```python
dlt.TranslationModel.get_transformers_model(self)
```
*Retrieve the underlying mBART transformer model.*
### get_tokenizer
```python
dlt.TranslationModel.get_tokenizer(self)
```
*Retrieve the mBART huggingface tokenizer.*
### available_codes
```python
dlt.TranslationModel.available_codes(self) -> List[str]
```
*Returns all the available codes for a given `dlt.TranslationModel`
instance.*
### available_languages
```python
dlt.TranslationModel.available_languages(self) -> List[str]
```
*Returns all the available languages for a given `dlt.TranslationModel`
instance.*
### get_lang_code_map
```python
dlt.TranslationModel.get_lang_code_map(self) -> Dict[str, str]
```
*Returns the language -> codes dictionary for a given `dlt.TranslationModel`
instance.*
### save_obj
```python
dlt.TranslationModel.save_obj(self, path: str = 'saved_model') -> None
```
*Saves your model as a torch object and save your tokenizer.*
| Parameter | Type | Default | Description |
|-|-|-|-|
| **path** | *str* | `saved_model` | The directory where you want to save your model and tokenizer
### load_obj
```python
dlt.TranslationModel.load_obj(path: str = 'saved_model', **kwargs)
```
*Initialize `dlt.TranslationModel` from the torch object and tokenizer
saved with `dlt.TranslationModel.save_obj`*
| Parameter | Type | Default | Description |
|-|-|-|-|
| **path** | *str* | `saved_model` | The directory where your torch model and tokenizer are stored
## dlt.utils
### get_lang_code_map
```python
dlt.utils.get_lang_code_map(weights: str = 'mbart50') -> Dict[str, str]
```
*Get a dictionary mapping a language -> code for a given model. The code will depend on the model you choose.*
| Parameter | Type | Default | Description |
|-|-|-|-|
| **weights** | *str* | `mbart50` | The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 languages available to use.
### available_codes
```python
dlt.utils.available_codes(weights: str = 'mbart50') -> List[str]
```
*Get all the codes available for a given model. The code format will depend on the model you select.*
| Parameter | Type | Default | Description |
|-|-|-|-|
| **weights** | *str* | `mbart50` | The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 codes available to use.
### available_languages
```python
dlt.utils.available_languages(weights: str = 'mbart50') -> List[str]
```
*Get all the languages available for a given model.*
| Parameter | Type | Default | Description |
|-|-|-|-|
| **weights** | *str* | `mbart50` | The name of the model you are using. For example, "mbart50" is the multilingual BART Large with 50 languages available to use.
================================================
FILE: docs/requirements.txt
================================================
mkdocs
mkdocs-material
jinja2<3.1.0
================================================
FILE: mkdocs-rtd.yml
================================================
site_name: DL Translate
repo_url: https://github.com/xhluca/dl-translate
edit_uri: blob/main/docs/
nav:
- index.md
- references.md
- contributing.md
- available_languages.md
theme: readthedocs
================================================
FILE: mkdocs.yml
================================================
site_name: DL Translate
repo_url: https://github.com/xhluca/dl-translate
nav:
- index.md
- references.md
- contributing.md
- available_languages.md
theme: material
markdown_extensions:
- pymdownx.highlight
- pymdownx.superfences
================================================
FILE: scripts/generate_langs.py
================================================
import json
import os
def name_to_var(lang_name):
return (
lang_name.upper()
.replace(" ", "_")
.replace("(", "")
.replace(")", "")
.replace("-", "_")
.replace("'", "")
)
def load_json(name):
filepath = os.path.join(os.path.dirname(__file__), "langs_coverage", f"{name}.json")
return json.loads(open(filepath).read())
auto_gen_comment = f"# Auto-generated. Do not modify, use {__file__} instead.\n"
name2json = {}
for name in ["m2m100", "mbart50", "nllb200"]:
name2json[name] = lang2code = load_json(name)
with open(f"./dl_translate/lang/{name}.py", "w") as f:
f.write(auto_gen_comment)
for lang, code in lang2code.items():
f.write(f'{name_to_var(lang)} = "{lang}"\n')
with open("./dl_translate/_pairs.py", "w") as f:
f.write(auto_gen_comment)
for name, lang2code in name2json.items():
f.write(f"_PAIRS_{name.upper()} = {tuple(lang2code.items())}\n")
================================================
FILE: scripts/langs_coverage/m2m100.json
================================================
{
"Afrikaans": "af",
"Amharic": "am",
"Arabic": "ar",
"Asturian": "ast",
"Azerbaijani": "az",
"Bashkir": "ba",
"Belarusian": "be",
"Bulgarian": "bg",
"Bengali": "bn",
"Breton": "br",
"Bosnian": "bs",
"Catalan": "ca",
"Valencian": "ca",
"Cebuano": "ceb",
"Czech": "cs",
"Welsh": "cy",
"Danish": "da",
"German": "de",
"Greek": "el",
"English": "en",
"Spanish": "es",
"Estonian": "et",
"Persian": "fa",
"Fulah": "ff",
"Finnish": "fi",
"French": "fr",
"Western Frisian": "fy",
"Irish": "ga",
"Gaelic": "gd",
"Scottish Gaelic": "gd",
"Galician": "gl",
"Gujarati": "gu",
"Hausa": "ha",
"Hebrew": "he",
"Hindi": "hi",
"Croatian": "hr",
"Haitian": "ht",
"Haitian Creole": "ht",
"Hungarian": "hu",
"Armenian": "hy",
"Indonesian": "id",
"Igbo": "ig",
"Iloko": "ilo",
"Icelandic": "is",
"Italian": "it",
"Japanese": "ja",
"Javanese": "jv",
"Georgian": "ka",
"Kazakh": "kk",
"Khmer": "km",
"Central Khmer": "km",
"Kannada": "kn",
"Korean": "ko",
"Luxembourgish": "lb",
"Letzeburgesch": "lb",
"Ganda": "lg",
"Lingala": "ln",
"Lao": "lo",
"Lithuanian": "lt",
"Latvian": "lv",
"Malagasy": "mg",
"Macedonian": "mk",
"Malayalam": "ml",
"Mongolian": "mn",
"Marathi": "mr",
"Malay": "ms",
"Burmese": "my",
"Nepali": "ne",
"Dutch": "nl",
"Flemish": "nl",
"Norwegian": "no",
"Northern Sotho": "ns",
"Occitan": "oc",
"Oriya": "or",
"Panjabi": "pa",
"Punjabi": "pa",
"Polish": "pl",
"Pushto": "ps",
"Pashto": "ps",
"Portuguese": "pt",
"Romanian": "ro",
"Moldavian": "ro",
"Moldovan": "ro",
"Russian": "ru",
"Sindhi": "sd",
"Sinhala": "si",
"Sinhalese": "si",
"Slovak": "sk",
"Slovenian": "sl",
"Somali": "so",
"Albanian": "sq",
"Serbian": "sr",
"Swati": "ss",
"Sundanese": "su",
"Swedish": "sv",
"Swahili": "sw",
"Tamil": "ta",
"Thai": "th",
"Tagalog": "tl",
"Tswana": "tn",
"Turkish": "tr",
"Ukrainian": "uk",
"Urdu": "ur",
"Uzbek": "uz",
"Vietnamese": "vi",
"Wolof": "wo",
"Xhosa": "xh",
"Yiddish": "yi",
"Yoruba": "yo",
"Chinese": "zh",
"Zulu": "zu"
}
================================================
FILE: scripts/langs_coverage/mbart50.json
================================================
{
"Arabic": "ar_AR",
"Czech": "cs_CZ",
"German": "de_DE",
"English": "en_XX",
"Spanish": "es_XX",
"Estonian": "et_EE",
"Finnish": "fi_FI",
"French": "fr_XX",
"Gujarati": "gu_IN",
"Hindi": "hi_IN",
"Italian": "it_IT",
"Japanese": "ja_XX",
"Kazakh": "kk_KZ",
"Korean": "ko_KR",
"Lithuanian": "lt_LT",
"Latvian": "lv_LV",
"Burmese": "my_MM",
"Nepali": "ne_NP",
"Dutch": "nl_XX",
"Romanian": "ro_RO",
"Russian": "ru_RU",
"Sinhala": "si_LK",
"Turkish": "tr_TR",
"Vietnamese": "vi_VN",
"Chinese": "zh_CN",
"Afrikaans": "af_ZA",
"Azerbaijani": "az_AZ",
"Bengali": "bn_IN",
"Persian": "fa_IR",
"Hebrew": "he_IL",
"Croatian": "hr_HR",
"Indonesian": "id_ID",
"Georgian": "ka_GE",
"Khmer": "km_KH",
"Macedonian": "mk_MK",
"Malayalam": "ml_IN",
"Mongolian": "mn_MN",
"Marathi": "mr_IN",
"Polish": "pl_PL",
"Pashto": "ps_AF",
"Portuguese": "pt_XX",
"Swedish": "sv_SE",
"Swahili": "sw_KE",
"Tamil": "ta_IN",
"Telugu": "te_IN",
"Thai": "th_TH",
"Tagalog": "tl_XX",
"Ukrainian": "uk_UA",
"Urdu": "ur_PK",
"Xhosa": "xh_ZA",
"Galician": "gl_ES",
"Slovene": "sl_SI"
}
================================================
FILE: scripts/langs_coverage/nllb200.json
================================================
{
"Acehnese (Arabic script)": "ace_Arab",
"Acehnese (Latin script)": "ace_Latn",
"Mesopotamian Arabic": "acm_Arab",
"Ta'izzi-Adeni Arabic": "acq_Arab",
"Tunisian Arabic": "aeb_Arab",
"Afrikaans": "afr_Latn",
"South Levantine Arabic": "ajp_Arab",
"Akan": "aka_Latn",
"Amharic": "amh_Ethi",
"North Levantine Arabic": "apc_Arab",
"Modern Standard Arabic": "arb_Arab",
"Modern Standard Arabic (Romanized)": "arb_Latn",
"Najdi Arabic": "ars_Arab",
"Moroccan Arabic": "ary_Arab",
"Egyptian Arabic": "arz_Arab",
"Assamese": "asm_Beng",
"Asturian": "ast_Latn",
"Awadhi": "awa_Deva",
"Central Aymara": "ayr_Latn",
"South Azerbaijani": "azb_Arab",
"North Azerbaijani": "azj_Latn",
"Bashkir": "bak_Cyrl",
"Bambara": "bam_Latn",
"Balinese": "ban_Latn",
"Belarusian": "bel_Cyrl",
"Bemba": "bem_Latn",
"Bengali": "ben_Beng",
"Bhojpuri": "bho_Deva",
"Banjar (Arabic script)": "bjn_Arab",
"Banjar (Latin script)": "bjn_Latn",
"Standard Tibetan": "bod_Tibt",
"Bosnian": "bos_Latn",
"Buginese": "bug_Latn",
"Bulgarian": "bul_Cyrl",
"Catalan": "cat_Latn",
"Cebuano": "ceb_Latn",
"Czech": "ces_Latn",
"Chokwe": "cjk_Latn",
"Central Kurdish": "ckb_Arab",
"Crimean Tatar": "crh_Latn",
"Welsh": "cym_Latn",
"Danish": "dan_Latn",
"German": "deu_Latn",
"Southwestern Dinka": "dik_Latn",
"Dyula": "dyu_Latn",
"Dzongkha": "dzo_Tibt",
"Greek": "ell_Grek",
"English": "eng_Latn",
"Esperanto": "epo_Latn",
"Estonian": "est_Latn",
"Basque": "eus_Latn",
"Ewe": "ewe_Latn",
"Faroese": "fao_Latn",
"Fijian": "fij_Latn",
"Finnish": "fin_Latn",
"Fon": "fon_Latn",
"French": "fra_Latn",
"Friulian": "fur_Latn",
"Nigerian Fulfulde": "fuv_Latn",
"Scottish Gaelic": "gla_Latn",
"Irish": "gle_Latn",
"Galician": "glg_Latn",
"Guarani": "grn_Latn",
"Gujarati": "guj_Gujr",
"Haitian Creole": "hat_Latn",
"Hausa": "hau_Latn",
"Hebrew": "heb_Hebr",
"Hindi": "hin_Deva",
"Chhattisgarhi": "hne_Deva",
"Croatian": "hrv_Latn",
"Hungarian": "hun_Latn",
"Armenian": "hye_Armn",
"Igbo": "ibo_Latn",
"Ilocano": "ilo_Latn",
"Indonesian": "ind_Latn",
"Icelandic": "isl_Latn",
"Italian": "ita_Latn",
"Javanese": "jav_Latn",
"Japanese": "jpn_Jpan",
"Kabyle": "kab_Latn",
"Jingpho": "kac_Latn",
"Kamba": "kam_Latn",
"Kannada": "kan_Knda",
"Kashmiri (Arabic script)": "kas_Arab",
"Kashmiri (Devanagari script)": "kas_Deva",
"Georgian": "kat_Geor",
"Central Kanuri (Arabic script)": "knc_Arab",
"Central Kanuri (Latin script)": "knc_Latn",
"Kazakh": "kaz_Cyrl",
"Kabiyè": "kbp_Latn",
"Kabuverdianu": "kea_Latn",
"Khmer": "khm_Khmr",
"Kikuyu": "kik_Latn",
"Kinyarwanda": "kin_Latn",
"Kyrgyz": "kir_Cyrl",
"Kimbundu": "kmb_Latn",
"Northern Kurdish": "kmr_Latn",
"Kikongo": "kon_Latn",
"Korean": "kor_Hang",
"Lao": "lao_Laoo",
"Ligurian": "lij_Latn",
"Limburgish": "lim_Latn",
"Lingala": "lin_Latn",
"Lithuanian": "lit_Latn",
"Lombard": "lmo_Latn",
"Latgalian": "ltg_Latn",
"Luxembourgish": "ltz_Latn",
"Luba-Kasai": "lua_Latn",
"Ganda": "lug_Latn",
"Luo": "luo_Latn",
"Mizo": "lus_Latn",
"Standard Latvian": "lvs_Latn",
"Magahi": "mag_Deva",
"Maithili": "mai_Deva",
"Malayalam": "mal_Mlym",
"Marathi": "mar_Deva",
"Minangkabau (Arabic script)": "min_Arab",
"Minangkabau (Latin script)": "min_Latn",
"Macedonian": "mkd_Cyrl",
"Plateau Malagasy": "plt_Latn",
"Maltese": "mlt_Latn",
"Meitei (Bengali script)": "mni_Beng",
"Halh Mongolian": "khk_Cyrl",
"Mossi": "mos_Latn",
"Maori": "mri_Latn",
"Burmese": "mya_Mymr",
"Dutch": "nld_Latn",
"Norwegian Nynorsk": "nno_Latn",
"Norwegian Bokmål": "nob_Latn",
"Nepali": "npi_Deva",
"Northern Sotho": "nso_Latn",
"Nuer": "nus_Latn",
"Nyanja": "nya_Latn",
"Occitan": "oci_Latn",
"West Central Oromo": "gaz_Latn",
"Odia": "ory_Orya",
"Pangasinan": "pag_Latn",
"Eastern Panjabi": "pan_Guru",
"Papiamento": "pap_Latn",
"Western Persian": "pes_Arab",
"Polish": "pol_Latn",
"Portuguese": "por_Latn",
"Dari": "prs_Arab",
"Southern Pashto": "pbt_Arab",
"Ayacucho Quechua": "quy_Latn",
"Romanian": "ron_Latn",
"Rundi": "run_Latn",
"Russian": "rus_Cyrl",
"Sango": "sag_Latn",
"Sanskrit": "san_Deva",
"Santali": "sat_Olck",
"Sicilian": "scn_Latn",
"Shan": "shn_Mymr",
"Sinhala": "sin_Sinh",
"Slovak": "slk_Latn",
"Slovenian": "slv_Latn",
"Samoan": "smo_Latn",
"Shona": "sna_Latn",
"Sindhi": "snd_Arab",
"Somali": "som_Latn",
"Southern Sotho": "sot_Latn",
"Spanish": "spa_Latn",
"Tosk Albanian": "als_Latn",
"Sardinian": "srd_Latn",
"Serbian": "srp_Cyrl",
"Swati": "ssw_Latn",
"Sundanese": "sun_Latn",
"Swedish": "swe_Latn",
"Swahili": "swh_Latn",
"Silesian": "szl_Latn",
"Tamil": "tam_Taml",
"Tatar": "tat_Cyrl",
"Telugu": "tel_Telu",
"Tajik": "tgk_Cyrl",
"Tagalog": "tgl_Latn",
"Thai": "tha_Thai",
"Tigrinya": "tir_Ethi",
"Tamasheq (Latin script)": "taq_Latn",
"Tamasheq (Tifinagh script)": "taq_Tfng",
"Tok Pisin": "tpi_Latn",
"Tswana": "tsn_Latn",
"Tsonga": "tso_Latn",
"Turkmen": "tuk_Latn",
"Tumbuka": "tum_Latn",
"Turkish": "tur_Latn",
"Twi": "twi_Latn",
"Central Atlas Tamazight": "tzm_Tfng",
"Uyghur": "uig_Arab",
"Ukrainian": "ukr_Cyrl",
"Umbundu": "umb_Latn",
"Urdu": "urd_Arab",
"Northern Uzbek": "uzn_Latn",
"Venetian": "vec_Latn",
"Vietnamese": "vie_Latn",
"Waray": "war_Latn",
"Wolof": "wol_Latn",
"Xhosa": "xho_Latn",
"Eastern Yiddish": "ydd_Hebr",
"Yoruba": "yor_Latn",
"Yue Chinese": "yue_Hant",
"Chinese (Simplified)": "zho_Hans",
"Chinese (Traditional)": "zho_Hant",
"Standard Malay": "zsm_Latn",
"Zulu": "zul_Latn"
}
================================================
FILE: scripts/render_available_langs.py
================================================
import os
import json
from jinja2 import Template
def load_json(name):
filepath = os.path.join(os.path.dirname(__file__), "langs_coverage", f"{name}.json")
return json.loads(open(filepath).read())
template_values = {}
for name in ["m2m100", "mbart50", "nllb200"]:
content = ""
di = load_json(name)
content += "| Language Name | Code |\n"
content += "| --- | --- |\n"
for key, val in di.items():
content += f"| {key} | {val} |\n"
template_values[name] = content
template_path = os.path.join(
os.path.dirname(__file__), "templates", "available_languages.md.jinja2"
)
save_path = os.path.join(
os.path.dirname(__file__), "..", "docs", "available_languages.md"
)
with open(template_path) as f:
template = Template(f.read())
rendered = template.render(template_values)
with open(save_path, "w") as f:
f.write(rendered)
================================================
FILE: scripts/render_references.py
================================================
import os
from typing import NamedTuple, List, Optional, Any, NamedTuple
import inspect
from jinja2 import Template
import dl_translate as dlt
type2str = {
int: "int",
float: "float",
str: "str",
bool: "bool",
dict: "dict",
inspect._empty: "unspecified",
Optional[Any]: "optional",
}
default2str = {inspect._empty: "*required*", None: "*optional*"}
def preprocess_annot(annotation):
annotation = type2str.get(annotation, str(annotation))
return annotation.replace("typing.", "")
def preprocess_default(default):
default = default2str.get(default, f"`{default}`")
return default
class FunctionReference:
def __init__(self, function, modname=None):
self.func = function
if modname is None:
self.modname = inspect.getmodule(self.func).__name__.replace(
"dl_translate", "dlt"
)
else:
self.modname = modname
@property
def name(self):
return self.func.__name__
@property
def signature(self):
return inspect.signature(self.func)
@property
def sig_desc(self):
return self.modname + "." + self.name + str(self.signature)
@property
def doc(self):
doc_template = Template(inspect.getdoc(self.func))
kwargs = {"params": "| Parameter | Type | Default | Description |\n|-|-|-|-|"}
for arg_name, param in self.signature.parameters.items():
if arg_name == "self":
continue
annot = preprocess_annot(param.annotation)
default = preprocess_default(param.default)
kwargs[arg_name] = f"| **{arg_name}** | *{annot}* | {default} |"
return doc_template.render(**kwargs)
class ModuleReferences(NamedTuple):
name: str
funcs: List[FunctionReference]
template_path = os.path.join(
os.path.dirname(__file__), "templates", "references.md.jinja2"
)
save_path = os.path.join(os.path.dirname(__file__), "..", "docs", "references.md")
with open(template_path) as f:
template = Template(f.read())
rendered = template.render(
modules=[
ModuleReferences(
"dlt.TranslationModel",
[
FunctionReference(
dlt.TranslationModel.__init__, "dlt.TranslationModel"
),
FunctionReference(
dlt.TranslationModel.translate, "dlt.TranslationModel"
),
FunctionReference(
dlt.TranslationModel.get_transformers_model, "dlt.TranslationModel"
),
FunctionReference(
dlt.TranslationModel.get_tokenizer, "dlt.TranslationModel"
),
FunctionReference(
dlt.TranslationModel.available_codes, "dlt.TranslationModel"
),
FunctionReference(
dlt.TranslationModel.available_languages, "dlt.TranslationModel"
),
FunctionReference(
dlt.TranslationModel.get_lang_code_map, "dlt.TranslationModel"
),
FunctionReference(
dlt.TranslationModel.save_obj, "dlt.TranslationModel"
),
FunctionReference(
dlt.TranslationModel.load_obj, "dlt.TranslationModel"
),
],
),
ModuleReferences(
"dlt.utils",
[
FunctionReference(dlt.utils.get_lang_code_map),
FunctionReference(dlt.utils.available_codes),
FunctionReference(dlt.utils.available_languages),
],
),
]
)
with open(save_path, "w") as f:
f.write(rendered)
================================================
FILE: scripts/templates/available_languages.md.jinja2
================================================
# Languages Available
This page gives all the languages available for each model family.
## MBart 50
{{mbart50}}
## M2M-100
{{m2m100}}
## NLLB-200
{{nllb200}}
================================================
FILE: scripts/templates/references.md.jinja2
================================================
# API Reference
{% for module in modules %}
## {{module.name}}
{% for func in module.funcs %}
### {{ func.name }}
```python
{{ func.sig_desc }}
```
{{ func.doc }}
{% endfor %}
{% endfor %}
================================================
FILE: setup.py
================================================
import setuptools
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
setuptools.setup(
name="dl-translate",
version="0.3.1",
author="Xing Han Lu",
author_email="github@xinghanlu.com",
description="A deep learning-based translation library built on Huggingface transformers",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/xhlulu/dl-translate",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
packages=setuptools.find_packages(),
python_requires=">=3.7",
install_requires=[
"transformers>=4.30.2",
"torch>=2.0.0",
"sentencepiece",
"protobuf",
"tqdm",
],
extras_require={"dev": ["pytest", "black", "jinja2", "mkdocs", "mkdocs-material"]},
)
================================================
FILE: tests/long/test_save_load.py
================================================
import os
import dl_translate as dlt
def test_save():
mt = dlt.TranslationModel()
mt.save_obj("saved_model")
assert os.path.exists("saved_model/weights.pt")
assert os.path.exists("saved_model/tokenizer_config.json")
def test_load():
mt = dlt.TranslationModel.load_obj("saved_model")
assert isinstance(mt, dlt.TranslationModel)
================================================
FILE: tests/long/test_translate.py
================================================
import dl_translate as dlt
def test_translate():
mt = dlt.TranslationModel()
msg_en = "Hello everyone, how are you?"
assert (
mt.translate(msg_en, source="English", target="Spanish")
== "Hola a todos, ¿cómo estás?"
)
fr_1 = mt.translate(msg_en, source="English", target="French")
ch = mt.translate(msg_en, source="English", target="Chinese")
fr_2 = mt.translate([msg_en, msg_en + msg_en], source="English", target="French")
assert fr_1 == fr_2[0]
assert ch != fr_1
def test_mbart50():
mt = dlt.TranslationModel("mbart50")
msg_en = "Hello everyone, how are you?"
fr_1 = mt.translate(msg_en, source="English", target="French")
ch = mt.translate(msg_en, source="English", target="Chinese")
fr_2 = mt.translate([msg_en, msg_en + msg_en], source="English", target="French")
assert fr_1 == fr_2[0]
assert ch != fr_1
================================================
FILE: tests/quick/test_lang.py
================================================
import dl_translate as dlt
from dl_translate._pairs import _PAIRS_MBART50, _PAIRS_M2M100
def test_lang():
for l, _ in _PAIRS_M2M100:
assert getattr(dlt.lang, l.upper().replace(" ", "_")) == l
def test_lang_m2m100():
for l, _ in _PAIRS_M2M100:
assert getattr(dlt.lang.m2m100, l.upper().replace(" ", "_")) == l
def test_lang_mbart50():
for l, _ in _PAIRS_MBART50:
assert getattr(dlt.lang.mbart50, l.upper().replace(" ", "_")) == l
================================================
FILE: tests/quick/test_translation_model.py
================================================
import pytest
import torch
import dl_translate as dlt
from dl_translate._translation_model import (
_resolve_lang_codes,
_select_device,
_infer_model_or_path,
_infer_model_family,
)
def test_resolve_lang_codes_mbart50():
sources = [dlt.lang.FRENCH, "fr_XX", "French"]
targets = [dlt.lang.ENGLISH, "en_XX", "English"]
for source, target in zip(sources, targets):
s = _resolve_lang_codes(source, "source", "mbart50")
t = _resolve_lang_codes(target, "target", "mbart50")
assert s == "fr_XX"
assert t == "en_XX"
def test_resolve_lang_codes_m2m100():
sources = [dlt.lang.m2m100.FRENCH, "fr", "French"]
targets = [dlt.lang.m2m100.ENGLISH, "en", "English"]
for source, target in zip(sources, targets):
s = _resolve_lang_codes(source, "source", "m2m100")
t = _resolve_lang_codes(target, "target", "m2m100")
assert s == "fr"
assert t == "en"
def test_resolve_lang_codes_m2m100():
sources = [dlt.lang.nllb200.FRENCH, "fra_Latn", "French"]
targets = [dlt.lang.nllb200.ENGLISH, "eng_Latn", "English"]
for source, target in zip(sources, targets):
s = _resolve_lang_codes(source, "source", "nllb200")
t = _resolve_lang_codes(target, "target", "nllb200")
assert s == "fra_Latn"
assert t == "eng_Latn"
sources = ["Central Kanuri (Latin script)"]
targets = ["Ta'izzi-Adeni Arabic"]
for source, target in zip(sources, targets):
s = _resolve_lang_codes(source, "source", "nllb200")
t = _resolve_lang_codes(target, "target", "nllb200")
assert s == "knc_Latn"
assert t == "acq_Arab"
def test_select_device():
assert _select_device("cpu") == torch.device("cpu")
assert _select_device("gpu") == torch.device("cuda")
assert _select_device("cuda:0") == torch.device("cuda", index=0)
if torch.cuda.is_available():
assert _select_device("auto") == torch.device("cuda")
else:
assert _select_device("auto") == torch.device("cpu")
def test_infer_model_or_path():
assert _infer_model_or_path("mbart50") == "facebook/mbart-large-50-many-to-many-mmt"
assert _infer_model_or_path("m2m100") == "facebook/m2m100_418M"
assert _infer_model_or_path("m2m100-small") == "facebook/m2m100_418M"
assert _infer_model_or_path("m2m100-medium") == "facebook/m2m100_1.2B"
assert _infer_model_or_path("non-existing-value") == "non-existing-value"
def test_infer_model_family():
assert _infer_model_family("facebook/mbart-large-50-many-to-many-mmt") == "mbart50"
assert _infer_model_family("facebook/m2m100_418M") == "m2m100"
assert _infer_model_family("facebook/m2m100_1.2B") == "m2m100"
with pytest.raises(ValueError):
_infer_model_family("non-existing-value")
================================================
FILE: tests/quick/test_utils.py
================================================
import pytest
from dl_translate import utils
from dl_translate._pairs import _PAIRS_MBART50, _PAIRS_M2M100, _PAIRS_NLLB200
def test_dict_from_weights():
weights = [
"mbart50",
"mbart-large-50-many-to-many-mmt",
"facebook/mbart-large-50-many-to-many-mmt",
"m2m100",
"m2m100_418M",
"m2m100_1.2B",
"facebook/m2m100_418M",
"facebook/m2m100_1.2B",
]
valid_keys = ["langs", "codes", "pairs"]
for w in weights:
assert type(utils._dict_from_weights(w)) is dict
keys = utils._dict_from_weights(w).keys()
for key in valid_keys:
assert key in keys
def test_dict_from_weights_exception():
with pytest.raises(ValueError):
utils._dict_from_weights("incorrect")
def test_available_languages():
assert utils.available_languages() == utils.available_languages()
langs = utils.available_languages()
for lang, _ in _PAIRS_M2M100:
assert lang in langs
langs = utils.available_languages("mbart50")
for lang, _ in _PAIRS_MBART50:
assert lang in langs
langs = utils.available_languages("nllb200")
for lang, _ in _PAIRS_NLLB200:
assert lang in langs
def test_available_codes():
assert utils.available_codes() == utils.available_codes("m2m100")
codes = utils.available_codes()
for _, code in _PAIRS_M2M100:
assert code in codes
codes = utils.available_codes("mbart50")
for _, code in _PAIRS_MBART50:
assert code in codes
codes = utils.available_codes("nllb200")
for _, code in _PAIRS_NLLB200:
assert code in codes