Repository: allenai/tango
Branch: main
Commit: 6aaa8ff0f203
Files: 329
Total size: 1.2 MB
Directory structure:
gitextract_p3vof5f6/
├── .dockerignore
├── .github/
│ ├── CONTRIBUTING.md
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.yml
│ │ ├── documentation.yml
│ │ └── feature_request.yml
│ ├── dependabot.yml
│ └── workflows/
│ ├── changelog.yml
│ ├── docker.yml
│ ├── docker_testing.yml
│ ├── integration_tests.yml
│ ├── main.yml
│ └── update_dependency_pr.yml
├── .gitignore
├── .readthedocs.yaml
├── CHANGELOG.md
├── CITATION.cff
├── Dockerfile
├── Dockerfile.test
├── LICENSE
├── Makefile
├── README.md
├── RELEASE_PROCESS.md
├── docs/
│ ├── .gitignore
│ ├── Makefile
│ ├── make.bat
│ └── source/
│ ├── _static/
│ │ └── css/
│ │ └── custom.css
│ ├── api/
│ │ ├── commands.rst
│ │ ├── components/
│ │ │ ├── executor.rst
│ │ │ ├── format.rst
│ │ │ ├── index.rst
│ │ │ ├── step.rst
│ │ │ ├── step_cache.rst
│ │ │ ├── step_graph.rst
│ │ │ ├── step_info.rst
│ │ │ └── workspace.rst
│ │ ├── det_hash.rst
│ │ ├── exceptions.rst
│ │ ├── integrations/
│ │ │ ├── beaker.rst
│ │ │ ├── datasets.rst
│ │ │ ├── fairscale.rst
│ │ │ ├── flax.rst
│ │ │ ├── gs.rst
│ │ │ ├── index.rst
│ │ │ ├── torch.rst
│ │ │ ├── transformers.rst
│ │ │ └── wandb.rst
│ │ ├── logging.rst
│ │ ├── sequences.rst
│ │ ├── settings.rst
│ │ └── utilities.rst
│ ├── conf.py
│ ├── examples/
│ │ ├── euler.md
│ │ ├── eval_p3.md
│ │ ├── index.rst
│ │ └── train_lm.md
│ ├── faq.md
│ ├── first_steps.md
│ ├── index.md
│ └── installation.md
├── examples/
│ ├── euler/
│ │ ├── README.md
│ │ ├── complex_arithmetic.py
│ │ ├── euler.jsonnet
│ │ ├── euler_general.jsonnet
│ │ └── run.sh
│ ├── eval_p3/
│ │ ├── README.md
│ │ ├── config.jsonnet
│ │ └── eval.py
│ ├── finetune/
│ │ ├── __init__.py
│ │ ├── config.jsonnet
│ │ ├── snli_steps.py
│ │ └── test.py
│ ├── finetune_resnet/
│ │ ├── .gitignore
│ │ ├── config.jsonnet
│ │ └── resnet_steps.py
│ ├── flax/
│ │ ├── config.jsonnet
│ │ ├── run.sh
│ │ └── xsum.py
│ └── train_lm/
│ ├── .gitignore
│ ├── README.md
│ ├── config.jsonnet
│ ├── test.py
│ └── tokenize_step.py
├── integration_tests/
│ ├── README.md
│ └── fairscale_benchmarks/
│ ├── README.md
│ ├── config.jsonnet
│ └── run.sh
├── pyproject.toml
├── scripts/
│ ├── entrypoint.sh
│ ├── hash_extras.py
│ ├── prepare_changelog.py
│ ├── prepare_citation_cff.py
│ ├── release.sh
│ └── release_notes.py
├── tango/
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── aliases.py
│ │ ├── dataset_dict.py
│ │ ├── det_hash.py
│ │ ├── exceptions.py
│ │ ├── file_lock.py
│ │ ├── from_params.py
│ │ ├── lazy.py
│ │ ├── logging.py
│ │ ├── params.py
│ │ ├── registrable.py
│ │ ├── remote_utils.py
│ │ ├── sequences.py
│ │ ├── testing/
│ │ │ ├── __init__.py
│ │ │ └── steps.py
│ │ ├── tqdm.py
│ │ └── util.py
│ ├── executor.py
│ ├── executors/
│ │ ├── __init__.py
│ │ └── multicore_executor.py
│ ├── format.py
│ ├── integrations/
│ │ ├── __init__.py
│ │ ├── beaker/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── entrypoint.sh
│ │ │ ├── executor.py
│ │ │ ├── step_cache.py
│ │ │ └── workspace.py
│ │ ├── datasets/
│ │ │ └── __init__.py
│ │ ├── fairscale/
│ │ │ ├── __init__.py
│ │ │ ├── fsdp_config.py
│ │ │ ├── module_wrapper.py
│ │ │ └── training_engine.py
│ │ ├── flax/
│ │ │ ├── __init__.py
│ │ │ ├── data.py
│ │ │ ├── eval.py
│ │ │ ├── eval_callback.py
│ │ │ ├── format.py
│ │ │ ├── model.py
│ │ │ ├── optim.py
│ │ │ ├── train.py
│ │ │ ├── train_callback.py
│ │ │ ├── train_config.py
│ │ │ ├── util.py
│ │ │ └── wrapper.py
│ │ ├── gs/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── step_cache.py
│ │ │ └── workspace.py
│ │ ├── torch/
│ │ │ ├── __init__.py
│ │ │ ├── data.py
│ │ │ ├── eval.py
│ │ │ ├── eval_callback.py
│ │ │ ├── exceptions.py
│ │ │ ├── format.py
│ │ │ ├── model.py
│ │ │ ├── optim.py
│ │ │ ├── train.py
│ │ │ ├── train_callback.py
│ │ │ ├── train_config.py
│ │ │ ├── training_engine.py
│ │ │ └── util.py
│ │ ├── transformers/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── data.py
│ │ │ ├── finetune.py
│ │ │ ├── ia3.py
│ │ │ ├── model.py
│ │ │ ├── optim.py
│ │ │ ├── run_generation.py
│ │ │ ├── soft_prompt.py
│ │ │ └── tokenizer.py
│ │ └── wandb/
│ │ ├── __init__.py
│ │ ├── flax_train_callback.py
│ │ ├── step_cache.py
│ │ ├── torch_train_callback.py
│ │ ├── util.py
│ │ └── workspace.py
│ ├── py.typed
│ ├── settings.py
│ ├── step.py
│ ├── step_cache.py
│ ├── step_caches/
│ │ ├── __init__.py
│ │ ├── local_step_cache.py
│ │ ├── memory_step_cache.py
│ │ └── remote_step_cache.py
│ ├── step_graph.py
│ ├── step_info.py
│ ├── steps/
│ │ ├── __init__.py
│ │ ├── dataset_remix.py
│ │ ├── print.py
│ │ └── shell_step.py
│ ├── version.py
│ ├── workspace.py
│ └── workspaces/
│ ├── __init__.py
│ ├── local_workspace.py
│ ├── memory_workspace.py
│ └── remote_workspace.py
├── test_fixtures/
│ ├── __init__.py
│ ├── beaker/
│ │ └── nvidia_smi.yml
│ ├── common/
│ │ ├── params_example.jsonnet
│ │ └── params_example.yaml
│ ├── experiment/
│ │ ├── hello_world.jsonnet
│ │ ├── logging_check.jsonnet
│ │ ├── multiprocessing.jsonnet
│ │ ├── noisy.jsonnet
│ │ └── random.jsonnet
│ ├── integrations/
│ │ ├── __init__.py
│ │ ├── common/
│ │ │ └── __init__.py
│ │ ├── datasets/
│ │ │ └── config.json
│ │ ├── fairscale/
│ │ │ ├── __init__.py
│ │ │ ├── components.py
│ │ │ └── config.jsonnet
│ │ ├── flax/
│ │ │ ├── __init__.py
│ │ │ ├── config.jsonnet
│ │ │ └── xsum.py
│ │ └── torch/
│ │ ├── __init__.py
│ │ ├── eval.jsonnet
│ │ ├── train.jsonnet
│ │ ├── train_dist.jsonnet
│ │ └── train_streaming.jsonnet
│ └── v1_local_workspace/
│ └── cache/
│ ├── AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/
│ │ ├── cache-metadata.json
│ │ ├── conda-environment.yaml
│ │ ├── executor-metadata.json
│ │ ├── lock
│ │ ├── requirements.txt
│ │ └── stepinfo.dill
│ ├── CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/
│ │ ├── cache-metadata.json
│ │ ├── conda-environment.yaml
│ │ ├── executor-metadata.json
│ │ ├── lock
│ │ ├── requirements.txt
│ │ └── stepinfo.dill
│ ├── ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/
│ │ ├── cache-metadata.json
│ │ ├── conda-environment.yaml
│ │ ├── executor-metadata.json
│ │ ├── lock
│ │ ├── requirements.txt
│ │ └── stepinfo.dill
│ ├── MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/
│ │ ├── cache-metadata.json
│ │ ├── conda-environment.yaml
│ │ ├── executor-metadata.json
│ │ ├── lock
│ │ ├── requirements.txt
│ │ └── stepinfo.dill
│ ├── MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/
│ │ ├── cache-metadata.json
│ │ ├── conda-environment.yaml
│ │ ├── executor-metadata.json
│ │ ├── lock
│ │ ├── requirements.txt
│ │ └── stepinfo.dill
│ ├── SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/
│ │ ├── cache-metadata.json
│ │ ├── conda-environment.yaml
│ │ ├── executor-metadata.json
│ │ ├── lock
│ │ ├── requirements.txt
│ │ └── stepinfo.dill
│ └── SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/
│ ├── cache-metadata.json
│ ├── conda-environment.yaml
│ ├── executor-metadata.json
│ ├── lock
│ ├── requirements.txt
│ └── stepinfo.dill
└── tests/
├── __init__.py
├── common/
│ ├── __init__.py
│ ├── dataset_dict_test.py
│ ├── det_hash_test.py
│ ├── from_params_pep_563_test.py
│ ├── from_params_test.py
│ ├── params_test.py
│ ├── registrable_test.py
│ ├── sequences_test.py
│ └── util_test.py
├── end_to_end/
│ ├── test_dataset_dict_from_separate_steps.py
│ ├── test_lazy_input_with_another_step.py
│ ├── test_multicore_cli.py
│ ├── test_non_cacheable_into_cacheable_multiple_runs.py
│ ├── test_registered_runs.py
│ ├── test_run_single_step.py
│ ├── test_step_indexing.py
│ ├── test_steps_that_fail.py
│ └── test_uncacheable_leaf_steps.py
├── executor_test.py
├── executors/
│ ├── __init__.py
│ └── multicore_executor_test.py
├── format_test.py
├── integrations/
│ ├── __init__.py
│ ├── beaker/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── executor_test.py
│ │ ├── step_cache_test.py
│ │ └── workspace_test.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ └── dataset_test.py
│ ├── fairscale/
│ │ ├── __init__.py
│ │ └── train_test.py
│ ├── flax/
│ │ ├── __init__.py
│ │ ├── data_test.py
│ │ ├── format_test.py
│ │ ├── optim_test.py
│ │ └── train_test.py
│ ├── gs/
│ │ ├── __init__.py
│ │ ├── step_cache_test.py
│ │ └── workspace_test.py
│ ├── torch/
│ │ ├── __init__.py
│ │ ├── data_test.py
│ │ ├── det_hash_test.py
│ │ ├── eval_test.py
│ │ ├── format_test.py
│ │ ├── optim_test.py
│ │ ├── train_callback_test.py
│ │ ├── train_test.py
│ │ └── training_engine_test.py
│ ├── transformers/
│ │ ├── data_test.py
│ │ ├── finetune_test.py
│ │ ├── ia3_test.py
│ │ ├── run_generation_test.py
│ │ └── soft_prompt_test.py
│ └── wandb/
│ ├── __init__.py
│ ├── step_cache_test.py
│ └── workspace_test.py
├── main_test.py
├── step_caches/
│ ├── __init__.py
│ └── local_step_cache_test.py
├── step_graph_test.py
├── step_info_test.py
├── step_test.py
├── steps/
│ ├── __init__.py
│ ├── dataset_remix_test.py
│ └── shell_step_test.py
└── workspaces/
├── __init__.py
├── local_workspace_test.py
└── memory_workspace_test.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .dockerignore
================================================
.dockerignore
**.pyc
**/__pycache__
.gitignore
.git
.coverage
.mypy_cache
docs
examples
tests
test_fixtures
integration_tests
dist
*.egg-info
================================================
FILE: .github/CONTRIBUTING.md
================================================
# Contributing
Thanks for considering contributing! Please read this document to learn the various ways you can contribute to this project and how to go about doing it.
## Bug reports and feature requests
### Did you find a bug?
First, do [a quick search](https://github.com/allenai/tango/issues) to see whether your issue has already been reported.
If your issue has already been reported, please comment on the existing issue.
Otherwise, open [a new GitHub issue](https://github.com/allenai/tango/issues). Be sure to include a clear title
and description. The description should include as much relevant information as possible. The description should
explain how to reproduce the erroneous behavior as well as the behavior you expect to see. Ideally you would include a
code sample or an executable test case demonstrating the expected behavior.
### Do you have a suggestion for an enhancement or new feature?
We use GitHub issues to track feature requests. Before you create a feature request:
- Make sure you have a clear idea of the enhancement you would like. If you have a vague idea, consider discussing
it first on a GitHub issue.
- Check the documentation to make sure your feature does not already exist.
- Do [a quick search](https://github.com/allenai/tango/issues) to see whether your feature has already been suggested.
When creating your request, please:
- Provide a clear title and description.
- Explain why the enhancement would be useful. It may be helpful to highlight the feature in other libraries.
- Include code examples to demonstrate how the enhancement would be used.
## Making a pull request
When you're ready to contribute code to address an open issue, please follow these guidelines to help us be able to review your pull request (PR) quickly.
1. **Initial setup** (only do this once)
Expand details 👇
If you haven't already done so, please [fork](https://help.github.com/en/enterprise/2.13/user/articles/fork-a-repo) this repository on GitHub.
Then clone your fork locally with
git clone https://github.com/USERNAME/tango.git
or
git clone git@github.com:USERNAME/tango.git
At this point the local clone of your fork only knows that it came from _your_ repo, github.com/USERNAME/tango.git, but doesn't know anything the _main_ repo, [https://github.com/allenai/tango.git](https://github.com/allenai/tango). You can see this by running
git remote -v
which will output something like this:
origin https://github.com/USERNAME/tango.git (fetch)
origin https://github.com/USERNAME/tango.git (push)
This means that your local clone can only track changes from your fork, but not from the main repo, and so you won't be able to keep your fork up-to-date with the main repo over time. Therefore you'll need to add another "remote" to your clone that points to [https://github.com/allenai/tango.git](https://github.com/allenai/tango). To do this, run the following:
git remote add upstream https://github.com/allenai/tango.git
Now if you do `git remote -v` again, you'll see
origin https://github.com/USERNAME/tango.git (fetch)
origin https://github.com/USERNAME/tango.git (push)
upstream https://github.com/allenai/tango.git (fetch)
upstream https://github.com/allenai/tango.git (push)
Finally, you'll need to create a Python 3 virtual environment suitable for working on this project. There a number of tools out there that making working with virtual environments easier.
The most direct way is with the [`venv` module](https://docs.python.org/3.8/library/venv.html) in the standard library, but if you're new to Python or you don't already have a recent Python 3 version installed on your machine,
we recommend [Miniconda](https://docs.conda.io/en/latest/miniconda.html).
On Mac, for example, you can install Miniconda with [Homebrew](https://brew.sh/):
brew install miniconda
Then you can create and activate a new Python environment by running:
conda create -n tango python=3.9
conda activate tango
Once your virtual environment is activated, you can install your local clone in "editable mode" with
pip install -U pip setuptools wheel
pip install -e '.[dev,all]'
The "editable mode" comes from the `-e` argument to `pip`, and essential just creates a symbolic link from the site-packages directory of your virtual environment to the source code in your local clone. That way any changes you make will be immediately reflected in your virtual environment.
To test your installation, just run
tango info
2. **Ensure your fork is up-to-date**
Expand details 👇
Once you've added an "upstream" remote pointing to [https://github.com/allenai/tango.git](https://github.com/allenai/tango), keeping your fork up-to-date is easy:
git checkout main # if not already on main
git pull --rebase upstream main
git push
3. **Create a new branch to work on your fix or enhancement**
Expand details 👇
Committing directly to the main branch of your fork is not recommended. It will be easier to keep your fork clean if you work on a separate branch for each contribution you intend to make.
You can create a new branch with
# replace BRANCH with whatever name you want to give it
git checkout -b BRANCH
git push -u origin BRANCH
4. **Test your changes**
Expand details 👇
Our continuous integration (CI) testing runs [a number of checks](https://github.com/allenai/tango/actions) for each pull request on [GitHub Actions](https://github.com/features/actions). You can run most of these tests locally, which is something you should do _before_ opening a PR to help speed up the review process and make it easier for us.
First, you should run [`isort`](https://github.com/PyCQA/isort) and [`black`](https://github.com/psf/black) to make sure you code is formatted consistently.
Many IDEs support code formatters as plugins, so you may be able to setup isort and black to run automatically everytime you save.
For example, [`black.vim`](https://github.com/psf/black/tree/master/plugin) will give you this functionality in Vim. But both `isort` and `black` are also easy to run directly from the command line.
Just run this from the root of your clone:
isort .
black .
Our CI also uses [`ruff`](https://github.com/charliermarsh/ruff) to lint the code base and [`mypy`](http://mypy-lang.org/) for type-checking. You should run both of these next with
ruff check .
and
mypy .
We also strive to maintain high test coverage, so most contributions should include additions to [the unit tests](https://github.com/allenai/tango/tree/main/tests). These tests are run with [`pytest`](https://docs.pytest.org/en/latest/), which you can use to locally run any test modules that you've added or changed.
For example, if you've fixed a bug in `tango/a/b.py`, you can run the tests specific to that module with
pytest -v tests/a/b_test.py
If your contribution involves additions to any public part of the API, we require that you write docstrings
for each function, method, class, or module that you add.
See the [Writing docstrings](#writing-docstrings) section below for details on the syntax.
You should test to make sure the API documentation can build without errors by running
make docs
If the build fails, it's most likely due to small formatting issues. If the error message isn't clear, feel free to comment on this in your pull request.
And finally, please update the [CHANGELOG](https://github.com/allenai/tango/blob/main/CHANGELOG.md) with notes on your contribution in the "Unreleased" section at the top.
After all of the above checks have passed, you can now open [a new GitHub pull request](https://github.com/allenai/tango/pulls).
Make sure you have a clear description of the problem and the solution, and include a link to relevant issues.
We look forward to reviewing your PR!
### Writing docstrings
We use [Sphinx](https://www.sphinx-doc.org/en/master/index.html) to build our API docs, which automatically parses all docstrings
of public classes and methods. All docstrings should adhere to the [Numpy styling convention](https://www.sphinx-doc.org/en/master/usage/extensions/example_numpy.html).
## Adding a new integration
In order to add a new integration, there are several additional steps and guidelines you should follow
in addition to everything listed in [Making a pull request](#making-a-pull-request).
1. First start by creating a new submodule `tango.integrations.name_of_integration` and put all of the code for your integration in there.
2. Then you must add a module docstring to the `__init__.py` file of the submodule which imports all of the public components of the integration,
and defines the [`__all__`](https://docs.python.org/3/tutorial/modules.html#importing-from-a-package) special variable to include all of those components.
This ensures all of the public components will show up in the documentation.
3. Next that you should add unit tests of your code to `tests/integrations/name_of_integration/`.
4. Then add a new file `docs/source/api/integrations/name_of_integration.rst`, and include the directive:
```
.. automodule:: tango.integrations.name_of_integration
:members:
```
Take a look at any of the other files in that folder to see how it should look exactly.
5. And then add `name_of_integration` to the `toctree` in `docs/source/api/integrations/index.rst`.
6. After that, add any additional requirements that your integration depends on to `requirements.txt`. Be sure to put those under the "Extra dependencies for integrations" section,
and add the special inline comment `# needed by: name_of_integration`.
7. And finally, in the `checks` job definition in `.github/workflows/main.yml`, add a new object
to the matrix for your integration following the other examples there.
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yml
================================================
name: 🐛 Bug Report
description: Create a report to help us reproduce and fix the bug
labels: 'bug'
body:
- type: markdown
attributes:
value: >
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/allenai/tango/issues?q=is%3Aissue+sort%3Acreated-desc+).
- type: textarea
attributes:
label: 🐛 Describe the bug
description: |
Please provide a clear and concise description of what the bug is.
If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example:
```python
# All necessary imports at the beginning
import tango
# A succinct reproducing example trimmed down to the essential parts:
assert False is True, "Oh no!"
```
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
Please also paste or describe the results you observe along with the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
placeholder: |
A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
attributes:
label: Versions
description: |
Please run the following and paste the output below.
```sh
python --version && pip freeze
```
validations:
required: true
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!
================================================
FILE: .github/ISSUE_TEMPLATE/documentation.yml
================================================
name: 📚 Documentation
description: Report an issue related to https://ai2-tango.readthedocs.io/latest
labels: 'documentation'
body:
- type: textarea
attributes:
label: 📚 The doc issue
description: >
A clear and concise description of what content in https://ai2-tango.readthedocs.io/latest is an issue.
validations:
required: true
- type: textarea
attributes:
label: Suggest a potential alternative/fix
description: >
Tell us how we could improve the documentation in this regard.
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.yml
================================================
name: 🚀 Feature request
description: Submit a proposal/request for a new feature
labels: 'feature request'
body:
- type: textarea
attributes:
label: 🚀 The feature, motivation and pitch
description: >
A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
validations:
required: true
- type: textarea
attributes:
label: Alternatives
description: >
A description of any alternative solutions or features you've considered, if any.
- type: textarea
attributes:
label: Additional context
description: >
Add any other context or screenshots about the feature request.
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!
================================================
FILE: .github/dependabot.yml
================================================
version: 2
updates:
- package-ecosystem: pip
directory: "/"
schedule:
interval: "daily"
open-pull-requests-limit: 10
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
================================================
FILE: .github/workflows/changelog.yml
================================================
name: Changelog
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
on:
pull_request:
branches:
- main
paths:
- 'tango/**'
jobs:
changelog:
name: CHANGELOG
runs-on: ubuntu-latest
if: github.event_name == 'pull_request'
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Check that CHANGELOG has been updated
run: |
# If this step fails, this means you haven't updated the CHANGELOG.md
# file with notes on your contribution.
git diff --name-only $(git merge-base origin/main HEAD) | grep '^CHANGELOG.md$' && echo "Thanks for helping keep our CHANGELOG up-to-date!"
================================================
FILE: .github/workflows/docker.yml
================================================
name: Docker
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
on:
pull_request:
branches:
- main
paths:
- "Dockerfile"
- ".dockerignore"
- "pyproject.toml"
push:
tags:
- "v*.*.*"
jobs:
build:
name: Build (${{ matrix.build.tag }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
build:
- base_image: ghcr.io/allenai/pytorch:1.12.1-cuda11.3-python3.9
tag: cuda11.3
env:
IMAGE_NAME: ghcr.io/allenai/tango
steps:
- uses: actions/checkout@v3
- name: Build Docker image
run: |
docker build --build-arg BASE_IMAGE=${{ matrix.build.base_image }} -t "${IMAGE_NAME}:${{ matrix.build.tag }}" .
- name: Test Docker image
run: |
docker run --rm "${IMAGE_NAME}:${{ matrix.build.tag }}" info
- name: Log in to ghcr.io
if: github.event_name != 'pull_request'
run: |
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin
- name: Push latest to ghcr.io
if: github.event_name != 'pull_request'
run: |
docker push "${IMAGE_NAME}:${{ matrix.build.tag }}"
- name: Push release version to ghcr.io
if: startsWith(github.ref, 'refs/tags/')
run: |
GITHUB_TAG=${GITHUB_REF#refs/tags/}
docker tag "${IMAGE_NAME}:${{ matrix.build.tag }}" "${IMAGE_NAME}:${GITHUB_TAG}-${{ matrix.build.tag }}"
docker push "${IMAGE_NAME}:${GITHUB_TAG}-${{ matrix.build.tag }}"
================================================
FILE: .github/workflows/docker_testing.yml
================================================
# This workflow is just for building our Docker image for GPU testing on Beaker,
# and pushing it to Beaker. We only run it when the relevant Dockerfile (or .dockerignore) changes.
name: Docker testing
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
on:
pull_request:
branches:
- main
paths:
- 'Dockerfile.test'
- '.dockerignore'
- 'scripts/entrypoint.sh'
push:
branches:
- main
paths:
- 'Dockerfile.test'
- '.dockerignore'
- 'scripts/entrypoint.sh'
jobs:
build:
name: Build
runs-on: ubuntu-latest
env:
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
BEAKER_WORKSPACE: ai2/tango-testing
IMAGE_NAME: tango-testing
steps:
- uses: actions/checkout@v3
- uses: allenai/setup-beaker@v2
with:
token: ${{ secrets.BEAKER_TOKEN }}
workspace: ${{ env.BEAKER_WORKSPACE }}
- name: Build Docker image
run: |
docker build -t "$IMAGE_NAME" -f Dockerfile.test .
- name: Determine current commit SHA (pull request)
if: github.event_name == 'pull_request'
run: |
echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV
- name: Determine current commit SHA (push)
if: github.event_name != 'pull_request'
run: |
echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV
- name: Test Docker image
run: |
docker run --rm --env COMMIT_SHA="$COMMIT_SHA" "$IMAGE_NAME" tango info
# In order to push a new version of an image to beaker, we have to delete the old version first.
# This doesn't actually delete the backing Docker image, so we'll still benefit from layer
# caching when we push new versions. But we have to be careful to minimize the amount
# of time between deletion and creation, because during that time any Beaker job trying to start
# that depends on that image will fail. So to minimize this downtime, we first push a
# "temp" version of the image, then delete the current one and quickly rename the "temp" one to take its place.
# The image might not exist yet though, so it's okay if the delete fails.
- name: Delete existing commit image
continue-on-error: true
run: |
beaker image delete petew/${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }}
- name: Upload new commit image
run: |
beaker image create --workspace ${{ env.BEAKER_WORKSPACE }} --name ${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }} ${{ env.IMAGE_NAME }}
- name: Delete existing image
if: github.event_name != 'pull_request'
continue-on-error: true
run: |
beaker image delete petew/${{ env.IMAGE_NAME }}
- name: Rename new commit image to final image
if: github.event_name != 'pull_request'
run: |
beaker image rename petew/${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }} ${{ env.IMAGE_NAME }}
================================================
FILE: .github/workflows/integration_tests.yml
================================================
name: Integration tests
on:
workflow_dispatch:
inputs:
test:
description: the integration test to run
default: fairscale_benchmarks
required: true
type: choice
options:
- fairscale_benchmarks
cluster:
description: the beaker cluster to run the test on
default: ai2/tango-integration-tests
required: true
type: choice
options:
- ai2/tango-integration-tests
- ai2/allennlp-cirrascale
# Uncomment this trigger to test changes on a pull request.
# You also have to uncomment the lines below that mention 'for pull request checks'
# pull_request:
# branches:
# - '*'
jobs:
run_test:
name: ${{ github.event.inputs.test }}
# name: fairscale_benchmarks # for pull request checks
runs-on: [ubuntu-latest]
timeout-minutes: 60
env:
TEST_NAME: ${{ github.event.inputs.test }}
# TEST_NAME: fairscale_benchmarks # for pull request checks
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
BEAKER_WORKSPACE: ai2/tango-integration-tests
BEAKER_CLUSTER: ${{ github.event.inputs.cluster }}
# BEAKER_CLUSTER: ai2/allennlp-cirrascale # for pull request checks
IMAGE_NAME: petew/tango-testing
steps:
- uses: actions/checkout@v3
- name: Validate inputs
run: |
# The 'test' input should be a directory in `integration_tests/`
test -d "integration_tests/${TEST_NAME}"
- name: Determine current commit SHA (pull request)
if: github.event_name == 'pull_request'
run: |
echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV
- name: Determine current commit SHA (push)
if: github.event_name != 'pull_request'
run: |
echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV
- name: Install beaker client
shell: bash
run: |
mkdir -p "$HOME/bin"
# Download and install from latest GitHub release.
curl -s https://api.github.com/repos/allenai/beaker/releases/latest \
| grep 'browser_download_url.*linux' \
| cut -d '"' -f 4 \
| wget -qi - \
&& tar -xvzf beaker_linux.tar.gz -C "$HOME/bin"
# Add to path.
echo "$HOME/bin" >> "$GITHUB_PATH"
- name: Verify beaker install
run: |
beaker account whoami
- name: Create beaker experiment config
run: |
cat >beaker_config.yml << EOL
version: v2-alpha
description: ${{ env.TEST_NAME }}
tasks:
- name: test
image:
beaker: ${{ env.IMAGE_NAME }}
command: ["/entrypoint.sh", "integration_tests/${{ env.TEST_NAME }}/run.sh"]
envVars:
- name: COMMIT_SHA
value: $COMMIT_SHA
- name: WANDB_API_KEY
secret: WANDB_API_KEY
- name: FILE_FRIENDLY_LOGGING
value: "true"
- name: TOKENIZERS_PARALLELISM # set this to avoid warnings
value: "true"
- name: PYTHONUNBUFFERED
value: "true"
result:
path: '/results'
resources:
gpuCount: 4
context:
cluster: ${{ env.BEAKER_CLUSTER }}
priority: normal
EOL
cat beaker_config.yml
- name: Submit beaker job
run: |
TIMESTAMP=$(date +%H%M%S)
EXPERIMENT=$(beaker experiment create beaker_config.yml --workspace $BEAKER_WORKSPACE --name "${TEST_NAME}-${{ github.run_number }}-${TIMESTAMP}" | awk '{print $2}')
if [ -z "$EXPERIMENT" ]; then
exit 1
else
echo "EXPERIMENT=$EXPERIMENT" >> $GITHUB_ENV
echo "Experiment $EXPERIMENT submitted. See progress at https://beaker.org/ex/$EXPERIMENT"
fi
- name: Wait for job to finish
run: |
beaker experiment await $EXPERIMENT test finalized --timeout 60m
# Check the job's exit code.
test $(beaker experiment get $EXPERIMENT --format=json | jq '.[0].jobs[0].status.exitCode') -eq 0
- name: Get logs
if: always()
run: |
# EXPERIMENT could be empty if the submission step failed.
# We'll exit right away if that's the case.
if [ -z "$EXPERIMENT" ]; then
echo "No logs to show"
exit 0
fi
# Download logs from beaker.
beaker experiment results $EXPERIMENT --prefix out.log --output results
# If the experiment failed during startup, there might not be any logs.
if [ -f results/test/out.log ]; then
echo ""
echo ">>> Logs:"
echo ""
cat results/test/out.log
else
echo "No logs to show"
fi
- name: Stop job
if: cancelled()
run: |
if [ ! -z "$EXPERIMENT" ]; then
beaker experiment stop $EXPERIMENT
fi
================================================
FILE: .github/workflows/main.yml
================================================
name: Main
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
on:
pull_request:
branches:
- "*"
push:
branches:
- main
tags:
- "v*.*.*"
env:
CACHE_PREFIX: v5 # Change this to invalidate existing cache.
PYTHON_PATH: ./
DEFAULT_PYTHON: 3.9
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
BEAKER_WORKSPACE: ai2/tango-testing
BEAKER_IMAGE: petew/tango-testing
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs:
checks:
name: python ${{ matrix.python }} - ${{ matrix.task.name }}
runs-on: [ubuntu-latest]
timeout-minutes: 30
permissions:
contents: "read"
id-token: "write"
strategy:
fail-fast: false
matrix:
python: ["3.9"]
task:
- name: Lint
extras: dev,all
requires_torch: true
run: |
ruff check .
- name: Type check
extras: dev,all
requires_torch: true
run: |
mypy --check-untyped-defs .
- name: Build
extras: dev,all
requires_torch: true
run: |
tango --version
python -m build
- name: Style
extras: dev
requires_torch: false
run: |
isort --check .
black --check .
- name: Docs
extras: dev,all
requires_torch: true
run: |
cd docs && make html SPHINXOPTS="-W --keep-going"
- name: Test
extras: dev
requires_torch: false
run: |
pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/
- name: Datasets integration
extras: dev,datasets
requires_torch: false
run: |
pytest -v --color=yes --doctest-modules tango/integrations/datasets tests/integrations/datasets
- name: PyTorch integration
extras: dev,torch
requires_torch: true
run: |
pytest -v --color=yes --doctest-modules tango/integrations/torch tests/integrations/torch
- name: Transformers integration
extras: dev,flax,transformers
requires_torch: true
run: |
pytest -v --color=yes --doctest-modules tango/integrations/transformers tests/integrations/transformers
- name: FairScale integration
extras: dev,fairscale
requires_torch: true
run: |
pytest -v --color=yes --doctest-modules tango/integrations/fairscale tests/integrations/fairscale
- name: W&B integration
extras: dev,torch,flax,wandb
requires_torch: true
run: |
pytest -v --color=yes --doctest-modules tango/integrations/wandb tests/integrations/wandb
- name: Beaker integration
extras: dev,beaker
requires_torch: false
run: |
pytest -v --color=yes --doctest-modules tango/integrations/beaker tests/integrations/beaker
- name: Flax integration
extras: dev,flax,transformers
requires_torch: false
run: |
pytest -v --color=yes --doctest-modules tango/integrations/flax tests/integrations/flax
- name: GS integration
extras: dev,gs
requires_torch: false
run: |
pytest -v --color=yes --doctest-modules tango/integrations/gs tests/integrations/gs
- name: Example - train_lm
extras: dev,all
requires_torch: true
run: |
cd examples/train_lm
pytest -v --color=yes test.py
include:
# Run the core tests on other Python versions as well.
- task:
name: Test
extras: dev
requires_torch: false
run: |
pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/
python: "3.8"
- task:
name: Test
extras: dev
requires_torch: false
run: |
pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/
python: "3.10"
steps:
- uses: "actions/checkout@v3"
- name: Checkout
if: github.event_name != 'pull_request'
uses: actions/checkout@v3
# For pull requests we need to checkout the HEAD commit instead of the merge
# commit since some tests depend on having an existing commit.
- name: Checkout (pull request)
if: github.event_name == 'pull_request'
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
- name: Install prerequisites
run: |
pip install --upgrade pip setuptools wheel virtualenv
- name: Set build variables
shell: bash
run: |
set -e
# Get the exact Python version to use in the cache key.
echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV
echo "RUNNER_ARCH=$(uname -m)" >> $GITHUB_ENV
# Use week number in cache key so we can refresh the cache weekly.
echo "WEEK_NUMBER=$(date +%V)" >> $GITHUB_ENV
echo "EXTRAS_HASH=$(python scripts/hash_extras.py ${{ matrix.task.extras }})" >> $GITHUB_ENV
- uses: actions/cache@v3
id: virtualenv-cache
with:
path: .venv
key: ${{ env.CACHE_PREFIX }}-${{ env.WEEK_NUMBER }}-${{ runner.os }}-${{ env.RUNNER_ARCH }}-${{ env.PYTHON_VERSION }}-${{ env.EXTRAS_HASH }}-${{ hashFiles('pyproject.toml') }}
- name: Setup virtual environment (no cache hit)
if: steps.virtualenv-cache.outputs.cache-hit != 'true'
run: |
test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv
# Reference: https://github.com/marketplace/actions/authenticate-to-google-cloud#setup
- name: Authenticate to Google Cloud
if: matrix.task.name == 'GS integration'
uses: "google-github-actions/auth@v1"
with:
workload_identity_provider: "projects/10554368204/locations/global/workloadIdentityPools/tango-ci-pool/providers/tango-ci-provider"
service_account: "tango-service@ai2-allennlp.iam.gserviceaccount.com"
- name: Pre-install torch
if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'torch') || contains(matrix.task.extras, 'all') || matrix.task.requires_torch)
run: |
. .venv/bin/activate
pip install torch==2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
- name: Pre-install flax
if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'flax') || contains(matrix.task.extras, 'all'))
run: |
. .venv/bin/activate
pip install flax jax jaxlib "tensorflow-cpu>=2.9.1" optax
- name: Install editable (no cache hit)
if: steps.virtualenv-cache.outputs.cache-hit != 'true'
run: |
. .venv/bin/activate
pip install -e .[${{ matrix.task.extras }}]
- name: Install editable (cache hit)
if: steps.virtualenv-cache.outputs.cache-hit == 'true'
run: |
. .venv/bin/activate
pip install --no-deps -e .[${{ matrix.task.extras }}]
- name: Show environment info
run: |
. .venv/bin/activate
echo "========= Python location ==========="
which python
echo "========= Python version ============"
python --version
echo "========= Python packages ==========="
pip freeze
echo "========= Tango installation ========"
tango info
- name: ${{ matrix.task.name }}
run: |
. .venv/bin/activate
${{ matrix.task.run }}
- name: Upload package distribution files
if: matrix.task.name == 'Build' && matrix.python == env.DEFAULT_PYTHON
uses: actions/upload-artifact@v3
with:
name: package
path: dist
- name: Upload docs build
if: matrix.task.name == 'Docs' && matrix.python == env.DEFAULT_PYTHON
uses: actions/upload-artifact@v3
with:
name: docs
path: docs/build
- name: Clean up
if: always()
run: |
. .venv/bin/activate
pip uninstall -y ai2-tango
gpu_tests:
name: GPU Tests
runs-on: ubuntu-latest
steps:
- name: Determine current commit SHA (pull request)
if: github.event_name == 'pull_request'
run: |
echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV
- name: Determine current commit SHA (push)
if: github.event_name != 'pull_request'
run: |
echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV
- name: GPU Tests
uses: allenai/beaker-run-action@v1.2
with:
spec: |
version: v2
description: GPU Tests
budget: ai2/oe-training
tasks:
- name: tests
image:
beaker: ${{ env.BEAKER_IMAGE }}
context:
preemptible: true
resources:
gpuCount: 2
envVars:
- name: COMMIT_SHA
value: ${{ env.COMMIT_SHA }}
command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/"]
result:
path: /unused
token: ${{ secrets.BEAKER_TOKEN }}
workspace: ${{ env.BEAKER_WORKSPACE }}
clusters: ai2/general-cirrascale,ai2/allennlp-cirrascale,ai2/aristo-cirrascale,ai2/mosaic-cirrascale,ai2/s2-cirrascale
release:
name: Release
runs-on: ubuntu-latest
needs: [gpu_tests, checks]
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ env.DEFAULT_PYTHON }}
- name: Install requirements
run: |
pip install -e .[dev]
- name: Prepare environment
run: |
echo "RELEASE_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV
echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
- name: Download package distribution files
uses: actions/download-artifact@v3
with:
name: package
path: dist
- name: Generate release notes
run: |
python scripts/release_notes.py > ${{ github.workspace }}-RELEASE_NOTES.md
- name: Publish package to PyPI
run: |
twine upload -u __token__ -p ${{ secrets.PYPI_PASSWORD }} dist/*
- name: Publish GitHub release
uses: softprops/action-gh-release@v1
with:
body_path: ${{ github.workspace }}-RELEASE_NOTES.md
prerelease: ${{ contains(env.TAG, 'rc') }}
files: |
dist/*
================================================
FILE: .github/workflows/update_dependency_pr.yml
================================================
name: Update dependency PR
on:
pull_request:
types:
- opened
paths:
- "pyproject.toml"
permissions:
pull-requests: write
jobs:
torch:
name: torch
runs-on: ubuntu-latest
if: startsWith(github.head_ref, 'dependabot/pip/torch-')
steps:
- uses: actions/github-script@v6
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: 'Hello! This is a [PyTorch](https://pytorch.org/) upgrade, which means you will also need to update:\n- [ ] The base image in `Dockerfile`\n- [ ] The base image in `Dockerfile.test`\n- [ ] The torch version hard-coded in `.github/workflows/main.yml`'
})
================================================
FILE: .gitignore
================================================
# build artifacts
.eggs/
.mypy_cache
ai2_tango.egg-info/
build/
dist/
pip-wheel-metadata/
runs/
workspace/
# dev tools
.envrc
.python-version
.idea
.venv/
.vscode/
/*.iml
# jupyter notebooks
.ipynb_checkpoints
# miscellaneous
.cache/
doc/_build/
*.swp
.DS_Store
# python
*.pyc
*.pyo
__pycache__
# testing and continuous integration
.coverage
.pytest_cache/
.benchmarks
# documentation build artifacts
docs/build
site/
# internal experiment configs
*-internal.jsonnet
*-internal.json
*-internal.yaml
*-internal.yml
================================================
FILE: .readthedocs.yaml
================================================
version: 2
sphinx:
configuration: docs/source/conf.py
fail_on_warning: true
build:
os: ubuntu-22.04
tools:
python: "3.10"
python:
install:
- method: pip
path: .
extra_requirements:
- dev
- all
================================================
FILE: CHANGELOG.md
================================================
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Fixed
- Fixed a bunch of dependencies
- Upgraded to new version of wandb
## [v1.3.2](https://github.com/allenai/tango/releases/tag/v1.3.2) - 2023-10-27
### Fixed
- Fix issues with gcloud auth in beaker executor.
## [v1.3.1](https://github.com/allenai/tango/releases/tag/v1.3.1) - 2023-10-25
### Fixed
- Minor bugs in the `GSWorkspace()`.
### Changed
- Added CLI-style execution functions for experiments defined in Python.
- Added `display()` to `ExecutorOutput` for producing a table that summarizes the run.
## [v1.3.0](https://github.com/allenai/tango/releases/tag/v1.3.0) - 2023-10-13
### Added
- Added the `Workspace.remove_step()` method to safely remove steps.
- The `GSWorkspace()` can now be initialized with google cloud bucket subfolders.
### Changed
- The `BeakerExecutor` now uses the HEAD commit at the time the executor is instantiated to executor a step instead of the HEAD commit at the time the step is run.
### Fixed
- Removed unnecessary code coverage dev requirements.
- Fixed issue where new version of torch caused no LR schedulers to be registered.
- Updated pinned versions of jax, jaxlib, and flax.
## [v1.2.1](https://github.com/allenai/tango/releases/tag/v1.2.1) - 2023-04-06
### Added
- Added the following workspace methods to support the Tango viz UI: `Workspace.search_registered_runs()`, `Workspace.search_step_info()`, `Workspace.num_registered_runs()`, and `Workspace.num_steps()`.
### Fixed
- Fixes a bug where `FromParams` would fail to parse when an object takes a `Step` argument directly.
- Changed a name so we don't override the built-in name `set`.
- Fixed a bug that would cause O(n^2) memory consumption in dense step graphs.
## [v1.2.0](https://github.com/allenai/tango/releases/tag/v1.2.0) - 2023-02-10
### Added
- You can now add arguments to steps without invalidating the cache. See `Step.SKIP_DEFAULT_ARGUMENTS`.
- Fixed integration status messages in `tango info` command.
- Added abstractions for `RemoteClient`, `RemoteStepCache`, and `RemoteWorkspace`.
- Added a GS integration that comes with `GSWorkspace`, a remote `Workspace` implementation that uses google cloud storage.
- You can now bind functional steps to the underlying `Step` instance with `@step(bind=True)`, meaning the first argument to the function will be a `Step`.
- Added `ShellStep` for running arbitrary shell commands.
- Added `@make_registrable` decorator to make arbitrary functions registrable, to make it easier to refer to them in tango configurations.
### Fixed
- Jsonnet parsing is now much faster and works on Windows.
- Warnings about locks are now reliably printed every 30 seconds
- We now make sure Beaker jobs have the latest version of beaker-py, so that we're compatible with the latest API changes.
- Stopping early now works when the metric doesn't change at all.
- Fixed bug with `FromParams` which didn't handle variable length tuples correctly.
### Changed
- The default log level for Tango is now `warning`.
- You can specify multiple steps with `-s` from the `tango run` command.
## [v1.1.0](https://github.com/allenai/tango/releases/tag/v1.1.0) - 2022-12-01
### Added
- Added `gpu_type` field to `StepResources`. The `BeakerExecutor` can use this to determine which clusters to a submit a step to.
- Added `machine` field to `StepResources`. You can set this to "local" when using the `BeakerExecutor` to force it to run the step locally.
- Added `--ext-var` argument to `tango run` for setting JSONNET external variables
when loading the experiment config.
- Added `@step()` decorator to create `Step` classes from functions.
- Added the `transformers::with_soft_prompt` integration, to make soft-prompted prefix transformers easy.
### Removed
- Removed PyTorch Lightning integration.
- Removed `tango server` command and `--serve/--no-serve` option for `tango run`.
- Removed `source_release.py`, which was checked in by accident.
### Fixed
- Fixed issue where Executor `parallelism` option in a Tango settings file would be ignored.
- Fixed a bug where the unique ID of a step that depends on a key-value of the result of another step could change if the name of the other step changes.
- Fixed a bug where importing certain libraries (like torchmetrics) would mess with our exception handling because they set `sys.excepthook` for some reason. Now we always reset `sys.excepthook` after importing.
- The type hints for the flax trainer suggested that the training split is optional when in fact it's mandatory.
- Made `BeakerWorkspace` / `BeakerStepLock` more robust when a job is preempted.
- Minor performance improvements for the Beaker executor and workspace.
- Fixed bug with `step_extra_dependencies` where uncacheable dependencies wouldn't be run.
## [v1.0.2](https://github.com/allenai/tango/releases/tag/v1.0.2) - 2022-11-14
### Changed
- `BeakerScheduler` can now return a list of clusters.
## [v1.0.1](https://github.com/allenai/tango/releases/tag/v1.0.1) - 2022-10-20
### Fixed
- `LightningTrainStep` now can take a `Lazy` model object which results in a gauranteed deterministic hash.
- Fixed issue where remote `Workspace` implementations like `WandbWorkspace` and `BeakerWorkspace` would use the same local cache regardless of the W&B / Beaker workspace
being used.
- Fixed bug with `TorchEvalStep` when constructing callbacks.
- Fixed some import error issues caused when an integration is not installed.
- Fix incorrect reporting of final results in `MulticoreExecutor`.
### Changed
- Wandb step cache retries api call in case of timeout
- `beaker-py >= 1.11` required.
## [v1.0.0](https://github.com/allenai/tango/releases/tag/v1.0.0) - 2022-10-05
### Added
- Added `step_extra_dependencies` input field to `Step` class that can be used to force a dependency on another step even if the current step doesn't directly depend on the output of the other step. See [#418](https://github.com/allenai/tango/issues/418) for more context.
### Changed
- `beaker-py >= 1.10` required.
### Fixed
- Long log lines will be soft-wrapped to ensure that links are clickable.
- Fixed a bug where some workspaces could be left in a bad state if a step's `Format` failed to serialize the step's result in `Workspace.step_finished()`.
- Sometimes functions and methods end up as arguments to steps, which means we have to hash them. Instead of taking
a hash of the function, we now take a hash of the function's module and name.
- Fixed a bug with the Beaker executor where it would hang at the end of a run if a step failed that is a dependency of another step.
- Fixed tests to work with new version of transformers.
- Fixed `Executor.execute_sub_graph_for_step()` to be able to run the step's dependencies in parallel.
## [v0.14.0](https://github.com/allenai/tango/releases/tag/v0.14.0) - 2022-09-20
### Added
- Adds a function to modify a Hugging Face transformer with IA3 adaptors
- Added a `BeakerScheduler` registrable class, specified as the argument `scheduler` to `BeakerExecutor`, which controls the resources assigned to steps ran on Beaker.
Users can implement their own `BeakerScheduler` subclasses to customize the resource assignment behavior.
### Changed
- In the `tango run` command, `--no-server` is now the default. Use `--server` to start the server.
### Fixed
- Made `BeakerExecutor` more robust to connection, timeout, SSL, and other recoverable HTTP errors.
- Made the `BeakerStepLock` more robust, and as a result `BeakerWorkspace` is more
robust and should require less manual intervention for locks in a bad state.
- Fixed a bug with the internal scheduling logic of the `BeakerExecutor` which
could delay submitting some steps in parallel.
- Fixed a bug where creating a `StepInfo` object from params might result in unnecessary imports.
- Fixed a bug where canceling the Beaker executor might not work properly.
- Fixed a bug where the trainer trains too much when `train_epochs` is set and you're using gradient accumulation.
- Fixed a bug where included modules might not be found when using multiprocessing when they're not on `sys.path` / `PYTHONPATH`.
- Fixed how the results of uncacheable steps are displayed by `tango run`.
- Beaker executor won't run duplicate cacheable steps at the same time.
## [v0.13.0](https://github.com/allenai/tango/releases/tag/v0.13.0) - 2022-09-07
### Added
- You can now reference into a particular index of the result of another step in a config. For example: `{type: "ref", ref: "some_previous_step", key: 0}`.
The key field can be an integer if the result of the referenced step is a list or tuple, or a string if the result of the referenced step is a dictionary.
- Added `priority` parameter to Beaker executor for setting the default task priority for Beaker jobs.
- Added `Workspace.step_result()` method for getting a step's result from the latest
run.
- `tango run` will now display a URL to the logs for failed steps when you use the `BeakerExecutor`.
### Changed
- The `TorchTrainStep` now enables monitoring arbitrary model outputs during training. `TorchTrainEngine.forward_train` now returns a tuple `loss, model_outputs` for each micro batch and the list of model outputs for all micro batches in a batch is passed to the `TrainCallback.log_batch` and `TrainCallback.post_batch`.
- Tango will now automatically search Python modules in the current working directory
for registered classes so that you don't always need to use the `--include-package` setting.
- The minimum supported Python version is now 3.8.
- Added support for PyTorch Lightning 1.7.x
- The Beaker Executor will no-longer live-stream logs from Beaker jobs, but logs will be viewable on Beaker and more readable.
- Only the Beaker executor requires a clean working directory
### Fixed
- Fixed a bug that did not allow a wandb artifact's type to be set from a step's metadata dictionary.
- Fixed a bug with how the Beaker executor streams log lines from Beaker which sometimes resulted in messages missing some starting characters, and tqdm lines being duplicated.
- Fixed a bug in the Beaker workspace where the lock dataset wouldn't be removed if the step
was found to be in an invalid state.
- Improved cluster choice logic in `BeakerExecutor` to ensure greater diversity of clusters when submitting many steps at once.
- Fixed bug where sub-processes of the multicore executor would use the wrong executor if `executor` was defined in a `tango.yml` file.
- Deterministic hashes for numpy and torch tensors were not deterministic. Now they are.
## [v0.12.0](https://github.com/allenai/tango/releases/tag/v0.12.0) - 2022-08-23
### Added
- **Step resources:**
- Added a `step_resources` parameter to the `Step` class which should be used to describe the computational resources required to run a step.
`Executor` implementations can use this information. For example, if your step needs 2 GPUs, you should set
`step_resources=StepResources(gpu_count=2)` (`"step_resources": {"gpu_count": 2}` in the configuration language).
- Added a `Step.resources()` property method. By default this returns the value specified by the `step_resources` parameter.
If your step implementation always requires the same resources, you can just override this method so you don't have to provide
the `step_resources` parameter.
- **Step execution:**
- Added an `executor` field to the `tango.yml` settings. You can use this to define the executor you want to use by default.
- Added a Beaker `Executor` to the Beaker integration, registered as an `Executor` with the name "beaker".
To use this executor, add these lines to your `tango.yml` file:
```yaml
executor:
type: beaker
beaker_workspace: ai2/my-workspace
clusters:
- ai2/general-cirrascale
```
See the docs for the `BeakerExecutor` for more information on the input parameters.
- **Step class:**
- Added a metadata field to the step class API. This can be set through the class
variable `METADATA` or through the constructor argument `step_metadata`.
- **Weights & Biases integration:**
- You can now change the artifact kind for step result artifacts by adding a field
called "artifact_kind" to a step's metadata.
For models, setting "artifact_kind" to "model" will add the corresponding artifact to W&B's new model zoo.
### Changed
- **CLI:**
- The `tango run` command will throw an error if you have uncommitted changes in your repository, unless
you use the `--allow-dirty` flag.
- The `tango run` command will use the lightweight base executor (single process) by default.
To use the multi-process executor, set `-j/--parallelism` to 1 or higher or -1 to use all available CPU cores.
### Fixed
- Fixed bug where `StepInfo` environment and platform metadata could be out-of-date if a step is run again due to failure.
- Fixed a bug where an unfortunate combination of early stopping and decreasing model performance could result in a crash in the torch trainer.
## [v0.11.0](https://github.com/allenai/tango/releases/tag/v0.11.0) - 2022-08-04
### Added
- Added a [Flax](https://flax.readthedocs.io/en/latest/) integration along with an example config.
## [v0.10.1](https://github.com/allenai/tango/releases/tag/v0.10.1) - 2022-07-26
### Fixed
- Fixed issue where the StepInfo config argument could be parsed into a Step.
- Restored capability to run tests out-of-tree.
## [v0.10.0](https://github.com/allenai/tango/releases/tag/v0.10.0) - 2022-07-07
### Changed
- Renamed `workspace` parameter of `BeakerWorkspace` class to `beaker_workspace`.
- `Executor` class is now a `Registrable` base class. `MulticoreExecutor` is registered as "multicore".
### Removed
- Removed `StepExecutionMetadata`. Its fields have been absorbed into `StepInfo`.
### Fixed
- Improved `Step.ensure_result()` such that the step's result doesn't have to be read from the cache.
- Fixed an issue with the output from `MulticoreExecutor` such that it's now consistent with the default `Executor` for steps that were found in the cache.
- One of our error messages referred to a configuration file that no longer exists.
- Improved performance of `BeakerWorkspace`.
### Added
- Added the ability to train straight `Model` instead of just `Lazy[Model]`
## [v0.9.1](https://github.com/allenai/tango/releases/tag/v0.9.1) - 2022-06-24
### Fixed
- Fixed non-deterministic behavior in `TorchTrainStep`.
- Fixed bug in `BeakerWorkspace` where `.step_info(step)` would raise a `KeyError` if the step hasn't been registered as part of a run yet.
- Fixed a bug in `BeakerWorkspace` where it would send too many requests to the beaker service.
- Fixed a bug where `WandbWorkspace.step_finished()` or `.step_failed()` would crash if called
from a different process than `.step_starting()`.
- Fixed a bug in `WandbWorkspace.step_finished()` which led to a `RuntimeError` sometimes while
caching the result of a step.
## [v0.9.0](https://github.com/allenai/tango/releases/tag/v0.9.0) - 2022-06-01
### Added
- Added a [Beaker](https://beaker.org) integration that comes with `BeakerWorkspace`, a remote `Workspace` implementation that uses Beaker Datasets under the hood.
- Added a `datasets::dataset_remix` step that provides the split remixing functionality of `tango.steps.datasest_remix.DatasetRemixStep` now for Huggingface `DatasetDict`.
- Added a config and code example of Registrable to the First Step docs with edits for clarity.
### Changed
- If you try to import something from a tango integration that is not fully installed due to missing dependencies, an `IntegrationMissingError` will be raised
instead of `ModuleNotFound`.
- You can now set `-j 0` in `tango run` to disable multicore execution altogether.
### Fixed
- Improved how steps and workspaces handle race conditions when different processes are competing to execute the same step. This would result in a `RuntimeError` before with most workspaces, but now it's handled gracefully.
- Fixed bug which caused GradScaler state to not be saved and loaded with checkpoints.
## [v0.8.0](https://github.com/allenai/tango/releases/tag/v0.8.0) - 2022-05-19
### Added
- Added a Weights & Baises remote `Workspace` implementation: `WandbWorkspace`, registered as "wandb".
This can be instantiated from a workspace URL in the form "wandb://entity/project".
- Added a method `Workspace.step_result_for_run` which gives the result of a step given the run name and step name within that run.
- Added property `Workspace.url`, which returns a URL for the workspace that can be used to instantiate the exact same workspace using `Workspace.from_url()`. Subclasses must implement this.
### Changed
- `StepInfo` start and end times will be always be in UTC now.
- `WandbTrainCallback` now logs system metrics from each worker process in distributed training.
- `StepCache.__contains__()` and `StepCache.__getitem__()` now take accept either a `Step` or `StepInfo` as an argument (`Union[Step, StepInfo]`).
- Refactored `tango.step_graph.StepGraph` to allow initialization from a `Dict[str, Step]`.
- `Executor.execute_step_graph()` now attempts to execute all steps and summarizes success/failures.
### Fixed
- Fixed bug with `LocalWorkspace.from_parsed_url()` ([#278](https://github.com/allenai/tango/issues/278)).
- Deprecation warnings will now be logged from `tango` CLI.
- Fixed the text format in the case of serializing an iterator of string.
- Added missing default value of `None` to `TangoGlobalSettings.find_or_default()`.
- Mypy has become incompatible with transformers and datasets, so we have to disable the checks in some places.
- The `VERSION` member of step arguments that were wrapped in `Lazy` were not respected. Now they are.
## [v0.7.0](https://github.com/allenai/tango/releases/tag/v0.7.0) - 2022-04-19
### Added
- Added the "-n/--name" option to `tango run`. This option allows the user to give the run an arbitrary name.
- Added a convenience property `.workspace` to `Step` class that can be called from a step's `.run()` method to get the current `Workspace` being used.
- Gave `FromParams` objects (which includes all `Registrable` objects) the ability to version themselves.
- Added CLI option to run a single step in a config using `--step-name` or `-s`.
- Added a `MultiCoreExecutor` that executes steps in parallel.
- Added an `ExecutorOutput` dataclass that is returned by `Executor.execute_step_graph()`.
- `StepGraph` now prints itself in a readable way.
- Tango now automatically detects when it's running under a debugger, and disables multicore support accordingly. Many debuggers can't properly follow sub-processes, so this is a convenience for people who love debuggers.
- Added more models to the stuff we can import from the transformers library.
- Added new example for finetuning text-to-text models.
### Changed
- Renamed `click_logger` to `cli_logger`, and we now use [rich](https://github.com/Textualize/rich)'s logging `Handler` as the default handler, which means prettier output, better tracebacks, and you can use rich's markup syntax with the `cli_logger` to easily add style to text.
- Refactored `tango.step_graph.StepGraph` to allow initialization from a `Dict[str, Step]`.
- `Executor.execute_step_graph()` now attempts to execute all steps and summarizes success/failures.
- Upgraded PyTorch version in `tango` Docker image to latest `v1.11.0+cu113`.
- `RunGeneration` now allows model object as input.
### Fixed
- Fixed bug that mistakenly disallowed fully-qualified names containing `"_"` (underscores) in the config.
- Fixed bug where `TorchTrainStep` working directory would be left in an unrecoverable state if training failed after saving the final model weights.
- Fixed bug in `FromParams` where `**kwargs` might be passed down to the constructors of arguments.
- Fixed bug in the way dependencies are tracked between steps.
- Fixed bug that caused `MulticoreExecutor` to hang in case of a failing step that was required recursively (not directly) downstream.
- Fixed bug in the way dependencies are tracked between steps
- Compatibility with PyTorch Lightning 1.6
## [v0.6.0](https://github.com/allenai/tango/releases/tag/v0.6.0) - 2022-02-25
### Added
- New example that finetunes a pre-trained ResNet model on the Cats & Dogs dataset.
- Added a '@requires_gpus' decorator for marking tests as needing GPUs. Tests marked with this will be run in the "GPU Tests" workflow
on dual k80 GPUs via Beaker.
- Added the "-w/--workspace" option to `tango run` and `tango server` commands. This option takes a path or URL, and instantiates the workspace from the URL using the newly added `Workspace.from_url()` method.
- Added the "workspace" field to `TangoGlobalSettings`.
- Added the "environment" field to `TangoGlobalSettings` for setting environment variables each
time `tango` is run.
- Added a utility function to get a `StepGraph` directly from a file.
- Added `tango.settings` module and `tango settings` group of commands.
- A format for storing sequences as `SqliteSparseSequence`
- A way to massage kwargs before they determine the unique ID of a `Step`
### Changed
- `local_workspace.ExecutorMetadata` renamed to `StepExecutionMetadata` and now saved as `execution-metadata.json`.
- `tango run` without the option "-w/--workspace" or "-d/--workspace-dir" will now use a `MemoryWorkspace` instead of a `LocalWorkspace` in a temp directory, unless you've specified
a default workspace in a `TangoGlobalSettings` file.
- Moved `tango.workspace.MemoryWorkspace` and `tango.local_workspace.LocalWorkspace` to `tango.workspaces.*`.
- Moved `tango.step_cache.MemoryStepCache` and `tango.step_cache.LocalStepCache` to `tango.step_caches.*`.
- Deprecated the `-d/--workspace-dir` command-line option. Please use `-w/--workspace` instead.
### Fixed
- Fixed a small bug `LocalWorkspace` would fail to capture the conda environment in our Docker image.
- Fixed activation of `FILE_FRIENDLY_LOGGING` when set from the corresponding environment variable.
- Fixed setting log level via the environment variable `TANGO_LOG_LEVEL`.
- Use relative paths within the `work_dir` for symbolic links to the latest and the best checkpoints in `TorchTrainStep`.
- Fixed some scenarios where Tango can hang after finishing all steps.
- `distributed_port` and `log_every` parameters won't factor into `TorchTrainStep`'s unique ID.
- `MappedSequence` now works with slicing.
- `MappedSequence` now works with Huggingface `Dataset`.
- Uncacheable steps are now visible in Tango UI.
- Fixed bug in `Registrable.list_available()` where an error might be raised if the default implementation hadn't been explicitly imported.
- Fixed issue where having a default argument to the `run()` method wasn't getting applied to the step's unique ID.
## [v0.5.0](https://github.com/allenai/tango/releases/tag/v0.5.0) - 2022-02-09
### Added
- Added `TrainingEngine` abstraction to torch integration.
- Added [FairScale](https://fairscale.readthedocs.io/en/latest/) with a `FairScaleTrainingEngine`
that leverages FairScale's `FullyShardedDataParallel`. This is meant to be used within the `TorchTrainStep`.
- All PyTorch components (such as learning rate schedulers, optimizers, data collators, etc) from the
transformers library and now registered under the corresponding class in the torch integration.
For example, transformers `Adafactor` optimizer is registered as an `Optimizer` under the name
"transformers::Adafactor". More details can be found in the documentation for the transformers integration.
### Changed
- Various changes to the parameters othe `TorchTrainStep` due to the introduction of the `TrainingEngine` class.
- Params logged as `DEBUG` level instead of `INFO` to reduce noise in logs.
- The waiting message for `FileLock` is now clear about which file it's waiting for.
- Added an easier way to get the default Tango global config
- Most methods to `TorchTrainCallback` also take an `epoch` parameter now.
- `WandbTrainCallback` now logs peak GPU memory occupied by PyTorch tensors per worker. This is useful because W&B's system metrics only display the total GPU memory reserved by PyTorch, which is always higher than the actual amount of GPU memory occupied by tensors. So these new metrics give a more accurate view into how much memory your training job is actually using.
- Plain old Python functions can now be used in `Lazy` objects.
- `LocalWorkspace` now creates a symlink to the outputs of the latest run.
- Tango is now better at guessing when a step has died and should be re-run.
- Tango is now more lenient about registering the same class under the same name twice.
- When you use `dict` instead of `Dict` in your type annotations, you now get a legible error message. Same for `List`, `Tuple`, and `Set`.
### Fixed
- Fixed a bug in `Registrable` and `FromParams` where registered function constructors would not properly construct
arguments that were classes.
- Fixed a bug in `FromParams` that would cause a crash when an argument to the constructor had the name `params`.
- Made `FromParams` more efficient by only trying to parse the params as a `Step` when it looks like it actually could be a step.
- Fixed bug where `Executor` would crash if `git` command could not be found.
- Fixed bug where validation settings were not interpreted the right way by the torch trainer.
- When you register the same name twice using `Registrable`, you get an error message. That error message now contains the correct class name.
## [v0.4.0](https://github.com/allenai/tango/releases/tag/v0.4.0) - 2022-01-27
### Changed
- Default log level is `WARNING` instead of `ERROR`.
- The web UI now renders the step graph left-to-right.
- The web UI now shows runs by date, with the most recent run at the top.
- The web UI now shows steps in a color-coded way.
- The `tango run` command now prints user-friendly paths if possible.
- The `--include-package` flag now also accepts paths instead of module names.
- `tango.common.sqlite_sparse_sequence.SqliteSparseSequence` now lives at `tango.common.sequences.SqliteSparseSequence`.
### Fixed
- Ensure tqdm log lines always make it into the log file `out.log` even when log level is `WARNING` or `ERROR`.
- Numerous parts of Tango now have documentation when they didn't before.
## [v0.4.0rc5](https://github.com/allenai/tango/releases/tag/v0.4.0rc5) - 2022-01-19
### Added
- Added `TorchEvalStep` to torch integration, registered as "torch::eval".
### Changed
- Renamed `aggregate_val_metric` to `auto_aggregate_val_metric` in `TorchTrainStep`.
- `devices` parameter to `TorchTrainStep` replaced with `device_count: int`.
- Run name printed at the end of a run so it's easier to find.
- Type information added to package data. See [PEP 561](https://www.python.org/dev/peps/pep-0561) for more information.
- A new integration, `transformers`, with two new steps for running seq2seq models.
- Added `logging_tqdm`, if you don't want a progress bar, but you still want to see progress in the logs.
- Added `threaded_generator()`, for wrapping generators so that they run in a separate thread from the generator's consumer.
- Added a new example for evaluating the T0 model on XSum, a summarization task.
- Added `MappedSequence` for functionally wrapping sequences.
- Added `TextFormat`, in case you want to store the output of your steps in raw text instead of JSON.
- Steps can now list arguments in `SKIP_ID_ARGUMENTS` to indicate that the argument should not affect a step's
unique id. This is useful for arguments that affect the execution of a step, but not the output.
- `Step` now implements `__str__`, so steps look pretty in the debugger.
- Added `DatasetCombineStep`, a step that combines multiple datasets into one.
- Added `common.logging.initialize_worker_logging()` function for configuring logging from worker processes/threads.
- Logs from `tango run ...` will be written to a file called `out.log` in the run directory.
### Fixed
- Fixed torch `StopEarlyCallback` state not being recovered properly on restarts.
- Fixed file friendly logging by removing special styling characters.
- Ensured exceptions captured in logs.
- `LocalWorkspace` now works properly with uncacheable steps.
- When a Tango run got killed hard, with `kill -9`, or because the machine lost power, `LocalWorkspace` would
sometimes keep a step marked as "running", preventing further executions. This still happens sometimes, but it
is now much less likely (and Tango gives you instructions for how to fix it).
- To make all this happen, `LocalWorkspace` now saves step info in a Sqlite database. Unfortunately that means that
the workspace format changes and existing workspace directories won't work properly with it.
- Fixed premature cleanup of temporary directories when using `MemoryWorkspace`
## [v0.4.0rc4](https://github.com/allenai/tango/releases/tag/v0.4.0rc4) - 2021-12-20
### Fixed
- Fixed a bug where `StepInfo` fails to deserialize when `error` is an exception that can't be pickled.
## [v0.4.0rc3](https://github.com/allenai/tango/releases/tag/v0.4.0rc3) - 2021-12-15
### Added
- Added `DatasetsFormat` format and `LoadStreamingDataset` step to `datasets` integration.
- `SqliteDictFormat` for datasets.
- Added `pre_epoch()` and `post_epoch()` callback methods to PyTorch `TrainCallback`.
### Changed
- `LoadDataset` step from `datasets` integration is now cacheable, using the `DatasetsFormat` format by default.
But this only works with non-streaming datasets. For streaming datasets, you should use the `LoadStreamingDataset` step instead.
### Fixed
- Fixed bug where `KeyboardInterrupt` exceptions were not handled properly by steps and workspaces.
- `WandbTrainCallback` now will use part of the step's unique ID as the name for the W&B run by default, to make
it easier to indentify which tango step corresponds to each run in W&B.
- `WandbTrainCallback` will save the entire `TrainConfig` object to the W&B config.
## [v0.4.0rc2](https://github.com/allenai/tango/releases/tag/v0.4.0rc2) - 2021-12-13
### Added
- Sample experiment configurations that prove Euler's identity
### Changed
- Loosened `Click` dependency to include v7.0.
- Loosened `datasets` dependency.
- Tightened `petname` dependency to exclude next major release for safety.
### Fixed
- `Workspace`, `MemoryWorkspace`, and `LocalWorkspace` can now be imported directly from the `tango`
base module.
- Uncacheable leaf steps would never get executed. This is now fixed.
- We were treating failed steps as if they were completed by accident.
- The visualization had a problem with showing steps that never executed because a dependency failed.
- Fixed a bug where `Lazy` inputs to a `Step` would fail to resolve arguments that come from the result
of another step.
- Fixed a bug in `TorchTrainStep` where some arguments for distributed training (`devices`, `distributed_port`) weren't being set properly.
## [v0.4.0rc1](https://github.com/allenai/tango/releases/tag/v0.4.0rc1) - 2021-11-30
### Added
- Introduced the concept of the `Workspace`, with `LocalWorkspace` and `MemoryWorkspace` as initial implementations.
- Added a stub of a webserver that will be able to visualize runs as they happen.
- Added separate classes for `LightningTrainingTypePlugin`, `LightningPrecisionPlugin`, `LightningClusterEnvironmentPlugin`, `LightningCheckpointPlugin` for compatibility with `pytorch-lightning>=1.5.0`.
- Added a visualization of workspaces that can show step graphs while they're executing.
### Removed
- Removed old `LightningPlugin` class
- Removed requirement of the `overrides` package
### Changed
- Made it possible to construct a step graph out of `Step` objects, instead of constructing it out of `StepStub` objects.
- Removed dataset fingerprinting code, since we can now use `Step` to make sure things are cached.
- Made steps deterministic by default.
- Brought back `MemoryStepCache`, so we can run steps without configuring anything.
- W&B `torch::TrainCallback` logs with `step=step+1` now so that training curves in the W&B dashboard
match up with checkpoints saved locally and are easier to read (e.g. step 10000 instead of 9999).
- `filelock >= 3.4` required, parameter `poll_intervall` to `tango.common.file_lock.FileLock.acquire` renamed
to `poll_interval`.
### Fixed
- Fixed bug in `FromParams` where a parameter to a `FromParams` class may not be instantiated correctly
if it's a class with a generic type parameter.
## [v0.3.6](https://github.com/allenai/tango/releases/tag/v0.3.6) - 2021-11-12
### Added
- Added a `.log_batch()` method on `torch::TrainCallback` which is given the average loss across
distributed workers, but only called every `log_every` steps.
### Removed
- Removed `.pre_log_batch()` method on `torch::TrainCallback`.
### Fixed
- Fixed typo in parameter name `remove_stale_checkpoints` in `TorchTrainStep` (previously was `remove_state_checkpoints`).
- Fixed bug in `FromParams` that would cause failures when `from __future__ import annotations`
was used with Python older than 3.10. See [PEP 563](https://www.python.org/dev/peps/pep-0563/)
for details.
## [v0.3.5](https://github.com/allenai/tango/releases/tag/v0.3.5) - 2021-11-05
### Fixed
- Fixed a bug in `FromParams` where the "type" parameter was ignored in some cases
where the `Registrable` base class did not directly inherit from `Registrable`.
## [v0.3.4](https://github.com/allenai/tango/releases/tag/v0.3.4) - 2021-11-04
### Added
- Added `StopEarlyCallback`, a `torch::TrainCallback` for early stopping.
- Added parameter `remove_stale_checkpoints` to `TorchTrainStep`.
### Changed
- Minor changes to `torch::TrainCallback` interface.
- Weights & Biases `torch::TrainCallback` now logs best validation metric score.
## [v0.3.3](https://github.com/allenai/tango/releases/tag/v0.3.3) - 2021-11-04
### Added
- Added support for PEP 604 in `FromParams`, i.e. writing union types as "X | Y" instead of "Union[X, Y]".
- [internals] Added a spot for miscellaneous end-to-end integration tests (not to be confused with "tests of integrations") in `tests/end_to_end/`.
- [internals] Core tests now run on all officially supported Python versions.
### Fixed
- Fixed a bug in `FromParams` where non-`FromParams` class parameters were not instantiated
properly (or at all).
- Fixed a bug in `FromParams` where kwargs were not passed on from a wrapper class to the wrapped class.
- Fixed small bug where some errors from git would be printed when executor metadata is created
outside of a git repository.
## [v0.3.2](https://github.com/allenai/tango/releases/tag/v0.3.2) - 2021-11-01
### Fixed
- Fixed a bug with `FromParams` that caused `.from_params()` to fail when the params contained
an object that was already instantiated.
- tango command no longer installs a SIGTERM handler, which fixes some bugs with integrations that use multiprocessing.
## [v0.3.1](https://github.com/allenai/tango/releases/tag/v0.3.1) - 2021-10-29
### Changed
- Updated the `LightningTrainStep` to optionally take in a `LightningDataModule` as input.
## [v0.3.0](https://github.com/allenai/tango/releases/tag/v0.3.0) - 2021-10-28
### Added
- Added `IterableDatasetDict`, a version of `DatasetDict` for streaming-like datasets.
- Added a [PyTorch Lightning](https://www.pytorchlightning.ai) integration with `LightningTrainStep`.
### Fixed
- Fixed bug with `FromParams` and `Lazy` where extra arguments would sometimes be passed down through
to a `Lazy` class when they shouldn't.
## [v0.2.4](https://github.com/allenai/tango/releases/tag/v0.2.4) - 2021-10-22
### Added
- Added support for [torch 1.10.0](https://github.com/pytorch/pytorch/releases).
### Changed
- `--file-friendly-logging` flag is now an option to the main `tango` command, so needs
to be passed before `run`, e.g. `tango --file-friendly-logging run ...`.
### Fixed
- Fixed bug with `Step.from_params`.
- Ensure logging is initialized is spawn processes during distributed training with `TorchTrainStep`.
## [v0.2.3](https://github.com/allenai/tango/releases/tag/v0.2.3) - 2021-10-21
### Added
- Added support for global settings file, `tango.yml`.
- Added 'include_package' (array of string) param to config spec.
- Added a custom error `StopEarly` that a `TrainCallback` can raise within the `TorchTrainStep`
to stop training early without crashing.
- Added step config, tango command, and tango version to executor metadata.
- Executor now also saves pip dependencies and conda environment files to the run directory
for each step.
### Fixed
- Ensured `**kwargs` arguments are logged in `FromParams`.
## [v0.2.2](https://github.com/allenai/tango/releases/tag/v0.2.2) - 2021-10-19
### Added
- Added new steps to `datasets` integration: `ConcatenateDatasets` ("datasets::concatenate") and `InterleaveDatasets` (datasets::interleave).
- Added `__contains__` and `__iter__` methods on `DatasetDict` so that it is now a `Mapping` class.
- Added `tango info` command that - among other things - displays which integrations are installed.
## [v0.2.1](https://github.com/allenai/tango/releases/tag/v0.2.1) - 2021-10-18
### Added
- Added `convert_to_tango_dataset_dict()` function in the `datasets` integration.
It's important for step caching purposes to use this to convert a HF `DatasetDict`
to a native Tango `DatasetDict` when that `DatasetDict` is part of the input to another
step. Otherwise the HF `DatasetDict` will have to be pickled to determine its hash.
### Changed
- `Format.checksum()` is now an abstract method. Subclasses should only compute checksum
on the serialized artifact and nothing else in the directory.
- [internals] Changed the relationship between `Executor`, `StepCache`, and `Step.`
`Executor` now owns the `StepCache`, and `Step` never interacts with `StepCache` directly.
## [v0.2.0](https://github.com/allenai/tango/releases/tag/v0.2.0) - 2021-10-15
### Added
- Added a [Weights & Biases](https://wandb.ai) integration with a training callback ("wandb::log")
for `TorchTrainStep` ("torch::train") that logs training and validation metrics to W&B.
### Fixed
- Fixed `Format.checksum()` when there is a symlink to a directory in the cache folder.
## [v0.1.3](https://github.com/allenai/tango/releases/tag/v0.1.3) - 2021-10-15
### Added
- Added the ability to track a metric other than "loss" for validation in `TorchTrainStep` ("torch::train").
### Fixed
- Final model returned from `TorchTrainStep` ("torch::train") will have best weights loaded.
- Checkpoints are saved from `TorchTrainStep` ("torch::train") even when there is no validation loop.
- Fixed `TorchTrainStep` ("torch::train") when `validation_split` is `None`.
- Fixed distributed training with `TorchTrainStep` ("torch::train") on GPU devices.
## [v0.1.2](https://github.com/allenai/tango/releases/tag/v0.1.2) - 2021-10-13
### Added
- Added support for YAML configuration files.
## [v0.1.1](https://github.com/allenai/tango/releases/tag/v0.1.1) - 2021-10-12
### Added
- `TorchTrainStep` now displays a progress bar while saving a checkpoint to file.
- The default executor now saves a "executor-metadata.json" file to the directory for each step.
### Changed
- Renamed `DirectoryStepCache` to `LocalStepCache` (registered as "local").
- `LocalStepCache` saves metadata to `cache-metadata.json` instead of `metadata.json`.
### Fixed
- Fixed bug with `TorchTrainStep` during distributed training.
- `FromParams` will automatically convert strings into `Path` types now when the annotation
is `Path`.
## [v0.1.0](https://github.com/allenai/tango/releases/tag/v0.1.0) - 2021-10-11
### Added
- Added `StepGraph` and `Executor` abstractions.
- Added a basic PyTorch training step registered as `"torch::train"`, along with other registrable
components, such as `Model`, `DataLoader`, `Sampler`, `DataCollator`, `Optimizer`, and `LRScheduler`.
- Added `DatasetRemixStep` in `tango.steps`.
- Added module `tango.common.sequences`.
- Added `DatasetDict` class in `tango.common.dataset_dict`.
- Added [🤗 Datasets](https://github.com/huggingface/datasets) integration.
- Added command-line options to set log level or disable logging completely.
### Changed
- `Step.work_dir`, `Step.unique_id`, `Step.dependencies`, and `Step.recursive_dependencies`
are now a properties instead of methods.
- `tango run` command will acquire a lock on the directory to avoid race conditions.
- Integrations can now be installed with `pip install tango[INTEGRATION_NAME]`. For example,
`pip install tango[torch]`.
- Added method `Registrable.search_modules()` for automatically finding and importing the modules
where a given ``name`` might be registered.
- `FromParams.from_params()` and `Registrable.resolve_class_name` will now call `Registrable.search_modules()` to automatically import modules where the type might be defined.
Thus for classes that are defined and registered within any `tango.*` submodules it is not necessary to explicitly import them.
### Fixed
- `Step` implementations can now take arbitrary `**kwargs` in their `run()` methods.
## [v0.0.3](https://github.com/allenai/tango/releases/tag/v0.0.3) - 2021-09-27
### Added
- Added `tango` command.
## [v0.0.2](https://github.com/allenai/tango/releases/tag/v0.0.2) - 2021-09-27
### Added
- Ported over core tango components from AllenNLP.
## [v0.0.1](https://github.com/allenai/tango/releases/tag/v0.0.1) - 2021-09-22
### Added
- Added initial project boilerplate.
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Groeneveld"
given-names: "Dirk"
affiliation: "Allen Institute for Artificial Intelligence"
- family-names: "Bhagia"
given-names: "Akshita"
affiliation: "Allen Institute for Artificial Intelligence"
- family-names: "Walsh"
given-names: "Pete"
affiliation: "Allen Institute for Artificial Intelligence"
title: "AI2 Tango"
abstract: "Organize your experiments into discrete steps that can be cached and reused throughout the lifetime of your research project."
version: "1.3.2"
repository-code: "https://github.com/allenai/tango"
license: "Apache-2.0"
date-released: "2023-10-27"
repository-code: "https://github.com/allenai/tango"
================================================
FILE: Dockerfile
================================================
# This Dockerfile can be used to build a Docker image suitable for tango projects.
ARG BASE_IMAGE=ghcr.io/allenai/pytorch:2.0.0-cuda11.7-python3.10
FROM ${BASE_IMAGE}
WORKDIR /stage
COPY . .
RUN /opt/conda/bin/pip install --no-cache-dir .[all]
WORKDIR /workspace
RUN rm -rf /stage/
ENTRYPOINT ["/opt/conda/bin/tango"]
================================================
FILE: Dockerfile.test
================================================
# This Dockerfile is for building an image suitable for running tango's GPU tests and integration tests.
# There are no instruction lines in this Dockerfile that install tango. Instead, the entrypoint
# script handles installing tango from a particular commit at runtime, based on the environment
# variable "COMMIT_SHA". That way we don't need to rebuild and push the image each time we run
# tests, and we can be sure the dependencies are always up-to-date.
FROM ghcr.io/allenai/pytorch:2.0.0-cuda11.7-python3.10
COPY scripts/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
WORKDIR /testing
ENTRYPOINT ["/entrypoint.sh"]
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: Makefile
================================================
.PHONY : docs
docs :
rm -rf docs/build/
sphinx-autobuild -b html --watch tango/ --watch examples/ docs/source/ docs/build/
.PHONY : run-checks
run-checks :
isort --check .
black --check .
ruff check .
mypy --check-untyped-defs .
CUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/
CUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules tango/integrations/torch tests/integrations/torch
CUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules tango/integrations/transformers tests/integrations/transformers
================================================
FILE: README.md
================================================
AI2 Tango replaces messy directories and spreadsheets full of file versions by organizing experiments into discrete steps that can be cached and reused throughout the lifetime of a research project.
## Quick links
- [Documentation](https://ai2-tango.readthedocs.io/)
- [PyPI Package](https://pypi.org/project/ai2-tango/)
- [Contributing](https://github.com/allenai/tango/blob/main/CONTRIBUTING.md)
- [License](https://github.com/allenai/tango/blob/main/LICENSE)
## In this README
- [Quick start](#quick-start)
- [Installation](#installation)
- [Installing with PIP](#installing-with-pip)
- [Installing with Conda](#installing-with-conda)
- [Installing from source](#installing-from-source)
- [Checking your installation](#checking-your-installation)
- [Docker image](#docker-image)
- [FAQ](#faq)
- [Team](#team)
- [License](#license)
## Quick start
Create a Tango step:
```python
# hello.py
from tango import step
@step()
def hello(name: str) -> str:
message = f"Hello, {name}!"
print(message)
return message
```
And create a corresponding experiment configuration file:
```jsonnet
// hello.jsonnet
{
steps: {
hello: {
type: "hello",
name: "World",
}
}
}
```
Then run the experiment using a local workspace to cache the result:
```bash
tango run hello.jsonnet -w /tmp/workspace
```
You'll see something like this in the output:
```
Starting new run expert-llama
● Starting step "hello"...
Hello, World!
✓ Finished step "hello"
✓ Finished run expert-llama
```
If you run this a second time the output will now look like this:
```
Starting new run open-crab
✓ Found output for step "hello" in cache...
✓ Finished run open-crab
```
You won't see "Hello, World!" this time because the result of the step was found in the cache, so it wasn't run again.
For a more detailed introduction check out the [First Steps](https://ai2-tango.readthedocs.io/en/latest/first_steps.html) walk-through.
## Installation
**ai2-tango** requires Python 3.8 or later.
### Installing with `pip`
**ai2-tango** is available [on PyPI](https://pypi.org/project/ai2-tango/). Just run
```bash
pip install ai2-tango
```
To install with a specific integration, such as `torch` for example, run
```bash
pip install 'ai2-tango[torch]'
```
To install with all integrations, run
```bash
pip install 'ai2-tango[all]'
```
### Installing with `conda`
**ai2-tango** is available on conda-forge. You can install just the base package with
```bash
conda install tango -c conda-forge
```
You can pick and choose from the integrations with one of these:
```bash
conda install tango-datasets -c conda-forge
conda install tango-torch -c conda-forge
conda install tango-wandb -c conda-forge
```
You can also install everything:
```bash
conda install tango-all -c conda-forge
```
Even though **ai2-tango** itself is quite small, installing everything will pull in a lot of dependencies.
Don't be surprised if this takes a while!
### Installing from source
To install **ai2-tango** from source, first clone [the repository](https://github.com/allenai/tango):
```bash
git clone https://github.com/allenai/tango.git
cd tango
```
Then run
```bash
pip install -e '.[all]'
```
To install with only a specific integration, such as `torch` for example, run
```bash
pip install -e '.[torch]'
```
Or to install just the base tango library, you can run
```bash
pip install -e .
```
### Checking your installation
Run
```bash
tango info
```
to check your installation.
### Docker image
You can build a Docker image suitable for tango projects by using [the official Dockerfile](https://github.com/allenai/tango/blob/main/Dockerfile) as a starting point for your own Dockerfile, or you can simply use one of our [prebuilt images](https://github.com/allenai/tango/pkgs/container/tango) as a base image in your Dockerfile. For example:
```Dockerfile
# Start from a prebuilt tango base image.
# You can choose the right tag from the available options here:
# https://github.com/allenai/tango/pkgs/container/tango/versions
FROM ghcr.io/allenai/tango:cuda11.3
# Install your project's additional requirements.
COPY requirements.txt .
RUN /opt/conda/bin/pip install --no-cache-dir -r requirements.txt
# Install source code.
# This instruction copies EVERYTHING in the current directory (build context),
# which may not be what you want. Consider using a ".dockerignore" file to
# exclude files and directories that you don't want on the image.
COPY . .
```
Make sure to choose the right base image for your use case depending on the version of tango you're using and the CUDA version that your host machine supports.
You can see a list of all available image tags [on GitHub](https://github.com/allenai/tango/pkgs/container/tango/versions).
## FAQ
### Why is the library named Tango?
The motivation behind this library is that we can make research easier by composing it into well-defined steps. What happens when you choreograph a number of steps together? Well, you get a dance. And since our [team's leader](https://nasmith.github.io/) is part of a tango band, "AI2 Tango" was an obvious choice!
### How can I debug my steps through the Tango CLI?
You can run the `tango` command through [pdb](https://docs.python.org/3/library/pdb.html). For example:
```bash
python -m pdb -m tango run config.jsonnet
```
### How is Tango different from [Metaflow](https://metaflow.org), [Airflow](https://airflow.apache.org), or [redun](https://github.com/insitro/redun)?
We've found that existing DAG execution engines like these tools are great for production workflows but not as well suited for messy, collaborative research projects
where code is changing constantly. AI2 Tango was built *specifically* for these kinds of research projects.
### How does Tango's caching mechanism work?
AI2 Tango caches the results of steps based on the `unique_id` of the step. The `unique_id` is essentially a hash of all of the inputs to the step along with:
1. the step class's fully qualified name, and
2. the step class's `VERSION` class variable (an arbitrary string).
Unlike other workflow engines like [redun](https://github.com/insitro/redun), Tango does *not* take into account the source code of the class itself (other than its fully qualified name) because we've found that using a hash of the source code bytes is way too sensitive and less transparent for users.
When you change the source code of your step in a meaningful way you can just manually change the `VERSION` class variable to indicate to Tango
that the step has been updated.
## Team
**ai2-tango** is developed and maintained by the AllenNLP team, backed by [the Allen Institute for Artificial Intelligence (AI2)](https://allenai.org/).
AI2 is a non-profit institute with the mission to contribute to humanity through high-impact AI research and engineering.
To learn more about who specifically contributed to this codebase, see [our contributors](https://github.com/allenai/tango/graphs/contributors) page.
## License
**ai2-tango** is licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0).
A full copy of the license can be found [on GitHub](https://github.com/allenai/tango/blob/main/LICENSE).
================================================
FILE: RELEASE_PROCESS.md
================================================
# GitHub Release Process
## Steps
1. Update the version in `tango/version.py`.
2. Run the release script:
```bash
./scripts/release.sh
```
This will automatically update the CHANGELOG, commit the changes to the CHANGELOG and `version.py` (and any other files you might have changed),
and then create a new tag in git which will trigger a workflow on GitHub Actions that handles the rest.
## Fixing a failed release
If for some reason the GitHub Actions release workflow failed with an error that needs to be fixed, you'll have to delete both the tag and corresponding release from GitHub. After you've pushed a fix, delete the tag from your local clone with
```bash
git tag -l | xargs git tag -d && git fetch -t
```
Then repeat the steps above.
================================================
FILE: docs/.gitignore
================================================
build
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
================================================
FILE: docs/source/_static/css/custom.css
================================================
================================================
FILE: docs/source/api/commands.rst
================================================
Commands
========
.. automodule:: tango.__main__
================================================
FILE: docs/source/api/components/executor.rst
================================================
Executor
========
Base class
----------
.. autoclass:: tango.executor.Executor
:members:
.. autoclass:: tango.executor.ExecutorOutput
:members:
.. autoclass:: tango.executor.ExecutionMetadata
:members:
================================================
FILE: docs/source/api/components/format.rst
================================================
Format
======
Base class
----------
.. autoclass:: tango.format.Format
:members:
:private-members:
Implementations
---------------
.. automodule:: tango.format
:members:
:exclude-members: Format,read,write,checksum
================================================
FILE: docs/source/api/components/index.rst
================================================
Components
==========
The core components of **AI2 Tango**.
.. toctree::
:maxdepth: 2
:caption: Components
step
step_info
step_graph
workspace
step_cache
format
executor
================================================
FILE: docs/source/api/components/step.rst
================================================
Step
====
Base class
----------
.. autoclass:: tango.step.Step
:members:
:special-members:
:exclude-members: from_params
.. autofunction:: tango.step.step
.. autoclass:: tango.step.WithUnresolvedSteps
:members:
.. autoclass:: tango.step.StepResources
:members:
Implementations
---------------
.. automodule:: tango.steps
:members:
================================================
FILE: docs/source/api/components/step_cache.rst
================================================
StepCache
=========
Base class
----------
.. autoclass:: tango.step_cache.StepCache
:members:
:special-members:
Implementations
---------------
.. autoclass:: tango.step_caches.LocalStepCache
:members:
.. autoclass:: tango.step_caches.MemoryStepCache
Metadata
--------
.. autoclass:: tango.step_cache.CacheMetadata
:members:
================================================
FILE: docs/source/api/components/step_graph.rst
================================================
StepGraph
=========
.. autoclass:: tango.step_graph.StepGraph
:members:
================================================
FILE: docs/source/api/components/step_info.rst
================================================
StepInfo
========
.. autoclass:: tango.step_info.StepInfo
:member-order: bysource
:members:
.. autoclass:: tango.step_info.StepState
:member-order: bysource
:members:
.. autoclass:: tango.step_info.PlatformMetadata
:member-order: bysource
:members:
.. autoclass:: tango.step_info.EnvironmentMetadata
:member-order: bysource
:members:
.. autoclass:: tango.step_info.GitMetadata
:member-order: bysource
:members:
.. autoclass:: tango.step_info.TangoMetadata
:member-order: bysource
:members:
================================================
FILE: docs/source/api/components/workspace.rst
================================================
Workspace
=========
Base class
----------
.. autoclass:: tango.workspace.Workspace
:members:
Implementations
---------------
.. autoclass:: tango.workspaces.LocalWorkspace
.. autoclass:: tango.workspaces.MemoryWorkspace
Metadata
--------
.. autoclass:: tango.workspace.Run
:members:
.. autoclass:: tango.workspace.RunInfo
:members:
Miscellaneous
-------------
.. autoclass:: tango.workspace.RunSort
:members:
.. autoclass:: tango.workspace.StepInfoSort
:members:
================================================
FILE: docs/source/api/det_hash.rst
================================================
Deterministic Hashing
=====================
In order to detect whether a :class:`~tango.step.Step` has to be re-run or not, Tango relies on some tools to compute
deterministic hashes from the inputs to the :class:`~tango.step.Step`.
The center-piece of this module is the :func:`~tango.common.det_hash.det_hash` function, which computes a deterministic hash of an
arbitrary Python object. The other things in this module influence how that works in various ways.
.. automodule:: tango.common.det_hash
:members:
================================================
FILE: docs/source/api/exceptions.rst
================================================
Exceptions
==========
.. autoexception:: tango.common.exceptions.TangoError
:members:
.. automodule:: tango.common.exceptions
:members:
:exclude-members: TangoError
================================================
FILE: docs/source/api/integrations/beaker.rst
================================================
🧪 Beaker
=========
.. automodule:: tango.integrations.beaker
Reference
---------
.. autoclass:: tango.integrations.beaker.BeakerWorkspace
.. autoclass:: tango.integrations.beaker.BeakerStepCache
.. autoclass:: tango.integrations.beaker.BeakerExecutor
:members: DEFAULT_BEAKER_IMAGE
.. autoclass:: tango.integrations.beaker.BeakerScheduler
:members:
.. autoclass:: tango.integrations.beaker.SimpleBeakerScheduler
.. autoclass:: tango.integrations.beaker.ResourceAssignment
:members:
.. autoclass:: tango.integrations.beaker.ResourceAssignmentError
================================================
FILE: docs/source/api/integrations/datasets.rst
================================================
🤗 Datasets
===========
.. automodule:: tango.integrations.datasets
Reference
---------
.. autofunction:: tango.integrations.datasets.convert_to_tango_dataset_dict
.. autoclass:: tango.integrations.datasets.DatasetsFormat
.. autoclass:: tango.integrations.datasets.LoadDataset
:members:
.. autoclass:: tango.integrations.datasets.LoadStreamingDataset
:members:
.. autoclass:: tango.integrations.datasets.InterleaveDatasets
:members:
.. autoclass:: tango.integrations.datasets.ConcatenateDatasets
:members:
.. autoclass:: tango.integrations.datasets.DatasetRemixStep
:members:
================================================
FILE: docs/source/api/integrations/fairscale.rst
================================================
🔥 FairScale
============
.. automodule:: tango.integrations.fairscale
Reference
---------
.. autoclass:: tango.integrations.fairscale.FairScaleTrainingEngine
.. autoclass:: tango.integrations.fairscale.FSDPConfig
:members:
.. autofunction:: tango.integrations.fairscale.with_wrapped_modules
================================================
FILE: docs/source/api/integrations/flax.rst
================================================
Flax
=======
.. automodule:: tango.integrations.flax
Reference
---------
Train step
~~~~~~~~~~
.. autoclass:: tango.integrations.flax.FlaxTrainStep
:members:
.. autoclass:: tango.integrations.flax.TrainConfig
:members:
Eval step
~~~~~~~~~
.. autoclass:: tango.integrations.flax.FlaxEvalStep
:members:
Flax format
~~~~~~~~~~~~
.. autoclass:: tango.integrations.flax.FlaxFormat
Model
~~~~~
.. autoclass:: tango.integrations.flax.Model
:members:
Optim
~~~~~
.. autoclass:: tango.integrations.flax.Optimizer
:members:
.. autoclass:: tango.integrations.flax.LRScheduler
:members:
Data
~~~~
.. autoclass:: tango.integrations.flax.DataLoader
:members:
.. autoclass:: tango.integrations.flax.FlaxDataLoader
:members:
Callbacks
~~~~~~~~~
.. autoclass:: tango.integrations.flax.TrainCallback
:members:
:member-order: bysource
.. autoclass:: tango.integrations.flax.EvalCallback
:members:
:member-order: bysource
================================================
FILE: docs/source/api/integrations/gs.rst
================================================
☁️ Google Cloud Storage
=======================
.. automodule:: tango.integrations.gs
Reference
---------
.. autoclass:: tango.integrations.gs.GSWorkspace
.. autoclass:: tango.integrations.gs.GSStepCache
================================================
FILE: docs/source/api/integrations/index.rst
================================================
Integrations
============
.. automodule:: tango.integrations
.. toctree::
:maxdepth: 2
:caption: Integrations
torch
fairscale
datasets
transformers
wandb
beaker
flax
gs
================================================
FILE: docs/source/api/integrations/torch.rst
================================================
🔥 PyTorch
==========
.. automodule:: tango.integrations.torch
Reference
---------
Train step
~~~~~~~~~~
.. autoclass:: tango.integrations.torch.TorchTrainStep
:members:
.. autoclass:: tango.integrations.torch.TrainConfig
:members:
Eval step
~~~~~~~~~
.. autoclass:: tango.integrations.torch.TorchEvalStep
:members:
Torch format
~~~~~~~~~~~~
.. autoclass:: tango.integrations.torch.TorchFormat
Model
~~~~~
.. autoclass:: tango.integrations.torch.Model
:members:
TrainingEngine
~~~~~~~~~~~~~~
.. autoclass:: tango.integrations.torch.TrainingEngine
:members:
.. autoclass:: tango.integrations.torch.TorchTrainingEngine
Optim
~~~~~
.. autoclass:: tango.integrations.torch.Optimizer
:members:
.. autoclass:: tango.integrations.torch.LRScheduler
:members:
Data
~~~~
.. autoclass:: tango.integrations.torch.DataLoader
:members:
.. autoclass:: tango.integrations.torch.Sampler
:members:
.. autoclass:: tango.integrations.torch.DataCollator
:members:
:special-members: __call__
.. autoclass:: tango.integrations.torch.ConcatTensorDictsCollator
:members:
Callbacks
~~~~~~~~~
.. autoclass:: tango.integrations.torch.TrainCallback
:members:
:member-order: bysource
.. autoclass:: tango.integrations.torch.EvalCallback
:members:
:member-order: bysource
.. autoclass:: tango.integrations.torch.StopEarlyCallback
.. autoclass:: tango.integrations.torch.StopEarly
:members:
================================================
FILE: docs/source/api/integrations/transformers.rst
================================================
🤗 Transformers
===============
.. automodule:: tango.integrations.transformers
:members:
.. autofunction:: tango.integrations.transformers.ia3.modify_with_ia3
================================================
FILE: docs/source/api/integrations/wandb.rst
================================================
⚖️ Weights & Biases
===================
.. automodule:: tango.integrations.wandb
Reference
---------
.. autoclass:: tango.integrations.wandb.WandbWorkspace
.. autoclass:: tango.integrations.wandb.WandbStepCache
.. autoclass:: tango.integrations.wandb.WandbTrainCallback
.. autoclass:: tango.integrations.wandb.WandbFlaxTrainCallback
================================================
FILE: docs/source/api/logging.rst
================================================
Logging
=======
.. automodule:: tango.common.logging
Reference
---------
.. autodata:: tango.common.logging.TANGO_LOG_LEVEL
.. autodata:: tango.common.logging.FILE_FRIENDLY_LOGGING
.. autodata:: tango.common.logging.cli_logger
.. autofunction:: tango.common.logging.initialize_logging
.. autofunction:: tango.common.logging.initialize_worker_logging
.. autofunction:: tango.common.logging.initialize_prefix_logging
.. autofunction:: tango.common.logging.teardown_logging
.. autofunction:: tango.common.logging.file_handler
================================================
FILE: docs/source/api/sequences.rst
================================================
Sequences
=========
This module contains some utilities to make sequences out of other sequences. All of these are lazy, so they
take minimal time and memory when you create them. These work particularly well when used together. For example,
you can concatenate two sequences (:class:`~tango.common.sequences.ConcatenatedSequence`), and then shuffle
them (:class:`~tango.common.sequences.ShuffledSequence`).
This module is not dependent on other Tango modules and can be used in isolation.
.. automodule:: tango.common.sequences
:members:
================================================
FILE: docs/source/api/settings.rst
================================================
Global settings
---------------
Some command-line options can set globally in a ``tango.yml`` or ``tango.yaml`` settings file.
Tango will check the current directory and ``~/.config/``, in that order.
The full spec of this config is defined by the :class:`~tango.settings.TangoGlobalSettings` class.
.. autoclass:: tango.settings.TangoGlobalSettings
:members:
:exclude-members: path,find_or_default
:member-order: bysource
================================================
FILE: docs/source/api/utilities.rst
================================================
Utilities
=========
.. automodule:: tango.common
:members:
:exclude-members: det_hash
================================================
FILE: docs/source/conf.py
================================================
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
import logging
import os
import sys
from datetime import datetime
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath("../../"))
from tango.version import VERSION, VERSION_SHORT # noqa: E402
# -- Project information -----------------------------------------------------
project = "AI2 Tango"
copyright = f"{datetime.today().year}, Allen Institute for Artificial Intelligence"
author = "Allen Institute for Artificial Intelligence"
version = VERSION_SHORT
release = VERSION
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"myst_parser",
"sphinx.ext.intersphinx",
"sphinx.ext.viewcode",
"sphinx.ext.doctest",
"sphinx_copybutton",
"sphinx_autodoc_typehints",
]
suppress_warnings = ["myst.header"]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build"]
source_suffix = [".rst", ".md"]
# -- Extension configuration -------------------------------------------------
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"rich": ("https://rich.readthedocs.io/en/latest", None),
"torch": ("https://pytorch.org/docs/stable", None),
"flax": ("https://flax.readthedocs.io/en/latest", None),
"fairscale": ("https://fairscale.readthedocs.io/en/latest/", None),
"datasets": ("https://huggingface.co/docs/datasets/master/en", None),
"transformers": ("https://huggingface.co/docs/transformers/master/en", None),
"beaker": ("https://beaker-py.readthedocs.io/en/latest/", None),
}
# Tell myst-parser to assign header anchors for h1-h3.
myst_heading_anchors = 3
# By default, sort documented members by type within classes and modules.
autodoc_member_order = "groupwise"
python_use_unqualified_type_names = True
# Include default values when documenting parameter types.
typehints_defaults = "comma"
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "furo"
html_title = f"ai2-tango v{VERSION}"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
html_css_files = ["css/custom.css"]
html_favicon = "_static/favicon.ico"
html_theme_options = {
"light_css_variables": {
"color-announcement-background": "#1B4596",
"color-announcement-text": "#FFFFFF",
},
"dark_css_variables": {},
"light_logo": "tango_final_squareish.png",
"dark_logo": "tango_final_squareish.png",
"footer_icons": [
{
"name": "GitHub",
"url": "https://github.com/allenai/tango",
"html": """
""", # noqa: E501
"class": "",
},
],
}
# -- Hack to get rid of stupid warnings from sphinx_autodoc_typehints --------
class ShutupSphinxAutodocTypehintsFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
if "Cannot resolve forward reference" in record.msg:
return False
return True
logging.getLogger("sphinx.sphinx_autodoc_typehints").addFilter(ShutupSphinxAutodocTypehintsFilter())
================================================
FILE: docs/source/examples/euler.md
================================================
```{include} ../../../examples/euler/README.md
```
## Running the experiment
If you haven't already, clone the [tango repository](https://github.com/allenai/tango) and then
change directories into `examples/euler`.
You can then run the experiment with:
```bash
tango run euler_general.jsonnet -i complex_arithmetic -w workspace
```
This will leave its results in a subdirectory of `workspace/runs/` corresponding to the name of the run.
The output it prints should look something like this:
```
Starting new run comic-heron
Server started at http://localhost:8080/run/comic-heron
[step i_times_pi] ● Starting step "i_times_pi"...
[step i_times_pi] ✓ Finished step "i_times_pi"
[step cos] ● Starting step "cos"...
[step cos] ✓ Finished step "cos"
[step sin] ● Starting step "sin"...
[step sin] ✓ Finished step "sin"
[step pow_e] ✓ Found output for step "i_times_pi" in cache (needed by "pow_e")...
[step pow_e] ● Starting step "pow_e"...
[step pow_e] ✓ Finished step "pow_e"
[step i_times_sin] ✓ Found output for step "sin" in cache (needed by "i_times_sin")...
[step i_times_sin] ● Starting step "i_times_sin"...
[step i_times_sin] ✓ Finished step "i_times_sin"
[step sum] ✓ Found output for step "cos" in cache (needed by "sum")...
[step sum] ✓ Found output for step "i_times_sin" in cache (needed by "sum")...
[step sum] ● Starting step "sum"...
[step sum] ✓ Finished step "sum"
[step sub] ✓ Found output for step "sum" in cache (needed by "sub")...
[step sub] ✓ Found output for step "pow_e" in cache (needed by "sub")...
[step sub] ● Starting step "sub"...
[step sub] ✓ Finished step "sub"
[step print] ✓ Found output for step "sub" in cache (needed by "print")...
[step print] ● Starting step "print"...
[step print] 0j
[step print] ✓ Finished step "print"
✓ Finished run comic-heron
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Step Name ┃ Status ┃ Cached Result ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ cos │ ✓ succeeded │ workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk │
│ i_times_pi │ ✓ succeeded │ workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae │
│ i_times_sin │ ✓ succeeded │ workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf │
│ pow_e │ ✓ succeeded │ workspace/cache/ExponentiateStep-1swPpNipP6HBSP5rKdNjEqbYAWNf4CdG │
│ print │ ✓ succeeded │ N/A │
│ sin │ ✓ succeeded │ workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk │
│ sub │ ✓ succeeded │ workspace/cache/SubtractionStep-4ygj1UyLk6TCVBxN7DWTCccbMa7M1C5v │
│ sum │ ✓ succeeded │ workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP │
└─────────────┴─────────────┴───────────────────────────────────────────────────────────────────┘
✓ 8 succeeded
Use your workspace to get the cached result of a step, e.g.
>>> from tango import Workspace
>>> workspace = Workspace.from_url(...)
>>> workspace.step_result_for_run("comic-heron", "sum")
```
A few things are of note here:
1. Tango assigns a name to your run. In this case, the name is "comic-heron".
2. In this configuration, the "print" step prints the output ("`0j`"). Most of the time though, you will look
for the output in the output directories that are given in the table.
3. You might notice that the "print" step produces no output. That's because it is uncacheable, and thus writes
out nothing.
## Change a step
Let's make an update to a step! Open `complex_arithmetic.py` and change `AdditionStep`. The actual change you make
in the `run()` method does not matter, but the important thing is to update the `VERSION` member of the
`AdditionStep` class. `AdditionStep` does not yet have a `VERSION`, so we will give it one:
```Python
@Step.register("cadd")
class AdditionStep(Step):
VERSION = "002" # This is the important change.
def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex: # type: ignore
return make_complex(a) + make_complex(b)
```
Now run the config again with
```bash
tango run euler_general.jsonnet -i complex_arithmetic -w workspace
```
This time, the output will look like this:
```
Starting new run right-amoeba
Server started at http://localhost:8080/run/right-amoeba
[step sum] ✓ Found output for step "cos" in cache (needed by "sum")...
[step sum] ✓ Found output for step "i_times_sin" in cache (needed by "sum")...
[step sum] ● Starting step "sum"...
[step sum] ✓ Finished step "sum"
[step sub] ✓ Found output for step "sum" in cache (needed by "sub")...
[step sub] ✓ Found output for step "pow_e" in cache (needed by "sub")...
[step sub] ● Starting step "sub"...
[step sub] ✓ Finished step "sub"
[step print] ✓ Found output for step "sub" in cache (needed by "print")...
[step print] ● Starting step "print"...
[step print] 0j
[step print] ✓ Finished step "print"
✓ Finished run right-amoeba
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Step Name ┃ Status ┃ Cached Result ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ cos │ - not run │ workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk │
│ i_times_pi │ - not run │ workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae │
│ i_times_sin │ - not run │ workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf │
│ pow_e │ - not run │ workspace/cache/ExponentiateStep-1swPpNipP6HBSP5rKdNjEqbYAWNf4CdG │
│ print │ ✓ succeeded │ N/A │
│ sin │ - not run │ workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk │
│ sub │ ✓ succeeded │ workspace/cache/SubtractionStep-42mdcQBtrNAYvxYhmzdd1vj2uCG8N5Yf │
│ sum │ ✓ succeeded │ workspace/cache/AdditionStep-002-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP │
└─────────────┴─────────────┴───────────────────────────────────────────────────────────────────┘
✓ 3 succeeded, 5 not run
Use your workspace to get the cached result of a step, e.g.
>>> from tango import Workspace
>>> workspace = Workspace.from_url(...)
>>> workspace.step_result_for_run("right-amoeba", "sum")
```
As you can see, it re-used the cached results for several of the steps, and only ran three steps anew.
```{eval-rst}
:class:`tango.step.Step.VERSION` is just one of the ways in which you can change the behavior of a step. Head over to the
documentation of the :class:`tango.step.Step` class to see the others.
```
================================================
FILE: docs/source/examples/eval_p3.md
================================================
```{include} ../../../examples/eval_p3/README.md
```
## `RougeScoreStep`
`RougeScoreStep` is defined in `eval.py`:
```{literalinclude} ../../../examples/eval_p3/eval.py
:language: py
```
## Config
The configuration file, `config.jsonnet`, uses some advanced [Jsonnet](https://jsonnet.org) concepts like `std.foldl`
to create the same configuration for all 10 prompts:
```{literalinclude} ../../../examples/eval_p3/config.jsonnet
```
## Run it
You can run the experiment with:
```bash
tango run config.jsonnet -i eval -d /tmp/workspace
```
================================================
FILE: docs/source/examples/index.rst
================================================
Examples
========
Real-world examples of using Tango.
You can find all of these `on GitHub `_ as well.
.. toctree::
:maxdepth: 2
:caption: Examples
euler
train_lm
eval_p3
================================================
FILE: docs/source/examples/train_lm.md
================================================
# Fine-tuning a language model
```{include} ../../../examples/train_lm/README.md
:start-after:
:end-before:
```
```{tip}
You can find the full code for this example on [GitHub](https://github.com/allenai/tango/tree/main/examples/train_lm).
```
## Components
We'll need to write a step for tokenizing the data and preparing it for language model training.
All of the other steps we need are provided by Tango integrations.
So, create a file called `tokenize_step.py` with following contents:
```{literalinclude} ../../../examples/train_lm/tokenize_step.py
:language: py
```
## Configuration file
Next you'll need to create a configuration file that defines the experiment. Just copy over these contents into a file called `config.jsonnet`:
```{literalinclude} ../../../examples/train_lm/config.jsonnet
```
## Run it
Now we can run the experiment with:
```bash
tango run config.jsonnet -i tokenize_step.py -d /tmp/results
```
================================================
FILE: docs/source/faq.md
================================================
# FAQ
```{include} ../../README.md
:start-after:
:end-before:
```
================================================
FILE: docs/source/first_steps.md
================================================
# First Steps
## What is a Step?
Tango is a Python library for choreographing machine learning research experiments by executing
a series of steps.
A step can do anything, really, such as [prepare a dataset](tango.integrations.datasets.LoadDataset), [train a model](tango.integrations.torch.TorchTrainStep), send an email to your mother wishing her happy birthday, *etc*.
Concretely, each step is just a subclass of {class}`~tango.step.Step`, where the {meth}`~tango.step.Step.run` method in particular defines what the step actually does.
So anything that can be implemented in Python can be run as a step.
Steps can also depend on other steps in that the output of one step can be part of the input to another step.
Therefore, the steps that make up an experiment form a [directed graph](tango.step_graph.StepGraph).
The concept of the {class}`~tango.step.Step` is the bread and butter that makes Tango so general and powerful.
*So* powerful, in fact, that you might be wondering if Tango is [Turing-complete](https://en.wikipedia.org/wiki/Turing_completeness)?
Well, we don't know yet, but we can say at least that Tango is **Tango-complete** 😉
## Configuration files
Experiments themselves are defined through JSON, [Jsonnet](https://jsonnet.org/), or YAML configuration files.
At a minimum, these files must contain the "steps" field, which should be a mapping of arbitrary (yet unique) step names to the configuration of the corresponding step.
For example, let's create a config file called `config.jsonnet` with the following contents:
```json
{
"steps": {
"random_name": {
"type": "random_choice",
"choices": ["Turing", "Tango", "Larry"],
},
"say_hello": {
"type": "concat_strings",
"string1": "Hello, ",
"string2": {
"type": "ref",
"ref": "random_name"
}
},
"print": {
"type": "print",
"input": {
"type": "ref",
"ref": "say_hello"
}
}
}
}
```
*Can you guess what this experiment does?*
There are three steps in this experiment graph: "random_name" is the name of one step, "say_hello" is the name of another, and "print" is the name of the last.
The "type" parameter within the config of each step tells Tango which {class}`~tango.step.Step` class implementation to use for that step.
So, within the "random_name" step config
```json
"random_name": {
"type": "random_choice",
"choices": ["Turing", "Tango", "Larry"],
}
```
the `"type": "random_choice"` part tells Tango to use the {class}`~tango.step.Step` subclass that is registered by the name "random_choice".
But wait... what do we mean by *registered*?
Tango keeps track of an internal registry for certain classes (such as the {class}`~tango.step.Step` class) that is just a mapping of arbitrary unique names to subclasses.
When you look through Tango's source code, you'll see things like:
```python
@Step.register("foo")
class Foo(Step):
...
```
This is how subclasses get added to the registry.
In this case the subclass `Foo` is added to the `Step` registry under the name "foo", so if you were to use `"type": "foo"` in your configuration file, Tango would understand
that you mean to use the `Foo` class for the given step.
```{tip}
Any class that inherits from {class}`~tango.common.registrable.Registrable` can have its own
registry.
```
Now back to our example.
The step classes referenced in our configuration file ("random_choice" and "concat_strings") don't actually exist in the Tango library (though the ["print" step](tango.steps.PrintStep) does),
but we can easily implement and register them on our own.
Let's put them in a file called `components.py`:
```python
# file: components.py
import random
from typing import List
from tango import Step
@Step.register("random_choice")
class RandomChoiceStep(Step):
DETERMINISTIC = False
def run(self, choices: List[str]) -> str:
return random.choice(choices)
@Step.register("concat_strings")
class ConcatStringsStep(Step):
def run(self, string1: str, string2: str) -> str:
return string1 + string2
```
```{important}
It's important that you use type hints in your code so that Tango can properly construct Python objects from the corresponding serialized (JSON) objects
and warn you when the types don't match up.
```
So as long as Tango is able to import this module (`components.py`) these step implementations will be added to the registry
and Tango will know how to instantiate and run them.
There's also a short-hand way of implementing steps, using the {func}`@step() ` function decorator:
```python
from tango import step
@step(deterministic=False)
def random_choice(choices: List[str]) -> str:
return random.choice(choices)
@step()
def concat_strings(string1: str, string2: str) -> str:
return string1 + string2
```
This will register these steps under the name of the corresponding function, i.e. "random_choice" and "concat_strings", by default, though that can be overridden by specifying the "name" parameter to the decorator:
```python
@step(name="random-string", deterministic=False)
def random_choice(choices: List[str]) -> str:
return random.choice(choices)
```
## Executing an experiment
At this point we've implemented our custom steps (`components.py`) and created our configuration
file `config.jsonnet`, so we're ready to actually run this experiment.
For that, just use the `tango run` command:
```
$ tango run config.jsonnet -i components
```
```{tip}
- The `-i` option is short for `--include-package`, which takes the name of a Python package which Tango will try to import.
In this case our custom steps are in `components.py`, so we need Tango to import this module to find those steps.
As long as `components.py` is in the current directory or somewhere else on the `PYTHONPATH`, Tango will be able to find and import
this module when you pass `-i components` (note the lack of the `.py` at the end).
```
You should see something like this in the output:
```
Starting new run cute-kitten
● Starting step "random_name"
✓ Finished step "random_name"
● Starting step "say_hello"
✓ Finished step "say_hello"
● Starting step "print"
Hello, Tango
✓ Finished step "print"
```
## Step caching
This particular experiment didn't write any results to disk, but in many situations you'll want to save the output of at least some of your steps.
For example, if you're using the {class}`~tango.integrations.torch.TorchTrainStep` step, the output is a trained model, which is certainly a useful thing to keep around.
In other cases, you may not actually care about the direct result of a particular step, but it could still be useful to save it when possible so that Tango doesn't need to run the step
again unnecessarily.
This is where Tango's caching mechanism comes in.
To demonstrate this, let's look at another example that pretends to do some expensive computation.
Here is the `config.jsonnet` file:
```json
{
"steps": {
"add_numbers": {
"type": "really_inefficient_addition",
"num1": 34,
"num2": 8
}
}
}
```
And let's implement "really_inefficient_addition":
```python
# components.py
import time
from tango import Step, JsonFormat
from tango.common import Tqdm
@Step.register("really_inefficient_addition")
class ReallyInefficientAdditionStep(Step):
DETERMINISTIC = True
CACHEABLE = True
FORMAT = JsonFormat()
def run(self, num1: int, num2: int) -> int:
for _ in Tqdm.tqdm(range(100), desc="Computing...", total=100):
time.sleep(0.05)
return num1 + num2
```
There are a couple of things to note about this step, other than the obvious inefficiencies; the class variables
we've defined: {attr}`~tango.step.Step.DETERMINISTIC`, {attr}`~tango.step.Step.CACHEABLE`, and
{attr}`~tango.step.Step.FORMAT`.
`DETERMINISTIC = True` tells Tango that, given particular inputs, the output to this step will always be the same
every time it is ran, which has implications on caching.
By default, Tango assumes steps are deterministic.
You can override this by saying `DETERMINISTIC = False`.
Tango will warn you when you try to cache a non-deterministic step.
`CACHEABLE = True` tells Tango that it can cache this step and `FORMAT = JsonFormat()` defines which
{class}`~tango.format.Format` Tango will use to serialize the result of the step.
This time when we run the experiment we'll designate a specific directory for Tango to use:
```bash
$ tango run config.jsonnet -i components -d workspace/
```
```
Starting new run live-tarpon
● Starting step "add_numbers"
Computing...: 100%|##########| 100/100 [00:05<00:00, 18.99it/s]
✓ Finished step "add_numbers"
✓ The output for "add_numbers" is in workspace/runs/live-tarpon/add_numbers
```
The last line in the output tells us where we can find the result of our "add_numbers" step. `live-tarpon` is
the name of the run. Run names are randomly generated and may be different on your machine. `add_numbers` is the
name of the step in your config. The whole path is a symlink to a directory, which contains (among other things)
a file `data.json`:
```bash
$ cat workspace/runs/live-tarpon/add_numbers/data.json
```
```
42
```
Now look what happens when we run this step again:
```bash
$ tango run config.jsonnet -i components -d workspace/
```
```
Starting new run modest-shrimp
✓ Found output for "add_numbers" in cache
✓ The output for "add_numbers" is in workspace/runs/modest-shrimp/add_numbers
```
Tango didn't have to run our really inefficient addition step this time because it found the previous cached
result. It put the results in the result directory for a different run (in our case, the `modest-shrimp` run),
but once again it is a symlink that links to the same results from our first run.
If we changed the inputs to the step in `config.jsonnet`:
```diff
"add_numbers": {
"type": "really_inefficient_addition",
"num1": 34,
- "num2": 8
+ "num2": 2
}
}
}
```
And ran it again:
```bash
$ tango run config.jsonnet -i components -d workspace/
```
```
Starting new run true-parrot
● Starting step "add_numbers"
Computing...: 100%|##########| 100/100 [00:05<00:00, 19.13it/s]
✓ Finished step "add_numbers"
✓ The output for "add_numbers" is in workspace/runs/true-parrot/add_numbers
```
You'd see that Tango had to run our "add_numbers" step again.
You may have noticed that `workspace/runs/true-parrot/add_numbers` is now a symlink that points to a different
place than it did for the first two runs. That's because it produced a different result this time. All the
result symlinks point into the `workspace/cache/` directory, where all the step's results are cached.
This means that if we ran the experiment again with the original inputs, Tango would still find the cached result
and wouldn't need to rerun the step.
## Arbitrary objects as inputs
### `FromParams`
So far the inputs to all of the steps in our examples have been built-in Python types that can be deserialized from JSON (e.g. {class}`int`, {class}`str`, etc.),
but sometimes you need the input to a step to be an instance of an arbitrary Python class.
Tango allows this as well as it can infer from type hints what the class is and how to instantiate it.
When writing your own classes, it's recommended that you have your class inherit from the {class}`~tango.common.from_params.FromParams` class, which will gaurantee that
Tango can instantiate it from a config file.
For example, suppose we had a step like this:
```python
from tango import Step
from tango.common import FromParams
class Bar(FromParams):
def __init__(self, x: int) -> None:
self.x = x
@Step.register("foo")
class FooStep(Step):
def run(self, bar: Bar) -> int:
return bar.x
```
```{tip}
If you've used [AllenNLP](https://github.com/allenai/allennlp) before, this will look familiar!
In fact, it's the same system under the hood.
```
Then we could create a config like this:
```json
{
"steps": {
"foo": {
"type": "foo",
"bar": {"x": 1}
}
}
}
```
And Tango will figure out how to deserialize `{"x": 1}` into a `Bar` instance.
You can also have `FromParams` objects nested within other `FromParams` objects or standard containers
like {class}`list`:
```python
from typing import List
from tango import Step
from tango.common import FromParams
class Bar(FromParams):
def __init__(self, x: int) -> None:
self.x = x
class Baz(FromParams):
def __init__(self, bar: Bar) -> None:
self.bar = bar
@Step.register("foo")
class FooStep(Step):
def run(self, bars: List[Bar], baz: Baz) -> int:
return sum([bar.x for bar in bars]) + baz.bar.x
```
### `Registrable`
The {class}`~tango.common.registrable.Registrable` class is a special kind of {class}`~tango.common.from_params.FromParams` class that allows you to specify from the config which subclass of an expected class to deserialize into.
This is actually how we've been instantiating specific `Step` subclasses. Because {class}`~tango.step.Step` inherits from {class}`~tango.common.registrable.Registrable`, we can use the `"type"` fields in the config file to specify a `Step` subclass.
This is also very useful when you're writing a step that requires a certain type as input, but you want to be able to change the exact subclass of the type from your config file. For example, the {class}`~tango.integrations.torch.TorchTrainStep` takes `Registrable` inputs such as {class}`~tango.integrations.torch.Model`. Model variants can then be subclasses that are specified in the config file by their registered names. A sketch of this might look like the following:
```python
from tango import Step
from tango.common import FromParams, Registrable
class Model(torch.nn.Module, Registrable):
...
@Model.register("variant1")
class Variant1(Model):
...
@Model.register("variant2")
class Variant2(Model):
...
@Step.register("torch::train")
class TorchTrainerStep(Step):
def run(self, model: Model, ...) -> Model:
...
```
And a sketch of the config file would be something like this:
```json
{
"steps": {
"train": {
"type": "torch::train",
"model": {
"type": "variant1",
}
}
}
}
```
As in the `FromParams` example the specifications can be nested, but now we also denote the subclass with the `"type": "..."` field. To swap models we need only change "variant1" to "variant2" in the config. The value for "type" can either be the name that the class is registered under (e.g. "train" for `TorchTrainStep`), or the fully qualified class name (e.g. `tango.integrations.torch.TorchTrainStep`).
You'll see more examples of this in the [next section](examples/index).
================================================
FILE: docs/source/index.md
================================================
# **AI2 Tango**
```{include} ../../README.md
:start-after:
:end-before:
```
```{toctree}
:maxdepth: 2
:hidden:
:caption: Getting started
installation
first_steps
examples/index
faq
```
```{toctree}
:maxdepth: 2
:hidden:
:caption: API Reference
api/commands
api/components/index
api/integrations/index
api/settings
api/exceptions
api/logging
api/sequences
api/det_hash
api/utilities
```
```{toctree}
:hidden:
:caption: Development
CONTRIBUTING
CHANGELOG
License
GitHub Repository
```
To learn about Tango in 5 minutes, head over to the [First Steps section](first_steps).
If you'd rather learn from examples, check out the [Examples section](examples/index).
## Team
```{include} ../../README.md
:start-after:
:end-before:
```
## License
```{include} ../../README.md
:start-after:
:end-before:
```
## Indices and tables
```{eval-rst}
* :ref:`genindex`
* :ref:`modindex`
```
================================================
FILE: docs/source/installation.md
================================================
Installation
============
```{include} ../../README.md
:start-after:
:end-before:
```
================================================
FILE: examples/euler/README.md
================================================
Euler
=====
This is a toy example that proves Euler's identity using Tango. You can use this to play with the concept of a
`Step` and see how Tango runs things without getting distracted by the details of what you're running.
================================================
FILE: examples/euler/complex_arithmetic.py
================================================
import cmath
from typing import Tuple, Union
from tango import Step
ComplexOrTuple = Union[complex, Tuple[float, float]]
def make_complex(x: ComplexOrTuple) -> complex:
if isinstance(x, complex):
return x
elif isinstance(x, (int, float)):
return complex(x)
else:
return complex(*x)
@Step.register("cadd")
class AdditionStep(Step):
def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex: # type: ignore
return make_complex(a) + make_complex(b)
@Step.register("csub")
class SubtractionStep(Step):
def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex: # type: ignore
return make_complex(a) - make_complex(b)
@Step.register("cexp")
class ExponentiateStep(Step):
def run(self, x: ComplexOrTuple, base: ComplexOrTuple = cmath.e) -> complex: # type: ignore
return make_complex(base) ** make_complex(x)
@Step.register("cmul")
class MultiplyStep(Step):
def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex: # type: ignore
return make_complex(a) * make_complex(b)
@Step.register("csin")
class SineStep(Step):
def run(self, x: ComplexOrTuple) -> complex: # type: ignore
return cmath.sin(make_complex(x))
@Step.register("ccos")
class CosineStep(Step):
def run(self, x: ComplexOrTuple) -> complex: # type: ignore
return cmath.cos(make_complex(x))
================================================
FILE: examples/euler/euler.jsonnet
================================================
local i = [0.0, 1.0];
local pi = [3.1415926535, 0.0];
{
"steps": {
"i_times_pi": {
"type": "cmul",
"a": i,
"b": pi
},
"pow_e": {
"type": "cexp",
"x": { "type": "ref", "ref": "i_times_pi" }
},
"plus_one": {
"type": "cadd",
"a": { "type": "ref", "ref": "pow_e" },
"b": [1, 0]
},
"print": {
"type": "print",
"input": { "type": "ref", "ref": "plus_one" }
}
}
}
================================================
FILE: examples/euler/euler_general.jsonnet
================================================
local i = [0.0, 1.0];
local pi = [3.1415926535, 0.0];
{
"steps": {
"cos": {
"type": "ccos",
"x": pi
},
"sin": {
"type": "csin",
"x": pi
},
"i_times_sin": {
"type": "cmul",
"a": i,
"b": { "type": "ref", "ref": "sin" }
},
"sum": {
"type": "cadd",
"a": { "type": "ref", "ref": "cos" },
"b": { "type": "ref", "ref": "i_times_sin" },
},
"i_times_pi": {
"type": "cmul",
"a": i,
"b": pi
},
"pow_e": {
"type": "cexp",
"x": { "type": "ref", "ref": "i_times_pi" }
},
"sub": {
"type": "csub",
"a": { "type": "ref", "ref": "sum" },
"b": { "type": "ref", "ref": "pow_e" },
},
"print": {
"type": "print",
"input": { "type": "ref", "ref": "sub" }
}
}
}
================================================
FILE: examples/euler/run.sh
================================================
#!/bin/bash
tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic
================================================
FILE: examples/eval_p3/README.md
================================================
# Evaluating T0
This example uses the `transformers::run_generation_dataset` step to run the
[T0 model](https://api.semanticscholar.org/CorpusID:239009562). It runs the
[XSum summarization data](https://github.com/EdinburghNLP/XSum), prompted in 10 different ways, and computes
ROUGE scores for all variants. Finally, it computes an overall ROUGE score.
This example uses mostly built-in Tango steps. You will need the `datasets` and `transformers` integrations.
The only custom step in this example is the `RougeScoreStep`, which computes ROUGE scores from the
generated text.
================================================
FILE: examples/eval_p3/config.jsonnet
================================================
local model = "bigscience/T0_3B";
local batch_size = 8;
local datasets = [
'xsum_DOC_boils_down_to_simple_idea_that',
'xsum_DOC_given_above_write_one_sentence',
'xsum_DOC_how_would_you_rephrase_few_words',
'xsum_DOC_tldr',
'xsum_DOC_write_summary_of_above',
'xsum_article_DOC_summary',
'xsum_college_roommate_asked_DOC_so_I_recap',
'xsum_read_below_DOC_write_abstract',
'xsum_summarize_DOC',
'xsum_summarize_this_DOC_summary'
];
# This creates three steps for each of the datasets:
# 1. Load the dataset.
# 2. Generate output based on the dataset.
# 3. Evaluate the output against the gold answers.
local dataset_steps = std.foldl(
function(x, dataset_name) x + {
["dataset_" + dataset_name]: {
"type": "datasets::load",
"path": "bigscience/P3",
"name": dataset_name,
},
["generation_" + dataset_name]: {
"type": "transformers::run_generation_dataset",
"max_length": 200,
"input": {"ref": "dataset_" + dataset_name},
"batch_size": batch_size,
"model": model,
"prompt_field": "inputs_pretokenized",
"output_field": "generation",
"splits": ["validation"]
},
["eval_" + dataset_name]: {
"type": "rouge_score",
"input": {"ref": "generation_" + dataset_name},
"input_split": "validation",
"target_field": "targets_pretokenized",
"prediction_field": "generation"
}
},
datasets,
{}
);
# In addition to the three steps per dataset, we also combine all the generations and
# evaluate them all together.
{
"steps": dataset_steps + {
"all_generations": {
"type": "dataset_combine",
"inputs": std.map(
function(dataset_name) {"ref": "generation_" + dataset_name},
datasets
)
},
"all_evaluations": {
"type": "rouge_score",
"input": {"ref": "all_generations"},
"input_split": "validation",
"target_field": "targets_pretokenized",
"prediction_field": "generation"
}
}
}
================================================
FILE: examples/eval_p3/eval.py
================================================
import logging
from typing import Dict
from torch import Tensor
from torchmetrics.text.rouge import ROUGEScore
from tango import Format, JsonFormat, Step
from tango.common import DatasetDict
from tango.common.tqdm import Tqdm
logger = logging.getLogger(__name__)
@Step.register("rouge_score")
class RougeScoreStep(Step[Dict[str, Tensor]]):
VERSION = "002"
FORMAT: Format = JsonFormat()
def run( # type: ignore
self,
input: DatasetDict,
input_split: str,
target_field: str,
prediction_field: str,
use_stemmer: bool = True,
) -> Dict[str, Tensor]:
metric = ROUGEScore(
use_stemmer=use_stemmer,
rouge_keys=("rouge1", "rouge2", "rougeL"),
accumulate="avg",
)
for instance in Tqdm.tqdm(input[input_split], desc="Calculating scores"):
target = instance[target_field]
for prediction in instance[prediction_field]:
metric.update(prediction, target)
return metric.compute()
================================================
FILE: examples/finetune/__init__.py
================================================
================================================
FILE: examples/finetune/config.jsonnet
================================================
##################
# Model settings #
##################
local pretrained_model = "t5-base";
local load_with_low_cpu_mem_usage = false;
local modules_to_wrap = ["[a-zA-Z_.]+\\.[0-9]+"]; # TODO: works for t5 and gpt2. confirm with other models too.
####################
# Trainer settings #
####################
# Trainer settings, adjust to your use-case.
local training_steps = 20; # total number of optimization steps to train for
local validate_every = 5; # how often to validate and save checkpoints
local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU)
local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size)
# This is the batch size per GPU, ignoring gradient accumulation:
local batch_size = 2;
# So the effective batch size is `batch_size * grad_accum * devices`
local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2)
local amp = false; # use PyTorch's native automatic mixed precision
local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2)
local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.
######################
# Optimizer settings #
######################
local warmup_steps = 20;
local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2"
assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp";
# FullyShardedDataParallel config:
local fsdp_config = if fsdp then {
reshard_after_forward: true,
move_params_to_cpu: cpu_offloading,
move_grads_to_cpu: cpu_offloading,
mixed_precision: amp,
} else null;
local training_engine = {
type: if fsdp then "fairscale" else "torch",
optimizer: {
type: "torch::AdamW",
lr: learning_rate,
betas: [0.9, 0.95],
eps: 1e-6,
},
lr_scheduler: {
type: "transformers::linear",
num_warmup_steps: warmup_steps,
num_training_steps: training_steps,
},
amp: amp,
[if fsdp then "fsdp_config" else null]: fsdp_config,
};
local distributed_dataloader = {
batch_size: batch_size,
sampler: {
type: "torch::DistributedSampler",
shuffle: true,
drop_last: true,
},
};
local single_device_dataloader = {
shuffle: true,
batch_size: batch_size,
};
local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader;
{
steps: {
raw_data: {
type: "datasets::load",
path: "snli",
},
/*"subset_data": {
type: "subset-data",
data: { type: "ref", ref: "raw_data" },
max_samples: 10,
},*/
processed_data: {
type: "snli-text2text",
data: { type: "ref", ref: "raw_data" },
},
trained_model: {
type: "transformers::finetune",
model: {
type: "fairscale::with_wrapped_modules",
model: {
type: "transformers::finetune::from_pretrained",
pretrained_model_name_or_path: pretrained_model,
low_cpu_mem_usage: load_with_low_cpu_mem_usage,
},
modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually
fsdp_config: fsdp_config,
activation_checkpointing: activation_checkpointing,
},
tokenizer: {
pretrained_model_name_or_path: pretrained_model
},
dataset_dict: { type: "ref", ref: "processed_data" },
train_dataloader: dataloader,
validation_split: "validation",
grad_accum: grad_accum,
train_steps: training_steps,
validate_every: validate_every,
checkpoint_every: validate_every,
log_every: 1,
device_count: devices,
training_engine: training_engine,
},
generations: {
type: "transformers::run_generation_dataset",
max_length: 5,
input: {"type": "ref", "ref": "processed_data"},
batch_size: batch_size,
model: {"type": "ref", "ref": "trained_model"},
prompt_field: "source",
output_field: "generation",
splits: ["validation"]
}
}
}
================================================
FILE: examples/finetune/snli_steps.py
================================================
from typing import Union
import datasets as ds
from tango.integrations.datasets import DatasetsFormat
from tango.step import Step
@Step.register("subset-data")
class SubsetData(Step):
"""
Creates a subset of the data; mostly to be used for testing/debugging.
"""
DETERMINISTIC = True
CACHEABLE = True
VERSION = "001"
FORMAT = DatasetsFormat()
def run( # type: ignore
self,
data: Union[ds.DatasetDict, ds.Dataset],
max_samples: int = 5,
) -> Union[ds.DatasetDict, ds.Dataset]:
"""
Returns a copy of the `data` with number of samples limited to `max_samples` for
each split.
:param data:
The dataset or dataset dict object.
:param max_samples:
The maximum number of samples to return per split.
"""
# Unlike `ds.Dataset.select`, this works on both `ds.Dataset` and `ds.DatasetDict`.
def filter_fn(example, indices):
return indices < max_samples
return data.filter(filter_fn, with_indices=True)
@Step.register("snli-text2text")
class SnliText2Text(Step):
"""
Converts the snli dataset to a text-to-text format.
Examples
--------
original_instance = {
"premise": "Two cats are sitting on a wall.",
"hypothesis": "The cats are chasing a mouse.",
"label": 2 # contradiction
}
returned_instance = {
"source": "nli premise: Two cats are sitting on a wall. hypothesis: The cats are chasing a mouse. label: "
"target": "contradiction"
}
"""
DETERMINISTIC = True
CACHEABLE = True
VERSION = "001"
FORMAT = DatasetsFormat()
def run( # type: ignore
self,
data: Union[ds.DatasetDict, ds.Dataset],
source_prefix: str = "nli",
premise_prefix: str = "premise",
hypothesis_prefix: str = "hypothesis",
label_prefix: str = "label",
num_workers: int = 1,
) -> Union[ds.DatasetDict, ds.Dataset]:
"""
:param data:
The snli `Dataset` or `DatasetDict` object.
:param source_prefix:
The str to add before the start of the source sequence.
:param premise_prefix:
The str to add before the start of the `premise` in the source sequence.
:param hypothesis_prefix:
The str to add before the start of the `hypothesis` in the source sequence.
:param label_prefix:
The str to add as the prompt for the label.
:param num_workers:
The number of workers to use for processing the data.
"""
def filter_no_gold(example, indices):
if example["label"] == -1:
return False
return True
data = data.filter(filter_no_gold, with_indices=True)
label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
def _mapper(example):
return {
"source": (
f'{source_prefix} {premise_prefix}: {example["premise"]} '
f'{hypothesis_prefix}: {example["hypothesis"]} {label_prefix}: '
),
"target": f'{label_map[example["label"]]}',
}
if isinstance(data, ds.Dataset):
old_cols = data.column_names
else:
old_cols = list(data.column_names.values())[0]
dataset = data.map(
_mapper,
batched=False,
num_proc=num_workers,
remove_columns=old_cols, # remove all old columns
desc="Converting data to text-to-text format",
)
return dataset
================================================
FILE: examples/finetune/test.py
================================================
import typing
import datasets as ds
import pytest
from tango.common import Params
from tango.common.testing import TangoTestCase, run_experiment
class TestFinetuneSNLI(TangoTestCase):
@pytest.mark.parametrize(
"model, model_type",
[("patrickvonplaten/t5-tiny-random", "t5"), ("sshleifer/tiny-gpt2", "gpt2")],
)
@typing.no_type_check # mypy has become incompatible with the datasets library
def test_config(self, model: str, model_type: str):
overrides = {
"steps.trained_model.model.model.pretrained_model_name_or_path": model,
"steps.trained_model.tokenizer.pretrained_model_name_or_path": model,
"steps.subset_data": {
"type": "subset-data",
"data": {"type": "ref", "ref": "raw_data"},
"max_samples": 10,
},
"steps.processed_data.data.ref": "subset_data",
}
config = Params.from_file("config.jsonnet", params_overrides=overrides)
# Make sure we've overrode the model entirely.
flattened = config.as_flat_dict()
for key, value in flattened.items():
if "model_name" in key or (isinstance(value, str) and model_type in value):
assert value == model
with run_experiment(config, include_package=["snli_steps.py"]) as run_dir:
assert (run_dir / "processed_data").is_dir()
processed = ds.load_from_disk(run_dir / "processed_data" / "data")
assert len(processed["train"][0].keys()) == 2
assert "source" in processed["train"][0].keys()
assert "target" in processed["train"][0].keys()
assert processed["train"][0]["source"].startswith("nli premise:")
assert (run_dir / "trained_model").is_dir()
================================================
FILE: examples/finetune_resnet/.gitignore
================================================
data/
results/
extra_testing.py
================================================
FILE: examples/finetune_resnet/config.jsonnet
================================================
local input_size = 224;
local batch_size = 32;
local num_classes = 2;
local val_size = 0.05;
local model = "resnet";
local feature_extract = true;
local distributed = false;
local devices = if distributed then 2 else 1;
local pretrained_model = "resnet_ft";
local training_steps = 500;
local validate_every = 50;
local image_url = "https://tinyurl.com/2p9xjvn9";
local distributed_dataloader = {
batch_size: batch_size,
sampler: {
type: "torch::DistributedSampler",
shuffle: true,
drop_last: true,
},
collate_fn: {"type": "image_collator"},
};
local single_device_dataloader = {
shuffle: true,
batch_size: batch_size,
collate_fn: {"type": "image_collator"},
};
{
steps: {
raw_data: {
type: "datasets::load",
path: "nateraw/auto-cats-and-dogs",
name: "cats_and_dogs",
},
transform_data: {
type: "transform_data",
dataset: { type: 'ref', ref: 'raw_data' },
input_size: input_size,
val_size: val_size,
},
trained_model: {
type: "torch::train",
model: {
type: pretrained_model,
num_classes: num_classes,
feature_extract: true,
use_pretrained: true,
},
training_engine: {
optimizer: {
type: "torch_adam",
lr: 0.001,
},
},
dataset_dict: {"type": "ref", "ref": "transform_data"},
train_dataloader: single_device_dataloader,
validation_split: "val",
val_metric_name: "accuracy",
train_steps: training_steps,
validate_every: validate_every,
checkpoint_every: validate_every,
log_every: 1,
device_count: devices,
minimize_val_metric: false,
},
prediction: {
type: "prediction",
image_url: image_url,
input_size: input_size,
model: {"type": "ref", "ref": "trained_model"},
},
},
}
================================================
FILE: examples/finetune_resnet/resnet_steps.py
================================================
from typing import Any, Dict, List, Optional
import datasets
import torch
from cached_path import cached_path
from PIL import Image
from torch import nn
from torch.optim import Adam
from torchvision import models, transforms
from tango import Format, JsonFormat, Step
from tango.integrations.torch import DataCollator, Model, Optimizer
# Register the Adam optimizer as an `Optimizer` so we can use it in the train step.
Optimizer.register("torch_adam")(Adam)
# Wrapper class around the pre-trained ResNet-18 model that modifies the final layer.
@Model.register("resnet_ft")
class ResNetWrapper(Model):
def __init__(self, num_classes: int, feature_extract: bool, use_pretrained: bool):
super().__init__()
self.model_ft = models.resnet18(pretrained=use_pretrained)
self.set_parameter_requires_grad(self.model_ft, feature_extract)
num_features = self.model_ft.fc.in_features
self.model_ft.fc = nn.Linear(num_features, num_classes)
self.loss_fn = nn.CrossEntropyLoss()
def set_parameter_requires_grad(self, model: models, feature_extracting: bool):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
def forward( # type: ignore
self, image: torch.Tensor, label: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
output = self.model_ft(image)
preds = torch.argmax(output, dim=1)
if label is None:
return {"preds": preds}
loss = self.loss_fn(output, label)
accuracy = (preds == label).float().mean()
return {"loss": loss, "accuracy": accuracy}
# Custom data collator for images, that takes in a batch of images and labels and
# reformats the data so that it is suitable for the model.
@DataCollator.register("image_collator")
class ImageCollator(DataCollator[Dict[str, Any]]):
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
return {
"image": torch.cat([item["image"].unsqueeze(0) for item in batch], dim=0),
"label": torch.tensor([item["labels"] for item in batch]),
}
# Function that returns an image transformations dict with the appropriate image size.
def get_data_transforms(input_size: int):
data_transforms = {
"train": transforms.Compose(
[
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
),
"val": transforms.Compose(
[
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
),
}
return data_transforms
# loads and image and applies the appropriate transformation
def pil_loader(path: str, input_size: int, transform_type: str):
with open(path, "rb") as f:
image = Image.open(f)
image = image.convert("RGB")
transform = get_data_transforms(input_size=input_size)[transform_type]
transformed_image = transform(image)
return transformed_image
# calls the image loader on every image in a given batch
def image_loader(example_batch, input_size: int, transform_type: str):
example_batch["image"] = [
pil_loader(f, input_size, transform_type) for f in example_batch["file"]
]
return example_batch
# This step takes in raw image data and transforms and tokenizes it.
@Step.register("transform_data")
class TransformData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run( # type: ignore
self, dataset: datasets.DatasetDict, val_size: float, input_size: int
) -> datasets.DatasetDict:
def image_loader_wrapper(example_batch):
return image_loader(example_batch, input_size=input_size, transform_type="train")
dataset = dataset.with_transform(image_loader_wrapper)
train_val = dataset["train"].train_test_split(test_size=val_size)
train_val["val"] = train_val.pop("test")
return train_val
# function to map integer labels to string labels
def convert_to_label(int_label: int) -> str:
if int_label == 0:
return "cat"
else:
return "dog"
@Step.register("prediction")
class Prediction(Step):
FORMAT: Format = JsonFormat()
def run( # type: ignore
self, image_url: str, input_size: int, model: models, device: Optional[str] = "cpu"
) -> Dict[str, Any]:
# download and store image
image_path = cached_path(image_url)
transformed_image = pil_loader(image_path, input_size, transform_type="val")
# pass image through transform
transformed_image = transformed_image.unsqueeze(0).to(device)
# pass image through model and get the prediction
prediction = model(image=transformed_image, label=None)["preds"][0].float()
label = convert_to_label(prediction)
return {"image_url": image_url, "local_path": image_path, "label": label}
================================================
FILE: examples/flax/config.jsonnet
================================================
{
"steps": {
"data": {
"type": "datasets::load",
"path": "xsum",
},
"tokenize": {
"type": "tokenize_data",
"dataset": {
"type": "ref",
"ref": "data"
}
},
"train": {
"type": "flax::train",
"model": {
"type": "transformers::FlaxAutoModelForSeq2SeqLM::from_pretrained",
"pretrained_model_name_or_path": "facebook/bart-base"
},
"dataset": {
"type": "ref",
"ref": "tokenize"
},
"optimizer": {
"type" : "optax::adamw",
"learning_rate" : 2e-5
},
"train_dataloader": {
"batch_size": 16,
"drop_last": true
},
"wrapper": {
"type": "xsum_wrapper"
},
"train_split": "train",
"validation_split" : "validation",
"validate_every" : 1000,
"validation_dataloader": {
"batch_size": 16,
"drop_last": true
},
"train_epoch": 5,
"checkpoint_every": 1000,
"log_every": 1000,
"callbacks" : [
//{"type" : "wandb::log_flax"},
{"type": "flax::generate_step"}
]
},
"eval": {
"type": "flax::eval",
"state": {
"type": "ref",
"ref": "train"
},
"dataset": {
"type": "ref",
"ref": "tokenize"
},
"dataloader": {
"batch_size": 16,
"drop_last": true
},
"wrapper": {
"type" : "xsum_wrapper"
}
}
}
}
================================================
FILE: examples/flax/run.sh
================================================
#!/bin/bash
tango run config.jsonnet -d workspace --include-package xsum
================================================
FILE: examples/flax/xsum.py
================================================
import logging
from typing import List, Optional
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_metric
from flax.training.common_utils import onehot
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSeq2SeqLM
from tango.integrations.flax import FlaxWrapper
from tango.integrations.flax.train_callback import TrainCallback
from tango.step import Step
"""
XSum Summarization with facebook/bart-base
"""
@Step.register("tokenize_data")
class PreProcessing(Step):
DETERMINISTIC = False
def run(self, dataset):
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = FlaxAutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base")
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
config = AutoConfig.from_pretrained("facebook/bart-base")
MAX_SOURCE_LENGTH = 512
MAX_TGT_LENGTH = 64
def preprocess_function(examples):
inputs = examples["document"]
targets = examples["summary"]
inputs = [inp for inp in inputs]
model_inputs = tokenizer(
inputs,
max_length=MAX_SOURCE_LENGTH,
padding="max_length",
truncation=True,
return_tensors="np",
)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(
targets,
max_length=MAX_TGT_LENGTH,
padding="max_length",
truncation=True,
return_tensors="np",
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
column_names = dataset["train"].column_names
dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=column_names,
desc="Running tokenizer on dataset",
)
return dataset
@FlaxWrapper.register("xsum_wrapper") # type: ignore
class TransformerWrapper(FlaxWrapper):
def loss_helper(self, logits, labels, batch):
label_smoothing_factor = 0
padding_mask = batch["decoder_attention_mask"]
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence)
+ (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum() / padding_mask.sum()
return loss
def train_loss(self, params, state, batch, dropout_rng, labels):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = self.loss_helper(logits, labels, batch)
return loss
def val_metrics(self, batch, logits, labels):
loss = self.loss_helper(logits, labels, batch)
metrics = {"loss": loss}
return metrics
def eval_metrics(self, batch, logits, labels):
loss = self.loss_helper(logits, labels, batch)
metrics = {"loss": loss}
return metrics
@TrainCallback.register("flax::generate_step")
class GenerateCallback(TrainCallback):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.logger = logging.getLogger(GenerateCallback.__name__)
def generate_step(self, params, batch):
self.model.params = params
gen_kwargs = {"max_length": 64, "num_beams": self.model.config.num_beams}
output_ids = self.model.generate(
batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs
)
return output_ids.sequences
def pre_train_loop(self) -> None:
if len(jax.devices()) > 1:
self.p_generate_step = jax.pmap(self.generate_step, axis_name="batch")
def pre_val_loop(self, step: int, val_step: int, state) -> None:
self.state = state
self.eval_preds: List = []
self.eval_labels: List = []
def pre_val_batch(self, step: int, val_step: int, epoch: int, val_batch) -> None:
labels = val_batch["labels"]
if len(jax.devices()) > 1:
generated_ids = self.p_generate_step(self.state.params, val_batch)
else:
generated_ids = self.generate_step(self.state.params, val_batch)
self.eval_preds.extend(jax.device_get(generated_ids.reshape(-1, 64)))
self.eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
def postprocess_text(self, preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def compute_metrics(self, preds, labels):
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds, decoded_labels = self.postprocess_text(decoded_preds, decoded_labels)
metric = load_metric("rouge")
result = metric.compute(
predictions=decoded_preds, references=decoded_labels, use_stemmer=True
)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
def post_val_loop(
self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float]
) -> None:
rouge_metrics = self.compute_metrics(self.eval_preds, self.eval_labels)
rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
self.logger.info(rouge_desc)
================================================
FILE: examples/train_lm/.gitignore
================================================
runs
run
================================================
FILE: examples/train_lm/README.md
================================================
# Fine-tuning a language model
This Tango example showcases how you could train or fine-tune a causal language model like [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)
or [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) from [transformers](https://github.com/huggingface/transformers) on WikiText2 or a similar dataset.
It's best that you run this experiment on a machine with a GPU and PyTorch [properly installed](https://pytorch.org/get-started/locally/#start-locally), otherwise Tango will fall back to CPU-only and it will be extremely slow.
This example also depends on [FairScale](https://fairscale.readthedocs.io/en/latest/), which allows you to leverage [`FullyShardedDataParallel`](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html) (FSDP) and [activation checkpointing](https://fairscale.readthedocs.io/en/latest/api/nn/checkpoint/checkpoint_activations.html) to fine-tune [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B) or a similar-sized model. Just set the constants `fsdp` and `activation_checkpointing` in the config to `true`.
Without using CPU offloading you'll need at least 4 x 40GiB A100 GPUs, or a different configuration with a comparable amount of total GPU memory.
To getting started, just run
```
tango run config.jsonnet -i tokenize_step.py
```
================================================
FILE: examples/train_lm/config.jsonnet
================================================
##################
# Model settings #
##################
local pretrained_model = "gpt2";
# With 'fsdp' and 'activation_checkpointing' (see constants below), you should be able to train
# a 6B model on 4x ~40GB GPUs:
# local pretrained_model = "EleutherAI/gpt-j-6B";
# This doesn't seem to work with gpt2, but works fine with gpt-j.
local load_with_low_cpu_mem_usage = std.startsWith(pretrained_model, "EleutherAI/gpt-j");
####################
# Trainer settings #
####################
# Trainer settings, adjust to your use-case.
local training_steps = 200; # total number of optimization steps to train for
local validate_every = 20; # how often to validate and save checkpoints
local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU)
local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size)
# This is the batch size per GPU, ignoring gradient accumulation:
local batch_size = 8;
# So the effective batch size is `batch_size * grad_accum * devices`
local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2)
local amp = false; # use PyTorch's native automatic mixed precision
local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2)
local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.
######################
# Optimizer settings #
######################
local warmup_steps = 20;
local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2"
# <----- you probably don't need to edit below this line ----> #
assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp";
# FullyShardedDataParallel config:
local fsdp_config = if fsdp then {
reshard_after_forward: true,
move_params_to_cpu: cpu_offloading,
move_grads_to_cpu: cpu_offloading,
mixed_precision: amp,
} else null;
local training_engine = {
type: if fsdp then "fairscale" else "torch",
optimizer: {
type: "torch::AdamW",
lr: learning_rate,
betas: [0.9, 0.95],
eps: 1e-6,
},
lr_scheduler: {
type: "transformers::linear",
num_warmup_steps: warmup_steps,
num_training_steps: training_steps,
},
amp: amp,
[if fsdp then "fsdp_config" else null]: fsdp_config,
};
local distributed_dataloader = {
batch_size: batch_size,
collate_fn: { type: "transformers::DefaultDataCollator" },
sampler: {
type: "torch::DistributedSampler",
shuffle: true,
drop_last: true,
},
};
local single_device_dataloader = {
shuffle: true,
batch_size: batch_size,
collate_fn: { type: "transformers::DefaultDataCollator" },
};
local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader;
{
steps: {
raw_data: {
type: "datasets::load",
path: "wikitext",
name: "wikitext-2-raw-v1",
},
tokenized_data: {
type: "tokenize_data",
dataset: { type: "ref", ref: "raw_data" },
tokenizer: { pretrained_model_name_or_path: pretrained_model }
},
trained_model: {
type: "torch::train",
model: {
type: "fairscale::with_wrapped_modules",
model: {
type: "transformers::AutoModelForCausalLM::from_pretrained",
pretrained_model_name_or_path: pretrained_model,
low_cpu_mem_usage: load_with_low_cpu_mem_usage,
},
modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually
fsdp_config: fsdp_config,
activation_checkpointing: activation_checkpointing,
},
dataset_dict: { type: "ref", ref: "tokenized_data" },
train_dataloader: dataloader,
validation_split: "validation",
grad_accum: grad_accum,
train_steps: training_steps,
validate_every: validate_every,
checkpoint_every: validate_every,
log_every: 1,
device_count: devices,
training_engine: training_engine,
},
final_metrics: {
type: "torch::eval",
model: { type: "ref", ref: "trained_model" },
dataset_dict: { type: "ref", ref: "tokenized_data" },
dataloader: single_device_dataloader,
test_split: "test",
},
}
}
================================================
FILE: examples/train_lm/test.py
================================================
from tango.common import Params
from tango.common.testing import run_experiment
def test_small_experiment():
model = "sshleifer/tiny-gpt2"
dataloader = {
"batch_size": 2,
"collate_fn": {"type": "transformers::DefaultDataCollator"},
}
steps = 4
overrides = {
"steps.tokenized_data.block_size": 64,
# Override the model in the config with the tiny alternative so training is fast.
"steps.tokenized_data.tokenizer.pretrained_model_name_or_path": model,
"steps.trained_model.model.model.pretrained_model_name_or_path": model,
# Use a small number of training/validation/eval steps.
"steps.trained_model.training_engine.lr_scheduler.num_warmup_steps": 1,
"steps.trained_model.training_engine.lr_scheduler.num_training_steps": steps,
"steps.trained_model.train_steps": steps,
"steps.trained_model.validation_steps": 2,
"steps.trained_model.validate_every": steps,
"steps.final_metrics.eval_steps": 2,
"steps.trained_model.checkpoint_every": steps,
"steps.trained_model.device_count": 1,
# Override data loaders.
"steps.trained_model.train_dataloader": dataloader,
"steps.trained_model.validation_dataloader": dataloader,
"steps.final_metrics.dataloader": dataloader,
}
# Load the config.
config = Params.from_file("config.jsonnet", params_overrides=overrides)
# Make sure we've overrode the model entirely.
flattened = config.as_flat_dict()
for key, value in flattened.items():
if "model_name" in key or (isinstance(value, str) and "gpt" in value):
assert value == model
with run_experiment(config, include_package=["tokenize_step.py"]) as run_dir:
assert (run_dir / "trained_model").is_dir()
================================================
FILE: examples/train_lm/tokenize_step.py
================================================
import datasets
from tango import Step
from tango.integrations.datasets import DatasetsFormat
from tango.integrations.transformers import Tokenizer
# We need a step to tokenize the raw data. The result of this step will be passed
# directly into the "torch::train" step.
@Step.register("tokenize_data")
class TokenizeData(Step):
DETERMINISTIC = True
CACHEABLE = True
FORMAT = DatasetsFormat()
def run( # type: ignore[override]
self,
dataset: datasets.DatasetDict,
tokenizer: Tokenizer,
block_size: int = 1024,
num_workers: int = 1,
field_to_tokenize: str = "text",
) -> datasets.DatasetDict:
def tokenize_function(example):
return tokenizer(example[field_to_tokenize])
dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=num_workers,
remove_columns=list(dataset.column_names.values())[0], # remove all old columns
desc="Tokenizing dataset",
)
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} # type: ignore
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported
# it instead of this drop, you can customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
dataset = dataset.map(
group_texts,
batched=True,
num_proc=num_workers,
desc=f"Grouping texts into chunks of {block_size}",
)
return dataset
================================================
FILE: integration_tests/README.md
================================================
# Integration tests
These are a collection of longer running end-to-end tests of various parts of the Tango library.
The easiest way to run any of these integration tests is by triggering the [**Integration tests**](https://github.com/allenai/tango/actions/workflows/integration_tests.yml)
workflow on GitHub Actions. Just select the "Run workflow" dropdown, then pick the test to run and the Beaker cluster to run it on,
and finally hit the "Run workflow" button.
Each test should have a `run.sh` file in its folder that will run the relevant tango command.
This is what the **Integration tests** workflow will call, and you can also use it to run the test manually.
================================================
FILE: integration_tests/fairscale_benchmarks/README.md
================================================
# FairScale Benchmarks
This integration test is for checking the performance of the `FairScaleTrainingEngine` with various configurations.
**When to run it:** It should be ran every time there is a major PyTorch or FairScale upgrade.
**Where to run it:** A server with 4 A100 GPUs. Make sure you set your `WANDB_API_KEY` environment variable.
**How to run it:** From the root directory of this repository, run:
```
integration_tests/fairscale_benchmarks/run.sh
```
By default, not all configurations are run. If you want to run change which configurations are run, open `config.jsonnet`
are search for "enabled". Then toggle this `enabled` field to `true` or `false` for each configuration.
**What to look for:** The training jobs shouldn't fail, for one. After `tango run` completes, check the corresponding Weights & Biases
dashboard and inspect the results. Compare the various "fsdp" training runs with the baseline to ensure you see memory savings.
================================================
FILE: integration_tests/fairscale_benchmarks/config.jsonnet
================================================
##################
# Model settings #
##################
local pretrained_model = "gpt2";
# local pretrained_model = "EleutherAI/gpt-j-6B";
# This doesn't seem to work with gpt2, but works fine with gpt-j-6B.
local load_with_low_cpu_mem_usage = pretrained_model == "EleutherAI/gpt-j-6B";
####################
# Trainer settings #
####################
# Trainer settings, adjust to your use-case.
local training_steps = 100; # total number of optimization steps to train for
local validate_every = 20; # how often to validate and save checkpoints
local devices = 4;
local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size)
# This is the batch size per GPU, ignoring gradient accumulation:
local batch_size = 8;
# So the effective batch size is `batch_size * grad_accum * devices`
######################
# Optimizer settings #
######################
local warmup_steps = 20;
local learning_rate = if pretrained_model == "EleutherAI/gpt-j-6B" then 0.00001 else 0.0001;
# <----- you probably don't need to edit below this line ----> #
local distributed_dataloader = {
batch_size: batch_size,
collate_fn: { type: "transformers::DefaultDataCollator" },
sampler: {
type: "torch::DistributedSampler",
shuffle: true,
drop_last: true,
},
};
local single_device_dataloader = {
shuffle: true,
batch_size: batch_size,
collate_fn: { type: "transformers::DefaultDataCollator" },
};
local TrainStep(options) =
local training_engine = {
type: if options.fsdp_config != null then "fairscale" else "torch",
optimizer: {
type: "torch::AdamW",
lr: learning_rate,
betas: [0.9, 0.95],
eps: 1e-6,
},
lr_scheduler: {
type: "transformers::linear",
num_warmup_steps: warmup_steps,
num_training_steps: training_steps,
},
amp: options.amp,
[if options.fsdp_config != null then "fsdp_config" else null]: options.fsdp_config,
};
{
type: "torch::train",
model: {
type: "fairscale::with_wrapped_modules",
model: {
type: "transformers::AutoModelForCausalLM::from_pretrained",
pretrained_model_name_or_path: pretrained_model,
low_cpu_mem_usage: load_with_low_cpu_mem_usage,
},
modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually
fsdp_config: options.fsdp_config,
activation_checkpointing: options.activation_checkpointing,
},
dataset_dict: { type: "ref", ref: "tokenized_data" },
train_dataloader: distributed_dataloader,
validation_split: "validation",
grad_accum: grad_accum,
train_steps: training_steps,
validate_every: validate_every,
checkpoint_every: validate_every,
log_every: 1,
device_count: devices,
training_engine: training_engine,
callbacks: [
{
type: "wandb::log",
entity: "allennlp",
project: "tango-fairscale-benchmarks",
wandb_config: options + {
effective_batch_size: batch_size * devices * grad_accum,
model: pretrained_model,
},
},
],
};
{
steps: {
raw_data: {
type: "datasets::load",
path: "wikitext",
name: "wikitext-2-raw-v1",
},
tokenized_data: {
type: "tokenize_data",
dataset: { type: "ref", ref: "raw_data" },
tokenizer: { pretrained_model_name_or_path: pretrained_model }
},
} + {
["trained_model_" + options.name]: TrainStep(options)
for options in [
# NOTE: With 6B model, baseline and many others will fail with CUDA OOM.
# FSDP and activation checkpointing will be required for a 6B model.
{
name: "baseline",
enabled: false,
amp: false,
fsdp_config: null,
activation_checkpointing: false,
},
{
name: "amp",
enabled: false,
amp: true,
fsdp_config: null,
activation_checkpointing: false,
},
{
name: "checkpointing",
enabled: false,
amp: false,
fsdp_config: null,
activation_checkpointing: true,
},
{
name: "amp_and_checkpointing",
enabled: false,
amp: true,
fsdp_config: null,
activation_checkpointing: true,
},
{
name: "fsdp",
enabled: false,
amp: false,
activation_checkpointing: false,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: false,
},
},
{
name: "fsdp_no_reshard",
enabled: false,
amp: false,
activation_checkpointing: false,
fsdp_config: {
reshard_after_forward: false,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: false,
},
},
{
name: "amp_and_fsdp",
enabled: false,
amp: true,
activation_checkpointing: false,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: false,
},
},
{
name: "amp_and_fsdp_no_reshard",
enabled: false,
amp: true,
activation_checkpointing: false,
fsdp_config: {
reshard_after_forward: false,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: false,
},
},
{
name: "amp_and_fsdp_mp",
enabled: false,
amp: true,
activation_checkpointing: false,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: true,
},
},
{
name: "amp_and_fsdp_mp_no_reshard",
enabled: false,
amp: true,
activation_checkpointing: false,
fsdp_config: {
reshard_after_forward: false,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: true,
},
},
{
name: "checkpointing_and_fsdp",
enabled: false,
amp: false,
activation_checkpointing: true,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: false,
},
},
{
name: "amp_and_checkpointing_and_fsdp",
enabled: false,
amp: true,
activation_checkpointing: true,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: false,
},
},
{
name: "amp_and_checkpointing_and_fsdp_mp",
enabled: true,
amp: true,
activation_checkpointing: true,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: true,
},
},
{
name: "checkpointing_and_fsdp_mp",
enabled: false,
amp: false,
activation_checkpointing: true,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: false,
move_grads_to_cpu: false,
mixed_precision: true,
},
},
{ # This configuration currently does not work. Tracking https://github.com/facebookresearch/fairscale/issues/918
name: "amp_and_checkpointing_and_fsdp_mp_with_partial_offloading",
enabled: false,
amp: true,
activation_checkpointing: true,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: true,
move_grads_to_cpu: false,
mixed_precision: true,
},
},
{
name: "amp_and_checkpointing_and_fsdp_mp_with_full_offloading",
enabled: false,
amp: true,
activation_checkpointing: true,
fsdp_config: {
reshard_after_forward: true,
move_params_to_cpu: true,
move_grads_to_cpu: true,
mixed_precision: true,
},
},
] if options.enabled
}
}
================================================
FILE: integration_tests/fairscale_benchmarks/run.sh
================================================
#!/bin/sh
tango run integration_tests/fairscale_benchmarks/config.jsonnet -i examples/train_lm/tokenize_step.py
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "ai2-tango"
dynamic = ["version"]
readme = "README.md"
description = "A library for choreographing your machine learning research."
classifiers=[
"Intended Audience :: Science/Research",
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
authors = [
{name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org"}
]
license = {file = "LICENSE"}
requires-python = ">=3.8.1"
dependencies = [
"cached-path>=1.0,<2.0",
"rjsonnet>=0.5.0",
"GitPython>=3.0,<4.0",
"PyYAML>=5.4.1,<7.0",
"dill",
"base58",
"xxhash",
"filelock>=3.4,<4.0",
"click>=8.0,<8.1.4",
"click-help-colors>=0.9.1,<0.10",
"rich>=12.3,<14.0",
"tqdm>=4.62,<5.0",
"more-itertools>=8.0,<11.0",
"sqlitedict",
"glob2>=0.7",
"petname>=2.6,<3.0",
"pytz"
]
[project.optional-dependencies]
dev = [
"ruff",
"mypy==1.2.0",
"types-PyYAML",
"types-setuptools",
"types-pytz",
"types-retry",
"black==23.3.0",
"isort==5.12.0",
"pytest",
"pytest-sphinx",
"flaky",
"twine>=1.11.0",
"setuptools",
"wheel",
"build",
"Sphinx==5.3.0",
"furo==2023.3.27",
"myst-parser==1.0.0",
"sphinx-copybutton==0.5.2",
"sphinx-autobuild==2021.3.14",
"sphinx-autodoc-typehints<=1.23.0",
"packaging"
]
examples = [
"torchmetrics>=0.7.0"
]
torch = [
"torch>=1.9,<2.1",
"numpy",
]
transformers = [
"torch>=1.9,<2.1",
"numpy",
"datasets>=1.12,<3.0",
"transformers>=4.12.3",
"sentencepiece==0.1.98",
"sacremoses"
]
datasets = [
"datasets>=1.12,<3.0"
]
fairscale = [
"torch>=1.9,<2.1",
"numpy",
"fairscale>=0.4.6,<0.5"
]
flax = [
"datasets>=1.12,<3.0",
"jax",
"jaxlib",
"flax",
"optax",
"tensorflow-cpu>=2.9.1"
]
wandb = [
"wandb>=0.16",
"retry"
]
beaker = [
"beaker-py>=1.14.0,<2.0"
]
gs = [
"google-cloud-storage>=2.6.0",
"google-cloud-datastore>=2.12.0"
]
all = [
"ai2-tango[examples,torch,transformers,datasets,fairscale,flax,wandb,beaker,gs]"
]
[project.scripts]
tango = "tango.__main__:main"
[project.urls]
homepage = "https://github.com/allenai/tango"
repository = "https://github.com/allenai/tango"
[tool.setuptools.packages.find]
exclude = [
"*.tests",
"*.tests.*",
"tests.*",
"tests",
"test_fixtures",
"test_fixtures.*",
"docs*",
"scripts*",
"examples*"
]
[tool.setuptools.package-data]
tango = ["py.typed"]
"tango.integrations.beaker" = ["*.sh"]
[tool.setuptools.dynamic]
version = {attr = "tango.version.VERSION"}
[tool.black]
line-length = 100
include = '\.pyi?$'
exclude = '''
(
__pycache__
| \.git
| \.mypy_cache
| \.pytest_cache
| \.vscode
| \.venv
| \bdist\b
| \bdoc\b
)
'''
[tool.isort]
profile = "black"
multi_line_output = 3
[tool.ruff]
line-length = 115
select = ["E"]
exclude = [
".venv",
".git",
"__pycache__",
".mypy_cache",
"docs/build",
"dist"
]
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"*/**/**/__init__.py" = ["F401","E501"]
[tool.mypy]
ignore_missing_imports = true
no_site_packages = false
allow_redefinition = true
check_untyped_defs = true
[[tool.mypy.overrides]]
module = "tests.*"
strict_optional = false
disable_error_code = [
"var-annotated",
"no-redef",
"dict-item"
]
allow_redefinition = true
[tool.pytest.ini_options]
testpaths = "tests/"
python_classes = [
"Test*",
"*Test"
]
log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
log_level = "DEBUG"
markers = [
"gpu: marks tests that need GPUs"
]
filterwarnings = [
'ignore:.*Consider increasing the value of the `num_workers` argument.*:UserWarning:pytorch_lightning\.trainer\.data_loading',
'ignore:.*you defined a validation_step but have no val_dataloader.*:UserWarning:pytorch_lightning\.trainer\.configuration_validator',
'ignore::UserWarning:tango\.*',
'ignore::DeprecationWarning:pkg_resources',
'ignore::DeprecationWarning:google\.rpc'
]
doctest_optionflags = "NORMALIZE_WHITESPACE"
================================================
FILE: scripts/entrypoint.sh
================================================
#!/bin/bash
# Exit script if any commands fail.
set -e
set -o pipefail
# Check that the environment variable has been set correctly
if [ -z "$COMMIT_SHA" ]; then
echo >&2 'error: missing COMMIT_SHA environment variable'
exit 1
fi
# Upgrade pip
/opt/conda/bin/pip install --upgrade pip
# Clone and install tango.
git clone https://github.com/allenai/tango.git
cd tango
git checkout --quiet "$COMMIT_SHA"
/opt/conda/bin/pip install --no-cache-dir '.[dev,all]'
# Create directory for results.
mkdir -p /results
# Execute the arguments to this script as commands themselves, piping output into a log file.
exec "$@" 2>&1 | tee /results/out.log
================================================
FILE: scripts/hash_extras.py
================================================
"""
Used in CI to create a unique ID for any set of install extras.
"""
import sys
def main():
extras = sys.argv[1]
print("-".join(sorted(extras.split(","))))
if __name__ == "__main__":
main()
================================================
FILE: scripts/prepare_changelog.py
================================================
from datetime import datetime
from pathlib import Path
from tango.version import VERSION
def main():
changelog = Path("CHANGELOG.md")
with changelog.open() as f:
lines = f.readlines()
insert_index: int
for i in range(len(lines)):
line = lines[i]
if line.startswith("## Unreleased"):
insert_index = i + 1
elif line.startswith(f"## [v{VERSION}]"):
print("CHANGELOG already up-to-date")
return
elif line.startswith("## [v"):
break
else:
raise RuntimeError("Couldn't find 'Unreleased' section")
lines.insert(insert_index, "\n")
lines.insert(
insert_index + 1,
f"## [v{VERSION}](https://github.com/allenai/tango/releases/tag/v{VERSION}) - "
f"{datetime.now().strftime('%Y-%m-%d')}\n",
)
with changelog.open("w") as f:
f.writelines(lines)
if __name__ == "__main__":
main()
================================================
FILE: scripts/prepare_citation_cff.py
================================================
from datetime import datetime
from pathlib import Path
from tango.version import VERSION
def main():
citation = Path("CITATION.cff")
with citation.open() as f:
lines = f.readlines()
for i in range(len(lines)):
line = lines[i]
if line.startswith("version:"):
lines[i] = f'version: "{VERSION}"\n'
elif line.startswith("date-released:"):
lines[i] = f'date-released: "{datetime.now().strftime("%Y-%m-%d")}"\n'
with citation.open("w") as f:
f.writelines(lines)
if __name__ == "__main__":
main()
================================================
FILE: scripts/release.sh
================================================
#!/bin/bash
set -e
TAG=$(python -c 'from tango.version import VERSION; print("v" + VERSION)')
read -p "Creating new release for $TAG. Do you want to continue? [Y/n] " prompt
if [[ $prompt == "y" || $prompt == "Y" || $prompt == "yes" || $prompt == "Yes" ]]; then
python scripts/prepare_changelog.py
python scripts/prepare_citation_cff.py
git add -A
git commit -m "Prepare for release $TAG" || true && git push
echo "Creating new git tag $TAG"
git tag "$TAG" -m "$TAG"
git push --tags
else
echo "Cancelled"
exit 1
fi
================================================
FILE: scripts/release_notes.py
================================================
# encoding: utf-8
"""
Prepares markdown release notes for GitHub releases.
"""
import os
from typing import List, Optional
import packaging.version
TAG = os.environ["TAG"]
ADDED_HEADER = "### Added 🎉"
CHANGED_HEADER = "### Changed ⚠️"
FIXED_HEADER = "### Fixed ✅"
REMOVED_HEADER = "### Removed 👋"
def get_change_log_notes() -> str:
in_current_section = False
current_section_notes: List[str] = []
with open("CHANGELOG.md") as changelog:
for line in changelog:
if line.startswith("## "):
if line.startswith("## Unreleased"):
continue
if line.startswith(f"## [{TAG}]"):
in_current_section = True
continue
break
if in_current_section:
if line.startswith("### Added"):
line = ADDED_HEADER + "\n"
elif line.startswith("### Changed"):
line = CHANGED_HEADER + "\n"
elif line.startswith("### Fixed"):
line = FIXED_HEADER + "\n"
elif line.startswith("### Removed"):
line = REMOVED_HEADER + "\n"
current_section_notes.append(line)
assert current_section_notes
return "## What's new\n\n" + "".join(current_section_notes).strip() + "\n"
def get_commit_history() -> str:
new_version = packaging.version.parse(TAG)
os.popen("git fetch --tags")
# Get all tags sorted by version, latest first.
all_tags = os.popen("git tag -l --sort=-version:refname 'v*'").read().split("\n")
# Out of `all_tags`, find the latest previous version so that we can collect all
# commits between that version and the new version we're about to publish.
# Note that we ignore pre-releases unless the new version is also a pre-release.
last_tag: Optional[str] = None
for tag in all_tags:
if not tag.strip(): # could be blank line
continue
version = packaging.version.parse(tag)
if new_version.pre is None and version.pre is not None:
continue
if version < new_version:
last_tag = tag
break
if last_tag is not None:
commits = os.popen(f"git log {last_tag}..{TAG}^ --oneline --first-parent").read()
else:
commits = os.popen("git log --oneline --first-parent").read()
return "## Commits\n\n" + commits
def main():
print(get_change_log_notes())
print(get_commit_history())
if __name__ == "__main__":
main()
================================================
FILE: tango/__init__.py
================================================
"""
A Python library for choreographing your machine learning research.
"""
__all__ = [
"cleanup_cli",
"DillFormat",
"DillFormatIterator",
"execute_step_graph",
"Executor",
"Format",
"initialize_cli",
"JsonFormat",
"JsonFormatIterator",
"load_settings",
"prepare_executor",
"prepare_workspace",
"Run",
"RunInfo",
"RunSort",
"SqliteDictFormat",
"Step",
"step",
"StepCache",
"StepGraph",
"StepInfo",
"StepInfoSort",
"StepResources",
"StepState",
"tango_cli",
"Workspace",
]
from .cli import (
cleanup_cli,
execute_step_graph,
initialize_cli,
load_settings,
prepare_executor,
prepare_workspace,
tango_cli,
)
from .executor import Executor
from .format import (
DillFormat,
DillFormatIterator,
Format,
JsonFormat,
JsonFormatIterator,
SqliteDictFormat,
)
from .step import Step, StepResources, step
from .step_cache import StepCache
from .step_graph import StepGraph
from .step_info import StepInfo, StepState
from .workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace
================================================
FILE: tango/__main__.py
================================================
"""
The Tango CLI is the recommended tool to run experiments with.
It also comes with several other useful commands.
You can see the the list of all available commands by running:
.. code-block::
$ tango --help
.. testcode::
:hide:
import subprocess
output = subprocess.run("tango --help".split(" "), capture_output=True)
output.check_returncode()
print(output.stdout.decode().replace("\\n\\n", "\\n").strip())
.. testoutput::
Usage: tango [OPTIONS] COMMAND [ARGS]...
Options:
--version Show the version and exit.
--settings FILE Path to a global tango.yml settings file.
--log-level [debug|info|warning|error]
Set the global log level.
--file-friendly-logging Outputs progress bar status on separate lines and slows refresh rate.
--start-method [fork|spawn|forkserver]
Set the multiprocessing start method.
--help Show this message and exit.
Commands:
info Get info about the current tango installation.
run Run a tango experiment.
settings Commands for initializing and updating global settings.
To see all of the available arguments and options for a particular command, run
.. code-block::
$ tango [COMMAND] --help
For example,
.. code-block::
$ tango run --help
``tango run``
-------------
The ``run`` command is used to execute a tango experiment from an experiment configuration file.
See the `Configuration files `_ section in the overview
for a quick introduction to the format.
``tango info``
--------------
The ``info`` command just prints out some useful information about the current tango installation,
such as which integrations are available.
``tango settings``
------------------
The ``settings`` group of commands can be used to initialize a :class:`~tango.settings.TangoGlobalSettings`
file or update fields in it.
"""
import os
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional, Sequence, Union
import click
from click_help_colors import HelpColorsCommand, HelpColorsGroup
from tango.cli import (
cleanup_cli,
execute_step_graph,
initialize_cli,
load_settings,
prepare_executor,
prepare_workspace,
)
from tango.common.exceptions import CliRunError, IntegrationMissingError
from tango.common.logging import cli_logger, initialize_logging
from tango.common.params import Params
from tango.common.util import (
find_integrations,
import_extra_module,
import_module_and_submodules,
)
from tango.settings import TangoGlobalSettings
from tango.step_graph import StepGraph
from tango.version import VERSION
from tango.workspace import Workspace
_CLICK_GROUP_DEFAULTS = {
"cls": HelpColorsGroup,
"help_options_color": "green",
"help_headers_color": "yellow",
"context_settings": {"max_content_width": 115},
}
_CLICK_COMMAND_DEFAULTS = {
"cls": HelpColorsCommand,
"help_options_color": "green",
"help_headers_color": "yellow",
"context_settings": {"max_content_width": 115},
}
class SettingsObject(NamedTuple):
settings: TangoGlobalSettings
called_by_executor: bool
@click.group(name=None, **_CLICK_GROUP_DEFAULTS)
@click.version_option(version=VERSION)
@click.option(
"--settings",
type=click.Path(exists=True, dir_okay=False, resolve_path=True),
help="Path to a global tango.yml settings file.",
)
@click.option(
"--log-level",
help="Set the global log level.",
type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False),
show_choices=True,
)
@click.option(
"--file-friendly-logging",
is_flag=True,
help="Outputs progress bar status on separate lines and slows refresh rate.",
)
@click.option(
"--start-method",
help="Set the multiprocessing start method.",
type=click.Choice(["fork", "spawn", "forkserver"], case_sensitive=True),
show_choices=True,
)
@click.option(
"--called-by-executor",
is_flag=True,
hidden=True,
)
@click.pass_context
def main(
ctx,
settings: Optional[str] = None,
log_level: Optional[str] = None,
file_friendly_logging: bool = False,
start_method: Optional[str] = None,
called_by_executor: bool = False,
):
settings: TangoGlobalSettings = load_settings(settings)
if start_method is not None:
settings.multiprocessing_start_method = start_method
if log_level is not None:
settings.log_level = log_level
if file_friendly_logging:
settings.file_friendly_logging = file_friendly_logging
ctx.obj = SettingsObject(settings, called_by_executor)
initialize_cli(settings=settings, called_by_executor=called_by_executor)
@main.result_callback()
def cleanup(*args, **kwargs):
cleanup_cli()
@main.command(**_CLICK_COMMAND_DEFAULTS)
@click.argument(
"experiment",
type=click.Path(exists=True, dir_okay=False, resolve_path=True),
)
@click.option(
"-w",
"--workspace",
type=click.Path(file_okay=False),
help="""A workspace path or URL. If not specified, the workspace from any global tango
settings file will be used, if found, otherwise an ephemeral MemoryWorkspace.""",
default=None,
)
@click.option(
"-d",
"--workspace-dir",
type=click.Path(file_okay=False),
default=None,
hidden=True,
)
@click.option(
"-o",
"--overrides",
type=str,
help="""A JSON(NET) string used to override fields in the experiment config.
Use dot syntax to specify nested fields.""",
)
@click.option(
"-i",
"--include-package",
type=str,
help="Python packages or modules to import for tango components.",
multiple=True,
)
@click.option(
"-j",
"--parallelism",
type=int,
help="""The maximum number of steps to run in parallel (for executors that support this).
The exact behavior depends on the executor. If you're using the default executors,
a value of 0 (or left unspecified) means each step is run in the main process using the default executor,
otherwise the multicore executor is used.""",
)
@click.option(
"-s",
"--step-name",
help="Execute a particular step (and its dependencies) in the experiment.",
multiple=True,
)
@click.option(
"-n",
"--name",
type=str,
help="""Specify the name for this run.""",
)
@click.option(
"-D",
"--ext-var",
type=str,
help="""JSONNET external variables to use when loading the experiment config.
For example, --ext-var 'pretrained_model=gpt2'.""",
multiple=True,
)
@click.pass_obj
def run(
obj: SettingsObject,
experiment: str,
workspace: Optional[str] = None,
workspace_dir: Optional[Union[str, os.PathLike]] = None,
overrides: Optional[str] = None,
include_package: Optional[Sequence[str]] = None,
parallelism: Optional[int] = None,
step_name: Optional[Sequence[str]] = None,
name: Optional[str] = None,
ext_var: Optional[Sequence[str]] = None,
):
"""
Run a tango experiment.
EXPERIMENT is the path to experiment's JSON/Jsonnet/YAML configuration file.
"""
if workspace_dir is not None:
import warnings
warnings.warn(
"-d/--workspace-dir option is deprecated. Please use -w/--workspace instead.",
DeprecationWarning,
)
if workspace is not None:
raise click.ClickException(
"-w/--workspace is mutually exclusive with -d/--workspace-dir"
)
workspace = "local://" + str(workspace_dir)
_run(
obj.settings,
experiment,
workspace_url=workspace,
overrides=overrides,
include_package=include_package,
parallelism=parallelism,
step_names=step_name,
name=name,
called_by_executor=obj.called_by_executor,
ext_var=ext_var,
)
@main.command(hidden=True)
@click.argument(
"experiment",
type=click.Path(exists=True, dir_okay=False, resolve_path=True),
)
@click.argument(
"step_name",
type=str,
)
@click.argument(
"workspace_url",
type=str,
)
@click.option(
"-i",
"--include-package",
type=str,
help="Python packages or modules to import for tango components.",
multiple=True,
)
@click.option(
"--log-level",
help="Set the global log level.",
type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False),
show_choices=True,
)
def beaker_executor_run(
experiment: str,
step_name: str,
workspace_url: str,
include_package: Optional[Sequence[str]] = None,
log_level: str = "debug",
):
"""
This command is only used internally by the BeakerExecutor.
"""
from tango.executor import Executor
if include_package:
for package_name in include_package:
import_extra_module(package_name)
# Load step graph and step.
step_graph = StepGraph.from_file(experiment)
step = step_graph[step_name]
# Initialize workspace and executor.
# NOTE: We use the default executor here because we're just running the step
# locally in the main process.
workspace = Workspace.from_url(workspace_url)
executor = Executor(workspace=workspace, include_package=include_package)
# Initialize logging.
initialize_logging(log_level=log_level, enable_cli_logs=True, file_friendly_logging=True)
# Run step.
executor.execute_step(step)
@main.command(**_CLICK_COMMAND_DEFAULTS)
@click.pass_obj
def info(obj: SettingsObject):
"""
Get info about the current tango installation.
"""
import platform
cli_logger.info("Tango version %s (python %s)", VERSION, platform.python_version())
cli_logger.info("")
# Show info about settings.
if obj.settings.path is not None:
cli_logger.info("[underline]Settings:[/]")
cli_logger.info("[green] \N{check mark} Loaded from %s[/]", obj.settings.path)
if obj.settings.include_package:
cli_logger.info(" Included packages:")
for package in obj.settings.include_package:
is_found = True
try:
import_module_and_submodules(package)
except (ModuleNotFoundError, ImportError):
is_found = False
if is_found:
cli_logger.info(" [green]\N{check mark} %s[/]", package)
else:
cli_logger.info(" [red]\N{ballot x} %s (not found)[/]", package)
cli_logger.info("")
# Show info about integrations.
cli_logger.info("[underline]Integrations:[/]")
for integration in find_integrations():
name = integration.split(".")[-1]
is_installed = True
try:
import_module_and_submodules(integration, recursive=False)
except (IntegrationMissingError, ModuleNotFoundError, ImportError):
is_installed = False
if is_installed:
cli_logger.info(" [green]\N{check mark} %s[/]", name)
else:
cli_logger.info(" [yellow]\N{ballot x} %s (not installed)[/]", name)
@main.group(**_CLICK_GROUP_DEFAULTS)
@click.pass_obj
def settings(ctx):
"""
Commands for initializing and updating global settings.
"""
@settings.command(**_CLICK_COMMAND_DEFAULTS)
@click.option(
"-p",
"--path",
type=click.Path(exists=False, dir_okay=False, resolve_path=True),
default=None,
help="""The path to write the settings to.""",
)
@click.option(
"-f",
"--force",
is_flag=True,
help="""Force overwrite the file if it exists.""",
)
@click.pass_obj
def init(obj: SettingsObject, path: Optional[str] = None, force: bool = False):
"""
Initialize the settings file.
"""
path_to_write = Path(path or TangoGlobalSettings._DEFAULT_LOCATION)
if path_to_write.is_file() and not force:
raise click.ClickException("Settings file already exists! Use -f/--force to overwrite it.")
obj.settings.to_file(path_to_write)
cli_logger.info(
"[green]\N{check mark} Settings file written to [bold]%s[/bold][/green]", path_to_write
)
@settings.group(name="set", **_CLICK_GROUP_DEFAULTS)
@click.pass_obj
def set_setting(obj: SettingsObject):
"""
Set a value in the settings file.
"""
if obj.settings.path is None:
raise click.ClickException(
"Settings file not found! Did you forget to call 'tango settings init'?"
)
@set_setting.result_callback()
def save_settings(settings: TangoGlobalSettings):
settings.save()
@set_setting.command(**_CLICK_COMMAND_DEFAULTS)
@click.argument(
"workspace",
type=str,
)
@click.option(
"--validate/--no-validate",
type=bool,
help="Validate that the workspace can be initialized.",
default=True,
)
@click.pass_obj
def workspace(obj: SettingsObject, workspace: str, validate: bool = True) -> TangoGlobalSettings:
"""
Set the default workspace path or URL.
"""
from urllib.parse import urlparse
if not urlparse(workspace).scheme:
obj.settings.workspace = {"type": "local", "dir": str(Path(workspace).resolve())}
else:
obj.settings.workspace = {"type": "from_url", "url": workspace}
if validate:
for package_name in obj.settings.include_package or []:
import_extra_module(package_name)
Workspace.from_params(obj.settings.workspace.copy())
return obj.settings
@set_setting.command(**_CLICK_COMMAND_DEFAULTS)
@click.argument(
"packages",
type=str,
nargs=-1,
)
@click.option(
"-a",
"--append",
is_flag=True,
help="Appends packages instead of overwriting.",
)
@click.option(
"--validate/--no-validate",
type=bool,
help="Validate that the workspace can be initialized.",
default=True,
)
@click.pass_obj
def include_package(
obj: SettingsObject,
packages: List[str],
append: bool = False,
validate: bool = True,
) -> TangoGlobalSettings:
"""
Set or add modules to automatically import on 'tango run'.
"""
new_include: List[str]
if append:
new_include = obj.settings.include_package or []
else:
new_include = []
for package in packages:
if package not in new_include:
new_include.append(package)
obj.settings.include_package = new_include
if validate:
for package in obj.settings.include_package:
try:
import_module_and_submodules(package)
except (ModuleNotFoundError, ImportError):
raise click.ClickException(f"Failed to import '{package}'")
return obj.settings
@set_setting.command(**_CLICK_COMMAND_DEFAULTS)
@click.argument(
"level",
type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False),
)
@click.pass_obj
def log_level(obj: SettingsObject, level: str) -> TangoGlobalSettings:
"""
Set the log level.
"""
obj.settings.log_level = level.lower()
return obj.settings
@set_setting.command(**_CLICK_COMMAND_DEFAULTS)
@click.argument(
"value",
type=bool,
)
@click.pass_obj
def file_friendly_logging(obj: SettingsObject, value: bool) -> TangoGlobalSettings:
"""
Toggle file friendly logging mode.
"""
obj.settings.file_friendly_logging = value
return obj.settings
@set_setting.command(**_CLICK_COMMAND_DEFAULTS)
@click.argument(
"start_method",
type=click.Choice(["fork", "spawn", "forkserver"], case_sensitive=True),
)
@click.pass_obj
def multiprocessing_start_method(obj: SettingsObject, start_method: str) -> TangoGlobalSettings:
"""
Set the Python multiprocessing start method.
"""
obj.settings.multiprocessing_start_method = start_method
return obj.settings
@set_setting.command(**_CLICK_COMMAND_DEFAULTS)
@click.argument(
"key",
type=str,
)
@click.argument(
"value",
type=str,
)
@click.pass_obj
def env(obj: SettingsObject, key: str, value: str) -> TangoGlobalSettings:
"""
Add or update an environment variable.
"""
from tango.common.aliases import EnvVarNames
# These environment variables should not be set this way since they'll be ignored.
blocked_env_variable_names = EnvVarNames.values()
if key in blocked_env_variable_names:
raise click.ClickException(
f"Cannot add environment variable '{key}' to settings. "
f"Please set the corresponding settings field instead."
)
if obj.settings.environment is None:
obj.settings.environment = {}
obj.settings.environment[key] = value
return obj.settings
def _run(
settings: TangoGlobalSettings,
experiment: str,
workspace_url: Optional[str] = None,
overrides: Optional[str] = None,
include_package: Optional[Sequence[str]] = None,
step_names: Optional[Sequence[str]] = None,
parallelism: Optional[int] = None,
multicore: Optional[bool] = None,
name: Optional[str] = None,
called_by_executor: bool = False,
ext_var: Optional[Sequence[str]] = None,
) -> str:
# Read params.
ext_vars: Dict[str, str] = {}
for var in ext_var or []:
try:
key, value = var.split("=")
except ValueError:
raise CliRunError(f"Invalid --ext-var '{var}'")
ext_vars[key] = value
params = Params.from_file(experiment, params_overrides=overrides or "", ext_vars=ext_vars)
# Import included packages to find registered components.
# NOTE: The Executor imports these as well because it's meant to be used
# directly, but we also need to import here in case the user is using a
# custom Executor, StepCache, or Workspace.
include_package: List[str] = list(include_package or [])
include_package += params.pop("include_package", [])
include_package += settings.include_package or []
for package_name in include_package:
import_extra_module(package_name)
# Initialize step graph.
step_graph: StepGraph = StepGraph.from_params(params.pop("steps"))
params.assert_empty("'tango run'")
if step_names:
for step_name in step_names:
assert step_name in step_graph, (
f"You want to run a step called '{step_name}', but it cannot be found in the experiment config. "
f"The config contains: {list(step_graph.keys())}."
)
step_graph = step_graph.sub_graph(*step_names)
# Execute step graph in workspace
workspace = prepare_workspace(settings=settings, workspace_url=workspace_url)
executor = prepare_executor(
settings=settings,
workspace=workspace,
include_package=include_package,
parallelism=parallelism,
multicore=multicore,
called_by_executor=called_by_executor,
)
run_name = execute_step_graph(
step_graph=step_graph,
workspace=workspace,
executor=executor,
name=name,
called_by_executor=called_by_executor,
step_names=step_names,
)
return run_name
if __name__ == "__main__":
main()
================================================
FILE: tango/cli.py
================================================
import logging
import multiprocessing as mp
import os
import sys
import warnings
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Optional, Sequence, Union
from tango.common.exceptions import CliRunError
from tango.common.logging import (
cli_logger,
initialize_logging,
initialize_prefix_logging,
teardown_logging,
)
from tango.common.params import Params
from tango.executor import Executor
from tango.settings import TangoGlobalSettings
from tango.step_graph import StepGraph
from tango.workspace import Workspace
if TYPE_CHECKING:
from tango.executor import ExecutorOutput
from tango.workspace import Run
logger = logging.getLogger(__name__)
def load_settings(settings: Union[str, Params, dict, None] = None) -> TangoGlobalSettings:
return (
TangoGlobalSettings.from_file(settings)
if isinstance(settings, str)
else TangoGlobalSettings.from_params(settings)
if isinstance(settings, (Params, dict))
else TangoGlobalSettings.default()
)
@contextmanager
def tango_cli(settings: Union[TangoGlobalSettings, str, Params, dict, None] = None):
if not isinstance(settings, TangoGlobalSettings):
settings = load_settings(settings)
try:
initialize_cli(settings=settings, called_by_executor=False)
yield
finally:
cleanup_cli()
def initialize_cli(
settings: Optional[TangoGlobalSettings] = None,
called_by_executor: bool = False,
):
if settings is None:
settings = TangoGlobalSettings.default()
if not sys.warnoptions:
warnings.simplefilter("default", category=DeprecationWarning)
if settings.environment:
from tango.common.aliases import EnvVarNames
# These environment variables should not be set this way since they'll be ignored.
blocked_env_variable_names = EnvVarNames.values()
for key, value in settings.environment.items():
if key not in blocked_env_variable_names:
os.environ[key] = value
else:
warnings.warn(
f"Ignoring environment variable '{key}' from settings file. "
f"Please use the corresponding settings field instead.",
UserWarning,
)
mp.set_start_method(settings.multiprocessing_start_method)
if not called_by_executor:
initialize_logging(
log_level=settings.log_level,
file_friendly_logging=settings.file_friendly_logging,
enable_cli_logs=True,
)
def cleanup_cli():
teardown_logging()
def prepare_workspace(
settings: Optional[TangoGlobalSettings] = None,
workspace_url: Optional[str] = None,
) -> Workspace:
from tango.workspaces import default_workspace
if settings is None:
settings = TangoGlobalSettings.default()
workspace: Workspace
if workspace_url is not None:
workspace = Workspace.from_url(workspace_url)
elif settings.workspace is not None:
workspace = Workspace.from_params(settings.workspace)
else:
workspace = default_workspace
return workspace
def prepare_executor(
workspace: Workspace,
settings: Optional[TangoGlobalSettings] = None,
include_package: Optional[Sequence[str]] = None,
parallelism: Optional[int] = None,
multicore: Optional[bool] = None,
called_by_executor: bool = False,
) -> Executor:
from tango.executors import MulticoreExecutor
from tango.workspaces import MemoryWorkspace
if settings is None:
settings = TangoGlobalSettings.default()
executor: Executor
if not called_by_executor and settings.executor is not None:
if multicore is not None:
logger.warning(
"Ignoring argument 'multicore' since executor is defined in %s",
settings.path or "setting",
)
executor = Executor.from_params(
settings.executor,
workspace=workspace,
include_package=include_package,
**(dict(parallelism=parallelism) if parallelism is not None else {}), # type: ignore
)
else:
# Determine if we can use the multicore executor.
if multicore is None:
if isinstance(workspace, MemoryWorkspace):
# Memory workspace does not work with multiple cores.
multicore = False
elif "pydevd" in sys.modules:
# Pydevd doesn't reliably follow child processes, so we disable multicore under the debugger.
logger.warning("Debugger detected, disabling multicore.")
multicore = False
elif parallelism is None or parallelism == 0:
multicore = False
else:
multicore = True
if multicore:
executor = MulticoreExecutor(
workspace=workspace, include_package=include_package, parallelism=parallelism
)
else:
executor = Executor(workspace=workspace, include_package=include_package)
return executor
def execute_step_graph(
step_graph: StepGraph,
workspace: Optional[Workspace] = None,
executor: Optional[Executor] = None,
name: Optional[str] = None,
called_by_executor: bool = False,
step_names: Optional[Sequence[str]] = None,
) -> str:
if workspace is None:
workspace = prepare_workspace()
executor = prepare_executor(workspace=workspace)
elif executor is None:
executor = prepare_executor(workspace=workspace)
# Register run.
run: "Run"
if called_by_executor and name is not None:
try:
run = workspace.registered_run(name)
except KeyError:
raise RuntimeError(
"The CLI was called by `MulticoreExecutor.execute_step_graph`, but "
f"'{name}' is not already registered as a run. This should never happen!"
)
else:
run = workspace.register_run((step for step in step_graph.values()), name)
if called_by_executor:
assert step_names is not None and len(step_names) == 1
from tango.common.aliases import EnvVarNames
# We set this environment variable so that any steps that contain multiprocessing
# and call `initialize_worker_logging` also log the messages with the `step_name` prefix.
os.environ[EnvVarNames.LOGGING_PREFIX.value] = f"step {step_names[0]}"
initialize_prefix_logging(prefix=f"step {step_names[0]}", main_process=False)
# Capture logs to file.
with workspace.capture_logs_for_run(run.name) if not called_by_executor else nullcontext():
if not called_by_executor:
cli_logger.info("[green]Starting new run [bold]%s[/][/]", run.name)
executor_output: ExecutorOutput = executor.execute_step_graph(step_graph, run_name=run.name)
if executor_output.failed:
cli_logger.error("[red]\N{ballot x} Run [bold]%s[/] finished with errors[/]", run.name)
elif not called_by_executor:
cli_logger.info("[green]\N{check mark} Finished run [bold]%s[/][/]", run.name)
if executor_output is not None:
if not called_by_executor:
executor_output.display()
if executor_output.failed:
raise CliRunError
return run.name
================================================
FILE: tango/common/__init__.py
================================================
from .aliases import PathOrStr
from .dataset_dict import DatasetDict, DatasetDictBase, IterableDatasetDict
from .det_hash import det_hash
from .from_params import FromParams
from .lazy import Lazy
from .params import Params
from .registrable import Registrable, RegistrableFunction, make_registrable
from .tqdm import Tqdm
from .util import filename_is_safe, threaded_generator
__all__ = [
"PathOrStr",
"DatasetDictBase",
"DatasetDict",
"IterableDatasetDict",
"det_hash",
"Params",
"FromParams",
"Registrable",
"RegistrableFunction",
"make_registrable",
"Lazy",
"Tqdm",
"filename_is_safe",
"threaded_generator",
]
================================================
FILE: tango/common/aliases.py
================================================
from enum import Enum, unique
from os import PathLike
from typing import Set, Union
PathOrStr = Union[str, PathLike]
@unique
class EnvVarNames(Enum):
FILE_FRIENDLY_LOGGING = "FILE_FRIENDLY_LOGGING"
LOG_LEVEL = "TANGO_LOG_LEVEL"
CLI_LOGGER_ENABLED = "TANGO_CLI_LOGGER_ENABLED"
LOGGING_HOST = "TANGO_LOGGING_HOST"
LOGGING_PORT = "TANGO_LOGGING_PORT"
LOGGING_PREFIX = "TANGO_LOGGING_PREFIX"
CONSOLE_WIDTH = "TANGO_CONSOLE_WIDTH"
@classmethod
def values(cls) -> Set[str]:
return set(e.value for e in cls)
================================================
FILE: tango/common/dataset_dict.py
================================================
from dataclasses import dataclass, field
from typing import Any, Generic, Iterable, Iterator, Mapping, Sequence, TypeVar
T = TypeVar("T")
S = TypeVar("S")
@dataclass
class DatasetDictBase(Generic[S], Mapping[str, S]):
"""
The base class for :class:`DatasetDict` and :class:`IterableDatasetDict`.
"""
splits: Mapping[str, S]
"""
A mapping of dataset split names to splits.
"""
metadata: Mapping[str, Any] = field(default_factory=dict)
"""
Metadata can contain anything you need.
"""
def __getitem__(self, split: str) -> S:
"""
Get a split in :attr:`splits`.
"""
return self.splits[split]
def __contains__(self, split: str) -> bool: # type: ignore[override]
"""
Checks if :attr:`splits` contains the given split.
"""
return split in self.splits
def __iter__(self) -> Iterator[str]:
"""
Returns an iterator over the keys in :attr:`splits`.
"""
return iter(self.splits.keys())
def __len__(self) -> int:
"""
Returns the number of splits in :attr:`splits`.
"""
return len(self.splits)
def keys(self):
"""
Returns the split names in :attr:`splits`.
"""
return self.splits.keys()
@dataclass
class DatasetDict(DatasetDictBase[Sequence[T]], Generic[T]):
"""
A generic :class:`~collections.abc.Mapping` class of split names (:class:`str`) to datasets
(``Sequence[T]``).
"""
@dataclass
class IterableDatasetDict(DatasetDictBase[Iterable[T]], Generic[T]):
"""
An "iterable" version of :class:`DatasetDict`, where the dataset splits have
type ``Iterable[T]`` instead of ``Sequence[T]``. This is useful for streaming datasets.
"""
================================================
FILE: tango/common/det_hash.py
================================================
import collections
import hashlib
import io
from abc import abstractmethod
from typing import Any, MutableMapping, Optional, Type
import base58
import dill
ndarray: Optional[Type]
try:
from numpy import ndarray
except ModuleNotFoundError:
ndarray = None
TorchTensor: Optional[Type]
try:
from torch import Tensor as TorchTensor
except ModuleNotFoundError:
TorchTensor = None
class CustomDetHash:
"""
By default, :func:`det_hash()` pickles an object, and returns the hash of the pickled
representation. Sometimes you want to take control over what goes into
that hash. In that case, derive from this class and implement :meth:`det_hash_object()`.
:func:`det_hash()` will pickle the result of this method instead of the object itself.
If you return ``None``, :func:`det_hash()` falls back to the original behavior and pickles
the object.
"""
@abstractmethod
def det_hash_object(self) -> Any:
"""
Return an object to use for deterministic hashing instead of ``self``.
"""
raise NotImplementedError()
class DetHashFromInitParams(CustomDetHash):
"""
Add this class as a mixin base class to make sure your class's det_hash is derived
exclusively from the parameters passed to ``__init__()``.
"""
_det_hash_object: Any
def __new__(cls, *args, **kwargs):
super_new = super(DetHashFromInitParams, cls).__new__
if super().__new__ is object.__new__ and cls.__init__ is not object.__init__:
instance = super_new(cls)
else:
instance = super_new(cls, *args, **kwargs)
instance._det_hash_object = (args, kwargs)
return instance
def det_hash_object(self) -> Any:
"""Returns a copy of the parameters that were passed to the class instance's ``__init__()`` method."""
return self._det_hash_object
class DetHashWithVersion(CustomDetHash):
"""
Add this class as a mixin base class to make sure your class's det_hash can be modified
by altering a static ``VERSION`` member of your class.
Let's say you are working on training a model. Whenever you change code that's part of your experiment,
you have to change the :attr:`~tango.step.Step.VERSION` of the step that's running that code to tell
Tango that the step has changed and should be re-run. But if
you are training your model using Tango's built-in :class:`~tango.integrations.torch.TorchTrainStep`,
how do you change the version of the step? The answer is, leave the version of the step alone, and
instead add a :attr:`VERSION` to your model by deriving from this class:
.. code-block:: Python
class MyModel(DetHashWithVersion):
VERSION = "001"
def __init__(self, ...):
...
"""
VERSION: Optional[str] = None
def det_hash_object(self) -> Any:
"""
Returns a tuple of :attr:`~tango.common.det_hash.DetHashWithVersion.VERSION` and this instance itself.
"""
if self.VERSION is not None:
return self.VERSION, self
else:
return None # When you return `None` from here, it falls back to just hashing the object itself.
_PICKLE_PROTOCOL = 4
class _DetHashPickler(dill.Pickler):
def __init__(self, buffer: io.BytesIO):
super().__init__(buffer, protocol=_PICKLE_PROTOCOL)
# We keep track of how deeply we are nesting the pickling of an object.
# If a class returns `self` as part of `det_hash_object()`, it causes an
# infinite recursion, because we try to pickle the `det_hash_object()`, which
# contains `self`, which returns a `det_hash_object()`, etc.
# So we keep track of how many times recursively we are trying to pickle the
# same object. We only call `det_hash_object()` the first time. We assume that
# if `det_hash_object()` returns `self` in any way, we want the second time
# to just pickle the object as normal. `DetHashWithVersion` takes advantage
# of this ability.
self.recursively_pickled_ids: MutableMapping[int, int] = collections.Counter()
def save(self, obj, save_persistent_id=True):
self.recursively_pickled_ids[id(obj)] += 1
super().save(obj, save_persistent_id)
self.recursively_pickled_ids[id(obj)] -= 1
def persistent_id(self, obj: Any) -> Any:
if isinstance(obj, CustomDetHash) and self.recursively_pickled_ids[id(obj)] <= 1:
det_hash_object = obj.det_hash_object()
if det_hash_object is not None:
return obj.__class__.__module__, obj.__class__.__qualname__, det_hash_object
else:
return None
elif isinstance(obj, type):
return obj.__module__, obj.__qualname__
elif callable(obj):
if hasattr(obj, "__module__") and hasattr(obj, "__qualname__"):
return obj.__module__, obj.__qualname__
else:
return None
elif ndarray is not None and isinstance(obj, ndarray):
# It's unclear why numpy arrays don't pickle in a consistent way.
return obj.dumps()
elif TorchTensor is not None and isinstance(obj, TorchTensor):
# It's unclear why torch tensors don't pickle in a consistent way.
import torch
with io.BytesIO() as buffer:
torch.save(obj, buffer, pickle_protocol=_PICKLE_PROTOCOL)
return buffer.getvalue()
else:
return None
def det_hash(o: Any) -> str:
"""
Returns a deterministic hash code of arbitrary Python objects.
If you want to override how we calculate the deterministic hash, derive from the
:class:`CustomDetHash` class and implement :meth:`CustomDetHash.det_hash_object()`.
"""
m = hashlib.blake2b()
with io.BytesIO() as buffer:
pickler = _DetHashPickler(buffer)
pickler.dump(o)
m.update(buffer.getbuffer())
return base58.b58encode(m.digest()).decode()
================================================
FILE: tango/common/exceptions.py
================================================
from typing import TYPE_CHECKING, Any, Optional, Set, Tuple, Union
if TYPE_CHECKING:
from tango.step import Step
from tango.step_info import StepInfo, StepState
class TangoError(Exception):
"""
Base class for Tango exceptions.
"""
class ConfigurationError(TangoError):
"""
The exception raised when a Tango object fails to initialize from a config
that's misconfigured (e.g. missing properties, invalid properties, unknown properties).
"""
def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return type(self), (self.message,)
def __init__(self, message: str):
super().__init__()
self.message = message
def __str__(self):
return self.message
class RegistryKeyError(ConfigurationError):
"""
A configuration error that is raised when attempting to get a class by a registered name
that doesn't exist in the registry.
"""
class CancellationError(TangoError):
"""
Base class for errors raised due to manual cancellation of a run or step.
"""
class SigTermReceived(CancellationError):
"""
Raised when a SIGTERM is caught.
"""
class StepCancelled(CancellationError):
pass
class RunCancelled(CancellationError):
pass
class CliRunError(TangoError):
"""
Raised when `tango run` command fails.
"""
class IntegrationMissingError(TangoError):
"""
Raised when an integration can't be used due to missing dependencies.
"""
def __init__(self, integration: str, dependencies: Optional[Set[str]] = None):
self.integration = integration
self.dependencies = dependencies or {integration}
msg = (
f"'{self.integration}' integration can't be used due to "
f"missing dependencies ({', '.join(self.dependencies)})"
)
super().__init__(msg)
class StepStateError(TangoError):
"""
Raised when a step is in an unexpected state.
"""
def __init__(
self,
step: Union["Step", "StepInfo"],
step_state: "StepState",
context: Optional[str] = None,
):
self.step_state = step_state
self.step_id = step.unique_id
msg = f"Step '{self.step_id}' is in unexpected state '{self.step_state.value}'"
if context is not None:
msg = msg + " " + context
super().__init__(msg)
class DirtyRepoError(TangoError):
"""
Raised when a repository is in a dirty state.
"""
class ExecutorError(TangoError):
"""
A base class for executor-specific errors.
"""
================================================
FILE: tango/common/file_lock.py
================================================
import os
import warnings
from typing import Optional
from filelock import AcquireReturnProxy
from filelock import FileLock as _FileLock
from .aliases import PathOrStr
class FileLock(_FileLock): # type: ignore[valid-type,misc]
"""
This is just a subclass of the `FileLock` class from the `filelock` library, except that
it adds an additional argument to the `__init__` method: `read_only_ok`.
By default this flag is `False`, which an exception will be thrown when a lock
can't be acquired due to lack of write permissions.
But if this flag is set to `True`, a warning will be emitted instead of an error when
the lock already exists but the lock can't be acquired because write access is blocked.
"""
def __init__(self, lock_file: PathOrStr, timeout=-1, read_only_ok: bool = False) -> None:
super().__init__(str(lock_file), timeout=timeout)
self._read_only_ok = read_only_ok
def acquire( # type: ignore[override]
self,
timeout: Optional[float] = None,
poll_interval: float = 0.05,
) -> AcquireReturnProxy:
try:
return super().acquire(timeout=timeout, poll_interval=poll_interval)
except OSError as err:
# OSError could be a lot of different things, but what we're looking
# for in particular are permission errors, such as:
# - errno 1 - EPERM - "Operation not permitted"
# - errno 13 - EACCES - "Permission denied"
# - errno 30 - EROFS - "Read-only file system"
if err.errno not in (1, 13, 30):
raise
if os.path.isfile(self._lock_file) and self._read_only_ok: # type: ignore
warnings.warn(
f"Lacking permissions required to obtain lock '{self._lock_file}'. " # type: ignore
"Race conditions are possible if other processes are writing to the same resource.",
UserWarning,
)
return AcquireReturnProxy(self)
else:
raise
def acquire_with_updates(self, desc: Optional[str] = None) -> AcquireReturnProxy:
"""
Same as :meth:`acquire()`, except that when the lock cannot be immediately acquired,
it will keep trying and print status updates as it goes.
"""
try:
return self.acquire(timeout=0.1)
except TimeoutError:
pass
from .tqdm import Tqdm
if desc is None:
desc = f"acquiring lock at {self._lock_file}" # type: ignore
progress = Tqdm.tqdm(desc=desc, bar_format="{desc} [{elapsed}]")
while True:
progress.update()
try:
return self.acquire(timeout=1)
except TimeoutError:
continue
================================================
FILE: tango/common/from_params.py
================================================
import collections.abc
import inspect
import logging
from copy import deepcopy
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
get_type_hints,
)
from tango.common.det_hash import DetHashWithVersion
from tango.common.exceptions import ConfigurationError
from tango.common.lazy import Lazy
from tango.common.params import Params
try:
# For PEP 604 support (python >= 3.10)
from types import UnionType # type: ignore[attr-defined]
except ImportError:
class UnionType: # type: ignore
pass
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="FromParams")
# If a function parameter has no default value specified,
# this is what the inspect module returns.
_NO_DEFAULT = inspect.Parameter.empty
def takes_arg(obj, arg: str) -> bool:
"""
Checks whether the provided obj takes a certain arg.
If it's a class, we're really checking whether its constructor does.
If it's a function or method, we're checking the object itself.
Otherwise, we raise an error.
"""
if inspect.isclass(obj):
signature = inspect.signature(obj.__init__)
elif inspect.ismethod(obj) or inspect.isfunction(obj):
signature = inspect.signature(obj)
else:
raise ConfigurationError(f"object {obj} is not callable")
return arg in signature.parameters
def takes_kwargs(obj) -> bool:
"""
Checks whether a provided object takes in any positional arguments.
Similar to takes_arg, we do this for both the __init__ function of
the class or a function / method
Otherwise, we raise an error
"""
if inspect.isclass(obj):
signature = inspect.signature(obj.__init__)
elif inspect.ismethod(obj) or inspect.isfunction(obj):
signature = inspect.signature(obj)
else:
raise ConfigurationError(f"object {obj} is not callable")
return any(
p.kind == inspect.Parameter.VAR_KEYWORD # type: ignore
for p in signature.parameters.values()
)
def is_base_registrable(cls) -> bool:
"""
Checks whether this is a class that directly inherits from Registrable, or is a subclass of such
a class.
"""
from tango.common.registrable import (
Registrable, # import here to avoid circular imports
)
if not issubclass(cls, Registrable):
return False
method_resolution_order = inspect.getmro(cls)[1:]
for base_class in method_resolution_order:
if issubclass(base_class, Registrable) and base_class is not Registrable:
return False
return True
def remove_optional(annotation: type):
"""
Optional[X] annotations are actually represented as Union[X, NoneType].
For our purposes, the "Optional" part is not interesting, so here we
throw it away.
"""
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", ())
if origin == Union:
return Union[tuple([arg for arg in args if arg != type(None)])] # noqa: E721
else:
return annotation
def infer_constructor_params(
cls: Type[T], constructor: Optional[Union[Callable[..., T], Callable[[T], None]]] = None
) -> Dict[str, inspect.Parameter]:
if constructor is None:
constructor = cls.__init__
return infer_method_params(cls, constructor)
infer_params = infer_constructor_params # Legacy name
def infer_method_params(
cls: Type[T], method: Callable, infer_kwargs: bool = True
) -> Dict[str, inspect.Parameter]:
signature = inspect.signature(method)
parameters = dict(signature.parameters)
has_kwargs = False
var_positional_key = None
for param_name in list(parameters.keys()):
# Ignore special private parameters.
# This is necessary to make `FromParams` work with Pydantic, for example.
if param_name.startswith("__"):
del parameters[param_name]
continue
param = parameters[param_name]
if param.kind == param.VAR_KEYWORD:
has_kwargs = True
elif param.kind == param.VAR_POSITIONAL:
var_positional_key = param.name
if isinstance(param.annotation, str):
# For Python < 3.10, if the module where this class was defined used
# `from __future__ import annotation`, the annotation will be a str,
# so we need to resolve it using `get_type_hints` from the typing module.
# See https://www.python.org/dev/peps/pep-0563/ for more info.
try:
parameters[param_name] = param.replace(
annotation=get_type_hints(method)[param_name]
)
except TypeError as e:
if "'type' object is not subscriptable" in str(e):
# This can happen when someone uses a type hint like `dict[str, str]`
# instead of `Dict[str, str]`.
err_msg = (
f"Failed to parse the type annotation `{param.annotation}` "
f"from `{cls.__qualname__}.{method.__name__}()`."
)
if "[" in param.annotation:
# Check if there is an equivalent generic in the `typing` module.
import typing
type_, *_ = param.annotation.split("[", 1)
for possible_typing_equivalent in {type_, type_.title()}:
if hasattr(typing, possible_typing_equivalent):
err_msg += (
f" Try using `{possible_typing_equivalent}` "
"from the `typing` module instead."
)
break
new_e = TypeError(err_msg)
new_e.__cause__ = e
new_e.__cause__ = e
raise new_e
else:
raise
if var_positional_key:
del parameters[var_positional_key]
if not has_kwargs or not infer_kwargs:
return parameters
# "mro" is "method resolution order". The first one is the current class, the next is the
# first superclass, and so on. We take the first superclass we find that inherits from
# FromParams.
super_class = None
# We have to be a little careful here because in some cases we might not have an
# actual class. Instead we might just have a function that returns a class instance.
if hasattr(cls, "mro"):
for super_class_candidate in cls.mro()[1:]:
if issubclass(super_class_candidate, FromParams):
super_class = super_class_candidate
break
if super_class:
super_parameters = infer_params(super_class)
else:
super_parameters = {}
return {**super_parameters, **parameters} # Subclass parameters overwrite superclass ones
def create_kwargs(
constructor: Callable[..., T],
cls: Type[T],
params: Params,
extras: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Given some class, a ``Params`` object, and potentially other keyword arguments,
create a dict of keyword args suitable for passing to the class's constructor.
The function does this by finding the class's constructor, matching the constructor
arguments to entries in the ``params`` object, and instantiating values for the parameters
using the type annotation and possibly a from_params method.
Any values that are provided in the ``extras`` will just be used as is.
For instance, you might provide an existing ``Vocabulary`` this way.
"""
# Get the signature of the constructor.
kwargs: Dict[str, Any] = {}
parameters = infer_params(cls, constructor)
accepts_kwargs = False
# Iterate over all the constructor parameters and their annotations.
for param_name, param in parameters.items():
# Skip "self". You're not *required* to call the first parameter "self",
# so in theory this logic is fragile, but if you don't call the self parameter
# "self" you kind of deserve what happens.
if param_name == "self":
continue
if param.kind == param.VAR_KEYWORD:
# When a class takes **kwargs, we do two things: first, we assume that the **kwargs are
# getting passed to the super class, so we inspect super class constructors to get
# allowed arguments (that happens in `infer_params` above). Second, we store the fact
# that the method allows extra keys; if we get extra parameters, instead of crashing,
# we'll just pass them as-is to the constructor, and hope that you know what you're
# doing.
accepts_kwargs = True
continue
# If the annotation is a compound type like typing.Dict[str, int],
# it will have an __origin__ field indicating `typing.Dict`
# and an __args__ field indicating `(str, int)`. We capture both.
annotation = remove_optional(param.annotation)
explicitly_set = param_name in params
constructed_arg = pop_and_construct_arg(
cls.__name__, param_name, annotation, param.default, params, extras or {}
)
# If the param wasn't explicitly set in `params` and we just ended up constructing
# the default value for the parameter, we can just omit it.
# Leaving it in can cause issues with **kwargs in some corner cases, where you might end up
# with multiple values for a single parameter (e.g., the default value gives you lazy=False
# for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes
# lazy=True - the superclass sees both lazy=True and lazy=False in its constructor).
if explicitly_set or constructed_arg is not param.default:
kwargs[param_name] = constructed_arg
if accepts_kwargs:
for key in list(params):
kwargs[key] = params.pop(key, keep_as_dict=True)
if extras:
for key, value in extras.items():
kwargs[key] = value
params.assert_empty(cls.__name__)
return kwargs
def create_extras(cls: Type[T], extras: Dict[str, Any]) -> Dict[str, Any]:
"""
Given a dictionary of extra arguments, returns a dictionary of
kwargs that actually are a part of the signature of the cls.from_params
(or cls) method.
"""
subextras: Dict[str, Any] = {}
if hasattr(cls, "from_params"):
from_params_method = cls.from_params # type: ignore
else:
# In some rare cases, we get a registered subclass that does _not_ have a
# from_params method (this happens with Activations, for instance, where we
# register pytorch modules directly). This is a bit of a hack to make those work,
# instead of adding a `from_params` method for them somehow. Then the extras
# in the class constructor are what we are looking for, to pass on.
from_params_method = cls
if takes_kwargs(from_params_method):
# If annotation.params accepts **kwargs, we need to pass them all along.
# For example, `BasicTextFieldEmbedder.from_params` requires a Vocabulary
# object, but `TextFieldEmbedder.from_params` does not.
subextras = extras
else:
# Otherwise, only supply the ones that are actual args; any additional ones
# will cause a TypeError.
subextras = {k: v for k, v in extras.items() if takes_arg(from_params_method, k)}
return subextras
def pop_and_construct_arg(
class_name: str,
argument_name: str,
annotation: Type,
default: Any,
params: Params,
extras: Dict[str, Any],
) -> Any:
"""
Does the work of actually constructing an individual argument for
[``create_kwargs``](./#create_kwargs).
Here we're in the inner loop of iterating over the parameters to a particular constructor,
trying to construct just one of them. The information we get for that parameter is its name,
its type annotation, and its default value; we also get the full set of ``Params`` for
constructing the object (which we may mutate), and any ``extras`` that the constructor might
need.
We take the type annotation and default value here separately, instead of using an
``inspect.Parameter`` object directly, so that we can handle ``Union`` types using recursion on
this method, trying the different annotation types in the union in turn.
"""
# We used `argument_name` as the method argument to avoid conflicts with 'name' being a key in
# `extras`, which isn't _that_ unlikely. Now that we are inside the method, we can switch back
# to using `name`.
name = argument_name
# Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.
# We check the provided `extras` for these and just use them if they exist.
if name in extras:
if name not in params:
return extras[name]
else:
logger.warning(
f"Parameter {name} for class {class_name} was found in both "
"**extras and in params. Using the specification found in params, "
"but you probably put a key in a config file that you didn't need, "
"and if it is different from what we get from **extras, you might "
"get unexpected behavior."
)
try:
popped_params = params.pop(name, default) if default != _NO_DEFAULT else params.pop(name)
except ConfigurationError:
raise ConfigurationError(f'Missing key "{name}" for {class_name}')
if popped_params is None:
return None
return construct_arg(class_name, name, popped_params, annotation, default)
def _params_contain_step(o: Any) -> bool:
from tango.step import Step
if isinstance(o, Step):
return True
elif isinstance(o, str):
return False # Confusingly, str is an Iterable of itself, resulting in infinite recursion.
elif isinstance(o, Params):
return _params_contain_step(o.as_dict(quiet=True))
elif isinstance(o, dict):
if set(o.keys()) == {"type", "ref"} and o["type"] == "ref":
return True
else:
return _params_contain_step(o.values())
elif isinstance(o, Iterable):
return any(_params_contain_step(p) for p in o)
else:
return False
def construct_arg(
class_name: str,
argument_name: str,
popped_params: Params,
annotation: Type,
default: Any,
try_from_step: bool = True,
) -> Any:
"""
The first two parameters here are only used for logging if we encounter an error.
"""
# If we have the default, we're already done :)
if popped_params is default:
return popped_params
from tango.step import FunctionalStep, Step, StepIndexer, WithUnresolvedSteps
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", [])
# Try to guess if `popped_params` might be a step, come from a step, or contain a step.
could_be_step = (
try_from_step
and (
origin == Step
or isinstance(popped_params, Step)
or _params_contain_step(popped_params)
or (isinstance(popped_params, (dict, Params)) and popped_params.get("type") == "ref")
)
and not (class_name == "StepInfo" and argument_name == "config")
)
if could_be_step:
# If we think it might be a step, we try parsing as a step _first_.
# Parsing as a non-step always succeeds, because it will fall back to returning a dict.
# So we can't try parsing as a non-step first.
backup_params = deepcopy(popped_params)
try:
return construct_arg(
class_name,
argument_name,
popped_params,
Step[annotation], # type: ignore
default,
try_from_step=False,
)
except (ValueError, TypeError, ConfigurationError, AttributeError, IndexError):
popped_params = backup_params
# The parameter is optional if its default value is not the "no default" sentinel.
optional = default != _NO_DEFAULT
if (inspect.isclass(annotation) and issubclass(annotation, FromParams)) or (
inspect.isclass(origin) and issubclass(origin, FromParams)
):
if origin is None and isinstance(popped_params, annotation):
return popped_params
elif popped_params is not None:
# In some cases we allow a string instead of a param dict, so
# we need to handle that case separately.
if isinstance(popped_params, str):
if origin != Step:
# We don't allow single strings to be upgraded to steps.
# Since we try everything as a step first, upgrading strings to
# steps automatically would cause confusion every time a step
# name conflicts with any string anywhere in a config.
popped_params = Params({"type": popped_params})
elif isinstance(popped_params, dict):
popped_params = Params(popped_params)
elif not isinstance(popped_params, (Params, Step)):
raise TypeError(
f"Expected a `Params` object, found `{popped_params}` instead while constructing "
f"parameter '{argument_name}' for `{class_name}`"
)
result: Union[FromParams, WithUnresolvedSteps]
if isinstance(popped_params, Step):
result = popped_params
else:
if origin != Step and _params_contain_step(popped_params):
result = WithUnresolvedSteps(annotation.from_params, popped_params)
else:
result = annotation.from_params(popped_params)
if isinstance(result, Step):
expected_return_type = args[0] if args else None
if isinstance(result, FunctionalStep):
return_type = inspect.signature(result.WRAPPED_FUNC).return_annotation
else:
return_type = inspect.signature(result.run).return_annotation
if return_type == inspect.Signature.empty:
logger.warning(
"Step %s has no return type annotation. Those are really helpful when "
"debugging, so we recommend them highly.",
result.__class__.__name__,
)
else:
try:
if expected_return_type is not None and not issubclass(
return_type, expected_return_type
):
raise ConfigurationError(
f"Step {result.name} returns {return_type}, but "
f"we expected {expected_return_type}."
)
except TypeError:
pass
return result
elif not optional:
# Not optional and not supplied, that's an error!
raise ConfigurationError(f"expected key {argument_name} for {class_name}")
else:
return default
# For StepIndexer, we just return as-is and hope the for the best.
# At worst, user will get an error at runtime if they are trying to index a step
# result that can't be indexed.
# TODO (epwalsh): we could check the return type of the wrapped step here
# and make sure that:
# 1. It's an index-able object,
# 2. The item in the index-able object matches `annotation`.
#
# But that's complex and might have false negatives.
elif type(popped_params) == StepIndexer:
return popped_params
# If the parameter type is a Python primitive, just pop it off
# using the correct casting pop_xyz operation.
elif annotation in {int, bool}:
if type(popped_params) in {int, bool}:
return annotation(popped_params)
else:
raise TypeError(
f"Expected {argument_name} to be {annotation.__name__}, "
f"found {popped_params} ({type(popped_params)})."
)
elif annotation == str:
# Strings are special because we allow casting from Path to str.
if isinstance(popped_params, str) or isinstance(popped_params, Path):
return str(popped_params) # type: ignore
else:
raise TypeError(
f"Expected {argument_name} to be a string, found {popped_params} ({type(popped_params)})"
)
elif annotation == float:
# Floats are special because in Python, you can put an int wherever you can put a float.
# https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html
if type(popped_params) in {int, float}:
return popped_params
else:
raise TypeError(f"Expected {argument_name} to be numeric.")
elif annotation == Path:
if isinstance(popped_params, (str, Path)):
return Path(popped_params)
else:
raise TypeError(
f"Expected {argument_name} to be a str or Path, found {popped_params} ({type(popped_params)})"
)
# This is special logic for handling types like Dict[str, TokenIndexer],
# List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
# which it creates by instantiating each value from_params and returning the resulting structure.
elif origin in {collections.abc.Mapping, Mapping, Dict, dict} and len(args) == 2:
value_cls = annotation.__args__[-1]
value_dict = {}
if not isinstance(popped_params, Mapping):
raise TypeError(
f"Expected {argument_name} to be a Mapping (probably a dict or a Params object) "
f"found {popped_params} ({type(popped_params)})."
)
for key, value_params in popped_params.items():
value_dict[key] = construct_arg(
str(value_cls),
argument_name + "." + key,
value_params,
value_cls,
_NO_DEFAULT,
)
return value_dict
elif origin in (Tuple, tuple):
value_list = []
value_types = list(annotation.__args__)
if value_types[-1] == Ellipsis:
# Variable length tuples, e.g. 'Tuple[int, ...]', we set value_types to '[int] * len(popped_params)'.
value_types = value_types[:-1] + [value_types[-2]] * (
len(popped_params) - len(annotation.__args__) + 1
)
for i, (value_cls, value_params) in enumerate(zip(value_types, popped_params)):
value = construct_arg(
str(value_cls),
argument_name + f".{i}",
value_params,
value_cls,
_NO_DEFAULT,
)
value_list.append(value)
return tuple(value_list)
elif origin in (Set, set) and len(args) == 1:
value_cls = annotation.__args__[0]
value_set = set()
for i, value_params in enumerate(popped_params):
value = construct_arg(
str(value_cls),
argument_name + f".{i}",
value_params,
value_cls,
_NO_DEFAULT,
)
value_set.add(value)
return value_set
elif origin == Union or isinstance(annotation, UnionType):
# Storing this so we can recover it later if we need to.
backup_params = deepcopy(popped_params)
# We'll try each of the given types in the union sequentially, returning the first one that
# succeeds.
error_chain: Optional[Exception] = None
for arg_annotation in args:
try:
return construct_arg(
str(arg_annotation),
argument_name,
popped_params,
arg_annotation,
default,
)
except (ValueError, TypeError, ConfigurationError, AttributeError) as e:
# Our attempt to construct the argument may have modified popped_params, so we
# restore it here.
popped_params = deepcopy(backup_params)
e.args = (f"While constructing an argument of type {arg_annotation}",) + e.args
e.__cause__ = error_chain
error_chain = e
# If none of them succeeded, we crash.
config_error = ConfigurationError(
f"Failed to construct argument {argument_name} with type {annotation}."
)
config_error.__cause__ = error_chain
raise config_error
elif origin == Lazy:
value_cls = args[0]
return Lazy(value_cls, params=deepcopy(popped_params)) # type: ignore
# For any other kind of iterable, we will just assume that a list is good enough, and treat
# it the same as List. This condition needs to be at the end, so we don't catch other kinds
# of Iterables with this branch.
elif origin in {collections.abc.Iterable, Iterable, List, list} and len(args) == 1:
value_cls = annotation.__args__[0]
value_list = []
for i, value_params in enumerate(popped_params):
value = construct_arg(
str(value_cls),
argument_name + f".{i}",
value_params,
value_cls,
_NO_DEFAULT,
)
value_list.append(value)
return value_list
elif (inspect.isclass(annotation) or inspect.isclass(origin)) and isinstance(
popped_params, Params
):
# Constructing arbitrary classes from params
arbitrary_class = origin or annotation
constructor_to_inspect = arbitrary_class.__init__
constructor_to_call = arbitrary_class
params_contain_step = _params_contain_step(popped_params)
kwargs = create_kwargs(constructor_to_inspect, arbitrary_class, popped_params)
from tango.step import WithUnresolvedSteps
if origin != Step and params_contain_step:
return WithUnresolvedSteps(constructor_to_call, *[], **kwargs)
else:
return constructor_to_call(**kwargs) # type: ignore
else:
# Pass it on as is and hope for the best. ¯\_(ツ)_/¯
if isinstance(popped_params, Params):
return popped_params.as_dict()
return popped_params
class FromParams(DetHashWithVersion):
"""
Mixin to give a :meth:`from_params` method to classes. We create a distinct base class for this
because sometimes we want non :class:`~tango.common.Registrable`
classes to be instantiatable ``from_params``.
"""
@classmethod
def from_params(
cls: Type[T],
params_: Union[Params, dict, str],
constructor_to_call: Optional[Callable[..., T]] = None,
constructor_to_inspect: Optional[Union[Callable[..., T], Callable[[T], None]]] = None,
**extras,
) -> T:
"""
This is the automatic implementation of ``from_params``. Any class that subclasses
from ``FromParams`` (or :class:`~tango.common.Registrable`,
which itself subclasses ``FromParams``) gets this
implementation for free. If you want your class to be instantiated from params in the
"obvious" way -- pop off parameters and hand them to your constructor with the same names --
this provides that functionality.
If you need more complex logic in your from ``from_params`` method, you'll have to implement
your own method that overrides this one.
The ``constructor_to_call`` and ``constructor_to_inspect`` arguments deal with a bit of
redirection that we do. We allow you to register particular ``@classmethods`` on a class as
the constructor to use for a registered name. This lets you, e.g., have a single
``Vocabulary`` class that can be constructed in two different ways, with different names
registered to each constructor. In order to handle this, we need to know not just the class
we're trying to construct (``cls``), but also what method we should inspect to find its
arguments (``constructor_to_inspect``), and what method to call when we're done constructing
arguments (``constructor_to_call``). These two methods are the same when you've used a
``@classmethod`` as your constructor, but they are ``different`` when you use the default
constructor (because you inspect ``__init__``, but call ``cls()``).
"""
from tango.common.registrable import (
Registrable, # import here to avoid circular imports
)
params = params_
logger.debug(
f"instantiating class {cls} from params {getattr(params, 'params', params)} "
f"and extras {set(extras.keys())}"
)
if params is None:
return None
if isinstance(params, str):
params = Params({"type": params})
if not isinstance(params, Params):
if isinstance(params, dict):
params = Params(params)
else:
raise ConfigurationError(
"from_params was passed a `params` object that was not a `Params`. This probably "
"indicates malformed parameters in a configuration file, where something that "
"should have been a dictionary was actually a list, or something else. "
f"This happened when constructing an object of type {cls}."
)
if issubclass(cls, Registrable) and not constructor_to_call:
# We know `cls` inherits from Registrable, so we'll use a cast to make mypy happy.
as_registrable = cast(Type[Registrable], cls)
if "type" in params and params["type"] not in as_registrable.list_available():
as_registrable.search_modules(params["type"])
# Resolve the subclass and constructor.
if is_base_registrable(cls) or "type" in params:
default_to_first_choice = as_registrable.default_implementation is not None
choice = params.pop_choice(
"type",
choices=as_registrable.list_available(),
default_to_first_choice=default_to_first_choice,
)
# We allow users to register methods and functions, not just classes.
# So we have to handle both here.
subclass_or_factory_func, constructor_name = as_registrable.resolve_class_name(
choice
)
if inspect.isclass(subclass_or_factory_func):
# We have an actual class.
subclass = subclass_or_factory_func
if constructor_name is not None:
constructor_to_inspect = cast(
Callable[..., T], getattr(subclass, constructor_name)
)
constructor_to_call = constructor_to_inspect
else:
constructor_to_inspect = subclass.__init__
constructor_to_call = subclass
else:
# We have a function that returns an instance of the class.
factory_func = cast(Callable[..., T], subclass_or_factory_func)
return_type = inspect.signature(factory_func).return_annotation
if return_type == inspect.Signature.empty:
subclass = cls
else:
subclass = return_type
constructor_to_inspect = factory_func
constructor_to_call = factory_func
else:
# Must be trying to instantiate the given class directly.
subclass = cls
constructor_to_inspect = cls.__init__
constructor_to_call = cast(Callable[..., T], cls)
if hasattr(subclass, "from_params"):
# We want to call subclass.from_params.
extras = create_extras(subclass, extras)
# mypy can't follow the typing redirection that we do, so we explicitly cast here.
retyped_subclass = cast(Type[T], subclass)
return retyped_subclass.from_params(
params,
constructor_to_call=constructor_to_call,
constructor_to_inspect=constructor_to_inspect,
**extras,
)
else:
# In some rare cases, we get a registered subclass that does _not_ have a
# from_params method (this happens with Activations, for instance, where we
# register pytorch modules directly). This is a bit of a hack to make those work,
# instead of adding a `from_params` method for them somehow. We just trust that
# you've done the right thing in passing your parameters, and nothing else needs to
# be recursively constructed.
kwargs = create_kwargs(constructor_to_inspect, subclass, params, extras) # type: ignore
return constructor_to_call(**kwargs) # type: ignore
else:
# This is not a base class, so convert our params and extras into a dict of kwargs.
# See the docstring for an explanation of what's going on here.
if not constructor_to_inspect:
constructor_to_inspect = cls.__init__
if not constructor_to_call:
constructor_to_call = cls
if constructor_to_inspect == object.__init__:
# This class does not have an explicit constructor, so don't give it any kwargs.
# Without this logic, create_kwargs will look at object.__init__ and see that
# it takes *args and **kwargs and look for those.
kwargs: Dict[str, Any] = {} # type: ignore[no-redef]
params.assert_empty(cls.__name__)
else:
# This class has a constructor, so create kwargs for it.
constructor_to_inspect = cast(Callable[..., T], constructor_to_inspect)
kwargs = create_kwargs(constructor_to_inspect, cls, params, extras)
return constructor_to_call(**kwargs) # type: ignore
def to_params(self) -> Params:
"""
Returns a ``Params`` object that can be used with ``.from_params()`` to recreate an
object just like it.
This relies on ``_to_params()``. If you need this in your custom ``FromParams`` class,
override ``_to_params()``, not this method.
"""
def replace_object_with_params(o: Any) -> Any:
if isinstance(o, FromParams):
return o.to_params().as_dict(quiet=True)
elif isinstance(o, (list, tuple, set)):
return [replace_object_with_params(i) for i in o]
elif isinstance(o, dict):
return {key: replace_object_with_params(value) for key, value in o.items()}
elif isinstance(o, Path):
return str(o)
elif o is None or isinstance(o, (str, float, int, bool)):
return o
else:
raise NotImplementedError(
f"Unexpected type encountered in to_params(): {o} ({type(o)})\n"
"You may need to implement a custom '_to_params()'."
)
return Params(replace_object_with_params(self._to_params()))
def _to_params(self) -> Dict[str, Any]:
"""
Returns a dictionary of parameters that, when turned into a ``Params`` object and
then fed to ``.from_params()``, will recreate this object.
You don't need to implement this all the time. Tango will let you know if you
need it.
"""
try:
return self.__dict__
except AttributeError:
raise NotImplementedError(
f"{self.__class__.__name__}._to_params() needs to be implemented"
)
================================================
FILE: tango/common/lazy.py
================================================
import copy
import inspect
from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, Union, cast
from .det_hash import CustomDetHash, DetHashWithVersion
from .params import Params
T = TypeVar("T")
class Lazy(Generic[T], CustomDetHash):
"""
This class is for use when constructing objects using :class:`~tango.common.FromParams`,
when an argument to a constructor has a `sequential dependency` with another argument to the same
constructor.
For example, in a ``Trainer`` class you might want to take a ``Model`` and an ``Optimizer`` as arguments,
but the ``Optimizer`` needs to be constructed using the parameters from the ``Model``. You can give
the type annotation ``Lazy[Optimizer]`` to the optimizer argument, then inside the constructor
call ``optimizer.construct(parameters=model.parameters)``.
This is only recommended for use when you have registered a ``@classmethod`` as the constructor
for your class, instead of using ``__init__``. Having a ``Lazy[]`` type annotation on an argument
to an ``__init__`` method makes your class completely dependent on being constructed using the
``FromParams`` pipeline, which is not a good idea.
The actual implementation here is incredibly simple; the logic that handles the lazy
construction is actually found in ``FromParams``, where we have a special case for a ``Lazy`` type
annotation.
Examples
--------
::
@classmethod
def my_constructor(
cls,
some_object: Lazy[MyObject],
optional_object: Lazy[MyObject] = None,
# or:
# optional_object: Optional[Lazy[MyObject]] = None,
optional_object_with_default: Optional[Lazy[MyObject]] = Lazy(MyObjectDefault),
required_object_with_default: Lazy[MyObject] = Lazy(MyObjectDefault),
) -> MyClass:
obj1 = some_object.construct()
obj2 = None if optional_object is None else optional_object.construct()
obj3 = None optional_object_with_default is None else optional_object_with_default.construct()
obj4 = required_object_with_default.construct()
"""
def __init__(
self,
constructor: Union[Type[T], Callable[..., T]],
params: Optional[Params] = None,
constructor_extras: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
self._constructor = constructor
self._params = params or Params({})
self._constructor_extras = constructor_extras or {}
self._constructor_extras.update(kwargs)
@property
def constructor(self) -> Callable[..., T]:
from tango.common.from_params import FromParams
if inspect.isclass(self._constructor) and issubclass(self._constructor, FromParams):
def constructor_to_use(**kwargs):
return self._constructor.from_params( # type: ignore[union-attr]
copy.deepcopy(self._params),
**kwargs,
)
return constructor_to_use
else:
return self._constructor
def construct(self, **kwargs) -> T:
"""
Call the constructor to create an instance of ``T``.
"""
# If there are duplicate keys between self._constructor_extras and kwargs,
# this will overwrite the ones in self._constructor_extras with what's in kwargs.
constructor_kwargs = {**self._constructor_extras, **kwargs}
return self.constructor(**constructor_kwargs)
def det_hash_object(self) -> Any:
from tango.common.from_params import FromParams
class_to_construct: Union[Type[T], Callable[..., T]] = self._constructor
if isinstance(class_to_construct, type) and issubclass(class_to_construct, FromParams):
params = copy.deepcopy(self._params)
if params is None:
params = Params({})
elif isinstance(params, str):
params = Params({"type": params})
elif isinstance(params, dict):
params = Params(params)
elif not isinstance(params, Params):
return None
from tango.common import Registrable
if issubclass(class_to_construct, Registrable):
as_registrable = cast(Type[Registrable], class_to_construct)
if "type" in params and params["type"] not in as_registrable.list_available():
as_registrable.search_modules(params["type"])
# Resolve the subclass and constructor.
from .from_params import is_base_registrable
if is_base_registrable(class_to_construct) or "type" in params:
default_to_first_choice = as_registrable.default_implementation is not None
choice = params.pop_choice(
"type",
choices=as_registrable.list_available(),
default_to_first_choice=default_to_first_choice,
)
subclass_or_factory_func, _ = as_registrable.resolve_class_name(choice)
if inspect.isclass(subclass_or_factory_func):
class_to_construct = subclass_or_factory_func
else:
# We have a function that returns an instance of the class.
factory_func = cast(Callable[..., T], subclass_or_factory_func)
return_type = inspect.signature(factory_func).return_annotation
if return_type != inspect.Signature.empty:
class_to_construct = return_type
if isinstance(class_to_construct, type) and issubclass(
class_to_construct, DetHashWithVersion
):
return class_to_construct.VERSION, self
else:
return self
================================================
FILE: tango/common/logging.py
================================================
"""
Tango makes heavy use of the :mod:`logging` module from the standard library to convey information to users.
When you're writing your own :class:`~tango.step.Step` implementations we encourage you to also use standard
Python logging as opposed to :func:`print` or other functions that write directly to ``stdout`` or ``stderr``.
This is easy enough since each :class:`~tango.step.Step` class already comes with its own logger:
:attr:`Step.logger `.
When using the `Tango CLI <./commands.html>`_ you can set the log level in several different ways:
1. Through a Tango `global settings <./commands.html#global-settings>`_ file.
2. With the environment variable ``TANGO_LOG_LEVEL``.
3. Or with the ``--log-level`` command-line option.
In some cases (like when running on `Beaker `_) you may also want
to enable `"file friendly logging" <#tango.common.logging.FILE_FRIENDLY_LOGGING>`_.
Configuring logging in your own CLI
-----------------------------------
If you're writing your own CLI that uses tango, you can utilize the :func:`initialize_logging()`
function to easily configure logging properly.
For example,
.. testcode::
from tango.common.logging import initialize_logging, teardown_logging
initialize_logging(log_level="info")
logger = logging.getLogger()
logger.info("Running script!")
teardown_logging()
.. testoutput::
:options: +ELLIPSIS
[...] INFO Running script! ...
If you want to have logs written to a file, you can use the :func:`file_handler` context manager.
Logging from worker processes or threads
----------------------------------------
If you have steps or other functions that spawn workers, and you want to enable logging within
those workers, you can call the :func:`initialize_worker_logging()` function to configure
logging within each worker. This assumes that you've called :func:`initialize_logging()` from the
main process (the tango CLI does this for you).
For example,
.. testcode::
import logging
import multiprocessing as mp
from tango import Step
from tango.common.logging import initialize_worker_logging
@Step.register("multiprocessing_step_example")
class MultiprocessingStep(Step):
def run(self, num_proc: int = 2) -> bool: # type: ignore
workers = []
for i in range(num_proc):
worker = mp.Process(target=_worker_function, args=(i,))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
return True
def _worker_function(worker_id: int):
initialize_worker_logging(worker_rank=worker_id)
logger = logging.getLogger(MultiprocessingStep.__name__)
logger.info("Hello from worker %d!", worker_id)
"""
import logging
import logging.handlers
import os
import pickle
import socketserver
import struct
import sys
import threading
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Callable, ClassVar, ContextManager, Generator, List, Optional, Union
import rich
from rich.console import Console, ConsoleRenderable, Group
from rich.highlighter import NullHighlighter
from rich.padding import Padding
from rich.syntax import Syntax
from rich.table import Table
from rich.text import Text
from .aliases import EnvVarNames, PathOrStr
from .exceptions import CancellationError, CliRunError, SigTermReceived
from .util import _parse_bool, _parse_optional_int
FILE_FRIENDLY_LOGGING: bool = _parse_bool(
os.environ.get(EnvVarNames.FILE_FRIENDLY_LOGGING.value, False)
)
"""
If this flag is set to ``True``, we remove special styling characters from log messages,
add newlines to :class:`~tango.common.tqdm.Tqdm` output even on an interactive terminal, and we slow
down :class:`~tango.common.tqdm.Tqdm`'s output to only once every 10 seconds.
.. attention::
Unfortunately this won't affect ``tqdm`` output from other libraries that don't use
Tango's :class:`~tango.common.tqdm.Tqdm` wrapper.
By default, it is set to ``False``. It can be changed by setting the corresponding environment
variable (``FILE_FRIENDLY_LOGGING``) or field in a :class:`~tango.__main__.TangoGlobalSettings`
file (``file_friendly_logging``) to "true" or "false",
or from the command line with the ``--file-friendly-logging`` flag.
For example,
.. code-block::
$ tango --file-friendly-logging run ...
"""
TANGO_LOG_LEVEL: Optional[str] = os.environ.get(EnvVarNames.LOG_LEVEL.value, None)
"""
The log level to use globally. The value can be set from the corresponding environment variable
(``TANGO_LOG_LEVEL``) or field in a :class:`~tango.__main__.TangoGlobalSettings` file (``log_level``),
or from the command line with the ``--log-level`` option.
Possible values are "debug", "info", "warning", or "error" (not case sensitive).
For example,
.. code-block::
$ tango --log-level info run ...
.. note::
This does not affect the :data:`~tango.common.logging.cli_logger`
or logs from :class:`~tango.common.Tqdm` progress bars.
"""
TANGO_CONSOLE_WIDTH: Optional[int] = _parse_optional_int(
os.environ.get(EnvVarNames.CONSOLE_WIDTH.value, None)
)
# Click logger disabled by default in case nobody calls initialize_logging().
TANGO_CLI_LOGGER_ENABLED: bool = _parse_bool(
os.environ.get(EnvVarNames.CLI_LOGGER_ENABLED.value, False)
)
# Keep track of exceptions logged so we don't log duplicates from our custom excepthook.
_EXCEPTIONS_LOGGED: List[BaseException] = []
class LevelFilter(logging.Filter):
"""
Filters out everything that is above `max_level` or higher. This is meant to be used
with a stdout handler when a stderr handler is also configured. That way WARNING or ERROR
messages aren't duplicated.
"""
def __init__(self, max_level: int, min_level: Optional[int] = None, name=""):
self.max_level = max_level
self.min_level = min_level
super().__init__(name)
def filter(self, record):
if self.min_level is not None:
return self.min_level <= record.levelno <= self.max_level
else:
return record.levelno <= self.max_level
class CliFilter(logging.Filter):
def __init__(self, filter_out: bool):
self.filter_out = filter_out
def filter(self, record):
if self.filter_out:
return record.name != "tango.__main__"
else:
return record.name == "tango.__main__"
class WorkerLogFilter(logging.Filter):
def __init__(self, rank=-1):
super().__init__()
self._rank = rank
def filter(self, record):
if self._rank != -1:
record.msg = f"[rank {self._rank}] {record.msg}"
return True
class PrefixLogFilter(logging.Filter):
def __init__(self, prefix):
super().__init__()
self._prefix = prefix
def filter(self, record):
if not isinstance(record.msg, str):
return True
if record.name == "tango.__main__":
from rich.markup import escape
record.msg = escape(f"[{self._prefix}] ") + record.msg
else:
record.msg = f"[{self._prefix}] {record.msg}"
return True
class LogRecordStreamHandler(socketserver.StreamRequestHandler):
"""Handler for a streaming logging request.
This basically logs the record using whatever logging policy is
configured locally.
Taken from
`the logging cookbook `_.
"""
def handle(self):
"""
Handle multiple requests - each expected to be a 4-byte length,
followed by the LogRecord in pickle format. Logs the record
according to whatever policy is configured locally.
"""
while True:
chunk = self.connection.recv(4)
if len(chunk) < 4:
break
slen = struct.unpack(">L", chunk)[0]
chunk = self.connection.recv(slen)
while len(chunk) < slen:
chunk = chunk + self.connection.recv(slen - len(chunk))
obj = self.unPickle(chunk)
record = logging.makeLogRecord(obj)
self.handleLogRecord(record)
def unPickle(self, data):
return pickle.loads(data)
def handleLogRecord(self, record):
name = record.name
logger = logging.getLogger(name)
# N.B. EVERY record gets logged. This is because Logger.handle
# is normally called AFTER logger-level filtering. If you want
# to do filtering, do it at the client end to save wasting
# cycles and network bandwidth!
logger.handle(record)
class LogRecordSocketReceiver(socketserver.ThreadingTCPServer):
"""
Simple TCP socket-based logging receiver.
Taken from
`the logging cookbook `_.
"""
allow_reuse_address = True
def __init__(self, host: str, port: int = 0):
super().__init__((host, port), LogRecordStreamHandler)
self.abort = False
self.timeout = 0.2
def serve_until_stopped(self):
import select
while not self.abort:
rd, _, _ = select.select([self.socket.fileno()], [], [], self.timeout)
if rd:
self.handle_request()
_LOGGING_PREFIX: str = os.environ.get(EnvVarNames.LOGGING_PREFIX.value, "")
_LOGGING_HOST: str = os.environ.get(EnvVarNames.LOGGING_HOST.value, "localhost")
_LOGGING_PORT: Optional[int] = _parse_optional_int(
os.environ.get(EnvVarNames.LOGGING_PORT.value, None)
)
_LOGGING_SERVER: Optional[LogRecordSocketReceiver] = None
_LOGGING_SERVER_THREAD: Optional[threading.Thread] = None
class RichHandler(logging.Handler):
"""
Adapted from
https://github.com/Textualize/rich/blob/master/rich/logging.py
"""
KEYWORDS: ClassVar[Optional[List[str]]] = [
"GET",
"POST",
"HEAD",
"PUT",
"DELETE",
"OPTIONS",
"TRACE",
"PATCH",
]
def __init__(
self,
level: Union[int, str] = logging.NOTSET,
console: Optional[Console] = None,
*,
markup: bool = False,
log_time_format: Union[str, Callable[[datetime], str]] = "[%x %X]",
keywords: Optional[List[str]] = None,
show_time: bool = True,
show_level: bool = True,
show_path: bool = True,
) -> None:
super().__init__(level=level)
self.console = console or rich.get_console()
self.highlighter = NullHighlighter()
self.time_format = log_time_format
self.markup = markup
self.keywords = keywords or self.KEYWORDS
self.show_time = show_time
self.show_level = show_level
self.show_path = show_path
def emit(self, record: logging.LogRecord) -> None:
if isinstance(record.msg, (Syntax, Table)):
self.console.print(Padding(record.msg, (1, 0, 1, 1)))
elif hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"):
self.console.print(record.msg)
else:
message = self.format(record)
message_renderable = self.render_message(record, message)
log_renderable = self.render(record=record, message_renderable=message_renderable)
try:
self.console.print(log_renderable)
except Exception:
self.handleError(record)
def render_message(self, record: logging.LogRecord, message: str) -> ConsoleRenderable:
use_markup = getattr(record, "markup", self.markup)
message_text = Text.from_markup(message) if use_markup else Text(message)
if self.show_path and record.exc_info is None:
message_text.end = " "
highlighter = getattr(record, "highlighter", self.highlighter)
if highlighter:
message_text = highlighter(message_text)
if self.keywords is None:
self.keywords = self.KEYWORDS
if self.keywords:
message_text.highlight_words(self.keywords, "logging.keyword")
return message_text
def get_time_text(self, record: logging.LogRecord) -> Text:
log_time = datetime.fromtimestamp(record.created)
time_str: str
if callable(self.time_format):
time_str = self.time_format(log_time)
else:
time_str = log_time.strftime(self.time_format)
return Text(time_str, style="log.time", end=" ")
def get_level_text(self, record: logging.LogRecord) -> Text:
level_name = record.levelname
level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}")
level_text.style = "log.level"
level_text.end = " "
return level_text
def get_path_text(self, record: logging.LogRecord, length_so_far: int) -> Text:
path = Path(record.pathname)
for package_root in sys.path:
try:
path = path.relative_to(Path(package_root))
break
except ValueError:
continue
text = f"{path}:{record.lineno}"
length_after_wrap = length_so_far % self.console.width
return Text(
text.rjust(self.console.width - length_after_wrap - 3),
style="log.path",
)
def render(
self,
*,
record: logging.LogRecord,
message_renderable: ConsoleRenderable,
) -> ConsoleRenderable:
components: List[ConsoleRenderable] = []
if self.show_time:
components.append(self.get_time_text(record))
if self.show_level:
components.append(self.get_level_text(record))
components.append(message_renderable)
if self.show_path and record.exc_info is None:
try:
length_so_far = sum(len(x) for x in components) # type: ignore
except TypeError:
pass
else:
components.append(self.get_path_text(record, length_so_far))
return Group(*components)
def get_handler(
level: int,
stderr: bool = False,
enable_markup: bool = False,
show_time: bool = True,
show_level: bool = True,
show_path: bool = True,
) -> logging.Handler:
console = Console(
color_system="auto" if not FILE_FRIENDLY_LOGGING else None,
stderr=stderr,
width=TANGO_CONSOLE_WIDTH,
soft_wrap=True,
)
if TANGO_CONSOLE_WIDTH is None and not console.is_terminal:
console.width = 160
handler = RichHandler(
level=level,
console=console,
markup=enable_markup,
show_time=show_time,
show_level=show_level,
show_path=show_path,
)
return handler
cli_logger = logging.getLogger("tango.__main__")
"""
A logger that emits messages directly to stdout/stderr using
`rich `_'s
:class:`~rich.console.Console` class.
This provides a convenient way for command-line apps to log pretty, styled messages
uses the `markup style `_ provided by `rich`.
"""
cli_logger.propagate = False
cli_logger.disabled = TANGO_CLI_LOGGER_ENABLED
def excepthook(exctype, value, traceback):
"""
Used to patch `sys.excepthook` in order to log exceptions.
"""
log_exc_info(exctype, value, traceback)
def log_exception(exc: Optional[BaseException] = None, logger: Optional[logging.Logger] = None):
if exc is None:
et, ev, tb = sys.exc_info()
log_exc_info(et, ev, tb, logger=logger)
else:
log_exc_info(exc.__class__, exc, exc.__traceback__, logger=logger)
def log_exc_info(exctype, value, traceback, logger: Optional[logging.Logger] = None):
global _EXCEPTIONS_LOGGED
if value not in _EXCEPTIONS_LOGGED:
_EXCEPTIONS_LOGGED.append(value)
logger = logger or logging.getLogger()
if isinstance(value, CliRunError):
msg = str(value)
if msg:
cli_logger.error(msg)
elif isinstance(value, (KeyboardInterrupt, CancellationError)):
logger.error("%s: %s", exctype.__name__, value)
else:
logger.error(
"Uncaught exception",
exc_info=(exctype, value, traceback),
extra={"highlighter": rich.highlighter.ReprHighlighter()},
)
def initialize_logging(
*,
log_level: Optional[str] = None,
enable_cli_logs: Optional[bool] = None,
file_friendly_logging: Optional[bool] = None,
):
"""
Initialize logging, which includes setting the global log level, format, and configuring
handlers.
.. tip::
This should be called as early on in your script as possible.
.. tip::
You should also call :func:`teardown_logging()` as the end of your script.
.. tip::
For worker threads/processes, use :func:`initialize_worker_logging()` instead.
:param log_level:
Can be one of "debug", "info", "warning", "error". Defaults to the value
of :data:`TANGO_LOG_LEVEL`, if set, or "error".
:param enable_cli_logs:
Set to ``True`` to enable messages from the :data:`cli_logger`.
:param file_friendly_logging:
Enable or disable file friendly logging. Defaults to the value of :data:`FILE_FRIENDLY_LOGGING`.
"""
import multiprocessing as mp
is_main_process: bool
if hasattr(mp, "parent_process"): # python 3.8 or greater
is_main_process = mp.parent_process() is None # type: ignore
else:
is_main_process = mp.current_process().name == "MainProcess"
_initialize_logging(
log_level=log_level,
enable_cli_logs=enable_cli_logs,
file_friendly_logging=file_friendly_logging,
main_process=is_main_process,
)
def initialize_worker_logging(worker_rank: Optional[int] = None):
"""
Initialize logging in a worker thread/process.
:param worker_rank:
The rank/ID of the worker.
"""
if worker_rank is not None:
if worker_rank != -1:
prefix = f"rank {worker_rank}"
else:
prefix = None
else:
prefix = None
return initialize_prefix_logging(prefix=prefix, main_process=False)
def initialize_prefix_logging(
*, log_level: Optional[str] = None, prefix: Optional[str] = None, main_process: bool = False
):
"""
Initialize logging with a prefix.
:param log_level:
Can be one of "debug", "info", "warning", "error". Defaults to the value
of :data:`TANGO_LOG_LEVEL`, if set, or "error".
:param prefix:
The string prefix to add to the log message.
:param main_process:
Whether it is for the main/worker process.
"""
return _initialize_logging(log_level=log_level, prefix=prefix, main_process=main_process)
def _initialize_logging(
*,
log_level: Optional[str] = None,
enable_cli_logs: Optional[bool] = None,
file_friendly_logging: Optional[bool] = None,
prefix: Optional[str] = None,
main_process: bool = True,
):
global FILE_FRIENDLY_LOGGING, TANGO_LOG_LEVEL, TANGO_CLI_LOGGER_ENABLED
global _LOGGING_HOST, _LOGGING_PORT, _LOGGING_SERVER, _LOGGING_SERVER_THREAD, _LOGGING_PREFIX
if log_level is None:
log_level = TANGO_LOG_LEVEL
if log_level is None:
log_level = "warning"
if file_friendly_logging is None:
file_friendly_logging = FILE_FRIENDLY_LOGGING
if enable_cli_logs is None:
enable_cli_logs = TANGO_CLI_LOGGER_ENABLED
if prefix:
prefix = _LOGGING_PREFIX + " " + prefix if _LOGGING_PREFIX else prefix
else:
prefix = _LOGGING_PREFIX
level = logging._nameToLevel[log_level.upper()]
# Update global flags and corresponding environment variables, if necessary,
# so that child processes can read the environment variables to determine the right
# settings.
TANGO_LOG_LEVEL = log_level
os.environ[EnvVarNames.LOG_LEVEL.value] = log_level
if file_friendly_logging is not None:
FILE_FRIENDLY_LOGGING = file_friendly_logging
os.environ[EnvVarNames.FILE_FRIENDLY_LOGGING.value] = str(file_friendly_logging).lower()
if enable_cli_logs is not None:
TANGO_CLI_LOGGER_ENABLED = enable_cli_logs
os.environ[EnvVarNames.CLI_LOGGER_ENABLED.value] = str(enable_cli_logs).lower()
from .tqdm import logger as tqdm_logger
# Handle special cases for specific loggers:
# These loggers emit too many messages, so we tell them to be quiet unless they have something
# important to say.
for loud_logger in {"filelock", "sqlitedict"}:
logging.getLogger(loud_logger).setLevel(max(level, logging.WARNING))
# We always want to see all CLI messages if we're running from the command line, and none otherwise.
cli_logger.setLevel(logging.DEBUG)
cli_logger.disabled = not enable_cli_logs
# We also want to enable the tqdm logger so that the progress bar lines end up in the log file.
tqdm_logger.setLevel(logging.DEBUG)
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.handlers.clear()
if main_process:
# Create stdout and stderr handlers so that we can route DEBUG and INFO
# messages to stdout, and WARNING and ERROR messages to stderr.
stdout_handler = get_handler(level)
stdout_handler.addFilter(LevelFilter(logging.INFO))
stderr_handler = get_handler(max(level, logging.WARNING), stderr=True)
stderr_handler.addFilter(LevelFilter(logging.CRITICAL, min_level=logging.WARNING))
root_logger.addHandler(stdout_handler)
root_logger.addHandler(stderr_handler)
# Configure cli_logger so that if log level <= INFO, it will behave
# like a regular logger, otherwise it prints directly to stdout.
cli_logger.handlers.clear()
if enable_cli_logs:
for handler_level in (logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR):
cli_handler = get_handler(
handler_level,
stderr=handler_level >= logging.WARNING,
enable_markup=True,
show_time=level <= handler_level,
show_level=(level <= handler_level) or handler_level >= logging.WARNING,
show_path=level <= handler_level,
)
cli_handler.addFilter(LevelFilter(handler_level))
cli_logger.addHandler(cli_handler)
# Add prefix.
if prefix:
for logger in (root_logger, cli_logger, tqdm_logger):
for handler in logger.handlers:
handler.addFilter(PrefixLogFilter(prefix))
# Main process: set formatter and handlers, initialize logging socket and server.
# Set up logging socket to emit log records from worker processes/threads.
# Inspired by:
# https://docs.python.org/3.8/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network
_LOGGING_SERVER = LogRecordSocketReceiver(_LOGGING_HOST, 0)
_LOGGING_PORT = _LOGGING_SERVER.server_address[1]
os.environ[EnvVarNames.LOGGING_PORT.value] = str(_LOGGING_PORT)
_LOGGING_SERVER_THREAD = threading.Thread(
target=_LOGGING_SERVER.serve_until_stopped, daemon=True
)
_LOGGING_SERVER_THREAD.start()
else:
# Child process: set handler and level, no need to set formatting since only raw log records
# will be sent to the logging socket.
if _LOGGING_PORT is None:
raise ValueError(
"missing logging socket configuration, "
"did you forget to call 'initialize_logging()' from the main process?"
)
socket_handler = logging.handlers.SocketHandler(_LOGGING_HOST, _LOGGING_PORT)
if prefix:
socket_handler.addFilter(PrefixLogFilter(prefix))
for logger in (root_logger, cli_logger, tqdm_logger):
logger.handlers.clear()
logger.addHandler(socket_handler)
# Write uncaught exceptions to the logs.
sys.excepthook = excepthook
# Ensure warnings issued by the 'warnings' module will be redirected to the logging system.
logging.captureWarnings(True)
def teardown_logging():
"""
Cleanup any logging fixtures created from :func:`initialize_logging()`. Should
be called at the end of your script.
"""
global _LOGGING_HOST, _LOGGING_PORT, _LOGGING_SERVER, _LOGGING_SERVER_THREAD
if _LOGGING_SERVER is not None:
_LOGGING_SERVER.abort = True
if _LOGGING_SERVER_THREAD is not None:
_LOGGING_SERVER_THREAD.join()
_LOGGING_SERVER_THREAD = None
if _LOGGING_SERVER is not None:
_LOGGING_SERVER = None
sys.excepthook = sys.__excepthook__ # type: ignore[assignment]
@contextmanager
def insert_handlers(*handlers: logging.Handler) -> Generator[None, None, None]:
"""
A context manager that can be used to route logs to a specific handler temporarily.
"""
global _EXCEPTIONS_LOGGED
root_logger = logging.getLogger()
from .tqdm import logger as tqdm_logger
for logger in (root_logger, cli_logger, tqdm_logger):
for handler in handlers:
logger.addHandler(handler)
try:
yield None
except BaseException as e:
if not isinstance(
e, (CliRunError, KeyboardInterrupt, SigTermReceived)
): # don't need tracebacks for these
log_exception(e)
_EXCEPTIONS_LOGGED.append(e)
raise
finally:
for logger in (root_logger, cli_logger, tqdm_logger):
for handler in handlers:
logger.removeHandler(handler)
def file_handler(filepath: PathOrStr) -> ContextManager[None]:
"""
A context manager that can be used to route logs to a file by adding a
:class:`logging.FileHandler` to the root logger's handlers.
For example,
.. code-block::
from tango.common.logging import initialize_logging, file_handler, teardown_logging
initialize_logging(log_level="info")
logger = logging.getLogger()
logger.info("Hi!")
with file_handler("log.out"):
logger.info("This message should also go into 'log.out'")
teardown_logging()
"""
log_file = open(filepath, "w")
handlers: List[logging.Handler] = []
console = Console(
color_system=None,
file=log_file,
force_terminal=False,
width=TANGO_CONSOLE_WIDTH or 160,
)
for is_cli_handler in (True, False):
handler = RichHandler(
console=console,
markup=is_cli_handler,
)
handler.addFilter(CliFilter(filter_out=not is_cli_handler))
handlers.append(handler)
return insert_handlers(*handlers)
================================================
FILE: tango/common/params.py
================================================
import copy
import json
import logging
import os
import zlib
from collections import OrderedDict
from collections.abc import MutableMapping
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Set, TypeVar, Union
import yaml
from rjsonnet import evaluate_file, evaluate_snippet
from .aliases import PathOrStr
from .exceptions import ConfigurationError
from .util import could_be_class_name
logger = logging.getLogger(__name__)
def infer_and_cast(value: Any):
"""
In some cases we'll be feeding params dicts to functions we don't own;
for example, PyTorch optimizers. In that case we can't use ``pop_int``
or similar to force casts (which means you can't specify ``int`` parameters
using environment variables). This function takes something that looks JSON-like
and recursively casts things that look like (bool, int, float) to (bool, int, float).
"""
if isinstance(value, (int, float, bool)):
# Already one of our desired types, so leave as is.
return value
elif isinstance(value, list):
# Recursively call on each list element.
return [infer_and_cast(item) for item in value]
elif isinstance(value, dict):
# Recursively call on each dict value.
return {key: infer_and_cast(item) for key, item in value.items()}
elif isinstance(value, str):
# If it looks like a bool, make it a bool.
if value.lower() == "true":
return True
elif value.lower() == "false":
return False
else:
# See if it could be an int.
try:
return int(value)
except ValueError:
pass
# See if it could be a float.
try:
return float(value)
except ValueError:
# Just return it as a string.
return value
else:
raise ValueError(f"cannot infer type of {value}")
def _is_encodable(value: str) -> bool:
"""
We need to filter out environment variables that can't
be unicode-encoded to avoid a "surrogates not allowed"
error in jsonnet.
"""
# Idiomatically you'd like to not check the != b""
# but mypy doesn't like that.
return (value == "") or (value.encode("utf-8", "ignore") != b"")
def _environment_variables() -> Dict[str, str]:
"""
Wraps ``os.environ`` to filter out non-encodable values.
"""
return {key: value for key, value in os.environ.items() if _is_encodable(value)}
T = TypeVar("T", dict, list)
def with_overrides(original: T, overrides_dict: Dict[str, Any], prefix: str = "") -> T:
merged: T
keys: Union[Iterable[str], Iterable[int]]
if isinstance(original, list):
merged = [None] * len(original)
keys = range(len(original))
elif isinstance(original, dict):
merged = {}
keys = chain(
original.keys(), (k for k in overrides_dict if "." not in k and k not in original)
)
else:
if prefix:
raise ValueError(
f"overrides for '{prefix[:-1]}.*' expected list or dict in original, "
f"found {type(original)} instead"
)
else:
raise ValueError(f"expected list or dict, found {type(original)} instead")
used_override_keys: Set[str] = set()
for key in keys:
if str(key) in overrides_dict:
merged[key] = copy.deepcopy(overrides_dict[str(key)])
used_override_keys.add(str(key))
else:
overrides_subdict = {}
for o_key in overrides_dict:
if o_key.startswith(f"{key}."):
overrides_subdict[o_key[len(f"{key}.") :]] = overrides_dict[o_key]
used_override_keys.add(o_key)
if overrides_subdict:
merged[key] = with_overrides(
original[key], overrides_subdict, prefix=prefix + f"{key}."
)
else:
merged[key] = copy.deepcopy(original[key])
unused_override_keys = [prefix + key for key in set(overrides_dict.keys()) - used_override_keys]
if unused_override_keys:
raise ValueError(f"overrides dict contains unused keys: {unused_override_keys}")
return merged
def parse_overrides(
serialized_overrides: str, ext_vars: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
if serialized_overrides:
ext_vars = {**_environment_variables(), **(ext_vars or {})}
return json.loads(evaluate_snippet("", serialized_overrides, ext_vars=ext_vars))
else:
return {}
def _is_dict_free(obj: Any) -> bool:
"""
Returns False if obj is a dict, or if it's a list with an element that _has_dict.
"""
if isinstance(obj, dict):
return False
elif isinstance(obj, list):
return all(_is_dict_free(item) for item in obj)
else:
return True
def pop_choice(
params: Dict[str, Any],
key: str,
choices: List[Any],
default_to_first_choice: bool = False,
history: str = "?.",
allow_class_names: bool = True,
) -> Any:
"""
Performs the same function as ``Params.pop_choice``, but is required in order to deal with
places that the Params object is not welcome, such as inside Keras layers. See the docstring
of that method for more detail on how this function works.
This method adds a ``history`` parameter, in the off-chance that you know it, so that we can
reproduce ``Params.pop_choice`` exactly. We default to using "?." if you don't know the
history, so you'll have to fix that in the log if you want to actually recover the logged
parameters.
"""
value = Params(params, history).pop_choice(
key, choices, default_to_first_choice, allow_class_names=allow_class_names
)
return value
def _replace_none(params: Any) -> Any:
if isinstance(params, str) and params == "None":
return None
elif isinstance(params, (dict, Params)):
if isinstance(params, Params):
params = params.as_dict(quiet=True)
for key, value in params.items():
params[key] = _replace_none(value)
return params
elif isinstance(params, list):
return [_replace_none(value) for value in params]
return params
def remove_keys_from_params(params: "Params", keys: List[str] = ["pretrained_file", "initializer"]):
if isinstance(params, Params): # The model could possibly be a string, for example.
param_keys = params.keys()
for key in keys:
if key in param_keys:
del params[key]
for value in params.values():
if isinstance(value, Params):
remove_keys_from_params(value, keys)
elif isinstance(value, list):
for item in value:
if isinstance(item, Params):
remove_keys_from_params(item, keys)
class Params(MutableMapping):
"""
A :class:`~collections.abc.MutableMapping` that represents a parameter dictionary with a history,
and contains other functionality around parameter passing and validation for AI2 Tango.
There are currently two benefits of a ``Params`` object over a plain dictionary for parameter
passing:
1. We handle a few kinds of parameter validation, including making sure that parameters
representing discrete choices actually have acceptable values, and making sure no extra
parameters are passed.
2. We log all parameter reads, including default values. This gives a more complete
specification of the actual parameters used than is given in a JSON file, because
those may not specify what default values were used, whereas this will log them.
.. important::
The convention for using a ``Params`` object in Tango is that you will consume the parameters
as you read them, so that there are none left when you've read everything you expect. This
lets us easily validate that you didn't pass in any ``extra`` parameters, just by making sure
that the parameter dictionary is empty. You should do this when you're done handling
parameters, by calling :meth:`Params.assert_empty()`.
"""
# This allows us to check for the presence of "None" as a default argument,
# which we require because we make a distinction between passing a value of "None"
# and passing no value to the default parameter of "pop".
DEFAULT = object()
def __init__(self, params: "MutableMapping[str, Any]", history: str = "") -> None:
if isinstance(params, Params):
self.params: MutableMapping = params.params
else:
self.params = _replace_none(params)
self.history = history
def pop(self, key: str, default: Any = DEFAULT, keep_as_dict: bool = False) -> Any:
"""
Performs the functionality associated with ``dict.pop(key)``, along with checking for
returned dictionaries, replacing them with Param objects with an updated history
(unless keep_as_dict is True, in which case we leave them as dictionaries).
If ``key`` is not present in the dictionary, and no default was specified, we raise a
:class:`~tango.common.exceptions.ConfigurationError`, instead of the typical ``KeyError``.
"""
if default is self.DEFAULT:
try:
value = self.params.pop(key)
except KeyError:
msg = f'key "{key}" is required'
if self.history:
msg += f' at location "{self.history}"'
raise ConfigurationError(msg)
else:
value = self.params.pop(key, default)
logger.debug(f"{self.history}{key} = {value}")
if keep_as_dict or _is_dict_free(value):
return value
else:
return self._check_is_dict(key, value)
def pop_int(self, key: str, default: Any = DEFAULT) -> Optional[int]:
"""
Performs a pop and coerces to an int.
"""
value = self.pop(key, default)
if value is None:
return None
else:
return int(value)
def pop_float(self, key: str, default: Any = DEFAULT) -> Optional[float]:
"""
Performs a pop and coerces to a float.
"""
value = self.pop(key, default)
if value is None:
return None
else:
return float(value)
def pop_bool(self, key: str, default: Any = DEFAULT) -> Optional[bool]:
"""
Performs a pop and coerces to a bool.
"""
value = self.pop(key, default)
if value is None:
return None
elif isinstance(value, bool):
return value
elif value == "true":
return True
elif value == "false":
return False
else:
raise ValueError("Cannot convert variable to bool: " + value)
def get(self, key: str, default: Any = DEFAULT):
"""
Performs the functionality associated with ``dict.get(key)`` but also checks for returned
dicts and returns a ``Params`` object in their place with an updated history.
"""
default = None if default is self.DEFAULT else default
value = self.params.get(key, default)
return self._check_is_dict(key, value)
def pop_choice(
self,
key: str,
choices: List[Any],
default_to_first_choice: bool = False,
allow_class_names: bool = True,
) -> Any:
"""
Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of
the given choices. Note that this ``pops`` the key from params, modifying the dictionary,
consistent with how parameters are processed in this codebase.
:param key:
Key to get the value from in the param dictionary
:param choices:
A list of valid options for values corresponding to ``key``. For example, if you're
specifying the type of encoder to use for some part of your model, the choices might be
the list of encoder classes we know about and can instantiate. If the value we find in
the param dictionary is not in ``choices``, we raise a
:class:`~tango.common.exceptions.ConfigurationError`, because
the user specified an invalid value in their parameter file.
:param default_to_first_choice:
If this is ``True``, we allow the ``key`` to not be present in the parameter
dictionary. If the key is not present, we will use the return as the value the first
choice in the ``choices`` list. If this is ``False``, we raise a
:class:`~tango.common.exceptions.ConfigurationError`, because
specifying the ``key`` is required (e.g., you ``have`` to
specify your model class when running an experiment, but you can feel free to use
default settings for encoders if you want).
:param allow_class_names:
If this is ``True``, then we allow unknown choices that look like fully-qualified class names.
This is to allow e.g. specifying a model type as ``my_library.my_model.MyModel``
and importing it on the fly. Our check for "looks like" is extremely lenient
and consists of checking that the value contains a '.'.
"""
default = choices[0] if default_to_first_choice else self.DEFAULT
value = self.pop(key, default)
ok_because_class_name = allow_class_names and could_be_class_name(value)
if value not in choices and not ok_because_class_name:
key_str = self.history + key
message = (
f"'{value}' not in acceptable choices for {key_str}: {choices}. "
"You should either use the --include-package flag to make sure the correct module "
"is loaded, or use a fully qualified class name in your config file like "
"""{"model": "my_module.models.MyModel"} to have it imported automatically."""
)
raise ConfigurationError(message)
return value
def as_dict(self, quiet: bool = False, infer_type_and_cast: bool = False):
"""
Sometimes we need to just represent the parameters as a dict, for instance when we pass
them to PyTorch code.
:param quiet:
Whether to log the parameters before returning them as a dict.
:param infer_type_and_cast:
If ``True``, we infer types and cast (e.g. things that look like floats to floats).
"""
if infer_type_and_cast:
params_as_dict = infer_and_cast(self.params)
else:
params_as_dict = self.params
if quiet:
return params_as_dict
def log_recursively(parameters, history):
for key, value in parameters.items():
if isinstance(value, dict):
new_local_history = history + key + "."
log_recursively(value, new_local_history)
else:
logger.debug(f"{history}{key} = {value}")
log_recursively(self.params, self.history)
return params_as_dict
def as_flat_dict(self) -> Dict[str, Any]:
"""
Returns the parameters of a flat dictionary from keys to values.
Nested structure is collapsed with periods.
"""
flat_params = {}
def recurse(parameters, path):
for key, value in parameters.items():
newpath = path + [key]
if isinstance(value, dict):
recurse(value, newpath)
else:
flat_params[".".join(newpath)] = value
recurse(self.params, [])
return flat_params
def duplicate(self) -> "Params":
"""
Uses ``copy.deepcopy()`` to create a duplicate (but fully distinct)
copy of these Params.
"""
return copy.deepcopy(self)
def assert_empty(self, name: str):
"""
Raises a :class:`~tango.common.exceptions.ConfigurationError` if ``self.params`` is not empty.
We take ``name`` as an argument so that the error message gives some idea of where an error
happened, if there was one. For example, ``name`` could be the name of the ``calling`` class
that got extra parameters (if there are any).
"""
if self.params:
raise ConfigurationError("Extra parameters passed to {}: {}".format(name, self.params))
def __getitem__(self, key):
if key in self.params:
return self._check_is_dict(key, self.params[key])
else:
raise KeyError(str(key))
def __setitem__(self, key, value):
self.params[key] = value
def __delitem__(self, key):
del self.params[key]
def __iter__(self):
return iter(self.params)
def __len__(self):
return len(self.params)
def _check_is_dict(self, new_history, value):
if isinstance(value, dict):
new_history = self.history + new_history + "."
return Params(value, history=new_history)
if isinstance(value, list):
value = [self._check_is_dict(f"{new_history}.{i}", v) for i, v in enumerate(value)]
return value
@classmethod
def from_file(
cls,
params_file: PathOrStr,
params_overrides: Union[str, Dict[str, Any]] = "",
ext_vars: Optional[dict] = None,
) -> "Params":
"""
Load a ``Params`` object from a configuration file.
:param params_file:
The path to the configuration file to load. Can be JSON, Jsonnet, or YAML.
:param params_overrides:
A dict of overrides that can be applied to final object.
e.g. ``{"model.embedding_dim": 10}`` will change the value of "embedding_dim"
within the "model" object of the config to 10. If you wanted to override the entire
"model" object of the config, you could do ``{"model": {"type": "other_type", ...}}``.
:param ext_vars:
Our config files are Jsonnet, which allows specifying external variables
for later substitution. Typically we substitute these using environment
variables; however, you can also specify them here, in which case they
take priority over environment variables.
e.g. ``{"HOME_DIR": "/Users/allennlp/home"}``
"""
if ext_vars is None:
ext_vars = {}
# redirect to cache, if necessary
from cached_path import cached_path
params_file: Path = Path(cached_path(params_file))
if not params_file.is_file():
raise FileNotFoundError(params_file)
file_dict: Dict[str, Any]
if params_file.suffix in {".yml", ".yaml"}:
with open(params_file) as f:
file_dict = yaml.safe_load(f)
else:
# Fall back to JSON/Jsonnet.
ext_vars = {**_environment_variables(), **ext_vars}
json_str = evaluate_file(params_file.name, str(params_file.parent), ext_vars=ext_vars)
file_dict = json.loads(json_str)
if isinstance(params_overrides, dict):
params_overrides = json.dumps(params_overrides)
overrides_dict = parse_overrides(params_overrides, ext_vars=ext_vars)
if overrides_dict:
param_dict = with_overrides(file_dict, overrides_dict)
else:
param_dict = file_dict
return cls(param_dict)
def to_file(
self, params_file: PathOrStr, preference_orders: Optional[List[List[str]]] = None
) -> None:
"""
Write the params to file.
"""
with open(params_file, "w") as handle:
json.dump(self.as_ordered_dict(preference_orders), handle, indent=4)
def as_ordered_dict(self, preference_orders: Optional[List[List[str]]] = None) -> OrderedDict:
"""
Returns an ``OrderedDict`` of ``Params`` from list of partial order preferences.
:param preference_orders:
``preference_orders`` is list of partial preference orders. ["A", "B", "C"] means
"A" > "B" > "C". For multiple preference_orders first will be considered first.
Keys not found, will have last but alphabetical preference. Default Preferences:
``[["dataset_reader", "iterator", "model", "train_data_path", "validation_data_path",
"test_data_path", "trainer", "vocabulary"], ["type"]]``
"""
params_dict = self.as_dict(quiet=True)
if not preference_orders:
preference_orders = []
preference_orders.append(["type"])
def order_func(key):
# Makes a tuple to use for ordering. The tuple is an index into each of the `preference_orders`,
# followed by the key itself. This gives us integer sorting if you have a key in one of the
# `preference_orders`, followed by alphabetical ordering if not.
order_tuple = [
order.index(key) if key in order else len(order) for order in preference_orders # type: ignore
]
return order_tuple + [key]
def order_dict(dictionary, order_func):
# Recursively orders dictionary according to scoring order_func
result = OrderedDict()
for key, val in sorted(dictionary.items(), key=lambda item: order_func(item[0])):
result[key] = order_dict(val, order_func) if isinstance(val, dict) else val
return result
return order_dict(params_dict, order_func)
def get_hash(self) -> str:
"""
Returns a hash code representing the current state of this ``Params`` object. We don't
want to implement ``__hash__`` because that has deeper python implications (and this is a
mutable object), but this will give you a representation of the current state.
We use ``zlib.adler32`` instead of Python's builtin ``hash`` because the random seed for the
latter is reset on each new program invocation, as discussed here:
https://stackoverflow.com/questions/27954892/deterministic-hashing-in-python-3.
"""
dumped = json.dumps(self.params, sort_keys=True)
hashed = zlib.adler32(dumped.encode())
return str(hashed)
def __str__(self) -> str:
return f"{self.history}Params({self.params})"
================================================
FILE: tango/common/registrable.py
================================================
"""
:class:`Registrable` is a "mixin" for endowing
any base class with a named registry for its subclasses and a decorator
for registering them.
"""
import importlib
import logging
from collections import defaultdict
from typing import (
Callable,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
cast,
)
from .exceptions import ConfigurationError, IntegrationMissingError, RegistryKeyError
from .from_params import FromParams
from .util import (
could_be_class_name,
find_integrations,
find_submodules,
import_module_and_submodules,
)
logger = logging.getLogger(__name__)
_T = TypeVar("_T")
_RegistrableT = TypeVar("_RegistrableT", bound="Registrable")
_SubclassRegistry = Dict[str, Tuple[type, Optional[str]]]
class Registrable(FromParams):
"""
Any class that inherits from ``Registrable`` gains access to a named registry for its
subclasses. To register them, just decorate them with the classmethod
``@BaseClass.register(name)``.
After which you can call ``BaseClass.list_available()`` to get the keys for the
registered subclasses, and ``BaseClass.by_name(name)`` to get the corresponding subclass.
Note that the registry stores the subclasses themselves; not class instances.
In most cases you would then call :meth:`~tango.common.from_params.FromParams.from_params()`
on the returned subclass.
You can specify a default by setting ``BaseClass.default_implementation``.
If it is set, it will be the first element of :meth:`list_available()`.
Note that if you use this class to implement a new ``Registrable`` abstract class,
you must ensure that all subclasses of the abstract class are loaded when the module is
loaded, because the subclasses register themselves in their respective files. You can
achieve this by having the abstract class and all subclasses in the ``__init__.py`` of the
module in which they reside (as this causes any import of either the abstract class or
a subclass to load all other subclasses and the abstract class).
"""
_registry: ClassVar[DefaultDict[type, _SubclassRegistry]] = defaultdict(dict)
default_implementation: Optional[str] = None
@classmethod
def register(
cls, name: str, constructor: Optional[str] = None, exist_ok: bool = False
) -> Callable[[Type[_T]], Type[_T]]:
"""
Register a class under a particular name.
:param name:
The name to register the class under.
:param constructor:
The name of the method to use on the class to construct the object. If this is given,
we will use this method (which must be a ``@classmethod``) instead of the default
constructor.
:param exist_ok:
If True, overwrites any existing models registered under ``name``. Else,
throws an error if a model is already registered under ``name``.
Examples
--------
To use this class, you would typically have a base class that inherits from ``Registrable``::
class Vocabulary(Registrable):
...
Then, if you want to register a subclass, you decorate it like this::
@Vocabulary.register("my-vocabulary")
class MyVocabulary(Vocabulary):
def __init__(self, param1: int, param2: str):
...
Registering a class like this will let you instantiate a class from a config file, where you
give ``"type": "my-vocabulary"``, and keys corresponding to the parameters of the ``__init__``
method (note that for this to work, those parameters must have type annotations).
If you want to have the instantiation from a config file call a method other than the
constructor, either because you have several different construction paths that could be
taken for the same object (as we do in ``Vocabulary``) or because you have logic you want to
happen before you get to the constructor (as we do in ``Embedding``), you can register a
specific ``@classmethod`` as the constructor to use, like this::
@Vocabulary.register("my-vocabulary-from-instances", constructor="from_instances")
@Vocabulary.register("my-vocabulary-from-files", constructor="from_files")
class MyVocabulary(Vocabulary):
def __init__(self, some_params):
...
@classmethod
def from_instances(cls, some_other_params) -> MyVocabulary:
... # construct some_params from instances
return cls(some_params)
@classmethod
def from_files(cls, still_other_params) -> MyVocabulary:
... # construct some_params from files
return cls(some_params)
"""
if _cls_is_step(cls) and name == "ref":
raise ConfigurationError(
"You cannot use the name 'ref' to register a step. This name is reserved."
)
registry = Registrable._registry[cls]
def add_subclass_to_registry(subclass: Type[_T]) -> Type[_T]:
# Add to registry, raise an error if key has already been used.
if name in registry:
already_in_use_for = registry[name][0]
if already_in_use_for.__module__ == "__main__":
# Sometimes the same class shows up under module.submodule.Class and __main__.Class, and we
# don't want to make a fuss in that case. We prefer the class without __main__, so we go
# ahead and overwrite the entry.
pass
elif subclass.__module__ == "__main__":
# We don't want to overwrite the entry because the new one comes from the __main__ module.
return already_in_use_for
elif exist_ok:
message = (
f"Registering {_fullname(subclass)} as a {_fullname(cls)} under the name {name} "
f"overwrites existing entry {_fullname(already_in_use_for)}, which is fine because "
"you said exist_ok=True."
)
logger.info(message)
else:
message = (
f"Attempting to register {_fullname(subclass)} as a {_fullname(cls)} under the name "
f"'{name}' failed. {_fullname(already_in_use_for)} is already registered under that name."
)
raise ConfigurationError(message)
registry[name] = (subclass, constructor)
return subclass
return add_subclass_to_registry
@classmethod
def by_name(cls: Type[_RegistrableT], name: str) -> Callable[..., _RegistrableT]:
"""
Returns a callable function that constructs an argument of the registered class. Because
you can register particular functions as constructors for specific names, this isn't
necessarily the ``__init__`` method of some class.
"""
logger.debug(f"instantiating registered subclass {name} of {cls}")
subclass, constructor = cls.resolve_class_name(name)
if not constructor:
return cast(Type[_RegistrableT], subclass)
else:
return cast(Callable[..., _RegistrableT], getattr(subclass, constructor))
@classmethod
def search_modules(cls: Type[_RegistrableT], name: str):
"""
Search for and import modules where ``name`` might be registered.
"""
if (
could_be_class_name(name)
or name in Registrable._registry[cls]
or (_cls_is_step(cls) and name == "ref")
):
return None
def try_import(module, recursive: bool = True):
try:
import_module_and_submodules(module, recursive=recursive)
except IntegrationMissingError:
pass
except ImportError as e:
if e.name != module:
raise
integrations = {m.split(".")[-1]: m for m in find_integrations()}
integrations_imported: Set[str] = set()
if name in integrations:
try_import(integrations[name], recursive=False)
integrations_imported.add(name)
if name in Registrable._registry[cls]:
return None
if "::" in name:
# Try to guess the integration that it comes from.
maybe_integration = name.split("::")[0]
if maybe_integration in integrations:
try_import(integrations[maybe_integration], recursive=False)
integrations_imported.add(maybe_integration)
if name in Registrable._registry[cls]:
return None
# Check Python files and modules in the current directory.
from glob import glob
from pathlib import Path
for pyfile in glob("*.py"):
module = str(Path(pyfile).with_suffix(""))
if module == "setup":
continue
try:
try_import(module)
if name in Registrable._registry[cls]:
return None
except: # noqa: E722
continue
for pyinit in glob("**/__init__.py"):
module = str(Path(pyinit).parent)
if module == "tango" or module.startswith("test"):
continue
try:
try_import(module)
if name in Registrable._registry[cls]:
return None
except: # noqa: E722
continue
# Search all other modules in Tango.
for module in find_submodules(exclude={"tango.integrations*"}, recursive=False):
try_import(module)
if name in Registrable._registry[cls]:
return None
# Try importing all other integrations.
for integration_name, module in integrations.items():
if integration_name not in integrations_imported:
try_import(module, recursive=False)
integrations_imported.add(integration_name)
if name in Registrable._registry[cls]:
return None
@classmethod
def resolve_class_name(
cls: Type[_RegistrableT],
name: str,
search_modules: bool = True,
) -> Tuple[Type[_RegistrableT], Optional[str]]:
"""
Returns the subclass that corresponds to the given ``name``, along with the name of the
method that was registered as a constructor for that ``name``, if any.
This method also allows ``name`` to be a fully-specified module name, instead of a name that
was already added to the ``Registry``. In that case, you cannot use a separate function as
a constructor (as you need to call ``cls.register()`` in order to tell us what separate
function to use).
If the ``name`` given is not in the registry and ``search_modules`` is ``True``,
it will search for and import modules where the class might be defined according to
:meth:`search_modules()`.
"""
if name in Registrable._registry[cls]:
subclass, constructor = Registrable._registry[cls][name]
return subclass, constructor
elif could_be_class_name(name):
# This might be a fully qualified class name, so we'll try importing its "module"
# and finding it there.
parts = name.split(".")
submodule = ".".join(parts[:-1])
class_name = parts[-1]
try:
module = importlib.import_module(submodule)
except ModuleNotFoundError:
raise ConfigurationError(
f"tried to interpret {name} as a path to a class "
f"but unable to import module {submodule}"
)
try:
subclass = getattr(module, class_name)
constructor = None
return subclass, constructor
except AttributeError:
raise ConfigurationError(
f"tried to interpret {name} as a path to a class "
f"but unable to find class {class_name} in {submodule}"
)
else:
# is not a qualified class name
if search_modules:
cls.search_modules(name)
return cls.resolve_class_name(name, search_modules=False)
available = cls.list_available()
suggestion = _get_suggestion(name, available)
raise RegistryKeyError(
(
f"'{name}' is not a registered name for '{cls.__name__}'"
+ (". " if not suggestion else f", did you mean '{suggestion}'? ")
)
+ "If your registered class comes from custom code, you'll need to import "
"the corresponding modules. If you're using Tango or AllenNLP from the command-line, "
"this is done by using the '--include-package' flag, or by specifying your imports "
"in a 'tango.yml' settings file. "
"Alternatively, you can specify your choices "
"""using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """
"in which case they will be automatically imported correctly."
)
@classmethod
def list_available(cls) -> List[str]:
"""List default first if it exists"""
keys = list(Registrable._registry[cls].keys())
default = cls.default_implementation
if default is None:
return keys
if default not in keys:
cls.search_modules(default)
keys = list(Registrable._registry[cls].keys())
if default not in keys:
raise ConfigurationError(f"Default implementation '{default}' is not registered")
else:
return [default] + [k for k in keys if k != default]
class RegistrableFunction(Registrable):
"""
A registrable class mimicking a `Callable`. This is to allow
referring to functions by their name in tango configurations.
"""
WRAPPED_FUNC: ClassVar[Callable]
def __call__(self, *args, **kwargs):
return self.__class__.WRAPPED_FUNC(*args, **kwargs)
def make_registrable(name: Optional[str] = None, *, exist_ok: bool = False):
"""
A decorator to create a :class:`RegistrableFunction` from a function.
:param name: A name to register the function under. By default the name of the function is used.
:param exist_ok:
If True, overwrites any existing function registered under the same ``name``. Else,
throws an error if a function is already registered under ``name``.
"""
def function_wrapper(func):
@RegistrableFunction.register(name or func.__name__, exist_ok=exist_ok)
class WrapperFunc(RegistrableFunction):
WRAPPED_FUNC = func
return WrapperFunc()
return function_wrapper
def _get_suggestion(name: str, available: List[str]) -> Optional[str]:
# Check for simple mistakes like using '-' instead of '_', or vice-versa.
for ch, repl_ch in (("_", "-"), ("-", "_")):
suggestion = name.replace(ch, repl_ch)
if suggestion in available:
return suggestion
return None
def _fullname(c: type) -> str:
return f"{c.__module__}.{c.__qualname__}"
def _cls_is_step(c: type) -> bool:
# NOTE (epwalsh): importing the actual Step class here would result in a circular
# import, even though the import wouldn't be at the top of the module (believe me, I've tried).
# So instead we just check the fully qualified name of the class.
return _fullname(c) == "tango.step.Step"
================================================
FILE: tango/common/remote_utils.py
================================================
import logging
from typing import Union
from tango.step import Step
from tango.step_info import StepInfo
logger = logging.getLogger(__name__)
class RemoteConstants:
"""
Common constants to be used as prefixes and filenames in remote workspaces.
"""
RUN_ARTIFACT_PREFIX: str = "tango-run-"
RUN_DATA_FNAME: str = "run.json"
STEP_ARTIFACT_PREFIX: str = "tango-step-"
STEP_INFO_FNAME: str = "step_info.json"
STEP_RESULT_DIR: str = "result"
STEP_GRAPH_ARTIFACT_PREFIX: str = "tango-step-graph-"
STEP_EXPERIMENT_PREFIX: str = "tango-step-"
STEP_GRAPH_FILENAME: str = "config.json"
GITHUB_TOKEN_SECRET_NAME: str = "TANGO_GITHUB_TOKEN"
RESULTS_DIR: str = "/tango/output"
INPUT_DIR: str = "/tango/input"
LOCK_ARTIFACT_SUFFIX: str = "-lock"
@classmethod
def step_artifact_name(cls, step: Union[str, StepInfo, Step]) -> str:
return f"{cls.STEP_ARTIFACT_PREFIX}{step if isinstance(step, str) else step.unique_id}"
@classmethod
def step_lock_artifact_name(cls, step: Union[str, StepInfo, Step]) -> str:
return f"{cls.step_artifact_name(step)}{cls.LOCK_ARTIFACT_SUFFIX}"
@classmethod
def run_artifact_name(cls, name: str) -> str:
return f"{cls.RUN_ARTIFACT_PREFIX}{name}"
================================================
FILE: tango/common/sequences.py
================================================
import bisect
import os
import random
import shutil
from collections import abc
from os import PathLike
from typing import Any, Callable, Iterable, MutableSequence, Optional, Sequence, Union
class ShuffledSequence(abc.Sequence):
"""
Produces a shuffled view of a sequence, such as a list.
This assumes that the inner sequence never changes. If it does, the results
are undefined.
:param inner_sequence: the inner sequence that's being shuffled
:param indices: Optionally, you can specify a list of indices here. If you don't, we'll just shuffle the
inner sequence randomly. If you do specify indices, element ``n`` of the output sequence
will be ``inner_sequence[indices[n]]``. This gives you great flexibility. You can repeat
elements, leave them out completely, or slice the list. A Python :class:`slice` object is
an acceptable input for this parameter, and so are other sequences from this module.
Example:
.. testcode::
:hide:
import random
random.seed(42)
.. testcode::
from tango.common.sequences import ShuffledSequence
l = [1, 2, 3, 4, 5, 6, 7, 8, 9]
shuffled_l = ShuffledSequence(l)
print(shuffled_l[0])
print(shuffled_l[1])
print(shuffled_l[2])
assert len(shuffled_l) == len(l)
This will print something like the following:
.. testoutput::
4
7
8
"""
def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None):
self.inner = inner_sequence
self.indices: Sequence[int]
if indices is None:
self.indices = list(range(len(inner_sequence)))
random.shuffle(self.indices)
else:
self.indices = indices
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, i: Union[int, slice]):
if isinstance(i, int):
return self.inner[self.indices[i]]
else:
return ShuffledSequence(self.inner, self.indices[i])
def __contains__(self, item) -> bool:
for i in self.indices:
if self.inner[i] == item:
return True
return False
class SlicedSequence(ShuffledSequence):
"""
Produces a sequence that's a slice into another sequence, without copying the elements.
This assumes that the inner sequence never changes. If it does, the results
are undefined.
:param inner_sequence: the inner sequence that's being shuffled
:param s: the :class:`~slice` to slice the input with.
Example:
.. testcode::
from tango.common.sequences import SlicedSequence
l = [1, 2, 3, 4, 5, 6, 7, 8, 9]
sliced_l = SlicedSequence(l, slice(1, 4))
print(sliced_l[0])
print(sliced_l[1])
print(sliced_l[2])
assert len(sliced_l) == 3
This will print the following:
.. testoutput::
2
3
4
"""
def __init__(self, inner_sequence: Sequence, s: slice):
super().__init__(inner_sequence, range(*s.indices(len(inner_sequence))))
class ConcatenatedSequence(abc.Sequence):
"""
Produces a sequence that's the lazy concatenation of multiple other sequences. It does not copy
any of the elements of the original sequences.
This assumes that the inner sequences never change. If they do, the results are undefined.
:param sequences: the inner sequences to concatenate
Example:
.. testcode::
from tango.common.sequences import ConcatenatedSequence
l1 = [1, 2, 3]
l2 = [4, 5]
l3 = [6]
cat_l = ConcatenatedSequence(l1, l2, l3)
assert len(cat_l) == 6
for i in cat_l:
print(i)
This will print the following:
.. testoutput::
1
2
3
4
5
6
"""
def __init__(self, *sequences: Sequence):
self.sequences = sequences
self.cumulative_sequence_lengths = [0]
for sequence in sequences:
self.cumulative_sequence_lengths.append(
self.cumulative_sequence_lengths[-1] + len(sequence)
)
def __len__(self):
return self.cumulative_sequence_lengths[-1]
def __getitem__(self, i: Union[int, slice]):
if isinstance(i, int):
if i < 0:
i += len(self)
if i < 0 or i >= len(self):
raise IndexError("list index out of range")
sequence_index = bisect.bisect_right(self.cumulative_sequence_lengths, i) - 1
i -= self.cumulative_sequence_lengths[sequence_index]
return self.sequences[sequence_index][i]
else:
return SlicedSequence(self, i)
def __contains__(self, item) -> bool:
return any(s.__contains__(item) for s in self.sequences)
class MappedSequence(abc.Sequence):
"""
Produces a sequence that applies a function to every element of another sequence.
This is similar to Python's :func:`map`, but it returns a sequence instead of a :class:`map` object.
:param fn: the function to apply to every element of the inner sequence. The function should take
one argument.
:param inner_sequence: the inner sequence to map over
Example:
.. testcode::
from tango.common.sequences import MappedSequence
def square(x):
return x * x
l = [1, 2, 3, 4]
map_l = MappedSequence(square, l)
assert len(map_l) == len(l)
for i in map_l:
print(i)
This will print the following:
.. testoutput::
1
4
9
16
"""
def __init__(self, fn: Callable, inner_sequence: Sequence):
self.inner = inner_sequence
self.fn = fn
def __getitem__(self, item):
if isinstance(item, slice):
new_inner = None
try:
# special case for a special library
from datasets import Dataset
if isinstance(self.inner, Dataset):
new_inner = self.inner.select(range(*item.indices(len(self.inner))))
except ImportError:
pass
if new_inner is None:
new_inner = self.inner[item]
return MappedSequence(self.fn, new_inner)
else:
item = self.inner.__getitem__(item)
return self.fn(item)
def __len__(self):
return self.inner.__len__()
def __contains__(self, item):
return any(e == item for e in self)
class SqliteSparseSequence(MutableSequence[Any]):
"""
This is a sparse sequence that pickles elements to a Sqlite database.
When you read from the sequence, elements are retrieved and unpickled lazily. That means creating/opening
a sequence is very fast and does not depend on the length of the sequence.
This is a "sparse sequence" because you can set element ``n`` before you set element ``n-1``:
.. testcode::
:hide:
from tango.common.sequences import SqliteSparseSequence
import tempfile
dir = tempfile.TemporaryDirectory()
from pathlib import Path
filename = Path(dir.name) / "test.sqlite"
.. testcode::
s = SqliteSparseSequence(filename)
element = "Big number, small database."
s[2**32] = element
assert len(s) == 2**32 + 1
assert s[2**32] == element
assert s[1000] is None
s.close()
.. testcode::
:hide:
dir.cleanup()
You can use a ``SqliteSparseSequence`` from multiple processes at the same time. This is useful, for example,
if you're filling out a sequence and you are partitioning ranges to processes.
:param filename: the filename at which to store the data
:param read_only: Set this to ``True`` if you only want to read.
"""
def __init__(self, filename: Union[str, PathLike], read_only: bool = False):
from sqlitedict import SqliteDict
self.table = SqliteDict(filename, "sparse_sequence", flag="r" if read_only else "c")
def __del__(self):
if self.table is not None:
self.table.close(force=True)
self.table = None
def __getitem__(self, i: Union[int, slice]) -> Any:
if isinstance(i, int):
try:
return self.table[str(i)]
except KeyError:
current_length = len(self)
if i >= current_length or current_length <= 0:
raise IndexError("list index out of range")
elif i < 0 < current_length:
return self.__getitem__(i % current_length)
else:
return None
elif isinstance(i, slice):
return SlicedSequence(self, i)
else:
raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}")
def __setitem__(self, i: Union[int, slice], value: Any):
if isinstance(i, int):
current_length = len(self)
if i < 0:
i %= current_length
self.table[str(i)] = value
self.table["_len"] = max(i + 1, current_length)
self.table.commit()
else:
raise TypeError(f"list indices must be integers, not {i.__class__.__name__}")
def __delitem__(self, i: Union[int, slice]):
current_length = len(self)
if isinstance(i, int):
if i < 0:
i %= current_length
if i >= current_length:
raise IndexError("list assignment index out of range")
for index in range(i + 1, current_length):
self.table[str(index - 1)] = self.table.get(str(index))
del self.table[str(current_length - 1)]
self.table["_len"] = current_length - 1
self.table.commit()
elif isinstance(i, slice):
# This isn't very efficient for continuous slices.
for index in reversed(range(*i.indices(current_length))):
del self[index]
else:
raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}")
def extend(self, values: Iterable[Any]) -> None:
current_length = len(self)
index = -1
for index, value in enumerate(values):
self.table[str(index + current_length)] = value
if index < 0:
return
self.table["_len"] = current_length + index + 1
self.table.commit()
def insert(self, i: int, value: Any) -> None:
current_length = len(self)
for index in reversed(range(i, current_length)):
self.table[str(index + 1)] = self.table.get(str(index))
self.table[str(i)] = value
self.table["_len"] = max(i + 1, current_length + 1)
self.table.commit()
def __len__(self) -> int:
try:
return self.table["_len"]
except KeyError:
return 0
def clear(self) -> None:
"""
Clears the entire sequence
"""
self.table.clear()
self.table.commit()
def close(self) -> None:
"""
Closes the underlying Sqlite table. Do not use this sequence afterwards!
"""
if self.table is not None:
self.table.close()
self.table = None
def copy_to(self, target: Union[str, PathLike]):
"""
Make a copy of this sequence at a new location.
:param target: the location of the copy
This will attempt to make a hardlink, which is very fast, but only works on Linux and if ``target`` is
on the same drive. If making a hardlink fails, it falls back to making a regular copy. As a result,
there is no guarantee whether you will get a hardlink or a copy. If you get a hardlink, future edits
in the source sequence will also appear in the target sequence. This is why we recommend to not use
:meth:`copy_to()` until you are done with the sequence. This is not ideal, but it is a compromise we make
for performance.
"""
try:
os.link(self.table.filename, target)
except OSError as e:
if e.errno == 18: # Cross-device link
shutil.copy(self.table.filename, target)
else:
raise
================================================
FILE: tango/common/testing/__init__.py
================================================
import logging
import os
import shutil
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, cast
from tango.common.aliases import EnvVarNames, PathOrStr
from tango.common.logging import initialize_logging, teardown_logging
from tango.common.params import Params
from tango.settings import TangoGlobalSettings
class TangoTestCase:
"""
A custom testing class that
* disables some of the more verbose logging,
* creates and destroys a temp directory as a test fixture, and
* restores the internal state of the `Registrable` class at the end of each test method.
"""
PROJECT_ROOT = (Path(__file__).parent / ".." / ".." / "..").resolve()
"""
Root of the git repository.
"""
# to run test suite with finished package, which does not contain
# tests & fixtures, we must be able to look them up somewhere else
PROJECT_ROOT_FALLBACK = (
# users wanting to run test suite for installed package
Path(os.environ["TANGO_SRC_DIR"])
if "TANGO_SRC_DIR" in os.environ
else (
# fallback for conda packaging
Path(os.environ["SRC_DIR"])
if "CONDA_BUILD" in os.environ
# stay in-tree
else PROJECT_ROOT
)
)
MODULE_ROOT = PROJECT_ROOT_FALLBACK / "tango"
"""
Root of the tango module.
"""
TESTS_ROOT = PROJECT_ROOT_FALLBACK / "tests"
"""
Root of the tests directory.
"""
FIXTURES_ROOT = PROJECT_ROOT_FALLBACK / "test_fixtures"
"""
Root of the test fixtures directory.
"""
def setup_method(self):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.DEBUG
)
# Disabling some of the more verbose logging statements that typically aren't very helpful
# in tests.
logging.getLogger("urllib3.connectionpool").disabled = True
# Create a temporary scratch directory.
self.TEST_DIR = Path(tempfile.mkdtemp(prefix="tango_tests"))
os.makedirs(self.TEST_DIR, exist_ok=True)
# Set an artificial console width so logs are not mangled.
os.environ[EnvVarNames.CONSOLE_WIDTH.value] = str(300)
def teardown_method(self):
shutil.rmtree(self.TEST_DIR)
if EnvVarNames.CONSOLE_WIDTH.value in os.environ:
del os.environ[EnvVarNames.CONSOLE_WIDTH.value]
def run(
self,
config: Union[PathOrStr, Dict[str, Any], Params],
overrides: Optional[Union[Dict[str, Any], str]] = None,
include_package: Optional[List[str]] = None,
workspace_url: Optional[str] = None,
step_name: Optional[str] = None,
parallelism: Optional[int] = 1,
multicore: Optional[bool] = False,
name: Optional[str] = None,
settings: Optional[TangoGlobalSettings] = None,
) -> Path:
from tango.__main__ import _run
if isinstance(config, (dict, Params)):
params = config if isinstance(config, Params) else Params(config)
config = self.TEST_DIR / "config.json"
params.to_file(cast(Path, config))
if isinstance(overrides, dict):
import json
overrides = json.dumps(overrides)
run_name = _run(
settings or TangoGlobalSettings(),
str(config),
workspace_url=workspace_url or "local://" + str(self.TEST_DIR / "workspace"),
overrides=overrides,
include_package=include_package,
step_names=None if not step_name else [step_name],
parallelism=parallelism,
multicore=multicore,
name=name,
)
return self.TEST_DIR / "workspace" / "runs" / run_name
@contextmanager
def run_experiment(
config: Union[PathOrStr, Dict[str, Any], Params],
overrides: Optional[Union[Dict[str, Any], str]] = None,
file_friendly_logging: bool = True,
include_package: Optional[List[str]] = None,
workspace_url: Optional[str] = None,
parallelism: Optional[int] = 1,
multicore: Optional[bool] = False,
name: Optional[str] = None,
settings: Optional[TangoGlobalSettings] = None,
):
"""
A context manager to make testing experiments easier. On ``__enter__`` it runs
the experiment and returns the path to the run directory, a temporary directory that will be
cleaned up on ``__exit__``.
"""
initialize_logging(enable_cli_logs=True, file_friendly_logging=file_friendly_logging)
test_case = TangoTestCase()
try:
test_case.setup_method()
yield test_case.run(
config,
overrides=overrides,
include_package=include_package,
workspace_url=workspace_url,
parallelism=parallelism,
multicore=multicore,
name=name,
settings=settings,
)
finally:
test_case.teardown_method()
teardown_logging()
def requires_gpus(test_method):
"""
Decorator to indicate that a test requires multiple GPU devices.
"""
import pytest
import torch
return pytest.mark.gpu(
pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 or more GPUs required.")(
test_method
)
)
================================================
FILE: tango/common/testing/steps.py
================================================
import logging
import multiprocessing as mp
import random
import time
from string import ascii_letters
from typing import List
import tango.common.logging as common_logging
from tango import Step
from tango.common import Tqdm
@Step.register("float")
class FloatStep(Step):
CACHEABLE = True
DETERMINISTIC = True
def run(self, result: float) -> float: # type: ignore
return result
@Step.register("string")
class StringStep(Step):
CACHEABLE = True
DETERMINISTIC = True
def run(self, result: str) -> str: # type: ignore
return result
@Step.register("concat_strings")
class ConcatStringsStep(Step):
CACHEABLE = True
DETERMINISTIC = True
def run(self, string1: str, string2: str, join_with: str = " ") -> str: # type: ignore
return join_with.join([string1, string2])
@Step.register("noisy_step")
class NoisyStep(Step):
CACHEABLE = True
DETERMINISTIC = True
def run(self, raise_error: bool = False) -> None: # type: ignore
self.logger.debug("debug message")
common_logging.cli_logger.debug("debug message from cli_logger")
self.logger.info("info message")
common_logging.cli_logger.info("info message from cli_logger")
self.logger.warning("warning message")
common_logging.cli_logger.warning("warning message from cli_logger")
self.logger.error("error message")
common_logging.cli_logger.error("error message from cli_logger")
if raise_error:
raise ValueError("Oh no!")
@Step.register("random_string")
class RandomStringStep(Step):
def run(self, length: int = 10) -> str: # type: ignore
return "".join([random.choice(ascii_letters) for _ in range(length)])
@Step.register("add_numbers")
class AddNumbersStep(Step):
DETERMINISTIC = True
CACHEABLE = True
def run(self, a_number: int, b_number: int) -> int: # type: ignore
return a_number + b_number
@Step.register("sleep-print-maybe-fail")
class SleepPrintMaybeFail(Step):
DETERMINISTIC = True
CACHEABLE = True
def run(self, string: str, seconds: int = 5, fail: bool = False) -> str: # type: ignore
time.sleep(seconds)
self.logger.info(f"Step {self.name} is awake.")
print(string)
if fail:
raise RuntimeError("Step had to fail!")
return string
@Step.register("logging-step")
class LoggingStep(Step):
DETERMINISTIC = True
CACHEABLE = True
def run(self, string: str, num_log_lines: int = 50) -> str: # type: ignore
for i in Tqdm.tqdm(list(range(num_log_lines)), desc="log progress"):
time.sleep(0.1)
self.logger.info(f"{i} - {string}")
return string
@Step.register("make_number")
class MakeNumber(Step):
DETERMINISTIC = True
CACHEABLE = True
def run(self, what_number: int) -> int: # type: ignore
return what_number
@Step.register("store_number_in_file")
class StoreNumberInFile(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self, number: int, file_name: str) -> None: # type: ignore
# Note: this is only for testing if the uncacheable step
# ran in the multicore setting. Normally, a step like this
# would be marked as CACHEABLE.
with open(file_name, "w") as file_ref:
file_ref.write(str(number))
@Step.register("multiprocessing_step")
class MultiprocessingStep(Step):
"""
Mainly used to test that logging works properly in child processes.
"""
def run(self, num_proc: int = 2) -> bool: # type: ignore
for _ in Tqdm.tqdm(list(range(10)), desc="progress from main process"):
time.sleep(0.1)
workers = []
for i in range(num_proc):
worker = mp.Process(target=_worker_function, args=(i,))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
return True
@Step.register("range_step")
class RangeOutput(Step):
def run(self, start: int, end: int) -> List[int]: # type: ignore
return list(range(start, end))
def _worker_function(worker_id: int):
common_logging.initialize_worker_logging(worker_id)
logger = logging.getLogger(MultiprocessingStep.__name__)
logger.info("Hello from worker %d!", worker_id)
common_logging.cli_logger.info("Hello from the cli logger in worker %d!", worker_id)
for _ in Tqdm.tqdm(list(range(10)), desc="progress from worker", disable=worker_id > 0):
time.sleep(0.1)
================================================
FILE: tango/common/tqdm.py
================================================
"""
Copied over from ``allennlp.common.tqdm.Tqdm``.
Wraps tqdm so we can add configurable global defaults for certain tqdm parameters.
"""
import logging
import sys
from contextlib import contextmanager
from time import time
from typing import Optional
try:
SHELL = str(type(get_ipython())) # type:ignore # noqa: F821
except: # noqa: E722
SHELL = ""
if "zmqshell.ZMQInteractiveShell" in SHELL:
from tqdm import tqdm_notebook as _tqdm
else:
from tqdm import tqdm as _tqdm
from tango.common import logging as common_logging
# This is necessary to stop tqdm from hanging
# when exceptions are raised inside iterators.
# It should have been fixed in 4.2.1, but it still
# occurs.
# TODO(Mark): Remove this once tqdm cleans up after itself properly.
# https://github.com/tqdm/tqdm/issues/469
_tqdm.monitor_interval = 0
logger = logging.getLogger("tqdm")
logger.propagate = False
def replace_cr_with_newline(message: str) -> str:
"""
TQDM and requests use carriage returns to get the training line to update for each batch
without adding more lines to the terminal output. Displaying those in a file won't work
correctly, so we'll just make sure that each batch shows up on its one line.
"""
# In addition to carriage returns, nested progress-bars will contain extra new-line
# characters and this special control sequence which tells the terminal to move the
# cursor one line up.
message = message.replace("\r", "").replace("\n", "").replace("[A", "")
if message and message[-1] != "\n":
message += "\n"
return message
class TqdmToLogsWriter:
def __init__(self):
self.last_message_written_time = 0.0
def write(self, message):
file_friendly_message: Optional[str] = None
if common_logging.FILE_FRIENDLY_LOGGING:
file_friendly_message = replace_cr_with_newline(message)
if file_friendly_message.strip():
sys.stderr.write(file_friendly_message)
else:
sys.stderr.write(message)
# Every 10 seconds we also log the message.
now = time()
if now - self.last_message_written_time >= 10 or "100%" in message:
if file_friendly_message is None:
file_friendly_message = replace_cr_with_newline(message)
for message in file_friendly_message.split("\n"):
message = message.strip()
if len(message) > 0:
logger.info(message)
self.last_message_written_time = now
def flush(self):
sys.stderr.flush()
class Tqdm:
"""
A `tqdm `_ wrapper that respects
:data:`~tango.common.logging.FILE_FRIENDLY_LOGGING` and other Tango logging configurations.
"""
@staticmethod
def tqdm(*args, **kwargs):
new_kwargs = Tqdm.get_updated_kwargs(**kwargs)
return _tqdm(*args, **new_kwargs)
@staticmethod
@contextmanager
def wrapattr(*args, **kwargs):
new_kwargs = Tqdm.get_updated_kwargs(**kwargs)
with _tqdm.wrapattr(*args, **new_kwargs) as t:
yield t
@staticmethod
def get_updated_kwargs(**kwargs):
# Use a slower interval when FILE_FRIENDLY_LOGGING is set.
default_mininterval = 2.0 if common_logging.FILE_FRIENDLY_LOGGING else 0.1
return {
"file": TqdmToLogsWriter(),
"mininterval": default_mininterval,
**kwargs,
}
@staticmethod
def set_lock(lock):
_tqdm.set_lock(lock)
@staticmethod
def get_lock():
return _tqdm.get_lock()
================================================
FILE: tango/common/util.py
================================================
import importlib
import pkgutil
import signal
import string
import sys
import traceback
from collections import OrderedDict
from dataclasses import asdict, is_dataclass
from datetime import datetime, tzinfo
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, Optional, Set, Tuple, Union
import pytz
from .exceptions import SigTermReceived
def tango_cache_dir() -> Path:
"""
Returns a directory suitable for caching things from Tango, defaulting
to ``$HOME/.cache/tango``.
"""
cache_dir = Path.home() / ".cache" / "tango"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def _handle_sigterm(sig, frame):
raise SigTermReceived
def install_sigterm_handler():
signal.signal(signal.SIGTERM, _handle_sigterm)
_extra_imported_modules: Set[str] = set()
def get_extra_imported_modules() -> Set[str]:
return _extra_imported_modules
def import_extra_module(package_name: str) -> None:
global _extra_imported_modules
import_module_and_submodules(package_name)
_extra_imported_modules.add(package_name)
def resolve_module_name(package_name: str) -> Tuple[str, Path]:
base_path = Path(".")
package_path = Path(package_name)
if not package_path.exists():
raise ValueError(f"'{package_path}' looks like a path, but the path does not exist")
parent = package_path.parent
while parent != parent.parent:
if (parent / "__init__.py").is_file():
parent = parent.parent
else:
base_path = parent
break
package_name = str(package_path.relative_to(base_path)).replace("/", ".")
if package_path.is_file():
if package_path.name == "__init__.py":
# If `__init__.py` file, resolve to the parent module.
package_name = package_name[: -len(".__init__.py")]
elif package_name.endswith(".py"):
package_name = package_name[:-3]
if not package_name:
raise ValueError(f"invalid package path '{package_path}'")
return package_name, base_path
def import_module_and_submodules(
package_name: str, exclude: Optional[Set[str]] = None, recursive: bool = True
) -> None:
"""
Import all submodules under the given package.
Primarily useful so that people using tango can specify their own custom packages
and have their custom classes get loaded and registered.
"""
# If `package_name` is in the form of a path, convert to the module format.
if "/" in package_name or package_name.endswith(".py"):
package_name, base_path = resolve_module_name(package_name)
else:
base_path = Path(".")
base_path = base_path.resolve()
if exclude and package_name in exclude:
return
importlib.invalidate_caches()
# Ensure `base_path` is first in `sys.path`.
if str(base_path) not in sys.path:
sys.path.insert(0, str(base_path))
else:
sys.path.insert(0, sys.path.pop(sys.path.index(str(base_path))))
# Certain packages might mess with sys.excepthook which we don't like since
# we mess with sys.excepthook ourselves. If it looks like the package is overriding
# the hook for a reason, we'll leave it be but also make sure our hook is still called.
excepthook = sys.excepthook
# Import at top level
try:
module = importlib.import_module(package_name)
finally:
if sys.excepthook != excepthook:
if sys.excepthook.__module__.startswith("rich"):
# We definitely don't want rich's traceback hook because that will mess
# with our exception handling.
sys.excepthook = excepthook
else:
new_hook = sys.excepthook
def excepthook_wrapper(exctype, value, traceback):
excepthook(exctype, value, traceback)
new_hook(exctype, value, traceback)
sys.excepthook = excepthook_wrapper
path = getattr(module, "__path__", [])
path_string = "" if not path else path[0]
if recursive:
# walk_packages only finds immediate children, so need to recurse.
for module_finder, name, _ in pkgutil.walk_packages(path):
# Sometimes when you import third-party libraries that are on your path,
# `pkgutil.walk_packages` returns those too, so we need to skip them.
if path_string and module_finder.path != path_string: # type: ignore[union-attr]
continue
subpackage = f"{package_name}.{name}"
import_module_and_submodules(subpackage, exclude=exclude)
def _parse_bool(value: Union[bool, str]) -> bool:
if isinstance(value, bool):
return value
if value in {"1", "true", "True", "TRUE"}:
return True
return False
def _parse_optional_int(value: Optional[str]) -> Optional[int]:
if value is not None:
return int(value)
return None
def find_submodules(
module: Optional[str] = None,
match: Optional[Set[str]] = None,
exclude: Optional[Set[str]] = None,
recursive: bool = True,
) -> Iterable[str]:
"""
Find tango submodules.
"""
from fnmatch import fnmatch
root = Path(__file__).parent / ".."
if module:
if module.startswith("tango."):
module = module.replace("tango.", "", 1)
for m in module.split("."):
root = root / m
module = f"tango.{module}"
else:
module = "tango"
for path in root.iterdir():
if path.name[0] in {"_", "."}:
continue
submodule: str
if path.is_dir():
submodule = path.name
elif path.suffix == ".py":
submodule = path.name[:-3]
else:
continue
submodule = f"{module}.{submodule}"
if exclude and any((fnmatch(submodule, pat) for pat in exclude)):
continue
if match and not any((fnmatch(submodule, pat) for pat in match)):
continue
yield submodule
if recursive and path.is_dir():
yield from find_submodules(submodule, match=match, exclude=exclude)
def find_integrations() -> Iterable[str]:
"""
Find all tango integration modules.
"""
yield from find_submodules("tango.integrations", recursive=False)
SAFE_FILENAME_CHARS = frozenset("-_.%s%s" % (string.ascii_letters, string.digits))
def filename_is_safe(filename: str) -> bool:
return all(c in SAFE_FILENAME_CHARS for c in filename)
def make_safe_filename(name: str) -> str:
if filename_is_safe(name):
return name
else:
from tango.common.det_hash import det_hash
name_hash = det_hash(name)
name = name.replace(" ", "-").replace("/", "--")
return "".join(c for c in name if c in SAFE_FILENAME_CHARS) + f"-{name_hash[:7]}"
def could_be_class_name(name: str) -> bool:
if "." in name and not name.endswith("."):
return all([_is_valid_python_name(part) for part in name.split(".")])
else:
return False
def _is_valid_python_name(name: str) -> bool:
return bool(name and name[0].isalpha() and name.replace("_", "").isalnum())
def threaded_generator(g, queue_size: int = 16):
"""
Puts the generating side of a generator into its own thread
Let's say you have a generator that reads records from disk, and something that consumes the
generator that spends most of its time in PyTorch. Wouldn't it be great if you could read more
records while the PyTorch code runs? If you wrap your record-reading generator with
``threaded_generator(inner)``, that's exactly what happens. The reading code will run in a new thread,
while the consuming code runs in the main thread as normal. ``threaded_generator()`` uses a queue
to hand off items.
:param queue_size: the maximum queue size for hand-offs between the main thread and the generator thread
"""
from queue import Queue
from threading import Thread
q: Queue = Queue(maxsize=queue_size)
sentinel = object()
def fill_queue():
try:
for value in g:
q.put(value)
finally:
q.put(sentinel)
thread = Thread(name=repr(g), target=fill_queue, daemon=True)
thread.start()
yield from iter(q.get, sentinel)
thread.join()
def exception_to_string(e: BaseException) -> str:
"""
Generates a string that contains an exception plus stack frames based on an exception.
This became trivial in Python 3.10, but we need to run on Python 3.8 as well.
"""
if sys.version_info >= (3, 10):
formatted = traceback.format_exception(e)
else:
formatted = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
return "".join(formatted)
def utc_now_datetime() -> datetime:
return datetime.utcnow().replace(tzinfo=pytz.utc)
def local_timezone() -> Optional[tzinfo]:
return datetime.now().astimezone().tzinfo
def replace_steps_with_unique_id(o: Any):
from tango.step import Step, StepIndexer
if isinstance(o, Step):
return {"type": "ref", "ref": o.unique_id}
elif isinstance(o, StepIndexer):
return {"type": "ref", "ref": o.step.unique_id, "key": o.key}
elif isinstance(o, (list, tuple, set)):
return o.__class__(replace_steps_with_unique_id(i) for i in o)
elif isinstance(o, dict):
return {key: replace_steps_with_unique_id(value) for key, value in o.items()}
else:
return o
def jsonify(o: Any) -> Any:
"""
Transform an object into a JSON-serializable equivalent (if there is one)
in a deterministic way. For example, tuples and sets are turned into lists,
dictionaries are turned into ordered dictionaries where the order depends on the sorting
of the keys, and datetimes are turned into formatted strings.
"""
if isinstance(o, (tuple, set)):
return [jsonify(x) for x in o]
elif isinstance(o, dict):
return OrderedDict((k, jsonify(v)) for k, v in sorted(o.items(), key=lambda x: x[0]))
elif isinstance(o, datetime):
return o.strftime("%Y-%m-%dT%H:%M:%S")
elif is_dataclass(o):
return jsonify(asdict(o))
elif isinstance(o, Path):
return str(o)
else:
return o
class StrEnum(str, Enum):
def __str__(self) -> str:
return self.value
================================================
FILE: tango/executor.py
================================================
import logging
import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar
from rich import get_console
from rich.table import Table
from .common.logging import cli_logger, log_exception
from .common.registrable import Registrable
from .common.util import import_extra_module
from .step_graph import StepGraph
from .workspace import Workspace
if TYPE_CHECKING:
from .step import Step
logger = logging.getLogger(__name__)
T = TypeVar("T")
@dataclass
class ExecutionMetadata:
logs_location: Optional[str] = None
"""
Path or URL to the logs for the step's execution.
"""
result_location: Optional[str] = None
"""
Path or URL to the result of the step's execution.
"""
@dataclass
class ExecutorOutput:
"""
Describes the outcome of the execution.
"""
successful: Dict[str, ExecutionMetadata] = field(default_factory=dict)
"""Steps which ran successfully or were found in the cache."""
failed: Dict[str, ExecutionMetadata] = field(default_factory=dict)
"""Steps that failed."""
not_run: Dict[str, ExecutionMetadata] = field(default_factory=dict)
"""Steps that were ignored (usually because of failed dependencies)."""
def display(self) -> None:
table = Table(caption_style="")
table.add_column("Step Name", justify="left", style="cyan")
table.add_column("Status", justify="left")
table.add_column("Results", justify="left")
all_steps = dict(self.successful)
all_steps.update(self.failed)
all_steps.update(self.not_run)
for step_name in sorted(all_steps):
status_str: str
result_str: str = "[grey62]N/A[/]"
if step_name in self.failed:
status_str = "[red]\N{ballot x} failed[/]"
execution_metadata = self.failed[step_name]
if execution_metadata.logs_location is not None:
result_str = f"[cyan]{execution_metadata.logs_location}[/]"
elif step_name in self.not_run:
status_str = "[yellow]- not run[/]"
elif step_name in self.successful:
status_str = "[green]\N{check mark} succeeded[/]"
execution_metadata = self.successful[step_name]
if execution_metadata.result_location is not None:
result_str = f"[cyan]{execution_metadata.result_location}[/]"
elif execution_metadata.logs_location is not None:
result_str = f"[cyan]{execution_metadata.logs_location}[/]"
else:
continue
table.add_row(step_name, status_str, result_str)
caption_parts: List[str] = []
if self.failed:
caption_parts.append(f"[red]\N{ballot x}[/] [italic]{len(self.failed)} failed[/]")
if self.successful:
caption_parts.append(
f"[green]\N{check mark}[/] [italic]{len(self.successful)} succeeded[/]"
)
if self.not_run:
caption_parts.append(f"[italic]{len(self.not_run)} not run[/]")
table.caption = ", ".join(caption_parts)
if logger.isEnabledFor(logging.INFO):
logger.info(table)
elif cli_logger.isEnabledFor(logging.INFO):
cli_logger.info(table)
else:
get_console().print(table)
class Executor(Registrable):
"""
An ``Executor`` is a class that is responsible for running steps and caching their results.
This is the base class and default implementation, registered as "default".
.. note::
The ``parallelism`` parameter has no effect with this default :class:`Executor`,
but is part of the API because most subclass implementations allow configuring
parallelism.
"""
default_implementation = "default"
def __init__(
self,
workspace: Workspace,
include_package: Optional[Sequence[str]] = None,
parallelism: Optional[int] = None,
) -> None:
self.workspace = workspace
self.include_package = include_package
self.parallelism = parallelism
def execute_step(self, step: "Step") -> None:
# Import included packages to find registered components.
if self.include_package is not None:
for package_name in self.include_package:
import_extra_module(package_name)
if step.cache_results:
step.ensure_result(self.workspace)
else:
step.result(self.workspace)
def execute_step_graph(
self, step_graph: StepGraph, run_name: Optional[str] = None
) -> ExecutorOutput:
"""
Execute a :class:`~tango.step_graph.StepGraph`. This attempts to execute
every step in order. If a step fails, its dependent steps are not run,
but unrelated steps are still executed. Step failures will be logged, but
no exceptions will be raised.
"""
if self.parallelism is not None:
warnings.warn(
"The 'parallelism' parameter has no effect with the default Executor. "
"If you want to run steps in parallel, consider using the MulticoreExecutor.",
UserWarning,
)
successful: Dict[str, ExecutionMetadata] = {}
failed: Dict[str, ExecutionMetadata] = {}
not_run: Dict[str, ExecutionMetadata] = {}
uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps()
for step in step_graph.values():
if not step.cache_results and step not in uncacheable_leaf_steps:
# If a step is uncacheable and required for another step, it will be
# executed as part of the downstream step's execution.
continue
if any(dep.name in failed for dep in step.recursive_dependencies):
not_run[step.name] = ExecutionMetadata()
else:
try:
self.execute_step(step)
successful[step.name] = ExecutionMetadata(
result_location=self.workspace.step_info(step).result_location
)
except Exception as exc:
failed[step.name] = ExecutionMetadata()
log_exception(exc, logger)
return ExecutorOutput(successful=successful, failed=failed, not_run=not_run)
# NOTE: The reason for having this method instead of just using `execute_step()` to run
# a single step is that the certain executors, such as the BeakerExecutor, need to
# serialize steps somehow, and the easiest way to serialize a step is by serializing the
# whole step config (which can be accessed via the step graph).
def execute_sub_graph_for_steps(
self, step_graph: StepGraph, *step_names: str, run_name: Optional[str] = None
) -> ExecutorOutput:
"""
Execute the sub-graph associated with a particular step in a
:class:`~tango.step_graph.StepGraph`.
"""
sub_graph = step_graph.sub_graph(*step_names)
return self.execute_step_graph(sub_graph, run_name=run_name)
Executor.register("default")(Executor)
================================================
FILE: tango/executors/__init__.py
================================================
"""
Built-in :class:`~tango.executor.Executor` implementations.
"""
from .multicore_executor import MulticoreExecutor
================================================
FILE: tango/executors/multicore_executor.py
================================================
import logging
import os
import subprocess
import time
from tempfile import NamedTemporaryFile
from typing import Dict, List, Optional, OrderedDict, Sequence, Set, TypeVar
from tango.executor import ExecutionMetadata, Executor, ExecutorOutput
from tango.step import Step
from tango.step_graph import StepGraph
from tango.step_info import StepState
from tango.workspace import Workspace
logger = logging.getLogger(__name__)
T = TypeVar("T")
@Executor.register("multicore")
class MulticoreExecutor(Executor):
"""
A ``MulticoreExecutor`` runs the steps in parallel and caches their results.
"""
def __init__(
self,
workspace: Workspace,
include_package: Optional[Sequence[str]] = None,
parallelism: Optional[int] = 1,
num_tries_to_sync_states: int = 3,
wait_seconds_to_sync_states: int = 3,
) -> None:
super().__init__(workspace, include_package=include_package, parallelism=parallelism or 1)
assert self.parallelism is not None
if self.parallelism < 0:
self.parallelism = min(32, os.cpu_count() or 1)
# Perhaps there's a better way to do this without these being passed as args.
self._num_tries_to_sync_states = num_tries_to_sync_states
self._wait_seconds_to_sync_states = wait_seconds_to_sync_states
def execute_step_graph(
self, step_graph: StepGraph, run_name: Optional[str] = None
) -> ExecutorOutput:
"""
Execute a :class:`tango.step_graph.StepGraph`. This attempts to execute steps in parallel.
If a step fails, its dependent steps are not run, but unrelated steps are still executed.
Step failures will be logged, but no exceptions will be raised.
"""
_running: OrderedDict[str, subprocess.Popen] = OrderedDict({})
_successful: Dict[str, ExecutionMetadata] = {}
_failed: Dict[str, ExecutionMetadata] = {}
_queued_steps: List[str] = []
uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps()
def _sync_step_states() -> Dict[str, StepState]:
"""
Update the StepState info.
Although, this is not really elegant. The issue is as follows: The main multicore executor process
queues a step, and starts step execution in a different process. If we try to read the StepState
before that process has had time to update the StepState, the Workspace will throw the out of sync
error (IOError: process should be running but it's considered incomplete...).
Hence, we try to read a few times, so that the child process has time to update the step's state.
"""
attempts = 0
while attempts < self._num_tries_to_sync_states:
attempts += 1
try:
step_states = {step.name: self._get_state(step) for step in step_graph.values()}
break
except IOError:
if attempts == self._num_tries_to_sync_states:
raise
step_states = {}
time.sleep(self._wait_seconds_to_sync_states)
return step_states
def _has_incomplete_steps(step_states: Dict[str, StepState]) -> bool:
"""
Are there any steps in the graph that are currently:
1) running, or
2) queued, or
3) incomplete (with no failed dependencies).
If there are any failed dependencies for a step, it will never manage to run.
"""
def _failed_dependencies(step: Step) -> bool:
for dependency in step.recursive_dependencies:
if (
step_states[dependency.name] == StepState.FAILED
or dependency.name in _failed
):
return True
return False
uncacheable_leaf_step_names = {step.name for step in uncacheable_leaf_steps}
for step_name, step_state in step_states.items():
if (
step_name in _running
or step_name in _queued_steps
or (
# If the workspace already has a previous run, we disregard the failure.
step_state in [StepState.INCOMPLETE, StepState.FAILED]
and not _failed_dependencies(step_graph[step_name])
# We check for failures in this run.
and step_name not in _failed
)
or (
# Uncacheable leaf steps need to run, but their StepState will always be UNCACHEABLE.
step_name in uncacheable_leaf_step_names
and step_name not in _successful
and step_name not in _failed
and not _failed_dependencies(step_graph[step_name])
)
):
return True
return False
def _update_running_steps(step_states: Dict[str, StepState]) -> List[str]:
"""
Check the running processes for status. We use poll_status to check if the process ended,
but the StepState for checking completion/failure status, because after the process ends,
the lock release etc. sometimes takes a beat longer.
"""
done = []
errors = []
for step_name, process in _running.items():
poll_status = process.poll()
if poll_status is not None:
# The step may have finished since we synced step states.
if step_states[step_name] == StepState.RUNNING:
step_states[step_name] = self._get_state(step_graph[step_name])
if step_states[step_name] == StepState.UNCACHEABLE:
if poll_status == 0:
done.append(step_name)
else:
errors.append(step_name)
elif step_states[step_name] == StepState.COMPLETED:
done.append(step_name)
elif (
step_states[step_name] == StepState.FAILED
or step_states[step_name] == StepState.INCOMPLETE
):
# TODO: look into why the step status changes from running back to incomplete sometimes.
# Possibly it's due to the workspace being aggressive in marking it as incomplete when
# it thinks that the process is not running.
errors.append(step_name)
else:
raise RuntimeError(
f"Step '{step_name}' has the state {step_states[step_name]}, "
"but the corresponding process has ended!"
)
for step_name in done + errors:
_running.pop(step_name)
for step_name in done:
step = step_graph[step_name]
_successful[step_name] = ExecutionMetadata(
result_location=None
if not step.cache_results
else self.workspace.step_info(step).result_location
)
for step_name in errors:
_failed[step_name] = ExecutionMetadata()
return errors
def _get_steps_to_run(step_states: Dict[str, StepState]) -> Set[str]:
"""
Returns the steps that can be queued to run immediately.
Criteria:
1) All dependencies are available.
2) Step is not already running or queued.
3) Step has not run in the past and failed.
4) Step's state is INCOMPLETE (or FAILED from a previous run), or
step's state is UNCACHEABLE and it is a leaf step.
(We only run uncacheable steps if they are needed for another step downstream,
as part of the downstream step).
"""
def _are_dependencies_available(step: Step) -> bool:
for dependency in step.dependencies:
if step_states[dependency.name] not in [
StepState.COMPLETED,
StepState.UNCACHEABLE,
]:
return False
return True
to_run: Set[str] = set()
for step in step_graph.values():
if (
_are_dependencies_available(step)
and step.name not in _running # Not already running.
and step.name not in _queued_steps # Not queued to run.
and step.name not in _failed # Not already failed.
# See comment in _has_incomplete_steps
and (
step_states[step.name] in [StepState.INCOMPLETE, StepState.FAILED]
or (
step_states[step.name] == StepState.UNCACHEABLE
and step in uncacheable_leaf_steps
and step.name not in _successful
)
)
):
to_run.add(step.name)
return to_run
def _queue_step(step_name: str) -> None:
_queued_steps.append(step_name)
logger.debug(f"Step {step_name} added to the queue for execution.")
def _try_to_execute_next_step(config_path: str, run_name: Optional[str] = None) -> None:
"""
If there are queued steps, try to start processes for them (limited by `parallelism`).
"""
if len(_queued_steps) == 0:
logger.debug("No steps in queue!")
return
if len(_running) < (self.parallelism or 1):
step_name = _queued_steps.pop(0)
command: List[str] = [
"tango",
"--called-by-executor",
"run",
config_path,
"-s",
step_name,
"-w",
self.workspace.url,
]
if self.include_package is not None:
for package in self.include_package:
command += ["-i", package]
if run_name is not None:
command += ["-n", run_name]
process = subprocess.Popen(command, shell=False)
_running[step_name] = process
else:
logger.debug(
f"{self.parallelism or 1} steps are already running. Will attempt to execute later."
)
# Creates a temporary file in which to store the config. This is passed as a command line
# argument to child step processes.
with NamedTemporaryFile(prefix="step-graph-to-file-run", suffix=".jsonnet") as file_ref:
step_graph.to_file(file_ref.name, include_unique_id=True)
assert os.path.exists(file_ref.name)
step_states = _sync_step_states()
while _has_incomplete_steps(step_states):
# Cleanup previously running steps.
_update_running_steps(step_states)
# Get steps that are ready to run.
to_run = _get_steps_to_run(step_states)
if to_run:
logger.debug(f"Steps ready to run: {to_run}")
for step_name in to_run:
_queue_step(step_name)
# Begin processes for any queued steps (if not enough processes are already running).
while len(_queued_steps) > 0 and len(_running) < (self.parallelism or 1):
_try_to_execute_next_step(config_path=file_ref.name, run_name=run_name)
# Re-sync the StepState info.
step_states = _sync_step_states()
assert not _running and not _queued_steps
_not_run: Dict[str, ExecutionMetadata] = {}
for step_name, step in step_graph.items():
if step_name in _successful or step_name in _failed:
# tried to execute directly
continue
elif not step.cache_results and step not in uncacheable_leaf_steps:
# uncacheable interior step; didn't execute directly.
continue
elif (
step.cache_results
and step_name in step_states
and step_states[step_name] == StepState.COMPLETED
):
# step result was found in cache.
# NOTE: since neither `Step.result()` nor `Step.ensure_result()` will have been
# called, we invoke the CLI logger here to let users know that we didn't run this
# step because we found it in the cache.
step.log_cache_hit()
_successful[step_name] = ExecutionMetadata(
result_location=self.workspace.step_info(step_graph[step_name]).result_location
)
else:
# step wasn't executed because parents failed, or
# step is uncacheable leaf step, so we do care about what happened to it.
_not_run[step_name] = ExecutionMetadata()
return ExecutorOutput(successful=_successful, failed=_failed, not_run=_not_run)
def _get_state(self, step: Step) -> StepState:
"""
Returns the StepState as determined by the workspace.
"""
return self.workspace.step_info(step).state
================================================
FILE: tango/format.py
================================================
import bz2
import dataclasses
import gzip
import importlib
import json
import logging
import lzma
from abc import abstractmethod
from os import PathLike
from pathlib import Path
from typing import (
IO,
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
import dill
from tango.common import DatasetDict, filename_is_safe
from tango.common.aliases import PathOrStr
from tango.common.exceptions import ConfigurationError
from tango.common.registrable import Registrable
from tango.common.sequences import SqliteSparseSequence
T = TypeVar("T")
class Format(Registrable, Generic[T]):
"""
Formats write objects to directories and read them back out.
In the context of Tango, the objects that are written by formats are usually
the result of a :class:`~tango.step.Step`.
"""
VERSION: str = NotImplemented
"""
Formats can have versions. Versions are part of a step's unique signature, part of
:attr:`~tango.step.Step.unique_id`, so when a step's format changes,
that will cause the step to be recomputed.
"""
default_implementation = "dill"
@abstractmethod
def write(self, artifact: T, dir: PathOrStr):
"""Writes the ``artifact`` to the directory at ``dir``."""
raise NotImplementedError()
@abstractmethod
def read(self, dir: PathOrStr) -> T:
"""Reads an artifact from the directory at ``dir`` and returns it."""
raise NotImplementedError()
def _to_params(self) -> Dict[str, Any]:
params_dict = super()._to_params()
for key in ["logger", "__orig_class__"]:
params_dict.pop(key, None) # Removing unnecessary keys.
params_dict["type"] = self.__module__ + "." + self.__class__.__qualname__
return params_dict
_OPEN_FUNCTIONS: Dict[Optional[str], Callable[[PathLike, str], IO]] = {
None: open,
"None": open,
"none": open,
"null": open,
"gz": gzip.open, # type: ignore
"gzip": gzip.open, # type: ignore
"bz": bz2.open, # type: ignore
"bz2": bz2.open, # type: ignore
"bzip": bz2.open, # type: ignore
"bzip2": bz2.open, # type: ignore
"lzma": lzma.open,
}
_SUFFIXES: Dict[Callable, str] = {
open: "",
gzip.open: ".gz",
bz2.open: ".bz2",
lzma.open: ".xz",
}
def _open_compressed(filename: PathOrStr, mode: str) -> IO:
open_fn: Callable
filename = str(filename)
for open_fn, suffix in _SUFFIXES.items():
if len(suffix) > 0 and filename.endswith(suffix):
break
else:
open_fn = open
return open_fn(filename, mode)
@Format.register("dill")
class DillFormat(Format[T], Generic[T]):
"""
This format writes the artifact as a single file called "data.dill" using dill
(a drop-in replacement for pickle). Optionally, it can compress the data.
This is very flexible, but not always the fastest.
.. tip::
This format has special support for iterables. If you write an iterator, it will consume the
iterator. If you read an iterator, it will read the iterator lazily.
"""
VERSION = "001"
def __init__(self, compress: Optional[str] = None):
if compress not in _OPEN_FUNCTIONS:
raise ConfigurationError(f"The {compress} compression format does not exist.")
self.compress = compress
def write(self, artifact: T, dir: PathOrStr):
filename = self._get_artifact_path(dir)
open_method = _OPEN_FUNCTIONS[self.compress]
with open_method(filename, "wb") as f:
pickler = dill.Pickler(file=f)
pickler.dump(self.VERSION)
if hasattr(artifact, "__next__"):
pickler.dump(True)
for item in cast(Iterable, artifact):
pickler.dump(item)
else:
pickler.dump(False)
pickler.dump(artifact)
def read(self, dir: PathOrStr) -> T:
filename = self._get_artifact_path(dir)
open_method = _OPEN_FUNCTIONS[self.compress]
with open_method(filename, "rb") as f:
unpickler = dill.Unpickler(file=f)
version = unpickler.load()
if version > self.VERSION:
raise ValueError(
f"File {filename} is too recent for this version of {self.__class__}."
)
iterator = unpickler.load()
if iterator:
return DillFormatIterator(filename) # type: ignore
else:
return unpickler.load()
def _get_artifact_path(self, dir: PathOrStr) -> Path:
return Path(dir) / ("data.dill" + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]])
class DillFormatIterator(Iterator[T], Generic[T]):
"""
An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.DillFormat.read`.
"""
def __init__(self, filename: PathOrStr):
self.f: Optional[IO[Any]] = _open_compressed(filename, "rb")
self.unpickler = dill.Unpickler(self.f)
version = self.unpickler.load()
if version > DillFormat.VERSION:
raise ValueError(f"File {filename} is too recent for this version of {self.__class__}.")
iterator = self.unpickler.load()
if not iterator:
raise ValueError(
f"Tried to open {filename} as an iterator, but it does not store an iterator."
)
def __iter__(self) -> Iterator[T]:
return self
def __next__(self) -> T:
if self.f is None:
raise StopIteration()
try:
return self.unpickler.load()
except EOFError:
self.f.close()
self.f = None
raise StopIteration()
@Format.register("json")
class JsonFormat(Format[T], Generic[T]):
"""This format writes the artifact as a single file in json format.
Optionally, it can compress the data. This is very flexible, but not always the fastest.
.. tip::
This format has special support for iterables. If you write an iterator, it will consume the
iterator. If you read an iterator, it will read the iterator lazily.
"""
VERSION = "002"
def __init__(self, compress: Optional[str] = None):
self.logger = logging.getLogger(self.__class__.__name__)
if compress not in _OPEN_FUNCTIONS:
raise ConfigurationError(f"The {compress} compression format does not exist.")
self.compress = compress
@staticmethod
def _encoding_fallback(unencodable: Any):
try:
import torch
if isinstance(unencodable, torch.Tensor):
if len(unencodable.shape) == 0:
return unencodable.item()
else:
raise TypeError(
"Tensors must have 1 element and no dimensions to be JSON serializable."
)
except ImportError:
pass
if dataclasses.is_dataclass(unencodable):
result = dataclasses.asdict(unencodable)
module = type(unencodable).__module__
qualname = type(unencodable).__qualname__
if module == "builtins":
result["_dataclass"] = qualname
else:
result["_dataclass"] = [module, qualname]
return result
raise TypeError(f"Object of type {type(unencodable)} is not JSON serializable")
@staticmethod
def _decoding_fallback(o: Dict) -> Any:
if "_dataclass" in o:
classname: Union[str, List[str]] = o.pop("_dataclass")
if isinstance(classname, list) and len(classname) == 2:
module, classname = classname
constructor: Callable = importlib.import_module(module) # type: ignore
for item in classname.split("."):
constructor = getattr(constructor, item)
elif isinstance(classname, str):
constructor = globals()[classname]
else:
raise RuntimeError(f"Could not parse {classname} as the name of a dataclass.")
return constructor(**o)
return o
def write(self, artifact: T, dir: PathOrStr):
open_method = _OPEN_FUNCTIONS[self.compress]
if hasattr(artifact, "__next__"):
filename = self._get_artifact_path(dir, iterator=True)
with open_method(filename, "wt") as f:
for item in cast(Iterable, artifact):
json.dump(item, f, default=self._encoding_fallback)
f.write("\n")
else:
filename = self._get_artifact_path(dir, iterator=False)
with open_method(filename, "wt") as f:
json.dump(artifact, f, default=self._encoding_fallback)
def read(self, dir: PathOrStr) -> T:
iterator_filename = self._get_artifact_path(dir, iterator=True)
iterator_exists = iterator_filename.exists()
non_iterator_filename = self._get_artifact_path(dir, iterator=False)
non_iterator_exists = non_iterator_filename.exists()
if iterator_exists and non_iterator_exists:
self.logger.warning(
"Both %s and %s exist. Ignoring %s.",
iterator_filename,
non_iterator_filename,
iterator_filename,
)
iterator_exists = False
if not iterator_exists and not non_iterator_exists:
raise IOError("Attempting to read non-existing data from %s", dir)
if iterator_exists and not non_iterator_exists:
return JsonFormatIterator(iterator_filename) # type: ignore
elif not iterator_exists and non_iterator_exists:
open_method = _OPEN_FUNCTIONS[self.compress]
with open_method(non_iterator_filename, "rt") as f:
return json.load(f, object_hook=self._decoding_fallback)
else:
raise RuntimeError("This should be impossible.")
def _get_artifact_path(self, dir: PathOrStr, iterator: bool = False) -> Path:
return Path(dir) / (
("data.jsonl" if iterator else "data.json") + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]]
)
class JsonFormatIterator(Iterator[T], Generic[T]):
"""
An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.JsonFormat.read`.
"""
def __init__(self, filename: PathOrStr):
self.f: Optional[IO[Any]] = _open_compressed(filename, "rt")
def __iter__(self) -> Iterator[T]:
return self
def __next__(self) -> T:
if self.f is None:
raise StopIteration()
try:
line = self.f.readline()
if len(line) <= 0:
raise EOFError()
return json.loads(line, object_hook=JsonFormat._decoding_fallback)
except EOFError:
self.f.close()
self.f = None
raise StopIteration()
@Format.register("text")
class TextFormat(Format[Union[str, Iterable[str]]]):
"""This format writes the artifact as a single file in text format.
Optionally, it can compress the data. This is very flexible, but not always the fastest.
This format can only write strings, or iterable of strings.
.. tip::
This format has special support for iterables. If you write an iterator, it will consume the
iterator. If you read an iterator, it will read the iterator lazily.
Be aware that if your strings contain newlines, you will read out more strings than you wrote.
For this reason, it's often advisable to use :class:`JsonFormat` instead. With :class:`JsonFormat`,
all special characters are escaped, strings are quoted, but it's all still human-readable.
"""
VERSION = "001"
def __init__(self, compress: Optional[str] = None):
self.logger = logging.getLogger(self.__class__.__name__)
if compress not in _OPEN_FUNCTIONS:
raise ConfigurationError(f"The {compress} compression format does not exist.")
self.compress = compress
def write(self, artifact: Union[str, Iterable[str]], dir: PathOrStr):
open_method = _OPEN_FUNCTIONS[self.compress]
if hasattr(artifact, "__next__"):
filename = self._get_artifact_path(dir, iterator=True)
with open_method(filename, "wt") as f:
for item in cast(Iterable, artifact):
f.write(str(item))
f.write("\n")
else:
filename = self._get_artifact_path(dir, iterator=False)
with open_method(filename, "wt") as f:
f.write(str(artifact))
def read(self, dir: PathOrStr) -> Union[str, Iterable[str]]:
iterator_filename = self._get_artifact_path(dir, iterator=True)
iterator_exists = iterator_filename.exists()
non_iterator_filename = self._get_artifact_path(dir, iterator=False)
non_iterator_exists = non_iterator_filename.exists()
if iterator_exists and non_iterator_exists:
self.logger.warning(
"Both %s and %s exist. Ignoring %s.",
iterator_filename,
non_iterator_filename,
iterator_filename,
)
iterator_exists = False
if not iterator_exists and not non_iterator_exists:
raise IOError("Attempting to read non-existing data from %s", dir)
if iterator_exists and not non_iterator_exists:
return TextFormatIterator(iterator_filename) # type: ignore
elif not iterator_exists and non_iterator_exists:
open_method = _OPEN_FUNCTIONS[self.compress]
with open_method(non_iterator_filename, "rt") as f:
return f.read()
else:
raise RuntimeError("This should be impossible.")
def _get_artifact_path(self, dir: PathOrStr, iterator: bool = False) -> Path:
return Path(dir) / (
("texts.txt" if iterator else "text.txt") + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]]
)
class TextFormatIterator(Iterator[str]):
"""
An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.TextFormat.read`.
"""
def __init__(self, filename: PathOrStr):
self.f: Optional[IO[Any]] = _open_compressed(filename, "rt")
def __iter__(self) -> Iterator[str]:
return self
def __next__(self) -> str:
if self.f is None:
raise StopIteration()
try:
line = self.f.readline()
if len(line) <= 0:
raise EOFError()
if line.endswith("\n"):
line = line[:-1]
return line
except EOFError:
self.f.close()
self.f = None
raise StopIteration()
@Format.register("sqlite_sequence")
class SqliteSequenceFormat(Format[Sequence[T]]):
VERSION = "003"
FILENAME = "data.sqlite"
def write(self, artifact: Sequence[T], dir: Union[str, PathLike]):
dir = Path(dir)
try:
(dir / self.FILENAME).unlink()
except FileNotFoundError:
pass
if isinstance(artifact, SqliteSparseSequence):
artifact.copy_to(dir / self.FILENAME)
else:
sqlite = SqliteSparseSequence(dir / self.FILENAME)
sqlite.extend(artifact)
def read(self, dir: Union[str, PathLike]) -> Sequence[T]:
dir = Path(dir)
return SqliteSparseSequence(dir / self.FILENAME, read_only=True)
@Format.register("sqlite")
class SqliteDictFormat(Format[DatasetDict]):
"""This format works specifically on results of type :class:`~tango.common.DatasetDict`. It writes those
datasets into Sqlite databases.
During reading, the advantage is that the dataset can be read lazily. Reading a result that is stored
in :class:`SqliteDictFormat` takes milliseconds. No actual reading takes place until you access individual
instances.
During writing, you have to take some care to take advantage of the same trick. Recall that
:class:`~tango.DatasetDict` is basically a map, mapping split names to lists of instances. If you ensure
that those lists of instances are of type :class:`~tango.common.sequences.SqliteSparseSequence`, then writing
the results in :class:`SqliteDictFormat` can in many cases be instantaneous.
Here is an example of the pattern to use to make writing fast:
.. code-block:: Python
@Step.register("my_step")
class MyStep(Step[DatasetDict]):
FORMAT: Format = SqliteDictFormat()
VERSION = "001"
def run(self, ...) -> DatasetDict:
result: Dict[str, Sequence] = {}
for split_name in my_list_of_splits:
output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite")
for instance in instances:
output_split.append(instance)
result[split_name] = output_split
metadata = {}
return DatasetDict(result, metadata)
Observe how for each split, we create a :class:`~tango.common.sequences.SqliteSparseSequence` in the step's
work directory (accessible with :meth:`~tango.step.Step.work_dir`). This has the added advantage that if the
step fails and you have to re-run it, the previous results that were already written to the
:class:`~tango.common.sequences.SqliteSparseSequence` are still there. You could replace the inner ``for``
loop like this to take advantage:
.. code-block:: Python
output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite")
for instance in instances[len(output_split):]: # <-- here is the difference
output_split.append(instance)
result[split_name] = output_split
This works because when you re-run the step, the work directory will still be there, so ``output_split`` is
not empty when you open it.
"""
VERSION = "003"
def write(self, artifact: DatasetDict, dir: Union[str, PathLike]):
dir = Path(dir)
with gzip.open(dir / "metadata.dill.gz", "wb") as f:
dill.dump(artifact.metadata, f)
for split_name, split in artifact.splits.items():
filename = f"{split_name}.sqlite"
if not filename_is_safe(filename):
raise ValueError(f"{split_name} is not a valid name for a split.")
try:
(dir / filename).unlink()
except FileNotFoundError:
pass
if isinstance(split, SqliteSparseSequence):
split.copy_to(dir / filename)
else:
sqlite = SqliteSparseSequence(dir / filename)
sqlite.extend(split)
def read(self, dir: Union[str, PathLike]) -> DatasetDict:
dir = Path(dir)
with gzip.open(dir / "metadata.dill.gz", "rb") as f:
metadata = dill.load(f)
splits = {
filename.stem: SqliteSparseSequence(filename, read_only=True)
for filename in dir.glob("*.sqlite")
}
return DatasetDict(metadata=metadata, splits=splits)
================================================
FILE: tango/integrations/__init__.py
================================================
"""
In :mod:`tango.integrations` we provide many ready-to-use `component <../components/index.html>`_
implementations for leveraging the functionality from popular libraries.
.. tip::
All registered components will be registered under a name that starts with the name of the integration module,
possibly followed by a double colon ("::") and another identifier if there are multiple registered
components of a given type.
For example, the :class:`~tango.integrations.datasets.LoadDataset` step in the `🤗 Datasets `_
integration is registered under the name "datasets::load", and the
:class:`~tango.integrations.torch.TorchFormat` format in the `PyTorch `_ integration
is registered under the name "torch".
"""
================================================
FILE: tango/integrations/beaker/__init__.py
================================================
"""
.. important::
To use this integration you should install ``tango`` with the "beaker" extra
(e.g. ``pip install tango[beaker]``) or just install the `beaker-py `_
library after the fact (e.g. ``pip install beaker-py``).
Components for Tango integration with `Beaker `_.
"""
from tango.common.exceptions import IntegrationMissingError
try:
from beaker import Beaker
except (ModuleNotFoundError, ImportError):
raise IntegrationMissingError("beaker", dependencies={"beaker-py"})
from .executor import (
BeakerExecutor,
BeakerScheduler,
ResourceAssignment,
ResourceAssignmentError,
SimpleBeakerScheduler,
UnrecoverableResourceAssignmentError,
)
from .step_cache import BeakerStepCache
from .workspace import BeakerWorkspace
__all__ = [
"BeakerStepCache",
"BeakerWorkspace",
"BeakerExecutor",
"BeakerScheduler",
"SimpleBeakerScheduler",
"ResourceAssignment",
"ResourceAssignmentError",
"UnrecoverableResourceAssignmentError",
]
================================================
FILE: tango/integrations/beaker/common.py
================================================
import atexit
import json
import logging
import os.path
import tempfile
import time
import urllib
import urllib.parse
from pathlib import Path
from typing import Any, Dict, Optional, Union
from beaker import Beaker
from beaker import Dataset as BeakerDataset
from beaker import DatasetConflict, DatasetNotFound, Experiment, ExperimentNotFound
from tango.common.remote_utils import RemoteConstants
from tango.step import Step
from tango.step_info import StepInfo
from tango.version import VERSION
logger = logging.getLogger(__name__)
class Constants(RemoteConstants):
ENTRYPOINT_DATASET_PREFIX = "tango-entrypoint-"
BEAKER_TOKEN_SECRET_NAME: str = "BEAKER_TOKEN"
GOOGLE_TOKEN_SECRET_NAME: str = "GOOGLE_TOKEN"
DEFAULT_GOOGLE_CREDENTIALS_FILE: str = os.path.expanduser(
os.path.join("~", ".config", "gcloud", "application_default_credentials.json")
)
ENTRYPOINT_DIR: str = "/tango/entrypoint"
ENTRYPOINT_FILENAME: str = "entrypoint.sh"
def get_client(beaker_workspace: Optional[str] = None, **kwargs) -> Beaker:
user_agent = f"tango v{VERSION}"
if beaker_workspace is not None:
return Beaker.from_env(
default_workspace=beaker_workspace,
session=True,
user_agent=user_agent,
**kwargs,
)
else:
return Beaker.from_env(session=True, user_agent=user_agent, **kwargs)
def dataset_url(beaker: Beaker, dataset: Optional[str] = None) -> str:
# this just creates a string url.
workspace_url = beaker.workspace.url()
if dataset:
return (
workspace_url
+ "/datasets?"
+ urllib.parse.urlencode(
{
"text": dataset,
"committed": "false",
}
)
)
return workspace_url
class BeakerStepLock:
METADATA_FNAME = "metadata.json"
def __init__(
self,
beaker: Beaker,
step: Union[str, StepInfo, Step],
current_beaker_experiment: Optional[Experiment] = None,
):
self._beaker = beaker
self._step_id = step if isinstance(step, str) else step.unique_id
self._lock_dataset_name = RemoteConstants.step_lock_artifact_name(step)
self._lock_dataset: Optional[BeakerDataset] = None
self._current_beaker_experiment = current_beaker_experiment
self.lock_dataset_url = dataset_url(beaker, self._lock_dataset_name)
@property
def metadata(self) -> Dict[str, Any]:
return {
"beaker_experiment": None
if not self._current_beaker_experiment
else self._current_beaker_experiment.id
}
def _last_metadata(self) -> Optional[Dict[str, Any]]:
try:
metadata_bytes = self._beaker.dataset.get_file(
self._lock_dataset_name, self.METADATA_FNAME, quiet=True
)
metadata = json.loads(metadata_bytes)
return metadata
except (DatasetNotFound, FileNotFoundError):
return None
def _acquiring_job_is_done(self) -> bool:
last_metadata = self._last_metadata()
if last_metadata is None:
return False
last_experiment_id = last_metadata.get("beaker_experiment")
if last_experiment_id is None:
return False
try:
last_experiment = self._beaker.experiment.get(last_experiment_id)
if (
self._current_beaker_experiment is not None
and self._current_beaker_experiment.id == last_experiment_id
):
# This means a previous job for this experiment was preempted and
# it didn't clean up after itself.
return True
else:
job = self._beaker.experiment.latest_job(last_experiment)
return False if job is None else job.is_done
except ExperimentNotFound:
# Experiment must have been deleted.
return True
except ValueError:
return False
def acquire(self, timeout=None, poll_interval: float = 2.0, log_interval: float = 30.0) -> None:
if self._lock_dataset is not None:
return
start = time.monotonic()
last_logged = None
while timeout is None or (time.monotonic() - start < timeout):
try:
self._lock_dataset = self._beaker.dataset.create(
self._lock_dataset_name, commit=False
)
atexit.register(self.release)
# Write metadata.
with tempfile.TemporaryDirectory() as tmp_dir_name:
tmp_dir = Path(tmp_dir_name)
metadata_path = tmp_dir / self.METADATA_FNAME
with open(metadata_path, "w") as f:
json.dump(self.metadata, f)
self._beaker.dataset.sync(self._lock_dataset, metadata_path, quiet=True)
except DatasetConflict:
# Check if existing lock was created from a Beaker experiment.
# If it was, and the experiment is no-longer running, we can safely
# delete it.
if self._acquiring_job_is_done():
self._beaker.dataset.delete(self._lock_dataset_name)
continue
now = time.monotonic()
if last_logged is None or now - last_logged >= log_interval:
logger.warning(
"Waiting to acquire lock dataset for step '%s':\n\n%s\n\n"
"This probably means the step is being run elsewhere, but if you're sure it isn't "
"you can just delete the lock dataset.",
self._step_id,
self.lock_dataset_url,
)
last_logged = now
time.sleep(poll_interval)
continue
else:
break
else:
raise TimeoutError(
f"Timeout error occurred while waiting to acquire dataset lock for step '{self._step_id}':\n\n"
f"{self.lock_dataset_url}\n\n"
f"This probably means the step is being run elsewhere, but if you're sure it isn't you can "
f"just delete the lock dataset."
)
def release(self):
if self._lock_dataset is not None:
try:
self._beaker.dataset.delete(self._lock_dataset)
except DatasetNotFound:
# Dataset must have been manually deleted.
pass
self._lock_dataset = None
atexit.unregister(self.release)
def __del__(self):
self.release()
================================================
FILE: tango/integrations/beaker/entrypoint.sh
================================================
#!/bin/bash
#
# This is the entrypoint script that the Beaker Executor uses when it runs a step
# on Beaker.
# It will work on any Docker image that has bash and conda / miniconda installed.
set -eo pipefail
# Ensure we have all the environment variables we need.
for env_var in "$GITHUB_TOKEN" "$GITHUB_REPO" "$GIT_REF"; do
if [[ -z "$env_var" ]]; then
echo >&2 "error: required environment variable is empty"
exit 1
fi
done
# Initialize conda for bash.
# See https://stackoverflow.com/a/58081608/4151392
eval "$(command conda 'shell.bash' 'hook' 2> /dev/null)"
echo "
[TANGO] [1/3] Installing prerequisites...
"
# Install GitHub CLI.
if ! command -v gh &> /dev/null; then
conda install gh --channel conda-forge
fi
# Configure git to use GitHub CLI as a credential helper so that we can clone private repos.
gh auth setup-git
echo "
[TANGO] [2/3] Cloning source code from '$GITHUB_REPO'...
"
# Clone the repo and checkout the target commit.
gh repo clone "$GITHUB_REPO" src
cd src
git checkout "$GIT_REF"
echo "
[TANGO] [3/3] Reconstructing Python env...
"
if [[ -z "$VENV_NAME" ]]; then
VENV_NAME=venv
fi
if [[ -z "$CONDA_ENV_FILE" ]]; then
# shellcheck disable=SC2296
CONDA_ENV_FILE="environment.yml"
fi
if [[ -z "$PIP_REQUIREMENTS_FILE" ]]; then
# shellcheck disable=SC2296
PIP_REQUIREMENTS_FILE="requirements.txt"
fi
if conda activate $VENV_NAME &>/dev/null; then
echo "[TANGO] Using existing conda environment '$VENV_NAME'"
# The virtual environment already exists. Possibly update it based on an environment file.
if [[ -f "$CONDA_ENV_FILE" ]]; then
echo "[TANGO] Updating environment from conda env file '$CONDA_ENV_FILE'..."
conda env update -f "$CONDA_ENV_FILE"
fi
else
# The virtual environment doesn't exist yet. Create it.
if [[ -f "$CONDA_ENV_FILE" ]]; then
# Create from the environment file.
echo "[TANGO] Initializing environment from conda env file '$CONDA_ENV_FILE'..."
conda env create -n "$VENV_NAME" -f "$CONDA_ENV_FILE"
elif [[ -z "$PYTHON_VERSION" ]]; then
# Create a new empty environment with the whatever the default Python version is.
echo "[TANGO] Initializing environment with default Python version..."
conda create -n "$VENV_NAME" pip
else
# Create a new empty environment with the specific Python version.
echo "[TANGO] Initializing environment with Python $PYTHON_VERSION..."
conda create -n "$VENV_NAME" "python=$PYTHON_VERSION" pip
fi
conda activate "$VENV_NAME"
fi
# Every time Beaker changes their APIs, we need to upgrade beaker-py. This happens all the
# time, so we make sure we have the latest.
# We do this when the conda environment is up, but before the requirements, so that
# requirements can request a particular beaker-py version if they want.
pip install --upgrade beaker-py
if [[ -z "$INSTALL_CMD" ]]; then
# Check for a 'requirements.txt' and/or 'setup.py/pyproject.toml/setup.cfg' file.
if ( [[ -f 'setup.py' ]] || [[ -f 'pyproject.toml' ]] || [[ -f 'setup.cfg' ]] ) && [[ -f "$PIP_REQUIREMENTS_FILE" ]]; then
echo "[GANTRY] Installing local project and packages from '$PIP_REQUIREMENTS_FILE'..."
pip install . -r "$PIP_REQUIREMENTS_FILE"
elif ( [[ -f 'setup.py' ]] || [[ -f 'pyproject.toml' ]] || [[ -f 'setup.cfg' ]] ); then
echo "[GANTRY] Installing local project..."
pip install .
elif [[ -f "$PIP_REQUIREMENTS_FILE" ]]; then
echo "[GANTRY] Installing packages from '$PIP_REQUIREMENTS_FILE'..."
pip install -r "$PIP_REQUIREMENTS_FILE"
fi
else
echo "[TANGO] Installing packages with given command: $INSTALL_CMD"
eval "$INSTALL_CMD"
fi
PYTHONPATH="$(pwd)"
export PYTHONPATH
echo "
Environment info:
"
echo "Using $(python --version) from $(which python)"
echo "Packages:"
if which sed >/dev/null; then
pip freeze | sed 's/^/- /'
else
pip freeze
fi
echo "
[TANGO] Setup complete ✓
"
# Execute the arguments to this script as commands themselves.
exec "$@"
================================================
FILE: tango/integrations/beaker/executor.py
================================================
import json
import logging
import os
import threading
import time
import uuid
import warnings
from abc import abstractmethod
from typing import Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
from beaker import (
Beaker,
DataMount,
Dataset,
DatasetConflict,
DatasetNotFound,
Digest,
EnvVar,
Experiment,
ExperimentNotFound,
ExperimentSpec,
JobFailedError,
JobTimeoutError,
NodeResources,
Priority,
TaskResources,
TaskSpec,
TaskStoppedError,
)
from git import Git, GitCommandError, InvalidGitRepositoryError, Repo
from tango.common.exceptions import (
CancellationError,
ConfigurationError,
ExecutorError,
RunCancelled,
)
from tango.common.logging import cli_logger, log_exception
from tango.common.registrable import Registrable
from tango.executor import ExecutionMetadata, Executor, ExecutorOutput
from tango.step import Step
from tango.step_graph import StepGraph
from tango.step_info import GitMetadata
from tango.version import VERSION
from tango.workspace import Workspace
from .common import Constants, get_client
logger = logging.getLogger(__name__)
class StepFailedError(ExecutorError):
def __init__(self, msg: str, experiment_url: str):
super().__init__(msg)
self.experiment_url = experiment_url
class ResourceAssignmentError(ExecutorError):
"""
Raised when a scheduler can't find enough free resources at the moment to run a step.
"""
class UnrecoverableResourceAssignmentError(ExecutorError):
"""
An unrecoverable version of :class:`ResourceAssignmentError`. Raises this
from a :class:`BeakerScheduler` will cause the executor to fail.
"""
class ResourceAssignment(NamedTuple):
"""
Resources assigned to a step.
"""
cluster: Union[str, List[str]]
"""
The cluster(s) to use to execute the step.
"""
resources: TaskResources
"""
The compute resources on the cluster to allocate for execution of the step.
"""
priority: Union[str, Priority]
"""
The priority to execute the step with.
"""
class BeakerScheduler(Registrable):
"""
A :class:`BeakerScheduler` is responsible for determining which resources and priority to
assign to the execution of a step.
"""
default_implementation = "simple"
"""
The default implementation is :class:`SimpleBeakerScheduler`.
"""
def __init__(self):
self._beaker: Optional[Beaker] = None
@property
def beaker(self) -> Beaker:
if self._beaker is None:
raise ValueError("'beaker' client has not be assigned to scheduler yet!")
return self._beaker
@beaker.setter
def beaker(self, beaker: Beaker) -> None:
self._beaker = beaker
@abstractmethod
def schedule(self, step: Step) -> ResourceAssignment:
"""
Determine the :class:`ResourceAssignment` for a step.
:raises ResourceAssignmentError: If the scheduler can't find enough free
resources at the moment to run the step.
"""
raise NotImplementedError()
@BeakerScheduler.register("simple")
class SimpleBeakerScheduler(BeakerScheduler):
"""
The :class:`SimpleBeakerScheduler` just searches the given clusters for one
with enough resources to match what's specified by the step's required resources.
"""
def __init__(self, clusters: List[str], priority: Union[str, Priority]):
super().__init__()
self.clusters = clusters
self.priority = priority
self._node_resources: Optional[Dict[str, List[NodeResources]]] = None
if not self.clusters:
raise ConfigurationError("At least one cluster is required in 'clusters'")
@property
def node_resources(self) -> Dict[str, List[NodeResources]]:
if self._node_resources is None:
node_resources = {
cluster: [node.limits for node in self.beaker.cluster.nodes(cluster)]
for cluster in self.clusters
}
self._node_resources = node_resources
return node_resources
else:
return self._node_resources
def schedule(self, step: Step) -> ResourceAssignment:
step_resources = step.resources
task_resources = TaskResources(
cpu_count=step_resources.cpu_count,
gpu_count=step_resources.gpu_count,
memory=step_resources.memory,
shared_memory=step_resources.shared_memory,
)
clusters = self.clusters
if step_resources.gpu_type is not None:
clusters = [
cluster
for cluster, nodes in self.node_resources.items()
if all([node.gpu_type == step_resources.gpu_type for node in nodes])
]
if not clusters:
raise UnrecoverableResourceAssignmentError(
f"Could not find cluster with nodes that have GPU type '{step_resources.gpu_type}'"
)
return ResourceAssignment(
cluster=clusters, resources=task_resources, priority=self.priority
)
@Executor.register("beaker")
class BeakerExecutor(Executor):
"""
This is a :class:`~tango.executor.Executor` that runs steps on `Beaker`_.
Each step is run as its own Beaker experiment.
.. tip::
Registered as an :class:`~tango.executor.Executor` under the name "beaker".
.. important::
The :class:`BeakerExecutor` requires that you run Tango within a GitHub repository and you push
all of your changes prior to each ``tango run`` call. It also requires that you have
a `GitHub personal access token `_
with at least the "repo" scope set to the environment variable ``GITHUB_TOKEN``
(you can also set it using the ``github_token`` parameter, see below).
This is because :class:`BeakerExecutor` has to be able to clone your code from Beaker.
.. important::
The :class:`BeakerExecutor` will try to recreate your Python environment on Beaker
every time a step is run, so it's important that you specify all of your dependencies
in a PIP ``requirements.txt`` file, ``setup.py`` file, or a conda ``environment.yml`` file.
Alternatively you could provide the ``install_cmd`` argument.
.. important::
The :class:`BeakerExecutor` takes no responsibility for saving the results of steps that
it runs on Beaker. That's the job of your workspace. So make sure your using the
right type of workspace or your results will be lost.
For example, any "remote" workspace (like the :class:`BeakerWorkspace`) would work,
or in some cases you could use a :class:`~tango.workspaces.LocalWorkspace` on an NFS drive.
.. important::
If you're running a step that requires special hardware, e.g. a GPU, you should
specify that in the ``step_resources`` parameter to the step, or by overriding
the step's :meth:`.resources() ` property method.
:param workspace: The :class:`~tango.workspace.Workspace` to use.
:param clusters: A list of Beaker clusters that the executor may use to run steps.
If ``scheduler`` is specified, this argument is ignored.
:param include_package: A list of Python packages to import before running steps.
:param beaker_workspace: The name or ID of the Beaker workspace to use.
:param github_token: You can use this parameter to set a GitHub personal access token instead of using
the ``GITHUB_TOKEN`` environment variable.
:param google_token: You can use this parameter to set a Google Cloud token instead of using
the ``GOOGLE_TOKEN`` environment variable.
:param beaker_image: The name or ID of a Beaker image to use for running steps on Beaker.
The image must come with bash and `conda `_
installed (Miniconda is okay).
This is mutually exclusive with the ``docker_image`` parameter. If neither ``beaker_image``
nor ``docker_image`` is specified, the :data:`DEFAULT_BEAKER_IMAGE` will be used.
:param docker_image: The name of a publicly-available Docker image to use for running
steps on Beaker. The image must come with bash and `conda `_
installed (Miniconda is okay).
This is mutually exclusive with the ``beaker_image`` parameter.
:param datasets: External data sources to mount into the Beaker job for each step. You could use
this to mount an NFS drive, for example.
:param env_vars: Environment variables to set in the Beaker job for each step.
:param venv_name: The name of the conda virtual environment to use or create on the image.
If you're using your own image that already has a conda environment you want to use,
you should set this variable to the name of that environment.
You can also set this to "base" to use the base environment.
:param parallelism: Control the maximum number of steps run in parallel on Beaker.
:param install_cmd: Override the command used to install your code and its dependencies
in each Beaker job.
For example, you could set ``install_cmd="pip install .[dev]"``.
:param priority: The default task priority to assign to jobs ran on Beaker.
If ``scheduler`` is specified, this argument is ignored.
:param scheduler: A :class:`BeakerScheduler` to use for assigning resources to steps.
If not specified the :class:`SimpleBeakerScheduler` is used with the given
``clusters`` and ``priority``.
:param allow_dirty: By default, the Beaker Executor requires that your git working directory has no uncommitted
changes. If you set this to ``True``, we skip this check.
:param kwargs: Additional keyword arguments passed to :meth:`Beaker.from_env() `.
.. attention::
Certain parameters should not be included in the :data:`~tango.settings.TangoGlobalSettings.executor`
part of your ``tango.yml`` file, namely ``workspace`` and ``include_package``.
Instead use the top-level :data:`~tango.settings.TangoGlobalSettings.workspace`
and :data:`~tango.settings.TangoGlobalSettings.include_package` fields, respectively.
:examples:
**Minimal tango.yaml file**
You can use this executor by specifying it in your ``tango.yml`` settings file:
.. code:: yaml
executor:
type: beaker
beaker_workspace: ai2/my-workspace
clusters:
- ai2/general-cirrascale
**Using GPUs**
If you have a step that requires a GPU, there are two things you need to do:
1. First, you'll need to ensure that the :class:`BeakerExecutor` can install your dependencies the right way
to support the GPU hardware. There are usually two ways to do this: use a Docker image that comes
with a proper installation of your hardware-specific dependencies (e.g. PyTorch), or add a conda
``environment.yml`` file to your project that specifies the proper version of those dependencies.
If you go with first option you don't necessarily need to build your own Docker image.
If PyTorch is the only hardware-specific dependency you have, you could just use
one of AI2's pre-built PyTorch images. Just add these lines to your ``tango.yml`` file:
.. code:: diff
executor:
type: beaker
beaker_workspace: ai2/my-workspace
+ docker_image: ghcr.io/allenai/pytorch:1.12.0-cuda11.3-python3.9
+ venv_name: base
clusters:
- ai2/general-cirrascale
The ``venv_name: base`` line tells the :class:`BeakerExecutor` to use the existing
conda environment called "base" on the image instead of creating a new one.
Alternatively, you could use the :data:`default image `
and just add a conda ``environment.yml`` file to the root of your project
that looks like this:
.. code:: yaml
name: torch-env
channels:
- pytorch
dependencies:
- python=3.9
- cudatoolkit=11.3
- numpy
- pytorch
- ...
2. And second, you'll need to specify the GPUs required by each step in the config for that step under
the :class:`step_resources ` parameter. For example,
.. code:: json
"steps": {
"train": {
"type": "torch::train",
"step_resources": {
"gpu_count": 1
}
}
}
"""
DEFAULT_BEAKER_IMAGE: str = "ai2/conda"
"""
The default image. Used if neither ``beaker_image`` nor ``docker_image`` are set.
"""
DEFAULT_NFS_DRIVE = "/net/nfs.cirrascale"
RESOURCE_ASSIGNMENT_WARNING_INTERVAL = 60 * 5
def __init__(
self,
workspace: Workspace,
clusters: Optional[List[str]] = None,
include_package: Optional[Sequence[str]] = None,
beaker_workspace: Optional[str] = None,
github_token: Optional[str] = None,
google_token: Optional[str] = None,
beaker_image: Optional[str] = None,
docker_image: Optional[str] = None,
datasets: Optional[List[DataMount]] = None,
env_vars: Optional[List[EnvVar]] = None,
venv_name: Optional[str] = None,
parallelism: Optional[int] = None,
install_cmd: Optional[str] = None,
priority: Optional[Union[str, Priority]] = None,
allow_dirty: bool = False,
scheduler: Optional[BeakerScheduler] = None,
budget: Optional[str] = None,
**kwargs,
):
# Pre-validate arguments.
if beaker_image is None and docker_image is None:
beaker_image = self.DEFAULT_BEAKER_IMAGE
elif (beaker_image is None) == (docker_image is None):
raise ConfigurationError(
"Either 'beaker_image' or 'docker_image' must be specified for BeakerExecutor, but not both."
)
if budget is None:
raise ConfigurationError("You must specify a budget to use the beaker executor.")
else:
self._budget = budget
from tango.workspaces import LocalWorkspace, MemoryWorkspace
if isinstance(workspace, MemoryWorkspace):
raise ConfigurationError(
"You cannot use the `MemoryWorkspace` with the `BeakerExecutor`! "
"Please specify a different workspace."
)
elif isinstance(workspace, LocalWorkspace):
if str(workspace.dir).startswith(self.DEFAULT_NFS_DRIVE):
# Mount the NFS drive if not mount already.
datasets = datasets or []
if not datasets or not any(
[
dm.source.host_path is not None
and dm.source.host_path.startswith(self.DEFAULT_NFS_DRIVE)
for dm in datasets
]
):
nfs_mount = DataMount.new(
self.DEFAULT_NFS_DRIVE, host_path=self.DEFAULT_NFS_DRIVE
)
datasets.append(nfs_mount)
else:
warnings.warn(
"It appears that you're using a `LocalWorkspace` on a directory that is not an NFS drive. "
"If the `BeakerExecutor` cannot access this directory from Beaker, your results will be lost.",
UserWarning,
)
super().__init__(workspace, include_package=include_package, parallelism=parallelism)
self.max_thread_workers = self.parallelism or min(32, (os.cpu_count() or 1) + 4)
self.beaker = get_client(beaker_workspace=beaker_workspace, **kwargs)
self.beaker_image = beaker_image
self.docker_image = docker_image
self.datasets = datasets
self.env_vars = env_vars
self.venv_name = venv_name
self.install_cmd = install_cmd
self.allow_dirty = allow_dirty
self.scheduler: BeakerScheduler
if scheduler is None:
if clusters is None:
raise ConfigurationError(
"Either 'scheduler' or 'clusters' argument to BeakerExecutor is required"
)
self.scheduler = SimpleBeakerScheduler(clusters, priority=priority or Priority.normal)
else:
if clusters is not None:
warnings.warn(
"The 'clusters' parameter will be ignored since you specified a 'scheduler'",
UserWarning,
)
if priority is not None:
warnings.warn(
"The 'priority' parameter will be ignored since you specified a 'scheduler'",
UserWarning,
)
self.scheduler = scheduler
self.scheduler.beaker = self.beaker
self._is_cancelled = threading.Event()
self._logged_git_info = False
self._last_resource_assignment_warning: Optional[float] = None
self._jobs = 0
try:
self.github_token: str = github_token or os.environ["GITHUB_TOKEN"]
except KeyError:
raise ConfigurationError(
"A GitHub personal access token with the 'repo' scope is required. "
"This can be set with the 'github_token' argument to the BeakerExecutor, "
"or as the environment variable 'GITHUB_TOKEN'."
)
self.google_token = google_token or os.environ.get("GOOGLE_TOKEN")
# Check if google auth credentials are in the default location
if self.google_token is None and os.path.exists(Constants.DEFAULT_GOOGLE_CREDENTIALS_FILE):
self.google_token = Constants.DEFAULT_GOOGLE_CREDENTIALS_FILE
# If credentials are provided in the form of a file path, load the credentials
# so that they can be used in beaker. Do this only if required, i.e., only if GSWorkspace
# is being used.
if self.google_token is not None and self.google_token.endswith(".json"):
from tango.integrations.gs import GSWorkspace
if isinstance(workspace, GSWorkspace):
with open(self.google_token) as f:
self.google_token = f.read()
if self.google_token is None:
self.google_token = "default"
# Ensure entrypoint dataset exists.
self._ensure_entrypoint_dataset()
# Get repo info and make sure we're in a GitHub-hosted repository.
git = GitMetadata.check_for_repo()
if (
git is None
or git.commit is None
or git.remote is None
or "github.com" not in git.remote
):
raise ExecutorError(
f"Missing git data. "
f"BeakerExecutor requires a git repository with a GitHub remote."
)
self._github_account, self._github_repo = self._parse_git_remote(git.remote)
self._git_commit = git.commit
def check_repo_state(self):
if not self.allow_dirty:
# Make sure repository is clean, if we're in one.
try:
# Check for uncommitted changes.
repo = Repo(".")
if repo.is_dirty():
raise ExecutorError(
"You have uncommitted changes! Commit your changes or use the 'allow_dirty' option."
)
# Check for un-pushed commits.
remote_name = repo.remote().name
git = Git(".")
if git.log([f"{remote_name}..HEAD", "--not", "--remotes", "--oneline"]):
raise ExecutorError(
"You have unpushed changes! Push your changes or use the 'allow_dirty' option."
)
except InvalidGitRepositoryError:
raise ExecutorError(
"It appears you're not in a valid git repository. "
"The Beaker executor requires a git repository."
)
except GitCommandError:
pass
def execute_step_graph(
self, step_graph: StepGraph, run_name: Optional[str] = None
) -> ExecutorOutput:
import concurrent.futures
self.check_repo_state()
self._is_cancelled.clear()
# These will hold the final results which we'll update along the way.
successful: Dict[str, ExecutionMetadata] = {}
failed: Dict[str, ExecutionMetadata] = {}
not_run: Dict[str, ExecutionMetadata] = {}
# Keeps track of steps that are next up to run on Beaker.
steps_to_run: Set[str] = set()
# These are steps that have been submitted to Beaker but haven't completed yet.
submitted_steps: Set[str] = set()
# Futures for tracking the Beaker runs for each step.
step_futures: List[concurrent.futures.Future] = []
uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps()
# These are all of the steps that still need to be run at some point.
steps_left_to_run = uncacheable_leaf_steps | {
step for step in step_graph.values() if step.cache_results
}
def update_steps_to_run():
nonlocal steps_to_run, not_run
for step_name, step in step_graph.items():
if (
step_name in submitted_steps
or step_name in successful
or step_name in failed
or step_name in not_run
):
# Make sure step is no longer in queue.
steps_to_run.discard(step_name) # This does NOT raise KeyError if not found
else:
# Check dependencies.
for dependency in step.dependencies:
if dependency.name not in successful and dependency.cache_results:
if dependency.name in failed or dependency.name in not_run:
# A dependency failed or can't be run, so this step can't be run.
not_run[step_name] = ExecutionMetadata()
steps_to_run.discard(step_name)
steps_left_to_run.discard(step)
break
else:
# Dependencies are OK, so we can run this step now.
if step.cache_results or step in uncacheable_leaf_steps:
steps_to_run.add(step_name)
def make_future_done_callback(step_name: str):
def future_done_callback(future: concurrent.futures.Future):
nonlocal successful, failed, steps_left_to_run
self._jobs = max(0, self._jobs - 1)
step = step_graph[step_name]
try:
exc = future.exception()
if exc is None:
successful[step_name] = ExecutionMetadata(
result_location=None
if not step.cache_results
else self.workspace.step_info(step).result_location,
logs_location=future.result(),
)
steps_left_to_run.discard(step)
elif isinstance(exc, ResourceAssignmentError):
submitted_steps.discard(step_name)
self._emit_resource_assignment_warning()
elif isinstance(exc, StepFailedError):
failed[step_name] = ExecutionMetadata(logs_location=exc.experiment_url)
steps_left_to_run.discard(step)
elif isinstance(exc, (ExecutorError, CancellationError)):
failed[step_name] = ExecutionMetadata()
steps_left_to_run.discard(step)
else:
log_exception(exc, logger)
failed[step_name] = ExecutionMetadata()
steps_left_to_run.discard(step)
except concurrent.futures.TimeoutError as exc:
log_exception(exc, logger)
failed[step_name] = ExecutionMetadata()
steps_left_to_run.discard(step)
return future_done_callback
last_progress_update = time.monotonic()
def log_progress():
nonlocal last_progress_update
now = time.monotonic()
if now - last_progress_update >= 60 * 2:
last_progress_update = now
waiting_for = [
step_name
for step_name in submitted_steps
if step_name not in failed and step_name not in successful
]
if len(waiting_for) > 5:
logger.info(
"Waiting for %d steps...",
len(waiting_for),
)
elif len(waiting_for) > 1:
logger.info(
"Waiting for %d steps (%s)...",
len(waiting_for),
"'" + "', '".join(waiting_for) + "'",
)
elif len(waiting_for) == 1:
logger.info("Waiting for 1 step ('%s')...", list(waiting_for)[0])
still_to_run = [
step.name for step in steps_left_to_run if step.name not in submitted_steps
]
if len(still_to_run) > 5:
logger.info(
"Still waiting to submit %d more steps...",
len(still_to_run),
)
elif len(still_to_run) > 1:
logger.info(
"Still waiting to submit %d more steps (%s)...",
len(still_to_run),
"'" + "', '".join(still_to_run) + "'",
)
elif len(still_to_run) == 1:
logger.info("Still waiting to submit 1 more step ('%s')...", still_to_run[0])
update_steps_to_run()
try:
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_workers) as pool:
while steps_left_to_run:
# Submit steps left to run.
for step_name in steps_to_run:
future = pool.submit(
self._execute_sub_graph_for_step, step_graph, step_name, True
)
future.add_done_callback(make_future_done_callback(step_name))
self._jobs += 1
step_futures.append(future)
submitted_steps.add(step_name)
if step_futures:
# Wait for something to happen.
_, not_done = concurrent.futures.wait(
step_futures,
return_when=concurrent.futures.FIRST_COMPLETED,
timeout=2.0,
)
# Update the list of running futures.
step_futures.clear()
step_futures = list(not_done)
else:
time.sleep(2.0)
# Update the step queue.
update_steps_to_run()
log_progress()
except (KeyboardInterrupt, CancellationError):
if step_futures:
cli_logger.warning("Received interrupt, canceling steps...")
self._is_cancelled.set()
concurrent.futures.wait(step_futures)
raise
finally:
self._is_cancelled.clear()
# NOTE: The 'done callback' added to each future is executed in a thread,
# and so might not complete before the last 'update_steps_to_run()' is called
# in the loop above. Therefore we have to call 'update_steps_to_run()'
# one last time here to ensure the 'not_run' set is up-to-date.
update_steps_to_run()
return ExecutorOutput(successful=successful, failed=failed, not_run=not_run)
def _emit_resource_assignment_warning(self):
if self._last_resource_assignment_warning is None or (
time.monotonic() - self._last_resource_assignment_warning
> self.RESOURCE_ASSIGNMENT_WARNING_INTERVAL
):
self._last_resource_assignment_warning = time.monotonic()
logger.warning(
"Some steps can't be run yet - waiting on more Beaker resources "
"to become available..."
)
def _check_if_cancelled(self):
if self._is_cancelled.is_set():
raise RunCancelled
def _execute_sub_graph_for_step(
self,
step_graph: StepGraph,
step_name: str,
in_thread: bool = False,
) -> Optional[str]:
if not in_thread:
self._is_cancelled.clear()
else:
self._check_if_cancelled()
step = step_graph[step_name]
if step.cache_results and step in self.workspace.step_cache:
cli_logger.info(
'[green]\N{check mark} Found output for step [bold]"%s"[/] in cache...[/]',
step_name,
)
return None
if step.resources.machine == "local":
if step.cache_results:
step.ensure_result(self.workspace)
else:
result = step.result(self.workspace)
if hasattr(result, "__next__"):
from collections import deque
deque(result, maxlen=0)
return None
experiment: Optional[Experiment] = None
experiment_url: Optional[str] = None
ephemeral_datasets: List[Dataset] = []
# Try to find any existing experiments for this step that are still running.
if step.cache_results:
for exp in self.beaker.workspace.experiments(
match=f"{Constants.STEP_EXPERIMENT_PREFIX}{step.unique_id}-"
):
self._check_if_cancelled()
try:
latest_job = self.beaker.experiment.latest_job(exp)
except (ValueError, ExperimentNotFound):
continue
if latest_job is not None and not latest_job.is_done:
experiment = exp
experiment_url = self.beaker.experiment.url(exp)
cli_logger.info(
"[blue]\N{black rightwards arrow} Found existing Beaker experiment [b]%s[/] for "
'step [b]"%s"[/] that is still running...[/]',
experiment_url,
step_name,
)
break
# Otherwise we submit a new experiment...
if experiment is None:
# Initialize experiment and task spec.
experiment_name, spec, ephemeral_datasets = self._build_experiment_spec(
step_graph, step_name
)
self._check_if_cancelled()
step.log_starting()
# Create experiment.
experiment = self.beaker.experiment.create(experiment_name, spec)
experiment_url = self.beaker.experiment.url(experiment)
cli_logger.info(
'[blue]\N{black rightwards arrow} Submitted Beaker experiment [b]%s[/] for step [b]"%s"[/]...[/]',
experiment_url,
step_name,
)
assert experiment is not None
assert experiment_url is not None
# Follow the experiment until it completes.
try:
while True:
poll_interval = min(60, 5 * min(self._jobs, self.max_thread_workers))
try:
self._check_if_cancelled()
self.beaker.experiment.wait_for(
experiment,
strict=True,
quiet=True,
timeout=poll_interval + 2,
poll_interval=poll_interval,
)
break
except JobTimeoutError:
time.sleep(poll_interval)
continue
except (JobFailedError, TaskStoppedError):
cli_logger.error(
'[red]\N{ballot x} Step [b]"%s"[/] failed. You can check the logs at [b]%s[/][/]',
step_name,
experiment_url,
)
raise StepFailedError(
f'Beaker job for step "{step_name}" failed. '
f"You can check the logs at {experiment_url}",
experiment_url,
)
except (KeyboardInterrupt, CancellationError):
cli_logger.warning(
'Stopping Beaker experiment [cyan]%s[/] for step [b]"%s"[/] (%s)',
experiment_url,
step_name,
step.unique_id,
)
self.beaker.experiment.stop(experiment)
raise
else:
step.log_finished()
finally:
# Remove ephemeral datasets.
result_dataset = self.beaker.experiment.results(experiment)
if result_dataset is not None:
ephemeral_datasets.append(result_dataset)
for dataset in ephemeral_datasets:
try:
self.beaker.dataset.delete(dataset)
except DatasetNotFound:
pass
return experiment_url
@staticmethod
def _parse_git_remote(url: str) -> Tuple[str, str]:
"""
Parse a git remote URL into a GitHub (account, repo) pair.
"""
account, repo = (
url.split("https://github.com/")[-1]
.split("git@github.com:")[-1]
.split(".git")[0]
.split("/")
)
return account, repo
def _ensure_entrypoint_dataset(self) -> Dataset:
import hashlib
from importlib.resources import read_binary
import tango.integrations.beaker
workspace_id = self.beaker.workspace.get().id
# Get hash of the local entrypoint source file.
sha256_hash = hashlib.sha256()
contents = read_binary(tango.integrations.beaker, Constants.ENTRYPOINT_FILENAME)
sha256_hash.update(contents)
entrypoint_dataset_name = (
f"{Constants.ENTRYPOINT_DATASET_PREFIX}{workspace_id}-{sha256_hash.hexdigest()[:6]}"
)
tmp_entrypoint_dataset_name = (
f"{Constants.ENTRYPOINT_DATASET_PREFIX}{str(uuid.uuid4())}-tmp"
)
# Ensure entrypoint dataset exists.
entrypoint_dataset: Dataset
try:
entrypoint_dataset = self.beaker.dataset.get(entrypoint_dataset_name)
except DatasetNotFound:
# Create it.
logger.debug(f"Creating entrypoint dataset '{entrypoint_dataset_name}'")
try:
tmp_entrypoint_dataset = self.beaker.dataset.create(
tmp_entrypoint_dataset_name, quiet=True, commit=False
)
self.beaker.dataset.upload(
tmp_entrypoint_dataset, contents, Constants.ENTRYPOINT_FILENAME, quiet=True
)
self.beaker.dataset.commit(tmp_entrypoint_dataset)
entrypoint_dataset = self.beaker.dataset.rename(
tmp_entrypoint_dataset, entrypoint_dataset_name
)
except DatasetConflict: # could be in a race with another `tango` process.
time.sleep(1.0)
entrypoint_dataset = self.beaker.dataset.get(entrypoint_dataset_name)
# Verify contents.
err_msg = (
f"Checksum failed for entrypoint dataset {self.beaker.dataset.url(entrypoint_dataset)}\n"
f"This could be a bug, or it could mean someone has tampered with the dataset.\n"
f"If you're sure no one has tampered with it, you can delete the dataset from "
f"the Beaker dashboard and try again."
)
file_info = self.beaker.dataset.file_info(entrypoint_dataset, Constants.ENTRYPOINT_FILENAME)
if file_info.digest is not None and file_info.digest != Digest.from_decoded(
sha256_hash.digest(), "SHA256"
):
raise ExecutorError(err_msg)
return entrypoint_dataset
def _ensure_step_graph_dataset(self, step_graph: StepGraph) -> Dataset:
step_graph_dataset_name = f"{Constants.STEP_GRAPH_ARTIFACT_PREFIX}{str(uuid.uuid4())}"
try:
dataset = self.beaker.dataset.create(step_graph_dataset_name, quiet=True, commit=False)
self.beaker.dataset.upload(
dataset,
json.dumps({"steps": step_graph.to_config(include_unique_id=True)}).encode(),
Constants.STEP_GRAPH_FILENAME,
quiet=True,
)
self.beaker.dataset.commit(dataset)
except DatasetConflict: # could be in a race with another `tango` process.
time.sleep(1.0)
dataset = self.beaker.dataset.get(step_graph_dataset_name)
return dataset
def _build_experiment_spec(
self, step_graph: StepGraph, step_name: str
) -> Tuple[str, ExperimentSpec, List[Dataset]]:
from tango.common.logging import TANGO_LOG_LEVEL
step = step_graph[step_name]
sub_graph = step_graph.sub_graph(step_name)
step_info = self.workspace.step_info(step)
experiment_name = (
f"{Constants.STEP_EXPERIMENT_PREFIX}{step.unique_id}-{str(uuid.uuid4())[:8]}"
)
github_account, github_repo, git_ref = (
self._github_account,
self._github_repo,
self._git_commit,
)
if not self._logged_git_info:
self._logged_git_info = True
cli_logger.info(
"[blue]Using source code from "
"[b]https://github.com/%s/%s/commit/%s[/] to run steps on Beaker[/]",
github_account,
github_repo,
git_ref,
)
# Get cluster, resources, and priority to use.
clusters, task_resources, priority = self.scheduler.schedule(step)
self._check_if_cancelled()
# Ensure dataset with the entrypoint script exists and get it.
entrypoint_dataset = self._ensure_entrypoint_dataset()
self._check_if_cancelled()
# Create dataset for step graph.
step_graph_dataset = self._ensure_step_graph_dataset(sub_graph)
self._check_if_cancelled()
# Write the GitHub token secret.
self.beaker.secret.write(Constants.GITHUB_TOKEN_SECRET_NAME, self.github_token)
self._check_if_cancelled()
# Write the Beaker token secret.
self.beaker.secret.write(Constants.BEAKER_TOKEN_SECRET_NAME, self.beaker.config.user_token)
self._check_if_cancelled()
# Write the Google Cloud token secret.
if self.google_token is not None:
self.beaker.secret.write(Constants.GOOGLE_TOKEN_SECRET_NAME, self.google_token)
self._check_if_cancelled()
# Build Tango command to run.
command = [
"tango",
"--log-level",
"debug",
"--called-by-executor",
"beaker-executor-run",
Constants.INPUT_DIR + "/" + Constants.STEP_GRAPH_FILENAME,
step.name,
self.workspace.url,
]
if self.include_package is not None:
for package in self.include_package:
command += ["-i", package, "--log-level", TANGO_LOG_LEVEL or "debug"]
self._check_if_cancelled()
# Ignore the patch version.
# E.g. '3.9.7' -> '3.9'
python_version = step_info.environment.python
python_version = python_version[: python_version.find(".", python_version.find(".") + 1)]
# Build task spec.
task_spec = (
TaskSpec.new(
step.unique_id,
beaker_image=self.beaker_image,
docker_image=self.docker_image,
result_path=Constants.RESULTS_DIR,
command=["bash", Constants.ENTRYPOINT_DIR + "/" + Constants.ENTRYPOINT_FILENAME],
arguments=command,
resources=task_resources,
datasets=self.datasets,
env_vars=self.env_vars,
priority=priority,
)
.with_constraint(cluster=[clusters] if isinstance(clusters, str) else clusters)
.with_env_var(name="TANGO_VERSION", value=VERSION)
.with_env_var(name="GITHUB_TOKEN", secret=Constants.GITHUB_TOKEN_SECRET_NAME)
.with_env_var(name="BEAKER_TOKEN", secret=Constants.BEAKER_TOKEN_SECRET_NAME)
.with_env_var(name="GOOGLE_TOKEN", secret=Constants.GOOGLE_TOKEN_SECRET_NAME)
.with_env_var(name="GITHUB_REPO", value=f"{github_account}/{github_repo}")
.with_env_var(name="GIT_REF", value=git_ref)
.with_env_var(name="PYTHON_VERSION", value=python_version)
.with_env_var(name="BEAKER_EXPERIMENT_NAME", value=experiment_name)
.with_dataset(Constants.ENTRYPOINT_DIR, beaker=entrypoint_dataset.id)
.with_dataset(Constants.INPUT_DIR, beaker=step_graph_dataset.id)
)
if self.venv_name is not None:
task_spec = task_spec.with_env_var(name="VENV_NAME", value=self.venv_name)
if self.install_cmd is not None:
task_spec = task_spec.with_env_var(name="INSTALL_CMD", value=self.install_cmd)
return (
experiment_name,
ExperimentSpec(
tasks=[task_spec],
description=f'Tango step "{step_name}" ({step.unique_id})',
budget=self._budget,
),
[step_graph_dataset],
)
================================================
FILE: tango/integrations/beaker/step_cache.py
================================================
import logging
from pathlib import Path
from typing import Optional, Union
from beaker import Beaker
from beaker import Dataset as BeakerDataset
from beaker import DatasetConflict, DatasetNotFound, DatasetWriteError
from tango import Step
from tango.common import PathOrStr
from tango.common.exceptions import ConfigurationError
from tango.common.util import make_safe_filename, tango_cache_dir
from tango.integrations.beaker.common import Constants, get_client
from tango.step_cache import StepCache
from tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache
from tango.step_info import StepInfo
logger = logging.getLogger(__name__)
@StepCache.register("beaker")
class BeakerStepCache(RemoteStepCache):
"""
This is a :class:`~tango.step_cache.StepCache` that's used by :class:`BeakerWorkspace`.
It stores the results of steps on Beaker as datasets.
It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a
step's resulting subsequent times should be fast.
.. tip::
Registered as a :class:`~tango.step_cache.StepCache` under the name "beaker".
:param workspace: The name or ID of the Beaker workspace to use.
:param beaker: The Beaker client to use.
"""
Constants = Constants
def __init__(self, beaker_workspace: Optional[str] = None, beaker: Optional[Beaker] = None):
self.beaker: Beaker
if beaker is not None:
self.beaker = beaker
if beaker_workspace is not None:
self.beaker.config.default_workspace = beaker_workspace
self.beaker.workspace.ensure(beaker_workspace)
else:
self.beaker = get_client(beaker_workspace=beaker_workspace)
if self.beaker.config.default_workspace is None:
raise ConfigurationError("Beaker default workspace must be set")
super().__init__(
tango_cache_dir()
/ "beaker_cache"
/ make_safe_filename(self.beaker.config.default_workspace)
)
def _step_result_remote(self, step: Union[Step, StepInfo]) -> Optional[BeakerDataset]:
"""
Returns a `BeakerDataset` object containing the details of the step.
This only returns if the step has been finalized (committed).
"""
try:
dataset = self.beaker.dataset.get(self.Constants.step_artifact_name(step))
return dataset if dataset.committed is not None else None
except DatasetNotFound:
return None
def _upload_step_remote(self, step: Step, objects_dir: Path) -> BeakerDataset:
"""
Uploads the step's output to remote location.
"""
dataset_name = self.Constants.step_artifact_name(step)
try:
self.beaker.dataset.create(dataset_name, commit=False)
except DatasetConflict:
pass
try:
self.beaker.dataset.sync(dataset_name, objects_dir, quiet=True)
self.beaker.dataset.commit(dataset_name)
except DatasetWriteError:
pass
return self.beaker.dataset.get(dataset_name)
def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None:
"""
Downloads the step's output from remote location.
"""
try:
self.beaker.dataset.fetch(step_result, target_dir, quiet=True)
except DatasetNotFound:
raise RemoteNotFoundError()
def __len__(self):
"""
Returns the number of committed step outputs present in the remote location.
"""
# NOTE: lock datasets should not count here.
return sum(
1
for ds in self.beaker.workspace.iter_datasets(
match=self.Constants.STEP_ARTIFACT_PREFIX, uncommitted=False, results=False
)
if ds.name is not None
and ds.name.startswith(self.Constants.STEP_ARTIFACT_PREFIX)
and not ds.name.endswith(self.Constants.LOCK_ARTIFACT_SUFFIX)
)
================================================
FILE: tango/integrations/beaker/workspace.py
================================================
import json
import logging
import os
import random
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Type, TypeVar, Union, cast
from urllib.parse import ParseResult
import petname
from beaker import Dataset
from beaker import Dataset as BeakerDataset
from beaker import (
DatasetConflict,
DatasetNotFound,
DatasetSort,
Digest,
Experiment,
ExperimentNotFound,
)
from tango.common.util import make_safe_filename, tango_cache_dir
from tango.step import Step
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace
from tango.workspaces.remote_workspace import RemoteWorkspace
from .common import BeakerStepLock, Constants, dataset_url, get_client
from .step_cache import BeakerStepCache
T = TypeVar("T")
U = TypeVar("U", Run, StepInfo)
logger = logging.getLogger(__name__)
@Workspace.register("beaker")
class BeakerWorkspace(RemoteWorkspace):
"""
This is a :class:`~tango.workspace.Workspace` that stores step artifacts on `Beaker`_.
.. tip::
Registered as a :class:`~tango.workspace.Workspace` under the name "beaker".
:param workspace: The name or ID of the Beaker workspace to use.
:param kwargs: Additional keyword arguments passed to :meth:`Beaker.from_env() `.
"""
STEP_INFO_CACHE_SIZE = 512
Constants = Constants
NUM_CONCURRENT_WORKERS = 9
def __init__(self, workspace: str, max_workers: Optional[int] = None, **kwargs):
self.beaker = get_client(beaker_workspace=workspace, **kwargs)
self._cache = BeakerStepCache(beaker=self.beaker)
self._locks: Dict[Step, BeakerStepLock] = {}
super().__init__()
self.max_workers = max_workers
self._disk_cache_dir = tango_cache_dir() / "beaker_cache" / "_objects"
self._mem_cache: "OrderedDict[Digest, Union[StepInfo, Run]]" = OrderedDict()
@property
def cache(self):
return self._cache
@property
def locks(self):
return self._locks
@property
def steps_dir_name(self):
return "beaker_workspace"
@property
def url(self) -> str:
return f"beaker://{self.beaker.workspace.get().full_name}"
def _step_location(self, step: Step) -> str:
return dataset_url(self.beaker, self.Constants.step_artifact_name(step))
@classmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:
workspace: str
if parsed_url.netloc and parsed_url.path:
# e.g. "beaker://ai2/my-workspace"
workspace = parsed_url.netloc + parsed_url.path
elif parsed_url.netloc:
# e.g. "beaker://my-workspace"
workspace = parsed_url.netloc
else:
raise ValueError(f"Bad URL for Beaker workspace '{parsed_url}'")
return cls(workspace)
@property
def current_beaker_experiment(self) -> Optional[Experiment]:
"""
When the workspace is being used within a Beaker experiment that was submitted
by the Beaker executor, this will return the `Experiment` object.
"""
experiment_name = os.environ.get("BEAKER_EXPERIMENT_NAME")
if experiment_name is not None:
try:
return self.beaker.experiment.get(experiment_name)
except ExperimentNotFound:
return None
else:
return None
def _remote_lock(self, step: Step) -> BeakerStepLock:
return BeakerStepLock(
self.beaker, step, current_beaker_experiment=self.current_beaker_experiment
)
def _get_object_from_cache(self, digest: Digest, o_type: Type[U]) -> Optional[U]:
cache_path = self._disk_cache_dir / make_safe_filename(str(digest))
if digest in self._mem_cache:
cached = self._mem_cache.pop(digest)
# Move to end.
self._mem_cache[digest] = cached
return cached if isinstance(cached, o_type) else None
elif cache_path.is_file():
try:
with cache_path.open("r+t") as f:
json_dict = json.load(f)
cached = o_type.from_json_dict(json_dict)
except Exception as exc:
logger.warning("Error while loading object from workspace cache: %s", str(exc))
try:
os.remove(cache_path)
except FileNotFoundError:
pass
return None
# Add to in-memory cache.
self._mem_cache[digest] = cached
while len(self._mem_cache) > self.STEP_INFO_CACHE_SIZE:
self._mem_cache.popitem(last=False)
return cached # type: ignore
else:
return None
def _add_object_to_cache(self, digest: Digest, o: U):
self._disk_cache_dir.mkdir(parents=True, exist_ok=True)
cache_path = self._disk_cache_dir / make_safe_filename(str(digest))
self._mem_cache[digest] = o
with cache_path.open("w+t") as f:
json.dump(o.to_json_dict(), f)
while len(self._mem_cache) > self.STEP_INFO_CACHE_SIZE:
self._mem_cache.popitem(last=False)
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
try:
dataset = self.beaker.dataset.get(self.Constants.step_artifact_name(step_or_unique_id))
return self._get_step_info_from_dataset(dataset)
except (DatasetNotFound, FileNotFoundError):
if not isinstance(step_or_unique_id, Step):
raise KeyError(step_or_unique_id)
step_info = StepInfo.new_from_step(step_or_unique_id)
self._update_step_info(step_info)
return step_info
def _get_step_info_from_dataset(self, dataset: Dataset) -> StepInfo:
file_info = self.beaker.dataset.file_info(dataset, Constants.STEP_INFO_FNAME)
step_info: StepInfo
cached = (
None
if file_info.digest is None
else self._get_object_from_cache(file_info.digest, StepInfo)
)
if cached is not None:
step_info = cached
else:
step_info_bytes = self.beaker.dataset.get_file(dataset, file_info, quiet=True)
step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))
if file_info.digest is not None:
self._add_object_to_cache(file_info.digest, step_info)
return step_info
def _save_run(
self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None
) -> Run:
# Create a remote dataset that represents this run. The dataset which just contain
# a JSON file that maps step names to step unique IDs.
run_dataset: BeakerDataset
if name is None:
# Find a unique name to use.
while True:
name = petname.generate() + str(random.randint(0, 100))
try:
run_dataset = self.beaker.dataset.create(
self.Constants.run_artifact_name(cast(str, name)), commit=False
)
except DatasetConflict:
continue
else:
break
else:
try:
run_dataset = self.beaker.dataset.create(
self.Constants.run_artifact_name(name), commit=False
)
except DatasetConflict:
raise ValueError(f"Run name '{name}' is already in use")
# Upload run data to dataset.
# NOTE: We don't commit the dataset here since we'll need to upload the logs file
# after the run.
self.beaker.dataset.upload(
run_dataset, json.dumps(run_data).encode(), self.Constants.RUN_DATA_FNAME, quiet=True
)
return Run(name=cast(str, name), steps=steps, start_date=run_dataset.created)
def registered_runs(self) -> Dict[str, Run]:
import concurrent.futures
runs: Dict[str, Run] = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS,
thread_name_prefix="BeakerWorkspace.registered_runs()-",
) as executor:
run_futures = []
for dataset in self.beaker.workspace.iter_datasets(
match=self.Constants.RUN_ARTIFACT_PREFIX, uncommitted=True, results=False
):
run_futures.append(executor.submit(self._get_run_from_dataset, dataset))
for future in concurrent.futures.as_completed(run_futures):
run = future.result()
if run is not None:
runs[run.name] = run
return runs
def search_registered_runs(
self,
*,
sort_by: Optional[RunSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> List[RunInfo]:
if match is None:
match = Constants.RUN_ARTIFACT_PREFIX
else:
match = Constants.RUN_ARTIFACT_PREFIX + match
if sort_by is None or sort_by == RunSort.START_DATE:
sort = DatasetSort.created
elif sort_by == RunSort.NAME:
sort = DatasetSort.dataset_name
else:
raise NotImplementedError
runs = []
for dataset in self.beaker.workspace.iter_datasets(
match=match,
results=False,
cursor=start or 0,
limit=None if stop is None else stop - (start or 0),
sort_by=sort,
descending=sort_descending,
):
if dataset.name is not None and dataset.name.startswith(
self.Constants.RUN_ARTIFACT_PREFIX
):
run_name = dataset.name[len(self.Constants.RUN_ARTIFACT_PREFIX) :]
runs.append(RunInfo(name=run_name, start_date=dataset.created))
return runs
def num_registered_runs(self, *, match: Optional[str] = None) -> int:
if match is None:
match = Constants.RUN_ARTIFACT_PREFIX
else:
match = Constants.RUN_ARTIFACT_PREFIX + match
count = 0
for dataset in self.beaker.workspace.iter_datasets(
match=match,
results=False,
):
if dataset.name is not None and dataset.name.startswith(Constants.RUN_ARTIFACT_PREFIX):
count += 1
return count
def search_step_info(
self,
*,
sort_by: Optional[StepInfoSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
state: Optional[StepState] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[StepInfo]:
if state is not None:
raise NotImplementedError(
f"{self.__class__.__name__} cannot filter steps efficiently by state"
)
if match is None:
match = Constants.STEP_ARTIFACT_PREFIX
else:
match = Constants.STEP_ARTIFACT_PREFIX + match
sort: Optional[DatasetSort] = None
if sort_by is None or sort_by == StepInfoSort.START_TIME:
sort = DatasetSort.created
elif sort_by == StepInfoSort.UNIQUE_ID:
sort = DatasetSort.dataset_name
elif sort_by is not None:
raise NotImplementedError
steps = []
for dataset in self.beaker.workspace.iter_datasets(
match=match,
results=False,
cursor=start or 0,
limit=None if stop is None else stop - (start or 0),
sort_by=sort or DatasetSort.created,
descending=sort_descending,
):
try:
steps.append(self._get_step_info_from_dataset(dataset))
except (DatasetNotFound, FileNotFoundError):
continue
return steps
def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int:
if state is not None:
raise NotImplementedError(
f"{self.__class__.__name__} cannot filter steps efficiently by state"
)
if match is None:
match = Constants.STEP_ARTIFACT_PREFIX
else:
match = Constants.STEP_ARTIFACT_PREFIX + match
count = 0
for dataset in self.beaker.workspace.iter_datasets(
match=match,
results=False,
):
if dataset.name is not None and dataset.name.startswith(Constants.STEP_ARTIFACT_PREFIX):
count += 1
return count
def registered_run(self, name: str) -> Run:
err_msg = f"Run '{name}' not found in workspace"
try:
dataset_for_run = self.beaker.dataset.get(self.Constants.run_artifact_name(name))
# Make sure the run is in our workspace.
if dataset_for_run.workspace_ref.id != self.beaker.workspace.get().id: # type: ignore # TODO
raise DatasetNotFound
except DatasetNotFound:
raise KeyError(err_msg)
run = self._get_run_from_dataset(dataset_for_run)
if run is None:
raise KeyError(err_msg)
else:
return run
def _save_run_log(self, name: str, log_file: Path):
run_dataset = self.Constants.run_artifact_name(name)
self.beaker.dataset.sync(run_dataset, log_file, quiet=True)
self.beaker.dataset.commit(run_dataset)
def _get_run_from_dataset(self, dataset: BeakerDataset) -> Optional[Run]:
if dataset.name is None:
return None
if not dataset.name.startswith(self.Constants.RUN_ARTIFACT_PREFIX):
return None
try:
run_name = dataset.name[len(self.Constants.RUN_ARTIFACT_PREFIX) :]
steps_info_bytes = self.beaker.dataset.get_file(
dataset, self.Constants.RUN_DATA_FNAME, quiet=True
)
steps_info = json.loads(steps_info_bytes)
except (DatasetNotFound, FileNotFoundError):
return None
steps: Dict[str, StepInfo] = {}
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS,
thread_name_prefix="BeakerWorkspace._get_run_from_dataset()-",
) as executor:
step_info_futures = []
for unique_id in steps_info.values():
step_info_futures.append(executor.submit(self.step_info, unique_id))
for future in concurrent.futures.as_completed(step_info_futures):
step_info = future.result()
assert step_info.step_name is not None
steps[step_info.step_name] = step_info
return Run(name=run_name, start_date=dataset.created, steps=steps)
def _update_step_info(self, step_info: StepInfo):
dataset_name = self.Constants.step_artifact_name(step_info)
step_info_dataset: BeakerDataset
try:
self.beaker.dataset.create(dataset_name, commit=False)
except DatasetConflict:
pass
step_info_dataset = self.beaker.dataset.get(dataset_name)
self.beaker.dataset.upload(
step_info_dataset, # folder name
json.dumps(step_info.to_json_dict()).encode(), # step info dict.
self.Constants.STEP_INFO_FNAME, # step info filename
quiet=True,
)
def _remove_step_info(self, step_info: StepInfo) -> None:
# remove dir from beaker workspace
dataset_name = self.Constants.step_artifact_name(step_info)
step_dataset = self.beaker.dataset.get(dataset_name)
if step_dataset is not None:
self.beaker.dataset.delete(step_dataset)
================================================
FILE: tango/integrations/datasets/__init__.py
================================================
"""
.. important::
To use this integration you should install ``tango`` with the "datasets" extra
(e.g. ``pip install tango[datasets]``) or just install the ``datasets`` library after the fact
(e.g. ``pip install datasets``).
Components for Tango integration with `🤗 Datasets `_.
Example: loading and combining
------------------------------
Here's an example config that uses the built-in steps from this integration to load,
concatenate, and interleave datasets from HuggingFace:
.. literalinclude:: ../../../../test_fixtures/integrations/datasets/config.json
You could run this with:
.. code-block::
tango run config.json
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, TypeVar, Union, overload
from tango.common.aliases import PathOrStr
from tango.common.dataset_dict import DatasetDict, IterableDatasetDict
from tango.common.exceptions import ConfigurationError, IntegrationMissingError
from tango.format import Format
from tango.step import Step
try:
import datasets as ds
except ModuleNotFoundError:
raise IntegrationMissingError("datasets")
__all__ = [
"LoadDataset",
"LoadStreamingDataset",
"DatasetsFormat",
"convert_to_tango_dataset_dict",
"InterleaveDatasets",
"ConcatenateDatasets",
"DatasetRemixStep",
]
@overload
def convert_to_tango_dataset_dict(hf_dataset_dict: ds.DatasetDict) -> DatasetDict:
...
@overload
def convert_to_tango_dataset_dict(hf_dataset_dict: ds.IterableDatasetDict) -> IterableDatasetDict: # type: ignore
...
def convert_to_tango_dataset_dict(hf_dataset_dict):
"""
A helper function that can be used to convert a HuggingFace :class:`~datasets.DatasetDict`
or :class:`~datasets.IterableDatasetDict` into a native Tango
:class:`~tango.common.DatasetDict` or :class:`~tango.common.IterableDatasetDict`.
This is important to do when your dataset dict is input to another step for caching
reasons.
"""
if isinstance(hf_dataset_dict, ds.IterableDatasetDict):
return IterableDatasetDict(splits=hf_dataset_dict)
else:
return DatasetDict(splits=hf_dataset_dict)
T = Union[ds.Dataset, ds.DatasetDict]
@Format.register("datasets")
class DatasetsFormat(Format[T]):
"""
This format writes a :class:`datasets.Dataset` or :class:`datasets.DatasetDict` to disk
using :meth:`datasets.Dataset.save_to_disk()`.
It is the default :class:`~tango.format.Format` for the :class:`LoadDataset` step.
"""
VERSION = "001"
def write(self, artifact: T, dir: PathOrStr):
dataset_path = Path(dir) / "data"
artifact.save_to_disk(str(dataset_path))
def read(self, dir: PathOrStr) -> T:
dataset_path = Path(dir) / "data"
return ds.load_from_disk(str(dataset_path))
@Step.register("datasets::load")
class LoadDataset(Step):
"""
This step loads a `HuggingFace dataset `_.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "datasets::load".
.. important::
If you are loading an :class:`~datasets.IterableDataset` or :class:`~datasets.IterableDatasetDict`
you need to use the :class:`LoadStreamingDataset` step instead.
"""
DETERMINISTIC = True
VERSION = "001"
CACHEABLE = True
# Even though HuggingFace datasets has its own caching mechanism, it can still be worth caching
# this step with tango's mechanism since some datasets take a really long time to query from HuggingFace
# ("bigscience/P3", for example). Tango's caching mechanism circumvents that issue.
FORMAT = DatasetsFormat()
def run(self, path: str, **kwargs) -> Union[ds.DatasetDict, ds.Dataset]: # type: ignore
"""
Load the HuggingFace dataset specified by ``path``.
``path`` is the canonical name or path to the dataset. Additional key word arguments
are passed as-is to :func:`datasets.load_dataset()`.
"""
dataset = ds.load_dataset(path, **kwargs)
if not isinstance(dataset, (ds.Dataset, ds.DatasetDict)):
raise ConfigurationError(
f"{self.__class__.__name__} can only be used with non-streaming datasets. "
f"For streaming datasets, use the 'LoadStreamingDataset' ('datasets::load_streaming') step instead."
)
return dataset
@Step.register("datasets::load_streaming")
class LoadStreamingDataset(Step):
"""
This step loads an iterable/streaming `HuggingFace dataset `_.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "datasets::load_streaming".
"""
DETERMINISTIC = True
VERSION = "001"
CACHEABLE = (
False # can't be cached with `DatasetsFormat`, and might be really inefficient anyway.
)
def run( # type: ignore
self, path: str, **kwargs
) -> Union[ds.IterableDatasetDict, ds.IterableDataset]:
"""
Load the HuggingFace streaming dataset specified by ``path``.
``path`` is the canonical name or path to the dataset. Additional key word arguments
are passed as-is to :func:`datasets.load_dataset()`.
"""
dataset = ds.load_dataset(path, **kwargs)
if not isinstance(dataset, (ds.IterableDataset, ds.IterableDatasetDict)):
raise ConfigurationError(
f"{self.__class__.__name__} can only be used with streaming datasets. "
f"For non-streaming datasets, use the 'LoadDataset' ('datasets::load') step instead."
)
return dataset
DatasetType = TypeVar("DatasetType", ds.Dataset, ds.IterableDataset)
@Step.register("datasets::interleave")
class InterleaveDatasets(Step):
"""
This steps interleaves multiple datasets using :func:`~datasets.interleave_datasets()`.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "datasets::interleave".
"""
DETERMINISTIC = True
VERSION = "001"
CACHEABLE = False # Not worth caching
def run( # type: ignore[override]
self,
datasets: List[DatasetType],
probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
) -> DatasetType:
"""
Interleave the list of datasets.
"""
return ds.interleave_datasets(datasets, probabilities=probabilities, seed=seed)
@Step.register("datasets::concatenate")
class ConcatenateDatasets(Step):
"""
This step concatenates multiple datasets using :func:`~datasets.concatenate_datasets()`.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "datasets::concatenate".
"""
DETERMINISTIC = True
VERSION = "001"
CACHEABLE = False # Not worth caching
def run( # type: ignore[override]
self,
datasets: List[ds.Dataset],
info: Optional[Any] = None,
split: Optional[Any] = None,
axis: int = 0,
) -> ds.Dataset:
"""
Concatenate the list of datasets.
"""
return ds.concatenate_datasets(datasets, info=info, split=split, axis=axis)
@Step.register("datasets::dataset_remix")
class DatasetRemixStep(Step):
"""
This step can remix splits in a :class:`~datasets.DatasetDict` into new splits.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "datasets::dataset_remix".
Examples
--------
.. testcode::
:hide:
from tango.common.logging import initialize_logging
initialize_logging(enable_cli_logs=True)
import datasets
.. testcode::
input = datasets.load_dataset("lhoestq/test")
new_splits = {
"all": "train + validation",
"crossval_train": "train[:1] + validation[1:]",
"crossval_test": "train[1:] + validation[:1]",
}
step = DatasetRemixStep()
remixed_dataset = step.run(input=input, new_splits=new_splits)
.. testoutput::
:hide:
:options: +ELLIPSIS
...
"""
DETERMINISTIC = True
CACHEABLE = True
VERSION = "001"
def run( # type: ignore
self,
input: ds.DatasetDict,
new_splits: Dict[str, str],
keep_old_splits: bool = True,
shuffle_before: bool = False,
shuffle_after: bool = False,
random_seed: int = 1532637578,
) -> ds.DatasetDict:
"""
Remixes and shuffles a dataset. This is done eagerly with native 🤗 Datasets features.
:param input:
The input dataset that will be remixed.
:param new_splits:
Specifies the new splits that the output dataset should have. Keys are the name of the new
splits. Values refer to the original splits. You can refer to original splits in the following ways:
* Mention the original split name to copy it to a new name.
* Mention the original split name with Python's slicing syntax to select part of the original
split's instances. For example, ``"train[:1000]"`` selects the first 1000 instances from the
``"train"`` split.
* ``"instances + instances"`` concatenates the instances into one split.
You can combine these possibilities.
:param keep_old_splits:
Whether to keep the splits from the input dataset in addition to the new ones given by
``new_splits``.
:param shuffle_before:
Whether to shuffle the input splits before creating the new ones.
If you need shuffled instances and you're not sure the input is properly shuffled, use this.
:param shuffle_after:
Whether to shuffle the input splits after creating the new ones.
If you need shuffled instances and you're slicing or concatenating splits, use this.
If you want to be on the safe side, shuffle both before and after.
:param random_seed:
Random seed, affects shuffling
:returns:
Returns a new dataset that is appropriately remixed.
"""
if shuffle_before:
input = input.shuffle(random_seed)
def get_slice(split_name: str) -> ds.Dataset:
slice_match = re.match(r"(.*)\[(-?[0-9]*:-?[0-9]*)\]", split_name)
if slice_match is None:
return input[split_name]
else:
split_name = slice_match[1]
slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(":")]
slice_indices = range(*slice(*slice_args).indices(len(input[split_name])))
return input[split_name].select(slice_indices)
def parse_split_spec(split_spec: str):
parts = [get_slice(name.strip()) for name in split_spec.split("+")]
if len(parts) == 1:
return parts[0]
else:
return ds.concatenate_datasets(parts)
if keep_old_splits:
result = ds.DatasetDict(input.items())
else:
result = ds.DatasetDict()
result.update(
{
new_split_name: parse_split_spec(new_split_spec)
for new_split_name, new_split_spec in new_splits.items()
}
)
if shuffle_after:
result = result.shuffle(random_seed)
return result
================================================
FILE: tango/integrations/fairscale/__init__.py
================================================
"""
.. important::
To use this integration you should install ``tango`` with the "fairscale" extra
(e.g. ``pip install tango[fairscale]``) or just install FairScale after the fact.
This integration also depends on `PyTorch `_, so make sure you
install the correct version of torch *first* given your operating system and supported
CUDA version. Check `pytorch.org/get-started/locally/ `_
for more details.
Components for Tango integration with `FairScale `_.
Overview
--------
FairScale is a PyTorch library for large scale training. Among other things, it implements
the main memory-savings techniques for distributed data-parallel training (DDP) that came from the paper
`ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
`_.
The main part of this Tango integration is the :class:`FairScaleTrainingEngine`.
This is a :class:`~tango.integrations.torch.TrainingEngine` implementation that utilizes
FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` (FSDP) for substantial memory savings
during distributed training.
For the best performance you should also use :func:`with_wrapped_modules()` to wrap the inner modules
of your :class:`~tango.integrations.torch.Model`. When used with FSDP this will dramatically reduce
the memory required to load your model.
"""
from tango.common.exceptions import IntegrationMissingError
try:
import fairscale
except ModuleNotFoundError:
raise IntegrationMissingError("fairscale")
__all__ = [
"FairScaleTrainingEngine",
"FSDPConfig",
"with_wrapped_modules",
]
from .fsdp_config import FSDPConfig
from .module_wrapper import with_wrapped_modules
from .training_engine import FairScaleTrainingEngine
================================================
FILE: tango/integrations/fairscale/fsdp_config.py
================================================
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from tango.common import FromParams
@dataclass
class FSDPConfig(FromParams):
"""
Defines all of the configurable options for FairScale's :class:`~fairscale.nn.FullyShardedDataParallel`.
.. seealso::
`Best practices for FullyShardedDataParallel `_
from the FairScale docs.
""" # noqa: E501
reshard_after_forward: bool = True
"""
See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.
"""
move_params_to_cpu: bool = False
"""
See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.
"""
move_grads_to_cpu: Optional[bool] = None
"""
See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.
.. seealso::
:data:`move_params_to_cpu`
.. warning::
At the moment we recommend that you don't mess with this parameter, or only explicitly
set it to the same value as :data:`move_params_to_cpu`. If you leave it as ``None``
(the default), it will automatically be set to match :data:`move_params_to_cpu` by FairScale.
Currently training seems to crash if you set this ``False`` while :data:`move_params_to_cpu` is ``True``.
We're tracking `fairscale#918 `_,
which may be related.
"""
mixed_precision: bool = False
"""
See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.
.. important::
We recommend setting this to the same value as the ``amp`` parameter in
:class:`FairScaleTrainingEngine`.
Based on our experiments, if you're training with AMP enabled (``amp=True``)
you might see a small additional speedup in training time along with a small
additional decrease in GPU memory utilization without any performance penalty
(with respect to convergence) by setting this to ``True``.
But if you're *not* training with AMP, setting this ``True`` could impact the
model's ability to converge.
"""
def as_kwargs(self) -> Dict[str, Any]:
"""
Convert to the appropriate ``kwargs`` for :class:`~fairscale.nn.FullyShardedDataParallel`.
"""
return asdict(self)
def wrap(self, module: torch.nn.Module):
"""
A convenience method for wrapping a module in :class:`~fairscale.nn.FullyShardedDataParallel`
with all of the options defined in this class.
.. seealso::
Internally this is what :func:`with_wrapped_modules()` calls.
"""
return FSDP(module, **self.as_kwargs())
================================================
FILE: tango/integrations/fairscale/module_wrapper.py
================================================
import re
from typing import Optional, Set
import torch
import torch.nn as nn
from fairscale.nn.checkpoint import checkpoint_wrapper
from tango.integrations.torch import Model
from .fsdp_config import FSDPConfig
@Model.register("fairscale::with_wrapped_modules") # type: ignore[arg-type]
def with_wrapped_modules(
model: Model,
modules_to_wrap: Set[str],
fsdp_config: Optional[FSDPConfig] = None,
activation_checkpointing: bool = False,
) -> Model:
"""
A :class:`~tango.integrations.torch.Model` wrapper that can be used to easily wrap
inner modules of a model with FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` wrapper
and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`.
.. tip::
Registered as a :class:`~tango.integrations.torch.Model` constructor under the name
"fairscale::with_wrapped_modules".
.. important::
This is meant to be used with the :class:`FairScaleTrainingEngine`.
:param model:
The model to wrap.
:param modules_to_wrap:
The names of submodule to wrap. Can be regular expressions.
:param fsdp_config:
The ``FullyShardedDataParallel`` configuration to use when wrapping the modules.
If not specified, the modules will NOT be wrapped with FSDP.
:param activation_checkpointing:
Whether to wrap the modules with FairScale's
:class:`~fairscale.nn.checkpoint.checkpoint_wrapper`.
Examples
--------
You can use this as a :class:`~tango.integrations.torch.Model` constructor from a config/params
like this:
.. testcode::
import torch.nn as nn
from tango.integrations.torch import Model
class FeedForward(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 4)
self.activation = nn.ReLU()
def forward(self, x):
return self.activation(self.linear(x))
@Model.register("simple_regression_model")
class SimpleRegressionModel(Model):
def __init__(self):
super().__init__()
self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)])
self.regression_head = nn.Linear(4, 1)
self.loss_fcn = nn.MSELoss()
def forward(self, x, y):
output = self.blocks(x)
output = self.regression_head(output)
loss = self.loss_fcn(output, y)
return {"loss": loss}
model = Model.from_params({
"type": "fairscale::with_wrapped_modules",
"model": {
"type": "simple_regression_model",
},
"modules_to_wrap": [r"blocks\\.[0-9]+", "regression_head"],
"activation_checkpointing": True,
})
"""
def wrap_module(
module: nn.Module,
) -> nn.Module:
if activation_checkpointing:
module = checkpoint_wrapper(module, offload_to_cpu=True)
if fsdp_config is not None and torch.distributed.is_initialized():
module = fsdp_config.wrap(module)
return module
all_module_names: Set[str] = set([name for name, _ in model.named_modules() if name])
actual_modules_to_wrap: Set[str] = set()
unmatched_patterns: Set[str] = modules_to_wrap.copy()
for module_name in all_module_names:
for pattern in modules_to_wrap:
if re.fullmatch(pattern, module_name):
actual_modules_to_wrap.add(module_name)
if pattern in unmatched_patterns:
unmatched_patterns.remove(pattern)
if unmatched_patterns:
raise ValueError(
f"Some patterns in 'modules_to_wrap' did not match actual module names ({unmatched_patterns})"
)
for module_name in actual_modules_to_wrap:
if "." in module_name:
*parent_parts, module_name = module_name.split(".")
parent_module = model.get_submodule(".".join(parent_parts))
else:
parent_module = model
module = parent_module.get_submodule(module_name)
module = wrap_module(module)
parent_module.add_module(module_name, module)
return model
================================================
FILE: tango/integrations/fairscale/training_engine.py
================================================
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.optim.grad_scaler import ShardedGradScaler
from tango.common import Lazy
from tango.common.exceptions import ConfigurationError
from tango.integrations.torch import (
LRScheduler,
Model,
Optimizer,
TorchTrainingEngine,
TrainConfig,
TrainingEngine,
)
from .fsdp_config import FSDPConfig
@TrainingEngine.register("fairscale")
class FairScaleTrainingEngine(TorchTrainingEngine):
"""
A :class:`~tango.integrations.torch.TrainingEngine` that leverages FairScale's
:class:`~fairscale.nn.FullyShardedDataParallel` for use within
:class:`~tango.integrations.torch.TorchTrainStep`.
.. tip::
Registered as an :class:`~tango.integrations.torch.TrainingEngine` under the name
"fairscale".
.. tip::
To get the best performance out of :class:`FairScaleTrainingEngine` you should
wrap individual layers of your model with :class:`~fairscale.nn.FullyShardedDataParallel`
and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`
while instantiating them. You can use :class:`with_wrapped_modules()` to accomplish this.
.. important::
Only the parameters listed below should be defined in a configuration
file. The other parameters will be automatically passed to the constructor
within :class:`~tango.integrations.torch.TorchTrainStep`.
.. warning::
:class:`~FairScaleTrainingEngine` can only be used in distributed training, i.e.
when ``device_count > 1`` in the :class:`~tango.integrations.torch.TorchTrainStep`.
For maximum memory savings, we recommend training with AMP enabled and the following
:class:`FSDPConfig`:
.. testcode::
from tango.integrations.fairscale import FSDPConfig
fsdp_config = FSDPConfig(
reshard_after_forward=True,
move_params_to_cpu=True,
move_grads_to_cpu=True,
mixed_precision=True,
)
For maximum training *speed*, we recommend training with AMP enabled and the following
:class:`FSDPConfig`:
.. testcode::
from tango.integrations.fairscale import FSDPConfig
fsdp_config = FSDPConfig(
reshard_after_forward=False,
move_params_to_cpu=False,
move_grads_to_cpu=False,
mixed_precision=True,
)
:param amp:
Use automatic mixed precision (AMP). Default is ``False``.
:param max_grad_norm:
If set, gradients will be clipped to have this max norm. Default is ``None``.
:param amp_use_bfloat16:
Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training.
Only applicable when ``amp=True``. If not specified, the default behavior will be
to use ``bfloat16`` when training with AMP on CPU, otherwise not.
:param fsdp_config:
The options for :class:`~fairscale.nn.FullyShardedDataParallel`.
If not specified, the default options will be used.
"""
def __init__(
self,
train_config: TrainConfig,
model: Lazy[Model],
optimizer: Lazy[Optimizer],
*,
lr_scheduler: Optional[Lazy[LRScheduler]] = None,
amp: bool = False,
max_grad_norm: Optional[float] = None,
amp_use_bfloat16: Optional[bool] = None,
fsdp_config: Optional[FSDPConfig] = None,
) -> None:
if not train_config.is_distributed:
raise ConfigurationError(
f"{self.__class__.__name__} can only be used with distributed training"
)
self.fsdp_config = fsdp_config or FSDPConfig()
self.logger = logging.getLogger(self.__class__.__name__)
super().__init__(
train_config,
model,
optimizer,
lr_scheduler=lr_scheduler,
amp=amp,
max_grad_norm=max_grad_norm,
amp_use_bfloat16=amp_use_bfloat16,
)
if amp:
self.grad_scaler = ShardedGradScaler()
def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model:
if isinstance(model, Lazy):
model = model.construct()
if not self.fsdp_config.move_params_to_cpu:
model.to(self.train_config.worker_local_default_device)
return FSDP(model, **self.fsdp_config.as_kwargs())
def clip_grad_norm(self) -> None:
if self.max_grad_norm is not None:
self.model.clip_grad_norm_(self.max_grad_norm) # type: ignore
def get_model_state(self) -> Dict[str, torch.Tensor]:
return {
"weights": self.model.local_state_dict(), # type: ignore
"metadata": self.model.local_metadata_dict(), # type: ignore
}
def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.model.load_local_state_dict(state_dict["weights"]) # type: ignore
def save_complete_weights_from_checkpoint(
self, checkpoint_dir: Path, weights_path: Path
) -> None:
self.logger.info("Consolidating sharded checkpoint weights...")
sharded_weights: List[Dict[str, torch.Tensor]] = []
sharded_metadata: List[Dict[str, Any]] = []
for path in checkpoint_dir.resolve().glob("worker*_model.pt"):
sharded_state = torch.load(path, map_location="cpu")
sharded_weights.append(sharded_state["weights"])
sharded_metadata.append(sharded_state["metadata"])
full_state = FSDP.consolidate_shard_weights(sharded_weights, sharded_metadata)
del sharded_weights
del sharded_metadata
torch.save(full_state, weights_path)
================================================
FILE: tango/integrations/flax/__init__.py
================================================
from tango.common.exceptions import IntegrationMissingError
try:
import flax
except ModuleNotFoundError:
raise IntegrationMissingError("flax")
__all__ = [
"DataLoader",
"FlaxDataLoader",
"LRScheduler",
"Model",
"Optimizer",
"FlaxTrainStep",
"FlaxFormat",
"TrainCallback",
"EvalCallback",
"FlaxWrapper",
"TrainConfig",
"FlaxEvalStep",
]
from .data import DataLoader, FlaxDataLoader
from .eval import FlaxEvalStep
from .eval_callback import EvalCallback
from .format import FlaxFormat
from .model import Model
from .optim import LRScheduler, Optimizer
from .train import FlaxTrainStep
from .train_callback import TrainCallback
from .train_config import TrainConfig
from .wrapper import FlaxWrapper
================================================
FILE: tango/integrations/flax/data.py
================================================
import logging
from typing import Generic, TypeVar
import jax.random
import numpy as np
from datasets import Dataset
from flax.training.common_utils import shard
from tango.common.registrable import Registrable
T = TypeVar("T")
class DataLoader(Generic[T], Registrable):
"""
A :class:`~tango.common.Registrable` version of a ``Flax DataLoader``.
``Flax DataLoader`` accepts Dataset object. The class yields a numpy batch.
"""
@DataLoader.register("flax::dataloader")
class FlaxDataLoader(DataLoader):
def __init__(
self,
dataset: Dataset,
batch_size: int = 8,
drop_last: bool = True,
shuffle: bool = True,
):
self.dataset = dataset
self.dataset_size = dataset.num_rows
self.batch_size = batch_size
self.drop_last = drop_last
if not drop_last:
raise NotImplementedError(
"With Jax you have to drop the last incomplete batch, because the batch size is compiled into the "
"model."
)
self.shuffle = shuffle
self.logger = logging.getLogger(FlaxDataLoader.__name__)
def __call__(self, rng: jax._src.random.KeyArrayLike, do_distributed: bool):
steps_per_epoch = self.dataset_size // self.batch_size
if self.shuffle:
perms = jax.random.permutation(rng, self.dataset_size)
perms = np.asarray(perms) # using jax arrays for indexing is a bottleneck on TPUs.
else:
perms = np.arange(self.dataset_size)
self.logger.info("Skipping last incomplete batch")
perms = perms[: steps_per_epoch * self.batch_size] # Skip incomplete batch.
perms = perms.reshape((steps_per_epoch, self.batch_size))
for perm in perms:
batch = self.dataset[perm]
if do_distributed:
batch = shard(batch)
yield batch
================================================
FILE: tango/integrations/flax/eval.py
================================================
import logging
from collections import defaultdict
from itertools import islice
from typing import Dict, List, Optional, Sequence
import jax
from flax import jax_utils
from flax.training.train_state import TrainState
from tango.common.dataset_dict import DatasetDictBase
from tango.common.exceptions import ConfigurationError
from tango.common.lazy import Lazy
from tango.common.tqdm import Tqdm
from tango.format import Format, JsonFormat
from tango.step import Step
from .data import FlaxDataLoader
from .eval_callback import EvalCallback
from .util import get_PRNGkey
from .wrapper import FlaxWrapper
@Step.register("flax::eval")
class FlaxEvalStep(Step):
"""
A Flax evaluation loop that pairs well with :class:`FlaxTrainStep`.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "flax::eval".
.. important::
The evaluation loop will use a GPU/TPU automatically if one is available.
You can control which GPU it uses with the environment variable ``CUDA_VISIBLE_DEVICES``.
For example, set ``CUDA_VISIBLE_DEVICES=1`` to force ``FlaxEvalStep`` to only use
the GPU with ID 1.
.. warning::
By default the metrics specified by the ``metric_names`` parameter
are aggregated by simply averaging across batches.
This behavior is usually correct for metrics like "loss" or "accuracy",
for example, but may not be correct for other metrics like "F1".
If this is not correct for your metric you will need to handle the aggregation
internally in your model or with an :class:`EvalCallback`
using the :meth:`EvalCallback.post_batch()` method.
Then set the parameter ``auto_aggregate_metrics`` to ``False``.
"""
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = JsonFormat()
SKIP_ID_ARGUMENTS = {"log_every"}
def run( # type: ignore[override]
self,
state: TrainState,
dataset: DatasetDictBase,
dataloader: Lazy[FlaxDataLoader],
wrapper: FlaxWrapper,
test_split: str = "test",
seed: int = 42,
log_every: int = 1,
do_distributed: bool = False,
eval_steps: Optional[int] = None,
metric_names: Sequence[str] = ("loss",),
auto_aggregate_metrics: bool = True,
callbacks: Optional[List[Lazy[EvalCallback]]] = None,
) -> Dict[str, float]:
"""
Evaluate the ``model``.
:param state:
The state of the model to evaluate. This contains the parameters.
:param dataset:
Should contain the test data.
:param dataloader:
The data loader that generates test batches. The batches should be :class:`dict`
objects.
:param wrapper:
The wrapper should define :meth:`eval_metrics`.
:param test_split:
The name of the data split used for evaluation in the ``dataset_dict``.
Default is "test".
:param seed:
Used to set the PRNG states at the beginning of the evaluation loop.
:param log_every:
Log every this many steps. Default is ``1``.
:param do_distributed:
Whether to do distributed training or not. Set as 0 or 1.
:param eval_steps:
The number of steps to evaluate for. If not specified evaluation will
stop after a complete iteration through the ``dataloader``.
:param metric_names:
The names of the metrics to track and aggregate. Default is ``("loss",)``.
:param auto_aggregate_metrics:
If ``True`` (the default), the metrics will be averaged across batches.
This may not be the correct behavior for some metrics (such as F1),
in which you should set this to ``False`` and handle the aggregation
internally in your model or with an :class:`EvalCallback`
(using :meth:`EvalCallback.post_batch()`).
:param callbacks:
A list of :class:`EvalCallback`.
"""
logger = logging.getLogger(FlaxEvalStep.__name__)
# construct dataloader
dataloader: FlaxDataLoader = dataloader.construct(
dataset=dataset[test_split].set_format("numpy")
)
steps: int
try:
dataloader_len = dataloader.dataset_size
steps = dataloader_len if eval_steps is None else min(dataloader_len, eval_steps)
except TypeError:
if eval_steps is None:
raise ConfigurationError(
"You must set 'eval_steps' for streaming/iterable datasets"
)
else:
steps = eval_steps
if do_distributed:
devices = jax.devices()
if len(devices) <= 1:
raise ConfigurationError(
"YOu have set distributed training=True but there is only one device."
)
# Initialize callbacks
callbacks: List[EvalCallback] = [
callback.construct(
step_id=self.unique_id,
work_dir=self.work_dir,
dataset_dict=dataset,
dataloader=dataloader,
)
for callback in (callbacks or [])
]
for callback in callbacks:
callback.pre_eval_loop()
rng = get_PRNGkey(seed)
devices = jax.devices()
if len(devices) > 1:
do_distributed = True
def eval_step(state, batch):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
metrics = wrapper.eval_metrics(batch, logits, labels)
if do_distributed:
metrics = jax.lax.pmean(metrics, axis_name="batch")
return logits, metrics
if do_distributed:
state = jax_utils.replicate(state)
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
eval_batches = enumerate(islice(dataloader(rng, do_distributed), steps))
running_metrics: Dict[str, float] = defaultdict(float)
aggregated_metrics: Dict[str, float] = defaultdict(float)
with Tqdm.tqdm(eval_batches, desc="Evaluating", total=steps) as batch_iter:
for step, batch in batch_iter:
should_log_this_step = step % log_every == 0 or step == steps - 1
for callback in callbacks:
callback.pre_batch(step, batch)
if do_distributed:
logits, metrics = parallel_eval_step(state, batch)
metrics = jax_utils.unreplicate(metrics)
else:
logits, metrics = eval_step(state, batch)
for callback in callbacks:
callback.post_batch(step, logits)
if auto_aggregate_metrics:
for key, val in metrics.items():
if key in metric_names:
running_metrics[key] += metrics[key].item()
aggregated_metrics[key] = running_metrics[key] / (step + 1)
else:
aggregated_metrics.update(metrics)
if should_log_this_step:
batch_iter.set_postfix(**aggregated_metrics)
del batch
logger.info("Evaluation Metrics:")
for key, val in aggregated_metrics.items():
logger.info(key, ":", val)
for callback in callbacks:
callback.post_eval_loop(aggregated_metrics)
return aggregated_metrics
================================================
FILE: tango/integrations/flax/eval_callback.py
================================================
from pathlib import Path
from typing import Any, Dict
from tango.common.dataset_dict import DatasetDictBase
from tango.common.registrable import Registrable
from tango.workspace import Workspace
from .data import FlaxDataLoader
class EvalCallback(Registrable):
"""
An ``EvalCallback`` is a :class:`~tango.common.Registrable` class that can be used
within :class:`FlaxEvalStep` to customize the behavior of the evaluation loop,
similar to how :class:`TrainCallback` is used to customize the behavior of the training
loop.
.. tip::
All of the parameters to this base class will be automatically set within
the training loop, so you shouldn't include them in your config for your callbacks.
:ivar Workspace workspace: The tango workspace being used.
:ivar str step_id: The unique ID of the step.
:ivar pathlib.Path work_dir: The working directory of the step
:ivar DatasetDictBase dataset_dict: The dataset dict containing the evaluation split.
:ivar DataLoader dataloader: The data loader used to load the evaluation split data.
"""
def __init__(
self,
workspace: Workspace,
step_id: str,
work_dir: Path,
dataset_dict: DatasetDictBase,
dataloader: FlaxDataLoader,
) -> None:
self.workspace = workspace
self.step_id = step_id
self.work_dir = work_dir
self.dataset_dict = dataset_dict
self.dataloader = dataloader
def pre_eval_loop(self) -> None:
"""
Called right before the first batch is processed
"""
pass
def post_eval_loop(self, aggregated_metrics: Dict[str, float]) -> None:
"""
Called after the evaluation loop completes with the final aggregated metrics.
This is the last method that is called, so any cleanup can be done in this method.
"""
pass
def pre_batch(self, step: int, batch: Dict[str, Any]) -> None:
"""
Called directly before processing a batch.
"""
pass
def post_batch(self, step: int, batch_outputs: Dict[str, Any]) -> None:
"""
Called directly after processing a batch with the outputs of the batch.
.. tip::
This method can be used to modify ``batch_outputs`` in place, which is useful
in scenarios where you might need to aggregate metrics
in a special way other than a simple average. If that's the case, make sure
to set ``auto_aggregate_metrics`` to ``False`` in :class:`FlaxEvalStep`.
"""
pass
================================================
FILE: tango/integrations/flax/format.py
================================================
from pathlib import Path
from typing import Generic, TypeVar
from flax.training import checkpoints
from tango.common.aliases import PathOrStr
from tango.format import Format
T = TypeVar("T")
@Format.register("flax")
class FlaxFormat(Format[T], Generic[T]):
"""
This format writes the artifact.
.. tip::
Registered as a :class:`~tango.format.Format` under the name "flax".
"""
VERSION = "002"
def write(self, artifact: T, dir: PathOrStr) -> None:
checkpoints.save_checkpoint(Path(dir), artifact, step=0)
def read(self, dir: PathOrStr) -> T:
# will return a dict
return checkpoints.restore_checkpoint(dir, target=None)
================================================
FILE: tango/integrations/flax/model.py
================================================
from flax import linen as nn
from tango.common.registrable import Registrable
class Model(nn.Module, Registrable):
"""
This is a :class:`~tango.common.Registrable` mixin class that inherits from
:class:`flax.linen.Module`.
Its :meth:`~flax.linen.Module.setup()` can be used to register submodules,
variables, parameters you will need in your model.
Its :meth:`~flax.linen.Module.__call__()` returns the output of the model
for a given input.
"""
================================================
FILE: tango/integrations/flax/optim.py
================================================
from inspect import isfunction
from typing import Callable, Type
import optax
from tango.common.registrable import Registrable
class Optimizer(Registrable):
"""
A :class:`~tango.common.Registrable` version of Optax optimizers.
All `built-in Optax optimizers
`_
are registered according to their class name (e.g. "optax::adam").
.. tip::
You can see a list of all available optimizers by running
.. testcode::
from tango.integrations.flax import Optimizer
for name in sorted(Optimizer.list_available()):
print(name)
.. testoutput::
:options: +ELLIPSIS
optax::adabelief
optax::adadelta
optax::adafactor
optax::adagrad
optax::adam
...
"""
def __init__(self, optimizer: Callable) -> None:
self.optimizer = optimizer
def __call__(self, **kwargs) -> optax.GradientTransformation:
return self.optimizer(**kwargs)
class LRScheduler(Registrable):
"""
A :class:`~tango.common.Registrable` version of an Optax learning
rate scheduler.
All `built-in Optax learning rate schedulers
`_
are registered according to their class name (e.g. "optax::linear_schedule").
.. tip::
You can see a list of all available schedulers by running
.. testcode::
from tango.integrations.flax import LRScheduler
for name in sorted(LRScheduler.list_available()):
print(name)
.. testoutput::
:options: +ELLIPSIS
optax::constant_schedule
optax::cosine_decay_schedule
optax::cosine_onecycle_schedule
optax::exponential_decay
...
"""
def __init__(self, scheduler: Callable) -> None:
self.scheduler = scheduler
def __call__(self, **kwargs):
return self.scheduler(**kwargs)
def optimizer_factory(optim_method: Callable) -> Type[Callable]:
def factory_func():
return Optimizer(optim_method)
return factory_func()
def scheduler_factory(scheduler_method: Callable) -> Type[Callable]:
def factory_func():
return LRScheduler(scheduler_method)
return factory_func()
# Register all optimizers.
for name, cls in optax._src.alias.__dict__.items():
if isfunction(cls) and not name.startswith("_") and cls.__annotations__:
factory_func = optimizer_factory(cls)
Optimizer.register("optax::" + name)(factory_func)
# Register all learning rate schedulers.
for name, cls in optax.schedules.__dict__.items():
if isfunction(cls) and not name.startswith("_") and cls.__annotations__:
factory_func = scheduler_factory(cls)
LRScheduler.register("optax::" + name)(factory_func)
# TODO: Handle inject_hyperparams.
# Refer: https://optax.readthedocs.io/en/latest/api.html?highlight=inject%20hyperparam
================================================
FILE: tango/integrations/flax/train.py
================================================
import logging
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, DefaultDict, Dict, List, Optional
import jax
import jax.numpy as jnp
from flax import jax_utils
from flax.training import checkpoints
from flax.training.train_state import TrainState
from tango.common.dataset_dict import DatasetDictBase
from tango.common.exceptions import ConfigurationError
from tango.common.lazy import Lazy
from tango.common.tqdm import Tqdm
from tango.format import Format
from tango.step import Step
from tango.workspace import Workspace
from .data import FlaxDataLoader
from .format import FlaxFormat
from .model import Model
from .optim import LRScheduler, Optimizer
from .train_callback import TrainCallback
from .train_config import TrainConfig
from .util import get_multiple_keys, get_PRNGkey
from .wrapper import FlaxWrapper
PyTree = Any
@Step.register("flax::train")
class FlaxTrainStep(Step):
"""
A Flax training step that supports distributed training with configurable dataloaders, callbacks,
optimizer.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "flax::train".
.. important::
To train on GPUs and TPUs, installation of jax[cuda] or jax[tpu] is required. Follow the
instructions here: https://github.com/google/jax to set up jax for GPUs and TPUs.
Note: CUDA and cuDNN installation is required to run jax on NVidia GPUs. It is recommended to
install cuDNN in your conda environment using: ``conda install -c anaconda cudnn``.
Distributed data parallel training is activated when the ``device_count`` is greater than 1.
You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``.
For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1``
(and ``device_count`` to 2).
.. warning::
During validation, the validation metric (specified by the ``val_metric_name`` parameter)
is aggregated by simply averaging across validation batches and distributed processes.
This behavior is usually correct when your validation metric is "loss" or "accuracy",
for example, but may not be correct for other metrics like "F1".
If this is not correct for your metric you will need to handle the aggregation
internally in your model or with a :class:`TrainCallback`
using the :meth:`TrainCallback.post_val_batch()` method.
Then set the parameter ``auto_aggregate_val_metric`` to ``False``.
Jax pre-allocates 90% of GPU memory. If you run into out-of-memory (OOM) issues, please refer
to this: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html.
"""
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = FlaxFormat()
SKIP_ID_ARGUMENTS = {"log_every"}
METADATA = {"artifact_kind": "model"}
def run( # type: ignore[override]
self,
model: Model,
dataset: DatasetDictBase,
optimizer: Lazy[Optimizer],
train_dataloader: Lazy[FlaxDataLoader],
*,
wrapper: FlaxWrapper,
seed: int = 42,
keep_checkpoints: int = 5,
lr_scheduler: Optional[Lazy[LRScheduler]] = None,
train_split: str = "train",
validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None,
validation_split: Optional[str] = None,
train_steps: Optional[int] = None,
train_epoch: Optional[int] = None,
validation_steps: Optional[int] = None,
log_every: int = 10,
checkpoint_every: int = 100,
validate_every: Optional[int] = None,
val_metric_name: str = "loss",
minimize_val_metric: bool = True,
auto_aggregate_val_metric: bool = True,
callbacks: Optional[List[Lazy[TrainCallback]]] = None,
remove_stale_checkpoints: bool = True,
) -> PyTree:
"""
Run a basic training loop to train the ``model``.
:param model:
The flax model to train. It should define ``__call__()``. Defining ``setup()`` is Optional.
:param dataset:
The train and optional validation dataset.
:param optimizer:
The name of the optax Optimizer to use for training.
:param train_dataloader:
The dataloader object that generates training batches.
:param wrapper:
A Wrapper class that defines ``loss_fn()``, ``eval_fn()`` and ``compute_metrics()``
:param seed:
Used to set the PRNG state. By default, ``seed=42``
:param keep_checkpoints:
An integer which denotes how many previous checkpoints should be stored while training.
By default, ``keep_checkpoints=5``
:param lr_scheduler:
The name of the learning rate scheduler.
:param train_split:
The name of the data split used for training in the ``dataset_dict``.
Default is "train".
:param validation_dataloader:
An optional data loader for generating validation batches. The batches should be
:class:`dict` objects. If not specified, but ``validation_split`` is given,
the validation ``DataLoader`` will be constructed from the same parameters
as the train ``DataLoader``.
:param validation_split:
Optional name of the validation split in the ``dataset_dict``. Default is ``None``,
which means no validation.
:param train_steps:
The number of steps to train for. If not specified training will
stop after a complete iteration through the ``train_dataloader``.
:param train_epoch:
The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs``
at the same time.
:param validation_steps:
The number of steps to validate for. If not specified validation
will stop after a complete iteration through the ``validation_dataloader``.
:param log_every:
Log every this many steps.
:param checkpoint_every:
Save a checkpoint every this many steps.
:param validate_every:
Run the validation loop every this many steps.
:param val_metric_name:
The name of the validation metric, i.e. the key of the metric in the dictionary
returned by the forward pass of the model. Default is "loss".
:param minimize_val_metric:
Whether the validation metric is meant to be minimized (such as the loss).
Default is ``True``. When using a metric such as accuracy, you should set
this to ``False``.
:param auto_aggregate_val_metric:
If ``True`` (the default), the validation metric will be averaged across
validation batches and distributed processes. This may not be the correct
behavior for some metrics (such as F1), in which you should set this to
``False`` and handle the aggregation internally in your model
or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`).
:param callbacks:
A list of :class: `TrainCallback`.
:param remove_stale_checkpoints:
If ``True`` (the default), stale checkpoints will be removed throughout training so that
only the latest and best checkpoints are kept.
:returns:
The trained model with the last checkpoint loaded.
"""
return self._train(
dataset=dataset,
model=model,
optimizer=optimizer,
train_dataloader=train_dataloader,
train_wrapper=wrapper,
seed=seed,
keep_checkpoints=keep_checkpoints,
lr_scheduler=lr_scheduler,
train_split=train_split,
validation_split=validation_split,
validation_dataloader=validation_dataloader,
train_steps=train_steps,
train_epochs=train_epoch,
validation_steps=validation_steps,
log_every=log_every,
checkpoint_every=checkpoint_every,
validate_every=validate_every,
val_metric_name=val_metric_name,
minimize_val_metric=minimize_val_metric,
auto_aggregate_val_metric=auto_aggregate_val_metric,
callbacks=callbacks,
remove_stale_checkpoints=remove_stale_checkpoints,
)
def _train(
self,
model: Model,
optimizer: Lazy[Optimizer],
dataset: DatasetDictBase,
train_dataloader: Lazy[FlaxDataLoader],
*,
train_wrapper: FlaxWrapper,
seed: int = 42,
keep_checkpoints: int = 5,
lr_scheduler: Optional[Lazy[LRScheduler]],
train_split: str = "train",
validation_split: Optional[str] = None,
validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None,
train_steps: Optional[int] = None,
train_epochs: Optional[int] = None,
validation_steps: Optional[int] = None,
log_every: int = 10,
checkpoint_every: int = 100,
validate_every: Optional[int] = None,
val_metric_name: str = "loss",
minimize_val_metric: bool = True,
auto_aggregate_val_metric: bool = True,
callbacks: Optional[List[Lazy[TrainCallback]]] = None,
remove_stale_checkpoints: bool = True,
) -> PyTree:
if validate_every is not None and validation_split is None:
raise ConfigurationError(
"You have set a validation interval, but no validation split. "
"That's probably unintentional."
)
if (train_steps is not None) and (train_epochs is not None):
raise ConfigurationError(
"One of 'train_steps' or 'train_epochs' needs to be specified, but not both."
)
if isinstance(dataset, DatasetDictBase) and train_split is None:
raise ConfigurationError("Specify the train split for Datasets object.")
config = TrainConfig(
self.unique_id,
self.work_dir,
step_name=self.name,
train_split=train_split,
validation_split=validation_split,
seed=seed,
train_steps=train_steps,
train_epochs=train_epochs,
log_every=log_every,
checkpoint_every=checkpoint_every,
validate_every=validate_every,
validation_steps=validation_steps,
val_metric_name=val_metric_name,
minimize_val_metric=minimize_val_metric,
auto_aggregate_val_metric=auto_aggregate_val_metric,
remove_stale_checkpoints=remove_stale_checkpoints,
)
optimizer = self._construct_optimizer(optimizer)
lr_scheduler_: Optional[LRScheduler] = None
if lr_scheduler is not None:
lr_scheduler_ = self._construct_lr_scheduler(lr_scheduler)
lr_scheduler = lr_scheduler_
final_model: Model
final_model = self.train_helper(
self.workspace,
config,
model,
optimizer,
keep_checkpoints,
lr_scheduler,
train_wrapper,
dataset,
train_dataloader,
validation_dataloader,
callbacks,
)
assert final_model is not None
return final_model
def train_helper(
self,
workspace: Workspace,
config: TrainConfig,
model: Model,
optimizer: Optimizer,
keep_checkpoints: int,
lr_scheduler: Optional[LRScheduler],
train_wrapper: FlaxWrapper,
dataset: DatasetDictBase,
train_dataloader: Lazy[FlaxDataLoader],
validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None,
callbacks: Optional[List[Lazy[TrainCallback]]] = None,
) -> PyTree:
if lr_scheduler is not None:
raise NotImplementedError(
"Learning rate scheduling is not supported by the flax trainer. "
"Please voice your support for this feature at "
"https://github.com/allenai/tango/issues/477."
)
logger = logging.getLogger(FlaxTrainStep.__name__)
# construct data loaders
validation_dataloader_: Optional[FlaxDataLoader] = None
if config.validation_split is not None:
validation_dataset = dataset[config.validation_split]
validation_dataset.set_format("numpy")
if validation_dataloader is not None:
validation_dataloader_ = validation_dataloader.construct(dataset=validation_dataset)
else:
validation_dataloader_ = train_dataloader.construct(dataset=validation_dataset)
validation_dataloader = validation_dataloader_
train_dataset = dataset[config.train_split]
train_dataset.set_format("numpy") # type:ignore
train_dataloader: FlaxDataLoader = train_dataloader.construct(dataset=train_dataset)
devices = self._get_devices()
do_distributed: bool = False
if len(devices) > 1:
do_distributed = True
if validation_dataloader is not None:
validation_dataloader.batch_size *= len(devices)
train_dataloader.batch_size *= len(devices)
rng = get_PRNGkey(config.seed)
if hasattr(model, "params"):
params = model.params
else:
# TODO: Find better way to init the shape
shape = list(train_dataset["x"].shape)
shape[0] = 1
x = jnp.ones(shape)
params = model.init(rng, x)["params"]
state = TrainState.create(apply_fn=model.__call__, params=params, tx=optimizer)
initial_state: Optional[Dict[str, Any]] = None
if config.state_path.exists():
logger.info("Recovering from previous run at %s" % config.state_path)
state = self.load_checkpoint(config.state_path, state)
if config.train_epochs is None:
assert config.train_steps is not None
try:
train_epochs = len(train_dataloader.dataset) // train_dataloader.batch_size
except TypeError:
raise ConfigurationError(
"You must set train_epochs for streaming/iterable datasets"
)
config.train_epochs = train_epochs
assert config.train_epochs is not None
if validation_dataloader is not None:
if config.validation_steps is None:
try:
config.validation_steps = len(validation_dataloader.dataset)
except TypeError:
raise ConfigurationError(
"You must set 'validation_steps' for streaming/iterable datasets"
)
val_metric: Optional[float] = None
best_val_metric: Optional[float] = None
start_step: int = 0
if initial_state is not None:
val_metric = initial_state[f"val_{config.val_metric_name}"]
best_val_metric = initial_state[f"best_{config.val_metric_name}"]
start_step = initial_state["training_epochs"]
# Initialize callbacks
callbacks: List[TrainCallback] = [
callback.construct(
workspace=workspace,
train_config=config,
dataset=dataset,
train_dataloader=train_dataloader,
model=model,
optimizer=optimizer,
validation_dataloader=validation_dataloader,
)
for callback in (callbacks or [])
]
if initial_state:
for callback, state in zip(callbacks, initial_state["callbacks"]):
callback.load_state_dict(state)
del initial_state
if start_step > 0:
with Tqdm.tqdm(
train_dataloader,
desc=f"Catching dataloader up to step {start_step}",
total=start_step - 1,
) as batch_iter:
for step, batch in enumerate(batch_iter):
del batch
if step >= start_step - 1:
break
def train_step(state, batch, dropout_rng):
# if transformer model
labels = batch.pop("labels")
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
grad_fn = jax.value_and_grad(train_wrapper.train_loss)
loss, grad = grad_fn(state.params, state, batch, dropout_rng, labels)
if do_distributed:
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
other_metrics = train_wrapper.train_metrics(state, batch, labels=labels)
metrics = {"loss": loss}
metrics.update(other_metrics)
if do_distributed:
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics, new_dropout_rng
def val_step(state, batch):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
metrics = train_wrapper.val_metrics(batch, logits, labels)
if do_distributed:
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
if do_distributed:
# NOTE: The trainer currently handles only data parallelism.
state = jax_utils.replicate(state)
dropout_rngs = get_multiple_keys(rng, jax.local_device_count())
parallel_train_step = jax.pmap(train_step, axis_name="batch")
parallel_val_step = jax.pmap(val_step, axis_name="batch")
step_per_epoch = train_dataloader.dataset_size // train_dataloader.batch_size
config.train_steps = step_per_epoch * config.train_epochs
assert config.train_steps is not None # for mypy
for callback in callbacks:
callback.pre_train_loop()
logger.info("***** Running training *****")
logger.info(f" Num examples = {train_dataloader.dataset_size}")
logger.info(f" Num Epochs = {config.train_epochs}")
logger.info(
f" Total train batch size (w. parallel & distributed) = {train_dataloader.batch_size}"
)
logger.info(f" Total optimization steps = {config.train_steps}")
step = start_step
epochs = Tqdm.tqdm(
range(config.train_epochs), desc=f"Epoch (1/{config.train_epochs})", position=0
)
for epoch in epochs:
start = time.time()
train_metrics = []
for callback in callbacks:
callback.pre_epoch(step, epoch)
train_loader = train_dataloader(rng, do_distributed)
for _ in Tqdm.tqdm(range(step_per_epoch), desc="Training", position=1):
batch = next(train_loader)
for callback in callbacks:
callback.pre_batch(step, epoch, batch)
if do_distributed:
state, train_metric, dropout_rngs = parallel_train_step(
state, batch, dropout_rngs
)
else:
state, train_metric, rng = train_step(state, batch, rng)
train_metrics.append(train_metric)
for callback in callbacks:
callback.post_batch(step, epoch, train_metric)
if config.should_log_this_step(step):
for callback in callbacks:
callback.log_batch(step, epoch, train_metric)
if config.should_checkpoint_this_step(step):
self.save_checkpoint(config.state_path, state, step, keep_checkpoints)
step += 1
# check if we need to do validation
if config.validation_split is None:
# If we can't validate, we don't.
should_validate = False
elif step == config.train_steps - 1:
# If we're at the end of the training run, we always validate.
should_validate = True
elif config.validate_every is not None and step % config.validate_every == 0:
# If validate_every is given, we use that to decide.
should_validate = True
else:
# Otherwise, we don't validate.
should_validate = False
if should_validate:
assert validation_dataloader is not None
assert config.validation_steps is not None
val_metrics: DefaultDict = defaultdict(list)
epoch_eval_metrics: DefaultDict = defaultdict(float)
val_dataloader = validation_dataloader(rng, do_distributed)
valid_step = 0
total_val_steps = len(validation_dataset) // validation_dataloader.batch_size
for callback in callbacks:
callback.pre_val_loop(step, valid_step, state)
for _ in Tqdm.tqdm(range(total_val_steps), desc="Evaluating", position=2):
batch = next(val_dataloader)
for callback in callbacks:
callback.pre_val_batch(step, valid_step, epoch, batch)
if do_distributed:
metrics = parallel_val_step(state, batch)
metrics = jax_utils.unreplicate(metrics)
else:
metrics = val_step(state, batch)
for key, value in metrics.items():
val_metrics[key].append(value.item())
for callback in callbacks:
callback.post_val_batch(step, valid_step, epoch, val_metrics)
valid_step += 1
for key, value in val_metrics.items():
if config.auto_aggregate_val_metric:
epoch_eval_metrics[key] = jax.tree_map(
jnp.mean, jnp.array(value)
).item()
else:
epoch_eval_metrics[key] = metrics[key].item()
for key, value in epoch_eval_metrics.items():
print("Validation %s : %.5f" % (key, value))
val_metric = epoch_eval_metrics[config.val_metric_name]
assert val_metric is not None
if best_val_metric is None:
best_val_metric = val_metric
elif config.minimize_val_metric and val_metric <= best_val_metric:
best_val_metric = val_metric
elif not config.minimize_val_metric and val_metric >= best_val_metric:
best_val_metric = val_metric
for callback in callbacks:
callback.post_val_loop(step, epoch, val_metric, best_val_metric)
if do_distributed:
train_metric = jax_utils.unreplicate(train_metric)
for key, value in train_metric.items():
print("Train %s : %.2f" % (key, value))
for callback in callbacks:
callback.post_epoch(step, epoch)
end = time.time()
train_time = (end - start) / 60
desc = f"Epoch... ({epoch + 1}/{config.train_epochs} | Time taken (mins): {train_time})"
epochs.write(desc)
epochs.desc = desc
for callback in callbacks:
callback.post_train_loop(step, epoch)
if do_distributed:
state = jax_utils.unreplicate(state)
return state
def save_checkpoint(self, dir: Path, target: PyTree, step: int, keep_checkpoints: int):
return checkpoints.save_checkpoint(
dir, target, step, prefix="checkpoint_", keep=keep_checkpoints, overwrite=True
)
def load_checkpoint(self, dir: Path, target: PyTree):
return checkpoints.restore_checkpoint(dir, target, prefix="checkpoint_")
def _construct_optimizer(self, optimizer):
self.optimizer = optimizer.construct()
return self.optimizer
def _construct_lr_scheduler(self, scheduler):
self.lr_scheduler = scheduler.construct()
return self.lr_scheduler
def _get_devices(self) -> List[Any]:
device_type = jax.default_backend()
self.devices = jax.devices()
device_count = len(self.devices)
print("Training on %d %s" % (device_count, device_type))
return self.devices
================================================
FILE: tango/integrations/flax/train_callback.py
================================================
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from tango.common.dataset_dict import DatasetDictBase
from tango.common.registrable import Registrable
from tango.workspace import Workspace
from .data import DataLoader
from .model import Model
from .optim import Optimizer
from .train_config import TrainConfig
class TrainCallback(Registrable):
"""
A :class:`TrainCallback` is a :class:`~tango.common.Registrable` class
that can be used within :class:`FlaxTrainStep` to customize behavior in the training
loop. You can set the training callbacks with the ``callbacks`` parameter to :class:`FlaxTrainStep`.
.. tip::
All of the parameters to this base class will be automatically set within
the training loop, so you shouldn't include them in your config for your callbacks.
.. tip::
You can access the model being trained through :attr:`self.model `.
.. important::
The ``step`` argument to callback methods is the total/overall number of training steps
so far, independent of the current epoch.
.. seealso::
See :class:`~tango.integrations.wandb.WandbTrainCallback` for an example
implementation.
:ivar Workspace workspace: The tango workspace being used.
:ivar TrainConfig train_config: The training config.
:ivar tango.common.DatasetDictBase dataset_dict: The dataset dict containing train and
optional validation splits.
:ivar DataLoader train_dataloader: The dataloader used for the training split.
:ivar Model model: The flax model being trained.
:ivar Optimizer optimizer: The optimizer being used for training.
:ivar DataLoader validation_dataloader: Optional dataloader used for the validation split.
"""
def __init__(
self,
workspace: Workspace,
train_config: TrainConfig,
dataset: DatasetDictBase,
train_dataloader: DataLoader,
model: Model,
optimizer: Optimizer,
validation_dataloader: Optional[DataLoader] = None,
) -> None:
self.workspace = workspace
self.train_config = train_config
self.dataset = dataset
self.train_dataloader = train_dataloader
self.model = model
self.optimizer = optimizer
self.validation_dataloader = validation_dataloader
self.logger = logging.getLogger(self.__class__.__name__)
@property
def step_id(self) -> str:
"""
The unique ID of the current :class:`~tango.Step`.
"""
return self.train_config.step_id
@property
def step_name(self) -> Optional[str]:
"""
The name of the current:class:`~tango.Step`.
"""
return self.train_config.step_name
@property
def work_dir(self) -> Path:
"""
The working directory of the current train step
"""
return self.train_config.work_dir
def state_dict(self) -> Dict[str, Any]:
"""
Return any state that needs to be kept after a restart.
Some callbacks need to maintain state across restarts. This is the callback's opportunity to
save it's state. It will be restored using :meth:`load_state_dict`.
"""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""
Load the state on a restart.
Some callbacks need to maintain state across restarts. This is the callback's opportunity to
restore it's state. It gets saved using :meth:`state_dict`.
"""
pass
def pre_train_loop(self) -> None:
"""
Called right before the first batch is processed, or after a restart
"""
pass
def post_train_loop(self, step: int, epoch: int) -> None:
"""
Called after the training loop completes.
This is the last method that is called, so any cleanup can be done in this method.
"""
pass
def pre_epoch(self, step: int, epoch: int) -> None:
"""
Called before start of an epoch. Epochs start at 0.
"""
pass
def post_epoch(self, step: int, epoch: int) -> None:
"""
Called after an epoch is completed. Epochs start at 0.
"""
pass
def pre_batch(self, step: int, epoch: int, batch) -> None:
"""
Called directly before processing a batch.
"""
def post_batch(self, step: int, epoch: int, train_metrics: Dict) -> None:
"""
Called directly after processing a batch, but before unscaling gradients,
clipping gradients, and taking an optimizer step.
.. note::
The ``train_metrics`` here is the dictionary with train metrics of the
current batch. If doing, distributed training, use `jax_utils.unreplicate(train_metrics)`
before using train_metrics.
If you need the average loss, use :meth:`log_batch()`.
"""
pass
def log_batch(self, step: int, epoch: int, train_metrics: Dict) -> None:
"""
Called after the optimizer step. Here ``train_metrics`` is the average metrics across
all distributed workers. If doing, distributed training, use
`jax_utils.unreplicate(train_metrics)` before using train_metrics.
.. note::
This callback method is not necessarily called on every step.
The frequency depends on the value of the ``log_every`` parameter of
:class:`FlaxTrainStep`.
"""
pass
def pre_val_loop(self, step: int, val_step: int, state) -> None:
"""
Called right before the validation loop starts.
"""
pass
def pre_val_batch(self, step: int, val_step: int, epoch: int, val_batch) -> None:
"""
Called right before a validation batch is processed.
"""
pass
def post_val_batch(self, step: int, val_step: int, epoch: int, val_metrics: Dict) -> None:
"""
Called right after a validation batch is processed with the outputs of the batch.
.. tip::
This method can be used to modify ``val_metrics`` in place, which is useful
in scenarios like distributed training where you might need to aggregate metrics
in a special way other than a simple average. If that's the case, make sure
to set ``auto_aggregate_val_metric`` to ``False`` in :class:`FlaxTrainStep`.
"""
pass
def post_val_loop(
self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float]
) -> None:
"""
Called right after the evaluation loop finishes
"""
pass
================================================
FILE: tango/integrations/flax/train_config.py
================================================
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Optional
@dataclass
class TrainConfig:
"""
Encapsulates the parameters of :class:`FlaxTrainStep`. This is used to pass all the training
options to :class:`TrainCallback`.
"""
step_id: str
"""
The unique ID of the current step.
"""
work_dir: Path
"""
The working directory for the training run.
"""
step_name: Optional[str] = None
"""
The name of the current step.
"""
train_split: str = "train"
"""
The name of the training split.
"""
validation_split: Optional[str] = None
"""
The name of the validation split.
"""
seed: int = 42
"""
The random seed used to generate
"""
train_steps: Optional[int] = None
"""
The number of steps to train for.
"""
train_epochs: Optional[int] = None
"""
The number of epochs to train for.
You cannot specify `train_steps` and `train_epochs` at the same time.
"""
validation_steps: Optional[int] = None
"""
The number of validation steps.
"""
log_every: int = 10
"""
Controls the frequency of log updates.
"""
checkpoint_every: int = 100
"""
Controls the frequency of checkpoints.
"""
validate_every: Optional[int] = None
"""
Controls the frequency of the validation loop.
"""
is_distributed: bool = False
"""
Whether or not the training job is distributed.
"""
val_metric_name: str = "loss"
"""
The name of the validation metric to track.
"""
minimize_val_metric: bool = True
"""
Should be ``True`` when the validation metric being tracked should be minimized.
"""
auto_aggregate_val_metric: bool = True
"""
Controls automatic aggregation of validation metric.
"""
remove_stale_checkpoints: bool = True
"""
Controls removal of stale checkpoints.
"""
@property
def state_path(self) -> Path:
"""
The path to the latest state checkpoint file.
"""
return self.work_dir / "checkpoint_state_latest"
@property
def best_state_path(self) -> Path:
"""
The path to the best state checkpoint file according to the validation metric or training
loss (if no validation split is given).
"""
return self.work_dir / "checkpoint_state_best"
def should_log_this_step(self, step: int) -> bool:
assert self.train_steps is not None
return step == 0 or (step + 1) % self.log_every == 0 or step == self.train_steps - 1
def should_checkpoint_this_step(self, step: int) -> bool:
assert self.train_steps is not None
return ((step + 1) % self.checkpoint_every == 0) or step == self.train_steps - 1
def should_log_this_val_step(self, val_step: int) -> bool:
assert self.validation_steps is not None
return val_step % self.log_every == 0 or val_step == self.validation_steps - 1
def as_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if not k.startswith("_")}
================================================
FILE: tango/integrations/flax/util.py
================================================
from typing import Any, Union
import jax
def get_PRNGkey(seed: int = 42) -> Union[Any, jax._src.random.KeyArray]:
"""
Utility function to create a pseudo-random number generator key
given a seed.
"""
return jax.random.PRNGKey(seed)
def get_multiple_keys(key, multiple: int = 1) -> Union[Any, jax._src.random.KeyArray]:
"""
Utility function to split a PRNG key into multiple new keys.
Used in distributed training.
"""
return jax.random.split(key, multiple)
================================================
FILE: tango/integrations/flax/wrapper.py
================================================
from abc import abstractmethod
from typing import Dict
from tango.common.registrable import Registrable
class FlaxWrapper(Registrable):
"""
A wrapper class which contains functions that need to be defined by the user
for using the ``flax::train`` and ``flax::eval`` steps.
"""
def train_metrics(self, state, batch, labels) -> Dict:
"""
Returns the train metrics other than loss as Dict.
"""
# return empty dict if no other metrics to compute
return {}
@abstractmethod
def train_loss(self, params, state, batch, dropout_rng, labels):
"""
This function performs the forward pass and computes loss. The function
should return the loss for the batch as a jax device array. The gradient
of this function is used for training.
"""
raise NotImplementedError()
@abstractmethod
def val_metrics(self, batch, logits, labels) -> Dict:
"""
Returns the validation metrics as Dict.
"""
raise NotImplementedError()
@abstractmethod
def eval_metrics(self, batch, logits, labels) -> Dict:
"""
Returns the evaluation metrics as Dict.
"""
raise NotImplementedError()
================================================
FILE: tango/integrations/gs/__init__.py
================================================
"""
.. important::
To use this integration you should install ``tango`` with the "gs" extra
(e.g. ``pip install tango[gs]``) or just install the `gcsfs `_
library after the fact (e.g. ``pip install gcsfs``).
Components for Tango integration with `GS `_.
"""
from tango.common.exceptions import IntegrationMissingError
try:
from google.cloud import datastore, storage
except (ModuleNotFoundError, ImportError):
raise IntegrationMissingError("gs", dependencies={"google-cloud-storage"})
from .step_cache import GSStepCache
from .workspace import GSWorkspace
__all__ = [
"GSStepCache",
"GSWorkspace",
]
================================================
FILE: tango/integrations/gs/common.py
================================================
"""
Classes and utility functions for GSWorkspace and GSStepCache.
"""
import atexit
import datetime
import json
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union
import google.auth
from google.api_core import exceptions
from google.auth.credentials import Credentials
from google.cloud import storage
from google.oauth2.credentials import Credentials as OAuth2Credentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from tango.common.aliases import PathOrStr
from tango.common.exceptions import TangoError
from tango.common.remote_utils import RemoteConstants
from tango.step import Step
from tango.step_info import StepInfo
logger = logging.getLogger(__name__)
def get_bucket_and_prefix(folder_name: str) -> Tuple[str, str]:
"""
Split bucket name and subfolder name, if present.
"""
split = folder_name.split("/")
return split[0], "/".join(split[1:])
def empty_bucket_folder(folder_name: str):
"""
Removes all the tango-related blobs from the specified bucket folder.
Used for testing.
"""
credentials, project = google.auth.default()
client = storage.Client(project=project, credentials=credentials)
bucket_name, prefix = get_bucket_and_prefix(folder_name)
prefix = prefix + "/tango-" if prefix else "tango-"
bucket = client.bucket(bucket_name)
try:
bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix)))
except exceptions.NotFound:
pass
def empty_datastore(folder_name: str):
"""
Removes all the tango-related entities from the specified namespace subfolder in datastore.
Used for testing.
"""
from google.cloud import datastore
credentials, project = google.auth.default()
namespace, prefix = get_bucket_and_prefix(folder_name)
run_kind = prefix + "/run" if prefix else "run"
stepinfo_kind = prefix + "/stepinfo" if prefix else "stepinfo"
client = datastore.Client(project=project, credentials=credentials, namespace=namespace)
run_query = client.query(kind=run_kind)
run_query.keys_only()
keys = [entity.key for entity in run_query.fetch()]
stepinfo_query = client.query(kind=stepinfo_kind)
stepinfo_query.keys_only()
keys += [entity.key for entity in stepinfo_query.fetch()]
client.delete_multi(keys)
@dataclass
class GSArtifact:
"""
A GSArtifact object is used for representing storage objects in google cloud storage.
"""
name: str
"""
Name of the artifact.
"""
artifact_path: str
"""
Remote location url for the artifact.
"""
created: datetime.datetime
"""
Time of creation.
"""
committed: bool
"""
If set to True, no further changes to the remote artifact are allowed.
If set to False, it means that the artifact is under construction.
"""
class GSArtifactConflict(TangoError):
"""
Error denoting that the storage artifact already exists.
"""
pass
class GSArtifactNotFound(TangoError):
"""
Error denoting that the storage artifact does not exist.
"""
pass
class GSArtifactWriteError(TangoError):
"""
Error denoting that there was an issue writing the artifact to the remote storage.
"""
pass
def join_path(*args) -> str:
"""
We use this since we cannot use `os.path.join` for cloud storage paths.
"""
return "/".join(args).strip("/")
class GSClient:
"""
A client for interacting with Google Cloud Storage. The authorization works by
providing OAuth2 credentials.
:param folder_name: The name of the Google Cloud bucket folder to use.
:param credentials: OAuth2 credentials can be provided. If not provided, default
gcloud credentials are inferred.
:param project: Optionally, the project ID can be provided. This is not essential
for `google.cloud.storage` API, since buckets are at the account level, rather
than the project level.
"""
placeholder_file = ".placeholder"
"""
The placeholder file is used for creation of a folder in the cloud bucket folder,
as empty folders are not allowed. It is also used as a marker for the creation
time of the folder, hence we use a separate file to mark the artifact as
uncommitted.
"""
uncommitted_file = ".uncommitted"
"""
The uncommitted file is used to denote an artifact under construction.
"""
settings_file = "settings.json"
"""
This file is for storing metadata like version information, etc.
"""
NUM_CONCURRENT_WORKERS: int = 9
def __init__(
self,
folder_name: str,
credentials: Optional[Credentials] = None,
project: Optional[str] = None,
):
if not credentials:
credentials, project = google.auth.default()
self.storage = storage.Client(project=project, credentials=credentials)
self.folder_name = folder_name
self.bucket_name, self.prefix = get_bucket_and_prefix(folder_name)
settings_file = self._gs_path(self.settings_file)
blob = self.storage.bucket(self.bucket_name).blob(settings_file) # no HTTP request yet
try:
with blob.open("r") as file_ref:
json.load(file_ref)
except exceptions.NotFound:
settings = {"version": 1}
with blob.open("w") as file_ref:
json.dump(settings, file_ref)
def url(self, artifact: Optional[str] = None):
"""
Returns the remote url of the storage artifact.
"""
path = f"gs://{self.folder_name}"
if artifact is not None:
path = f"{path}/{artifact}"
return path
def _convert_blobs_to_artifact(self, blobs: List[storage.Blob]) -> GSArtifact:
"""
Converts a list of `google.cloud.storage.Blob` to a `GSArtifact`.
"""
name: str
artifact_path: str
created: datetime.datetime
committed: bool = True
for blob in blobs:
if blob.name.endswith(self.placeholder_file):
created = blob.time_created
name = blob.name.replace("/" + self.placeholder_file, "")
if self.prefix:
name = name.replace(self.prefix + "/", "")
artifact_path = name # does not contain bucket info here.
elif blob.name.endswith(self.uncommitted_file):
committed = False
assert name is not None, "Folder is not a GSArtifact, should not have happened."
return GSArtifact(name, artifact_path, created, committed)
@classmethod
def from_env(cls, folder_name: str):
"""
Constructs the client object from the environment, using default credentials.
"""
return cls(folder_name)
def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact:
"""
Returns a `GSArtifact` object created by fetching the artifact's information
from remote location.
"""
if isinstance(artifact, str):
path = artifact
else:
# We have an artifact, and we recreate it with refreshed info.
path = artifact.artifact_path
prefix = self._gs_path(path)
blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=prefix))
if len(blobs) > 0:
return self._convert_blobs_to_artifact(blobs)
else:
raise GSArtifactNotFound()
def _gs_path(self, *args):
"""
Returns path within google cloud storage bucket.
"""
return join_path(self.prefix, *args)
def create(self, artifact: str):
"""
Creates a new artifact in the remote location. By default, it is uncommitted.
"""
bucket = self.storage.bucket(self.bucket_name)
# gives refreshed information
artifact_path = self._gs_path(artifact, self.placeholder_file)
if bucket.blob(artifact_path).exists():
raise GSArtifactConflict(f"{artifact} already exists!")
else:
# Additional safety check
if bucket.blob(self._gs_path(artifact, self.uncommitted_file)).exists():
raise GSArtifactConflict(f"{artifact} already exists!")
bucket.blob(self._gs_path(artifact, self.placeholder_file)).upload_from_string("")
bucket.blob(self._gs_path(artifact, self.uncommitted_file)).upload_from_string("")
return self._convert_blobs_to_artifact(
list(bucket.list_blobs(prefix=self._gs_path(artifact)))
)
def delete(self, artifact: GSArtifact):
"""
Removes the artifact from the remote location.
"""
bucket = self.storage.bucket(self.bucket_name)
prefix = self._gs_path(artifact.artifact_path)
blobs = list(bucket.list_blobs(prefix=prefix))
bucket.delete_blobs(blobs)
def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path):
"""
Writes the contents of objects_dir to the remote artifact location.
"""
if isinstance(artifact, str):
folder_path = artifact
else:
folder_path = artifact.artifact_path
source_path = str(objects_dir)
def _sync_blob(source_file_path: str, target_file_path: str):
blob = self.storage.bucket(self.bucket_name).blob(self._gs_path(target_file_path))
blob.upload_from_filename(source_file_path)
import concurrent.futures
try:
# TODO: google-cloud-storage==2.7.0 has added a preview feature called `transfer_manager`
# which allows for concurrent uploads and downloads. We should upgrade to this once
# it is more robustly supported. Also update in `download()`.
if os.path.isdir(source_path):
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSClient.upload()-"
) as executor:
upload_futures = []
for dirpath, _, filenames in os.walk(source_path):
for filename in filenames:
source_file_path = os.path.join(dirpath, filename)
target_file_path = join_path(
folder_path, source_file_path.replace(source_path + "/", "")
)
upload_futures.append(
executor.submit(_sync_blob, source_file_path, target_file_path)
)
for future in concurrent.futures.as_completed(upload_futures):
future.result()
else:
source_file_path = source_path
target_file_path = join_path(folder_path, os.path.basename(source_file_path))
_sync_blob(source_file_path, target_file_path)
except Exception:
raise GSArtifactWriteError()
def commit(self, artifact: Union[str, GSArtifact]):
"""
Marks the artifact as committed. No further changes to the artifact are allowed.
"""
if isinstance(artifact, str):
folder_path = artifact
else:
folder_path = artifact.artifact_path
bucket = self.storage.bucket(self.bucket_name)
try:
bucket.delete_blob(self._gs_path(folder_path, self.uncommitted_file))
except exceptions.NotFound:
if not bucket.blob(self._gs_path(folder_path, self.placeholder_file)).exists():
raise GSArtifactNotFound()
# Otherwise, already committed. No change.
def download(self, artifact: GSArtifact, target_dir: PathOrStr):
"""
Writes the contents of the remote artifact to the `target_dir`.
"""
assert (
self.storage.bucket(self.bucket_name)
.blob(self._gs_path(artifact.artifact_path, self.placeholder_file))
.exists()
)
def _fetch_blob(blob: storage.Blob):
source_path = blob.name.replace(artifact.artifact_path + "/", "")
target_path = os.path.join(target_dir, source_path)
if not os.path.exists(os.path.dirname(target_path)):
os.mkdir(os.path.dirname(target_path))
blob.download_to_filename(target_path)
import concurrent.futures
bucket = self.storage.bucket(self.bucket_name)
# We may not need updates that frequently, with list_blobs(prefix).
# bucket.update()
try:
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSClient.download()-"
) as executor:
download_futures = []
prefix = self._gs_path(artifact.artifact_path)
for blob in bucket.list_blobs(prefix=prefix):
download_futures.append(executor.submit(_fetch_blob, blob))
for future in concurrent.futures.as_completed(download_futures):
future.result()
except exceptions.NotFound:
raise GSArtifactWriteError()
def artifacts(self, prefix: str, uncommitted: bool = True) -> List[GSArtifact]:
"""
Lists all the artifacts within the remote storage, based on
`match` and `uncommitted` criteria. These can include steps and runs.
"""
list_of_artifacts = []
prefix = self._gs_path(prefix)
for folder_name in self.storage.list_blobs(
self.bucket_name, prefix=prefix, delimiter="/"
)._get_next_page_response()["prefixes"]:
artifact = self._convert_blobs_to_artifact(
list(self.storage.list_blobs(self.bucket_name, prefix=folder_name))
)
if not uncommitted:
if not artifact.committed:
continue
list_of_artifacts.append(artifact)
return list_of_artifacts
def get_credentials(credentials: Optional[Union[str, Credentials]] = None) -> Credentials:
"""
:param credentials:
* if OAuth2 credentials are provided, they are returned.
* if `str`, it can be either a file path or a json string of credentials dict.
* if `None`, credentials are inferred from the environment.
More details on Google Cloud credentials can be found here:
https://googleapis.dev/python/google-auth/latest/user-guide.html#service-account-private-key-files,
and https://googleapis.dev/python/google-api-core/latest/auth.html
"""
# BeakerExecutor uses GOOGLE_TOKEN
credentials = os.environ.get("GOOGLE_TOKEN", credentials)
if credentials is not None:
# Path to the credentials file has been provided
if isinstance(credentials, str) and credentials.endswith(".json"):
with open(credentials) as file_ref:
credentials = file_ref.read()
try:
# If credentials dict has been passed as a json string
credentials_dict = json.loads(credentials)
if credentials_dict.pop("type", None) == "service_account":
credentials = ServiceAccountCredentials.from_service_account_info(credentials_dict)
else:
# sometimes the credentials dict may not contain `token` and `token_uri` keys,
# but `Credentials()` needs the parameter.
token = credentials_dict.pop("token", None)
token_uri = credentials_dict.pop("token_uri", "https://oauth2.googleapis.com/token")
credentials = OAuth2Credentials(
token=token, token_uri=token_uri, **credentials_dict
)
except (json.decoder.JSONDecodeError, TypeError, ValueError):
# It is not a json string.
# We use this string because BeakerExecutor cannot write a None secret.
if credentials == "default":
credentials = None
if not credentials:
# Infer default credentials
credentials, _ = google.auth.default()
return credentials
def get_client(
folder_name: str,
credentials: Optional[Union[str, Credentials]] = None,
project: Optional[str] = None,
) -> GSClient:
"""
Returns a `GSClient` object for a google cloud bucket folder.
"""
credentials = get_credentials(credentials)
return GSClient(folder_name, credentials=credentials, project=project)
class Constants(RemoteConstants):
pass
class GCSStepLock:
"""
Google Cloud offers consistency https://cloud.google.com/storage/docs/consistency,
so we can use lock files.
"""
def __init__(
self,
client: GSClient,
step: Union[str, StepInfo, Step],
):
self._client = client
self._step_id = step if isinstance(step, str) else step.unique_id
self._lock_artifact_name = RemoteConstants.step_lock_artifact_name(step)
self._lock_artifact: Optional[GSArtifact] = None
self.lock_artifact_url = self._client.url(self._lock_artifact_name)
def acquire(self, timeout=None, poll_interval: float = 2.0, log_interval: float = 30.0) -> None:
if self._lock_artifact is not None:
return
start = time.monotonic()
last_logged = None
while timeout is None or (time.monotonic() - start < timeout):
try:
self._lock_artifact = self._client.create(self._lock_artifact_name)
atexit.register(self.release)
except GSArtifactConflict:
if last_logged is None or last_logged - start >= log_interval:
logger.warning(
"Waiting to acquire lock artifact for step '%s':\n\n%s\n\n"
"This probably means the step is being run elsewhere, but if you're sure it isn't "
"you can just delete the lock artifact, using the command: \n`gsutil rm -r %s`",
self._step_id,
self.lock_artifact_url,
self.lock_artifact_url,
)
last_logged = time.monotonic()
time.sleep(poll_interval)
continue
else:
break
else:
raise TimeoutError(
f"Timeout error occurred while waiting to acquire artifact lock for step '{self._step_id}':\n\n"
f"{self.lock_artifact_url}\n\n"
f"This probably means the step is being run elsewhere, but if you're sure it isn't you can "
f"just delete the lock, using the command: \n`gsutil rm -r {self.lock_artifact_url}`"
)
def release(self):
if self._lock_artifact is not None:
try:
self._client.delete(self._lock_artifact)
except GSArtifactNotFound:
# Artifact must have been manually deleted.
pass
self._lock_artifact = None
atexit.unregister(self.release)
def __del__(self):
self.release()
================================================
FILE: tango/integrations/gs/step_cache.py
================================================
import logging
from pathlib import Path
from typing import Optional, Union
from tango.common import PathOrStr
from tango.common.util import make_safe_filename, tango_cache_dir
from tango.integrations.gs.common import (
Constants,
GSArtifact,
GSArtifactConflict,
GSArtifactNotFound,
GSArtifactWriteError,
GSClient,
get_bucket_and_prefix,
)
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache
from tango.step_info import StepInfo
logger = logging.getLogger(__name__)
@StepCache.register("gs")
class GSStepCache(RemoteStepCache):
"""
This is a :class:`~tango.step_cache.StepCache` that's used by :class:`GSWorkspace`.
It stores the results of steps on Google cloud buckets as blobs.
It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a
step's resulting subsequent times should be fast.
.. tip::
Registered as a :class:`~tango.step_cache.StepCache` under the name "gs".
:param folder_name: The name of the google cloud bucket folder to use.
:param client: The google cloud storage client to use.
"""
Constants = Constants
def __init__(self, folder_name: str, client: Optional[GSClient] = None):
if client is not None:
bucket_name, _ = get_bucket_and_prefix(folder_name)
assert (
bucket_name == client.bucket_name
), "Assert that bucket name is same as client bucket until we do better"
self.folder_name = folder_name
self._client = client
else:
self._client = GSClient(folder_name)
super().__init__(tango_cache_dir() / "gs_cache" / make_safe_filename(folder_name))
@property
def client(self):
return self._client
def _step_result_remote(self, step: Union[Step, StepInfo]) -> Optional[GSArtifact]:
"""
Returns a `GSArtifact` object containing the details of the step.
This only returns if the step has been finalized (committed).
"""
try:
artifact = self.client.get(self.Constants.step_artifact_name(step))
return artifact if artifact.committed else None
except GSArtifactNotFound:
return None
def _upload_step_remote(self, step: Step, objects_dir: Path) -> GSArtifact:
"""
Uploads the step's output to remote location.
"""
artifact_name = self.Constants.step_artifact_name(step)
try:
self.client.create(artifact_name)
except GSArtifactConflict:
pass
try:
self.client.upload(artifact_name, objects_dir)
self.client.commit(artifact_name)
except GSArtifactWriteError:
pass
return self.client.get(artifact_name)
def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None:
"""
Downloads the step's output from remote location.
"""
try:
self.client.download(step_result, target_dir)
except GSArtifactNotFound:
raise RemoteNotFoundError()
def __len__(self):
"""
Returns the number of committed step outputs present in the remote location.
"""
# NOTE: lock files should not count here.
return sum(
1
for ds in self.client.artifacts(
prefix=self.Constants.STEP_ARTIFACT_PREFIX, uncommitted=False
)
if ds.name is not None
and ds.name.startswith(self.Constants.STEP_ARTIFACT_PREFIX)
and not ds.name.endswith(self.Constants.LOCK_ARTIFACT_SUFFIX)
)
================================================
FILE: tango/integrations/gs/workspace.py
================================================
import json
import random
from pathlib import Path
from typing import (
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from urllib.parse import ParseResult
import petname
from google.auth.credentials import Credentials
from google.cloud import datastore
from tango.common.util import utc_now_datetime
from tango.integrations.gs.common import (
Constants,
GCSStepLock,
get_bucket_and_prefix,
get_client,
get_credentials,
)
from tango.integrations.gs.step_cache import GSStepCache
from tango.step import Step
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace
from tango.workspaces.remote_workspace import RemoteWorkspace
T = TypeVar("T")
@Workspace.register("gs")
class GSWorkspace(RemoteWorkspace):
"""
This is a :class:`~tango.workspace.Workspace` that stores step artifacts on Google Cloud Storage.
.. tip::
Registered as a :class:`~tango.workspace.Workspace` under the name "gs".
:param workspace: The name or ID of the Google Cloud bucket folder to use.
:param project: The Google project ID. This is required for the datastore. If not provided,
it will be inferred from the Google cloud credentials.
.. important::
Credentials can be provided in the following ways:
- Using the `credentials` keyword argument:
- You can specify the path to the credentials json file.
- You can specify the `google.oauth2.credentials.Credentials()` object.
- You can specify the json string of credentials dict.
- Using the default credentials: You can use your default google cloud credentials by running
`gcloud auth application-default login`. If you are using `GSWorkspace` with
:class:`~tango.integrations.beaker.BeakerExecutor`, you will need to set the environment variable
`GOOGLE_TOKEN` to the credentials json file. The default location is usually
`~/.config/gcloud/application_default_credentials.json`.
"""
Constants = Constants
NUM_CONCURRENT_WORKERS = 32
def __init__(
self,
workspace: str,
project: Optional[str] = None,
credentials: Optional[Union[str, Credentials]] = None,
):
credentials = get_credentials(credentials)
self.client = get_client(folder_name=workspace, credentials=credentials, project=project)
self.client.NUM_CONCURRENT_WORKERS = self.NUM_CONCURRENT_WORKERS
self._cache = GSStepCache(workspace, client=self.client)
self._locks: Dict[Step, GCSStepLock] = {}
super().__init__()
project = project or self.client.storage.project or credentials.quota_project_id
self.bucket_name, self.prefix = get_bucket_and_prefix(workspace)
self._ds = datastore.Client(
namespace=self.bucket_name, project=project, credentials=credentials
)
@property
def cache(self):
return self._cache
@property
def locks(self):
return self._locks
@property
def steps_dir_name(self):
return "gs_workspace"
@classmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:
workspace: str
if parsed_url.netloc and parsed_url.path:
# e.g. "gs://ai2/my-workspace"
workspace = parsed_url.netloc + parsed_url.path
elif parsed_url.netloc:
# e.g. "gs://my-workspace"
workspace = parsed_url.netloc
else:
raise ValueError(f"Bad URL for GS workspace '{parsed_url}'")
return cls(workspace)
@property
def url(self) -> str:
return self.client.url()
def _remote_lock(self, step: Step) -> GCSStepLock:
return GCSStepLock(self.client, step)
def _step_location(self, step: Step) -> str:
return self.client.url(self.Constants.step_artifact_name(step))
@property
def _run_key(self):
return self.client._gs_path("run")
@property
def _stepinfo_key(self):
return self.client._gs_path("stepinfo")
def _save_run(
self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None
) -> Run:
if name is None:
while True:
name = petname.generate() + str(random.randint(0, 100))
if not self._ds.get(self._ds.key(self._run_key, name)):
break
else:
if self._ds.get(self._ds.key(self._run_key, name)):
raise ValueError(f"Run name '{name}' is already in use")
run_entity = self._ds.entity(
key=self._ds.key(self._run_key, name), exclude_from_indexes=("steps",)
)
# Even though the run's name is part of the key, we add this as a
# field so we can index on it and order asc/desc (indices on the key field don't allow ordering).
run_entity["name"] = name
run_entity["start_date"] = utc_now_datetime()
run_entity["steps"] = json.dumps(run_data).encode()
self._ds.put(run_entity)
return Run(name=cast(str, name), steps=steps, start_date=run_entity["start_date"])
def _get_run_from_entity(self, run_entity: datastore.Entity) -> Optional[Run]:
try:
steps_info_bytes = run_entity["steps"]
steps_info = json.loads(steps_info_bytes)
except KeyError:
return None
import concurrent.futures
steps: Dict[str, StepInfo] = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS,
thread_name_prefix="GSWorkspace._get_run_from_dataset()-",
) as executor:
step_info_futures = []
for unique_id in steps_info.values():
step_info_futures.append(executor.submit(self.step_info, unique_id))
for future in concurrent.futures.as_completed(step_info_futures):
step_info = future.result()
assert step_info.step_name is not None
steps[step_info.step_name] = step_info
return Run(name=run_entity.key.name, start_date=run_entity["start_date"], steps=steps)
def registered_runs(self) -> Dict[str, Run]:
import concurrent.futures
runs: Dict[str, Run] = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS,
thread_name_prefix="GSWorkspace.registered_runs()-",
) as executor:
run_futures = []
for run_entity in self._ds.query(kind=self._run_key).fetch():
run_futures.append(executor.submit(self._get_run_from_entity, run_entity))
for future in concurrent.futures.as_completed(run_futures):
run = future.result()
if run is not None:
runs[run.name] = run
return runs
def search_registered_runs(
self,
*,
sort_by: Optional[RunSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[RunInfo]:
run_entities = self._fetch_run_entities(
sort_by=sort_by, sort_descending=sort_descending, match=match, start=start, stop=stop
)
return [
RunInfo(name=e.key.name, start_date=e["start_date"], steps=json.loads(e["steps"]))
for e in run_entities
]
def num_registered_runs(self, *, match: Optional[str] = None) -> int:
count = 0
for _ in self._fetch_run_entities(match=match):
count += 1
return count
def _fetch_run_entities(
self,
*,
sort_by: Optional[RunSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
start: int = 0,
stop: Optional[int] = None,
) -> Generator[datastore.Entity, None, None]:
from itertools import islice
# Note: we can't query or order by multiple fields without a suitable
# composite index. So in that case we have to apply remaining filters
# or slice and order locally. We'll default to using 'match' in the query.
# But if 'match' is null we can sort with the query.
sort_locally = bool(match)
sort_field: Optional[str] = None
if sort_by == RunSort.START_DATE:
sort_field = "start_date"
elif sort_by == RunSort.NAME:
sort_field = "name"
elif sort_by is not None:
raise NotImplementedError(sort_by)
order: List[str] = []
if sort_field is not None and not sort_locally:
order = [sort_field if not sort_descending else f"-{sort_field}"]
query = self._ds.query(kind=self._run_key, order=order)
if match:
# HACK: Datastore has no direct string matching functionality,
# but this comparison is equivalent to checking if 'name' starts with 'match'.
query.add_filter("name", ">=", match)
query.add_filter("name", "<=", match[:-1] + chr(ord(match[-1]) + 1))
entity_iter: Iterable[datastore.Entity] = query.fetch(
offset=0 if sort_locally else start,
limit=None if (stop is None or sort_locally) else stop - start,
)
if sort_field is not None and sort_locally:
entity_iter = sorted(
entity_iter, key=lambda entity: entity[sort_field], reverse=sort_descending
)
if sort_locally:
entity_iter = islice(entity_iter, start, stop)
for entity in entity_iter:
yield entity
def search_step_info(
self,
*,
sort_by: Optional[StepInfoSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
state: Optional[StepState] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[StepInfo]:
step_info_entities = self._fetch_step_info_entities(
sort_by=sort_by,
sort_descending=sort_descending,
match=match,
state=state,
start=start,
stop=stop,
)
return [
StepInfo.from_json_dict(json.loads(e["step_info_dict"])) for e in step_info_entities
]
def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int:
count = 0
for _ in self._fetch_step_info_entities(match=match, state=state):
count += 1
return count
def _fetch_step_info_entities(
self,
*,
sort_by: Optional[StepInfoSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
state: Optional[StepState] = None,
start: int = 0,
stop: Optional[int] = None,
) -> Generator[datastore.Entity, None, None]:
from itertools import islice
# Note: we can't query or order by multiple fields without a suitable
# composite index. So in that case we have to apply remaining filters
# or slice and order locally. We'll default to using 'match' in the query.
# But if 'match' is null, we'll use 'state' to filter in the query.
# If 'state' is also null, we can sort with the query.
sort_locally = sort_by is not None and (match is not None or state is not None)
filter_locally = state is not None and match is not None
slice_locally = sort_locally or filter_locally
sort_field: Optional[str] = None
if sort_by == StepInfoSort.START_TIME:
sort_field = "start_time"
elif sort_by == StepInfoSort.UNIQUE_ID:
sort_field = "step_id"
elif sort_by is not None:
raise NotImplementedError(sort_by)
order: List[str] = []
if sort_field is not None and not sort_locally:
order = [sort_field if not sort_descending else f"-{sort_field}"]
query = self._ds.query(kind=self._stepinfo_key, order=order)
if match is not None:
# HACK: Datastore has no direct string matching functionality,
# but this comparison is equivalent to checking if 'step_id' starts with 'match'.
query.add_filter("step_id", ">=", match)
query.add_filter("step_id", "<=", match[:-1] + chr(ord(match[-1]) + 1))
elif state is not None and not filter_locally:
query.add_filter("state", "=", str(state.value))
entity_iter: Iterable[datastore.Entity] = query.fetch(
offset=0 if slice_locally else start,
limit=None if (stop is None or slice_locally) else stop - start,
)
if state is not None and filter_locally:
entity_iter = filter(lambda entity: entity["state"] == state, entity_iter)
if sort_field is not None and sort_locally:
entity_iter = sorted(
entity_iter, key=lambda entity: entity[sort_field], reverse=sort_descending
)
if slice_locally:
entity_iter = islice(entity_iter, start, stop)
for entity in entity_iter:
yield entity
def registered_run(self, name: str) -> Run:
err_msg = f"Run '{name}' not found in workspace"
run_entity = self._ds.get(key=self._ds.key(self._run_key, name))
if not run_entity:
raise KeyError(err_msg)
run = self._get_run_from_entity(run_entity)
if run is None:
raise KeyError(err_msg)
else:
return run
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
unique_id = (
step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id
)
step_info_entity = self._ds.get(key=self._ds.key(self._stepinfo_key, unique_id))
if step_info_entity is not None:
step_info_bytes = step_info_entity["step_info_dict"]
step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))
return step_info
else:
if not isinstance(step_or_unique_id, Step):
raise KeyError(step_or_unique_id)
step_info = StepInfo.new_from_step(step_or_unique_id)
self._update_step_info(step_info)
return step_info
def _step_info_multiple(
self, step_or_unique_ids: Union[List[Step], List[str]]
) -> List[StepInfo]:
"""
This method is to combine all calls to the datastore api in a single transaction.
"""
all_unique_id_keys = []
for step_or_unique_id in step_or_unique_ids:
unique_id = (
step_or_unique_id
if isinstance(step_or_unique_id, str)
else step_or_unique_id.unique_id
)
key = self._ds.key(self._stepinfo_key, unique_id)
all_unique_id_keys.append(key)
missing: List = []
step_info_entities = self._ds.get_multi(keys=all_unique_id_keys, missing=missing)
missing_steps = [entity.key.name for entity in missing]
step_infos = []
for step_info_entity in step_info_entities:
step_info_bytes = step_info_entity["step_info_dict"]
step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))
step_infos.append(step_info)
for step_or_unique_id in step_or_unique_ids:
step_id = (
step_or_unique_id
if isinstance(step_or_unique_id, str)
else step_or_unique_id.unique_id
)
if step_id in missing_steps:
if not isinstance(step_or_unique_id, Step):
raise KeyError(step_or_unique_id)
step_info = StepInfo.new_from_step(step_or_unique_id)
self._update_step_info(step_info)
step_infos.append(step_info)
return step_infos
def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]:
all_steps = set(targets)
for step in targets:
all_steps |= step.recursive_dependencies
steps: Dict[str, StepInfo] = {}
run_data: Dict[str, str] = {}
all_valid_steps = [step for step in all_steps if step.name is not None]
step_infos = self._step_info_multiple(all_valid_steps)
for step_info in step_infos:
assert step_info.step_name is not None
steps[step_info.step_name] = step_info
run_data[step_info.step_name] = step_info.unique_id
return steps, run_data
def _update_step_info(self, step_info: StepInfo):
step_info_entity = self._ds.entity(
key=self._ds.key(self._stepinfo_key, step_info.unique_id),
exclude_from_indexes=("step_info_dict",),
)
# Even though the step's unique ID is part of the key, we add this as a
# field so we can index on it and order asc/desc (indices on the key field don't allow ordering).
step_info_entity["step_id"] = step_info.unique_id
step_info_entity["step_name"] = step_info.step_name
step_info_entity["start_time"] = step_info.start_time
step_info_entity["end_time"] = step_info.end_time
step_info_entity["state"] = str(step_info.state.value)
step_info_entity["updated"] = utc_now_datetime()
step_info_entity["step_info_dict"] = json.dumps(step_info.to_json_dict()).encode()
self._ds.put(step_info_entity)
def _remove_step_info(self, step_info: StepInfo) -> None:
# remove dir from bucket
step_artifact = self.client.get(self.Constants.step_artifact_name(step_info))
if step_artifact is not None:
self.client.delete(step_artifact)
# remove datastore entities
self._ds.delete(key=self._ds.key("stepinfo", step_info.unique_id))
def _save_run_log(self, name: str, log_file: Path):
"""
The logs are stored in the bucket. The Run object details are stored in
the remote database.
"""
run_dataset = self.Constants.run_artifact_name(name)
self.client.upload(run_dataset, log_file)
================================================
FILE: tango/integrations/torch/__init__.py
================================================
# -*- coding: UTF-8 -*-
"""
.. important::
To use this integration you should install ``tango`` with the "torch" extra
(e.g. ``pip install tango[torch]``) or just install PyTorch after the fact.
Make sure you install the correct version of torch given your operating system
and supported CUDA version. Check
`pytorch.org/get-started/locally/ `_
for more details.
Components for Tango integration with `PyTorch `_.
These include a training loop :class:`~tango.step.Step` and registrable versions
of many ``torch`` classes, such :class:`torch.optim.Optimizer` and :class:`torch.utils.data.DataLoader`.
Example: training a model
-------------------------
Let's look a simple example of training a model.
We'll make a basic regression model and generate some fake data to train on.
First, the setup:
.. testcode::
import torch
import torch.nn as nn
from tango.common.dataset_dict import DatasetDict
from tango.step import Step
from tango.integrations.torch import Model
Now let's build and register our model:
.. testcode::
@Model.register("basic_regression")
class BasicRegression(Model):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
self.sigmoid = nn.Sigmoid()
self.mse = nn.MSELoss()
def forward(self, x, y=None):
pred = self.sigmoid(self.linear(x))
out = {"pred": pred}
if y is not None:
out["loss"] = self.mse(pred, y)
return out
def _to_params(self):
return {}
Lastly, we'll need a step to generate data:
.. testcode::
@Step.register("generate_data")
class GenerateData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self) -> DatasetDict:
torch.manual_seed(1)
return DatasetDict(
{
"train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)],
"validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)],
}
)
You could then run this experiment with a config that looks like this:
.. literalinclude:: ../../../../test_fixtures/integrations/torch/train.jsonnet
.. testcode::
:hide:
from tango.common.testing import run_experiment
from tango.common.registrable import Registrable
# Pickling the model fails because the class is defined ad hoc, not in a module.
# So we put in this hack to pickle a 0 instead of the Model.
def _return_zero(self):
return (int, (0,))
BasicRegression.__reduce__ = _return_zero
with run_experiment(
"test_fixtures/integrations/torch/train.jsonnet", name="boss-alien"
) as run_dir:
assert (run_dir / "train").is_dir(), "Output for the 'train' step was not produced."
# Restore state of registry.
del Registrable._registry[Step]["generate_data"]
del Registrable._registry[Model]["basic_regression"]
For example,
.. code-block::
tango run train.jsonnet -i my_package -d /tmp/train
would produce the following output:
.. testoutput::
:options: +ELLIPSIS
Starting new run boss-alien
● Starting step "data" (needed by "train")...
✓ Finished step "data"
● Starting step "train"...
✓ Finished step "train"
✓ Finished run boss-alien
...
Tips
----
Debugging
~~~~~~~~~
When debugging a training loop that's causing errors on a GPU, you should set the environment variable
``CUDA_LAUNCH_BLOCKING=1``. This will ensure that the stack traces shows where the error actually happened.
You could also use a custom :class:`TrainCallback` to log each batch before they are passed into the model
so that you can see the exact inputs that are causing the issue.
Stopping early
~~~~~~~~~~~~~~
You can stop the "torch::train" step early using a custom :class:`TrainCallback`. Your callback just
needs to raise the :class:`StopEarly` exception.
"""
from tango.common.exceptions import IntegrationMissingError
try:
import torch
except ModuleNotFoundError:
raise IntegrationMissingError("torch")
__all__ = [
"TorchFormat",
"TorchTrainStep",
"TorchEvalStep",
"Optimizer",
"LRScheduler",
"Model",
"DataLoader",
"DataCollator",
"Sampler",
"ConcatTensorDictsCollator",
"TrainCallback",
"EvalCallback",
"TrainConfig",
"StopEarlyCallback",
"StopEarly",
"TrainingEngine",
"TorchTrainingEngine",
]
from .data import ConcatTensorDictsCollator, DataCollator, DataLoader, Sampler
from .eval import TorchEvalStep
from .eval_callback import EvalCallback
from .exceptions import StopEarly
from .format import TorchFormat
from .model import Model
from .optim import LRScheduler, Optimizer
from .train import TorchTrainStep
from .train_callback import StopEarlyCallback, TrainCallback
from .train_config import TrainConfig
from .training_engine import TorchTrainingEngine, TrainingEngine
================================================
FILE: tango/integrations/torch/data.py
================================================
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
import torch
from tango.common.lazy import Lazy
from tango.common.registrable import Registrable
T = TypeVar("T")
class DataCollator(Generic[T], Registrable):
"""
A :class:`~tango.common.Registrable` version of a ``collate_fn``
for a ``DataLoader``.
Subclasses just need to implement :meth:`__call__()`.
"""
default_implementation = "concat_tensor_dicts"
"""
The default implementation is :class:`ConcatTensorDictsCollator`.
"""
def __call__(self, items: List[T]) -> Dict[str, Any]:
"""
Takes a list of items from a dataset and combines them into a batch.
"""
raise NotADirectoryError
@DataCollator.register("concat_tensor_dicts")
class ConcatTensorDictsCollator(DataCollator[Dict[str, Any]]):
"""
A simple ``collate_fn`` that expects items to be dictionaries of tensors.
The tensors are just concatenated together.
.. tip::
Registered as a :class:`DataCollator` under the name "concat_tensor_dicts".
"""
def __call__(self, items: List[Dict[str, Any]]) -> Dict[str, Any]:
out = {}
keys = items[0].keys()
for key in keys:
if isinstance(items[0][key], torch.Tensor):
out[key] = torch.cat([item[key].unsqueeze(0) for item in items])
elif isinstance(items[0][key], (int, float)):
out[key] = torch.tensor([item[key] for item in items])
else:
out[key] = [item[key] for item in items] # type: ignore[assignment]
return out
class Sampler(torch.utils.data.Sampler, Registrable):
"""
A :class:`~tango.common.Registrable` version of a PyTorch
:class:`~torch.utils.data.Sampler`.
All `built-in PyTorch samplers
`_
are registered under their corresponding class name (e.g. "RandomSampler").
"""
@Sampler.register("torch::BatchSampler")
class BatchSampler(torch.utils.data.BatchSampler, Sampler):
def __init__(
self,
dataset: torch.utils.data.Dataset,
sampler: Union[Lazy[Sampler], Sampler],
batch_size: int,
drop_last: bool,
) -> None:
super().__init__(
sampler.construct(data_source=dataset, dataset=dataset)
if isinstance(sampler, Lazy)
else sampler,
batch_size,
drop_last,
)
# Register all remaining samplers.
for name, cls in torch.utils.data.__dict__.items():
registered_name = "torch::" + name
if (
isinstance(cls, type)
and issubclass(cls, torch.utils.data.Sampler)
and not cls == torch.utils.data.Sampler
and registered_name not in Sampler.list_available()
):
Sampler.register(registered_name)(cls)
class DataLoader(torch.utils.data.DataLoader, Registrable):
"""
A :class:`~tango.common.Registrable` version of a PyTorch
:class:`~torch.utils.data.DataLoader`.
"""
default_implementation = "default"
def __init__(
self,
dataset: torch.utils.data.Dataset,
collate_fn: Optional[DataCollator] = ConcatTensorDictsCollator(),
sampler: Optional[Union[Lazy[Sampler], Sampler]] = None,
**kwargs,
):
super().__init__(
dataset,
collate_fn=collate_fn,
sampler=sampler.construct(data_source=dataset, dataset=dataset)
if isinstance(sampler, Lazy)
else sampler,
**kwargs,
)
DataLoader.register("default")(DataLoader)
================================================
FILE: tango/integrations/torch/eval.py
================================================
from collections import defaultdict
from itertools import islice
from typing import Dict, List, Optional, Sequence
import torch
from tango.common.dataset_dict import DatasetDictBase
from tango.common.exceptions import ConfigurationError
from tango.common.lazy import Lazy
from tango.common.tqdm import Tqdm
from tango.format import Format, JsonFormat
from tango.step import Step, StepResources
from .data import DataLoader
from .eval_callback import EvalCallback
from .model import Model
from .util import check_dataset, move_to_device, resolve_device, set_seed_all
@Step.register("torch::eval")
class TorchEvalStep(Step):
"""
A PyTorch evaluation loop that pairs well with :class:`TorchTrainStep`.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "torch::eval".
.. important::
The evaluation loop will use a GPU automatically if one is available.
You can control which GPU it uses with the environment variable ``CUDA_VISIBLE_DEVICES``.
For example, set ``CUDA_VISIBLE_DEVICES=1`` to force ``TorchEvalStep`` to only use
the GPU with ID 1.
.. warning::
By default the metrics specified by the ``metric_names`` parameter
are aggregated by simply averaging across batches.
This behavior is usually correct for metrics like "loss" or "accuracy",
for example, but may not be correct for other metrics like "F1".
If this is not correct for your metric you will need to handle the aggregation
internally in your model or with an :class:`EvalCallback`
using the :meth:`EvalCallback.post_batch()` method.
Then set the parameter ``auto_aggregate_metrics`` to ``False``.
"""
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = JsonFormat()
SKIP_ID_ARGUMENTS = {"log_every"}
@property
def resources(self) -> StepResources:
return self.step_resources or StepResources(gpu_count=1)
def run( # type: ignore[override]
self,
model: Model,
dataset_dict: DatasetDictBase,
dataloader: Lazy[DataLoader],
test_split: str = "test",
seed: int = 42,
eval_steps: Optional[int] = None,
log_every: int = 1,
metric_names: Sequence[str] = ("loss",),
auto_aggregate_metrics: bool = True,
callbacks: Optional[List[Lazy[EvalCallback]]] = None,
) -> Dict[str, float]:
"""
Evaluate the ``model``.
:param model:
The model to evaluate. It should return a ``dict`` from its ``forward()`` method
that includes all of the metrics in ``metric_names`` .
:param dataset_dict:
Should contain the test data.
:param dataloader:
The data loader that generates test batches. The batches should be :class:`dict`
objects.
:param test_split:
The name of the data split used for evaluation in the ``dataset_dict``.
Default is "test".
:param seed:
Used to set the RNG states at the beginning of the evaluation loop.
:param eval_steps:
The number of steps to evaluate for. If not specified evaluation will
stop after a complete iteration through the ``dataloader``.
:param log_every:
Log every this many steps. Default is ``1``.
:param metric_names:
The names of the metrics to track and aggregate. Default is ``("loss",)``.
:param auto_aggregate_metrics:
If ``True`` (the default), the metrics will be averaged across batches.
This may not be the correct behavior for some metrics (such as F1),
in which you should set this to ``False`` and handle the aggregation
internally in your model or with an :class:`EvalCallback`
(using :meth:`EvalCallback.post_batch()`).
:param callbacks:
A list of :class:`EvalCallback`.
"""
set_seed_all(seed)
check_dataset(dataset_dict, test_split)
# Resolve device.
device = resolve_device()
# Prep model.
model = model.eval().to(device)
# Construct dataloader.
dataloader: DataLoader = dataloader.construct(dataset=dataset_dict[test_split])
steps: int
try:
dataloader_len = len(dataloader)
steps = dataloader_len if eval_steps is None else min(dataloader_len, eval_steps)
except TypeError:
if eval_steps is None:
raise ConfigurationError(
"You must set 'eval_steps' for streaming/iterable datasets"
)
else:
steps = eval_steps
# Initialize callbacks.
callbacks: List[EvalCallback] = [
callback.construct(
workspace=self.workspace,
step_id=self.unique_id,
work_dir=self.work_dir,
model=model,
dataset_dict=dataset_dict,
dataloader=dataloader,
)
for callback in (callbacks or [])
]
for callback in callbacks:
callback.pre_eval_loop()
eval_batches = enumerate(islice(dataloader, steps))
running_metrics: Dict[str, float] = defaultdict(float)
aggregated_metrics: Dict[str, float] = {}
with Tqdm.tqdm(eval_batches, desc="Evaluating", total=steps) as batch_iter:
for step, batch in batch_iter:
should_log_this_step = step % log_every == 0 or step == steps - 1
for callback in callbacks:
callback.pre_batch(step, batch)
batch = move_to_device(batch, device)
with torch.inference_mode():
outputs = model(**batch)
for callback in callbacks:
callback.post_batch(step, outputs)
# Gather metrics we want to track.
batch_metrics = {
k: outputs[k].item() if isinstance(outputs[k], torch.Tensor) else outputs[k]
for k in metric_names
}
# Aggregate metrics.
if auto_aggregate_metrics:
for k in batch_metrics:
running_metrics[k] += batch_metrics[k]
aggregated_metrics[k] = running_metrics[k] / (step + 1)
else:
aggregated_metrics.update(batch_metrics)
# Update progress bar.
if should_log_this_step:
batch_iter.set_postfix(**aggregated_metrics)
# Clean up to help garbage collector. Hopefully this saves memory.
del batch
del outputs
del batch_metrics
for callback in callbacks:
callback.post_eval_loop(aggregated_metrics)
return aggregated_metrics
================================================
FILE: tango/integrations/torch/eval_callback.py
================================================
from pathlib import Path
from typing import Any, Dict
from tango.common.dataset_dict import DatasetDictBase
from tango.common.registrable import Registrable
from tango.workspace import Workspace
from .data import DataLoader
from .model import Model
class EvalCallback(Registrable):
"""
An ``EvalCallback`` is a :class:`~tango.common.Registrable` class that can be used
within :class:`TorchEvalStep` to customize the behavior of the evaluation loop,
similar to how :class:`TrainCallback` is used to customize the behavior of the training
loop.
.. tip::
All of the parameters to this base class will be automatically set within
the training loop, so you shouldn't include them in your config for your callbacks.
:ivar Workspace workspace: The tango workspace being used.
:ivar str step_id: The unique ID of the step.
:ivar pathlib.Path work_dir: The working directory of the step
:ivar Model model: The model being evaluated.
:ivar DatasetDictBase dataset_dict: The dataset dict containing the evaluation split.
:ivar DataLoader dataloader: The data loader used to load the evaluation split data.
"""
def __init__(
self,
workspace: Workspace,
step_id: str,
work_dir: Path,
model: Model,
dataset_dict: DatasetDictBase,
dataloader: DataLoader,
) -> None:
self.workspace = workspace
self.step_id = step_id
self.work_dir = work_dir
self.model = model
self.dataset_dict = dataset_dict
self.dataloader = dataloader
def pre_eval_loop(self) -> None:
"""
Called right before the first batch is processed.
"""
pass
def post_eval_loop(self, aggregated_metrics: Dict[str, float]) -> None:
"""
Called after the evaluation loop completes with the final aggregated metrics.
This is the last method that is called, so any cleanup can be done in this method.
"""
pass
def pre_batch(self, step: int, batch: Dict[str, Any]) -> None:
"""
Called directly before processing a batch.
"""
pass
def post_batch(self, step: int, batch_outputs: Dict[str, Any]) -> None:
"""
Called directly after processing a batch with the outputs of the batch.
.. tip::
This method can be used to modify ``batch_outputs`` in place, which is useful
in scenarios where you might need to aggregate metrics
in a special way other than a simple average. If that's the case, make sure
to set ``auto_aggregate_metrics`` to ``False`` in :class:`TorchEvalStep`.
"""
pass
================================================
FILE: tango/integrations/torch/exceptions.py
================================================
from tango.common.exceptions import TangoError
class StopEarly(TangoError):
"""
Callbacks can raise this exception to stop training early without crashing.
.. important::
During distributed training all workers must raise this exception at the same point
in the training loop, otherwise there will be a deadlock.
"""
================================================
FILE: tango/integrations/torch/format.py
================================================
from pathlib import Path
from typing import Generic, TypeVar
import dill
import torch
from tango.common.aliases import PathOrStr
from tango.format import Format
T = TypeVar("T")
@Format.register("torch")
class TorchFormat(Format[T], Generic[T]):
"""
This format writes the artifact using ``torch.save()``.
Unlike :class:`tango.format.DillFormat`, this has no special support for iterators.
.. tip::
Registered as a :class:`~tango.format.Format` under the name "torch".
"""
VERSION = "002"
def write(self, artifact: T, dir: PathOrStr):
filename = Path(dir) / "data.pt"
with open(filename, "wb") as f:
torch.save((self.VERSION, artifact), f, pickle_module=dill)
def read(self, dir: PathOrStr) -> T:
filename = Path(dir) / "data.pt"
with open(filename, "rb") as f:
version, artifact = torch.load(f, pickle_module=dill, map_location=torch.device("cpu"))
if version > self.VERSION:
raise ValueError(
f"File {filename} is too recent for this version of {self.__class__}."
)
return artifact
================================================
FILE: tango/integrations/torch/model.py
================================================
import torch
from tango.common.registrable import Registrable
class Model(torch.nn.Module, Registrable):
"""
This is a :class:`~tango.common.Registrable` mixin class that inherits from
:class:`torch.nn.Module`.
Its :meth:`~torch.nn.Module.forward()` method should return a :class:`dict` that
includes the ``loss`` during training and any tracked metrics during validation.
"""
================================================
FILE: tango/integrations/torch/optim.py
================================================
from typing import Type
import torch
from tango.common.registrable import Registrable
class Optimizer(torch.optim.Optimizer, Registrable):
"""
A :class:`~tango.common.Registrable` version of a PyTorch
:class:`torch.optim.Optimizer`.
All `built-in PyTorch optimizers
`_
are registered according to their class name (e.g. "torch::Adam").
.. tip::
You can see a list of all available optimizers by running
.. testcode::
from tango.integrations.torch import Optimizer
for name in sorted(Optimizer.list_available()):
print(name)
.. testoutput::
:options: +ELLIPSIS
torch::ASGD
torch::Adadelta
torch::Adagrad
torch::Adam
torch::AdamW
...
"""
class LRScheduler(torch.optim.lr_scheduler._LRScheduler, Registrable):
"""
A :class:`~tango.common.Registrable` version of a PyTorch learning
rate scheduler.
All `built-in PyTorch learning rate schedulers
`_
are registered according to their class name (e.g. "torch::StepLR").
.. tip::
You can see a list of all available schedulers by running
.. testcode::
from tango.integrations.torch import LRScheduler
for name in sorted(LRScheduler.list_available()):
print(name)
.. testoutput::
:options: +ELLIPSIS
torch::ChainedScheduler
torch::ConstantLR
torch::CosineAnnealingLR
...
"""
# Register all optimizers.
for name, cls in torch.optim.__dict__.items():
if (
isinstance(cls, type)
and issubclass(cls, torch.optim.Optimizer)
and not cls == torch.optim.Optimizer
):
Optimizer.register("torch::" + name)(cls)
# Note: This is a hack. Remove after we upgrade the torch version.
base_class: Type
try:
base_class = torch.optim.lr_scheduler.LRScheduler
except AttributeError:
base_class = torch.optim.lr_scheduler._LRScheduler
# Register all learning rate schedulers.
for name, cls in torch.optim.lr_scheduler.__dict__.items():
if isinstance(cls, type) and issubclass(cls, base_class) and not cls == base_class:
LRScheduler.register("torch::" + name)(cls)
================================================
FILE: tango/integrations/torch/train.py
================================================
import logging
import math
import os
import shutil
from itertools import islice
from typing import Any, Dict, List, Optional, Set, Union, cast
import more_itertools
import torch
import torch.distributed as dist
from more_itertools import chunked
from torch.utils.data import DistributedSampler
from tango.common.dataset_dict import DatasetDictBase
from tango.common.exceptions import ConfigurationError
from tango.common.lazy import Lazy
from tango.common.tqdm import Tqdm
from tango.common.util import get_extra_imported_modules, import_extra_module
from tango.format import Format
from tango.step import Step, StepResources
from tango.workspace import Workspace
from .data import DataLoader
from .exceptions import StopEarly
from .format import TorchFormat
from .model import Model
from .train_callback import TrainCallback
from .train_config import TrainConfig
from .training_engine import TrainingEngine
from .util import check_dataloader, check_dataset, set_seed_all
@Step.register("torch::train")
class TorchTrainStep(Step):
"""
A PyTorch training loop step that supports gradient accumulation, distributed training,
and AMP, with configurable dataloaders, callbacks, optimizer, and LR scheduler.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "torch::train".
.. important::
The training loop will use GPU(s) automatically when available, as long as at least
``device_count`` CUDA devices are available.
Distributed data parallel training is activated when the ``device_count`` is greater than 1.
You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``.
For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1``
(and ``device_count`` to 2).
.. warning::
During validation, the validation metric (specified by the ``val_metric_name`` parameter)
is aggregated by simply averaging across validation batches and distributed processes.
This behavior is usually correct when your validation metric is "loss" or "accuracy",
for example, but may not be correct for other metrics like "F1".
If this is not correct for your metric you will need to handle the aggregation
internally in your model or with a :class:`TrainCallback`
using the :meth:`TrainCallback.post_val_batch()` method.
Then set the parameter ``auto_aggregate_val_metric`` to ``False``.
Note that correctly aggregating your metric during distributed training will
involve distributed communication.
"""
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = TorchFormat()
SKIP_ID_ARGUMENTS = {"distributed_port", "log_every"}
METADATA = {"artifact_kind": "model"}
@property
def resources(self) -> StepResources:
return self.step_resources or StepResources(gpu_count=self.kwargs["device_count"])
def run( # type: ignore[override]
self,
model: Union[Lazy[Model], Model], # Lazy has to come first
training_engine: Lazy[TrainingEngine],
dataset_dict: DatasetDictBase,
train_dataloader: Lazy[DataLoader],
*,
train_split: str = "train",
validation_split: Optional[str] = None,
validation_dataloader: Optional[Lazy[DataLoader]] = None,
seed: int = 42,
train_steps: Optional[int] = None,
train_epochs: Optional[int] = None,
validation_steps: Optional[int] = None,
grad_accum: int = 1,
log_every: int = 10,
checkpoint_every: int = 100,
validate_every: Optional[int] = None,
device_count: int = 1,
distributed_port: int = 54761,
val_metric_name: str = "loss",
minimize_val_metric: bool = True,
auto_aggregate_val_metric: bool = True,
callbacks: Optional[List[Lazy[TrainCallback]]] = None,
remove_stale_checkpoints: bool = True,
) -> Model:
"""
Run a basic training loop to train the ``model``.
:param model:
The model to train. It should return a ``dict`` that includes the ``loss``
during training and the ``val_metric_name`` during validation.
:param training_engine:
The :class:`TrainingEngine` to use to train the model.
:param dataset_dict:
The train and optional validation data.
:param train_dataloader:
The data loader that generates training batches. The batches should be :class:`dict`
objects that will be used as ``kwargs`` for the model's ``forward()`` method.
:param train_split:
The name of the data split used for training in the ``dataset_dict``.
Default is "train".
:param validation_split:
Optional name of the validation split in the ``dataset_dict``. Default is ``None``,
which means no validation.
:param validation_dataloader:
An optional data loader for generating validation batches. The batches should be
:class:`dict` objects. If not specified, but ``validation_split`` is given,
the validation ``DataLoader`` will be constructed from the same parameters
as the train ``DataLoader``.
:param seed:
Used to set the RNG states at the beginning of training.
:param train_steps:
The number of steps to train for. If not specified training will
stop after a complete iteration through the ``train_dataloader``.
:param train_epochs:
The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs``
at the same time.
:param validation_steps:
The number of steps to validate for. If not specified validation
will stop after a complete iteration through the ``validation_dataloader``.
:param grad_accum:
The number of gradient accumulation steps. Defaults to 1.
.. note::
This parameter - in conjuction with the settings of your data loader
and the number distributed workers -
determines the *effective batch size* of your training run.
:param log_every:
Log every this many steps.
:param checkpoint_every:
Save a checkpoint every this many steps.
:param validate_every:
Run the validation loop every this many steps.
:param device_count:
The number of devices to train on, i.e. the number of distributed data parallel workers.
:param distributed_port:
The port of the distributed process group. Default = "54761".
:param val_metric_name:
The name of the validation metric, i.e. the key of the metric in the dictionary
returned by the forward pass of the model. Default is "loss".
:param minimize_val_metric:
Whether the validation metric is meant to be minimized (such as the loss).
Default is ``True``. When using a metric such as accuracy, you should set
this to ``False``.
:param auto_aggregate_val_metric:
If ``True`` (the default), the validation metric will be averaged across
validation batches and distributed processes. This may not be the correct
behavior for some metrics (such as F1), in which you should set this to
``False`` and handle the aggregation internally in your model
or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`).
:param callbacks:
A list of :class:`TrainCallback`.
:param remove_stale_checkpoints:
If ``True`` (the default), stale checkpoints will be removed throughout training so that
only the latest and best checkpoints are kept.
:returns:
The trained model on CPU with the weights from the best checkpoint loaded.
"""
devices = self._get_devices(device_count)
return self._train(
model=model,
training_engine=training_engine,
dataset_dict=dataset_dict,
train_dataloader=train_dataloader,
train_split=train_split,
validation_split=validation_split,
validation_dataloader=validation_dataloader,
seed=seed,
train_steps=train_steps,
train_epochs=train_epochs,
validation_steps=validation_steps,
grad_accum=grad_accum,
log_every=log_every,
checkpoint_every=checkpoint_every,
validate_every=validate_every,
devices=devices,
distributed_port=distributed_port,
val_metric_name=val_metric_name,
minimize_val_metric=minimize_val_metric,
auto_aggregate_val_metric=auto_aggregate_val_metric,
callbacks=callbacks,
remove_stale_checkpoints=remove_stale_checkpoints,
)
def _get_devices(self, device_count: int) -> List[int]:
"""
Validates the device count, and returns the list of devices.
"""
# Validate device(s).
if device_count <= 0:
raise ConfigurationError("Invalid value for 'device_count'. Must be at least 1.")
devices: List[int]
if torch.cuda.is_available() and torch.cuda.device_count() >= device_count:
devices = list(range(device_count))
self.logger.info("Training on %d GPU%s", device_count, "s" if device_count > 1 else "")
else:
devices = [-1] * device_count
self.logger.info(
"Training on CPU with %d worker%s", device_count, "s" if device_count > 1 else ""
)
return devices
def _train(
self,
model: Union[Model, Lazy[Model]],
training_engine: Lazy[TrainingEngine],
dataset_dict: DatasetDictBase,
train_dataloader: Lazy[DataLoader],
*,
train_split: str = "train",
validation_split: Optional[str] = None,
validation_dataloader: Optional[Lazy[DataLoader]] = None,
seed: int = 42,
train_steps: Optional[int] = None,
train_epochs: Optional[int] = None,
validation_steps: Optional[int] = None,
grad_accum: int = 1,
log_every: int = 10,
checkpoint_every: int = 100,
validate_every: Optional[int] = None,
devices: Optional[List[int]] = None,
distributed_port: int = 54761,
val_metric_name: str = "loss",
minimize_val_metric: bool = True,
auto_aggregate_val_metric: bool = True,
callbacks: Optional[List[Lazy[TrainCallback]]] = None,
remove_stale_checkpoints: bool = True,
) -> Model:
is_distributed = False
num_workers = 1
if devices and len(devices) > 1:
is_distributed = True
num_workers = len(devices)
if validate_every is not None and validation_split is None:
raise ConfigurationError(
"You have set a validation interval, but no validation split. "
"That's probably unintentional."
)
if (train_steps is not None) == (train_epochs is not None):
raise ConfigurationError(
"One of 'train_steps' or 'train_epochs' needs to be specified, but not both."
)
if validate_every is not None and checkpoint_every is not None:
if checkpoint_every % validate_every != 0 and validate_every % checkpoint_every != 0:
raise ConfigurationError(
"'checkpoint_every' needs to be multiple of 'validate_every' or vice versa"
)
config = TrainConfig(
self.unique_id,
self.work_dir,
step_name=self.name,
train_split=train_split,
validation_split=validation_split,
seed=seed,
train_steps=train_steps,
train_epochs=train_epochs,
grad_accum=grad_accum,
log_every=log_every,
checkpoint_every=checkpoint_every,
validate_every=validate_every,
validation_steps=validation_steps,
is_distributed=is_distributed,
devices=devices,
distributed_port=distributed_port,
val_metric_name=val_metric_name,
minimize_val_metric=minimize_val_metric,
auto_aggregate_val_metric=auto_aggregate_val_metric,
remove_stale_checkpoints=remove_stale_checkpoints,
world_size=num_workers,
)
final_model: Model
if is_distributed:
import torch.multiprocessing as mp
mp.spawn(
_train,
args=(
self.workspace,
config,
model,
training_engine,
dataset_dict,
train_dataloader,
validation_dataloader,
callbacks,
get_extra_imported_modules(),
),
nprocs=num_workers,
)
self.logger.info("Constructing final model")
if isinstance(model, Lazy):
final_model = model.construct()
else:
final_model = model
else:
final_model = _train( # type: ignore[assignment]
0,
self.workspace,
config,
model,
training_engine,
dataset_dict,
train_dataloader,
validation_dataloader=validation_dataloader,
callbacks=callbacks,
)
assert final_model is not None
final_model = final_model.cpu()
# Load best checkpoint before returning model.
if config.final_weights_path.is_file():
self.logger.info(
f"Loading best weights from {str(config.final_weights_path.resolve())}"
)
state = torch.load(config.final_weights_path, map_location="cpu")
# We use `strict=False` because there might be missing keys due to weight tying.
final_model.load_state_dict(state, strict=False)
return final_model
def _train(
worker_id: int,
workspace: Workspace,
config: TrainConfig,
model: Union[Model, Lazy[Model]],
training_engine: Lazy[TrainingEngine],
dataset_dict: DatasetDictBase,
train_dataloader: Lazy[DataLoader],
validation_dataloader: Optional[Lazy[DataLoader]] = None,
callbacks: Optional[List[Lazy[TrainCallback]]] = None,
include_package: Optional[Set[str]] = None,
) -> Optional[Model]:
# Set random seeds.
set_seed_all(config.seed)
config.worker_id = worker_id
if config.is_distributed and include_package:
# During distributed training we need to import `include_package` modules again
# in order to initialize the lazy objects.
for package_name in include_package:
import_extra_module(package_name)
if config.is_distributed:
import tango.common.logging as common_logging
common_logging.initialize_worker_logging(config.worker_id)
logger = logging.getLogger(TorchTrainStep.__name__)
training_engine: TrainingEngine = training_engine.construct(
train_config=config,
model=model,
)
# Check working directory to see if we should recover from a previous run.
initial_state: Optional[Dict[str, Any]] = None
if config.state_path.exists():
if config.is_local_main_process:
logger.info(f"Recovering from previous run at {str(config.state_path.resolve())}")
initial_state = training_engine.load_checkpoint(config.state_path)
device = config.worker_local_default_device
# Construct data loaders.
validation_dataloader_: Optional[DataLoader] = None
if config.validation_split is not None:
validation_dataset = dataset_dict[config.validation_split]
check_dataset(validation_dataset, config.validation_split)
if validation_dataloader is not None:
validation_dataloader_ = validation_dataloader.construct(dataset=validation_dataset)
else:
validation_dataloader_ = train_dataloader.construct(dataset=validation_dataset)
validation_dataloader: Optional[DataLoader] = validation_dataloader_
train_dataset = dataset_dict[config.train_split]
check_dataset(train_dataset, config.train_split)
train_dataloader: DataLoader = train_dataloader.construct(dataset=train_dataset)
if config.train_steps is None:
assert config.train_epochs is not None
try:
steps_per_epoch = len(train_dataloader)
except TypeError:
raise ConfigurationError("You must set 'train_steps' for streaming/iterable datasets")
config.train_steps = math.ceil(
steps_per_epoch * (config.train_epochs or 1) / config.grad_accum
)
assert config.train_steps is not None # for mypy
if validation_dataloader is not None:
if config.validation_steps is None:
try:
config.validation_steps = len(validation_dataloader)
except TypeError:
raise ConfigurationError(
"You must set 'validation_steps' for streaming/iterable datasets"
)
# Make sure we're using a DistributedSampler during distributed training.
if config.is_distributed:
check_dataloader(train_dataloader)
if validation_dataloader is not None:
check_dataloader(validation_dataloader)
# The (training) loss for each batch, updated every training batch.
batch_loss: float = 0.0
# The value of the validation metric (could be loss), updated after every validation loop.
val_metric: Optional[float] = None
# The best validation metric over all validation set passes.
best_val_metric: Optional[float] = None
# The best validation metric over all validation set passes that correspond to a checkpoint.
# Could be different from `best_val_metric` if `checkpoint_every` > `validate_every`.
best_val_metric_checkpointed: Optional[float] = None
# The step to start training from.
start_step: int = 0
# The current training step.
step: int = start_step
# If we should do a validation pass after the current training batch.
should_validate_this_step: bool = False
# Load state from checkpoint.
if initial_state is not None:
val_metric = initial_state[f"val_{config.val_metric_name}"]
best_val_metric = initial_state[f"best_{config.val_metric_name}"]
best_val_metric_checkpointed = initial_state[f"best_{config.val_metric_name}_checkpointed"]
start_step = initial_state["training_steps"]
# Initialize callbacks.
callbacks: List[TrainCallback] = [
callback.construct(
workspace=workspace,
train_config=config,
training_engine=training_engine,
dataset_dict=dataset_dict,
train_dataloader=train_dataloader,
validation_dataloader=validation_dataloader,
)
for callback in (callbacks or [])
]
if initial_state:
for callback, state in zip(callbacks, initial_state["callbacks"]):
callback.load_state_dict(state)
del initial_state
training_engine.model.train()
training_batches = enumerate(
islice(
_cycle_through_epochs(train_dataloader, config.is_distributed, config.grad_accum),
config.train_steps,
)
)
def is_best_checkpoint() -> bool:
"""
A closure that we'll call when saving checkpoints to check if we should link
the best checkpoint path to the current checkpoint file.
"""
if val_metric is not None:
if best_val_metric_checkpointed is not None:
return (
config.minimize_val_metric and val_metric <= best_val_metric_checkpointed
) or (not config.minimize_val_metric and val_metric >= best_val_metric_checkpointed)
else:
return False
else:
# Without a validation loop we always treat the most recent checkpoint as the best.
return True
def save_state(step: int):
"""
A closure that we'll call every `checkpoint_every` steps in the train loop to
save model and training state.
"""
# Update best loss/metric trackers.
nonlocal best_val_metric_checkpointed
if should_validate_this_step and val_metric is not None:
if (
best_val_metric_checkpointed is None
or (config.minimize_val_metric and val_metric <= best_val_metric_checkpointed)
or (not config.minimize_val_metric and val_metric >= best_val_metric_checkpointed)
):
best_val_metric_checkpointed = val_metric
train_state = {
"training_steps": step + 1,
f"val_{config.val_metric_name}": val_metric,
f"best_{config.val_metric_name}": best_val_metric,
f"best_{config.val_metric_name}_checkpointed": best_val_metric_checkpointed,
"callbacks": [
callback.state_dict() for callback in callbacks # type: ignore[union-attr]
],
}
# For reason mypy can't figure out that `training_engine` is a `TrainingEngine` in this closure,
# and not a `Lazy[TrainingEngine]`.
cast(TrainingEngine, training_engine).save_checkpoint(
config.state_path_for_step(step), train_state
)
# Link to most recent state path.
# NOTE: While hard linking would be preferable to creating symlinks, some train engines
# require a whole directory to save their state instead of a single file, which
# means state_path_for_step will be a directory, so a hard link won't work.
if config.is_local_main_process:
if config.state_path.is_symlink():
config.state_path.unlink()
config.state_path.symlink_to(
config.state_path_for_step(step).relative_to(config.work_dir)
)
# Link to best state path.
if is_best_checkpoint():
if config.best_state_path.is_symlink():
config.best_state_path.unlink()
config.best_state_path.symlink_to(
config.state_path_for_step(step).relative_to(config.work_dir)
)
# Clean up stale checkpoints.
if config.remove_stale_checkpoints:
checkpoints_to_keep = {
config.best_state_path.resolve(),
config.state_path.resolve(),
}
for path in config.work_dir.glob("checkpoint_state_step*"):
path = path.resolve()
if path not in checkpoints_to_keep:
if path.is_file():
path.unlink()
else:
shutil.rmtree(path)
if config.is_distributed:
dist.barrier()
# Catch data loader up to where we left off before.
current_epoch: int = -1
if start_step > 0:
with Tqdm.tqdm(
training_batches,
desc=f"Catching dataloader up to step {start_step}",
total=start_step - 1,
disable=not config.is_local_main_process,
) as batch_iter:
for step, (current_epoch, batch) in batch_iter:
del batch
if step >= start_step - 1:
break
if config.is_distributed:
dist.barrier()
for callback in callbacks:
callback.pre_train_loop()
train_batch_iterator_tqdm = Tqdm.tqdm(
training_batches,
desc="Training",
initial=start_step,
total=config.train_steps,
disable=not config.is_local_main_process,
)
train_batch_iterator = more_itertools.peekable(train_batch_iterator_tqdm)
try:
for step, (epoch, batch) in train_batch_iterator:
if epoch != current_epoch:
# Start of new epoch.
if epoch > 0:
# Call post-epoch callbacks for the last epoch.
for callback in callbacks:
callback.post_epoch(step, current_epoch)
for callback in callbacks:
callback.pre_epoch(step, epoch)
current_epoch = epoch
# Pre-batch callback.
for callback in callbacks:
callback.pre_batch(step, current_epoch, batch)
batch_loss = 0.0
batch_outputs = []
for micro_batch_idx, micro_batch in enumerate(batch):
# Get loss.
micro_batch_loss, micro_batch_outputs = training_engine.forward_train(
micro_batch, micro_batch_idx, len(batch)
)
if torch.isnan(micro_batch_loss):
raise ValueError("nan loss encountered")
batch_loss += micro_batch_loss.detach().item()
batch_outputs.append(
{
key: output.detach() if isinstance(output, torch.Tensor) else output
for key, output in micro_batch_outputs.items()
}
)
# Calculate gradients.
training_engine.backward(micro_batch_loss)
# Clean up in case it saves memory.
del micro_batch
del micro_batch_loss
del micro_batch_outputs
# Post-batch callback.
for callback in callbacks:
callback.post_batch(step, current_epoch, batch_loss, batch_outputs)
del batch
training_engine.step()
# Find out whether we should validate
if config.validation_split is None:
# If we can't validate, we don't.
should_validate_this_step = False
elif step == config.train_steps - 1:
# If we're at the end of the training run, we always validate.
should_validate_this_step = True
elif config.validate_every is not None and (step + 1) % config.validate_every == 0:
# If validate_every is given, we use that to decide.
should_validate_this_step = True
elif config.validate_every is None and epoch != train_batch_iterator.peek()[1][0]:
# If validate_every is not given, we validate at the end of the epoch.
should_validate_this_step = True
else:
# Otherwise, we don't validate.
should_validate_this_step = False
# Gather average loss across all workers.
if (
config.should_log_this_step(step) or should_validate_this_step
) and config.is_distributed:
batch_loss_tensor = torch.tensor(batch_loss, device=device)
dist.all_reduce(batch_loss_tensor)
batch_loss = batch_loss_tensor.item() / config.world_size
if config.should_log_this_step(step):
# Callbacks.
for callback in callbacks:
callback.log_batch(step, current_epoch, batch_loss, batch_outputs)
# Update progress bar.
metrics_to_log: Dict[str, float] = {"batch_loss": batch_loss}
if val_metric is not None:
metrics_to_log[f"val_{config.val_metric_name}"] = val_metric
if best_val_metric is not None:
metrics_to_log[f"best_val_{config.val_metric_name}"] = best_val_metric
if config.is_local_main_process:
train_batch_iterator_tqdm.set_postfix(**metrics_to_log)
# Validate.
if should_validate_this_step:
assert validation_dataloader is not None
assert config.validation_steps is not None
# Prepare model for validation.
training_engine.model.eval()
running_metric = 0.0
with Tqdm.tqdm(
islice(validation_dataloader, config.validation_steps),
desc="Validating",
total=config.validation_steps,
leave=False,
disable=not config.is_local_main_process,
) as val_batch_iterator:
for val_step, val_batch in enumerate(val_batch_iterator):
for callback in callbacks:
callback.pre_val_batch(step, val_step, current_epoch, val_batch)
# Get metric.
outputs = training_engine.forward_eval(val_batch)
for callback in callbacks:
callback.post_val_batch(step, val_step, current_epoch, outputs)
metric = outputs[config.val_metric_name]
if config.auto_aggregate_val_metric:
running_metric += metric if isinstance(metric, float) else metric.item()
val_metric = running_metric / (val_step + 1)
else:
val_metric = metric if isinstance(metric, float) else metric.item()
# Average metric across all workers.
if (
config.is_distributed
and config.should_log_this_val_step(val_step)
and config.auto_aggregate_val_metric
):
val_metric_tensor = torch.tensor(val_metric, device=device)
dist.all_reduce(val_metric_tensor)
val_metric = val_metric_tensor.item() / config.world_size
# Update progress bar.
if config.is_local_main_process and config.should_log_this_val_step(
val_step
):
val_batch_iterator.set_postfix(**{config.val_metric_name: val_metric})
# Clean up.
del val_batch
del outputs
del metric
assert val_metric is not None
# Reset model to train mode.
training_engine.model.train()
if (
best_val_metric is None
or (config.minimize_val_metric and val_metric <= best_val_metric)
or (not config.minimize_val_metric and val_metric >= best_val_metric)
):
best_val_metric = val_metric
# Checkpoint.
if config.should_checkpoint_this_step(step):
save_state(step)
# Post validation callback.
for callback in callbacks:
callback.post_val_loop(step, current_epoch, val_metric, best_val_metric)
# Reset model to train mode again in case the callbacks messed with it.
if callbacks:
training_engine.model.train()
# Update progress bar again.
metrics_to_log = {
"batch_loss": batch_loss,
f"val_{config.val_metric_name}": val_metric,
f"best_{config.val_metric_name}": best_val_metric,
}
if config.is_local_main_process:
train_batch_iterator_tqdm.set_postfix(**metrics_to_log)
else:
# Checkpoint.
if config.should_checkpoint_this_step(step):
save_state(step)
# End train loop
# Final post-epoch callback.
for callback in callbacks:
callback.post_epoch(step, current_epoch)
except StopEarly:
if config.is_local_main_process:
logger.info("Stopping early!")
finally:
train_batch_iterator_tqdm.close()
if config.is_distributed:
dist.barrier()
# If we haven't saved a checkpoint yet, do it now.
if not config.best_state_path.exists():
save_state(step)
for callback in callbacks:
callback.post_train_loop(step, current_epoch)
if config.is_local_main_process:
# It's possible this file already exists if the step previously failed after
# already saving the final weights.
if config.final_weights_path.is_file():
os.remove(config.final_weights_path)
training_engine.save_complete_weights_from_checkpoint(
config.best_state_path, config.final_weights_path
)
if not config.is_distributed:
return training_engine.model
else:
return None
def _cycle_through_epochs(dataloader: DataLoader, is_distributed: bool, grad_accum: int):
epoch = 0
while True:
if is_distributed and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
for batch in chunked(dataloader, grad_accum):
yield epoch, batch
epoch += 1
================================================
FILE: tango/integrations/torch/train_callback.py
================================================
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
from tango.common.dataset_dict import DatasetDictBase
from tango.common.registrable import Registrable
from tango.workspace import Workspace
from .data import DataLoader
from .exceptions import StopEarly
from .model import Model
from .train_config import TrainConfig
from .training_engine import TrainingEngine
class TrainCallback(Registrable):
"""
A :class:`TrainCallback` is a :class:`~tango.common.Registrable` class
that can be used within :class:`TorchTrainStep` to customize behavior in the training
loop. You can set the training callbacks with the ``callbacks`` parameter to :class:`TorchTrainStep`.
.. tip::
All of the parameters to this base class will be automatically set within
the training loop, so you shouldn't include them in your config for your callbacks.
.. tip::
You can access the model being trained through :attr:`self.model `.
.. important::
The ``step`` argument to callback methods is the total/overall number of training steps
so far, independent of the current epoch.
.. seealso::
See :class:`~tango.integrations.wandb.WandbTrainCallback` for an example
implementation.
:ivar Workspace workspace: The tango workspace being used.
:ivar TrainConfig train_config: The training config.
:ivar TrainingEngine training_engine: The engine used to train the model.
:ivar tango.common.DatasetDictBase dataset_dict: The dataset dict containing train and
optional validation splits.
:ivar DataLoader train_dataloader: The dataloader used for the training split.
:ivar DataLoader validation_dataloader: Optional dataloader used for the validation split.
"""
def __init__(
self,
workspace: Workspace,
train_config: TrainConfig,
training_engine: TrainingEngine,
dataset_dict: DatasetDictBase,
train_dataloader: DataLoader,
validation_dataloader: Optional[DataLoader] = None,
) -> None:
self.workspace = workspace
self.train_config = train_config
self.training_engine = training_engine
self.dataset_dict = dataset_dict
self.train_dataloader = train_dataloader
self.validation_dataloader = validation_dataloader
self.logger = logging.getLogger(self.__class__.__name__)
@property
def step_id(self) -> str:
"""
The unique ID of the current :class:`~tango.Step`.
"""
return self.train_config.step_id
@property
def step_name(self) -> Optional[str]:
"""
The name of the current :class:`~tango.Step`.
"""
return self.train_config.step_name
@property
def work_dir(self) -> Path:
"""
The working directory of the current train step.
"""
return self.train_config.work_dir
@property
def is_local_main_process(self) -> bool:
"""
This is ``True`` if the current worker is the main distributed worker of the current node, or if
we are not using distributed training.
"""
return self.train_config.is_local_main_process
@property
def model(self) -> Model:
"""
The :class:`Model` being trained.
"""
return self.training_engine.model
def state_dict(self) -> Dict[str, Any]:
"""
Return any state that needs to be kept after a restart.
Some callbacks need to maintain state across restarts. This is the callback's opportunity to
save it's state. It will be restored using :meth:`load_state_dict`.
"""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Load the state on a restart.
Some callbacks need to maintain state across restarts. This is the callback's opportunity to
restore it's state. It gets saved using :meth:`state_dict`.
"""
pass
def pre_train_loop(self) -> None:
"""
Called right before the first batch is processed, or after a restart.
"""
pass
def post_train_loop(self, step: int, epoch: int) -> None:
"""
Called after the training loop completes.
This is the last method that is called, so any cleanup can be done in this method.
"""
pass
def pre_epoch(self, step: int, epoch: int) -> None:
"""
Called right before the start of an epoch. Epochs start at 0.
"""
pass
def post_epoch(self, step: int, epoch: int) -> None:
"""
Called after an epoch is completed. Epochs start at 0.
"""
pass
def pre_batch(self, step: int, epoch: int, batch: List[Dict[str, Any]]) -> None:
"""
Called directly before processing a batch.
.. note::
A type of ``batch`` is a list because with gradient accumulation there will
more than one "micro batch" in the batch.
"""
pass
def post_batch(
self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]]
) -> None:
"""
Called directly after processing a batch, but before unscaling gradients,
clipping gradients, and taking an optimizer step.
.. note::
The ``batch_loss`` here is the loss local to the current worker, not the
overall (average) batch loss across distributed workers.
If you need the average loss, use :meth:`log_batch()`.
.. note::
A type of ``batch_outputs`` is a list because with gradient accumulation there will
more than one "micro batch" in the batch.
"""
pass
def log_batch(
self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]]
) -> None:
"""
Called after the optimizer step. Here ``batch_loss`` is the average loss across
all distributed workers.
.. note::
This callback method is not necessarily called on every step.
The frequency depends on the value of the ``log_every`` parameter of
:class:`TorchTrainStep`.
.. note::
A type of ``batch_outputs`` is a list because with gradient accumulation there will
more than one "micro batch" in the batch.
"""
pass
def pre_val_batch(
self, step: int, val_step: int, epoch: int, val_batch: Dict[str, Any]
) -> None:
"""
Called right before a validation batch is processed.
"""
pass
def post_val_batch(
self, step: int, val_step: int, epoch: int, val_batch_outputs: Dict[str, Any]
) -> None:
"""
Called right after a validation batch is processed with the outputs of the batch.
.. tip::
This method can be used to modify ``val_batch_outputs`` in place, which is useful
in scenarios like distributed training where you might need to aggregate metrics
in a special way other than a simple average. If that's the case, make sure
to set ``auto_aggregate_val_metric`` to ``False`` in :class:`TorchTrainStep`.
"""
pass
def post_val_loop(
self, step: int, epoch: int, val_metric: float, best_val_metric: float
) -> None:
"""
Called right after the validation loop finishes.
"""
pass
@TrainCallback.register("torch::stop_early")
class StopEarlyCallback(TrainCallback):
"""
A :class:`TrainCallback` for early stopping. Training is stopped early after
``patience`` steps without an improvement to the validation metric.
.. tip::
Registered as a :class:`TrainCallback` under the name "torch::stop_early".
"""
def __init__(self, *args, patience: int = 10000, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.patience = patience
self.best_step = 0
self.best_val_metric: Optional[float] = None
def post_val_loop(
self, step: int, epoch: int, val_metric: float, best_val_metric: float
) -> None:
# We can't rely on the best_val_metric parameter, because then we can't detect when the metric stays
# the same for many steps.
if self.best_val_metric is None or val_metric > self.best_val_metric:
self.best_step = step
self.best_val_metric = val_metric
elif step > self.best_step + self.patience:
raise StopEarly
def state_dict(self) -> Dict[str, Any]:
"""
Return any state that needs to be kept after a restart.
"""
return {
"patience": self.patience,
"best_step": self.best_step,
"best_val_metric": self.best_val_metric,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Load the state on a restart.
"""
self.patience = state_dict["patience"]
self.best_step = state_dict["best_step"]
self.best_val_metric = state_dict["best_val_metric"]
================================================
FILE: tango/integrations/torch/train_config.py
================================================
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
@dataclass
class TrainConfig:
"""
Encapsulates the parameters of :class:`TorchTrainStep`. This is used to pass all the training
options to :class:`TrainCallback`.
"""
step_id: str
"""
The unique ID of the current step.
"""
work_dir: Path
"""
The working directory for the training run.
"""
step_name: Optional[str] = None
"""
The name of the current step.
.. note::
The same step can be run under different names.
"""
worker_id: int = 0
"""
The ID of the distributed worker.
"""
train_split: str = "train"
"""
The name of the training split.
"""
validation_split: Optional[str] = None
"""
The name of the validation split.
"""
seed: int = 42
"""
The random seed.
"""
train_steps: Optional[int] = None
"""
The number of steps to train for.
"""
train_epochs: Optional[int] = None
"""
The number of epochs to train for.
You cannot specify `train_steps` and `train_epochs` at the same time.
"""
validation_steps: Optional[int] = None
"""
The number of validation steps.
The default is to validate on the entire validation set.
"""
grad_accum: int = 1
"""
The number of micro-batches per gradient accumulation mini-batch.
"""
log_every: int = 10
"""
Controls the frequency of log updates, in number of optimizer steps
"""
checkpoint_every: int = 100
"""
Controls the frequency of checkpoints, in number of optimizer steps
"""
validate_every: Optional[int] = None
"""
Controls the frequency of the validation loop, in number of optimizer steps
"""
is_distributed: bool = False
"""
Whether or not the training job is distributed.
"""
devices: Optional[List[int]] = None
"""
The devices used (for distributed jobs).
"""
distributed_address: str = "127.0.0.1"
"""
The IP address of the main distributed process.
"""
distributed_port: int = 54761
"""
The port of the main distributed process.
"""
val_metric_name: str = "loss"
"""
The name of the validation metric to track.
"""
minimize_val_metric: bool = True
"""
Should be ``True`` when the validation metric being tracked should be minimized.
"""
auto_aggregate_val_metric: bool = True
"""
Controls automatic aggregation of validation metric.
"""
remove_stale_checkpoints: bool = True
"""
Controls removal of stale checkpoints.
"""
world_size: int = 1
"""
The number of distributed workers.
"""
_worker_local_default_device: Optional[torch.device] = None
_device_type: Optional[str] = None # either "cuda" or "cpu"
@property
def worker_local_default_device(self) -> torch.device:
"""
The default ``torch`` device for the current worker.
"""
if self._worker_local_default_device is not None:
return self._worker_local_default_device
else:
if self.devices:
device_id = self.devices[self.worker_id]
if device_id >= 0:
device = torch.device(f"cuda:{device_id}")
else:
device = torch.device("cpu")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
self._worker_local_default_device = device
return device
@property
def device_type(self) -> str:
if self._device_type is None:
device_type = (
"cpu" if self.worker_local_default_device == torch.device("cpu") else "cuda"
)
self._device_type = device_type
return device_type
else:
return self._device_type
@property
def is_local_main_process(self) -> bool:
"""
Whether the local process is the main distributed worker.
"""
return self.worker_id == 0
@property
def state_path(self) -> Path:
"""
The path to the latest state checkpoint file.
"""
return self.work_dir / "checkpoint_state_latest"
@property
def best_state_path(self) -> Path:
"""
The path to the best state checkpoint file according to the validation metric or training
loss (if no validation split is given).
"""
return self.work_dir / "checkpoint_state_best"
def state_path_for_step(self, step: int) -> Path:
return self.work_dir / f"checkpoint_state_step{step + 1}"
@property
def final_weights_path(self) -> Path:
return self.work_dir / "weights.pt"
def should_log_this_step(self, step: int) -> bool:
assert self.train_steps is not None
return step == 0 or (step + 1) % self.log_every == 0 or step == self.train_steps - 1
def should_checkpoint_this_step(self, step: int) -> bool:
assert self.train_steps is not None
return ((step + 1) % self.checkpoint_every == 0) or step == self.train_steps - 1
def should_log_this_val_step(self, val_step: int) -> bool:
assert self.validation_steps is not None
return val_step % self.log_every == 0 or val_step == self.validation_steps - 1
def as_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if not k.startswith("_")}
================================================
FILE: tango/integrations/torch/training_engine.py
================================================
import os
import tempfile
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union, cast
import torch
import torch.distributed as dist
import torch.nn as nn
from tango.common import Lazy, Registrable, Tqdm
from .model import Model
from .optim import LRScheduler, Optimizer
from .train_config import TrainConfig
from .util import move_to_device
class TrainingEngine(Registrable):
"""
A :class:`TrainingEngine` defines and drives the strategy for training a model
in :class:`TorchTrainStep`.
:ivar TrainConfig train_config: The training config.
:ivar Model model: The model being trained.
:ivar Optimizer optimizer: The optimizer being used to train the model.
:ivar LRScheduler lr_scheduler: The optional learning rate scheduler.
"""
default_implementation = "torch"
"""
The default implementation is :class:`TorchTrainingEngine`.
"""
def __init__(
self,
train_config: TrainConfig,
model: Union[Model, Lazy[Model]],
optimizer: Lazy[Optimizer],
*,
lr_scheduler: Optional[Lazy[LRScheduler]] = None,
) -> None:
self.train_config = train_config
self.model = self._construct_model(model)
self.optimizer = self._construct_optimizer(optimizer)
self.lr_scheduler: Optional[LRScheduler] = None
if lr_scheduler is not None:
self.lr_scheduler = self._construct_lr_scheduler(lr_scheduler)
def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model:
if isinstance(model, Lazy):
model = model.construct()
return model.to(self.train_config.worker_local_default_device)
def _construct_optimizer(self, optimizer: Lazy[Optimizer]) -> Optimizer:
optimizer: Optimizer = optimizer.construct(params=self.model.parameters())
return optimizer
def _construct_lr_scheduler(self, lr_scheduler: Lazy[LRScheduler]) -> LRScheduler:
lr_scheduler: LRScheduler = lr_scheduler.construct(optimizer=self.optimizer)
return lr_scheduler
@abstractmethod
def forward_train(
self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""
Run a forward training pass on the model.
"""
raise NotImplementedError
@abstractmethod
def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"""
Run a forward evaluation pass on the model.
"""
raise NotImplementedError
@abstractmethod
def backward(self, loss: torch.Tensor) -> None:
"""
Run a backwards pass on the model. This will always be called after :meth:`forward_train()`.
"""
raise NotImplementedError
@abstractmethod
def step(self) -> None:
"""
Take an optimization step. This will always be called after :meth:`backward()`.
"""
raise NotImplementedError
@abstractmethod
def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None:
"""
Save a training checkpoint with model state, optimizer state, etc., as well
as the arbitrary ``client_state`` to the given ``checkpoint_dir``.
"""
raise NotImplementedError
@abstractmethod
def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]:
"""
Load a checkpoint to resume training. Should return the same ``client_state`` saved
in :meth:`save_checkpoint()`.
"""
raise NotImplementedError
@abstractmethod
def save_complete_weights_from_checkpoint(
self, checkpoint_dir: Path, weights_path: Path
) -> None:
"""
Gather the final weights from the best checkpoint and save to the file at ``weights_path``.
"""
raise NotImplementedError
@TrainingEngine.register("torch")
class TorchTrainingEngine(TrainingEngine):
"""
This train engine only uses native PyTorch functionality to provide
vanilla distributed data parallel training and AMP.
.. tip::
Registered as a :class:`TrainingEngine` under the name "torch".
.. important::
Only the parameters listed below should be defined in a configuration
file. The other parameters will be automatically passed to the constructor
within :class:`TorchTrainStep`.
:param amp:
Use automatic mixed precision. Default is ``False``.
:param max_grad_norm:
If set, gradients will be clipped to have this max norm. Default is ``None``.
:param amp_use_bfloat16:
Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training.
Only applicable when ``amp=True``. If not specified, the default behavior will be
to use ``bfloat16`` when training with AMP on CPU, otherwise not.
"""
def __init__(
self,
train_config: TrainConfig,
model: Union[Model, Lazy[Model]],
optimizer: Lazy[Optimizer],
*,
lr_scheduler: Optional[Lazy[LRScheduler]] = None,
amp: bool = False,
max_grad_norm: Optional[float] = None,
amp_use_bfloat16: Optional[bool] = None,
) -> None:
self.device = train_config.worker_local_default_device
if amp_use_bfloat16 is None:
amp_use_bfloat16 = True if train_config.device_type == "cpu" else False
self.amp = amp
self.amp_dtype = torch.bfloat16 if amp_use_bfloat16 else torch.float16
self.max_grad_norm = max_grad_norm
self.grad_scaler: Optional[torch.cuda.amp.GradScaler] = (
None if not amp else torch.cuda.amp.GradScaler()
)
if train_config.is_distributed:
# Initialize distributed process group.
backend: str
if train_config.device_type != "cpu":
torch.cuda.set_device(self.device)
backend = "nccl"
else:
backend = "gloo"
dist.init_process_group(
backend=backend,
init_method=f"tcp://{train_config.distributed_address}:{train_config.distributed_port}",
world_size=train_config.world_size,
rank=train_config.worker_id,
)
super().__init__(train_config, model, optimizer, lr_scheduler=lr_scheduler)
def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model:
if isinstance(model, Lazy):
model = model.construct()
model.to(self.train_config.worker_local_default_device)
# Wrap model with DDP wrapper.
if self.train_config.is_distributed:
model = cast(Model, nn.parallel.DistributedDataParallel(model))
return model
def forward_train(
self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int
) -> Tuple[torch.Tensor, Dict[str, Any]]:
if micro_batch_idx == 0:
self.optimizer.zero_grad(set_to_none=True)
# Move tensors to right device.
micro_batch = move_to_device(micro_batch, self.device)
with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype):
outputs = self.model(**micro_batch)
micro_batch_loss = outputs["loss"] / num_micro_batches
return micro_batch_loss, outputs
def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# Move tensors to right device.
batch = move_to_device(batch, self.device)
with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype):
with torch.inference_mode():
outputs = self.model(**batch)
return outputs
def backward(self, loss: torch.Tensor) -> None:
if self.grad_scaler is not None:
self.grad_scaler.scale(loss).backward()
else:
loss.backward()
def clip_grad_norm(self) -> None:
if self.max_grad_norm is not None:
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
def step(self) -> None:
# Unscale gradients.
if self.grad_scaler is not None:
self.grad_scaler.unscale_(self.optimizer)
# Clip gradients.
self.clip_grad_norm()
# Take optimizer step.
if self.grad_scaler is not None:
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
self.optimizer.step()
# Adjust LR schedule.
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def get_model_state(self) -> Dict[str, torch.Tensor]:
if self.train_config.is_distributed:
return self.model.module.state_dict() # type: ignore[union-attr]
else:
return self.model.state_dict()
def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None:
if self.train_config.is_distributed:
self.model.module.load_state_dict(state_dict) # type: ignore
else:
self.model.load_state_dict(state_dict) # type: ignore
def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None:
checkpoint_dir.mkdir(exist_ok=True)
def save_state(state: Dict[str, Any], name: str):
temp_state_file = tempfile.NamedTemporaryFile(
"w+b", dir=checkpoint_dir, delete=False, suffix=".pt"
)
try:
with Tqdm.wrapattr(
temp_state_file,
"write",
desc=f"Saving {name} state",
leave=False,
disable=not self.train_config.is_local_main_process,
) as f:
torch.save(state, f)
temp_state_file.close()
os.replace(
temp_state_file.name,
checkpoint_dir / f"worker{self.train_config.worker_id}_{name}.pt",
)
finally:
if os.path.exists(temp_state_file.name):
os.remove(temp_state_file.name)
save_state(self.get_model_state(), "model")
save_state(self.optimizer.state_dict(), "optimizer"),
if self.lr_scheduler is not None:
save_state(self.lr_scheduler.state_dict(), "lr_scheduler")
if self.grad_scaler is not None:
save_state(self.grad_scaler.state_dict(), "grad_scaler")
save_state(client_state, "trainer")
def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]:
self.load_model_state(
torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_model.pt")
)
self.optimizer.load_state_dict(
torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_optimizer.pt")
)
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(
torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_lr_scheduler.pt")
)
if self.grad_scaler is not None:
self.grad_scaler.load_state_dict(
torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_grad_scaler.pt")
)
return torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_trainer.pt")
def save_complete_weights_from_checkpoint(
self, checkpoint_dir: Path, weights_path: Path
) -> None:
os.link(checkpoint_dir.resolve() / "worker0_model.pt", weights_path)
================================================
FILE: tango/integrations/torch/util.py
================================================
import random
import warnings
from collections import UserDict
from typing import Dict, Optional, TypeVar, Union
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler, IterableDataset
from .data import DataLoader
T = TypeVar("T")
def move_to_device(o: T, device: torch.device) -> T:
if isinstance(o, torch.Tensor):
return o.to(device) # type: ignore[return-value]
elif isinstance(o, dict) or isinstance(o, UserDict):
return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
elif isinstance(o, list):
return [move_to_device(x, device) for x in o] # type: ignore[return-value]
elif isinstance(o, tuple):
return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value]
else:
return o
def check_dataset(dataset, split: str):
try:
len(dataset)
except TypeError:
if not isinstance(dataset, IterableDataset):
warnings.warn(
f"Dataset for {split} split appears to be a streaming/iterable dataset, "
"but is not an instance of 'torch.utils.data.IterableDataset'. This could cause issues "
"within the DataLoader.",
UserWarning,
)
def check_dataloader(dataloader: DataLoader):
# If using a regular dataset and not streaming/iterable dataset, we
# should probably be using a `DistributedSampler`.
if not isinstance(dataloader.dataset, IterableDataset) and not isinstance(
dataloader.sampler, DistributedSampler
):
warnings.warn(
"DistributedSampler is required for dataloader during distributed training, "
f"found {type(dataloader.sampler)} instead.",
UserWarning,
)
def set_seed_all(seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
try:
import numpy as np
except ModuleNotFoundError:
pass
else:
np.random.seed(seed)
def resolve_device(device: Optional[Union[int, str, torch.device]] = None) -> torch.device:
if device is None:
if torch.cuda.is_available():
# TODO (epwalsh, dirkgr): automatically pick which GPU to use when there are multiple
return torch.device("cuda")
else:
return torch.device("cpu")
elif isinstance(device, int):
if device >= 0:
return torch.device(f"cuda:{device}")
else:
return torch.device("cpu")
elif isinstance(device, str):
return torch.device(device)
elif isinstance(device, torch.device):
return device
else:
raise TypeError(f"unexpected type for 'device': '{device}'")
def peak_gpu_memory(reset: bool = False) -> Dict[int, int]:
"""
Get the peak GPU memory usage in MiB by distributed worker rank.
:returns:
Keys are rank ids as integers (from 0 up to world size - 1).
Values are memory usage as integers in MiB.
Returns an empty `dict` if GPUs are not available.
"""
if not torch.cuda.is_available():
return {}
device = torch.device("cuda")
results_dict: Dict[int, int] = {}
if dist.is_available() and dist.is_initialized():
# If the backend is not 'nccl', we're training on CPU.
if dist.get_backend() != "nccl":
return {}
global_rank = dist.get_rank()
world_size = dist.get_world_size()
peak_mb = torch.cuda.max_memory_allocated(device) // 1048576
peak_mb_tensor = torch.tensor([global_rank, peak_mb], device=device)
# All of these tensors will be gathered into this list.
gather_results = [torch.tensor([0, 0], device=device) for _ in range(world_size)]
dist.all_gather(gather_results, peak_mb_tensor)
for peak_mb_tensor in gather_results:
results_dict[int(peak_mb_tensor[0])] = int(peak_mb_tensor[1])
else:
results_dict = {0: torch.cuda.max_memory_allocated()}
if reset:
# Reset peak stats.
torch.cuda.reset_max_memory_allocated(device)
return results_dict
================================================
FILE: tango/integrations/transformers/__init__.py
================================================
"""
.. important::
To use this integration you should install ``tango`` with the "transformers" extra
(e.g. ``pip install tango[transformers]``) or just install the ``transformers`` library after the fact
(e.g. ``pip install transformers``).
Components for Tango integration with `🤗 Transformers `_.
This integration provides some useful steps and also registers PyTorch components from the transformers
library under the corresponding class from the `torch `_ integration, such as:
- :class:`~tango.integrations.torch.Model`: All transformers "auto" model classes are registered
according to their class names (e.g. "transformers::AutoModelForCausalLM::from_pretrained"
or "transformers::AutoModelForCausalLM::from_config").
For example, to instantiate a pretrained transformer model from params:
.. testcode::
from tango.integrations.torch import Model
model = Model.from_params({
"type": "transformers::AutoModel::from_pretrained",
"pretrained_model_name_or_path": "epwalsh/bert-xsmall-dummy",
})
Or to instantiate a transformer model from params without loading pretrained weights:
.. testcode::
from tango.integrations.torch import Model
model = Model.from_params({
"type": "transformers::AutoModel::from_config",
"config": {"pretrained_model_name_or_path": "epwalsh/bert-xsmall-dummy"},
})
.. tip::
You can see a list of all of the available auto model constructors from transformers by running:
.. testcode::
from tango.integrations.torch import Model
from tango.integrations.transformers import *
available_models = []
for name in sorted(Model.list_available()):
if name.startswith("transformers::AutoModel"):
available_models.append(name)
- :class:`~tango.integrations.torch.Optimizer`: All optimizers from transformers are registered according
to their class names (e.g. "transformers::AdaFactor").
.. tip::
You can see a list of all of the available optimizers from transformers by running:
.. testcode::
from tango.integrations.torch import Optimizer
from tango.integrations.transformers import *
for name in sorted(Optimizer.list_available()):
if name.startswith("transformers::"):
print(name)
.. testoutput::
transformers::Adafactor
transformers::AdamW
transformers::LayerWiseDummyOptimizer
- :class:`~tango.integrations.torch.LRScheduler`: All learning rate scheduler function from transformers
are registered according to their type name (e.g. "transformers::linear").
.. tip::
You can see a list of all of the available scheduler functions from transformers by running:
.. testcode::
from tango.integrations.torch import LRScheduler
from tango.integrations.transformers import *
for name in sorted(LRScheduler.list_available()):
if name.startswith("transformers::"):
print(name)
.. testoutput::
transformers::constant
transformers::constant_with_warmup
transformers::cosine
transformers::cosine_with_min_lr
transformers::cosine_with_restarts
transformers::inverse_sqrt
transformers::linear
transformers::polynomial
transformers::reduce_lr_on_plateau
- :class:`~tango.integrations.torch.DataCollator`: All data collators from transformers
are registered according to their class name (e.g. "transformers::DefaultDataCollator").
You can instantiate any of these from a config / params like so:
.. testcode::
from tango.integrations.torch import DataCollator
collator = DataCollator.from_params({
"type": "transformers::DataCollatorWithPadding",
"tokenizer": {
"pretrained_model_name_or_path": "epwalsh/bert-xsmall-dummy",
},
})
.. tip::
You can see a list of all of the available data collators from transformers by running:
.. testcode::
from tango.integrations.torch import DataCollator
from tango.integrations.transformers import *
for name in sorted(DataCollator.list_available()):
if name.startswith("transformers::"):
print(name)
.. testoutput::
transformers::DataCollatorForLanguageModeling
transformers::DataCollatorForPermutationLanguageModeling
transformers::DataCollatorForSOP
transformers::DataCollatorForSeq2Seq
transformers::DataCollatorForTokenClassification
transformers::DataCollatorForWholeWordMask
transformers::DataCollatorWithPadding
transformers::DefaultDataCollator
"""
from tango.common.exceptions import IntegrationMissingError
try:
import transformers
except ModuleNotFoundError:
raise IntegrationMissingError("transformers")
__all__ = [
"RunGeneration",
"RunGenerationDataset",
"Tokenizer",
"Config",
"add_soft_prompt",
"FinetuneWrapper",
"FinetuneStep",
"TokenizeText2TextData",
]
from .config import Config
from .data import * # noqa: F403
from .finetune import FinetuneStep, FinetuneWrapper, TokenizeText2TextData
from .model import * # noqa: F403
from .optim import * # noqa: F403
from .run_generation import RunGeneration, RunGenerationDataset
from .soft_prompt import add_soft_prompt
from .tokenizer import Tokenizer
================================================
FILE: tango/integrations/transformers/config.py
================================================
from transformers import AutoConfig, PretrainedConfig
from tango.common import Registrable
class Config(PretrainedConfig, Registrable):
"""
A :class:`~tango.common.Registrable` version of transformers'
:class:`~transformers.PretrainedConfig`.
"""
default_implementation = "auto"
"""
The default registered implementation just calls
:meth:`transformers.AutoConfig.from_pretrained()`.
"""
Config.register("auto", constructor="from_pretrained")(AutoConfig)
================================================
FILE: tango/integrations/transformers/data.py
================================================
from dataclasses import fields, is_dataclass
from typing import Callable
from transformers.data import data_collator as transformers_data_collator
from tango.integrations.torch.data import DataCollator
from .tokenizer import Tokenizer
# Some data collators take a tokenizer, so in order to instantiate those collators from params,
# we need to use a factory function that takes our registrable version of a tokenizer as
# an argument.
def data_collator_with_tokenizer_factory(cls) -> Callable[..., DataCollator]:
def factory(tokenizer: Tokenizer, **kwargs) -> DataCollator:
return cls(tokenizer=tokenizer, **kwargs)
return factory
for name, cls in transformers_data_collator.__dict__.items():
if (
isinstance(cls, type)
and is_dataclass(cls)
and "DataCollator" in name
and hasattr(cls, "__call__")
):
for field in fields(cls):
if field.name == "tokenizer":
factory_func = data_collator_with_tokenizer_factory(cls)
DataCollator.register("transformers::" + name)(factory_func) # type: ignore
break
else:
DataCollator.register("transformers::" + name)(cls)
================================================
FILE: tango/integrations/transformers/finetune.py
================================================
import logging
from os import PathLike
from typing import List, Optional, Union, cast
import datasets as ds
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
DataCollatorForSeq2Seq,
PreTrainedModel,
)
from tango.common import Lazy, Params
from tango.format import Format
from tango.integrations.datasets import DatasetsFormat, convert_to_tango_dataset_dict
from tango.integrations.torch import (
DataCollator,
DataLoader,
Model,
TorchFormat,
TrainCallback,
TrainingEngine,
)
from tango.integrations.torch.train import TorchTrainStep
from tango.integrations.transformers import Tokenizer
from tango.step import Step
logger = logging.getLogger(__name__)
SEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys() # type: ignore
CAUSAL = AutoModelForCausalLM._model_mapping.keys() # type: ignore
class FinetuneWrapper(PreTrainedModel):
"""
Wrapper `PreTrainedModel` class that returns either a `Seq2SeqLM` or `CausalLM` model.
"""
@classmethod
def from_pretrained( # type: ignore
cls,
pretrained_model_name_or_path: Union[str, PathLike],
num_tokens: Optional[int] = None,
**kwargs,
) -> PreTrainedModel:
"""
:param pretrained_model_name_or_path:
The name of the model to return. Any name that works in the transformers library works here.
:param num_tokens:
The number of token embeddings to have.
"""
try:
model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
except ValueError:
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
if num_tokens is not None:
model.resize_token_embeddings(num_tokens)
return model
Model.register("transformers::finetune::from_pretrained", constructor="from_pretrained")(
FinetuneWrapper
)
def _add_special_tokens(tokenizer: Tokenizer) -> None:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.sep_token is None:
tokenizer.add_special_tokens({"sep_token": "[SEP]"})
if tokenizer.eos_token is None:
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
def tokenize_data(
data: ds.DatasetDict,
tokenizer: Tokenizer,
num_workers: int = 1,
source_field: str = "source",
target_field: str = "target",
max_source_length: Optional[int] = 1024,
max_target_length: Optional[int] = 1024,
pad_to_max_length: bool = False,
ignore_pad_token_for_loss: bool = True,
concat_source_target: bool = False,
) -> ds.DatasetDict:
"""
Returns a `DatasetDict` with tokenized source and target fields.
:param data:
The original dataset dict containing the source and target fields.
:param tokenizer:
The tokenizer to use.
:param num_workers:
The number of workers to use for processing the data.
:param source_field:
The string name of the field containing the source sequence.
:param target_field:
The string name of the field containing the target sequence.
:param max_source_length:
The maximum number of tokens in the source sequence.
:param max_target_length:
The maximum number of tokens in the target sequence.
:param pad_to_max_length:
Whether to pad to the maximum length when tokenizing.
:param ignore_pad_token_for_loss:
Whether to ignore the padded tokens for calculating loss.
If set to True, all the pad tokens in the labels are replaced
by -100, which is ignored by the loss function.
:param concat_source_target:
If the downstream model is decoder-only, like "gpt2", the source
and target sequences need to be concatenated and fed to the model
together.
"""
padding = "max_length" if pad_to_max_length else False
_add_special_tokens(tokenizer)
def preprocess_function(examples):
# remove pairs where at least one record is None
inputs, targets = [], []
input_lengths = []
for i in range(len(examples[source_field])):
if examples[source_field][i] is not None and examples[target_field][i] is not None:
if not concat_source_target:
inputs.append(examples[source_field][i])
targets.append(examples[target_field][i])
else:
text = (
examples[source_field][i]
+ tokenizer.sep_token
+ examples[target_field][i]
+ tokenizer.eos_token
)
inputs.append(text)
targets.append(text)
input_lengths.append(len(examples[source_field][i]))
model_inputs = tokenizer(
inputs, max_length=max_source_length, padding=padding, truncation=True
)
if not concat_source_target:
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(
targets, max_length=max_target_length, padding=padding, truncation=True
)
else:
labels = {"input_ids": []}
for input_ids in model_inputs["input_ids"]:
label_start_idx = input_ids.index(tokenizer.sep_token_id)
label_ids = [-100] * len(input_ids)
label_ids[label_start_idx + 1 :] = input_ids[label_start_idx + 1 :]
labels["input_ids"].append(label_ids)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
# when we want to ignore padding in the loss.
if padding == "max_length" and ignore_pad_token_for_loss:
labels["input_ids"] = [
[(lb if lb != tokenizer.pad_token_id else -100) for lb in label]
for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
data = data.map(
preprocess_function,
batched=True,
num_proc=num_workers,
remove_columns=list(data.column_names.values())[0], # remove all old columns
desc="Tokenizing dataset",
)
return data
@Step.register("transformers::tokenize_text2text")
class TokenizeText2TextData(Step):
"""
A step that tokenizes data containing source and target sequences.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "transformers::tokenize_text2text".
"""
DETERMINISTIC = True
CACHEABLE = True
FORMAT = DatasetsFormat()
def run( # type: ignore[override]
self,
data: ds.DatasetDict,
tokenizer: Tokenizer,
num_workers: int = 1,
source_field: str = "source",
target_field: str = "target",
max_source_length: Optional[int] = 1024,
max_target_length: Optional[int] = 1024,
pad_to_max_length: bool = False,
ignore_pad_token_for_loss: bool = True,
concat_source_target: bool = False,
) -> ds.DatasetDict:
"""
Returns a `DatasetDict` with tokenized source and target fields.
:param data:
The original dataset dict containing the source and target fields.
:param tokenizer:
The tokenizer to use.
:param num_workers:
The number of workers to use for processing the data.
:param source_field:
The string name of the field containing the source sequence.
:param target_field:
The string name of the field containing the target sequence.
:param max_source_length:
The maximum number of tokens in the source sequence.
:param max_target_length:
The maximum number of tokens in the target sequence.
:param pad_to_max_length:
Whether to pad to the maximum length when tokenizing.
:param ignore_pad_token_for_loss:
Whether to ignore the padded tokens for calculating loss.
If set to True, all the pad tokens in the labels are replaced
by -100, which is ignored by the loss function.
:param concat_source_target:
If the downstream model is decoder-only, like "gpt2", the source
and target sequences need to be concatenated and fed to the model
together.
.. tip::
If concat_source_target is set to True, we pad all sequences to max
length here. Otherwise, we leave it to the appropriate
:class:`~tango.integrations.torch.DataCollator` object.
"""
return tokenize_data(
data,
tokenizer=tokenizer,
num_workers=num_workers,
source_field=source_field,
target_field=target_field,
max_source_length=max_source_length,
max_target_length=max_target_length,
pad_to_max_length=pad_to_max_length,
ignore_pad_token_for_loss=ignore_pad_token_for_loss,
concat_source_target=concat_source_target,
)
@Step.register("transformers::finetune")
class FinetuneStep(TorchTrainStep):
"""
Mostly similar to :class:`~tango.integrations.torch.train.TorchTrainStep` with additional
preprocessing for data.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "transformers::finetune".
.. important::
The training loop will use GPU(s) automatically when available, as long as at least
``device_count`` CUDA devices are available.
Distributed data parallel training is activated when the ``device_count`` is greater than 1.
You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``.
For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1``
(and ``device_count`` to 2).
.. warning::
During validation, the validation metric (specified by the ``val_metric_name`` parameter)
is aggregated by simply averaging across validation batches and distributed processes.
This behavior is usually correct when your validation metric is "loss" or "accuracy",
for example, but may not be correct for other metrics like "F1".
If this is not correct for your metric you will need to handle the aggregation
internally in your model or with a :class:`TrainCallback`
using the :meth:`TrainCallback.post_val_batch()` method.
Then set the parameter ``auto_aggregate_val_metric`` to ``False``.
Note that correctly aggregating your metric during distributed training will
involve distributed communication.
"""
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = TorchFormat()
SKIP_ID_ARGUMENTS = {"distributed_port", "log_every"}
def run( # type: ignore[override]
self,
model: Lazy[Model],
tokenizer: Tokenizer,
training_engine: Lazy[TrainingEngine],
dataset_dict: ds.DatasetDict,
train_dataloader: Lazy[DataLoader],
*,
train_split: str = "train",
validation_split: Optional[str] = None,
validation_dataloader: Optional[Lazy[DataLoader]] = None,
source_field: str = "source",
target_field: str = "target",
max_source_length: Optional[int] = 1024,
max_target_length: Optional[int] = 1024,
seed: int = 42,
train_steps: Optional[int] = None,
train_epochs: Optional[int] = None,
validation_steps: Optional[int] = None,
grad_accum: int = 1,
log_every: int = 10,
checkpoint_every: int = 100,
validate_every: Optional[int] = None,
device_count: int = 1,
distributed_port: int = 54761,
val_metric_name: str = "loss",
minimize_val_metric: bool = True,
auto_aggregate_val_metric: bool = True,
callbacks: Optional[List[Lazy[TrainCallback]]] = None,
remove_stale_checkpoints: bool = True,
) -> Model:
"""
Run a basic training loop to train the ``model``.
:param model:
The model to train. It should return a ``dict`` that includes the ``loss``
during training and the ``val_metric_name`` during validation.
:param tokenizer:
The tokenizer to use for tokenizing source and target sequences.
:param training_engine:
The :class:`TrainingEngine` to use to train the model.
:param dataset_dict:
The train and optional validation data.
:param train_dataloader:
The data loader that generates training batches. The batches should be :class:`dict`
objects that will be used as ``kwargs`` for the model's ``forward()`` method.
:param train_split:
The name of the data split used for training in the ``dataset_dict``.
Default is "train".
:param validation_split:
Optional name of the validation split in the ``dataset_dict``. Default is ``None``,
which means no validation.
:param validation_dataloader:
An optional data loader for generating validation batches. The batches should be
:class:`dict` objects. If not specified, but ``validation_split`` is given,
the validation ``DataLoader`` will be constructed from the same parameters
as the train ``DataLoader``.
:param source_field:
The string name of the field containing the source sequence.
:param target_field:
The string name of the field containing the target sequence.
:param max_source_length:
The maximum number of tokens in the source sequence.
:param max_target_length:
The maximum number of tokens in the target sequence.
:param seed:
Used to set the RNG states at the beginning of training.
:param train_steps:
The number of steps to train for. If not specified training will
stop after a complete iteration through the ``train_dataloader``.
:param train_epochs:
The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs``
at the same time.
:param validation_steps:
The number of steps to validate for. If not specified validation
will stop after a complete iteration through the ``validation_dataloader``.
:param grad_accum:
The number of gradient accumulation steps. Defaults to 1.
.. note::
This parameter - in conjuction with the settings of your data loader
and the number distributed workers -
determines the *effective batch size* of your training run.
:param log_every:
Log every this many steps.
:param checkpoint_every:
Save a checkpoint every this many steps.
:param validate_every:
Run the validation loop every this many steps.
:param device_count:
The number of devices to train on, i.e. the number of distributed data parallel workers.
:param distributed_port:
The port of the distributed process group. Default = "54761".
:param val_metric_name:
The name of the validation metric, i.e. the key of the metric in the dictionary
returned by the forward pass of the model. Default is "loss".
:param minimize_val_metric:
Whether the validation metric is meant to be minimized (such as the loss).
Default is ``True``. When using a metric such as accuracy, you should set
this to ``False``.
:param auto_aggregate_val_metric:
If ``True`` (the default), the validation metric will be averaged across
validation batches and distributed processes. This may not be the correct
behavior for some metrics (such as F1), in which you should set this to
``False`` and handle the aggregation internally in your model
or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`).
:param callbacks:
A list of :class:`TrainCallback`.
:param remove_stale_checkpoints:
If ``True`` (the default), stale checkpoints will be removed throughout training so that
only the latest and best checkpoints are kept.
:returns:
The trained model on CPU with the weights from the best checkpoint loaded.
"""
devices = self._get_devices(device_count)
is_distributed = False
if devices and len(devices) > 1:
is_distributed = True
# Setup the tokenizer
_add_special_tokens(tokenizer)
# Hacky way to deal with resizing the model embeddings.
model_params_dict = model._params.as_dict()
if "fairscale" in model_params_dict["type"]:
model_params_dict["model"]["num_tokens"] = len(tokenizer) # type: ignore
else:
model_params_dict["num_tokens"] = len(tokenizer) # type: ignore
model = Lazy(
model._constructor,
Params(model_params_dict),
constructor_extras=model._constructor_extras,
)
# Get the config to check in order to check if the model is seq2seq or causal.
config = AutoConfig.from_pretrained(tokenizer.name_or_path)
seq2seq: bool = type(config) in SEQ2SEQ
dataset_dict = tokenize_data(
dataset_dict,
tokenizer=tokenizer,
source_field=source_field,
target_field=target_field,
max_source_length=max_source_length,
max_target_length=max_target_length,
concat_source_target=not seq2seq,
)
if is_distributed:
from torch.utils.data.distributed import DistributedSampler
sampler = Lazy(DistributedSampler, drop_last=True, shuffle=True)
train_dataloader = Lazy(
train_dataloader._constructor,
train_dataloader._params,
constructor_extras=train_dataloader._constructor_extras,
sampler=sampler,
)
collate_fn: DataCollator
collate_fn = cast(DataCollator, DataCollatorForSeq2Seq(tokenizer=tokenizer))
train_dataloader = Lazy(
train_dataloader._constructor,
train_dataloader._params,
constructor_extras=train_dataloader._constructor_extras,
collate_fn=collate_fn,
)
return self._train(
model=model,
training_engine=training_engine,
dataset_dict=convert_to_tango_dataset_dict(dataset_dict),
train_dataloader=train_dataloader,
train_split=train_split,
validation_split=validation_split,
validation_dataloader=validation_dataloader,
seed=seed,
train_steps=train_steps,
train_epochs=train_epochs,
validation_steps=validation_steps,
grad_accum=grad_accum,
log_every=log_every,
checkpoint_every=checkpoint_every,
validate_every=validate_every,
devices=devices,
distributed_port=distributed_port,
val_metric_name=val_metric_name,
minimize_val_metric=minimize_val_metric,
auto_aggregate_val_metric=auto_aggregate_val_metric,
callbacks=callbacks,
remove_stale_checkpoints=remove_stale_checkpoints,
)
================================================
FILE: tango/integrations/transformers/ia3.py
================================================
import re
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_utils import Conv1D
@dataclass
class WithIA3Config:
"""
A class for configuring which layers to modify with IA3 adaptors.
:param ia3_param_names:
A string used as the name for all ia3 parameters
:param attention_modules:
A regex that matches all attention modules which are parents to the keys and value layers to modify.
:param mlp_modules:
A regex that matches all modules that are parents to the feed forward layer to modify.
:param mlp_layers:
A regex that matches the feed forward layer in the modules specified by `mlp_modles`.
:param fused_qkv_layers:
A regex that matches the combined query, key, and value layer in the modules specified
by `attention_modules`.
:param k_layers:
A regex that matches the key layer in the modules specified by `attention_modules`.
:param v_layers:
A regex that matches the value layer in the modules specified by `attention_modules`.
"""
ia3_param_names: str
attention_modules: str
mlp_modules: str
mlp_layers: str
fused_qkv_layers: Optional[str] = None
k_layers: Optional[str] = None
v_layers: Optional[str] = None
GPT_J_IA3_CONFIG = WithIA3Config(
attention_modules=".*attn",
k_layers="k_proj",
v_layers="v_proj",
mlp_modules=".*mlp",
mlp_layers="fc_in",
ia3_param_names="ia3",
)
GPT_2_IA3_CONFIG = WithIA3Config(
attention_modules=".*attn",
fused_qkv_layers="c_attn",
mlp_modules=".*mlp",
mlp_layers="c_fc",
ia3_param_names="ia3",
)
OPT_IA3_CONFIG = WithIA3Config(
attention_modules=".*self_attn",
k_layers="k_proj",
v_layers="v_proj",
mlp_modules=r".*layers\.\d*",
mlp_layers="fc1",
ia3_param_names="ia3",
)
BLOOM_IA3_CONFIG = WithIA3Config(
attention_modules=".*self_attention",
fused_qkv_layers="query_key_value",
mlp_modules=".*mlp",
mlp_layers="dense_h_to_4h",
ia3_param_names="ia3",
)
MODEL_NAME_TO_CONFIG = {
"sshleifer/tiny-gpt2": GPT_2_IA3_CONFIG,
"gpt2": GPT_2_IA3_CONFIG,
"gpt2-medium": GPT_2_IA3_CONFIG,
"gpt2-large": GPT_2_IA3_CONFIG,
"gpt2-xl": GPT_2_IA3_CONFIG,
"bigscience/bloom-560m": BLOOM_IA3_CONFIG,
"bigscience/bloom-1b1": BLOOM_IA3_CONFIG,
"bigscience/bloom-1b7": BLOOM_IA3_CONFIG,
"bigscience/bloom-3b": BLOOM_IA3_CONFIG,
"bigscience/bloom-7b1": BLOOM_IA3_CONFIG,
"bigscience/bloom": BLOOM_IA3_CONFIG,
"facebook/opt-125m": OPT_IA3_CONFIG,
"facebook/opt-350m": OPT_IA3_CONFIG,
"facebook/opt-1.3b": OPT_IA3_CONFIG,
"facebook/opt-2.7b": OPT_IA3_CONFIG,
"facebook/opt-6.7b": OPT_IA3_CONFIG,
"facebook/opt-13b": OPT_IA3_CONFIG,
"facebook/opt-30b": OPT_IA3_CONFIG,
"facebook/opt-66b": OPT_IA3_CONFIG,
"EleutherAI/gpt-j-6B": GPT_J_IA3_CONFIG,
}
class WithIA3(nn.Module):
def __init__(self, ia3_param_names: str, unfuse_size: Optional[int] = None):
super().__init__()
self.ia3_param_names = ia3_param_names
# if (q,k,v) are stacked into one layer
if unfuse_size is not None:
# IA3 only operates on k and v (not q), thus the "* 2"
setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1)))
else:
setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) # type: ignore
def scale_by_ia3(self, x):
ia3_params = getattr(self, self.ia3_param_names)
if ia3_params.requires_grad:
if self.unfuse_size is not None:
# non_q means k and v
q, non_q = x[:, :, : self.unfuse_size], x[:, :, self.unfuse_size :] # type: ignore
ia3_params = getattr(self, self.ia3_param_names)
non_q = non_q * ia3_params.flatten()
x = torch.cat([q, non_q], dim=2)
else:
x = x * ia3_params.flatten()
return x
class LinearWithIA3(WithIA3):
def __init__(
self, linear_layer: nn.Linear, ia3_param_names: str, unfuse_size: Optional[int] = None
):
"""
A replacement for :class:`~torch.nn.Linear` modified with an IA3 adaptor
:param linear_layer:
A :class:`~torch.nn.Linear` layer to adapt.
:param ia3_param_names:
A `str` to use as the name of ia3 parameters.
:param unfuse_size:
An `int` indicating hidden dimension of the query, key, and value vectors.
To be used only when the layer to modify is a fused projection of query,
key, and value vectors in an attention mechanism.
"""
assert unfuse_size is None or (linear_layer.out_features == unfuse_size * 3)
self.in_features = linear_layer.in_features
self.out_features = linear_layer.out_features
self.unfuse_size = unfuse_size
super().__init__(ia3_param_names, unfuse_size)
self.weight = linear_layer.weight
self.bias = linear_layer.bias
def forward(self, x):
x = F.linear(x, self.weight, self.bias)
return self.scale_by_ia3(x)
class Conv1DWithIA3(WithIA3):
def __init__(
self, conv1d_layer: Conv1D, ia3_param_names: str, unfuse_size: Optional[int] = None
):
"""
A replacement for :class:`~transformers.modeling_utils.Conv1D` modified with an IA3 adaptor
:param conv1d_layer:
A :class:`~transformers.modeling_utils.Conv1D` layer to adapt.
:param ia3_param_names:
A `str` to use as the name of ia3 parameters.
:param unfuse_size:
An `int` indicating hidden dimension of the query, key, and value vectors.
To be used only when the layer to modify is a fused projection of query,
key, and value vectors in an attention mechanism.
"""
assert unfuse_size is None or (conv1d_layer.nf == unfuse_size * 3)
# nf: number of output features; nx: number of input features
self.out_features = conv1d_layer.nf
self.unfuse_size = unfuse_size
super().__init__(ia3_param_names, unfuse_size)
self.weight = conv1d_layer.weight
self.bias = conv1d_layer.bias
def forward(self, x):
# copied and pasted from the original Conv1D implemnetation
size_out = x.size()[:-1] + (self.out_features,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out) # ... * self.nf
return self.scale_by_ia3(x)
def modify_with_ia3(
transformer: PreTrainedModel,
*,
config: Optional[WithIA3Config] = None,
only_ia3_requires_grad: bool = True,
) -> PreTrainedModel:
"""
A function to add ia3 adaptors to the given transformer. Code modified from
`t-few `_
and Qinyuan Ye
:param model:
A :class:`~transformers.PreTrainedModel` to modify.
:param config:
A :class:`~tango.integrations.transformers.ia3.WithIA3Config` that specifies the layers to modify.
:param only_ia3_requires_grad:
A `bool`, `True` if `requires_grad` should only be set on ia3 paramenters in the output model.
Examples
--------
You can use the provided configurations:
.. testcode::
from transformers import AutoModelForCausalLM, AutoTokenizer
from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG
model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
model = modify_with_ia3(model, config=GPT_2_IA3_CONFIG)
Or you can write your own configuration with regex matching the layers to modify and their parents:
.. testcode::
from transformers import AutoModelForCausalLM, AutoTokenizer
from tango.integrations.transformers.ia3 import modify_with_ia3
my_config = WithIA3Config(
attention_modules=".*attn",
fused_qkv_layers="c_attn",
mlp_modules=".*mlp",
mlp_layers="c_fc",
ia3_param_names="ia3",
)
model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
model = modify_with_ia3(model, config=my_config)
"""
if config is None:
model_name = transformer.config._name_or_path # type: ignore
assert (
model_name in MODEL_NAME_TO_CONFIG
), f"{model_name} does not have a pre made configuration; please make your own."
config = MODEL_NAME_TO_CONFIG[model_name]
for m_name, module in dict(transformer.named_modules()).items(): # type: ignore
if re.fullmatch(config.attention_modules, m_name) or re.fullmatch(
config.mlp_modules, m_name
):
attn_layers = [
regex
for regex in (config.fused_qkv_layers, config.k_layers, config.v_layers)
if regex is not None
]
layers_to_change = (
"|".join(attn_layers)
if re.fullmatch(config.attention_modules, m_name)
else config.mlp_layers
)
for c_name, layer in dict(module.named_children()).items():
if re.fullmatch(layers_to_change, c_name):
assert isinstance(layer, Conv1D) or isinstance(
layer, nn.Linear
), "This code only supports Conv1D and nn.Linear"
adaptor_class = Conv1DWithIA3 if isinstance(layer, Conv1D) else LinearWithIA3
new_module = adaptor_class(
layer,
config.ia3_param_names,
unfuse_size=transformer.config.hidden_size # type: ignore
if config.fused_qkv_layers and re.fullmatch(config.fused_qkv_layers, c_name)
else None,
)
setattr(module, c_name, new_module)
if only_ia3_requires_grad:
transformer.requires_grad_(False) # type: ignore
for p_name, v in dict(transformer.named_parameters()).items(): # type: ignore
if re.fullmatch(".*" + config.ia3_param_names + ".*", p_name):
v.requires_grad_(True)
return transformer
================================================
FILE: tango/integrations/transformers/model.py
================================================
from typing import Optional, Type
from transformers.models.auto import modeling_auto
from tango.common.exceptions import IntegrationMissingError
from tango.integrations.torch.model import Model
from .config import Config
def auto_model_wrapper_factory(cls: type) -> Type[Model]:
class AutoModelWrapper(cls, Model): # type: ignore
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, config: Optional[Config] = None, **kwargs
) -> Model:
return super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)
@classmethod
def from_config(cls, config: Config, **kwargs) -> Model:
return super().from_config(config, **kwargs)
return AutoModelWrapper
for name, cls in modeling_auto.__dict__.items():
if isinstance(cls, type) and name.startswith("AutoModel"):
wrapped_cls = auto_model_wrapper_factory(cls)
Model.register(
"transformers::" + name + "::from_pretrained", constructor="from_pretrained"
)(wrapped_cls)
Model.register("transformers::" + name + "::from_config", constructor="from_config")(
wrapped_cls
)
try:
from transformers.models.auto import modeling_flax_auto
from tango.integrations.flax.model import Model as FlaxModel
def flax_auto_model_wrapper_factory(cls: type) -> Type[FlaxModel]:
class AutoModelWrapper(cls, FlaxModel): # type: ignore
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, config: Optional[Config] = None, **kwargs
) -> FlaxModel:
return super().from_pretrained(
pretrained_model_name_or_path, config=config, **kwargs
)
@classmethod
def from_config(cls, config: Config, **kwargs) -> FlaxModel:
return super().from_config(config, **kwargs)
return AutoModelWrapper
for name, cls in modeling_flax_auto.__dict__.items():
if isinstance(cls, type) and name.startswith("FlaxAutoModel"):
wrapped_cls_ = flax_auto_model_wrapper_factory(cls)
FlaxModel.register(
"transformers::" + name + "::from_pretrained", constructor="from_pretrained"
)(wrapped_cls_)
FlaxModel.register(
"transformers::" + name + "::from_config", constructor="from_config"
)(wrapped_cls_)
except ModuleNotFoundError:
pass
except IntegrationMissingError:
pass
================================================
FILE: tango/integrations/transformers/optim.py
================================================
import torch
from transformers import optimization as transformers_optim
from tango.integrations.torch.optim import LRScheduler, Optimizer
# Register all transformers optimizers.
for name, cls in transformers_optim.__dict__.items():
if (
isinstance(cls, type)
and issubclass(cls, torch.optim.Optimizer)
and not cls == torch.optim.Optimizer
):
Optimizer.register("transformers::" + name)(cls)
# Register all transformers scheduler factory functions.
for scheduler_type, scheduler_func in transformers_optim.TYPE_TO_SCHEDULER_FUNCTION.items():
name = scheduler_type.value
LRScheduler.register("transformers::" + name)(scheduler_func) # type: ignore
================================================
FILE: tango/integrations/transformers/run_generation.py
================================================
import logging
import typing
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union, cast
import more_itertools
import torch
from datasets import Dataset
from datasets import DatasetDict as HfDatasetDict
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
CTRLLMHeadModel,
CTRLTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
TransfoXLLMHeadModel,
TransfoXLTokenizer,
XLMTokenizer,
XLMWithLMHeadModel,
XLNetLMHeadModel,
XLNetTokenizer,
)
from tango import Format, JsonFormat, SqliteDictFormat, Step
from tango.common import DatasetDict
from tango.common.sequences import MappedSequence, SqliteSparseSequence
from tango.common.tqdm import Tqdm
from tango.integrations.torch import Model
from tango.integrations.torch.util import resolve_device, set_seed_all
logger = logging.getLogger(__name__)
#
# A lot of the code in this step is stolen from the run_generation.py script in transformers. Unfortunately their
# examples don't ship when you `pip install transformers`, so we have to duplicate it here.
#
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
MODEL_CLASSES = {
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. """
SEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys() # type: ignore
CAUSAL = AutoModelForCausalLM._model_mapping.keys() # type: ignore
def adjust_length_to_model(length, model):
max_sequence_length = (
model.config.max_position_embeddings
if hasattr(model.config, "max_position_embeddings")
else MAX_LENGTH
)
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
@typing.no_type_check # mypy has somehow lost the ability to tell what PreTrainedTokenizer and Model are.
def _generate(
model: Model,
# TODO: Change type to `Tokenizer` once HF includes `convert_tokens_to_ids` in `PretrainedTokenizerBase` class.
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompts: Iterable[str],
*,
batch_size: int = 4,
max_length: int = 20,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
k: int = 0,
p: float = 0.9,
prefix: str = "",
xlm_language: str = "",
seed: int = 42,
num_return_sequences: int = 1,
fp16: bool = False,
) -> Iterable[List[str]]:
if not isinstance(model.config, tuple(SEQ2SEQ + CAUSAL)):
raise NotImplementedError(
"This function is only defined for huggingface models seq2seq/causal models."
)
device = resolve_device()
set_seed_all(seed)
tokenizer_kwargs: Dict[str, Any] = {}
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.eos_token is None:
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
# Seq2Seq models don't return their own prefix.
seq2seq_model = model.config_class in SEQ2SEQ
# HF does not do this? WTF?
model.eval()
model.to(device)
if fp16:
model.half()
def prepare_batch_without_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]:
result = tokenizer.batch_encode_plus(
prompts,
add_special_tokens=False,
return_tensors="pt",
padding=True,
**tokenizer_kwargs,
)
result = {key: tensor.to(device) for key, tensor in result.items()}
return result
def prepare_batch_with_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]:
if len(prefix) > 0:
prompts = [f"{prefix} {t}" for t in prompts]
return prepare_batch_without_prefix(prompts)
prepare_batch_fn = prepare_batch_with_prefix
num_prefix_tokens: Optional[int] = None
# transformer model-specific exceptions
if isinstance(model, PreTrainedModel) and model.config_class:
if model.config_class.model_type == "xlm":
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
if hasattr(model.config, "lang2id") and use_lang_emb:
model.config.lang_id = xlm_language
# Original HF code ignores the prefix, but it looks like a bug?
prepare_batch_fn = prepare_batch_without_prefix
num_prefix_tokens = 0
elif model.config_class.model_type in {"xlnet", "transfo-xl"}:
prefix = prefix if prefix else PREFIX
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
# This actually doesn't work in the current version of transformers, which is probably a bug in the
# transformers library.
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
if num_prefix_tokens is None:
num_prefix_tokens = len(tokenizer.tokenize(prefix))
batches = more_itertools.chunked(Tqdm.tqdm(prompts, desc="Pre-processing prompts"), batch_size)
encoded_batches = map(prepare_batch_fn, batches)
for encoded_batch in Tqdm.tqdm(encoded_batches, desc="Processing batches"):
if seq2seq_model:
length = max_length
else:
length = adjust_length_to_model(max_length + encoded_batch["input_ids"].size(1), model)
with torch.inference_mode():
generated_sequences: torch.Tensor = model.generate( # type: ignore
**encoded_batch,
max_length=length,
temperature=temperature,
top_k=k,
top_p=p,
repetition_penalty=repetition_penalty,
do_sample=True,
num_return_sequences=num_return_sequences,
synced_gpus=False, # Needs to be True if we have more than one GPU running.
)
generated_sequences = generated_sequences.view(
-1, num_return_sequences, *generated_sequences.shape[1:]
).to("cpu")
def strip_special_tokens(t: torch.Tensor) -> torch.Tensor:
# amazing that torch has no capability for this
start = 0
while start < len(t) and int(t[start]) in {0, eos_token_id, pad_token_id}:
start += 1
end = len(t)
while int(t[end - 1]) in {0, eos_token_id, pad_token_id} and end > start:
end -= 1
return t[start:end]
# strip padding
generated_sequences_list = [
[strip_special_tokens(sequence) for sequence in per_prompt_sequences]
for per_prompt_sequences in generated_sequences
]
# strip prefix
if not seq2seq_model:
generated_sequences_list = [
[sequence[num_prefix_tokens:] for sequence in per_prompt_sequences]
for per_prompt_sequences in generated_sequences_list
]
texts = [
tokenizer.batch_decode(per_prompt_sequences, clean_up_tokenization_spaces=True)
for per_prompt_sequences in generated_sequences_list
]
yield from texts
def _generate_with_model_name(model_name: str, *args, **kwargs) -> Iterable[List[str]]:
try:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
except ValueError:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return _generate(model, tokenizer, *args, **kwargs)
@Step.register("transformers::run_generation")
class RunGeneration(Step[Iterable[List[str]]]):
"""
A step that runs seq2seq Huggingface models in inference mode.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "transformers::run_generation".
"""
FORMAT: Format = JsonFormat("gz")
VERSION = "001"
SKIP_ID_ARGUMENTS = {"batch_size"}
# TODO: multiple GPUs
def run( # type: ignore
self,
model: Union[str, Model],
prompts: Iterable[str],
*,
tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
batch_size: int = 4,
max_length: int = 20,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
k: int = 0,
p: float = 0.9,
prefix: str = "",
xlm_language: str = "",
seed: int = 42,
num_return_sequences: int = 1,
fp16: bool = False,
) -> Iterable[List[str]]:
"""
Run a Huggingface seq2seq model in inference mode.
:param model:
The name of the model to run. Any name that works in the transformers library works here.
Or, you can directly provide the model to run.
:param prompts:
The prompts to run through the model. You can specify prompts directly in the config, but
more commonly the prompts are produced by another step that reads a dataset, for example.
:param tokenizer:
The tokenizer to run.
:param batch_size:
The number of sequences to process at one time. This has no bearing on the output, so
you can change this number without invalidating cached results.
:param max_length:
The maximum number of tokens/word pieces that the model will generate. For models that
extend the prompt, the prefix does not count towards this limit.
:param temperature:
Passed directly to transformer's ``generate()`` method.
The value used to model the next token probabilities.
:param repetition_penalty:
Passed directly to transformer's ``generate()`` method.
The parameter for repetition penalty. 1.0 means no penalty.
:param k:
Passed directly to transformer's ``generate()`` method.
The number of highest probability vocabulary tokens to keep for top-k-filtering.
:param p:
Passed directly to transformer's ``generate()`` method.
If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher
are kept for generation.
:param prefix:
A prefix that gets pre-pended to all prompts.
:param xlm_language:
For the XLM model, this is a way to specify the language you want to use.
:param seed:
Random seed
:param num_return_sequences:
The number of generations to return for each prompt.
:param fp16:
Whether to use 16-bit floats.
:returns:
Returns an iterator of lists of string. Each list contains the predictions for one prompt.
"""
if isinstance(model, str):
try:
model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model))
except ValueError:
model = cast(Model, AutoModelForCausalLM.from_pretrained(model))
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path)
return _generate(
model,
tokenizer,
prompts,
batch_size=batch_size,
max_length=max_length,
temperature=temperature,
repetition_penalty=repetition_penalty,
k=k,
p=p,
prefix=prefix,
xlm_language=xlm_language,
seed=seed,
num_return_sequences=num_return_sequences,
fp16=fp16,
)
@Step.register("transformers::run_generation_dataset")
class RunGenerationDataset(Step[DatasetDict]):
"""
A step that runs seq2seq Huggingface models in inference mode.
This is similar to :class:`RunGeneration`, but it takes a dataset as input and produces
a new dataset as output, which contains the predictions in a new field.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "transformers::run_generation_dataset".
"""
FORMAT: Format = SqliteDictFormat()
VERSION = "002"
SKIP_ID_ARGUMENTS = {"batch_size"}
def run( # type: ignore
self,
model: Union[str, Model],
input: Union[DatasetDict, HfDatasetDict],
prompt_field: str,
*,
tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
output_field: Optional[str] = None,
splits: Optional[Union[str, Set[str]]] = None,
batch_size: int = 4,
max_length: int = 20,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
k: int = 0,
p: float = 0.9,
prefix: str = "",
xlm_language: str = "",
seed: int = 42,
num_return_sequences: int = 1,
fp16: bool = False,
) -> DatasetDict:
"""
Augment an input dataset with generations from a Huggingface seq2seq model.
:param model:
The name of the model to run. Any name that works in the transformers library works here.
Or, you can directly provide the model to run.
:param input:
The input dataset.
:param prompt_field:
The field in the dataset that contains the text of the prompts.
:param tokenizer:
The tokenizer to run.
:param output_field:
The field in the dataset that we will write the predictions into. In the result, this field
will contain ``List[str]``.
:param splits:
A split, or set of splits, to process. If this is not specified, we will process all splits.
:param batch_size:
The number of sequences to process at one time. This has no bearing on the output, so
you can change this number without invalidating cached results.
:param max_length:
The maximum number of tokens/word pieces that the model will generate. For models that
extend the prompt, the prefix does not count towards this limit.
:param temperature:
Passed directly to transformer's `generate()` method.
The value used to model the next token probabilities.
:param repetition_penalty:
Passed directly to transformer's `generate()` method.
The parameter for repetition penalty. 1.0 means no penalty.
:param k:
Passed directly to transformer's `generate()` method.
The number of highest probability vocabulary tokens to keep for top-k-filtering.
:param p:
Passed directly to transformer's `generate()` method.
If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher
are kept for generation.
:param prefix:
A prefix that gets pre-pended to all prompts.
:param xlm_language:
For the XLM model, this is a way to specify the language you want to use.
:param seed:
Random seed
:param num_return_sequences:
The number of generations to return for each prompt.
:param fp16:
Whether to use 16-bit floats.
:returns:
Returns a dataset with an extra field containing the predictions.
"""
if isinstance(model, str):
try:
model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model))
except ValueError:
model = cast(Model, AutoModelForCausalLM.from_pretrained(model))
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path)
if isinstance(input, HfDatasetDict):
input = DatasetDict(input, {})
if splits is None:
splits = input.keys()
elif isinstance(splits, str):
splits = {splits}
result: Dict[str, Sequence] = {}
for split_name, input_split in input.items():
if split_name in splits:
output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite")
if len(output_split) > 0:
logger.info(
"Found %d items already generated. Will generate %d more.",
len(output_split),
len(input_split) - len(output_split),
)
if len(output_split) > 0:
if isinstance(input_split, Dataset):
input_split = input_split.select(range(len(output_split), len(input_split)))
else:
input_split = input_split[len(output_split) :]
prompts = MappedSequence(lambda i: i[prompt_field], input_split)
generations = _generate(
model,
tokenizer,
prompts,
batch_size=batch_size,
max_length=max_length,
temperature=temperature,
repetition_penalty=repetition_penalty,
k=k,
p=p,
prefix=prefix,
xlm_language=xlm_language,
seed=seed,
num_return_sequences=num_return_sequences,
fp16=fp16,
)
for instance, generation in zip(input_split, generations):
output_split.append(
{**instance, **{output_field or prompt_field + "_generated": generation}}
)
result[split_name] = output_split
else:
result[split_name] = input_split
return DatasetDict(result, input.metadata)
================================================
FILE: tango/integrations/transformers/soft_prompt.py
================================================
import inspect
import logging
import random
from typing import Any, Dict, Optional
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
CausalLMOutputWithCrossAttentions,
Seq2SeqModelOutput,
)
from tango.integrations.torch import Model
logger = logging.getLogger(__name__)
def _get_bound_args_with_decorators(fn, *args, **kwargs):
while True:
try:
fn = fn.__wrapped__
except AttributeError:
break
signature = inspect.Signature.from_callable(fn)
return signature.bind(*args, **kwargs)
def add_soft_prompt(
model: Model,
prompt_length: int,
*,
only_prompt_is_trainable: bool = True,
initialize_from_top_embeddings: Optional[int] = 5000,
random_seed: int = 1940,
) -> None:
"""
Takes a regular huggingface transformer, and equips it with a soft prompt.
Example:
.. testcode::
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
generated = model.generate(tokenizer.encode("It was the best of times.", return_tensors="pt"))
original_output = tokenizer.decode(generated[0])
add_soft_prompt(model, prompt_length=3)
generated = model.generate(tokenizer.encode("It was the best of times.", return_tensors="pt"))
prompted_output = tokenizer.decode(generated[0])
:param model: the original huggingface transformer. This model is augmented in-place!
:param prompt_length: the length of the soft prompt, in tokens
:param only_prompt_is_trainable: freezes the original model's weights, leaving only the prompt trainable
:param initialize_from_top_embeddings: Prompt embeddings are initialized from a random selection of the top n
word piece embeddings from the original model. This is how you set n.
:param random_seed: random seed used to initialize the prompt embeddings
"""
assert isinstance(model, PreTrainedModel)
original_embedding: nn.Embedding = model.get_input_embeddings() # type: ignore
prompt_embedding = nn.Parameter(
torch.empty(
1,
prompt_length,
original_embedding.embedding_dim,
dtype=original_embedding.weight.dtype,
device=original_embedding.weight.device,
)
)
r = random.Random(random_seed)
if initialize_from_top_embeddings is None:
initialize_from_top_embeddings = original_embedding.num_embeddings
indices = torch.tensor(r.sample(range(initialize_from_top_embeddings), prompt_length))
with torch.no_grad():
prompt_embedding.copy_(original_embedding(indices).unsqueeze(0))
if only_prompt_is_trainable:
for param in model.parameters():
param.requires_grad = False
# find unique parameter name
parameter_name = "prompt_embedding"
parameter_name_index = 0
while True:
try:
model.get_parameter(parameter_name)
except AttributeError:
break
parameter_name_index += 1
parameter_name = f"prompt_embedding_{parameter_name_index}"
model.register_parameter(parameter_name, prompt_embedding)
def patch_tensor(kwargs: Dict[str, torch.Tensor], key: str, value: Any = 0) -> None:
t = kwargs.get(key)
if t is None:
return
prefix = t.new_full((t.size(0), prompt_length) + t.shape[2:], value)
kwargs[key] = torch.cat([prefix, t], dim=1)
def patch_tensor_with_indices(
kwargs: Dict[str, torch.Tensor], key: str, offset: int = 0
) -> None:
t = kwargs.get(key)
if t is None:
return
kwargs[key] = torch.cat(
[
torch.arange(0, prompt_length, dtype=t.dtype)
.unsqueeze(0)
.expand(t.size(0), prompt_length),
t + offset,
],
dim=1,
)
old_forward = model.forward
def new_forward(*args, **kwargs):
# Massage the input to include the prompt
if kwargs.get("past_key_values") is not None:
# If we have already been running this model, we don't need to do anything with the prefix now.
return old_forward(*args, **kwargs)
if kwargs.get("encoder_outputs") is not None:
# For encoder/decoder models, this runs only on the encoder. If we already have encoder outputs,
# we don't have to do anything.
return old_forward(*args, **kwargs)
inputs_embeds: Optional[torch.Tensor] = None
input_ids = kwargs.pop("input_ids", None)
if input_ids is not None:
inputs_embeds = original_embedding(input_ids)
inputs_embeds = kwargs.get("inputs_embeds", inputs_embeds)
if inputs_embeds is not None:
kwargs["inputs_embeds"] = torch.cat(
[prompt_embedding.expand(inputs_embeds.size(0), -1, -1), inputs_embeds], dim=1
)
patch_tensor(kwargs, "labels")
patch_tensor(kwargs, "attention_mask", 1)
patch_tensor(kwargs, "token_type_ids")
patch_tensor_with_indices(kwargs, "position_ids", prompt_length)
# Run the model
result = old_forward(*args, **kwargs)
# Massage the output to look like the prompt was never there
unpatch_tensor = lambda t: t[:, prompt_length:] # noqa: E731
unpatch_attention_tensor = lambda t: t[:, :, prompt_length:] # noqa: E731
unpatch_kv_tensor = unpatch_attention_tensor
if isinstance(result, CausalLMOutputWithCrossAttentions):
if result.logits is not None:
result.logits = unpatch_tensor(result.logits)
if result.hidden_states is not None:
result.hidden_states = tuple(map(unpatch_tensor, result.hidden_states))
if result.attentions is not None:
result.attentions = tuple(map(unpatch_attention_tensor, result.attentions))
if result.cross_attentions is not None:
result.cross_attentions = tuple(
map(unpatch_attention_tensor, result.cross_attentions)
)
return result
elif isinstance(result, Seq2SeqModelOutput):
if result.last_hidden_state is not None:
result.last_hidden_state = unpatch_tensor(result.last_hidden_state)
if result.past_key_values is not None:
result.past_key_values = tuple(map(unpatch_kv_tensor, result.past_key_values))
if result.encoder_hidden_states is not None:
result.hidden_states = tuple(map(unpatch_tensor, result.hidden_states))
if result.encoder_attentions is not None:
result.attentions = tuple(map(unpatch_attention_tensor, result.attentions))
if result.cross_attentions is not None:
result.cross_attentions = tuple(
map(unpatch_attention_tensor, result.cross_attentions)
)
return result
else:
logger.warning(
"Unexpected result type from the transformer in soft_prompt_transformer: `%s`",
result.__class__,
)
return result
model.forward = new_forward # type: ignore
# For encoder/decoder models, HF doesn't call `forward()` like it should when you use `generate()`. Instead, it
# calls the encoder separately, and then passes the results into `forward()`. So in that case, we have to patch
# this too.
if model.config.is_encoder_decoder:
old_generate = model.generate
def new_generate(*args, **kwargs):
args = (model,) + args
ba = _get_bound_args_with_decorators(old_generate, *args, **kwargs)
del ba.arguments["self"]
if "encoder_outputs" in ba.arguments:
# For encoder/decoder models, this runs only on the encoder. If we already have encoder outputs,
# we don't have to do anything.
return old_generate(*ba.args, **ba.kwargs)
inputs_embeds: Optional[torch.Tensor] = None
inputs = ba.arguments.pop("inputs", None)
if inputs is not None:
inputs_embeds = original_embedding(inputs)
inputs_embeds = ba.arguments.pop("inputs_embeds", inputs_embeds)
if inputs_embeds is not None:
inputs_embeds = torch.cat(
[prompt_embedding.expand(inputs_embeds.size(0), -1, -1), inputs_embeds], dim=1
)
assert callable(model.get_encoder)
encoder = model.get_encoder()
kwargs = ba.kwargs
kwargs["encoder_outputs"] = encoder(inputs_embeds=inputs_embeds, return_dict=True)
return old_generate(*ba.args, **kwargs)
model.generate = new_generate # type: ignore
def _with_soft_prompt(
model: Model,
prompt_length: int,
*,
only_prompt_is_trainable: bool = True,
initialize_from_top_embeddings: Optional[int] = 5000,
random_seed: int = 1940,
) -> Model:
"""To initialize a soft-prompt model as a Registrable (i.e., to use it from a config file), we need a variant
of this function that returns the resulting model. This is that variant."""
add_soft_prompt(
model,
prompt_length,
only_prompt_is_trainable=only_prompt_is_trainable,
initialize_from_top_embeddings=initialize_from_top_embeddings,
random_seed=random_seed,
)
return model
Model.register("transformers::with_soft_prompt")(_with_soft_prompt) # type: ignore
================================================
FILE: tango/integrations/transformers/tokenizer.py
================================================
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from tango.common import Registrable
class Tokenizer(PreTrainedTokenizerBase, Registrable):
"""
A :class:`~tango.common.Registrable` version of transformers'
:class:`~transformers.PreTrainedTokenizerBase`.
"""
default_implementation = "auto"
"""
The default registered implementation just calls
:meth:`transformers.AutoTokenizer.from_pretrained()`.
"""
Tokenizer.register("auto", constructor="from_pretrained")(AutoTokenizer)
================================================
FILE: tango/integrations/wandb/__init__.py
================================================
"""
.. important::
To use this integration you should install ``tango`` with the "wandb" extra
(e.g. ``pip install tango[wandb]``) or just install the ``wandb`` library after the fact
(e.g. ``pip install wandb``).
Components for Tango integration with `Weights & Biases `_.
Overview
--------
The main components provided by this integration are the :class:`WandbWorkspace` and
the :class:`WandbTrainCallback`.
The :class:`WandbWorkspace` is a :class:`~tango.workspace.Workspace` implementation that is
great for collaboration. It tracks Tango runs and steps in the W&B project of your choosing
and uses W&B Artifacts to cache step results in the cloud so that they're accessible anywhere.
And if you're training PyTorch models via the :class:`~tango.integrations.torch.TorchTrainStep`,
you can use the :class:`WandbTrainCallback` to track metrics throughout the run.
"""
from tango.common.exceptions import IntegrationMissingError
try:
import wandb
except ModuleNotFoundError:
raise IntegrationMissingError("wandb")
__all__ = ["WandbWorkspace", "WandbStepCache"]
from .step_cache import WandbStepCache
from .workspace import WandbWorkspace
try:
import torch
except ModuleNotFoundError:
pass
else:
from .torch_train_callback import WandbTrainCallback
__all__.append("WandbTrainCallback")
try:
import flax
import jax
import tensorflow # flax has a tensorflow dependency
except ModuleNotFoundError:
pass
else:
from .flax_train_callback import WandbFlaxTrainCallback
__all__.append("WandbFlaxTrainCallback")
================================================
FILE: tango/integrations/wandb/flax_train_callback.py
================================================
from typing import Any, Dict, List, Optional
import jax
import wandb
from flax import jax_utils
from tango.common.exceptions import ConfigurationError
from tango.integrations.flax.train_callback import TrainCallback
from .workspace import WandbWorkspace
@TrainCallback.register("wandb::log_flax")
class WandbFlaxTrainCallback(TrainCallback):
"""
A flax :class:`~tango.integrations.flax.TrainCallback` for use with
the :class:`~tango.integrations.flax.FlaxTrainStep` that logs training and
validation metrics to W&B.
This can be used with any :class:`~tango.workspace.Workspace` implementation,
including :class:`WandbWorkspace`.
.. tip::
Registered as a :class:`~tango.integrations.flax.TrainCallback`
under the name "wandb::log_flax".
.. important::
When this callback is used with the :class:`WandbWorkspace` it will log metrics
to the same W&B project that the workspace uses. The ``group`` and ``name``
parameters will also automatically be set, so a :class:`~tango.common.exceptions.ConfigurationError`
will be raised if any of ``project``, ``entity``, ``group``, or ``name`` are set in this callback.
:param project:
W&B project to associated this run with.
:param entity:
W&B entity (user or organization) to associated this run with.
:param group:
W&B group to associated this run with.
:param name:
Set the name of the run in W&B. If not set, the default will be the name of the step.
:param notes:
Arbitrary notes to add in W&B to this run.
:param tags:
Arbitrary tags to add in W&B to this run.
:param watch_model:
If ``True``, ``wandb.watch()`` is called to collect gradients and other information
about the model throughout training.
See `docs.wandb.ai/ref/python/watch `_.
:param wandb_config:
Arbitrary configuration fields to set in W&B for this run.
See `docs.wandb.ai/guides/track/config `_.
"""
def __init__(
self,
*args,
project: Optional[str] = None,
entity: Optional[str] = None,
group: Optional[str] = None,
name: Optional[str] = None,
notes: Optional[str] = None,
tags: Optional[List[str]] = None,
watch_model: bool = False,
wandb_config: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
if isinstance(self.workspace, WandbWorkspace) or wandb.run is not None:
err_msg_template = "Cannot set '{var_name}' in WandbTrainCallback "
if isinstance(self.workspace, WandbWorkspace):
err_msg_template += "since it has already been set from the WandbWorkspace."
else:
err_msg_template += "since a W&B run has already been initialized."
for var, var_name in [
(project, "project"),
(entity, "entity"),
(group, "group"),
(name, "name"),
]:
if var is not None:
raise ConfigurationError(err_msg_template.format(var_name=var_name))
self.project = (
project if not isinstance(self.workspace, WandbWorkspace) else self.workspace.project
)
self.entity = (
entity if not isinstance(self.workspace, WandbWorkspace) else self.workspace.entity
)
self.group = group or self.step_id
self.notes = notes
self.tags = tags
self.watch_model = watch_model
self.wandb_config = self.train_config.as_dict()
if wandb_config is not None:
self.wandb_config.update(wandb_config)
if wandb.run is None:
self.wandb_config["job_type"] = "train_metrics"
self.run_name: str = name or self.step_name or "train"
self.run_id: str = (
wandb.run.id if wandb.run is not None else self.step_id # type: ignore[attr-defined]
)
self.resume: Optional[str] = None
self.should_finalize_run: bool = (
wandb.run is None
) # if we have to start out own W&B run, we need to finish it
def state_dict(self) -> Dict[str, Any]:
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.resume = "allow"
def pre_train_loop(self) -> None:
if wandb.run is None:
if self.run_id is None:
self.run_id = self.step_id
wandb.init(
id=self.run_id,
dir=str(self.work_dir),
project=self.project,
entity=self.entity,
group=self.group,
name=self.run_name,
notes=self.notes,
config=self.wandb_config,
tags=self.tags,
job_type="train_metrics",
)
else:
# We are already running inside of a W&B run, possibly because
# we're using the WandbWorkspace.
wandb.config.update(self.wandb_config)
if self.tags:
wandb.run.tags = (wandb.run.tags or tuple()) + tuple(self.tags)
if self.notes:
wandb.run.notes = self.notes
if self.watch_model:
wandb.watch(self.model)
def post_train_loop(self, step: int, epoch: int) -> None:
if self.should_finalize_run:
wandb.finish()
def log_batch(self, step: int, epoch: int, train_metrics: Dict) -> None:
if len(jax.devices()) > 1:
train_metrics = jax_utils.unreplicate(train_metrics)
metrics = {"train/loss": train_metrics["loss"], "epoch": epoch}
wandb.log(metrics, step=step + 1)
def post_val_loop(
self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float]
) -> None:
wandb.log(
{
f"val/{self.train_config.val_metric_name}": val_metric,
f"val/best_{self.train_config.val_metric_name}": best_val_metric,
"epoch": epoch,
},
step=step + 1,
)
================================================
FILE: tango/integrations/wandb/step_cache.py
================================================
import logging
from typing import Any, Optional, Union
import wandb
from retry import retry
from wandb.errors import Error as WandbError
from tango.common.aliases import PathOrStr
from tango.common.util import make_safe_filename, tango_cache_dir
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache
from tango.step_info import StepInfo
from .util import ArtifactKind, check_environment, is_missing_artifact_error
logger = logging.getLogger(__name__)
@StepCache.register("wandb")
class WandbStepCache(RemoteStepCache):
"""
This is a :class:`~tango.step_cache.StepCache` that's used by :class:`WandbWorkspace`.
It stores the results of steps on W&B as Artifacts.
It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a
step's resulting subsequent times should be fast.
:param project: The W&B project to use.
:param entity: The W&B entity (user or organization account) to use.
.. tip::
Registered as :class:`~tango.step_cache.StepCache` under the name "wandb".
"""
def __init__(self, project: str, entity: str):
check_environment()
super().__init__(
tango_cache_dir()
/ "wandb_cache"
/ make_safe_filename(entity)
/ make_safe_filename(project)
)
self.project = project
self.entity = entity
@property
def wandb_client(self) -> wandb.Api:
return wandb.Api(overrides={"entity": self.entity, "project": self.project})
@property
def client(self):
"""
To maintain compatibility
"""
return self.wandb_client
@property
def wandb_project_url(self) -> str:
"""
The URL of the W&B project this workspace uses.
"""
app_url = self.wandb_client.client.app_url
app_url = app_url.rstrip("/")
return f"{app_url}/{self.entity}/{self.project}"
def _step_artifact_name(self, step: Union[Step, StepInfo]) -> str:
if isinstance(step, Step):
return step.class_name
else:
return step.step_class_name
def _step_result_remote( # type: ignore
self, step: Union[Step, StepInfo]
) -> Optional[wandb.Artifact]:
artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value)
try:
return self.wandb_client.artifact(
f"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}",
type=artifact_kind,
)
except WandbError as exc:
if is_missing_artifact_error(exc):
return None
else:
raise
def create_step_result_artifact(self, step: Step, objects_dir: Optional[PathOrStr] = None):
self._upload_step_remote(step, objects_dir)
def get_step_result_artifact(self, step: Union[Step, StepInfo]) -> Optional[wandb.Artifact]:
artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value)
try:
return self.wandb_client.artifact(
f"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}",
type=artifact_kind,
)
except WandbError as exc:
if is_missing_artifact_error(exc):
return None
else:
raise
def _upload_step_remote(self, step: Step, objects_dir: Optional[PathOrStr] = None) -> Any:
"""
Create an artifact for the result of a step.
"""
artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value)
artifact = wandb.Artifact(self._step_artifact_name(step), type=artifact_kind)
# Add files
if objects_dir is not None:
artifact.add_dir(str(objects_dir))
# Log/persist the artifact to W&B.
artifact.save()
artifact.wait()
# Add an alias for the step's unique ID.
# Only after we've logged the artifact can we add an alias.
artifact.aliases.append(step.unique_id)
artifact.save()
artifact.wait()
def get_step_result_artifact_url(self, step: Union[Step, StepInfo]) -> str:
artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value)
return (
f"{self.wandb_project_url}/artifacts/{artifact_kind}"
f"/{self._step_artifact_name(step)}/{step.unique_id}"
)
@retry(exceptions=(wandb.errors.CommError,), delay=10, backoff=2, max_delay=120)
def use_step_result_artifact(self, step: Union[Step, StepInfo]) -> None:
"""
"Use" the artifact corresponding to the result of a step.
"""
if wandb.run is None:
raise RuntimeError("This can only be called from within a W&B run")
wandb.run.use_artifact(
f"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}"
)
def _download_step_remote(self, step_result, target_dir: PathOrStr):
try:
step_result.download(root=target_dir)
except (WandbError, ValueError):
raise RemoteNotFoundError()
def __len__(self) -> int:
completed_cacheable_step_runs = self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters={ # type: ignore
"config.job_type": "step",
"config.cacheable": True,
"state": "finished",
},
)
return len(list(completed_cacheable_step_runs))
================================================
FILE: tango/integrations/wandb/torch_train_callback.py
================================================
from typing import Any, Dict, List, Optional
import torch
import wandb
from tango.common.exceptions import ConfigurationError
from tango.integrations.torch.train_callback import TrainCallback
from tango.integrations.torch.util import peak_gpu_memory
from .util import check_environment
from .workspace import WandbWorkspace
@TrainCallback.register("wandb::log")
class WandbTrainCallback(TrainCallback):
"""
A torch :class:`~tango.integrations.torch.TrainCallback` for use with
the :class:`~tango.integrations.torch.TorchTrainStep` that logs training and
validation metrics to W&B.
This can be used with any :class:`~tango.workspace.Workspace` implementation,
including :class:`WandbWorkspace`.
.. tip::
Registered as a :class:`~tango.integrations.torch.TrainCallback`
under the name "wandb::log".
.. important::
When this callback is used with the :class:`WandbWorkspace` it will log metrics
to the same W&B project that the workspace uses. The ``group`` and ``name``
parameters will also automatically be set, so a :class:`~tango.common.exceptions.ConfigurationError`
will be raised if any of ``project``, ``entity``, ``group``, or ``name`` are set in this callback.
:param project:
W&B project to associated this run with.
:param entity:
W&B entity (user or organization) to associated this run with.
:param group:
W&B group to associated this run with.
:param name:
Set the name of the run in W&B. If not set, the default will be the name of the step.
:param notes:
Arbitrary notes to add in W&B to this run.
:param tags:
Arbitrary tags to add in W&B to this run.
:param watch_model:
If ``True``, ``wandb.watch()`` is called to collect gradients and other information
about the model throughout training.
See `docs.wandb.ai/ref/python/watch `_.
:param wandb_config:
Arbitrary configuration fields to set in W&B for this run.
See `docs.wandb.ai/guides/track/config `_.
"""
def __init__(
self,
*args,
project: Optional[str] = None,
entity: Optional[str] = None,
group: Optional[str] = None,
name: Optional[str] = None,
notes: Optional[str] = None,
tags: Optional[List[str]] = None,
watch_model: bool = False,
wandb_config: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
if self.is_local_main_process:
check_environment()
if isinstance(self.workspace, WandbWorkspace) or wandb.run is not None:
err_msg_template = "Cannot set '{var_name}' in WandbTrainCallback "
if isinstance(self.workspace, WandbWorkspace):
err_msg_template += "since it has already been set from the WandbWorkspace."
else:
err_msg_template += "since a W&B run has already been initialized."
for var, var_name in [
(project, "project"),
(entity, "entity"),
(group, "group"),
(name, "name"),
]:
if var is not None:
raise ConfigurationError(err_msg_template.format(var_name=var_name))
self.project = (
project if not isinstance(self.workspace, WandbWorkspace) else self.workspace.project
)
self.entity = (
entity if not isinstance(self.workspace, WandbWorkspace) else self.workspace.entity
)
self.group = group or self.step_id
self.notes = notes or self._get_default_notes()
self.tags = tags
self.watch_model = watch_model
self.wandb_config = self.train_config.as_dict()
del self.wandb_config["worker_id"]
if wandb_config is not None:
self.wandb_config.update(wandb_config)
if wandb.run is None:
self.wandb_config["job_type"] = "train_metrics"
self.run_name: str = name or self.step_name or "train"
if self.train_config.is_distributed:
self.run_name += f" (rank {self.train_config.worker_id})"
self.run_id: str = (
wandb.run.id # type: ignore[attr-defined]
if wandb.run is not None
else self.step_id + f"-rank{self.train_config.worker_id}"
)
self.resume: Optional[str] = None
self.should_finalize_run: bool = (
wandb.run is None
) # if we have to start out own W&B run, we need to finish it
def state_dict(self) -> Dict[str, Any]:
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.resume = "allow"
def pre_train_loop(self) -> None:
if wandb.run is None:
if self.run_id is None:
self.run_id = self.step_id + f"-rank{self.train_config.worker_id}"
# Initialize a new W&B run.
wandb.init(
id=self.run_id,
dir=str(self.work_dir),
project=self.project,
entity=self.entity,
group=self.group,
name=self.run_name,
notes=self.notes,
config=self.wandb_config,
tags=self.tags,
job_type="train_metrics",
)
else:
# We are already running inside of a W&B run, possibly because
# we're using the WandbWorkspace.
wandb.config.update(self.wandb_config)
if self.tags:
wandb.run.tags = (wandb.run.tags or tuple()) + tuple(self.tags)
if self.notes:
wandb.run.notes = self.notes
if self.watch_model:
wandb.watch(self.training_engine.model)
# Log GPU memory statistics.
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
peak_gpu_mbs = peak_gpu_memory()
if self.is_local_main_process:
metrics = {f"sys/worker{rank}_peak_gpu_mem": mbs for rank, mbs in peak_gpu_mbs.items()}
metrics["epoch"] = 0
wandb.log(metrics, step=0)
def post_train_loop(self, step: int, epoch: int) -> None:
if self.should_finalize_run:
wandb.finish()
def log_batch(
self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]]
) -> None:
peak_gpu_mbs = peak_gpu_memory()
if self.is_local_main_process:
metrics = {
"train/loss": batch_loss,
"train/lr": self.training_engine.optimizer.param_groups[0]["lr"],
"epoch": epoch,
}
metrics.update(
{f"sys/worker{rank}_peak_gpu_mem": mbs for rank, mbs in peak_gpu_mbs.items()}
)
wandb.log(
metrics,
step=step + 1,
)
def post_val_loop(
self, step: int, epoch: int, val_metric: float, best_val_metric: float
) -> None:
if self.is_local_main_process:
wandb.log(
{
f"val/{self.train_config.val_metric_name}": val_metric,
f"val/best_{self.train_config.val_metric_name}": best_val_metric,
"epoch": epoch,
},
step=step + 1,
)
def _get_default_notes(self) -> str:
notes = (
f'Metrics for Tango step "{self.step_name}" from worker {self.train_config.worker_id}.'
)
if isinstance(self.workspace, WandbWorkspace):
notes += f"\nMain run for step: {self.workspace.wandb_project_url}/runs/{self.step_id}/overview"
return notes
================================================
FILE: tango/integrations/wandb/util.py
================================================
import os
import re
import warnings
from enum import Enum
from wandb.errors import Error as WandbError
_API_KEY_WARNING_ISSUED = False
_SILENCE_WARNING_ISSUED = False
def is_missing_artifact_error(err: WandbError):
"""
Check if a specific W&B error is caused by a 404 on the artifact we're looking for.
"""
# This is brittle, but at least we have a test for it.
# This is a workaround for a bug in the wandb API
if err.message == "'NoneType' object has no attribute 'get'":
return True
if re.search(r"^artifact '.*' not found in '.*'$", err.message):
return True
return ("does not contain artifact" in err.message) or (
"Unable to fetch artifact with name" in err.message
)
def check_environment():
global _API_KEY_WARNING_ISSUED, _SILENCE_WARNING_ISSUED
if "WANDB_API_KEY" not in os.environ and not _API_KEY_WARNING_ISSUED:
warnings.warn(
"Missing environment variable 'WANDB_API_KEY' required to authenticate to Weights & Biases.",
UserWarning,
)
_API_KEY_WARNING_ISSUED = True
if "WANDB_SILENT" not in os.environ and not _SILENCE_WARNING_ISSUED:
warnings.warn(
"The Weights & Biases client may produce a lot of log messages. "
"You can silence these by setting the environment variable 'WANDB_SILENT=true'",
UserWarning,
)
_SILENCE_WARNING_ISSUED = True
class RunKind(Enum):
STEP = "step"
TANGO_RUN = "tango_run"
class ArtifactKind(Enum):
STEP_RESULT = "step_result"
================================================
FILE: tango/integrations/wandb/workspace.py
================================================
import logging
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, Optional, TypeVar, Union
from urllib.parse import ParseResult
import pytz
import wandb
from tango.common.exceptions import StepStateError
from tango.common.file_lock import FileLock
from tango.common.util import exception_to_string, tango_cache_dir, utc_now_datetime
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, Workspace
from .step_cache import WandbStepCache
from .util import RunKind, check_environment
T = TypeVar("T")
logger = logging.getLogger(__name__)
@Workspace.register("wandb")
class WandbWorkspace(Workspace):
"""
This is a :class:`~tango.workspace.Workspace` that tracks Tango runs in a W&B project.
It also stores step results as W&B Artifacts via :class:`WandbStepCache`.
Each Tango run with this workspace will generate multiple runs in your W&B project.
There will always be a W&B run corresponding to each Tango run with the same name,
which will contain some metadata about the Tango run. Then there will be one W&B run
for each cacheable step that runs with a name corresponding to the name of the step.
So if your Tango run includes 3 cacheable steps, that will result in a total of 4 new runs in W&B.
:param project: The W&B project to use for the workspace.
:param entity: The W&B entity (user or organization account) to use for the workspace.
.. tip::
Registered as a :class:`~tango.workspace.Workspace` under the name "wandb".
.. tip::
If you want to change the artifact kind for step result artifacts uploaded
to W&B, add a field called ``artifact_kind`` to the ``metadata`` of
the :class:`~tango.step.Step` class.
This can be useful if you want model objects to be added to the model zoo.
In that you would set ``artifact_kind = "model"``.
For example, your config for the step would look like this:
.. code-block::
{ type: "trainer", step_metadata: { artifact_kind: "model" }, ... }
Or just add this to the ``METADATA`` class attribute:
.. code-block::
@Step.register("trainer")
class TrainerStep(Step):
METADATA = {"artifact_kind": "model"}
"""
def __init__(self, project: str, entity: Optional[str] = None):
check_environment()
super().__init__()
self.project = project
self._entity = entity
self.cache = WandbStepCache(project=self.project, entity=self.entity)
self.steps_dir = tango_cache_dir() / "wandb_workspace"
self.locks: Dict[Step, FileLock] = {}
self._running_step_info: Dict[str, StepInfo] = {}
def __getstate__(self):
"""
We override `__getstate__()` to customize how instances of this class are pickled
since we don't want to persist certain attributes.
"""
out = super().__getstate__()
out["locks"] = {}
return out
@property
def wandb_client(self) -> wandb.Api:
overrides = {"project": self.project}
if self._entity is not None:
overrides["entity"] = self._entity
return wandb.Api(overrides=overrides)
@property
def entity(self) -> str:
return self._entity or self.wandb_client.default_entity
@property
def url(self) -> str:
return f"wandb://{self.entity}/{self.project}"
@classmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:
entity = parsed_url.netloc
project = parsed_url.path
if project:
project = project.strip("/")
return cls(project=project, entity=entity)
@property
def step_cache(self) -> StepCache:
return self.cache
@property
def wandb_project_url(self) -> str:
"""
The URL of the W&B project this workspace uses.
"""
app_url = self.wandb_client.client.app_url
app_url = app_url.rstrip("/")
return f"{app_url}/{self.entity}/{self.project}"
def _get_unique_id(self, step_or_unique_id: Union[Step, str]) -> str:
if isinstance(step_or_unique_id, Step):
unique_id = step_or_unique_id.unique_id
else:
unique_id = step_or_unique_id
return unique_id
def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path:
unique_id = self._get_unique_id(step_or_unique_id)
path = self.steps_dir / unique_id
path.mkdir(parents=True, exist_ok=True)
return path
def work_dir(self, step: Step) -> Path:
path = self.step_dir(step) / "work"
path.mkdir(parents=True, exist_ok=True)
return path
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
unique_id = self._get_unique_id(step_or_unique_id)
if unique_id in self._running_step_info:
return self._running_step_info[unique_id]
step_info = self._get_updated_step_info(
unique_id,
step_name=step_or_unique_id.name if isinstance(step_or_unique_id, Step) else None,
)
if step_info is None:
raise KeyError(step_or_unique_id)
else:
return step_info
def step_starting(self, step: Step) -> None:
if wandb.run is not None:
raise RuntimeError(
"There is already a W&B run initialized, cannot initialize another one."
)
work_dir = self.work_dir(step)
lock_path = self.step_dir(step) / "lock"
lock = FileLock(lock_path, read_only_ok=True)
lock.acquire_with_updates(desc=f"acquiring lock for '{step.name}'")
self.locks[step] = lock
step_info = self._get_updated_step_info(step.unique_id) or StepInfo.new_from_step(step)
if step_info.state not in {StepState.INCOMPLETE, StepState.FAILED, StepState.UNCACHEABLE}:
raise StepStateError(
step,
step_info.state,
context="If you are certain the step is not running somewhere else, delete the lock "
f"file at {lock_path}.",
)
try:
# Initialize W&B run for the step.
wandb.init(
name=step_info.step_name,
job_type=RunKind.STEP.value,
group=step.unique_id,
dir=str(work_dir),
entity=self.entity,
project=self.project,
# For cacheable steps we can just use the step's unique ID as the W&B run ID,
# but not for uncacheable steps since those might be ran more than once, and
# and will need a unique W&B run ID each time.
id=step.unique_id if step.cache_results else None,
resume="allow" if step.cache_results else None,
notes="\n".join(
[
f'Tango step "{step.name}"',
f"\N{bullet} type: {step_info.step_class_name}",
f"\N{bullet} ID: {step.unique_id}",
]
),
config={
"job_type": RunKind.STEP.value,
"_run_suite_id": self._generate_run_suite_id(), # used for testing only
},
)
assert wandb.run is not None
logger.info(
"Tracking '%s' step on Weights and Biases: %s/runs/%s/overview",
step.name,
self.wandb_project_url,
wandb.run.id,
)
# "Use" all of the result artifacts for this step's dependencies in order to declare
# those dependencies to W&B.
for dependency in step.dependencies:
self.cache.use_step_result_artifact(dependency)
# Update StepInfo to mark as running.
step_info.start_time = utc_now_datetime()
step_info.end_time = None
step_info.error = None
step_info.result_location = None
wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True)
self._running_step_info[step.unique_id] = step_info
except: # noqa: E722
lock.release()
del self.locks[step]
raise
def step_finished(self, step: Step, result: T) -> T:
if wandb.run is None:
raise RuntimeError(
f"{self.__class__.__name__}.step_finished() called outside of a W&B run. "
f"Did you forget to call {self.__class__.__name__}.step_starting() first?"
)
step_info = self._running_step_info.get(step.unique_id) or self._get_updated_step_info(
step.unique_id
)
if step_info is None:
raise KeyError(step.unique_id)
try:
if step.cache_results:
self.step_cache[step] = result
if hasattr(result, "__next__"):
assert isinstance(result, Iterator)
# Caching the iterator will consume it, so we write it to the
# cache and then read from the cache for the return value.
result = self.step_cache[step]
step_info.result_location = self.cache.get_step_result_artifact_url(step)
else:
# Create an empty artifact in order to build the DAG in W&B.
self.cache.create_step_result_artifact(step)
step_info.end_time = utc_now_datetime()
wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True)
# Finalize the step's W&B run.
wandb.finish()
finally:
self.locks[step].release()
del self.locks[step]
if step.unique_id in self._running_step_info:
del self._running_step_info[step.unique_id]
return result
def step_failed(self, step: Step, e: BaseException) -> None:
if wandb.run is None:
raise RuntimeError(
f"{self.__class__.__name__}.step_failed() called outside of a W&B run. "
f"Did you forget to call {self.__class__.__name__}.step_starting() first?"
)
step_info = self._running_step_info.get(step.unique_id) or self._get_updated_step_info(
step.unique_id
)
if step_info is None:
raise KeyError(step.unique_id)
try:
# Update StepInfo, marking the step as failed.
if step_info.state != StepState.RUNNING:
raise StepStateError(step, step_info.state)
step_info.end_time = utc_now_datetime()
step_info.error = exception_to_string(e)
wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True)
# Finalize the step's W&B run.
wandb.finish(exit_code=1)
finally:
self.locks[step].release()
del self.locks[step]
if step.unique_id in self._running_step_info:
del self._running_step_info[step.unique_id]
def remove_step(self, step_unique_id: str):
"""
Removes cached step using the given unique step id
:raises KeyError: If there is no step with the given name.
"""
raise NotImplementedError()
def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
all_steps = set(targets)
for step in targets:
all_steps |= step.recursive_dependencies
wandb_run_id: str
wandb_run_name: str
with tempfile.TemporaryDirectory() as temp_dir_name:
with wandb.init( # type: ignore[union-attr]
job_type=RunKind.TANGO_RUN.value,
entity=self.entity,
project=self.project,
name=name,
dir=temp_dir_name,
config={
"job_type": RunKind.TANGO_RUN.value, # need this in the config so we can filter runs by this
"_run_suite_id": self._generate_run_suite_id(), # used for testing only
},
) as wandb_run:
wandb_run_id = wandb_run.id
wandb_run_name = wandb_run.name # type: ignore[assignment]
logger.info("Registering run %s with Weights and Biases", wandb_run.name)
logger.info(
"View run at: %s/runs/%s/overview", self.wandb_project_url, wandb_run_id
)
# Collect step info for all steps.
step_ids: Dict[str, bool] = {}
step_name_to_info: Dict[str, Dict[str, Any]] = {}
for step in all_steps:
step_info = StepInfo.new_from_step(step)
step_name_to_info[step.name] = {
k: v for k, v in step_info.to_json_dict().items() if v is not None
}
step_ids[step.unique_id] = True
# Update config with step info.
wandb_run.config.update({"steps": step_name_to_info, "_step_ids": step_ids})
# Update notes.
notes = "Tango run\n--------------"
cacheable_steps = {step for step in all_steps if step.cache_results}
if cacheable_steps:
notes += "\nCacheable steps:\n"
for step in sorted(cacheable_steps, key=lambda step: step.name):
notes += f"\N{bullet} {step.name}"
dependencies = step.dependencies
if dependencies:
notes += ", depends on: " + ", ".join(
sorted(
[f"'{dep.name}'" for dep in dependencies],
)
)
notes += "\n \N{rightwards arrow with hook} "
notes += f"{self.wandb_project_url}/runs/{step.unique_id}/overview\n"
wandb_run.notes = notes
return self.registered_run(wandb_run_name)
def _generate_run_suite_id(self) -> str:
return wandb.util.generate_id()
def registered_runs(self) -> Dict[str, Run]:
runs: Dict[str, Run] = {}
matching_runs = list(
self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters={"config.job_type": RunKind.TANGO_RUN.value}, # type: ignore
)
)
for wandb_run in matching_runs:
runs[wandb_run.name] = self._get_run_from_wandb_run(wandb_run)
return runs
def registered_run(self, name: str) -> Run:
matching_runs = list(
self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters={"display_name": name, "config.job_type": RunKind.TANGO_RUN.value}, # type: ignore
)
)
if not matching_runs:
raise KeyError(f"Run '{name}' not found in workspace")
elif len(matching_runs) > 1:
raise ValueError(f"Found more than one run named '{name}' in W&B project")
return self._get_run_from_wandb_run(matching_runs[0])
def _get_run_from_wandb_run(
self,
wandb_run: wandb.apis.public.Run,
) -> Run:
step_name_to_info = {}
for step_name, step_info_dict in wandb_run.config["steps"].items():
step_info = StepInfo.from_json_dict(step_info_dict)
if step_info.cacheable:
updated_step_info = self._get_updated_step_info(
step_info.unique_id, step_name=step_name
)
if updated_step_info is not None:
step_info = updated_step_info
step_name_to_info[step_name] = step_info
return Run(
name=wandb_run.name,
steps=step_name_to_info,
start_date=datetime.strptime(wandb_run.created_at, "%Y-%m-%dT%H:%M:%S").replace(
tzinfo=pytz.utc
),
)
def _get_updated_step_info(
self, step_id: str, step_name: Optional[str] = None
) -> Optional[StepInfo]:
# First try to find the W&B run corresponding to the step. This will only
# work if the step execution was started already.
filters = {
"config.job_type": RunKind.STEP.value,
"config.step_info.unique_id": step_id,
}
if step_name is not None:
filters["display_name"] = step_name
for wandb_run in self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters=filters, # type: ignore
):
step_info = StepInfo.from_json_dict(wandb_run.config["step_info"])
# Might need to fix the step info the step failed and we failed to update the config.
if step_info.start_time is None:
step_info.start_time = datetime.strptime(
wandb_run.created_at, "%Y-%m-%dT%H:%M:%S"
).replace(tzinfo=pytz.utc)
if wandb_run.state in {"failed", "finished"}:
if step_info.end_time is None:
step_info.end_time = datetime.strptime(
wandb_run.heartbeatAt, "%Y-%m-%dT%H:%M:%S"
).replace(tzinfo=pytz.utc)
if wandb_run.state == "failed" and step_info.error is None:
step_info.error = "Exception"
return step_info
# If the step hasn't been started yet, we'll have to pull the step info from the
# registered run.
filters = {
"config.job_type": RunKind.TANGO_RUN.value,
f"config._step_ids.{step_id}": True,
}
if step_name is not None:
filters[f"config.steps.{step_name}.unique_id"] = step_id
for wandb_run in self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters=filters, # type: ignore
):
if step_name is not None:
step_info_data = wandb_run.config["steps"][step_name]
else:
step_info_data = next(
d for d in wandb_run.config["steps"].values() if d["unique_id"] == step_id
)
step_info = StepInfo.from_json_dict(step_info_data)
return step_info
return None
================================================
FILE: tango/py.typed
================================================
================================================
FILE: tango/settings.py
================================================
from dataclasses import dataclass
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional
import yaml
from .common.aliases import PathOrStr
from .common.from_params import FromParams
from .common.params import Params
@dataclass
class TangoGlobalSettings(FromParams):
"""
Defines global settings for tango.
"""
workspace: Optional[Dict[str, Any]] = None
"""
Parameters to initialize a :class:`~tango.workspace.Workspace` with.
"""
executor: Optional[Dict[str, Any]] = None
"""
Parameters to initialize an :class:`~tango.executor.Executor` with.
"""
include_package: Optional[List[str]] = None
"""
An list of modules where custom registered steps or classes can be found.
"""
log_level: Optional[str] = None
"""
The log level to use. Options are "debug", "info", "warning", and "error".
.. note::
This does not affect the :data:`~tango.common.logging.cli_logger`
or logs from :class:`~tango.common.Tqdm` progress bars.
"""
file_friendly_logging: Optional[bool] = None
"""
If this flag is set to ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
"""
multiprocessing_start_method: str = "spawn"
"""
The ``start_method`` to use when starting new multiprocessing workers. Can be "fork", "spawn",
or "forkserver". Default is "spawn".
See :func:`multiprocessing.set_start_method()` for more details.
"""
environment: Optional[Dict[str, str]] = None
"""
Environment variables that will be set each time ``tango`` is run.
"""
_path: Optional[Path] = None
_DEFAULT_LOCATION: ClassVar[Path] = Path.home() / ".config" / "tango.yml"
@classmethod
def default(cls) -> "TangoGlobalSettings":
"""
Initialize the config from files by checking the default locations
in order, or just return the default if none of the files can be found.
"""
for directory in (Path("."), cls._DEFAULT_LOCATION.parent):
for extension in ("yml", "yaml"):
path = directory / f"tango.{extension}"
if path.is_file():
return cls.from_file(path)
return cls()
@classmethod
def find_or_default(cls, path: Optional[PathOrStr] = None) -> "TangoGlobalSettings":
"""
Initialize the config from a given configuration file, or falls back to returning
the default configuration if no file is given.
"""
if path is not None:
path = Path(path)
if not path.is_file():
raise FileNotFoundError(path)
return cls.from_file(path)
else:
return cls.default()
@property
def path(self) -> Optional[Path]:
"""
The path to the file the config was read from.
"""
return self._path
@classmethod
def from_file(cls, path: PathOrStr) -> "TangoGlobalSettings":
"""
Read settings from a file.
"""
params = Params.from_file(path)
params["_path"] = Path(path).resolve()
return cls.from_params(params)
def to_file(self, path: PathOrStr) -> None:
"""
Save the settings to a file.
"""
data = {
k: v for k, v in self.to_params().as_dict(quiet=True).items() if not k.startswith("_")
}
with open(path, "w") as settings_file:
yaml.safe_dump(data, settings_file)
def save(self) -> None:
"""
Save the settings to the file it was read from.
:raises ValueError: If the settings was not read from a file.
"""
if self.path is None:
raise ValueError("No path given, use .to_file() instead")
self.to_file(self.path)
================================================
FILE: tango/step.py
================================================
import inspect
import itertools
import logging
import random
import re
import warnings
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generic,
Iterable,
Optional,
Set,
Type,
TypeVar,
Union,
cast,
)
from tango.common.det_hash import CustomDetHash, det_hash
from tango.common.exceptions import ConfigurationError, StepStateError
from tango.common.from_params import (
FromParams,
infer_constructor_params,
infer_method_params,
pop_and_construct_arg,
)
from tango.common.lazy import Lazy
from tango.common.logging import cli_logger, log_exception
from tango.common.params import Params
from tango.common.registrable import Registrable
from tango.format import DillFormat, Format
try:
from typing import get_args, get_origin # type: ignore
except ImportError:
def get_origin(tp): # type: ignore
return getattr(tp, "__origin__", None)
def get_args(tp): # type: ignore
return getattr(tp, "__args__", ())
if TYPE_CHECKING:
from tango.workspace import Workspace
_version_re = re.compile("""^[a-zA-Z0-9]+$""")
T = TypeVar("T")
_random_for_step_names = random.Random()
@dataclass
class StepResources(FromParams):
"""
TaskResources describe minimum external hardware requirements which must be available for a
step to run.
"""
machine: Optional[str] = None
"""
This is an executor-dependent option.
With the Beaker executor, for example, you can set this to "local" to force
the executor to run the step locally instead of on Beaker.
"""
cpu_count: Optional[float] = None
"""
Minimum number of logical CPU cores. It may be fractional.
Examples: ``4``, ``0.5``.
"""
gpu_count: Optional[int] = None
"""
Minimum number of GPUs. It must be non-negative.
"""
gpu_type: Optional[str] = None
"""
The type of GPU that the step requires.
The exact string you should use to define a GPU type depends on the executor.
With the Beaker executor, for example, you should use the same strings you
see in the Beaker UI, such as 'NVIDIA A100-SXM-80GB'.
"""
memory: Optional[str] = None
"""
Minimum available system memory as a number with unit suffix.
Examples: ``2.5GiB``, ``1024m``.
"""
shared_memory: Optional[str] = None
"""
Size of ``/dev/shm`` as a number with unit suffix.
Examples: ``2.5GiB``, ``1024m``.
"""
class Step(Registrable, Generic[T]):
"""
This class defines one step in your experiment. To write your own step, derive from this class
and overwrite the :meth:`run()` method. The :meth:`run()` method must have parameters with type hints.
``Step.__init__()`` takes all the arguments we want to run the step with. They get passed
to :meth:`run()` (almost) as they are. If the arguments are other instances of ``Step``, those
will be replaced with the step's results before calling :meth:`run()`. Further, there are four special
parameters:
:param step_name: contains an optional human-readable name for the step. This name is used for
error messages and the like, and has no consequence on the actual computation.
:param cache_results: specifies whether the results of this step should be cached. If this is
``False``, the step is recomputed every time it is needed. If this is not set at all,
and :attr:`CACHEABLE` is ``True``, we cache if the step is marked as :attr:`DETERMINISTIC`,
and we don't cache otherwise.
:param step_format: gives you a way to override the step's default format (which is given in :attr:`FORMAT`).
:param step_config: is the original raw part of the experiment config corresponding to this step.
This can be accessed via the :attr:`config` property within each step's :meth:`run()` method.
:param step_unique_id_override: overrides the construction of the step's unique id using the hash
of inputs.
:param step_resources: gives you a way to set the minimum compute resources required
to run this step. Certain executors require this information.
:param step_metadata: use this to specify additional metadata for your step.
This is added to the :attr:`METADATA` class variable to form the ``self.metadata`` attribute.
Values in ``step_metadata`` take precedence over ``METADATA``.
:param step_extra_dependencies: use this to force a dependency on other steps. Normally dependencies
between steps are determined by the inputs and outputs of the steps, but you can use this
parameter to force that other steps run before this step even if this step doesn't
explicitly depend on the outputs of those steps.
.. important::
Overriding the unique id means that the step will always map to this value, regardless of the inputs,
and therefore, the step cache will only hold a single copy of the step's output (from the last execution).
Thus, in most cases, this should not be used when constructing steps. We include this option for the case
when the executor creates subprocesses, which also need to access the *same* ``Step`` object.
"""
DETERMINISTIC: bool = True
"""This describes whether this step can be relied upon to produce the same results every time
when given the same inputs. If this is ``False``, you can still cache the output of the step,
but the results might be unexpected. Tango will print a warning in this case."""
CACHEABLE: Optional[bool] = None
"""This provides a direct way to turn off caching. For example, a step that reads a HuggingFace
dataset doesn't need to be cached, because HuggingFace datasets already have their own caching
mechanism. But it's still a deterministic step, and all following steps are allowed to cache.
If it is ``None``, the step figures out by itself whether it should be cacheable or not."""
VERSION: Optional[str] = None
"""This is optional, but recommended. Specifying a version gives you a way to tell Tango that
a step has changed during development, and should now be recomputed. This doesn't invalidate
the old results, so when you revert your code, the old cache entries will stick around and be
picked up."""
FORMAT: Format = DillFormat("gz")
"""This specifies the format the results of this step will be serialized in. See the documentation
for :class:`~tango.format.Format` for details."""
SKIP_ID_ARGUMENTS: Set[str] = set()
"""If your :meth:`run()` method takes some arguments that don't affect the results, list them here.
Arguments listed here will not be used to calculate this step's unique ID, and thus changing those
arguments does not invalidate the cache.
For example, you might use this for the batch size in an inference step, where you only care about
the model output, not about how many outputs you can produce at the same time.
"""
SKIP_DEFAULT_ARGUMENTS: Dict[str, Any] = {}
"""Sometimes, you want to add another argument to your :meth:`run()` method, but you don't want to
invalidate the cache when this new argument is set to its default value. If that is the case, add
the argument to this dictionary with the default value that should be ignored."""
METADATA: Dict[str, Any] = {}
"""
Arbitrary metadata about the step.
"""
_UNIQUE_ID_SUFFIX: Optional[str] = None
"""
Used internally for testing.
"""
def __init__(
self,
step_name: Optional[str] = None,
cache_results: Optional[bool] = None,
step_format: Optional[Format] = None,
step_config: Optional[Union[Dict[str, Any], Params]] = None,
step_unique_id_override: Optional[str] = None,
step_resources: Optional[StepResources] = None,
step_metadata: Optional[Dict[str, Any]] = None,
step_extra_dependencies: Optional[Iterable["Step"]] = None,
**kwargs,
):
if self.VERSION is not None:
assert _version_re.match(
self.VERSION
), f"Invalid characters in version '{self.VERSION}'"
run_defaults = {
k: v.default
for k, v in inspect.signature(self.run).parameters.items()
if v.default is not inspect.Parameter.empty
}
self.kwargs = self.massage_kwargs({**run_defaults, **kwargs})
if step_format is None:
self.format = self.FORMAT
if isinstance(self.format, type):
self.format = self.format()
else:
self.format = step_format
self.unique_id_cache = step_unique_id_override
if step_name is None:
self.name = self.unique_id
else:
self.name = step_name
# TODO: It is bad design to have the step_name in the Step class. The same step can be part of multiple
# runs at the same time, and they can have different names in different runs. Step names are
# a property of the run, not of the step.
if cache_results is True:
if not self.CACHEABLE:
raise ConfigurationError(
f"Step {self.name} is configured to use the cache, but it's not a cacheable step."
)
if not self.DETERMINISTIC:
warnings.warn(
f"Step {self.name} is going to be cached despite not being deterministic.",
UserWarning,
)
self.cache_results = True
elif cache_results is False:
self.cache_results = False
elif cache_results is None:
c = (self.DETERMINISTIC, self.CACHEABLE)
if c == (False, None):
self.cache_results = False
elif c == (True, None):
self.cache_results = True
elif c == (False, False):
self.cache_results = False
elif c == (True, False):
self.cache_results = False
elif c == (False, True):
warnings.warn(
f"Step {self.name} is set to be cacheable despite not being deterministic.",
UserWarning,
)
self.cache_results = True
elif c == (True, True):
self.cache_results = True
else:
assert False, "Step.DETERMINISTIC or step.CACHEABLE are set to an invalid value."
else:
raise ConfigurationError(
f"Step {self.name}'s cache_results parameter is set to an invalid value."
)
self._workspace: Optional["Workspace"] = None
self.work_dir_for_run: Optional[
Path
] = None # This is set only while the run() method runs.
if isinstance(step_config, Params):
self._config = step_config.as_dict(quiet=True)
else:
self._config = step_config
assert step_resources is None or isinstance(step_resources, StepResources)
self.step_resources = step_resources
self.metadata = deepcopy(self.METADATA)
if step_metadata:
self.metadata.update(step_metadata)
self.extra_dependencies = set(step_extra_dependencies) if step_extra_dependencies else set()
@property
def class_name(self) -> str:
return self.__class__.__name__
@classmethod
def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Override this method in your step if you want to change the step's arguments before they are passed to the
:meth:`run()` method.
This can be useful if you want to normalize arguments that are passed to your step. For example,
you might not care about the case of a string that's passed in. You can lowercase the string in this
method, and the step will function as if it had been created with a lowercase string from the start.
This way you can make sure that the step's unique ID does not change when the case of the input changes.
.. note::
When the input to a step is another step, this method will see the step in the input, not the other
step's result.
.. warning::
This is an advanced feature of Tango that you won't need most of the time.
By default, this method does nothing and just returns its input unchanged.
:param kwargs: The original kwargs that were passed to the step during construction.
:return: New kwargs that will be passed to the step's :meth:`run()` method.
"""
return kwargs
@property
def logger(self) -> logging.Logger:
"""
A :class:`logging.Logger` that can be used within the :meth:`run()` method.
"""
return logging.getLogger(self.__class__.__name__)
@classmethod
def from_params( # type: ignore[override]
cls: Type["Step"],
params: Union[Params, dict, str],
constructor_to_call: Optional[Callable[..., "Step"]] = None,
constructor_to_inspect: Optional[
Union[Callable[..., "Step"], Callable[["Step"], None]]
] = None,
step_name: Optional[str] = None,
**extras,
) -> "Step":
# Why do we need a custom from_params? Step classes have a run() method that takes all the
# parameters necessary to perform the step. The __init__() method of the step takes those
# same parameters, but each of them could be wrapped in another Step instead of being
# supplied directly. from_params() doesn't know anything about these shenanigans, so
# we have to supply the necessary logic here.
if constructor_to_call is not None:
raise ConfigurationError(
f"{cls.__name__}.from_params cannot be called with a constructor_to_call."
)
if constructor_to_inspect is not None:
raise ConfigurationError(
f"{cls.__name__}.from_params cannot be called with a constructor_to_inspect."
)
if isinstance(params, str):
params = Params({"type": params})
if not isinstance(params, Params):
if isinstance(params, dict):
params = Params(params)
else:
raise ConfigurationError(
"from_params was passed a ``params`` object that was not a ``Params``. This probably "
"indicates malformed parameters in a configuration file, where something that "
"should have been a dictionary was actually a list, or something else. "
f"This happened when constructing an object of type {cls}."
)
# Build up a raw step config
def replace_steps_with_refs(o: Any) -> Any:
if isinstance(o, (list, tuple, set)):
return o.__class__(replace_steps_with_refs(i) for i in o)
elif isinstance(o, (dict, Params)):
result = {key: replace_steps_with_refs(value) for key, value in o.items()}
if isinstance(o, dict):
return result
elif isinstance(o, Params):
return Params(result, history=o.history)
elif isinstance(o, Step):
return {"type": "ref", "ref": o.name}
else:
return deepcopy(o)
raw_step_config = replace_steps_with_refs(params.as_dict(quiet=True))
as_registrable = cast(Type[Registrable], cls)
if "type" in params and params["type"] not in as_registrable.list_available():
as_registrable.search_modules(params["type"])
choice = params.pop_choice(
"type", choices=as_registrable.list_available(), default_to_first_choice=False
)
subclass, constructor_name = as_registrable.resolve_class_name(choice)
if not issubclass(subclass, Step):
# This can happen if `choice` is a fully qualified name.
raise ConfigurationError(
f"Tried to make a Step of type {choice}, but ended up with a {subclass}."
)
if issubclass(subclass, FunctionalStep):
parameters = infer_method_params(subclass, subclass.WRAPPED_FUNC, infer_kwargs=False)
if subclass.BIND:
if "self" not in parameters:
raise ConfigurationError(
f"Functional step for {subclass.WRAPPED_FUNC} is bound but is missing argument 'self'"
)
else:
del parameters["self"]
else:
parameters = infer_method_params(subclass, subclass.run, infer_kwargs=False)
del parameters["self"]
init_parameters = infer_constructor_params(subclass)
del init_parameters["self"]
del init_parameters["kwargs"]
parameter_overlap = parameters.keys() & init_parameters.keys()
assert len(parameter_overlap) <= 0, (
f"If this assert fails it means that you wrote a Step with a run() method that takes one of the "
f"reserved parameters ({', '.join(init_parameters.keys())})"
)
parameters.update(init_parameters)
kwargs: Dict[str, Any] = {}
accepts_kwargs = False
for param_name, param in parameters.items():
if param.kind == param.VAR_KEYWORD:
# When a class takes **kwargs we store the fact that the method allows extra keys; if
# we get extra parameters, instead of crashing, we'll just pass them as-is to the
# constructor, and hope that you know what you're doing.
accepts_kwargs = True
continue
explicitly_set = param_name in params
constructed_arg = pop_and_construct_arg(
subclass.__name__, param_name, param.annotation, param.default, params, extras
)
# If the param wasn't explicitly set in `params` and we just ended up constructing
# the default value for the parameter, we can just omit it.
# Leaving it in can cause issues with **kwargs in some corner cases, where you might end up
# with multiple values for a single parameter (e.g., the default value gives you lazy=False
# for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes
# lazy=True - the superclass sees both lazy=True and lazy=False in its constructor).
if explicitly_set or constructed_arg is not param.default:
kwargs[param_name] = constructed_arg
if accepts_kwargs:
kwargs.update(params)
else:
params.assert_empty(subclass.__name__)
return subclass(step_name=step_name, step_config=raw_step_config, **kwargs)
@abstractmethod
def run(self, **kwargs) -> T:
"""
Execute the step's action.
This method needs to be implemented when creating a ``Step`` subclass, but
it shouldn't be called directly. Instead, call :meth:`result()`.
"""
raise NotImplementedError()
def _run_with_work_dir(self, workspace: "Workspace", needed_by: Optional["Step"] = None) -> T:
if self.work_dir_for_run is not None:
raise RuntimeError("You can only run a Step's run() method once at a time.")
if self.DETERMINISTIC:
random.seed(784507111)
self._workspace = workspace
if self.cache_results:
self.work_dir_for_run = workspace.work_dir(self)
dir_for_cleanup = None
else:
dir_for_cleanup = TemporaryDirectory(prefix=f"{self.unique_id}-", suffix=".step_dir")
self.work_dir_for_run = Path(dir_for_cleanup.name)
try:
self._replace_steps_with_results(self.extra_dependencies, workspace)
kwargs = self._replace_steps_with_results(self.kwargs, workspace)
self.log_starting(needed_by=needed_by)
workspace.step_starting(self)
try:
result = self.run(**kwargs)
result = workspace.step_finished(self, result)
except BaseException as e:
self.log_failure(e)
workspace.step_failed(self, e)
raise
self.log_finished()
return result
finally:
self._workspace = None
self.work_dir_for_run = None
if dir_for_cleanup is not None:
dir_for_cleanup.cleanup()
@property
def work_dir(self) -> Path:
"""
The working directory that a step can use while its ``:meth:run()`` method runs.
This is a convenience property for you to call inside your :meth:`run()` method.
This directory stays around across restarts. You cannot assume that it is empty when your
step runs, but you can use it to store information that helps you restart a step if it
got killed half-way through the last time it ran."""
if self.work_dir_for_run is None:
raise RuntimeError(
"You can only call this method while the step is running with a working directory. "
"Did you call '.run()' directly? You should only run a step with '.result()'."
)
return self.work_dir_for_run
@property
def workspace(self) -> "Workspace":
"""
The :class:`~tango.workspace.Workspace` being used.
This is a convenience property for you to call inside your :meth:`run()` method.
"""
if self._workspace is None:
raise RuntimeError(
"You can only call this method while the step is running with a workspace. "
"Did you call '.run()' directly? You should only run a step with '.result()'."
)
return self._workspace
@property
def config(self) -> Dict[str, Any]:
"""
The configuration parameters that were used to construct the step. This can be empty
if the step was not constructed from a configuration file.
"""
if self._config is None:
raise ValueError(f"No config has been assigned to this step! ('{self.name}')")
else:
return self._config
def det_hash_object(self) -> Any:
return self.unique_id
@property
def resources(self) -> StepResources:
"""
Defines the minimum compute resources required to run this step.
Certain executors require this information in order to allocate resources for each step.
You can set this with the ``step_resources`` argument to :class:`Step`
or you can override this method to automatically define the required resources.
"""
return self.step_resources or StepResources()
@property
def unique_id(self) -> str:
"""Returns the unique ID for this step.
Unique IDs are of the shape ``$class_name-$version-$hash``, where the hash is the hash of the
inputs for deterministic steps, and a random string of characters for non-deterministic ones.
"""
if self.unique_id_cache is None:
self.unique_id_cache = self.class_name
if self.VERSION is not None:
self.unique_id_cache += "-"
self.unique_id_cache += self.VERSION
self.unique_id_cache += "-"
if self.DETERMINISTIC:
hash_kwargs = {
key: value
for key, value in self.kwargs.items()
if (key not in self.SKIP_ID_ARGUMENTS)
and (
(
key not in self.SKIP_DEFAULT_ARGUMENTS
or self.SKIP_DEFAULT_ARGUMENTS[key] != value
)
)
}
self.unique_id_cache += det_hash(
(
(self.format.__class__.__module__, self.format.__class__.__qualname__),
self.format.VERSION,
hash_kwargs,
)
)[:32]
else:
self.unique_id_cache += det_hash(
_random_for_step_names.getrandbits((58**32).bit_length())
)[:32]
if self._UNIQUE_ID_SUFFIX is not None:
self.unique_id_cache += f"-{self._UNIQUE_ID_SUFFIX}"
return self.unique_id_cache
def __str__(self):
return self.unique_id
def __hash__(self):
"""
A step's hash is just its unique ID.
"""
return hash(self.unique_id)
def __eq__(self, other):
"""
Determines whether this step is equal to another step. Two steps with the same unique ID are
considered identical.
"""
if isinstance(other, Step):
return self.unique_id == other.unique_id
else:
return False
def _replace_steps_with_results(self, o: Any, workspace: "Workspace"):
if isinstance(o, (Step, StepIndexer)):
return o.result(workspace=workspace, needed_by=self)
elif isinstance(o, Lazy):
return Lazy(
o._constructor,
params=Params(
self._replace_steps_with_results(o._params.as_dict(quiet=True), workspace)
),
constructor_extras=self._replace_steps_with_results(
o._constructor_extras, workspace
),
)
elif isinstance(o, WithUnresolvedSteps):
return o.construct(workspace)
elif isinstance(o, (list, tuple, set)):
return o.__class__(self._replace_steps_with_results(i, workspace) for i in o)
elif isinstance(o, dict):
return {
key: self._replace_steps_with_results(value, workspace) for key, value in o.items()
}
else:
return o
def result(
self, workspace: Optional["Workspace"] = None, needed_by: Optional["Step"] = None
) -> T:
"""Returns the result of this step. If the results are cached, it returns those. Otherwise it
runs the step and returns the result from there.
If necessary, this method will first produce the results of all steps it depends on."""
if workspace is None:
from tango.workspaces import default_workspace
workspace = default_workspace
from tango.step_info import StepState
if not self.cache_results or self not in workspace.step_cache:
# Try running the step. It might get completed by a different tango process
# if there is a race, so we catch "StepStateError" and check if it's "COMPLETED"
# at that point.
try:
return self._run_with_work_dir(workspace, needed_by=needed_by)
except StepStateError as exc:
if exc.step_state != StepState.COMPLETED or not self.cache_results:
raise
elif self not in workspace.step_cache:
raise StepStateError(
self, exc.step_state, "because it's not found in the cache"
)
else:
# Step has been completed (and cached) by a different process, so we're done.
pass
self.log_cache_hit(needed_by=needed_by)
return workspace.step_cache[self]
def ensure_result(
self,
workspace: Optional["Workspace"] = None,
) -> None:
"""This makes sure that the result of this step is in the cache. It does
not return the result."""
if not self.cache_results:
raise RuntimeError(
"It does not make sense to call ensure_result() on a step that's not cacheable."
)
if workspace is None:
from tango.workspaces import default_workspace
workspace = default_workspace
if self in workspace.step_cache:
self.log_cache_hit()
else:
self.result(workspace)
def _ordered_dependencies(self) -> Iterable["Step"]:
def dependencies_internal(o: Any) -> Iterable[Step]:
if isinstance(o, Step):
yield o
elif isinstance(o, Lazy):
yield from dependencies_internal(o._params.as_dict(quiet=True))
elif isinstance(o, WithUnresolvedSteps):
yield from dependencies_internal(o.args)
yield from dependencies_internal(o.kwargs)
elif isinstance(o, StepIndexer):
yield o.step
elif isinstance(o, str):
return # Confusingly, str is an Iterable of itself, resulting in infinite recursion.
elif isinstance(o, (dict, Params)):
yield from dependencies_internal(o.values())
elif isinstance(o, Iterable):
yield from itertools.chain(*(dependencies_internal(i) for i in o))
else:
return
yield from self.extra_dependencies
yield from dependencies_internal(self.kwargs.values())
@property
def dependencies(self) -> Set["Step"]:
"""
Returns a set of steps that this step depends on. This does not return recursive dependencies.
"""
return set(self._ordered_dependencies())
@property
def recursive_dependencies(self) -> Set["Step"]:
"""
Returns a set of steps that this step depends on. This returns recursive dependencies.
"""
seen = set()
steps = list(self.dependencies)
while len(steps) > 0:
step = steps.pop()
if step in seen:
continue
seen.add(step)
steps.extend(step.dependencies)
return seen
def log_cache_hit(self, needed_by: Optional["Step"] = None) -> None:
if needed_by is not None:
cli_logger.info(
'[green]\N{check mark} Found output for step [bold]"%s"[/bold] in cache '
'(needed by "%s")...[/green]',
self.name,
needed_by.name,
)
else:
cli_logger.info(
'[green]\N{check mark} Found output for step [bold]"%s"[/] in cache...[/]',
self.name,
)
def log_starting(self, needed_by: Optional["Step"] = None) -> None:
if needed_by is not None:
cli_logger.info(
'[blue]\N{black circle} Starting step [bold]"%s"[/] (needed by "%s")...[/]',
self.name,
needed_by.name,
)
else:
cli_logger.info(
'[blue]\N{black circle} Starting step [bold]"%s"[/]...[/]',
self.name,
)
def log_finished(self, run_name: Optional[str] = None) -> None:
if run_name is not None:
cli_logger.info(
'[green]\N{check mark} Finished run for step [bold]"%s"[/] (%s)[/]',
self.name,
run_name,
)
else:
cli_logger.info(
'[green]\N{check mark} Finished step [bold]"%s"[/][/]',
self.name,
)
def log_failure(self, exception: Optional[BaseException] = None) -> None:
if exception is not None:
log_exception(exception, logger=self.logger)
cli_logger.error('[red]\N{ballot x} Step [bold]"%s"[/] failed[/]', self.name)
class FunctionalStep(Step):
WRAPPED_FUNC: ClassVar[Callable]
BIND: ClassVar[bool] = False
@property
def class_name(self) -> str:
return self.WRAPPED_FUNC.__name__
def run(self, *args, **kwargs):
if self.BIND:
return self.WRAPPED_FUNC(*args, **kwargs)
else:
return self.__class__.WRAPPED_FUNC(*args, **kwargs)
def step(
name: Optional[str] = None,
*,
exist_ok: bool = False,
bind: bool = False,
deterministic: bool = True,
cacheable: Optional[bool] = None,
version: Optional[str] = None,
format: Format = DillFormat("gz"),
skip_id_arguments: Optional[Set[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
):
"""
A decorator to create a :class:`Step` from a function.
:param name: A name to register the step under. By default the name of the function is used.
:param exist_ok:
If True, overwrites any existing step registered under the same ``name``. Else,
throws an error if a step is already registered under ``name``.
:param bind: If ``True``, the first argument passed to the step function will
be the underlying :class:`Step` instance, i.e. the function will be called as an instance method.
In this case you must name the first argument 'self' or you will get a
:class:`~tango.common.exceptions.ConfigurationError` when instantiating the class.
See the :class:`Step` class for an explanation of the other parameters.
Example
-------
.. testcode::
from tango import step
@step(version="001")
def add(a: int, b: int) -> int:
return a + b
@step(bind=True)
def bound_step(self) -> None:
assert self.work_dir.is_dir()
"""
def step_wrapper(step_func):
@Step.register(name or step_func.__name__, exist_ok=exist_ok)
class WrapperStep(FunctionalStep):
DETERMINISTIC = deterministic
CACHEABLE = cacheable
VERSION = version
FORMAT = format
SKIP_ID_ARGUMENTS = skip_id_arguments or set()
METADATA = metadata or {}
WRAPPED_FUNC = step_func
BIND = bind
return WrapperStep
return step_wrapper
class StepIndexer(CustomDetHash):
def __init__(self, step: Step, key: Union[str, int]):
self.step = step
self.key = key
def result(
self, workspace: Optional["Workspace"] = None, needed_by: Optional["Step"] = None
) -> Any:
return self.step.result(workspace=workspace, needed_by=needed_by)[self.key]
def det_hash_object(self) -> Any:
return self.step.unique_id, self.key
class WithUnresolvedSteps(CustomDetHash):
"""
This is a helper class for some scenarios where steps depend on other steps.
Let's say we have two steps, :class:`ConsumeDataStep` and :class:`ProduceDataStep`. The easiest way to make
:class:`ConsumeDataStep` depend on :class:`ProduceDataStep` is to specify ``Produce`` as one of the arguments
to the step. This works when ``Consume`` takes the output of ``Produce`` directly, or if it takes
it inside standard Python container, like a list, set, or dictionary.
But what if the output of :class:`ConsumeDataStep` needs to be added to a complex, custom data
structure? :class:`WithUnresolvedSteps` takes care of this scenario.
For example, this works without any help:
.. code-block:: Python
class ProduceDataStep(Step[MyDataClass]):
def run(self, ...) -> MyDataClass
...
return MyDataClass(...)
class ConsumeDataStep(Step):
def run(self, input_data: MyDataClass):
...
produce = ProduceDataStep()
consume = ConsumeDataStep(input_data = produce)
This scenario needs help:
.. code-block:: Python
@dataclass
class DataWithTimestamp:
data: MyDataClass
timestamp: float
class ProduceDataStep(Step[MyDataClass]):
def run(self, ...) -> MyDataClass
...
return MyDataClass(...)
class ConsumeDataStep(Step):
def run(self, input_data: DataWithTimestamp):
...
produce = ProduceDataStep()
consume = ConsumeDataStep(
input_data = DataWithTimestamp(produce, time.now())
)
That does not work, because :class:`DataWithTimestamp` needs an object of type :class:`MyDataClass`, but we're
giving it an object of type :class:`Step[MyDataClass]`. Instead, we change the last line to this:
.. code-block:: Python
consume = ConsumeDataStep(
input_data = WithUnresolvedSteps(
DataWithTimestamp, produce, time.now()
)
)
:class:`WithUnresolvedSteps` will delay calling the constructor of ``DataWithTimestamp`` until
the :meth:`run()` method runs. Tango will make sure that the results from the ``produce`` step
are available at that time, and replaces the step in the arguments with the step's results.
:param function: The function to call after resolving steps to their results.
:param args: The args to pass to the function. These may contain steps, which will be resolved before the
function is called.
:param kwargs: The kwargs to pass to the function. These may contain steps, which will be resolved before the
function is called.
"""
def __init__(self, function, *args, **kwargs):
self.function = function
self.args = args
self.kwargs = kwargs
@classmethod
def with_resolved_steps(
cls,
o: Any,
workspace: "Workspace",
):
"""
Recursively goes through a Python object and replaces all instances of :class:`.Step` with the results of
that step.
:param o: The Python object to go through
:param workspace: The workspace in which to resolve all steps
:return: A new object that's a copy of the original object, with all instances of :class:`.Step` replaced
with the results of the step.
"""
if isinstance(o, (Step, StepIndexer)):
return o.result(workspace=workspace)
elif isinstance(o, Lazy):
return Lazy(
o._constructor,
params=Params(cls.with_resolved_steps(o._params.as_dict(quiet=True), workspace)),
constructor_extras=cls.with_resolved_steps(o._constructor_extras, workspace),
)
elif isinstance(o, cls):
return o.construct(workspace)
elif isinstance(o, (dict, Params)):
return o.__class__(
{key: cls.with_resolved_steps(value, workspace) for key, value in o.items()}
)
elif isinstance(o, (list, tuple, set)):
return o.__class__(cls.with_resolved_steps(item, workspace) for item in o)
else:
return o
def construct(self, workspace: "Workspace"):
"""
Replaces all steps in the args that are stored in this object, and calls the function with those args.
:param workspace: The :class:`.Workspace` in which to resolve all the steps.
:return: The result of calling the function.
"""
resolved_args = self.with_resolved_steps(self.args, workspace)
resolved_kwargs = self.with_resolved_steps(self.kwargs, workspace)
return self.function(*resolved_args, **resolved_kwargs)
def det_hash_object(self) -> Any:
return self.function.__qualname__, self.args, self.kwargs
================================================
FILE: tango/step_cache.py
================================================
import logging
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, TypeVar, Union
from .common.from_params import FromParams
from .common.registrable import Registrable
from .format import Format
from .step import Step
from .step_info import StepInfo
logger = logging.getLogger(__name__)
T = TypeVar("T")
class StepCache(Registrable):
"""
This is a mapping from instances of :class:`~tango.step.Step` to the results of that step.
Generally :class:`StepCache` implementations are used internally by :class:`~tango.workspace.Workspace`
implementations.
"""
default_implementation = "memory"
"""
The default implementation is :class:`.MemoryStepCache`.
"""
def __contains__(self, step: Any) -> bool:
"""This is a generic implementation of ``__contains__``. If you are writing your own
``StepCache``, you might want to write a faster one yourself."""
if not isinstance(step, (Step, StepInfo)):
return False
try:
self.__getitem__(step)
return True
except KeyError:
return False
@abstractmethod
def __getitem__(self, step: Union[Step, StepInfo]) -> Any:
"""Returns the results for the given step."""
raise NotImplementedError()
@abstractmethod
def __setitem__(self, step: Step, value: Any) -> None:
"""Writes the results for the given step. Throws an exception if the step is already cached."""
raise NotImplementedError()
@abstractmethod
def __delitem__(self, step_unique_id: Union[Step, StepInfo]) -> None:
"""Removes a step from step cache"""
raise NotImplementedError()
@abstractmethod
def __len__(self) -> int:
"""Returns the number of results saved in this cache."""
raise NotImplementedError()
@dataclass
class CacheMetadata(FromParams):
step: str
"""
The step name.
"""
format: Format
"""
The format used to serialize the step's result.
"""
================================================
FILE: tango/step_caches/__init__.py
================================================
"""
Built-in :class:`~tango.step_cache.StepCache` implementations.
"""
from .local_step_cache import LocalStepCache
from .memory_step_cache import MemoryStepCache, default_step_cache
================================================
FILE: tango/step_caches/local_step_cache.py
================================================
import collections
import logging
import os
import shutil
import warnings
import weakref
from pathlib import Path
from typing import Any, MutableMapping, Optional, OrderedDict, Union, cast
from tango.common.aliases import PathOrStr
from tango.common.params import Params
from tango.step import Step
from tango.step_cache import CacheMetadata, StepCache
from tango.step_info import StepInfo
logger = logging.getLogger(__name__)
@StepCache.register("local")
class LocalStepCache(StepCache):
"""
This is a :class:`.StepCache` that stores its results on disk, in the location given in ``dir``.
Every cached step gets a directory under ``dir`` with that step's :attr:`~tango.step.Step.unique_id`.
In that directory we store the results themselves in some format according to the step's
:attr:`~tango.step.Step.FORMAT`, and we also write a ``cache-metadata.json`` file that
stores the :class:`.CacheMetadata`.
The presence of ``cache-metadata.json`` signifies that the cache entry is complete and
has been written successfully.
.. tip::
Registered as :class:`.StepCache` under the name "local".
"""
LRU_CACHE_MAX_SIZE = 8
METADATA_FILE_NAME = "cache-metadata.json"
def __init__(self, dir: PathOrStr):
self.dir = Path(dir)
self.dir.mkdir(parents=True, exist_ok=True)
# We keep an in-memory cache as well so we don't have to de-serialize stuff
# we happen to have in memory already.
self.weak_cache: MutableMapping[str, Any]
# Not all Python objects can be referenced weakly, and even if they can they
# might get removed too quickly, so we also keep an LRU cache.
self.strong_cache: OrderedDict[str, Any]
self._init_mem_caches()
def _init_mem_caches(self):
self.weak_cache = weakref.WeakValueDictionary()
self.strong_cache = collections.OrderedDict()
def __getstate__(self):
"""
We override `__getstate__()` to customize how instances of this class are pickled
since we don't want to persist values in the weak and strong in-memory caches
during pickling. And `WeakValueDictionary` can't be pickled anyway.
"""
return {"dir": self.dir}
def __setstate__(self, state):
for k, v in state.items():
setattr(self, k, v)
self._init_mem_caches()
def _add_to_cache(self, key: str, o: Any) -> None:
if hasattr(o, "__next__"):
# We never cache iterators, because they are mutable, storing their current position.
return
self.strong_cache[key] = o
self.strong_cache.move_to_end(key)
while len(self.strong_cache) > self.LRU_CACHE_MAX_SIZE:
del self.strong_cache[next(iter(self.strong_cache))]
try:
self.weak_cache[key] = o
except TypeError:
pass # Many native Python objects cannot be referenced weakly, and they throw TypeError when you try
def _get_from_cache(self, key: str) -> Optional[Any]:
result = self.strong_cache.get(key)
if result is not None:
self.strong_cache.move_to_end(key)
return result
try:
return self.weak_cache[key]
except KeyError:
return None
def _remove_from_cache(self, key: str) -> None:
# check and remove from strong cache
if key in self.strong_cache:
del self.strong_cache[key]
assert key not in self.strong_cache
# check and remove from weak cache
if key in self.weak_cache:
del self.weak_cache[key]
assert key not in self.weak_cache
def _metadata_path(self, step_or_unique_id: Union[Step, StepInfo, str]) -> Path:
return self.step_dir(step_or_unique_id) / self.METADATA_FILE_NAME
def __contains__(self, step: object) -> bool:
if (isinstance(step, Step) and step.cache_results) or (
isinstance(step, StepInfo) and step.cacheable
):
key = step.unique_id
if key in self.strong_cache:
return True
if key in self.weak_cache:
return True
return self._metadata_path(
cast(Union[Step, StepInfo], step) # cast is for mypy :/
).exists()
else:
return False
def __getitem__(self, step: Union[Step, StepInfo]) -> Any:
key = step.unique_id
result = self._get_from_cache(key)
if result is None:
if step not in self:
raise KeyError(step)
metadata = CacheMetadata.from_params(Params.from_file(self._metadata_path(step)))
result = metadata.format.read(self.step_dir(step))
self._add_to_cache(key, result)
return result
def __setitem__(self, step: Step, value: Any) -> None:
if not step.cache_results:
warnings.warn(
f"Tried to cache step '{step.name}' despite being marked as uncacheable",
UserWarning,
)
return
location = self.step_dir(step)
location.mkdir(parents=True, exist_ok=True)
metadata_location = self._metadata_path(step)
if metadata_location.exists():
raise ValueError(f"{metadata_location} already exists! Will not overwrite.")
temp_metadata_location = metadata_location.with_suffix(".temp")
try:
step.format.write(value, location)
metadata = CacheMetadata(step=step.unique_id, format=step.format)
metadata.to_params().to_file(temp_metadata_location)
self._add_to_cache(step.unique_id, value)
temp_metadata_location.rename(metadata_location)
except: # noqa: E722
try:
temp_metadata_location.unlink()
except FileNotFoundError:
pass
raise
def __delitem__(self, step: Union[Step, StepInfo]) -> None:
location = str(self.dir) + "/" + str(step.unique_id)
try:
shutil.rmtree(location)
self._remove_from_cache(step.unique_id)
except OSError:
raise OSError(f"Step cache folder for '{step.unique_id}' not found. Cannot be deleted.")
def __len__(self) -> int:
return sum(1 for _ in self.dir.glob(f"*/{self.METADATA_FILE_NAME}"))
def step_dir(self, step_or_unique_id: Union[Step, StepInfo, str]) -> Path:
"""Returns the directory that contains the results of the step.
You can use this even for a step that's not cached yet. In that case it will return the directory where
the results will be written."""
if isinstance(step_or_unique_id, (Step, StepInfo)):
cacheable = (
step_or_unique_id.cache_results
if isinstance(step_or_unique_id, Step)
else step_or_unique_id.cacheable
)
if not cacheable:
class_name = (
step_or_unique_id.class_name
if isinstance(step_or_unique_id, Step)
else step_or_unique_id.step_class_name
)
raise RuntimeError(
f"Uncacheable steps (like '{class_name}') don't have step directories."
)
unique_id = step_or_unique_id.unique_id
else:
unique_id = step_or_unique_id
return self.dir / unique_id
================================================
FILE: tango/step_caches/memory_step_cache.py
================================================
import logging
import warnings
from typing import Any, Dict, Union
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_info import StepInfo
logger = logging.getLogger(__name__)
@StepCache.register("memory")
class MemoryStepCache(StepCache):
"""
This is a :class:`.StepCache` that stores results in memory. It is little more than a Python dictionary.
.. tip::
Registered as :class:`.StepCache` under the name "memory".
"""
def __init__(self):
self.cache: Dict[str, Any] = {}
def __getitem__(self, step: Union[Step, StepInfo]) -> Any:
return self.cache[step.unique_id]
def __setitem__(self, step: Step, value: Any) -> None:
if step in self:
raise ValueError(f"{step.unique_id} is already cached! Will not overwrite.")
if step.cache_results:
self.cache[step.unique_id] = value
else:
warnings.warn(
f"Tried to cache step '{step.name}' despite being marked as uncacheable.",
UserWarning,
)
def __delitem__(self, step: Union[Step, StepInfo]) -> None:
if step.unique_id in self.cache:
del self.cache[step.unique_id]
else:
raise KeyError(f"{step.unique_id} not present in the memory cache. Cannot be deleted.")
def __contains__(self, step: object) -> bool:
if isinstance(step, (Step, StepInfo)):
return step.unique_id in self.cache
else:
return False
def __len__(self) -> int:
return len(self.cache)
default_step_cache = MemoryStepCache()
================================================
FILE: tango/step_caches/remote_step_cache.py
================================================
import logging
import os
import shutil
import tempfile
from abc import abstractmethod
from pathlib import Path
from typing import Any, Union
from tango.common.aliases import PathOrStr
from tango.common.exceptions import TangoError
from tango.common.file_lock import FileLock
from tango.common.params import Params
from tango.common.remote_utils import RemoteConstants
from tango.step import Step
from tango.step_cache import CacheMetadata
from tango.step_caches.local_step_cache import LocalStepCache
from tango.step_info import StepInfo
logger = logging.getLogger(__name__)
class RemoteNotFoundError(TangoError):
"""
Classes inheriting from the RemoteStepCache should raise this if a step result object is not found.
"""
# This class inherits from `LocalStepCache` to benefit from its in-memory "weak cache" and "strong cache",
# but it handles saving artifacts to disk a little differently.
class RemoteStepCache(LocalStepCache):
"""
This is a :class:`~tango.step_cache.StepCache` that's used by :class:`RemoteWorkspace`.
It stores the results of steps on some RemoteWorkspace.
It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a
step's resulting subsequent times should be fast.
.. tip::
All remote step caches inherit from this.
"""
Constants = RemoteConstants
def __init__(self, local_dir: Path):
super().__init__(local_dir)
@abstractmethod
def _step_result_remote(self, step: Union[Step, StepInfo]):
raise NotImplementedError()
@abstractmethod
def _upload_step_remote(self, step: Step, objects_dir: Path):
raise NotImplementedError()
@abstractmethod
def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None:
raise NotImplementedError()
@abstractmethod
def __len__(self):
raise NotImplementedError()
def _acquire_step_lock_file(self, step: Union[Step, StepInfo], read_only_ok: bool = False):
return FileLock(
self.step_dir(step).with_suffix(".lock"), read_only_ok=read_only_ok
).acquire_with_updates(desc=f"acquiring step cache lock for '{step.unique_id}'")
def __contains__(self, step: Any) -> bool:
if isinstance(step, (Step, StepInfo)):
cacheable = step.cache_results if isinstance(step, Step) else step.cacheable
if not cacheable:
return False
key = step.unique_id
# First check if we have a copy in memory.
if key in self.strong_cache:
return True
if key in self.weak_cache:
return True
# Then check if we have a copy on disk in our cache directory.
with self._acquire_step_lock_file(step, read_only_ok=True):
if self.step_dir(step).is_dir():
return True
# If not, check the remote location.
return self._step_result_remote(step) is not None
else:
return False
def __getitem__(self, step: Union[Step, StepInfo]) -> Any:
key = step.unique_id
step_result = self._step_result_remote(step)
if step_result is None:
raise KeyError(step)
# Try getting the result from our in-memory caches first.
result = self._get_from_cache(key)
if result is not None:
return result
def load_and_return():
metadata = CacheMetadata.from_params(Params.from_file(self._metadata_path(step)))
result = metadata.format.read(self.step_dir(step) / self.Constants.STEP_RESULT_DIR)
self._add_to_cache(key, result)
return result
# Next check our local on-disk cache.
with self._acquire_step_lock_file(step, read_only_ok=True):
if self.step_dir(step).is_dir():
return load_and_return()
# Finally, check the remote location for the corresponding dataset.
with self._acquire_step_lock_file(step):
# Make sure the step wasn't cached since the last time we checked (above).
if self.step_dir(step).is_dir():
return load_and_return()
# We'll download the dataset to a temporary directory first, in case something goes wrong.
temp_dir = tempfile.mkdtemp(dir=self.dir, prefix=key)
try:
self._download_step_remote(step_result, target_dir=temp_dir)
# Download and extraction was successful, rename temp directory to final step result directory.
os.replace(temp_dir, self.step_dir(step))
except RemoteNotFoundError:
raise KeyError(step)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
return load_and_return()
def __setitem__(self, step: Step, value: Any) -> None:
if not step.cache_results:
logger.warning("Tried to cache step %s despite being marked as uncacheable.", step.name)
return
with self._acquire_step_lock_file(step):
# We'll write the step's results to temporary directory first, and try to upload to
# remote workspace from there in case anything goes wrong.
temp_dir = Path(tempfile.mkdtemp(dir=self.dir, prefix=step.unique_id))
(temp_dir / self.Constants.STEP_RESULT_DIR).mkdir()
try:
step.format.write(value, temp_dir / self.Constants.STEP_RESULT_DIR)
metadata = CacheMetadata(step=step.unique_id, format=step.format)
metadata.to_params().to_file(temp_dir / self.METADATA_FILE_NAME)
# Create the dataset and upload serialized result to it.
self._upload_step_remote(step, temp_dir)
# Upload successful, rename temp directory to the final step result directory.
if self.step_dir(step).is_dir():
shutil.rmtree(self.step_dir(step), ignore_errors=True)
os.replace(temp_dir, self.step_dir(step))
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
# Finally, add to in-memory caches.
self._add_to_cache(step.unique_id, value)
================================================
FILE: tango/step_graph.py
================================================
import logging
from typing import Any, Dict, Iterator, List, Mapping, Set, Type, Union
from tango.common import PathOrStr
from tango.common.exceptions import ConfigurationError
from tango.common.params import Params
from tango.step import Step, StepIndexer
logger = logging.getLogger(__name__)
class StepGraph(Mapping[str, Step]):
"""
Represents an experiment as a directed graph.
It can be treated as a :class:`~collections.abc.Mapping` of step names (``str``)
to :class:`Step`.
"""
def __init__(self, step_dict: Dict[str, Step]):
# TODO: What happens with anonymous steps in here?
is_ordered = self._is_ordered(step_dict)
if not is_ordered:
self.parsed_steps = {step.name: step for step in self.ordered_steps(step_dict)}
else:
self.parsed_steps = {}
for step_name, step in step_dict.items():
step.name = step_name
self.parsed_steps[step_name] = step
# Sanity-check the graph
self._sanity_check()
@classmethod
def _is_ordered(cls, step_dict: Dict[str, Step]):
present = set()
for _, step in step_dict.items():
for dep in step.dependencies:
if dep.name not in present:
return False
present.add(step.name)
return True
@classmethod
def _check_unsatisfiable_dependencies(cls, dependencies: Dict[str, Set[str]]) -> None:
# Check whether some of those dependencies can never be satisfied.
unsatisfiable_dependencies = {
dep
for step_deps in dependencies.values()
for dep in step_deps
if dep not in dependencies.keys()
}
if len(unsatisfiable_dependencies) > 0:
if len(unsatisfiable_dependencies) == 1:
dep = next(iter(unsatisfiable_dependencies))
raise ConfigurationError(
f"Specified dependency '{dep}' can't be found in the config."
)
else:
raise ConfigurationError(
f"Some dependencies can't be found in the config: {', '.join(unsatisfiable_dependencies)}"
)
@classmethod
def _get_ordered_steps(cls, dependencies: Dict[str, Set[str]]) -> List[str]:
done: Set[str] = set()
todo = list(dependencies.keys())
ordered_steps = list()
while len(todo) > 0:
new_todo = []
for step_name in todo:
if len(dependencies[step_name] & done) == len(dependencies[step_name]):
done.add(step_name)
ordered_steps.append(step_name)
else:
new_todo.append(step_name)
if len(todo) == len(new_todo):
raise ConfigurationError(
"Could not make progress parsing the steps. "
"You probably have a circular reference between the steps, "
"Or a missing dependency."
)
todo = new_todo
del dependencies
del done
del todo
return ordered_steps
def _sanity_check(self) -> None:
for step in self.parsed_steps.values():
if step.cache_results:
nondeterministic_dependencies = [
s for s in step.recursive_dependencies if not s.DETERMINISTIC
]
if len(nondeterministic_dependencies) > 0:
nd_step = nondeterministic_dependencies[0]
logger.warning(
f"Task {step.name} is set to cache results, but depends on non-deterministic "
f"step {nd_step.name}. This will produce confusing results."
)
@classmethod
def from_params(cls: Type["StepGraph"], params: Dict[str, Params]) -> "StepGraph": # type: ignore[override]
# Determine the order in which to create steps so that all dependent steps are available when we need them.
# This algorithm for resolving step dependencies is O(n^2). Since we're
# anticipating the number of steps in a single config to be in the dozens at most (#famouslastwords),
# we choose simplicity over cleverness.
dependencies = {
step_name: cls._find_step_dependencies(step_params)
for step_name, step_params in params.items()
}
cls._check_unsatisfiable_dependencies(dependencies)
# We need ordered dependencies to construct the steps with refs.
ordered_steps = cls._get_ordered_steps(dependencies)
# Parse the steps
step_dict: Dict[str, Step] = {}
for step_name in ordered_steps:
step_params = params.pop(step_name)
if step_name in step_dict:
raise ConfigurationError(f"Duplicate step name {step_name}")
step_params = cls._replace_step_dependencies(step_params, step_dict)
step_dict[step_name] = Step.from_params(step_params, step_name=step_name)
return cls(step_dict)
def sub_graph(self, *step_names: str) -> "StepGraph":
step_dict: Dict[str, Step] = {}
for step_name in step_names:
if step_name not in self.parsed_steps:
raise KeyError(
f"{step_name} is not a part of this StepGraph. "
f"Available steps are: {list(self.parsed_steps.keys())}"
)
step_dict.update(
{dep.name: dep for dep in self.parsed_steps[step_name].recursive_dependencies}
)
step_dict[step_name] = self.parsed_steps[step_name]
return StepGraph(step_dict)
@staticmethod
def _dict_is_ref(d: Union[dict, Params]) -> bool:
keys = set(d.keys())
if keys == {"ref"}:
return True
if keys >= {"type", "ref"} and d["type"] == "ref":
return True
return False
@classmethod
def _find_step_dependencies(cls, o: Any) -> Set[str]:
dependencies: Set[str] = set()
if isinstance(o, (list, tuple, set)):
for item in o:
dependencies = dependencies | cls._find_step_dependencies(item)
elif isinstance(o, (dict, Params)):
if cls._dict_is_ref(o):
dependencies.add(o["ref"])
else:
for value in o.values():
dependencies = dependencies | cls._find_step_dependencies(value)
elif o is not None and not isinstance(o, (bool, str, int, float)):
raise ValueError(o)
return dependencies
@classmethod
def _replace_step_dependencies(cls, o: Any, existing_steps: Mapping[str, Step]) -> Any:
if isinstance(o, (list, tuple, set)):
return o.__class__(cls._replace_step_dependencies(i, existing_steps) for i in o)
elif isinstance(o, (dict, Params)):
if cls._dict_is_ref(o):
if "key" in o:
return StepIndexer(existing_steps[o["ref"]], o["key"])
else:
return existing_steps[o["ref"]]
else:
result = {
key: cls._replace_step_dependencies(value, existing_steps)
for key, value in o.items()
}
if isinstance(o, dict):
return result
elif isinstance(o, Params):
return Params(result, history=o.history)
else:
raise RuntimeError(f"Object {o} is of unexpected type {o.__class__}.")
elif o is not None and not isinstance(o, (bool, str, int, float)):
raise ValueError(o)
return o
def __getitem__(self, name: str) -> Step:
"""
Get the step with the given name.
"""
return self.parsed_steps[name]
def __len__(self) -> int:
"""
The number of steps in the experiment.
"""
return len(self.parsed_steps)
def __iter__(self) -> Iterator[str]:
"""
The names of the steps in the experiment.
"""
return iter(self.parsed_steps)
@classmethod
def ordered_steps(cls, step_dict: Dict[str, Step]) -> List[Step]:
"""
Returns the steps in this step graph in an order that can be executed one at a time.
This does not take into account which steps may be cached. It simply returns an executable
order of steps.
"""
dependencies = {
step_name: set([dep.name for dep in step.dependencies])
for step_name, step in step_dict.items()
}
result: List[Step] = []
for step_name in cls._get_ordered_steps(dependencies):
step_dict[step_name].name = step_name
result.append(step_dict[step_name])
return result
def uncacheable_leaf_steps(self) -> Set[Step]:
interior_steps: Set[Step] = set()
for _, step in self.parsed_steps.items():
for dependency in step.dependencies:
interior_steps.add(dependency)
uncacheable_leaf_steps = {
step for step in set(self.values()) - interior_steps if not step.cache_results
}
return uncacheable_leaf_steps
@classmethod
def from_file(cls, filename: PathOrStr) -> "StepGraph":
params = Params.from_file(filename)
return StepGraph.from_params(params.pop("steps", keep_as_dict=True))
def to_config(self, include_unique_id: bool = False) -> Dict[str, Dict]:
step_dict = {}
def _to_config(o: Any):
if isinstance(o, (list, tuple, set)):
return o.__class__(_to_config(i) for i in o)
elif isinstance(o, dict):
return {key: _to_config(value) for key, value in o.items()}
elif isinstance(o, Step):
return {"type": "ref", "ref": o.name}
elif isinstance(o, StepIndexer):
return {"type": "ref", "ref": o.step.name, "key": o.key}
elif o is not None and not isinstance(o, (bool, str, int, float)):
raise ValueError(o)
return o
for step_name, step in self.parsed_steps.items():
try:
step_dict[step_name] = {
key: _to_config(value) for key, value in step.config.items()
}
except ValueError: # step.config throws an error.
# If the step_graph was not constructed using a config, we attempt to create
# the config using the step object.
step_dict[step_name] = {
key: _to_config(val) for key, val in step._to_params()["kwargs"].items()
}
step_dict[step_name]["type"] = step.__module__ + "." + step.class_name
# We only add cache_results and format to the config if the values are different from default.
if step.cache_results != step.CACHEABLE:
step_dict[step_name]["cache_results"] = step.cache_results
if step.format != step.FORMAT:
step_dict[step_name]["step_format"] = _to_config(step.format._to_params())
if include_unique_id:
step_dict[step_name]["step_unique_id_override"] = step.unique_id
return step_dict
def to_file(self, filename: PathOrStr, include_unique_id: bool = False) -> None:
"""
Note: In normal use cases, `include_unique_id` should always be False.
We do not want to save the unique id in the config, because we want the
output to change if we modify other kwargs in the config file. We include
this flag for `MulticoreExecutor`, to ensure that steps have the same
unique id in the main process and the created subprocesses.
"""
step_dict = self.to_config(include_unique_id=include_unique_id)
params = Params({"steps": step_dict})
params.to_file(filename)
def __repr__(self) -> str:
result = [f'"{name}": {step}' for name, step in self.items()]
result = ", ".join(result)
return f"{self.__class__.__name__}({result})"
================================================
FILE: tango/step_info.py
================================================
import getpass
import logging
import os
import platform
import socket
import sys
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
import pytz
from .common.from_params import FromParams
from .common.logging import log_exception
from .common.util import StrEnum, jsonify, local_timezone, replace_steps_with_unique_id
from .step import Step
from .version import VERSION
logger = logging.getLogger(__name__)
def get_pip_packages() -> Optional[List[Tuple[str, str]]]:
"""
Get the current working set of pip packages. Equivalent to running ``pip freeze``.
"""
# Adapted from the Weights & Biases client library:
# github.com/wandb/client/blob/a04722575eee72eece7eef0419d0cea20940f9fe/wandb/sdk/internal/meta.py#L56-L72
try:
import pkg_resources
return sorted([(d.key, d.version) for d in iter(pkg_resources.working_set)])
except Exception as exc:
logger.error("Error saving pip packages")
log_exception(exc)
return None
class StepState(StrEnum):
"""Describes the possible state a step can be in."""
INCOMPLETE = "incomplete"
"""The step has not run yet."""
RUNNING = "running"
"""The step is running right now."""
COMPLETED = "completed"
"""The step finished running successfully."""
FAILED = "failed"
"""The step ran, but failed."""
UNCACHEABLE = "uncacheable"
"""The step is uncacheable. It will be executed as many times as the results are needed,
so we don't keep track of the state."""
@dataclass
class GitMetadata(FromParams):
commit: Optional[str] = None
"""
The commit SHA of the current repo.
"""
remote: Optional[str] = None
"""
The URL of the primary remote.
"""
@classmethod
def check_for_repo(cls) -> Optional["GitMetadata"]:
from git import InvalidGitRepositoryError, Repo
try:
repo = Repo(".")
except InvalidGitRepositoryError:
return None
return cls(commit=str(repo.commit()), remote=repo.remote().url)
@dataclass
class TangoMetadata(FromParams):
version: str = VERSION
"""
The tango release version.
"""
@dataclass
class PlatformMetadata(FromParams):
operating_system: str = field(default_factory=platform.platform)
"""
Full operating system name.
"""
cpu_count: Optional[int] = field(default_factory=os.cpu_count)
"""
Numbers of CPUs on the machine.
"""
user: str = field(default_factory=getpass.getuser)
"""
The user that ran this step.
"""
host: str = field(default_factory=socket.gethostname)
"""
Name of the host machine.
"""
@dataclass
class EnvironmentMetadata(FromParams):
python: str = field(default_factory=platform.python_version)
"""
The Python version.
"""
executable: Path = field(default_factory=lambda: Path(sys.executable))
"""
Path to the Python executable.
"""
command: str = field(default_factory=lambda: " ".join(sys.argv))
"""
The exact command used.
"""
root: Path = field(default_factory=lambda: Path(os.getcwd()))
"""
The root directory from where the Python executable was ran.
"""
packages: Optional[List[Tuple[str, str]]] = field(default_factory=get_pip_packages)
"""
The current set of Python packages in the Python environment. Each entry is a tuple of strings.
The first element is the name of the package, the second element is the version.
"""
git: Optional[GitMetadata] = field(default_factory=GitMetadata.check_for_repo)
"""
The :class:`GitMetadata`.
"""
tango: Optional[TangoMetadata] = field(default_factory=TangoMetadata)
"""
The :class:`TangoMetadata`.
"""
@dataclass
class StepInfo(FromParams):
"""Stores step information without being the :class:`.Step` itself.
It's not always possible to get a :class:`.Step` object, because :class:`.Step` objects can't be serialized.
But you can always serialize a :class:`.StepInfo` object.
"""
unique_id: str
"""
The unique ID of the step
"""
step_class_name: str
"""
The name of the :class:`.Step` class
"""
dependencies: Set[str]
"""
The unique ids of all the steps that this step depends on
"""
cacheable: bool
"""
Whether or not the step is cacheable.
"""
step_name: Optional[str] = None
"""
The name of the step, if it has one. Anonymous steps are identified only by their unique ID.
The same step can have different names in different runs. The last run wins, so don't rely
on this property in your code. It is just here to aid readability.
"""
version: Optional[str] = None
"""
The version string of the :class:`.Step`, if it has one.
"""
start_time: Optional[datetime] = None
"""
The time (in UTC) that this step started running.
.. seealso::
:meth:`start_time_local()`.
"""
end_time: Optional[datetime] = None
"""
The time (in UTC) that this step stopped running. This will be set whether the step succeeded or failed.
.. seealso::
:meth:`end_time_local()`.
"""
error: Optional[str] = None
"""
If the step failed, this is where the error goes.
.. note::
Some ``Workspace`` implementations need to serialize ``StepInfo`` (using pickle or dill, for example),
but some exceptions can't be pickled. In those cases ``error`` will just be a string representation
of the exception.
"""
result_location: Optional[str] = None
"""
Location of the result. This could be a path or a URL.
"""
config: Optional[Dict[str, Any]] = None
"""
The raw config of the step.
"""
metadata: Optional[Dict[str, Any]] = None
"""
Metadata from the step. This comes from the ``step_metadata``
argument to the :class:`~tango.step.Step` class.
"""
platform: PlatformMetadata = field(default_factory=PlatformMetadata)
"""
The :class:`PlatformMetadata`.
"""
environment: EnvironmentMetadata = field(default_factory=EnvironmentMetadata)
"""
The :class:`EnvironmentMetadata`.
"""
@property
def start_time_local(self) -> Optional[datetime]:
"""
The time the step started running with respect to the local timezone, if the timezone
can be determined.
"""
return None if self.start_time is None else self.start_time.astimezone(local_timezone())
@property
def end_time_local(self) -> Optional[datetime]:
"""
The time the step stopped running with respect to the local timezone, if the timezone
can be determined.
"""
return None if self.end_time is None else self.end_time.astimezone(local_timezone())
@property
def duration(self) -> Optional[timedelta]:
"""
The time it took to run this step.
"""
if self.start_time is not None and self.end_time is not None:
return self.end_time - self.start_time
else:
return None
@property
def state(self) -> StepState:
"""
Returns the state of the step
"""
if self.cacheable:
if self.start_time is None and self.end_time is None and self.error is None:
return StepState.INCOMPLETE
if self.start_time is not None and self.end_time is None and self.error is None:
return StepState.RUNNING
if self.start_time is not None and self.end_time is not None and self.error is None:
return StepState.COMPLETED
if self.start_time is not None and self.end_time is not None and self.error is not None:
return StepState.FAILED
else:
return StepState.UNCACHEABLE
raise RuntimeError(f"{self.__class__.__name__} is in an invalid state.")
def to_json_dict(self) -> Dict[str, Any]:
"""
Generates a JSON-safe, human-readable, dictionary representation of this dataclass.
"""
return jsonify(self)
@classmethod
def from_json_dict(cls, json_dict: Dict[str, Any]) -> "StepInfo":
"""
The inverse of :meth:`to_json_dict()`.
:param json_dict: A dictionary representation, such as the one produced by :meth:`to_json_dict()`.
"""
step_info = cls.from_params(
{
k: (
datetime.strptime(v, "%Y-%m-%dT%H:%M:%S").replace(tzinfo=pytz.utc)
if k in {"start_time", "end_time"} and v is not None
else v
)
for k, v in json_dict.items()
if k != "config"
}
)
step_info.config = json_dict.get("config")
return step_info
@classmethod
def new_from_step(cls, step: Step, **kwargs) -> "StepInfo":
try:
config = step.config
except ValueError:
config = None
return cls(
unique_id=step.unique_id,
step_name=step.name,
step_class_name=step.class_name,
version=step.VERSION,
dependencies={dep.unique_id for dep in step.dependencies},
cacheable=step.cache_results,
config=replace_steps_with_unique_id(config),
metadata=step.metadata,
**kwargs,
)
def refresh(self):
"""
Refresh environment and platform metadata.
"""
self.platform = PlatformMetadata()
self.environment = EnvironmentMetadata()
================================================
FILE: tango/steps/__init__.py
================================================
"""
Built-in :class:`~tango.step.Step` implementations that are not tied to any particular
integration.
"""
__all__ = ["DatasetCombineStep", "DatasetRemixStep", "PrintStep", "ShellStep"]
from .dataset_remix import DatasetCombineStep, DatasetRemixStep
from .print import PrintStep
from .shell_step import ShellStep
================================================
FILE: tango/steps/dataset_remix.py
================================================
import collections
import random
import re
from typing import Any, Dict, List, Mapping, Sequence
from tango.common.dataset_dict import DatasetDict
from tango.common.sequences import (
ConcatenatedSequence,
ShuffledSequence,
SlicedSequence,
)
from tango.step import Step
@Step.register("dataset_remix")
class DatasetRemixStep(Step[DatasetDict]):
"""
This step can remix splits in a :class:`~tango.common.dataset_dict.DatasetDict` into new splits.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "dataset_remix".
Examples
--------
.. testcode::
:hide:
from tango.common.logging import initialize_logging
initialize_logging(enable_cli_logs=True)
.. testcode::
input = DatasetDict({
"train": list(range(10)),
"dev": list(range(10, 15)),
})
new_splits = {
"all": "train + dev",
"crossval_train": "train[0:5] + train[7:]",
"crossval_test": "train[5:7]",
}
remix_step = DatasetRemixStep(input=input, new_splits=new_splits)
remixed_dataset = remix_step.result()
.. testoutput::
:hide:
:options: +ELLIPSIS
...
"""
DETERMINISTIC = True
CACHEABLE = False # This is so fast it's not worth caching.
VERSION = "001"
def run( # type: ignore
self,
input: DatasetDict,
new_splits: Dict[str, str],
keep_old_splits: bool = True,
shuffle_before: bool = False,
shuffle_after: bool = False,
random_seed: int = 1532637578,
) -> DatasetDict:
"""
Remixes and shuffles a dataset. This is done lazily, so all operations are fast.
:param input:
The input dataset that will be remixed.
:param new_splits:
Specifies the new splits that the output dataset should have. Keys are the name of the new
splits. Values refer to the original splits. You can refer to original splits in the following ways:
* Mention the original split name to copy it to a new name.
* Mention the original split name with Python's slicing syntax to select part of the original
split's instances. For example, ``"train[:1000]"`` selects the first 1000 instances from the
``"train"`` split.
* ``"instances + instances"`` concatenates the instances into one split.
You can combine these possibilities.
:param keep_old_splits:
Whether to keep the splits from the input dataset in addition to the new ones given by
``new_splits``.
:param shuffle_before:
Whether to shuffle the input splits before creating the new ones.
If you need shuffled instances and you're not sure the input is properly shuffled, use this.
:param shuffle_after:
Whether to shuffle the input splits after creating the new ones.
If you need shuffled instances and you're slicing or concatenating splits, use this.
If you want to be on the safe side, shuffle both before and after. Shuffling is a cheap operation.
:param random_seed:
Random seed, affects shuffling
:returns:
Returns a new dataset that is appropriately remixed.
"""
random.seed(random_seed)
if shuffle_before:
input_splits: Mapping[str, Sequence[Any]] = {
split_name: ShuffledSequence(split_instances)
for split_name, split_instances in input.splits.items()
}
else:
input_splits = input.splits
def get_slice(split_name: str) -> Sequence[Any]:
slice_match = re.match(r"(.*)\[(-?[0-9]*:-?[0-9]*)\]", split_name)
if slice_match is None:
return input[split_name]
else:
split_name = slice_match[1]
slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(":")]
return SlicedSequence(input[split_name], slice(*slice_args))
def parse_split_spec(split_spec: str):
parts = [get_slice(name.strip()) for name in split_spec.split("+")]
if len(parts) == 1:
return parts[0]
else:
return ConcatenatedSequence(*parts)
if keep_old_splits:
result = dict(input_splits.items())
else:
result = {}
result.update(
{
new_split_name: parse_split_spec(new_split_spec)
for new_split_name, new_split_spec in new_splits.items()
}
)
if shuffle_after:
result = {
split_name: ShuffledSequence(split_instances)
for split_name, split_instances in result.items()
}
return DatasetDict(splits=result, metadata=input.metadata)
@Step.register("dataset_combine")
class DatasetCombineStep(Step[DatasetDict]):
"""
This step combines multiple :class:`~tango.common.dataset_dict.DatasetDict` s into one.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "dataset_combine".
Examples
--------
.. testcode::
:hide:
from tango.common.logging import initialize_logging
initialize_logging(enable_cli_logs=True)
.. testcode::
input1 = DatasetDict({
"train": list(range(10)),
"dev": list(range(10, 15)),
})
input2 = DatasetDict({
"train": list(range(15, 25)),
"val": list(range(25, 30)),
})
combined = DatasetCombineStep(inputs=[input1, input2])
combined_dataset = combined.result()
.. testoutput::
:hide:
:options: +ELLIPSIS
...
"""
DETERMINISTIC = True
CACHEABLE = False # This is so fast it's not worth caching.
VERSION = "001"
def run( # type: ignore
self,
inputs: List[DatasetDict],
shuffle: bool = False,
random_seed: int = 1532637578,
) -> DatasetDict:
"""
Combines multiple datasets into one. This is done lazily, so all operations are fast.
If a split is present in more than one input dataset, the output dataset will have a split that's
the concatenation of the input splits.
:param inputs:
The list of input datasets that will be combined.
:param shuffle:
Whether to shuffle the combined datasets. If you don't do this, the new splits will contain first
all the instances from one dataset, and then all the instances from another dataset.
:param random_seed:
Random seed, affects shuffling
:returns:
Returns a new dataset that is the combination of the input datasets.
"""
split_to_datasets: Dict[str, List[Sequence]] = collections.defaultdict(lambda: [])
for input in inputs:
for split_name, sequence in input.items():
split_to_datasets[split_name].append(sequence)
result: Dict[str, Sequence] = {
split_name: ConcatenatedSequence(*sequences)
for split_name, sequences in split_to_datasets.items()
}
if shuffle:
random.seed(random_seed)
result = {
split_name: ShuffledSequence(split_instances)
for split_name, split_instances in result.items()
}
return DatasetDict(result, {})
================================================
FILE: tango/steps/print.py
================================================
import logging
from typing import Any
from tango.common.logging import cli_logger
from tango.step import Step
@Step.register("print")
class PrintStep(Step):
"""
This step just logs or prints its input and also returns what it prints.
"""
DETERMINISTIC = True
CACHEABLE = False # so fast it's not worth caching
def run(self, input: Any) -> str: # type: ignore[override]
"""
Print out the input.
"""
out = str(input)
if self.logger.isEnabledFor(logging.INFO):
self.logger.info(out)
elif cli_logger.isEnabledFor(logging.INFO):
cli_logger.info(out)
else:
print(out)
return out
================================================
FILE: tango/steps/shell_step.py
================================================
import os
import subprocess
from typing import List, Optional, Union
from tango.common import PathOrStr, RegistrableFunction, make_registrable
from tango.step import Step
@make_registrable(exist_ok=True)
def check_path_existence(path: PathOrStr):
assert os.path.exists(path), f"Output not found at {path}!"
@Step.register("shell_step")
class ShellStep(Step):
"""
This step runs a shell command, and returns the standard output as a string.
.. tip::
Registered as a :class:`~tango.step.Step` under the name "shell_step".
:param shell_command: The shell command to run.
:param output_path: The step makes no assumptions about the command being run. If your command produces some
output, you can optionally specify the output path for recording the output location, and optionally
validating it. See `validate_output` argument for this.
:param validate_output: If an expected `output_path` has been specified, you can choose to validate that the
step produced the correct output. By default, it will just check if the `output_path` exists, but you can
pass any other validating function. For example, if your command is a script generating a model output,
you can check if the model weights can be loaded.
:param kwargs: Other kwargs to be passed to `subprocess.run()`. If you need to take advantage of environment
variables, set `shell = True`.
"""
def run( # type: ignore[override]
self,
shell_command: Union[str, List[str]],
output_path: Optional[PathOrStr] = None,
validate_output: RegistrableFunction = check_path_existence,
**kwargs,
):
output = self.run_command(shell_command, **kwargs)
self.logger.info(output)
if output_path is not None:
validate_output(output_path)
self.logger.info(f"Output found at: {output_path}")
return str(output.decode("utf-8"))
def run_command(self, command: Union[str, List[str]], **kwargs):
import shlex
if kwargs.get("shell", None):
if isinstance(command, list):
command = shlex.join(command)
else:
if isinstance(command, str):
command = shlex.split(command)
self.logger.info(f"Command: {command}")
process = subprocess.run(command, capture_output=True, **kwargs)
if process.returncode != 0:
raise RuntimeError(f"The command failed with the following errors: {process.stderr}")
return process.stdout
================================================
FILE: tango/version.py
================================================
_MAJOR = "1"
_MINOR = "3"
_PATCH = "2"
# This is mainly for pre-releases which have the suffix "rc[0-9]+".
_SUFFIX = ""
VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
================================================
FILE: tango/workspace.py
================================================
import logging
from abc import abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import (
Any,
ContextManager,
Dict,
Generator,
Iterable,
List,
Optional,
TypeVar,
Union,
cast,
)
from urllib.parse import ParseResult, urlparse
import pytz
from .common import Registrable
from .common.from_params import FromParams
from .common.util import StrEnum, jsonify, utc_now_datetime
from .step import Step
from .step_cache import StepCache
from .step_info import StepInfo, StepState
logger = logging.getLogger(__name__)
T = TypeVar("T")
@dataclass
class Run(FromParams):
"""
Stores information about a single Tango run.
"""
name: str
"""
The name of the run
"""
steps: Dict[str, StepInfo]
"""
A mapping from step names to :class:`~tango.step_info.StepInfo`, for all the target steps in the run.
This only contains the targets of a run. Usually, that means it contains all named steps.
Un-named dependencies (or dependencies that are not targets) are not contained in ``steps``.
"""
start_date: datetime
"""
The time at which the run was registered in the workspace.
"""
def to_json_dict(self) -> Dict[str, Any]:
return jsonify(self)
@classmethod
def from_json_dict(cls, json_dict: Dict[str, Any]) -> "Run":
params = {**json_dict}
params["start_date"] = datetime.strptime(params["start_date"], "%Y-%m-%dT%H:%M:%S").replace(
tzinfo=pytz.utc
)
params["steps"] = {k: StepInfo.from_json_dict(v) for k, v in params["steps"].items()}
return cls.from_params(params)
@dataclass
class RunInfo(FromParams):
"""
Stores partial data about a run. This is the type that you get back from
:meth:`Workspace.search_registered_runs()`. The data here is a subset of
the data in the :class:`Run` type because not all workspaces can fetch all
of the data in the :class:`Run` type efficiently.
"""
name: str
"""
The name of the run.
"""
steps: Optional[Dict[str, str]] = None
"""
The steps within the run. An optional mapping of step name to step unique ID.
"""
start_date: Optional[datetime] = None
"""
The time at which the run was registered in the workspace.
"""
class RunSort(StrEnum):
START_DATE = "start_date"
NAME = "name"
class StepInfoSort(StrEnum):
UNIQUE_ID = "unique_id"
START_TIME = "start_time"
class Workspace(Registrable):
"""
A workspace is a place for Tango to put the results of steps, intermediate results, and various other pieces
of metadata. If you don't want to worry about all that, do nothing and Tango will use the default
:class:`.LocalWorkspace` that puts everything into a directory of your choosing.
If you want to do fancy things like store results in the cloud, share state across machines, etc., this is your
integration point.
If you got here solely because you want to share results between machines, consider that
:class:`.LocalWorkspace` works fine on an NFS drive.
"""
default_implementation = "local"
#
# As a general rule, workspaces can never return `Step`, only `StepInfo`, because we can't reliably serialize
# objects of type `Step`. Doing that would require serializing the code that runs the step, and we can't
# do that.
#
def __init__(self):
self._delayed_cleanup_temp_dirs: List[TemporaryDirectory] = []
def __getstate__(self):
"""
We override `__getstate__()` to customize how instances of this class are pickled
since we don't want to persist certain attributes.
"""
out = {k: v for k, v in self.__dict__.items() if k not in {"_delayed_cleanup_temp_dirs"}}
out["_delayed_cleanup_temp_dirs"] = []
return out
@property
@abstractmethod
def url(self) -> str:
"""
Get a URL for the workspace that can be used to instantiate the same workspace
using :meth:`.from_url()`.
"""
raise NotImplementedError
@classmethod
def from_url(cls, url: str) -> "Workspace":
"""
Initialize a :class:`Workspace` from a workspace URL or path, e.g. ``local:///tmp/workspace``
would give you a :class:`~tango.workspaces.LocalWorkspace` in the directory ``/tmp/workspace``.
For :class:`~tango.workspaces.LocalWorkspace`, you can also just pass in a plain
path, e.g. ``/tmp/workspace``.
.. tip::
Registered as a workspace constructor under the name "from_url".
"""
parsed = urlparse(url)
workspace_type = parsed.scheme or "local"
workspace_cls = cast(Workspace, cls.by_name(workspace_type))
return workspace_cls.from_parsed_url(parsed)
@classmethod
@abstractmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> "Workspace":
"""
Subclasses should override this so that can be initialized from a URL.
:param parsed_url: The parsed URL object.
"""
raise NotImplementedError
@property
@abstractmethod
def step_cache(self) -> StepCache:
"""
A :class:`.StepCache` to store step results in
"""
raise NotImplementedError()
def work_dir(self, step: Step) -> Path:
"""Steps that can be restarted (like a training job that gets interrupted half-way through)
must save their state somewhere. A :class:`.StepCache` can help by providing a suitable location
in this method.
By default, the step dir is a temporary directory that gets cleaned up after every run.
This effectively disables restartability of steps."""
# TemporaryDirectory cleans up the directory automatically when the TemporaryDirectory object
# gets garbage collected, so we hold on to it in the Workspace.
dir = TemporaryDirectory(prefix=f"{step.unique_id}-", suffix=".step_dir")
self._delayed_cleanup_temp_dirs.append(dir)
return Path(dir.name)
@abstractmethod
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
"""
Returns a :class:`~tango.step_info.StepInfo` for a given step.
:raises KeyError: If the corresponding step info cannot be found or created.
This should never happen if you pass a :class:`~tango.step.Step` object to this method
since a :class:`~tango.step_info.StepInfo` can always be created from a
:class:`~tango.step.Step`.
"""
raise NotImplementedError()
def search_step_info(
self,
*,
sort_by: Optional[StepInfoSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
state: Optional[StepState] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[StepInfo]:
"""
Search through steps in the workspace.
This method is primarily meant to be used to implement a UI, and workspaces don't necessarily
need to implement all `sort_by` or filter operations. They should only implement those
that can be done efficiently.
:param sort_by: The field to sort the results by.
:param sort_descending: Sort the results in descending order of the ``sort_by`` field.
:param match: Only return steps with a unique ID matching this string.
:param state: Only return steps that are in the given state.
:param start: Start from a certain index in the results.
:param stop: Stop at a certain index in the results.
:raises NotImplementedError: If a workspace doesn't support an efficient implementation
for the given sorting/filtering criteria.
"""
steps = [
step
for run in self.registered_runs().values()
for step in run.steps.values()
if (match is None or match in step.unique_id) and (state is None or step.state == state)
]
if sort_by == StepInfoSort.START_TIME:
now = utc_now_datetime()
steps = sorted(
steps,
key=lambda step: step.start_time or now,
reverse=sort_descending,
)
elif sort_by == StepInfoSort.UNIQUE_ID:
steps = sorted(steps, key=lambda step: step.unique_id, reverse=sort_descending)
elif sort_by is not None:
raise NotImplementedError
return steps[slice(start, stop)]
def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int:
"""
Get the total number of registered steps.
:param match: Only count steps with a unique ID matching this string.
:param state: Only count steps that are in the given state.
"""
return len(self.search_step_info(match=match, state=state))
@abstractmethod
def step_starting(self, step: Step) -> None:
"""
The :class:`.Step` class calls this when a step is about to start running.
:param step: The step that is about to start.
:raises StepStateError: If the step is in an unexpected state (e.g. RUNNING).
"""
raise NotImplementedError()
@abstractmethod
def step_finished(self, step: Step, result: T) -> T:
"""
The :class:`.Step` class calls this when a step finished running.
:param step: The step that finished.
:raises StepStateError: If the step is in an unexpected state (e.g. RUNNING).
This method is given the result of the step's :meth:`.Step.run` method. It is expected to return that
result. This gives it the opportunity to make changes to the result if necessary. For example, if the
:meth:`.Step.run` method returns an iterator, that iterator would be consumed when it's written to the
cache. So this method can handle the situation and return something other than the now-consumed iterator.
"""
raise NotImplementedError()
@abstractmethod
def step_failed(self, step: Step, e: BaseException) -> None:
"""
The :class:`.Step` class calls this when a step failed.
:param step: The step that failed.
:param e: The exception thrown by the step's :meth:`.Step.run` method.
:raises StepStateError: If the step is in an unexpected state (e.g. RUNNING).
"""
raise NotImplementedError()
@abstractmethod
def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
"""
Register a run in the workspace. A run is a set of target steps that a user wants to execute.
:param targets: The steps that the user wants to execute. This could come from a :class:`.StepGraph`.
:param name: A name for the run. Runs must have unique names. If not given, this method invents a name and
returns it.
:return: The run object
"""
raise NotImplementedError()
def search_registered_runs(
self,
*,
sort_by: Optional[RunSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[RunInfo]:
"""
Search through registered runs in the workspace.
This method is primarily meant to be used to implement a UI, and workspaces don't necessarily
need to implement all `sort_by` or filter operations. They should only implement those
that can be done efficiently.
.. note::
The data type returned in the list here is :class:`RunInfo`, which
contains a subset of the data in the :class:`Run` type.
:param sort_by: The field to sort the results by.
:param sort_descending: Sort the results in descending order of the ``sort_by`` field.
:param match: Only return results with a name matching this string.
:param start: Start from a certain index in the results.
:param stop: Stop at a certain index in the results.
:raises NotImplementedError: If a workspace doesn't support an efficient implementation
for the given sorting/filtering criteria.
"""
runs = [
run for run in self.registered_runs().values() if match is None or match in run.name
]
if sort_by == RunSort.START_DATE:
runs = sorted(runs, key=lambda run: run.start_date, reverse=sort_descending)
elif sort_by == RunSort.NAME:
runs = sorted(runs, key=lambda run: run.name, reverse=sort_descending)
elif sort_by is not None:
raise NotImplementedError
return [
RunInfo(
name=run.name,
start_date=run.start_date,
steps={k: s.unique_id for k, s in run.steps.items()},
)
for run in runs[slice(start, stop)]
]
def num_registered_runs(self, *, match: Optional[str] = None) -> int:
"""
Get the number of registered runs.
:param match: Only count runs with a name matching this string.
"""
return len(self.search_registered_runs(match=match))
@abstractmethod
def registered_runs(self) -> Dict[str, Run]:
"""
Returns all runs in the workspace
:return: A dictionary mapping run names to :class:`Run` objects
"""
raise NotImplementedError
@abstractmethod
def registered_run(self, name: str) -> Run:
"""
Returns the run with the given name
:return: A :class:`Run` object representing the named run
:raises KeyError: If there is no run with the given name.
"""
raise NotImplementedError()
def step_result_for_run(self, run_name: str, step_name: str) -> Any:
"""
Get the result of a step from a run.
:raises KeyError: If there is no run or step with the given name.
"""
run = self.registered_run(run_name)
step_info = run.steps[step_name]
try:
return self.step_cache[step_info]
except KeyError:
raise KeyError(f"Step result for '{step_name}' not found in workspace")
def step_result(self, step_name: str) -> Any:
"""
Get the result of a step from the latest run with a step by that name.
:raises KeyError: If there is no run with the given step.
"""
runs = sorted(self.registered_runs().values(), key=lambda run: run.start_date, reverse=True)
for run in runs:
if step_name in run.steps:
return self.step_cache[run.steps[step_name]]
raise KeyError(f"No step named '{step_name}' found in previous runs")
@abstractmethod
def remove_step(self, step_unique_id: str):
"""
Removes cached step using the given unique step id
:raises KeyError: If there is no step with the given name.
"""
raise NotImplementedError()
def capture_logs_for_run(self, name: str) -> ContextManager[None]:
"""
Should return a context manager that can be used to capture the logs for a run.
By default, this doesn't do anything.
Examples
--------
The :class:`.LocalWorkspace` implementation uses this method to capture logs
to a file in the workspace's directory using the :func:`~tango.common.logging.file_handler()`
context manager, similar to this:
.. testcode::
from tango.common.logging import file_handler
from tango.workspace import Workspace
class MyLocalWorkspace(Workspace):
def capture_logs_for_run(self, name: str):
return file_handler("/path/to/workspace/" + name + ".log")
"""
@contextmanager
def do_nothing() -> Generator[None, None, None]:
yield None
return do_nothing()
Workspace.register("from_url", constructor="from_url")(Workspace) # type: ignore
================================================
FILE: tango/workspaces/__init__.py
================================================
"""
Built-in :class:`~tango.workspace.Workspace` implementations.
"""
from .local_workspace import LocalWorkspace
from .memory_workspace import MemoryWorkspace, default_workspace
================================================
FILE: tango/workspaces/local_workspace.py
================================================
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, Iterator, List, Optional, Set, TypeVar, Union
from urllib.parse import ParseResult
import dill
import petname
from sqlitedict import SqliteDict
from tango.common import PathOrStr
from tango.common.exceptions import StepStateError
from tango.common.file_lock import FileLock
from tango.common.logging import file_handler
from tango.common.util import exception_to_string, utc_now_datetime
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_caches import LocalStepCache
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, StepInfoSort, Workspace
logger = logging.getLogger(__name__)
T = TypeVar("T")
@Workspace.register("local")
class LocalWorkspace(Workspace):
"""
This is a :class:`.Workspace` that keeps all its data in a local directory. This works great for single-machine
jobs, or for multiple machines in a cluster if they can all access the same NFS drive.
:param dir: The directory to store all the data in
The directory will have three subdirectories, ``cache/`` for the step cache, ``runs/`` for the runs,
and ``latest/`` for the results of the latest run. For the format of the ``cache/`` directory,
refer to :class:`.LocalStepCache`. The ``runs/`` directory will contain one subdirectory for each
registered run. Each one of those contains a symlink from the name of the step to the results directory
in the step cache. Note that :class:`.LocalWorkspace` creates these symlinks even for steps that have not
finished yet. You can tell the difference because either the symlink points to a directory that doesn't exist,
or it points to a directory in the step cache that doesn't contain results.
.. tip::
Registered as a :class:`~tango.workspace.Workspace` under the name "local".
You can also instantiate this workspace from a URL with the scheme ``local://``.
For example, ``Workspace.from_url("local:///tmp/workspace")`` gives you a :class:`LocalWorkspace`
in the directory ``/tmp/workspace``.
"""
def __init__(self, dir: PathOrStr):
super().__init__()
self.dir = Path(dir)
self.dir.mkdir(parents=True, exist_ok=True)
self.cache = LocalStepCache(self.dir / "cache")
self.locks: Dict[Step, FileLock] = {}
self.runs_dir = self.dir / "runs"
self.runs_dir.mkdir(parents=True, exist_ok=True)
self.step_info_file = self.dir / "stepinfo.sqlite"
self.latest_dir = self.dir / "latest"
# Check the version of the local workspace
try:
with open(self.dir / "settings.json", "r") as settings_file:
settings = json.load(settings_file)
except FileNotFoundError:
settings = {"version": 1}
# Upgrade to version 2
if settings["version"] == 1:
with SqliteDict(self.step_info_file) as d:
for stepinfo_file in self.cache.dir.glob("*/stepinfo.dill"):
with stepinfo_file.open("rb") as f:
stepinfo = dill.load(f)
# The `StepInfo` class changed from one version to the next. The deserialized version
# ends up being a `StepInfo` instance that is missing the `cacheable` member. This
# hack adds it in.
kwargs = stepinfo.__dict__
kwargs[
"cacheable"
] = True # Only cacheable steps were saved in v1. That's what v2 fixes.
d[stepinfo.unique_id] = StepInfo(**kwargs)
d.commit()
for stepinfo_file in self.cache.dir.glob("*/stepinfo.dill"):
stepinfo_file.unlink()
settings["version"] = 2
with open(self.dir / "settings.json", "w") as settings_file:
json.dump(settings, settings_file)
def __getstate__(self):
"""
We override `__getstate__()` to customize how instances of this class are pickled
since we don't want to persist certain attributes.
"""
out = super().__getstate__()
out["locks"] = {}
return out
@property
def url(self) -> str:
return "local://" + str(self.dir)
@classmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> "Workspace":
workspace_dir: Path
if parsed_url.netloc:
workspace_dir = Path(parsed_url.netloc)
if parsed_url.path:
workspace_dir = workspace_dir / parsed_url.path.lstrip("/")
elif parsed_url.path:
workspace_dir = Path(parsed_url.path)
else:
workspace_dir = Path(".")
return cls(workspace_dir.resolve())
def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path:
return self.cache.step_dir(step_or_unique_id)
@property
def step_cache(self) -> StepCache:
return self.cache
def work_dir(self, step: Step) -> Path:
result = self.step_dir(step) / "work"
result.mkdir(parents=True, exist_ok=True)
return result
@classmethod
def guess_step_dir_state(cls, dir: Path) -> Set[StepState]:
"""
Returns the possible states of a given step dir, to the best of our knowledge.
:param dir: the step dir to example
:return: a set of possible states for the step
"""
# If the directory doesn't exist, the step is incomplete or uncacheable.
if not dir.exists():
return {StepState.INCOMPLETE, StepState.UNCACHEABLE}
# If the lock file exists and is locked, the step is running.
lock_file = dir / "lock"
if lock_file.exists():
lock = FileLock(lock_file)
try:
lock.acquire(0.1)
lock.release()
except TimeoutError:
return {StepState.RUNNING}
# If the directory is empty except for the work dir and the lock file, the step is running, incomplete,
# or failed. But it can't be running because then the lockfile would be locked, so it can only be
# incomplete or failed.
for dir_entry in dir.iterdir():
if dir_entry.name == "work" and dir_entry.is_dir():
continue
if dir_entry.name == "lock" and dir_entry.is_file():
continue
break
else:
return {StepState.INCOMPLETE, StepState.FAILED}
return set(StepState)
@staticmethod
def _fix_step_info(step_info: StepInfo) -> None:
"""
Tragically we need to run a fix-up step over StepInfo objects that are freshly read from
the database. This is for backwards compatibility.
This function operates on the `step_info` object in place.
"""
if isinstance(step_info.error, BaseException):
step_info.error = exception_to_string(step_info.error)
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
with SqliteDict(self.step_info_file) as d:
def find_or_add_step_info(step_or_unique_id: Union[Step, str]) -> StepInfo:
if isinstance(step_or_unique_id, Step):
unique_id = step_or_unique_id.unique_id
else:
unique_id = step_or_unique_id
try:
step_info = d[unique_id]
except KeyError:
if not isinstance(step_or_unique_id, Step):
raise
step = step_or_unique_id
for dep in step.dependencies:
find_or_add_step_info(dep)
step_info = StepInfo.new_from_step(step)
d[unique_id] = step_info
del step
# Perform some sanity checks. Sqlite and the file system can get out of sync
# when a process dies suddenly.
step_dir = self.step_dir(unique_id)
step_state_guesses = self.guess_step_dir_state(step_dir) or step_info.state
if step_info.state not in step_state_guesses:
if step_info.state == StepState.RUNNING:
# We think the step is running, but it can't possibly be running, so we go ahead and
# assume the step is incomplete.
step_info.start_time = None
step_info.end_time = None
d[unique_id] = step_info
else:
possible_states = ", ".join(s.value for s in step_state_guesses)
raise IOError(
f"The step '{unique_id}' is marked as being {step_info.state.value}, but we "
f"determined it can only be one of {{{possible_states}}}. If you are positive "
f"this is a screw-up, delete the directory at '{step_dir}' and try again."
)
return step_info
result = find_or_add_step_info(step_or_unique_id)
d.commit()
self._fix_step_info(result)
return result
def _step_lock_file(self, step_or_unique_id: Union[Step, str]) -> Path:
step_dir = self.step_dir(step_or_unique_id)
step_dir.mkdir(parents=True, exist_ok=True)
return step_dir / "lock"
def step_starting(self, step: Step) -> None:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return
# Gather the existing step info first. Step info automatically fixes itself if steps are
# marked as "running" but are not locked. This happens, for example, when a process
# gets killed. To make sure this works, we have to get the step info before we start
# messing with locks.
step_info = self.step_info(step)
if step_info.state not in {StepState.INCOMPLETE, StepState.FAILED}:
raise StepStateError(
step,
step_info.state,
context="If you are certain the step is not running somewhere else, delete the lock "
f"file at {self._step_lock_file(step)}.",
)
if step_info.state == StepState.FAILED:
# Refresh environment metadata since it might be out-of-date now.
step_info.refresh()
lock = FileLock(self._step_lock_file(step), read_only_ok=True)
lock.acquire_with_updates(desc=f"acquiring lock for '{step.name}'")
self.locks[step] = lock
try:
step_info.start_time = utc_now_datetime()
step_info.end_time = None
step_info.error = None
step_info.result_location = None
with SqliteDict(self.step_info_file) as d:
d[step.unique_id] = step_info
d.commit()
except: # noqa: E722
lock.release()
del self.locks[step]
raise
def step_finished(self, step: Step, result: T) -> T:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return result
lock = self.locks[step]
step_info = self.step_info(step)
if step_info.state != StepState.RUNNING:
raise StepStateError(step, step_info.state)
self.step_cache[step] = result
if hasattr(result, "__next__"):
assert isinstance(result, Iterator)
# Caching the iterator will consume it, so we write it to the cache and then read from the cache
# for the return value.
result = self.step_cache[step]
# Mark the step as finished
step_info.end_time = utc_now_datetime()
step_info.result_location = str(self.step_dir(step).absolute())
with SqliteDict(self.step_info_file) as d:
d[step.unique_id] = step_info
d.commit()
lock.release()
del self.locks[step]
return result
def step_failed(self, step: Step, e: BaseException) -> None:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return
lock = self.locks[step]
try:
step_info = self.step_info(step)
if step_info.state != StepState.RUNNING:
raise StepStateError(step, step_info.state)
step_info.end_time = utc_now_datetime()
step_info.error = exception_to_string(e)
with SqliteDict(self.step_info_file) as d:
d[step.unique_id] = step_info
d.commit()
finally:
lock.release()
del self.locks[step]
def remove_step(self, step_unique_id: str) -> None:
"""
Get Step unique id from the user and remove the step information from cache
:raises KeyError: If no step with the unique name found in the cache dir
"""
with SqliteDict(self.step_info_file) as d:
try:
step_info = self.step_info(step_unique_id)
del d[step_unique_id]
d.commit()
del self.cache[step_info]
except KeyError:
raise KeyError(f"No step named '{step_unique_id}' found")
def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
# sanity check targets
targets = list(targets)
if name is None:
while name is None or (self.runs_dir / name).exists():
name = petname.generate()
run_dir = self.runs_dir / name
# clean any existing run directory
if run_dir.exists():
for filename in run_dir.iterdir():
filename.unlink()
else:
run_dir.mkdir(parents=True, exist_ok=True)
# write step info for all steps
all_steps = set(targets)
for step in targets:
all_steps |= step.recursive_dependencies
self._save_registered_run(name, all_steps)
# write targets
for target in targets:
if target.cache_results:
target_path = self.step_dir(target)
(run_dir / target.name).symlink_to(os.path.relpath(target_path, run_dir))
self.latest_dir.unlink(missing_ok=True)
self.latest_dir.symlink_to(run_dir)
return self.registered_run(name)
def registered_runs(self) -> Dict[str, Run]:
return {
str(run_dir.name): self.registered_run(run_dir.name)
for run_dir in self.runs_dir.iterdir()
if run_dir.is_dir()
}
def search_step_info(
self,
*,
sort_by: Optional[StepInfoSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
state: Optional[StepState] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[StepInfo]:
with SqliteDict(self.step_info_file, flag="r") as d:
steps = [
step
for step in d.values()
if (match is None or match in step.unique_id)
and (state is None or step.state == state)
]
if sort_by == StepInfoSort.START_TIME:
now = utc_now_datetime()
steps = sorted(
steps,
key=lambda step: step.start_time or now,
reverse=sort_descending,
)
elif sort_by == StepInfoSort.UNIQUE_ID:
steps = sorted(steps, key=lambda step: step.unique_id, reverse=sort_descending)
elif sort_by is not None:
raise NotImplementedError
return steps[slice(start, stop)]
def registered_run(self, name: str) -> Run:
run_dir = self.runs_dir / name
if not run_dir.is_dir():
raise KeyError(name)
steps_for_run = self._load_registered_run(name)
return Run(name, steps_for_run, datetime.fromtimestamp(run_dir.stat().st_ctime))
def _run_step_info_file(self, name: str) -> Path:
return self.runs_dir / name / "stepinfo.json"
def _save_registered_run(self, name: str, all_steps: Iterable[Step]) -> None:
step_unique_ids = {}
with SqliteDict(self.step_info_file) as d:
for step in all_steps:
try:
step_info = d[step.unique_id]
step_info.name = step.name
d[step.unique_id] = step_info
except KeyError:
d[step.unique_id] = StepInfo.new_from_step(step)
step_unique_ids[step.name] = step.unique_id
d.commit()
run_step_info_file = self._run_step_info_file(name)
with open(run_step_info_file, "w") as file_ref:
json.dump(step_unique_ids, file_ref)
def _load_registered_run(self, name: str) -> Dict[str, StepInfo]:
run_step_info_file = self._run_step_info_file(name)
try:
with open(run_step_info_file, "r") as file_ref:
step_ids = json.load(file_ref)
except FileNotFoundError:
# for backwards compatibility
run_dir = self.runs_dir / name
step_ids = {}
for step_symlink in run_dir.iterdir():
if not step_symlink.is_symlink():
continue
step_name = str(step_symlink.name)
unique_id = str(step_symlink.resolve().name)
step_ids[step_name] = unique_id
with SqliteDict(self.step_info_file, flag="r") as d:
steps_for_run = {}
for step_name, unique_id in step_ids.items():
step_info = d[unique_id]
assert isinstance(step_info, StepInfo)
self._fix_step_info(step_info)
steps_for_run[step_name] = step_info
return steps_for_run
def run_dir(self, name: str) -> Path:
"""
Returns the directory where a given run is stored.
:param name: The name of the run
:return: The directory where the results of the run are stored
If the run does not exist, this returns the directory where it will be stored if you call
:meth:`register_run()` with that name.
"""
return self.runs_dir / name
def capture_logs_for_run(self, name: str):
return file_handler(self.run_dir(name) / "out.log")
================================================
FILE: tango/workspaces/memory_workspace.py
================================================
import copy
from typing import Dict, Iterable, Iterator, Optional, TypeVar, Union
from urllib.parse import ParseResult
import petname
from tango.common.exceptions import StepStateError
from tango.common.util import exception_to_string, utc_now_datetime
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_caches import default_step_cache
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, Workspace
T = TypeVar("T")
@Workspace.register("memory")
class MemoryWorkspace(Workspace):
"""
This is a workspace that keeps all its data in memory. This is useful for debugging or for quick jobs, but of
course you don't get any caching across restarts.
.. tip::
Registered as a :class:`~tango.workspace.Workspace` under the name "memory".
"""
def __init__(self):
super().__init__()
self.unique_id_to_info: Dict[str, StepInfo] = {}
self.runs: Dict[str, Run] = {}
@property
def url(self) -> str:
return "memory://"
@classmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> "Workspace":
return cls()
@property
def step_cache(self) -> StepCache:
return default_step_cache
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
unique_id = (
step_or_unique_id.unique_id
if isinstance(step_or_unique_id, Step)
else step_or_unique_id
)
try:
return self.unique_id_to_info[unique_id]
except KeyError:
if isinstance(step_or_unique_id, Step):
step = step_or_unique_id
return StepInfo.new_from_step(step)
else:
raise KeyError()
def step_starting(self, step: Step) -> None:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return
self.unique_id_to_info[step.unique_id] = StepInfo.new_from_step(
step, start_time=utc_now_datetime()
)
def step_finished(self, step: Step, result: T) -> T:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return result
existing_step_info = self.unique_id_to_info[step.unique_id]
if existing_step_info.state != StepState.RUNNING:
raise StepStateError(step, existing_step_info.state)
existing_step_info.end_time = utc_now_datetime()
if step.cache_results:
self.step_cache[step] = result
if hasattr(result, "__next__"):
assert isinstance(result, Iterator)
# Caching the iterator will consume it, so we write it to the cache and then read from the cache
# for the return value.
return self.step_cache[step]
return result
def step_failed(self, step: Step, e: BaseException) -> None:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return
assert e is not None
existing_step_info = self.unique_id_to_info[step.unique_id]
if existing_step_info.state != StepState.RUNNING:
raise StepStateError(step, existing_step_info.state)
existing_step_info.end_time = utc_now_datetime()
existing_step_info.error = exception_to_string(e)
def remove_step(self, step_unique_id: str) -> None:
"""
Get Step unique id from the user and remove the step information from memory cache
:raises KeyError: If no step with the unique name found in the cache dir
"""
try:
step_info = self.step_info(step_unique_id)
del self.unique_id_to_info[step_unique_id]
del self.step_cache[step_info]
except KeyError:
raise KeyError(f"{step_unique_id} step info not found, step cache cannot be deleted")
def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
if name is None:
name = petname.generate()
steps: Dict[str, StepInfo] = {}
for step in targets:
step_info = StepInfo.new_from_step(step)
self.unique_id_to_info[step.unique_id] = step_info
steps[step.unique_id] = step_info
run = Run(name, steps, utc_now_datetime())
self.runs[name] = run
return run
def registered_runs(self) -> Dict[str, Run]:
return copy.deepcopy(self.runs)
def registered_run(self, name: str) -> Run:
return copy.deepcopy(self.runs[name])
default_workspace = MemoryWorkspace()
================================================
FILE: tango/workspaces/remote_workspace.py
================================================
import logging
import tempfile
import warnings
from abc import abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Generator, Iterable, Iterator, Optional, Tuple, TypeVar, Union
from urllib.parse import ParseResult
from tango.common.exceptions import StepStateError
from tango.common.logging import file_handler
from tango.common.remote_utils import RemoteConstants
from tango.common.util import exception_to_string, tango_cache_dir, utc_now_datetime
from tango.step import Step
from tango.step_caches.remote_step_cache import RemoteStepCache
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, Workspace
T = TypeVar("T")
logger = logging.getLogger(__name__)
class RemoteWorkspace(Workspace):
"""
This is a :class:`~tango.workspace.Workspace` that stores step artifacts on some remote storage location.
.. tip::
All remote workspaces inherit from this.
"""
Constants = RemoteConstants
NUM_CONCURRENT_WORKERS: int = 9
@property
@abstractmethod
def cache(self) -> RemoteStepCache:
raise NotImplementedError()
@property
@abstractmethod
def steps_dir_name(self) -> str:
raise NotImplementedError()
@property
@abstractmethod
def locks(self) -> Dict:
raise NotImplementedError()
@property
def steps_dir(self) -> Path:
return tango_cache_dir() / self.steps_dir_name
@property
@abstractmethod
def url(self) -> str:
raise NotImplementedError()
@classmethod
@abstractmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:
raise NotImplementedError()
@property
def step_cache(self) -> RemoteStepCache:
return self.cache
def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path:
unique_id = (
step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id
)
path = self.steps_dir / unique_id
path.mkdir(parents=True, exist_ok=True)
return path
def work_dir(self, step: Step) -> Path:
path = self.step_dir(step) / "work"
path.mkdir(parents=True, exist_ok=True)
return path
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
raise NotImplementedError()
@abstractmethod
def _remote_lock(self, step: Step):
raise NotImplementedError()
@abstractmethod
def _step_location(self, step: Step) -> str:
raise NotImplementedError()
def step_starting(self, step: Step) -> None:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return
# Get local file lock + remote dataset lock.
lock = self._remote_lock(step)
lock.acquire()
self.locks[step] = lock
step_info = self.step_info(step)
if step_info.state == StepState.RUNNING:
# Since we've acquired the step lock we know this step can't be running
# elsewhere. But the step state can still say its running if the last
warnings.warn(
f"Step info for step '{step.unique_id}' is invalid - says step is running "
"although it shouldn't be. Ignoring and overwriting step start time.",
UserWarning,
)
elif step_info.state not in {StepState.INCOMPLETE, StepState.FAILED, StepState.UNCACHEABLE}:
self.locks.pop(step).release()
raise StepStateError(
step,
step_info.state,
context=f"If you are certain the step is not running somewhere else, delete the step "
f"datasets at {self._step_location(step)}",
)
if step_info.state == StepState.FAILED:
# Refresh the environment metadata since it might be out-of-date now.
step_info.refresh()
# Update StepInfo to mark as running.
try:
step_info.start_time = utc_now_datetime()
step_info.end_time = None
step_info.error = None
step_info.result_location = None
self._update_step_info(step_info)
except: # noqa: E722
self.locks.pop(step).release()
raise
def step_finished(self, step: Step, result: T) -> T:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return result
step_info = self.step_info(step)
if step_info.state != StepState.RUNNING:
raise StepStateError(step, step_info.state)
# Update step info and save step execution metadata.
# This needs to be done *before* adding the result to the cache, since adding
# the result to the cache will commit the step dataset, making it immutable.
step_info.end_time = utc_now_datetime()
step_info.result_location = self._step_location(step)
self._update_step_info(step_info)
self.cache[step] = result
if hasattr(result, "__next__"):
assert isinstance(result, Iterator)
# Caching the iterator will consume it, so we write it to the cache and then read from the cache
# for the return value.
result = self.cache[step]
self.locks.pop(step).release()
return result
def step_failed(self, step: Step, e: BaseException) -> None:
# We don't do anything with uncacheable steps.
if not step.cache_results:
return
try:
step_info = self.step_info(step)
if step_info.state != StepState.RUNNING:
raise StepStateError(step, step_info.state)
step_info.end_time = utc_now_datetime()
step_info.error = exception_to_string(e)
self._update_step_info(step_info)
finally:
self.locks.pop(step).release()
def remove_step(self, step_unique_id: str) -> None:
"""
Get Step unique id from the user and remove the step information from cache
:raises KeyError: If no step with the unique name found in the cache dir
"""
try:
step_info = self.step_info(step_unique_id)
# remove remote objects
self._remove_step_info(step_info)
# remove cache info
del self.cache[step_info]
except KeyError:
raise KeyError(f"No step named '{step_unique_id}' found.")
return None
def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]:
import concurrent.futures
all_steps = set(targets)
for step in targets:
all_steps |= step.recursive_dependencies
steps: Dict[str, StepInfo] = {}
run_data: Dict[str, str] = {}
# Collect step info.
with concurrent.futures.ThreadPoolExecutor(
thread_name_prefix="RemoteWorkspace._get_run_step_info()-"
) as executor:
step_info_futures = []
for step in all_steps:
if step.name is None:
continue
step_info_futures.append(executor.submit(self.step_info, step))
for future in concurrent.futures.as_completed(step_info_futures):
step_info = future.result()
assert step_info.step_name is not None
steps[step_info.step_name] = step_info
run_data[step_info.step_name] = step_info.unique_id
return steps, run_data
@abstractmethod
def _save_run(
self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None
) -> Run:
raise NotImplementedError()
def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
steps, run_data = self._get_run_step_info(targets)
run = self._save_run(steps, run_data, name)
return run
@abstractmethod
def _save_run_log(self, name: str, log_file: Path):
raise NotImplementedError()
@contextmanager
def capture_logs_for_run(self, name: str) -> Generator[None, None, None]:
with tempfile.TemporaryDirectory() as tmp_dir_name:
log_file = Path(tmp_dir_name) / "out.log"
try:
with file_handler(log_file):
yield None
finally:
self._save_run_log(name, log_file)
@abstractmethod
def _update_step_info(self, step_info: StepInfo):
raise NotImplementedError()
@abstractmethod
def _remove_step_info(self, step_info: StepInfo):
raise NotImplementedError()
================================================
FILE: test_fixtures/__init__.py
================================================
================================================
FILE: test_fixtures/beaker/nvidia_smi.yml
================================================
# Used to test that GPUs in a cluster are available. Submit this to beaker with:
# $ beaker experiment create test_fixtures/beaker/nvidia_smi.yml --workspace ai2/tango-testing --name tango-test-1
version: v2-alpha
description: NvidiaSMI
tasks:
- name: nvidia-smi
image:
docker: nvidia/cuda:11.0-base
command: [nvidia-smi]
result:
path: '/unused'
resources:
gpuCount: 2
context:
cluster: ai2/tango-gpu-tests
priority: normal
================================================
FILE: test_fixtures/common/params_example.jsonnet
================================================
{
"model": {
"type": "classifier",
"num_classes": 3,
"layers": [
{
"type": "ff",
"activation": "relu",
},
{
"type": "ff",
"activation": "softmax",
},
]
},
"data_path": "data.txt",
}
================================================
FILE: test_fixtures/common/params_example.yaml
================================================
model:
type: classifier
num_classes: 3
layers:
- type: ff
activation: relu
- type: ff
activation: softmax
data_path: data.txt
================================================
FILE: test_fixtures/experiment/hello_world.jsonnet
================================================
{
"steps": {
"hello": {"type": "string", "result": "Hello"},
"hello_world": {
"type": "concat_strings",
"string1": {"type": "ref", "ref": "hello"},
"string2": "World!",
"join_with": ", ",
},
},
}
================================================
FILE: test_fixtures/experiment/logging_check.jsonnet
================================================
{
"steps": {
"stringA": {"type": "logging-step", "string": "This is a logging test.", "num_log_lines": 5},
"stringB": {
"type": "concat_strings",
"string1": {"type": "ref", "ref": "stringA"},
"string2": "This is being logged."
},
"stringC": {"type": "logging-step", "string": "This is also a logging test.", "num_log_lines": 5},
"final_string": {
"type": "logging-step",
"string": {"type": "ref", "ref": "stringB"},
"num_log_lines": 3
},
"multiprocessing_result": {
"type": "multiprocessing_step",
}
}
}
================================================
FILE: test_fixtures/experiment/multiprocessing.jsonnet
================================================
{
"steps": {
"result": {
"type": "multiprocessing_step",
}
}
}
================================================
FILE: test_fixtures/experiment/noisy.jsonnet
================================================
{
steps: {
hello_world: { type: "string", result: "Hello, World!" },
noisy_step: { type: "noisy_step" },
}
}
================================================
FILE: test_fixtures/experiment/random.jsonnet
================================================
{
"steps": {
"rand_string1": {"type": "random_string", "length": 5},
"rand_string2": {"type": "random_string", "length": 5},
"string1": {
"type": "concat_strings",
"string1": {"type": "ref", "ref": "rand_string1"},
"string2": {"type": "ref", "ref": "rand_string2"},
},
"string2": {
"type": "string",
"result": "foo",
},
"final_string": {
"type": "concat_strings",
"string1": {"type": "ref", "ref": "string1"},
"string2": {"type": "ref", "ref": "string2"},
}
}
}
================================================
FILE: test_fixtures/integrations/__init__.py
================================================
================================================
FILE: test_fixtures/integrations/common/__init__.py
================================================
import torch
from torch.utils.data import IterableDataset
from tango import Step
from tango.common import DatasetDict, IterableDatasetDict
@Step.register("generate_data")
class GenerateData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self) -> DatasetDict: # type: ignore[override]
torch.manual_seed(1)
return DatasetDict(
{
"train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)],
"validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)],
}
)
class RandomIterableDataset(IterableDataset):
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data)
@Step.register("generate_streaming_data")
class GenerateStreamingData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self) -> IterableDatasetDict: # type: ignore[override]
torch.manual_seed(1)
return IterableDatasetDict(
{
"train": RandomIterableDataset(
[{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)]
),
"validation": RandomIterableDataset(
[{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)]
),
}
)
================================================
FILE: test_fixtures/integrations/datasets/config.json
================================================
{
"steps": {
"train_data": {
"type": "datasets::load",
"path": "lhoestq/test",
"split": "train"
},
"dev_data": {
"type": "datasets::load",
"path": "lhoestq/test",
"split": "validation"
},
"all_data": {
"type": "datasets::concatenate",
"datasets": [
{
"type": "ref",
"ref": "train_data"
},
{
"type": "ref",
"ref": "dev_data"
}
]
},
"mixed_data": {
"type": "datasets::interleave",
"datasets": [
{
"type": "ref",
"ref": "train_data"
},
{
"type": "ref",
"ref": "dev_data"
}
],
"probabilities": [0.9, 0.1]
}
}
}
================================================
FILE: test_fixtures/integrations/fairscale/__init__.py
================================================
================================================
FILE: test_fixtures/integrations/fairscale/components.py
================================================
import torch
import torch.nn as nn
from tango import Step
from tango.common import DatasetDict
from tango.integrations.torch import Model
from tango.integrations.torch.util import set_seed_all
class FeedForward(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 4)
self.activation = nn.ReLU()
def forward(self, x):
return self.activation(self.linear(x))
@Model.register("simple_regression_model", exist_ok=True)
class SimpleRegressionModel(Model):
def __init__(self):
super().__init__()
self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)])
self.regression_head = nn.Linear(4, 1)
self.loss_fcn = nn.MSELoss()
def forward(self, x, y):
output = self.blocks(x)
output = self.regression_head(output)
loss = self.loss_fcn(output, y)
return {"loss": loss}
@Step.register("simple_regression_data")
class SimpleRegressionDataStep(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self, seed: int = 317) -> DatasetDict: # type: ignore
set_seed_all(seed)
def get_data(n: int):
return [{"x": torch.randn(4), "y": torch.randn(1)} for _ in range(n)]
dataset_dict = DatasetDict(splits={"train": get_data(32), "dev": get_data(16)})
return dataset_dict
================================================
FILE: test_fixtures/integrations/fairscale/config.jsonnet
================================================
local pretrained_model = "sshleifer/tiny-gpt2";
####################
# Trainer settings #
####################
local training_steps = 4;
local validate_every = 4;
local devices = 2;
local grad_accum = 1;
local batch_size = 2;
local activation_checkpointing = true;
local amp = false;
local fsdp = true;
local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.
######################
# Optimizer settings #
######################
local warmup_steps = 2;
local learning_rate = 0.005;
local fsdp_config = {
reshard_after_forward: true,
move_params_to_cpu: cpu_offloading,
move_grads_to_cpu: cpu_offloading,
mixed_precision: amp,
};
local training_engine = {
type: "fairscale",
optimizer: {
type: "torch::AdamW",
lr: learning_rate,
betas: [0.9, 0.95],
eps: 1e-6,
},
amp: amp,
fsdp_config: fsdp_config,
};
local dataloader = {
batch_size: batch_size,
sampler: {
type: "torch::DistributedSampler",
shuffle: true,
drop_last: true,
},
};
{
steps: {
regression_data: {
type: "simple_regression_data",
},
trained_model: {
type: "torch::train",
model: {
type: "fairscale::with_wrapped_modules",
model: {
type: "simple_regression_model",
},
modules_to_wrap: ["blocks\\.[0-9]+"],
fsdp_config: fsdp_config,
activation_checkpointing: activation_checkpointing,
},
training_engine: training_engine,
dataset_dict: { type: "ref", ref: "regression_data" },
train_dataloader: dataloader,
validation_split: "dev",
grad_accum: grad_accum,
train_steps: training_steps,
validate_every: training_steps,
validation_steps: 2,
checkpoint_every: training_steps,
log_every: 1,
device_count: devices,
},
}
}
================================================
FILE: test_fixtures/integrations/flax/__init__.py
================================================
================================================
FILE: test_fixtures/integrations/flax/config.jsonnet
================================================
{
"steps": {
"data_full": {
"type": "datasets::load",
"path": "iohadrubin/mini_xsum",
},
"data": {
"type": "datasets::dataset_remix",
"input": {"type": "ref", "ref": "data_full"},
"new_splits": {"train": "train[:20]", "validation": "validation[:20]"},
},
"tokenize": {
"type": "tokenize_data",
"dataset": {
"type": "ref",
"ref": "data"
}
},
"train": {
"type": "flax::train",
"model": {
"type" : "transformers::FlaxAutoModelForSeq2SeqLM::from_pretrained",
"pretrained_model_name_or_path" : "t5-small"
},
"dataset": {
"type": "ref",
"ref": "tokenize"
},
"optimizer": {
"type" : "optax::adamw",
"learning_rate" : 2e-5
},
"train_dataloader": {
"batch_size": 16,
"drop_last": true
},
"wrapper": {
"type": "xsum_wrapper"
},
"train_split": "train",
"validation_split" : "validation",
"validate_every" : 1,
"validation_dataloader": {
"batch_size": 16,
"drop_last": true
},
"train_epoch": 1,
"checkpoint_every": 1,
"log_every": 1
}
}
}
================================================
FILE: test_fixtures/integrations/flax/xsum.py
================================================
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.common_utils import onehot
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSeq2SeqLM
from tango.integrations.flax import FlaxWrapper
from tango.step import Step
"""
A minimal xsum t5-small config for testing.
"""
@Step.register("tokenize_data")
class PreProcessing(Step):
DETERMINISTIC = False
def run(self, dataset):
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = FlaxAutoModelForSeq2SeqLM.from_pretrained("t5-small")
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
config = AutoConfig.from_pretrained("t5-small")
MAX_SOURCE_LENGTH = 512
MAX_TGT_LENGTH = 64
def preprocess_function(examples):
inputs = examples["document"]
targets = examples["summary"]
inputs = [inp for inp in inputs]
model_inputs = tokenizer(
inputs,
max_length=MAX_SOURCE_LENGTH,
padding="max_length",
truncation=True,
return_tensors="np",
)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(
targets,
max_length=MAX_TGT_LENGTH,
padding="max_length",
truncation=True,
return_tensors="np",
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
column_names = dataset["train"].column_names
dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=column_names,
desc="Running tokenizer on dataset",
)
return dataset
@FlaxWrapper.register("xsum_wrapper") # type: ignore
class TransformerWrapper(FlaxWrapper):
def train_metrics(self, state, batch, labels):
# return empty dict if no other metrics to compute
return {}
def loss_helper(self, logits, labels, batch):
label_smoothing_factor = 0
padding_mask = batch["decoder_attention_mask"]
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence)
+ (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum() / padding_mask.sum()
return loss
def train_loss(self, params, state, batch, dropout_rng, labels):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = self.loss_helper(logits, labels, batch)
return loss
def val_metrics(self, batch, logits, labels):
loss = self.loss_helper(logits, labels, batch)
metrics = {"loss": loss}
return metrics
def eval_metrics(self, batch, logits, labels):
loss = self.loss_helper(logits, labels, batch)
metrics = {"loss": loss}
return metrics
================================================
FILE: test_fixtures/integrations/torch/__init__.py
================================================
import torch.nn as nn
from tango.integrations.torch import Model
@Model.register("basic_regression")
class BasicRegression(Model):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
self.sigmoid = nn.Sigmoid()
self.mse = nn.MSELoss()
def forward(self, x, y=None):
pred = self.sigmoid(self.linear(x))
out = {"pred": pred}
if y is not None:
out["loss"] = self.mse(pred, y)
return out
def _to_params(self):
return {}
================================================
FILE: test_fixtures/integrations/torch/eval.jsonnet
================================================
{
"steps": {
"data": {
"type": "generate_data",
},
"eval": {
"type": "torch::eval",
"model": {
"type": "basic_regression",
},
"dataset_dict": {
"type": "ref",
"ref": "data"
},
"dataloader": {
"batch_size": 8,
"shuffle": true
},
"test_split": "validation",
"log_every": 1
}
}
}
================================================
FILE: test_fixtures/integrations/torch/train.jsonnet
================================================
{
"steps": {
"data": {
"type": "generate_data",
},
"train": {
"type": "torch::train",
"model": {
"type": "basic_regression",
},
"training_engine": {
"optimizer": {
"type": "torch::Adam",
},
},
"dataset_dict": {
"type": "ref",
"ref": "data"
},
"train_dataloader": {
"batch_size": 8,
"shuffle": true
},
"validation_split": "validation",
"validation_dataloader": {
"batch_size": 8,
"shuffle": false
},
"train_steps": 100,
"validate_every": 10,
"checkpoint_every": 10,
"log_every": 1
}
}
}
================================================
FILE: test_fixtures/integrations/torch/train_dist.jsonnet
================================================
{
"steps": {
"data": {
"type": "generate_data",
},
"train": {
"type": "torch::train",
"model": {
"type": "basic_regression",
},
"training_engine": {
"optimizer": {
"type": "torch::Adam",
},
},
"dataset_dict": {
"type": "ref",
"ref": "data",
},
"train_dataloader": {
"batch_size": 8,
"sampler": {
"type": "torch::DistributedSampler",
"shuffle": true,
"drop_last": true,
}
},
"validation_split": "validation",
"validation_dataloader": {
"batch_size": 8,
"sampler": {
"type": "torch::DistributedSampler",
"shuffle": true,
"drop_last": true,
}
},
"train_steps": 100,
"validate_every": 10,
"checkpoint_every": 10,
"log_every": 1,
"device_count": 2,
}
}
}
================================================
FILE: test_fixtures/integrations/torch/train_streaming.jsonnet
================================================
{
"steps": {
"data": {
"type": "generate_streaming_data",
},
"train": {
"type": "torch::train",
"model": {
"type": "basic_regression",
},
"training_engine": {
"optimizer": {
"type": "torch::Adam",
},
},
"dataset_dict": {
"type": "ref",
"ref": "data"
},
"train_dataloader": {
"batch_size": 8,
"shuffle": true
},
"train_steps": 100,
"checkpoint_every": 10,
"log_every": 1
}
}
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/cache-metadata.json
================================================
{
"step": "AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP"
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/conda-environment.yaml
================================================
name: tango
channels:
- pytorch
- defaults
dependencies:
- appnope=0.1.2=py38hecd8cb5_1001
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2021.10.26=hecd8cb5_2
- certifi=2021.10.8=py38hecd8cb5_0
- decorator=5.1.0=pyhd3eb1b0_0
- ffmpeg=4.3=h0a44026_0
- freetype=2.11.0=hd8bbffd_0
- gettext=0.21.0=h7535e17_0
- giflib=5.2.1=haf1e3a3_0
- gmp=6.2.1=h23ab428_2
- gnutls=3.6.15=hed9c0bf_0
- icu=58.2=h0a44026_3
- intel-openmp=2021.4.0=hecd8cb5_3538
- ipython=7.29.0=py38h01d92e1_0
- jedi=0.18.0=py38hecd8cb5_1
- jpeg=9d=h9ed2024_0
- lame=3.100=h1de35cc_0
- lcms2=2.12=hf1fd2bf_0
- libcxx=12.0.0=h2f01273_0
- libffi=3.3=hb1e8313_2
- libiconv=1.16=h1de35cc_0
- libidn2=2.3.2=h9ed2024_0
- libpng=1.6.37=ha441bb4_0
- libtasn1=4.16.0=h9ed2024_0
- libtiff=4.2.0=h87d7836_0
- libunistring=0.9.10=h9ed2024_0
- libuv=1.40.0=haf1e3a3_0
- libwebp=1.2.0=hacca55c_0
- libwebp-base=1.2.0=h9ed2024_0
- libxml2=2.9.12=hcdb78fc_0
- llvm-openmp=12.0.0=h0dcd299_1
- lz4-c=1.9.3=h23ab428_1
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mkl=2021.4.0=hecd8cb5_637
- mkl-service=2.4.0=py38h9ed2024_0
- mkl_fft=1.3.1=py38h4ab4a9b_0
- mkl_random=1.2.2=py38hb2f4e1b_0
- ncurses=6.3=hca72f7f_1
- nettle=3.7.3=h230ac6f_1
- numpy=1.21.2=py38h4b4dc7a_0
- numpy-base=1.21.2=py38he0bd621_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd9629dc_0
- openssl=1.1.1l=h9ed2024_0
- parso=0.8.2=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h98e4679_0
- pip=21.2.4=py38hecd8cb5_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pygments=2.10.0=pyhd3eb1b0_0
- python=3.8.12=h88f2d9e_0
- pytorch=1.10.0=py3.8_0
- readline=8.1=h9ed2024_0
- setuptools=58.0.4=py38hecd8cb5_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hce871da_0
- tk=8.6.11=h7bc2e8c_0
- torchaudio=0.10.0=py38_cpu
- torchvision=0.11.1=py38_cpu
- traitlets=5.1.0=pyhd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h1de35cc_0
- zlib=1.2.11=h1de35cc_3
- zstd=1.4.9=h322a384_0
- pip:
- absl-py==0.15.0
- aiohttp==3.8.0
- aiosignal==1.2.0
- alabaster==0.7.12
- astunparse==1.6.3
- async-timeout==4.0.0
- attrs==21.2.0
- babel==2.9.1
- base58==2.1.1
- beautifulsoup4==4.10.0
- black==21.12b0
- bleach==4.1.0
- boto3==1.19.12
- botocore==1.22.12
- cached-path==1.0.0
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- click-help-colors==0.9.1
- codecov==2.1.12
- colorama==0.4.4
- configparser==5.1.0
- coverage==6.1.1
- datasets==1.15.1
- dill==0.3.4
- docker-pycreds==0.4.0
- docutils==0.17.1
- filelock==3.4.0
- flake8==4.0.1
- flaky==3.7.0
- flatbuffers==2.0
- frozenlist==1.2.0
- fsspec==2021.11.0
- furo==2022.1.2
- future==0.18.2
- gast==0.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- glob2==0.7
- google-api-core==2.2.2
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- google-cloud-core==2.1.0
- google-cloud-storage==1.42.3
- google-crc32c==1.3.0
- google-pasta==0.2.0
- google-resumable-media==2.1.0
- googleapis-common-protos==1.53.0
- grpcio==1.41.1
- h5py==3.6.0
- huggingface-hub==0.1.1
- idna==3.3
- imagesize==1.2.0
- importlib-metadata==4.8.1
- iniconfig==1.1.1
- isort==5.10.1
- jinja2==3.0.2
- jmespath==0.10.0
- joblib==1.1.0
- jsonnet==0.17.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- keyring==23.2.1
- libclang==12.0.0
- livereload==2.6.3
- markdown==3.3.4
- markdown-it-py==1.1.0
- markupsafe==2.0.1
- mccabe==0.6.1
- mdit-py-plugins==0.3.0
- more-itertools==8.10.0
- multidict==5.2.0
- multiprocess==0.70.12.2
- mypy==0.931
- mypy-extensions==0.4.3
- myst-parser==0.16.1
- nltk==3.6.7
- oauthlib==3.1.1
- opt-einsum==3.3.0
- overrides==6.1.0
- packaging==21.2
- pandas==1.3.4
- pathspec==0.9.0
- pathtools==0.1.2
- petname==2.6
- pkginfo==1.7.1
- platformdirs==2.4.0
- pluggy==1.0.0
- promise==2.3
- protobuf==3.19.1
- psutil==5.8.0
- py==1.11.0
- pyarrow==6.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycodestyle==2.8.0
- pydeprecate==0.3.1
- pyflakes==2.4.0
- pyparsing==2.4.7
- pytest==6.2.5
- pytest-cov==3.0.0
- pytest-sphinx==0.3.1
- python-dateutil==2.8.2
- pytorch-lightning==1.5.1
- pytz==2021.3
- pyyaml==6.0
- readme-renderer==30.0
- regex==2021.11.2
- requests==2.26.0
- requests-oauthlib==1.3.0
- requests-toolbelt==0.9.1
- rfc3986==1.5.0
- rouge-score==0.0.4
- rsa==4.7.2
- s3transfer==0.5.0
- sacremoses==0.0.46
- sentencepiece==0.1.96
- sentry-sdk==1.4.3
- shortuuid==1.0.1
- smmap==5.0.0
- snowballstemmer==2.1.0
- soupsieve==2.3
- sphinx==4.3.1
- sphinx-autobuild==2021.3.14
- sphinx-copybutton==0.4.0
- sphinxcontrib-applehelp==1.0.2
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlitedict==1.7.0
- subprocess32==3.5.4
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow==2.7.0
- tensorflow-estimator==2.7.0
- tensorflow-io-gcs-filesystem==0.23.1
- termcolor==1.1.0
- tokenizers==0.10.3
- toml==0.10.2
- tomli==1.2.2
- torchmetrics==0.6.0
- tornado==6.1
- tqdm==4.62.3
- transformers==4.12.3
- twine==3.5.0
- types-pyyaml==6.0.0
- types-setuptools==57.4.2
- typing-utils==0.1.0
- urllib3==1.26.7
- wandb==0.12.6
- webencodings==0.5.1
- werkzeug==2.0.2
- wrapt==1.13.3
- xxhash==2.0.2
- yarl==1.7.2
- yaspin==2.1.0
- zipp==3.6.0
prefix: /opt/miniconda3/envs/tango
================================================
FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/executor-metadata.json
================================================
{
"config": {
"type": "cadd",
"a": {
"type": "ref",
"ref": "CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk"
},
"b": {
"type": "ref",
"ref": "MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf"
}
},
"duration": 0.0007,
"finished_at": 1642546363.9658601,
"git": {
"commit": "8e09b66caffbff20fd0b1504c961932b97417e8d",
"remote": "https://github.com/allenai/tango.git"
},
"platform": {
"cpu_count": 16,
"executable": "/opt/miniconda3/envs/tango/bin/python",
"host": "ip-192-168-1-194.us-west-2.compute.internal",
"operating_system": "macOS-10.16-x86_64-i386-64bit",
"python": "3.8.12",
"root": "/Users/dirkg/Documents/tango/examples/euler",
"user": "dirkg"
},
"started_at": 1642546363.965193,
"step": "AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP",
"tango": {
"command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic",
"version": "0.4.0rc4"
}
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/lock
================================================
================================================
FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/requirements.txt
================================================
absl-py==0.15.0
ai2-tango==0.4.0rc1
aiohttp==3.8.0
aiosignal==1.2.0
alabaster==0.7.12
appnope==0.1.2
astunparse==1.6.3
async-timeout==4.0.0
attrs==21.2.0
babel==2.9.1
backcall==0.2.0
base58==2.1.1
beautifulsoup4==4.10.0
black==21.12b0
bleach==4.1.0
boto3==1.19.12
botocore==1.22.12
cached-path==1.0.0
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.7
click-help-colors==0.9.1
click==8.0.3
codecov==2.1.12
colorama==0.4.4
configparser==5.1.0
coverage==6.1.1
datasets==1.15.1
decorator==5.1.0
dill==0.3.4
docker-pycreds==0.4.0
docutils==0.17.1
filelock==3.4.0
flake8==4.0.1
flaky==3.7.0
flatbuffers==2.0
frozenlist==1.2.0
fsspec==2021.11.0
furo==2022.1.2
future==0.18.2
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
glob2==0.7
google-api-core==2.2.2
google-auth-oauthlib==0.4.6
google-auth==2.3.3
google-cloud-core==2.1.0
google-cloud-storage==1.42.3
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==2.1.0
googleapis-common-protos==1.53.0
grpcio==1.41.1
h5py==3.6.0
huggingface-hub==0.1.1
idna==3.3
imagesize==1.2.0
importlib-metadata==4.8.1
iniconfig==1.1.1
ipython==7.29.0
isort==5.10.1
jedi==0.18.0
jinja2==3.0.2
jmespath==0.10.0
joblib==1.1.0
jsonnet==0.17.0
keras-preprocessing==1.1.2
keras==2.7.0
keyring==23.2.1
libclang==12.0.0
livereload==2.6.3
markdown-it-py==1.1.0
markdown==3.3.4
markupsafe==2.0.1
matplotlib-inline==0.1.2
mccabe==0.6.1
mdit-py-plugins==0.3.0
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
more-itertools==8.10.0
multidict==5.2.0
multiprocess==0.70.12.2
mypy-extensions==0.4.3
mypy==0.931
myst-parser==0.16.1
nltk==3.6.7
numpy==1.21.2
oauthlib==3.1.1
olefile==0.46
opt-einsum==3.3.0
overrides==6.1.0
packaging==21.2
pandas==1.3.4
parso==0.8.2
pathspec==0.9.0
pathtools==0.1.2
petname==2.6
pexpect==4.8.0
pickleshare==0.7.5
pillow==8.4.0
pip==21.2.4
pkginfo==1.7.1
platformdirs==2.4.0
pluggy==1.0.0
promise==2.3
prompt-toolkit==3.0.20
protobuf==3.19.1
psutil==5.8.0
ptyprocess==0.7.0
py==1.11.0
pyarrow==6.0.0
pyasn1-modules==0.2.8
pyasn1==0.4.8
pycodestyle==2.8.0
pydeprecate==0.3.1
pyflakes==2.4.0
pygments==2.10.0
pyparsing==2.4.7
pytest-cov==3.0.0
pytest-sphinx==0.3.1
pytest==6.2.5
python-dateutil==2.8.2
pytorch-lightning==1.5.1
pytz==2021.3
pyyaml==6.0
readme-renderer==30.0
regex==2021.11.2
requests-oauthlib==1.3.0
requests-toolbelt==0.9.1
requests==2.26.0
rfc3986==1.5.0
rouge-score==0.0.4
rsa==4.7.2
s3transfer==0.5.0
sacremoses==0.0.46
sentencepiece==0.1.96
sentry-sdk==1.4.3
setuptools==58.0.4
shortuuid==1.0.1
six==1.16.0
smmap==5.0.0
snowballstemmer==2.1.0
soupsieve==2.3
sphinx-autobuild==2021.3.14
sphinx-copybutton==0.4.0
sphinx==4.3.1
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlitedict==1.7.0
subprocess32==3.5.4
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorboard==2.7.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.23.1
tensorflow==2.7.0
termcolor==1.1.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.2
torch==1.10.0
torchaudio==0.10.0
torchmetrics==0.6.0
torchvision==0.11.1
tornado==6.1
tqdm==4.62.3
traitlets==5.1.0
transformers==4.12.3
twine==3.5.0
types-pyyaml==6.0.0
types-setuptools==57.4.2
typing-extensions==3.10.0.2
typing-utils==0.1.0
urllib3==1.26.7
wandb==0.12.6
wcwidth==0.2.5
webencodings==0.5.1
werkzeug==2.0.2
wheel==0.37.0
wrapt==1.13.3
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zipp==3.6.0
================================================
FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/cache-metadata.json
================================================
{
"step": "CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk"
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/conda-environment.yaml
================================================
name: tango
channels:
- pytorch
- defaults
dependencies:
- appnope=0.1.2=py38hecd8cb5_1001
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2021.10.26=hecd8cb5_2
- certifi=2021.10.8=py38hecd8cb5_0
- decorator=5.1.0=pyhd3eb1b0_0
- ffmpeg=4.3=h0a44026_0
- freetype=2.11.0=hd8bbffd_0
- gettext=0.21.0=h7535e17_0
- giflib=5.2.1=haf1e3a3_0
- gmp=6.2.1=h23ab428_2
- gnutls=3.6.15=hed9c0bf_0
- icu=58.2=h0a44026_3
- intel-openmp=2021.4.0=hecd8cb5_3538
- ipython=7.29.0=py38h01d92e1_0
- jedi=0.18.0=py38hecd8cb5_1
- jpeg=9d=h9ed2024_0
- lame=3.100=h1de35cc_0
- lcms2=2.12=hf1fd2bf_0
- libcxx=12.0.0=h2f01273_0
- libffi=3.3=hb1e8313_2
- libiconv=1.16=h1de35cc_0
- libidn2=2.3.2=h9ed2024_0
- libpng=1.6.37=ha441bb4_0
- libtasn1=4.16.0=h9ed2024_0
- libtiff=4.2.0=h87d7836_0
- libunistring=0.9.10=h9ed2024_0
- libuv=1.40.0=haf1e3a3_0
- libwebp=1.2.0=hacca55c_0
- libwebp-base=1.2.0=h9ed2024_0
- libxml2=2.9.12=hcdb78fc_0
- llvm-openmp=12.0.0=h0dcd299_1
- lz4-c=1.9.3=h23ab428_1
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mkl=2021.4.0=hecd8cb5_637
- mkl-service=2.4.0=py38h9ed2024_0
- mkl_fft=1.3.1=py38h4ab4a9b_0
- mkl_random=1.2.2=py38hb2f4e1b_0
- ncurses=6.3=hca72f7f_1
- nettle=3.7.3=h230ac6f_1
- numpy=1.21.2=py38h4b4dc7a_0
- numpy-base=1.21.2=py38he0bd621_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd9629dc_0
- openssl=1.1.1l=h9ed2024_0
- parso=0.8.2=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h98e4679_0
- pip=21.2.4=py38hecd8cb5_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pygments=2.10.0=pyhd3eb1b0_0
- python=3.8.12=h88f2d9e_0
- pytorch=1.10.0=py3.8_0
- readline=8.1=h9ed2024_0
- setuptools=58.0.4=py38hecd8cb5_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hce871da_0
- tk=8.6.11=h7bc2e8c_0
- torchaudio=0.10.0=py38_cpu
- torchvision=0.11.1=py38_cpu
- traitlets=5.1.0=pyhd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h1de35cc_0
- zlib=1.2.11=h1de35cc_3
- zstd=1.4.9=h322a384_0
- pip:
- absl-py==0.15.0
- aiohttp==3.8.0
- aiosignal==1.2.0
- alabaster==0.7.12
- astunparse==1.6.3
- async-timeout==4.0.0
- attrs==21.2.0
- babel==2.9.1
- base58==2.1.1
- beautifulsoup4==4.10.0
- black==21.12b0
- bleach==4.1.0
- boto3==1.19.12
- botocore==1.22.12
- cached-path==1.0.0
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- click-help-colors==0.9.1
- codecov==2.1.12
- colorama==0.4.4
- configparser==5.1.0
- coverage==6.1.1
- datasets==1.15.1
- dill==0.3.4
- docker-pycreds==0.4.0
- docutils==0.17.1
- filelock==3.4.0
- flake8==4.0.1
- flaky==3.7.0
- flatbuffers==2.0
- frozenlist==1.2.0
- fsspec==2021.11.0
- furo==2022.1.2
- future==0.18.2
- gast==0.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- glob2==0.7
- google-api-core==2.2.2
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- google-cloud-core==2.1.0
- google-cloud-storage==1.42.3
- google-crc32c==1.3.0
- google-pasta==0.2.0
- google-resumable-media==2.1.0
- googleapis-common-protos==1.53.0
- grpcio==1.41.1
- h5py==3.6.0
- huggingface-hub==0.1.1
- idna==3.3
- imagesize==1.2.0
- importlib-metadata==4.8.1
- iniconfig==1.1.1
- isort==5.10.1
- jinja2==3.0.2
- jmespath==0.10.0
- joblib==1.1.0
- jsonnet==0.17.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- keyring==23.2.1
- libclang==12.0.0
- livereload==2.6.3
- markdown==3.3.4
- markdown-it-py==1.1.0
- markupsafe==2.0.1
- mccabe==0.6.1
- mdit-py-plugins==0.3.0
- more-itertools==8.10.0
- multidict==5.2.0
- multiprocess==0.70.12.2
- mypy==0.931
- mypy-extensions==0.4.3
- myst-parser==0.16.1
- nltk==3.6.7
- oauthlib==3.1.1
- opt-einsum==3.3.0
- overrides==6.1.0
- packaging==21.2
- pandas==1.3.4
- pathspec==0.9.0
- pathtools==0.1.2
- petname==2.6
- pkginfo==1.7.1
- platformdirs==2.4.0
- pluggy==1.0.0
- promise==2.3
- protobuf==3.19.1
- psutil==5.8.0
- py==1.11.0
- pyarrow==6.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycodestyle==2.8.0
- pydeprecate==0.3.1
- pyflakes==2.4.0
- pyparsing==2.4.7
- pytest==6.2.5
- pytest-cov==3.0.0
- pytest-sphinx==0.3.1
- python-dateutil==2.8.2
- pytorch-lightning==1.5.1
- pytz==2021.3
- pyyaml==6.0
- readme-renderer==30.0
- regex==2021.11.2
- requests==2.26.0
- requests-oauthlib==1.3.0
- requests-toolbelt==0.9.1
- rfc3986==1.5.0
- rouge-score==0.0.4
- rsa==4.7.2
- s3transfer==0.5.0
- sacremoses==0.0.46
- sentencepiece==0.1.96
- sentry-sdk==1.4.3
- shortuuid==1.0.1
- smmap==5.0.0
- snowballstemmer==2.1.0
- soupsieve==2.3
- sphinx==4.3.1
- sphinx-autobuild==2021.3.14
- sphinx-copybutton==0.4.0
- sphinxcontrib-applehelp==1.0.2
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlitedict==1.7.0
- subprocess32==3.5.4
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow==2.7.0
- tensorflow-estimator==2.7.0
- tensorflow-io-gcs-filesystem==0.23.1
- termcolor==1.1.0
- tokenizers==0.10.3
- toml==0.10.2
- tomli==1.2.2
- torchmetrics==0.6.0
- tornado==6.1
- tqdm==4.62.3
- transformers==4.12.3
- twine==3.5.0
- types-pyyaml==6.0.0
- types-setuptools==57.4.2
- typing-utils==0.1.0
- urllib3==1.26.7
- wandb==0.12.6
- webencodings==0.5.1
- werkzeug==2.0.2
- wrapt==1.13.3
- xxhash==2.0.2
- yarl==1.7.2
- yaspin==2.1.0
- zipp==3.6.0
prefix: /opt/miniconda3/envs/tango
================================================
FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/executor-metadata.json
================================================
{
"config": {
"type": "ccos",
"x": [
3.1415926535,
0
]
},
"duration": 0.0004,
"finished_at": 1642546350.3743181,
"git": {
"commit": "8e09b66caffbff20fd0b1504c961932b97417e8d",
"remote": "https://github.com/allenai/tango.git"
},
"platform": {
"cpu_count": 16,
"executable": "/opt/miniconda3/envs/tango/bin/python",
"host": "ip-192-168-1-194.us-west-2.compute.internal",
"operating_system": "macOS-10.16-x86_64-i386-64bit",
"python": "3.8.12",
"root": "/Users/dirkg/Documents/tango/examples/euler",
"user": "dirkg"
},
"started_at": 1642546350.373902,
"step": "CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk",
"tango": {
"command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic",
"version": "0.4.0rc4"
}
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/lock
================================================
================================================
FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/requirements.txt
================================================
absl-py==0.15.0
ai2-tango==0.4.0rc1
aiohttp==3.8.0
aiosignal==1.2.0
alabaster==0.7.12
appnope==0.1.2
astunparse==1.6.3
async-timeout==4.0.0
attrs==21.2.0
babel==2.9.1
backcall==0.2.0
base58==2.1.1
beautifulsoup4==4.10.0
black==21.12b0
bleach==4.1.0
boto3==1.19.12
botocore==1.22.12
cached-path==1.0.0
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.7
click-help-colors==0.9.1
click==8.0.3
codecov==2.1.12
colorama==0.4.4
configparser==5.1.0
coverage==6.1.1
datasets==1.15.1
decorator==5.1.0
dill==0.3.4
docker-pycreds==0.4.0
docutils==0.17.1
filelock==3.4.0
flake8==4.0.1
flaky==3.7.0
flatbuffers==2.0
frozenlist==1.2.0
fsspec==2021.11.0
furo==2022.1.2
future==0.18.2
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
glob2==0.7
google-api-core==2.2.2
google-auth-oauthlib==0.4.6
google-auth==2.3.3
google-cloud-core==2.1.0
google-cloud-storage==1.42.3
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==2.1.0
googleapis-common-protos==1.53.0
grpcio==1.41.1
h5py==3.6.0
huggingface-hub==0.1.1
idna==3.3
imagesize==1.2.0
importlib-metadata==4.8.1
iniconfig==1.1.1
ipython==7.29.0
isort==5.10.1
jedi==0.18.0
jinja2==3.0.2
jmespath==0.10.0
joblib==1.1.0
jsonnet==0.17.0
keras-preprocessing==1.1.2
keras==2.7.0
keyring==23.2.1
libclang==12.0.0
livereload==2.6.3
markdown-it-py==1.1.0
markdown==3.3.4
markupsafe==2.0.1
matplotlib-inline==0.1.2
mccabe==0.6.1
mdit-py-plugins==0.3.0
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
more-itertools==8.10.0
multidict==5.2.0
multiprocess==0.70.12.2
mypy-extensions==0.4.3
mypy==0.931
myst-parser==0.16.1
nltk==3.6.7
numpy==1.21.2
oauthlib==3.1.1
olefile==0.46
opt-einsum==3.3.0
overrides==6.1.0
packaging==21.2
pandas==1.3.4
parso==0.8.2
pathspec==0.9.0
pathtools==0.1.2
petname==2.6
pexpect==4.8.0
pickleshare==0.7.5
pillow==8.4.0
pip==21.2.4
pkginfo==1.7.1
platformdirs==2.4.0
pluggy==1.0.0
promise==2.3
prompt-toolkit==3.0.20
protobuf==3.19.1
psutil==5.8.0
ptyprocess==0.7.0
py==1.11.0
pyarrow==6.0.0
pyasn1-modules==0.2.8
pyasn1==0.4.8
pycodestyle==2.8.0
pydeprecate==0.3.1
pyflakes==2.4.0
pygments==2.10.0
pyparsing==2.4.7
pytest-cov==3.0.0
pytest-sphinx==0.3.1
pytest==6.2.5
python-dateutil==2.8.2
pytorch-lightning==1.5.1
pytz==2021.3
pyyaml==6.0
readme-renderer==30.0
regex==2021.11.2
requests-oauthlib==1.3.0
requests-toolbelt==0.9.1
requests==2.26.0
rfc3986==1.5.0
rouge-score==0.0.4
rsa==4.7.2
s3transfer==0.5.0
sacremoses==0.0.46
sentencepiece==0.1.96
sentry-sdk==1.4.3
setuptools==58.0.4
shortuuid==1.0.1
six==1.16.0
smmap==5.0.0
snowballstemmer==2.1.0
soupsieve==2.3
sphinx-autobuild==2021.3.14
sphinx-copybutton==0.4.0
sphinx==4.3.1
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlitedict==1.7.0
subprocess32==3.5.4
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorboard==2.7.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.23.1
tensorflow==2.7.0
termcolor==1.1.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.2
torch==1.10.0
torchaudio==0.10.0
torchmetrics==0.6.0
torchvision==0.11.1
tornado==6.1
tqdm==4.62.3
traitlets==5.1.0
transformers==4.12.3
twine==3.5.0
types-pyyaml==6.0.0
types-setuptools==57.4.2
typing-extensions==3.10.0.2
typing-utils==0.1.0
urllib3==1.26.7
wandb==0.12.6
wcwidth==0.2.5
webencodings==0.5.1
werkzeug==2.0.2
wheel==0.37.0
wrapt==1.13.3
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zipp==3.6.0
================================================
FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/cache-metadata.json
================================================
{
"step": "ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9"
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/conda-environment.yaml
================================================
name: tango
channels:
- pytorch
- defaults
dependencies:
- appnope=0.1.2=py38hecd8cb5_1001
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2021.10.26=hecd8cb5_2
- certifi=2021.10.8=py38hecd8cb5_0
- decorator=5.1.0=pyhd3eb1b0_0
- ffmpeg=4.3=h0a44026_0
- freetype=2.11.0=hd8bbffd_0
- gettext=0.21.0=h7535e17_0
- giflib=5.2.1=haf1e3a3_0
- gmp=6.2.1=h23ab428_2
- gnutls=3.6.15=hed9c0bf_0
- icu=58.2=h0a44026_3
- intel-openmp=2021.4.0=hecd8cb5_3538
- ipython=7.29.0=py38h01d92e1_0
- jedi=0.18.0=py38hecd8cb5_1
- jpeg=9d=h9ed2024_0
- lame=3.100=h1de35cc_0
- lcms2=2.12=hf1fd2bf_0
- libcxx=12.0.0=h2f01273_0
- libffi=3.3=hb1e8313_2
- libiconv=1.16=h1de35cc_0
- libidn2=2.3.2=h9ed2024_0
- libpng=1.6.37=ha441bb4_0
- libtasn1=4.16.0=h9ed2024_0
- libtiff=4.2.0=h87d7836_0
- libunistring=0.9.10=h9ed2024_0
- libuv=1.40.0=haf1e3a3_0
- libwebp=1.2.0=hacca55c_0
- libwebp-base=1.2.0=h9ed2024_0
- libxml2=2.9.12=hcdb78fc_0
- llvm-openmp=12.0.0=h0dcd299_1
- lz4-c=1.9.3=h23ab428_1
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mkl=2021.4.0=hecd8cb5_637
- mkl-service=2.4.0=py38h9ed2024_0
- mkl_fft=1.3.1=py38h4ab4a9b_0
- mkl_random=1.2.2=py38hb2f4e1b_0
- ncurses=6.3=hca72f7f_1
- nettle=3.7.3=h230ac6f_1
- numpy=1.21.2=py38h4b4dc7a_0
- numpy-base=1.21.2=py38he0bd621_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd9629dc_0
- openssl=1.1.1l=h9ed2024_0
- parso=0.8.2=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h98e4679_0
- pip=21.2.4=py38hecd8cb5_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pygments=2.10.0=pyhd3eb1b0_0
- python=3.8.12=h88f2d9e_0
- pytorch=1.10.0=py3.8_0
- readline=8.1=h9ed2024_0
- setuptools=58.0.4=py38hecd8cb5_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hce871da_0
- tk=8.6.11=h7bc2e8c_0
- torchaudio=0.10.0=py38_cpu
- torchvision=0.11.1=py38_cpu
- traitlets=5.1.0=pyhd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h1de35cc_0
- zlib=1.2.11=h1de35cc_3
- zstd=1.4.9=h322a384_0
- pip:
- absl-py==0.15.0
- aiohttp==3.8.0
- aiosignal==1.2.0
- alabaster==0.7.12
- astunparse==1.6.3
- async-timeout==4.0.0
- attrs==21.2.0
- babel==2.9.1
- base58==2.1.1
- beautifulsoup4==4.10.0
- black==21.12b0
- bleach==4.1.0
- boto3==1.19.12
- botocore==1.22.12
- cached-path==1.0.0
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- click-help-colors==0.9.1
- codecov==2.1.12
- colorama==0.4.4
- configparser==5.1.0
- coverage==6.1.1
- datasets==1.15.1
- dill==0.3.4
- docker-pycreds==0.4.0
- docutils==0.17.1
- filelock==3.4.0
- flake8==4.0.1
- flaky==3.7.0
- flatbuffers==2.0
- frozenlist==1.2.0
- fsspec==2021.11.0
- furo==2022.1.2
- future==0.18.2
- gast==0.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- glob2==0.7
- google-api-core==2.2.2
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- google-cloud-core==2.1.0
- google-cloud-storage==1.42.3
- google-crc32c==1.3.0
- google-pasta==0.2.0
- google-resumable-media==2.1.0
- googleapis-common-protos==1.53.0
- grpcio==1.41.1
- h5py==3.6.0
- huggingface-hub==0.1.1
- idna==3.3
- imagesize==1.2.0
- importlib-metadata==4.8.1
- iniconfig==1.1.1
- isort==5.10.1
- jinja2==3.0.2
- jmespath==0.10.0
- joblib==1.1.0
- jsonnet==0.17.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- keyring==23.2.1
- libclang==12.0.0
- livereload==2.6.3
- markdown==3.3.4
- markdown-it-py==1.1.0
- markupsafe==2.0.1
- mccabe==0.6.1
- mdit-py-plugins==0.3.0
- more-itertools==8.10.0
- multidict==5.2.0
- multiprocess==0.70.12.2
- mypy==0.931
- mypy-extensions==0.4.3
- myst-parser==0.16.1
- nltk==3.6.7
- oauthlib==3.1.1
- opt-einsum==3.3.0
- overrides==6.1.0
- packaging==21.2
- pandas==1.3.4
- pathspec==0.9.0
- pathtools==0.1.2
- petname==2.6
- pkginfo==1.7.1
- platformdirs==2.4.0
- pluggy==1.0.0
- promise==2.3
- protobuf==3.19.1
- psutil==5.8.0
- py==1.11.0
- pyarrow==6.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycodestyle==2.8.0
- pydeprecate==0.3.1
- pyflakes==2.4.0
- pyparsing==2.4.7
- pytest==6.2.5
- pytest-cov==3.0.0
- pytest-sphinx==0.3.1
- python-dateutil==2.8.2
- pytorch-lightning==1.5.1
- pytz==2021.3
- pyyaml==6.0
- readme-renderer==30.0
- regex==2021.11.2
- requests==2.26.0
- requests-oauthlib==1.3.0
- requests-toolbelt==0.9.1
- rfc3986==1.5.0
- rouge-score==0.0.4
- rsa==4.7.2
- s3transfer==0.5.0
- sacremoses==0.0.46
- sentencepiece==0.1.96
- sentry-sdk==1.4.3
- shortuuid==1.0.1
- smmap==5.0.0
- snowballstemmer==2.1.0
- soupsieve==2.3
- sphinx==4.3.1
- sphinx-autobuild==2021.3.14
- sphinx-copybutton==0.4.0
- sphinxcontrib-applehelp==1.0.2
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlitedict==1.7.0
- subprocess32==3.5.4
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow==2.7.0
- tensorflow-estimator==2.7.0
- tensorflow-io-gcs-filesystem==0.23.1
- termcolor==1.1.0
- tokenizers==0.10.3
- toml==0.10.2
- tomli==1.2.2
- torchmetrics==0.6.0
- tornado==6.1
- tqdm==4.62.3
- transformers==4.12.3
- twine==3.5.0
- types-pyyaml==6.0.0
- types-setuptools==57.4.2
- typing-utils==0.1.0
- urllib3==1.26.7
- wandb==0.12.6
- webencodings==0.5.1
- werkzeug==2.0.2
- wrapt==1.13.3
- xxhash==2.0.2
- yarl==1.7.2
- yaspin==2.1.0
- zipp==3.6.0
prefix: /opt/miniconda3/envs/tango
================================================
FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/executor-metadata.json
================================================
{
"config": {
"type": "cexp",
"x": {
"type": "ref",
"ref": "MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae"
}
},
"duration": 0.0006,
"finished_at": 1642546361.347647,
"git": {
"commit": "8e09b66caffbff20fd0b1504c961932b97417e8d",
"remote": "https://github.com/allenai/tango.git"
},
"platform": {
"cpu_count": 16,
"executable": "/opt/miniconda3/envs/tango/bin/python",
"host": "ip-192-168-1-194.us-west-2.compute.internal",
"operating_system": "macOS-10.16-x86_64-i386-64bit",
"python": "3.8.12",
"root": "/Users/dirkg/Documents/tango/examples/euler",
"user": "dirkg"
},
"started_at": 1642546361.347095,
"step": "ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9",
"tango": {
"command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic",
"version": "0.4.0rc4"
}
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/lock
================================================
================================================
FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/requirements.txt
================================================
absl-py==0.15.0
ai2-tango==0.4.0rc1
aiohttp==3.8.0
aiosignal==1.2.0
alabaster==0.7.12
appnope==0.1.2
astunparse==1.6.3
async-timeout==4.0.0
attrs==21.2.0
babel==2.9.1
backcall==0.2.0
base58==2.1.1
beautifulsoup4==4.10.0
black==21.12b0
bleach==4.1.0
boto3==1.19.12
botocore==1.22.12
cached-path==1.0.0
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.7
click-help-colors==0.9.1
click==8.0.3
codecov==2.1.12
colorama==0.4.4
configparser==5.1.0
coverage==6.1.1
datasets==1.15.1
decorator==5.1.0
dill==0.3.4
docker-pycreds==0.4.0
docutils==0.17.1
filelock==3.4.0
flake8==4.0.1
flaky==3.7.0
flatbuffers==2.0
frozenlist==1.2.0
fsspec==2021.11.0
furo==2022.1.2
future==0.18.2
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
glob2==0.7
google-api-core==2.2.2
google-auth-oauthlib==0.4.6
google-auth==2.3.3
google-cloud-core==2.1.0
google-cloud-storage==1.42.3
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==2.1.0
googleapis-common-protos==1.53.0
grpcio==1.41.1
h5py==3.6.0
huggingface-hub==0.1.1
idna==3.3
imagesize==1.2.0
importlib-metadata==4.8.1
iniconfig==1.1.1
ipython==7.29.0
isort==5.10.1
jedi==0.18.0
jinja2==3.0.2
jmespath==0.10.0
joblib==1.1.0
jsonnet==0.17.0
keras-preprocessing==1.1.2
keras==2.7.0
keyring==23.2.1
libclang==12.0.0
livereload==2.6.3
markdown-it-py==1.1.0
markdown==3.3.4
markupsafe==2.0.1
matplotlib-inline==0.1.2
mccabe==0.6.1
mdit-py-plugins==0.3.0
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
more-itertools==8.10.0
multidict==5.2.0
multiprocess==0.70.12.2
mypy-extensions==0.4.3
mypy==0.931
myst-parser==0.16.1
nltk==3.6.7
numpy==1.21.2
oauthlib==3.1.1
olefile==0.46
opt-einsum==3.3.0
overrides==6.1.0
packaging==21.2
pandas==1.3.4
parso==0.8.2
pathspec==0.9.0
pathtools==0.1.2
petname==2.6
pexpect==4.8.0
pickleshare==0.7.5
pillow==8.4.0
pip==21.2.4
pkginfo==1.7.1
platformdirs==2.4.0
pluggy==1.0.0
promise==2.3
prompt-toolkit==3.0.20
protobuf==3.19.1
psutil==5.8.0
ptyprocess==0.7.0
py==1.11.0
pyarrow==6.0.0
pyasn1-modules==0.2.8
pyasn1==0.4.8
pycodestyle==2.8.0
pydeprecate==0.3.1
pyflakes==2.4.0
pygments==2.10.0
pyparsing==2.4.7
pytest-cov==3.0.0
pytest-sphinx==0.3.1
pytest==6.2.5
python-dateutil==2.8.2
pytorch-lightning==1.5.1
pytz==2021.3
pyyaml==6.0
readme-renderer==30.0
regex==2021.11.2
requests-oauthlib==1.3.0
requests-toolbelt==0.9.1
requests==2.26.0
rfc3986==1.5.0
rouge-score==0.0.4
rsa==4.7.2
s3transfer==0.5.0
sacremoses==0.0.46
sentencepiece==0.1.96
sentry-sdk==1.4.3
setuptools==58.0.4
shortuuid==1.0.1
six==1.16.0
smmap==5.0.0
snowballstemmer==2.1.0
soupsieve==2.3
sphinx-autobuild==2021.3.14
sphinx-copybutton==0.4.0
sphinx==4.3.1
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlitedict==1.7.0
subprocess32==3.5.4
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorboard==2.7.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.23.1
tensorflow==2.7.0
termcolor==1.1.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.2
torch==1.10.0
torchaudio==0.10.0
torchmetrics==0.6.0
torchvision==0.11.1
tornado==6.1
tqdm==4.62.3
traitlets==5.1.0
transformers==4.12.3
twine==3.5.0
types-pyyaml==6.0.0
types-setuptools==57.4.2
typing-extensions==3.10.0.2
typing-utils==0.1.0
urllib3==1.26.7
wandb==0.12.6
wcwidth==0.2.5
webencodings==0.5.1
werkzeug==2.0.2
wheel==0.37.0
wrapt==1.13.3
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zipp==3.6.0
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/cache-metadata.json
================================================
{
"step": "MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf"
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/conda-environment.yaml
================================================
name: tango
channels:
- pytorch
- defaults
dependencies:
- appnope=0.1.2=py38hecd8cb5_1001
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2021.10.26=hecd8cb5_2
- certifi=2021.10.8=py38hecd8cb5_0
- decorator=5.1.0=pyhd3eb1b0_0
- ffmpeg=4.3=h0a44026_0
- freetype=2.11.0=hd8bbffd_0
- gettext=0.21.0=h7535e17_0
- giflib=5.2.1=haf1e3a3_0
- gmp=6.2.1=h23ab428_2
- gnutls=3.6.15=hed9c0bf_0
- icu=58.2=h0a44026_3
- intel-openmp=2021.4.0=hecd8cb5_3538
- ipython=7.29.0=py38h01d92e1_0
- jedi=0.18.0=py38hecd8cb5_1
- jpeg=9d=h9ed2024_0
- lame=3.100=h1de35cc_0
- lcms2=2.12=hf1fd2bf_0
- libcxx=12.0.0=h2f01273_0
- libffi=3.3=hb1e8313_2
- libiconv=1.16=h1de35cc_0
- libidn2=2.3.2=h9ed2024_0
- libpng=1.6.37=ha441bb4_0
- libtasn1=4.16.0=h9ed2024_0
- libtiff=4.2.0=h87d7836_0
- libunistring=0.9.10=h9ed2024_0
- libuv=1.40.0=haf1e3a3_0
- libwebp=1.2.0=hacca55c_0
- libwebp-base=1.2.0=h9ed2024_0
- libxml2=2.9.12=hcdb78fc_0
- llvm-openmp=12.0.0=h0dcd299_1
- lz4-c=1.9.3=h23ab428_1
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mkl=2021.4.0=hecd8cb5_637
- mkl-service=2.4.0=py38h9ed2024_0
- mkl_fft=1.3.1=py38h4ab4a9b_0
- mkl_random=1.2.2=py38hb2f4e1b_0
- ncurses=6.3=hca72f7f_1
- nettle=3.7.3=h230ac6f_1
- numpy=1.21.2=py38h4b4dc7a_0
- numpy-base=1.21.2=py38he0bd621_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd9629dc_0
- openssl=1.1.1l=h9ed2024_0
- parso=0.8.2=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h98e4679_0
- pip=21.2.4=py38hecd8cb5_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pygments=2.10.0=pyhd3eb1b0_0
- python=3.8.12=h88f2d9e_0
- pytorch=1.10.0=py3.8_0
- readline=8.1=h9ed2024_0
- setuptools=58.0.4=py38hecd8cb5_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hce871da_0
- tk=8.6.11=h7bc2e8c_0
- torchaudio=0.10.0=py38_cpu
- torchvision=0.11.1=py38_cpu
- traitlets=5.1.0=pyhd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h1de35cc_0
- zlib=1.2.11=h1de35cc_3
- zstd=1.4.9=h322a384_0
- pip:
- absl-py==0.15.0
- aiohttp==3.8.0
- aiosignal==1.2.0
- alabaster==0.7.12
- astunparse==1.6.3
- async-timeout==4.0.0
- attrs==21.2.0
- babel==2.9.1
- base58==2.1.1
- beautifulsoup4==4.10.0
- black==21.12b0
- bleach==4.1.0
- boto3==1.19.12
- botocore==1.22.12
- cached-path==1.0.0
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- click-help-colors==0.9.1
- codecov==2.1.12
- colorama==0.4.4
- configparser==5.1.0
- coverage==6.1.1
- datasets==1.15.1
- dill==0.3.4
- docker-pycreds==0.4.0
- docutils==0.17.1
- filelock==3.4.0
- flake8==4.0.1
- flaky==3.7.0
- flatbuffers==2.0
- frozenlist==1.2.0
- fsspec==2021.11.0
- furo==2022.1.2
- future==0.18.2
- gast==0.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- glob2==0.7
- google-api-core==2.2.2
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- google-cloud-core==2.1.0
- google-cloud-storage==1.42.3
- google-crc32c==1.3.0
- google-pasta==0.2.0
- google-resumable-media==2.1.0
- googleapis-common-protos==1.53.0
- grpcio==1.41.1
- h5py==3.6.0
- huggingface-hub==0.1.1
- idna==3.3
- imagesize==1.2.0
- importlib-metadata==4.8.1
- iniconfig==1.1.1
- isort==5.10.1
- jinja2==3.0.2
- jmespath==0.10.0
- joblib==1.1.0
- jsonnet==0.17.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- keyring==23.2.1
- libclang==12.0.0
- livereload==2.6.3
- markdown==3.3.4
- markdown-it-py==1.1.0
- markupsafe==2.0.1
- mccabe==0.6.1
- mdit-py-plugins==0.3.0
- more-itertools==8.10.0
- multidict==5.2.0
- multiprocess==0.70.12.2
- mypy==0.931
- mypy-extensions==0.4.3
- myst-parser==0.16.1
- nltk==3.6.7
- oauthlib==3.1.1
- opt-einsum==3.3.0
- overrides==6.1.0
- packaging==21.2
- pandas==1.3.4
- pathspec==0.9.0
- pathtools==0.1.2
- petname==2.6
- pkginfo==1.7.1
- platformdirs==2.4.0
- pluggy==1.0.0
- promise==2.3
- protobuf==3.19.1
- psutil==5.8.0
- py==1.11.0
- pyarrow==6.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycodestyle==2.8.0
- pydeprecate==0.3.1
- pyflakes==2.4.0
- pyparsing==2.4.7
- pytest==6.2.5
- pytest-cov==3.0.0
- pytest-sphinx==0.3.1
- python-dateutil==2.8.2
- pytorch-lightning==1.5.1
- pytz==2021.3
- pyyaml==6.0
- readme-renderer==30.0
- regex==2021.11.2
- requests==2.26.0
- requests-oauthlib==1.3.0
- requests-toolbelt==0.9.1
- rfc3986==1.5.0
- rouge-score==0.0.4
- rsa==4.7.2
- s3transfer==0.5.0
- sacremoses==0.0.46
- sentencepiece==0.1.96
- sentry-sdk==1.4.3
- shortuuid==1.0.1
- smmap==5.0.0
- snowballstemmer==2.1.0
- soupsieve==2.3
- sphinx==4.3.1
- sphinx-autobuild==2021.3.14
- sphinx-copybutton==0.4.0
- sphinxcontrib-applehelp==1.0.2
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlitedict==1.7.0
- subprocess32==3.5.4
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow==2.7.0
- tensorflow-estimator==2.7.0
- tensorflow-io-gcs-filesystem==0.23.1
- termcolor==1.1.0
- tokenizers==0.10.3
- toml==0.10.2
- tomli==1.2.2
- torchmetrics==0.6.0
- tornado==6.1
- tqdm==4.62.3
- transformers==4.12.3
- twine==3.5.0
- types-pyyaml==6.0.0
- types-setuptools==57.4.2
- typing-utils==0.1.0
- urllib3==1.26.7
- wandb==0.12.6
- webencodings==0.5.1
- werkzeug==2.0.2
- wrapt==1.13.3
- xxhash==2.0.2
- yarl==1.7.2
- yaspin==2.1.0
- zipp==3.6.0
prefix: /opt/miniconda3/envs/tango
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/executor-metadata.json
================================================
{
"config": {
"type": "cmul",
"a": [
0,
1
],
"b": {
"type": "ref",
"ref": "SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk"
}
},
"duration": 0.0004,
"finished_at": 1642546358.776982,
"git": {
"commit": "8e09b66caffbff20fd0b1504c961932b97417e8d",
"remote": "https://github.com/allenai/tango.git"
},
"platform": {
"cpu_count": 16,
"executable": "/opt/miniconda3/envs/tango/bin/python",
"host": "ip-192-168-1-194.us-west-2.compute.internal",
"operating_system": "macOS-10.16-x86_64-i386-64bit",
"python": "3.8.12",
"root": "/Users/dirkg/Documents/tango/examples/euler",
"user": "dirkg"
},
"started_at": 1642546358.776602,
"step": "MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf",
"tango": {
"command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic",
"version": "0.4.0rc4"
}
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/lock
================================================
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/requirements.txt
================================================
absl-py==0.15.0
ai2-tango==0.4.0rc1
aiohttp==3.8.0
aiosignal==1.2.0
alabaster==0.7.12
appnope==0.1.2
astunparse==1.6.3
async-timeout==4.0.0
attrs==21.2.0
babel==2.9.1
backcall==0.2.0
base58==2.1.1
beautifulsoup4==4.10.0
black==21.12b0
bleach==4.1.0
boto3==1.19.12
botocore==1.22.12
cached-path==1.0.0
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.7
click-help-colors==0.9.1
click==8.0.3
codecov==2.1.12
colorama==0.4.4
configparser==5.1.0
coverage==6.1.1
datasets==1.15.1
decorator==5.1.0
dill==0.3.4
docker-pycreds==0.4.0
docutils==0.17.1
filelock==3.4.0
flake8==4.0.1
flaky==3.7.0
flatbuffers==2.0
frozenlist==1.2.0
fsspec==2021.11.0
furo==2022.1.2
future==0.18.2
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
glob2==0.7
google-api-core==2.2.2
google-auth-oauthlib==0.4.6
google-auth==2.3.3
google-cloud-core==2.1.0
google-cloud-storage==1.42.3
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==2.1.0
googleapis-common-protos==1.53.0
grpcio==1.41.1
h5py==3.6.0
huggingface-hub==0.1.1
idna==3.3
imagesize==1.2.0
importlib-metadata==4.8.1
iniconfig==1.1.1
ipython==7.29.0
isort==5.10.1
jedi==0.18.0
jinja2==3.0.2
jmespath==0.10.0
joblib==1.1.0
jsonnet==0.17.0
keras-preprocessing==1.1.2
keras==2.7.0
keyring==23.2.1
libclang==12.0.0
livereload==2.6.3
markdown-it-py==1.1.0
markdown==3.3.4
markupsafe==2.0.1
matplotlib-inline==0.1.2
mccabe==0.6.1
mdit-py-plugins==0.3.0
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
more-itertools==8.10.0
multidict==5.2.0
multiprocess==0.70.12.2
mypy-extensions==0.4.3
mypy==0.931
myst-parser==0.16.1
nltk==3.6.7
numpy==1.21.2
oauthlib==3.1.1
olefile==0.46
opt-einsum==3.3.0
overrides==6.1.0
packaging==21.2
pandas==1.3.4
parso==0.8.2
pathspec==0.9.0
pathtools==0.1.2
petname==2.6
pexpect==4.8.0
pickleshare==0.7.5
pillow==8.4.0
pip==21.2.4
pkginfo==1.7.1
platformdirs==2.4.0
pluggy==1.0.0
promise==2.3
prompt-toolkit==3.0.20
protobuf==3.19.1
psutil==5.8.0
ptyprocess==0.7.0
py==1.11.0
pyarrow==6.0.0
pyasn1-modules==0.2.8
pyasn1==0.4.8
pycodestyle==2.8.0
pydeprecate==0.3.1
pyflakes==2.4.0
pygments==2.10.0
pyparsing==2.4.7
pytest-cov==3.0.0
pytest-sphinx==0.3.1
pytest==6.2.5
python-dateutil==2.8.2
pytorch-lightning==1.5.1
pytz==2021.3
pyyaml==6.0
readme-renderer==30.0
regex==2021.11.2
requests-oauthlib==1.3.0
requests-toolbelt==0.9.1
requests==2.26.0
rfc3986==1.5.0
rouge-score==0.0.4
rsa==4.7.2
s3transfer==0.5.0
sacremoses==0.0.46
sentencepiece==0.1.96
sentry-sdk==1.4.3
setuptools==58.0.4
shortuuid==1.0.1
six==1.16.0
smmap==5.0.0
snowballstemmer==2.1.0
soupsieve==2.3
sphinx-autobuild==2021.3.14
sphinx-copybutton==0.4.0
sphinx==4.3.1
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlitedict==1.7.0
subprocess32==3.5.4
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorboard==2.7.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.23.1
tensorflow==2.7.0
termcolor==1.1.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.2
torch==1.10.0
torchaudio==0.10.0
torchmetrics==0.6.0
torchvision==0.11.1
tornado==6.1
tqdm==4.62.3
traitlets==5.1.0
transformers==4.12.3
twine==3.5.0
types-pyyaml==6.0.0
types-setuptools==57.4.2
typing-extensions==3.10.0.2
typing-utils==0.1.0
urllib3==1.26.7
wandb==0.12.6
wcwidth==0.2.5
webencodings==0.5.1
werkzeug==2.0.2
wheel==0.37.0
wrapt==1.13.3
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zipp==3.6.0
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/cache-metadata.json
================================================
{
"step": "MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae"
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/conda-environment.yaml
================================================
name: tango
channels:
- pytorch
- defaults
dependencies:
- appnope=0.1.2=py38hecd8cb5_1001
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2021.10.26=hecd8cb5_2
- certifi=2021.10.8=py38hecd8cb5_0
- decorator=5.1.0=pyhd3eb1b0_0
- ffmpeg=4.3=h0a44026_0
- freetype=2.11.0=hd8bbffd_0
- gettext=0.21.0=h7535e17_0
- giflib=5.2.1=haf1e3a3_0
- gmp=6.2.1=h23ab428_2
- gnutls=3.6.15=hed9c0bf_0
- icu=58.2=h0a44026_3
- intel-openmp=2021.4.0=hecd8cb5_3538
- ipython=7.29.0=py38h01d92e1_0
- jedi=0.18.0=py38hecd8cb5_1
- jpeg=9d=h9ed2024_0
- lame=3.100=h1de35cc_0
- lcms2=2.12=hf1fd2bf_0
- libcxx=12.0.0=h2f01273_0
- libffi=3.3=hb1e8313_2
- libiconv=1.16=h1de35cc_0
- libidn2=2.3.2=h9ed2024_0
- libpng=1.6.37=ha441bb4_0
- libtasn1=4.16.0=h9ed2024_0
- libtiff=4.2.0=h87d7836_0
- libunistring=0.9.10=h9ed2024_0
- libuv=1.40.0=haf1e3a3_0
- libwebp=1.2.0=hacca55c_0
- libwebp-base=1.2.0=h9ed2024_0
- libxml2=2.9.12=hcdb78fc_0
- llvm-openmp=12.0.0=h0dcd299_1
- lz4-c=1.9.3=h23ab428_1
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mkl=2021.4.0=hecd8cb5_637
- mkl-service=2.4.0=py38h9ed2024_0
- mkl_fft=1.3.1=py38h4ab4a9b_0
- mkl_random=1.2.2=py38hb2f4e1b_0
- ncurses=6.3=hca72f7f_1
- nettle=3.7.3=h230ac6f_1
- numpy=1.21.2=py38h4b4dc7a_0
- numpy-base=1.21.2=py38he0bd621_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd9629dc_0
- openssl=1.1.1l=h9ed2024_0
- parso=0.8.2=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h98e4679_0
- pip=21.2.4=py38hecd8cb5_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pygments=2.10.0=pyhd3eb1b0_0
- python=3.8.12=h88f2d9e_0
- pytorch=1.10.0=py3.8_0
- readline=8.1=h9ed2024_0
- setuptools=58.0.4=py38hecd8cb5_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hce871da_0
- tk=8.6.11=h7bc2e8c_0
- torchaudio=0.10.0=py38_cpu
- torchvision=0.11.1=py38_cpu
- traitlets=5.1.0=pyhd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h1de35cc_0
- zlib=1.2.11=h1de35cc_3
- zstd=1.4.9=h322a384_0
- pip:
- absl-py==0.15.0
- aiohttp==3.8.0
- aiosignal==1.2.0
- alabaster==0.7.12
- astunparse==1.6.3
- async-timeout==4.0.0
- attrs==21.2.0
- babel==2.9.1
- base58==2.1.1
- beautifulsoup4==4.10.0
- black==21.12b0
- bleach==4.1.0
- boto3==1.19.12
- botocore==1.22.12
- cached-path==1.0.0
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- click-help-colors==0.9.1
- codecov==2.1.12
- colorama==0.4.4
- configparser==5.1.0
- coverage==6.1.1
- datasets==1.15.1
- dill==0.3.4
- docker-pycreds==0.4.0
- docutils==0.17.1
- filelock==3.4.0
- flake8==4.0.1
- flaky==3.7.0
- flatbuffers==2.0
- frozenlist==1.2.0
- fsspec==2021.11.0
- furo==2022.1.2
- future==0.18.2
- gast==0.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- glob2==0.7
- google-api-core==2.2.2
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- google-cloud-core==2.1.0
- google-cloud-storage==1.42.3
- google-crc32c==1.3.0
- google-pasta==0.2.0
- google-resumable-media==2.1.0
- googleapis-common-protos==1.53.0
- grpcio==1.41.1
- h5py==3.6.0
- huggingface-hub==0.1.1
- idna==3.3
- imagesize==1.2.0
- importlib-metadata==4.8.1
- iniconfig==1.1.1
- isort==5.10.1
- jinja2==3.0.2
- jmespath==0.10.0
- joblib==1.1.0
- jsonnet==0.17.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- keyring==23.2.1
- libclang==12.0.0
- livereload==2.6.3
- markdown==3.3.4
- markdown-it-py==1.1.0
- markupsafe==2.0.1
- mccabe==0.6.1
- mdit-py-plugins==0.3.0
- more-itertools==8.10.0
- multidict==5.2.0
- multiprocess==0.70.12.2
- mypy==0.931
- mypy-extensions==0.4.3
- myst-parser==0.16.1
- nltk==3.6.7
- oauthlib==3.1.1
- opt-einsum==3.3.0
- overrides==6.1.0
- packaging==21.2
- pandas==1.3.4
- pathspec==0.9.0
- pathtools==0.1.2
- petname==2.6
- pkginfo==1.7.1
- platformdirs==2.4.0
- pluggy==1.0.0
- promise==2.3
- protobuf==3.19.1
- psutil==5.8.0
- py==1.11.0
- pyarrow==6.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycodestyle==2.8.0
- pydeprecate==0.3.1
- pyflakes==2.4.0
- pyparsing==2.4.7
- pytest==6.2.5
- pytest-cov==3.0.0
- pytest-sphinx==0.3.1
- python-dateutil==2.8.2
- pytorch-lightning==1.5.1
- pytz==2021.3
- pyyaml==6.0
- readme-renderer==30.0
- regex==2021.11.2
- requests==2.26.0
- requests-oauthlib==1.3.0
- requests-toolbelt==0.9.1
- rfc3986==1.5.0
- rouge-score==0.0.4
- rsa==4.7.2
- s3transfer==0.5.0
- sacremoses==0.0.46
- sentencepiece==0.1.96
- sentry-sdk==1.4.3
- shortuuid==1.0.1
- smmap==5.0.0
- snowballstemmer==2.1.0
- soupsieve==2.3
- sphinx==4.3.1
- sphinx-autobuild==2021.3.14
- sphinx-copybutton==0.4.0
- sphinxcontrib-applehelp==1.0.2
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlitedict==1.7.0
- subprocess32==3.5.4
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow==2.7.0
- tensorflow-estimator==2.7.0
- tensorflow-io-gcs-filesystem==0.23.1
- termcolor==1.1.0
- tokenizers==0.10.3
- toml==0.10.2
- tomli==1.2.2
- torchmetrics==0.6.0
- tornado==6.1
- tqdm==4.62.3
- transformers==4.12.3
- twine==3.5.0
- types-pyyaml==6.0.0
- types-setuptools==57.4.2
- typing-utils==0.1.0
- urllib3==1.26.7
- wandb==0.12.6
- webencodings==0.5.1
- werkzeug==2.0.2
- wrapt==1.13.3
- xxhash==2.0.2
- yarl==1.7.2
- yaspin==2.1.0
- zipp==3.6.0
prefix: /opt/miniconda3/envs/tango
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/executor-metadata.json
================================================
{
"config": {
"type": "cmul",
"a": [
0,
1
],
"b": [
3.1415926535,
0
]
},
"duration": 0.0005,
"finished_at": 1642546353.7523232,
"git": {
"commit": "8e09b66caffbff20fd0b1504c961932b97417e8d",
"remote": "https://github.com/allenai/tango.git"
},
"platform": {
"cpu_count": 16,
"executable": "/opt/miniconda3/envs/tango/bin/python",
"host": "ip-192-168-1-194.us-west-2.compute.internal",
"operating_system": "macOS-10.16-x86_64-i386-64bit",
"python": "3.8.12",
"root": "/Users/dirkg/Documents/tango/examples/euler",
"user": "dirkg"
},
"started_at": 1642546353.751795,
"step": "MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae",
"tango": {
"command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic",
"version": "0.4.0rc4"
}
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/lock
================================================
================================================
FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/requirements.txt
================================================
absl-py==0.15.0
ai2-tango==0.4.0rc1
aiohttp==3.8.0
aiosignal==1.2.0
alabaster==0.7.12
appnope==0.1.2
astunparse==1.6.3
async-timeout==4.0.0
attrs==21.2.0
babel==2.9.1
backcall==0.2.0
base58==2.1.1
beautifulsoup4==4.10.0
black==21.12b0
bleach==4.1.0
boto3==1.19.12
botocore==1.22.12
cached-path==1.0.0
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.7
click-help-colors==0.9.1
click==8.0.3
codecov==2.1.12
colorama==0.4.4
configparser==5.1.0
coverage==6.1.1
datasets==1.15.1
decorator==5.1.0
dill==0.3.4
docker-pycreds==0.4.0
docutils==0.17.1
filelock==3.4.0
flake8==4.0.1
flaky==3.7.0
flatbuffers==2.0
frozenlist==1.2.0
fsspec==2021.11.0
furo==2022.1.2
future==0.18.2
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
glob2==0.7
google-api-core==2.2.2
google-auth-oauthlib==0.4.6
google-auth==2.3.3
google-cloud-core==2.1.0
google-cloud-storage==1.42.3
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==2.1.0
googleapis-common-protos==1.53.0
grpcio==1.41.1
h5py==3.6.0
huggingface-hub==0.1.1
idna==3.3
imagesize==1.2.0
importlib-metadata==4.8.1
iniconfig==1.1.1
ipython==7.29.0
isort==5.10.1
jedi==0.18.0
jinja2==3.0.2
jmespath==0.10.0
joblib==1.1.0
jsonnet==0.17.0
keras-preprocessing==1.1.2
keras==2.7.0
keyring==23.2.1
libclang==12.0.0
livereload==2.6.3
markdown-it-py==1.1.0
markdown==3.3.4
markupsafe==2.0.1
matplotlib-inline==0.1.2
mccabe==0.6.1
mdit-py-plugins==0.3.0
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
more-itertools==8.10.0
multidict==5.2.0
multiprocess==0.70.12.2
mypy-extensions==0.4.3
mypy==0.931
myst-parser==0.16.1
nltk==3.6.7
numpy==1.21.2
oauthlib==3.1.1
olefile==0.46
opt-einsum==3.3.0
overrides==6.1.0
packaging==21.2
pandas==1.3.4
parso==0.8.2
pathspec==0.9.0
pathtools==0.1.2
petname==2.6
pexpect==4.8.0
pickleshare==0.7.5
pillow==8.4.0
pip==21.2.4
pkginfo==1.7.1
platformdirs==2.4.0
pluggy==1.0.0
promise==2.3
prompt-toolkit==3.0.20
protobuf==3.19.1
psutil==5.8.0
ptyprocess==0.7.0
py==1.11.0
pyarrow==6.0.0
pyasn1-modules==0.2.8
pyasn1==0.4.8
pycodestyle==2.8.0
pydeprecate==0.3.1
pyflakes==2.4.0
pygments==2.10.0
pyparsing==2.4.7
pytest-cov==3.0.0
pytest-sphinx==0.3.1
pytest==6.2.5
python-dateutil==2.8.2
pytorch-lightning==1.5.1
pytz==2021.3
pyyaml==6.0
readme-renderer==30.0
regex==2021.11.2
requests-oauthlib==1.3.0
requests-toolbelt==0.9.1
requests==2.26.0
rfc3986==1.5.0
rouge-score==0.0.4
rsa==4.7.2
s3transfer==0.5.0
sacremoses==0.0.46
sentencepiece==0.1.96
sentry-sdk==1.4.3
setuptools==58.0.4
shortuuid==1.0.1
six==1.16.0
smmap==5.0.0
snowballstemmer==2.1.0
soupsieve==2.3
sphinx-autobuild==2021.3.14
sphinx-copybutton==0.4.0
sphinx==4.3.1
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlitedict==1.7.0
subprocess32==3.5.4
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorboard==2.7.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.23.1
tensorflow==2.7.0
termcolor==1.1.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.2
torch==1.10.0
torchaudio==0.10.0
torchmetrics==0.6.0
torchvision==0.11.1
tornado==6.1
tqdm==4.62.3
traitlets==5.1.0
transformers==4.12.3
twine==3.5.0
types-pyyaml==6.0.0
types-setuptools==57.4.2
typing-extensions==3.10.0.2
typing-utils==0.1.0
urllib3==1.26.7
wandb==0.12.6
wcwidth==0.2.5
webencodings==0.5.1
werkzeug==2.0.2
wheel==0.37.0
wrapt==1.13.3
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zipp==3.6.0
================================================
FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/cache-metadata.json
================================================
{
"step": "SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk"
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/conda-environment.yaml
================================================
name: tango
channels:
- pytorch
- defaults
dependencies:
- appnope=0.1.2=py38hecd8cb5_1001
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2021.10.26=hecd8cb5_2
- certifi=2021.10.8=py38hecd8cb5_0
- decorator=5.1.0=pyhd3eb1b0_0
- ffmpeg=4.3=h0a44026_0
- freetype=2.11.0=hd8bbffd_0
- gettext=0.21.0=h7535e17_0
- giflib=5.2.1=haf1e3a3_0
- gmp=6.2.1=h23ab428_2
- gnutls=3.6.15=hed9c0bf_0
- icu=58.2=h0a44026_3
- intel-openmp=2021.4.0=hecd8cb5_3538
- ipython=7.29.0=py38h01d92e1_0
- jedi=0.18.0=py38hecd8cb5_1
- jpeg=9d=h9ed2024_0
- lame=3.100=h1de35cc_0
- lcms2=2.12=hf1fd2bf_0
- libcxx=12.0.0=h2f01273_0
- libffi=3.3=hb1e8313_2
- libiconv=1.16=h1de35cc_0
- libidn2=2.3.2=h9ed2024_0
- libpng=1.6.37=ha441bb4_0
- libtasn1=4.16.0=h9ed2024_0
- libtiff=4.2.0=h87d7836_0
- libunistring=0.9.10=h9ed2024_0
- libuv=1.40.0=haf1e3a3_0
- libwebp=1.2.0=hacca55c_0
- libwebp-base=1.2.0=h9ed2024_0
- libxml2=2.9.12=hcdb78fc_0
- llvm-openmp=12.0.0=h0dcd299_1
- lz4-c=1.9.3=h23ab428_1
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mkl=2021.4.0=hecd8cb5_637
- mkl-service=2.4.0=py38h9ed2024_0
- mkl_fft=1.3.1=py38h4ab4a9b_0
- mkl_random=1.2.2=py38hb2f4e1b_0
- ncurses=6.3=hca72f7f_1
- nettle=3.7.3=h230ac6f_1
- numpy=1.21.2=py38h4b4dc7a_0
- numpy-base=1.21.2=py38he0bd621_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd9629dc_0
- openssl=1.1.1l=h9ed2024_0
- parso=0.8.2=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h98e4679_0
- pip=21.2.4=py38hecd8cb5_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pygments=2.10.0=pyhd3eb1b0_0
- python=3.8.12=h88f2d9e_0
- pytorch=1.10.0=py3.8_0
- readline=8.1=h9ed2024_0
- setuptools=58.0.4=py38hecd8cb5_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hce871da_0
- tk=8.6.11=h7bc2e8c_0
- torchaudio=0.10.0=py38_cpu
- torchvision=0.11.1=py38_cpu
- traitlets=5.1.0=pyhd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h1de35cc_0
- zlib=1.2.11=h1de35cc_3
- zstd=1.4.9=h322a384_0
- pip:
- absl-py==0.15.0
- aiohttp==3.8.0
- aiosignal==1.2.0
- alabaster==0.7.12
- astunparse==1.6.3
- async-timeout==4.0.0
- attrs==21.2.0
- babel==2.9.1
- base58==2.1.1
- beautifulsoup4==4.10.0
- black==21.12b0
- bleach==4.1.0
- boto3==1.19.12
- botocore==1.22.12
- cached-path==1.0.0
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- click-help-colors==0.9.1
- codecov==2.1.12
- colorama==0.4.4
- configparser==5.1.0
- coverage==6.1.1
- datasets==1.15.1
- dill==0.3.4
- docker-pycreds==0.4.0
- docutils==0.17.1
- filelock==3.4.0
- flake8==4.0.1
- flaky==3.7.0
- flatbuffers==2.0
- frozenlist==1.2.0
- fsspec==2021.11.0
- furo==2022.1.2
- future==0.18.2
- gast==0.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- glob2==0.7
- google-api-core==2.2.2
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- google-cloud-core==2.1.0
- google-cloud-storage==1.42.3
- google-crc32c==1.3.0
- google-pasta==0.2.0
- google-resumable-media==2.1.0
- googleapis-common-protos==1.53.0
- grpcio==1.41.1
- h5py==3.6.0
- huggingface-hub==0.1.1
- idna==3.3
- imagesize==1.2.0
- importlib-metadata==4.8.1
- iniconfig==1.1.1
- isort==5.10.1
- jinja2==3.0.2
- jmespath==0.10.0
- joblib==1.1.0
- jsonnet==0.17.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- keyring==23.2.1
- libclang==12.0.0
- livereload==2.6.3
- markdown==3.3.4
- markdown-it-py==1.1.0
- markupsafe==2.0.1
- mccabe==0.6.1
- mdit-py-plugins==0.3.0
- more-itertools==8.10.0
- multidict==5.2.0
- multiprocess==0.70.12.2
- mypy==0.931
- mypy-extensions==0.4.3
- myst-parser==0.16.1
- nltk==3.6.7
- oauthlib==3.1.1
- opt-einsum==3.3.0
- overrides==6.1.0
- packaging==21.2
- pandas==1.3.4
- pathspec==0.9.0
- pathtools==0.1.2
- petname==2.6
- pkginfo==1.7.1
- platformdirs==2.4.0
- pluggy==1.0.0
- promise==2.3
- protobuf==3.19.1
- psutil==5.8.0
- py==1.11.0
- pyarrow==6.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycodestyle==2.8.0
- pydeprecate==0.3.1
- pyflakes==2.4.0
- pyparsing==2.4.7
- pytest==6.2.5
- pytest-cov==3.0.0
- pytest-sphinx==0.3.1
- python-dateutil==2.8.2
- pytorch-lightning==1.5.1
- pytz==2021.3
- pyyaml==6.0
- readme-renderer==30.0
- regex==2021.11.2
- requests==2.26.0
- requests-oauthlib==1.3.0
- requests-toolbelt==0.9.1
- rfc3986==1.5.0
- rouge-score==0.0.4
- rsa==4.7.2
- s3transfer==0.5.0
- sacremoses==0.0.46
- sentencepiece==0.1.96
- sentry-sdk==1.4.3
- shortuuid==1.0.1
- smmap==5.0.0
- snowballstemmer==2.1.0
- soupsieve==2.3
- sphinx==4.3.1
- sphinx-autobuild==2021.3.14
- sphinx-copybutton==0.4.0
- sphinxcontrib-applehelp==1.0.2
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlitedict==1.7.0
- subprocess32==3.5.4
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow==2.7.0
- tensorflow-estimator==2.7.0
- tensorflow-io-gcs-filesystem==0.23.1
- termcolor==1.1.0
- tokenizers==0.10.3
- toml==0.10.2
- tomli==1.2.2
- torchmetrics==0.6.0
- tornado==6.1
- tqdm==4.62.3
- transformers==4.12.3
- twine==3.5.0
- types-pyyaml==6.0.0
- types-setuptools==57.4.2
- typing-utils==0.1.0
- urllib3==1.26.7
- wandb==0.12.6
- webencodings==0.5.1
- werkzeug==2.0.2
- wrapt==1.13.3
- xxhash==2.0.2
- yarl==1.7.2
- yaspin==2.1.0
- zipp==3.6.0
prefix: /opt/miniconda3/envs/tango
================================================
FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/executor-metadata.json
================================================
{
"config": {
"type": "csin",
"x": [
3.1415926535,
0
]
},
"duration": 0.0004,
"finished_at": 1642546356.265017,
"git": {
"commit": "8e09b66caffbff20fd0b1504c961932b97417e8d",
"remote": "https://github.com/allenai/tango.git"
},
"platform": {
"cpu_count": 16,
"executable": "/opt/miniconda3/envs/tango/bin/python",
"host": "ip-192-168-1-194.us-west-2.compute.internal",
"operating_system": "macOS-10.16-x86_64-i386-64bit",
"python": "3.8.12",
"root": "/Users/dirkg/Documents/tango/examples/euler",
"user": "dirkg"
},
"started_at": 1642546356.264595,
"step": "SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk",
"tango": {
"command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic",
"version": "0.4.0rc4"
}
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/lock
================================================
================================================
FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/requirements.txt
================================================
absl-py==0.15.0
ai2-tango==0.4.0rc1
aiohttp==3.8.0
aiosignal==1.2.0
alabaster==0.7.12
appnope==0.1.2
astunparse==1.6.3
async-timeout==4.0.0
attrs==21.2.0
babel==2.9.1
backcall==0.2.0
base58==2.1.1
beautifulsoup4==4.10.0
black==21.12b0
bleach==4.1.0
boto3==1.19.12
botocore==1.22.12
cached-path==1.0.0
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.7
click-help-colors==0.9.1
click==8.0.3
codecov==2.1.12
colorama==0.4.4
configparser==5.1.0
coverage==6.1.1
datasets==1.15.1
decorator==5.1.0
dill==0.3.4
docker-pycreds==0.4.0
docutils==0.17.1
filelock==3.4.0
flake8==4.0.1
flaky==3.7.0
flatbuffers==2.0
frozenlist==1.2.0
fsspec==2021.11.0
furo==2022.1.2
future==0.18.2
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
glob2==0.7
google-api-core==2.2.2
google-auth-oauthlib==0.4.6
google-auth==2.3.3
google-cloud-core==2.1.0
google-cloud-storage==1.42.3
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==2.1.0
googleapis-common-protos==1.53.0
grpcio==1.41.1
h5py==3.6.0
huggingface-hub==0.1.1
idna==3.3
imagesize==1.2.0
importlib-metadata==4.8.1
iniconfig==1.1.1
ipython==7.29.0
isort==5.10.1
jedi==0.18.0
jinja2==3.0.2
jmespath==0.10.0
joblib==1.1.0
jsonnet==0.17.0
keras-preprocessing==1.1.2
keras==2.7.0
keyring==23.2.1
libclang==12.0.0
livereload==2.6.3
markdown-it-py==1.1.0
markdown==3.3.4
markupsafe==2.0.1
matplotlib-inline==0.1.2
mccabe==0.6.1
mdit-py-plugins==0.3.0
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
more-itertools==8.10.0
multidict==5.2.0
multiprocess==0.70.12.2
mypy-extensions==0.4.3
mypy==0.931
myst-parser==0.16.1
nltk==3.6.7
numpy==1.21.2
oauthlib==3.1.1
olefile==0.46
opt-einsum==3.3.0
overrides==6.1.0
packaging==21.2
pandas==1.3.4
parso==0.8.2
pathspec==0.9.0
pathtools==0.1.2
petname==2.6
pexpect==4.8.0
pickleshare==0.7.5
pillow==8.4.0
pip==21.2.4
pkginfo==1.7.1
platformdirs==2.4.0
pluggy==1.0.0
promise==2.3
prompt-toolkit==3.0.20
protobuf==3.19.1
psutil==5.8.0
ptyprocess==0.7.0
py==1.11.0
pyarrow==6.0.0
pyasn1-modules==0.2.8
pyasn1==0.4.8
pycodestyle==2.8.0
pydeprecate==0.3.1
pyflakes==2.4.0
pygments==2.10.0
pyparsing==2.4.7
pytest-cov==3.0.0
pytest-sphinx==0.3.1
pytest==6.2.5
python-dateutil==2.8.2
pytorch-lightning==1.5.1
pytz==2021.3
pyyaml==6.0
readme-renderer==30.0
regex==2021.11.2
requests-oauthlib==1.3.0
requests-toolbelt==0.9.1
requests==2.26.0
rfc3986==1.5.0
rouge-score==0.0.4
rsa==4.7.2
s3transfer==0.5.0
sacremoses==0.0.46
sentencepiece==0.1.96
sentry-sdk==1.4.3
setuptools==58.0.4
shortuuid==1.0.1
six==1.16.0
smmap==5.0.0
snowballstemmer==2.1.0
soupsieve==2.3
sphinx-autobuild==2021.3.14
sphinx-copybutton==0.4.0
sphinx==4.3.1
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlitedict==1.7.0
subprocess32==3.5.4
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorboard==2.7.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.23.1
tensorflow==2.7.0
termcolor==1.1.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.2
torch==1.10.0
torchaudio==0.10.0
torchmetrics==0.6.0
torchvision==0.11.1
tornado==6.1
tqdm==4.62.3
traitlets==5.1.0
transformers==4.12.3
twine==3.5.0
types-pyyaml==6.0.0
types-setuptools==57.4.2
typing-extensions==3.10.0.2
typing-utils==0.1.0
urllib3==1.26.7
wandb==0.12.6
wcwidth==0.2.5
webencodings==0.5.1
werkzeug==2.0.2
wheel==0.37.0
wrapt==1.13.3
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zipp==3.6.0
================================================
FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/cache-metadata.json
================================================
{
"step": "SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz"
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/conda-environment.yaml
================================================
name: tango
channels:
- pytorch
- defaults
dependencies:
- appnope=0.1.2=py38hecd8cb5_1001
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2021.10.26=hecd8cb5_2
- certifi=2021.10.8=py38hecd8cb5_0
- decorator=5.1.0=pyhd3eb1b0_0
- ffmpeg=4.3=h0a44026_0
- freetype=2.11.0=hd8bbffd_0
- gettext=0.21.0=h7535e17_0
- giflib=5.2.1=haf1e3a3_0
- gmp=6.2.1=h23ab428_2
- gnutls=3.6.15=hed9c0bf_0
- icu=58.2=h0a44026_3
- intel-openmp=2021.4.0=hecd8cb5_3538
- ipython=7.29.0=py38h01d92e1_0
- jedi=0.18.0=py38hecd8cb5_1
- jpeg=9d=h9ed2024_0
- lame=3.100=h1de35cc_0
- lcms2=2.12=hf1fd2bf_0
- libcxx=12.0.0=h2f01273_0
- libffi=3.3=hb1e8313_2
- libiconv=1.16=h1de35cc_0
- libidn2=2.3.2=h9ed2024_0
- libpng=1.6.37=ha441bb4_0
- libtasn1=4.16.0=h9ed2024_0
- libtiff=4.2.0=h87d7836_0
- libunistring=0.9.10=h9ed2024_0
- libuv=1.40.0=haf1e3a3_0
- libwebp=1.2.0=hacca55c_0
- libwebp-base=1.2.0=h9ed2024_0
- libxml2=2.9.12=hcdb78fc_0
- llvm-openmp=12.0.0=h0dcd299_1
- lz4-c=1.9.3=h23ab428_1
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mkl=2021.4.0=hecd8cb5_637
- mkl-service=2.4.0=py38h9ed2024_0
- mkl_fft=1.3.1=py38h4ab4a9b_0
- mkl_random=1.2.2=py38hb2f4e1b_0
- ncurses=6.3=hca72f7f_1
- nettle=3.7.3=h230ac6f_1
- numpy=1.21.2=py38h4b4dc7a_0
- numpy-base=1.21.2=py38he0bd621_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd9629dc_0
- openssl=1.1.1l=h9ed2024_0
- parso=0.8.2=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h98e4679_0
- pip=21.2.4=py38hecd8cb5_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pygments=2.10.0=pyhd3eb1b0_0
- python=3.8.12=h88f2d9e_0
- pytorch=1.10.0=py3.8_0
- readline=8.1=h9ed2024_0
- setuptools=58.0.4=py38hecd8cb5_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hce871da_0
- tk=8.6.11=h7bc2e8c_0
- torchaudio=0.10.0=py38_cpu
- torchvision=0.11.1=py38_cpu
- traitlets=5.1.0=pyhd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h1de35cc_0
- zlib=1.2.11=h1de35cc_3
- zstd=1.4.9=h322a384_0
- pip:
- absl-py==0.15.0
- aiohttp==3.8.0
- aiosignal==1.2.0
- alabaster==0.7.12
- astunparse==1.6.3
- async-timeout==4.0.0
- attrs==21.2.0
- babel==2.9.1
- base58==2.1.1
- beautifulsoup4==4.10.0
- black==21.12b0
- bleach==4.1.0
- boto3==1.19.12
- botocore==1.22.12
- cached-path==1.0.0
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- click-help-colors==0.9.1
- codecov==2.1.12
- colorama==0.4.4
- configparser==5.1.0
- coverage==6.1.1
- datasets==1.15.1
- dill==0.3.4
- docker-pycreds==0.4.0
- docutils==0.17.1
- filelock==3.4.0
- flake8==4.0.1
- flaky==3.7.0
- flatbuffers==2.0
- frozenlist==1.2.0
- fsspec==2021.11.0
- furo==2022.1.2
- future==0.18.2
- gast==0.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- glob2==0.7
- google-api-core==2.2.2
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- google-cloud-core==2.1.0
- google-cloud-storage==1.42.3
- google-crc32c==1.3.0
- google-pasta==0.2.0
- google-resumable-media==2.1.0
- googleapis-common-protos==1.53.0
- grpcio==1.41.1
- h5py==3.6.0
- huggingface-hub==0.1.1
- idna==3.3
- imagesize==1.2.0
- importlib-metadata==4.8.1
- iniconfig==1.1.1
- isort==5.10.1
- jinja2==3.0.2
- jmespath==0.10.0
- joblib==1.1.0
- jsonnet==0.17.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- keyring==23.2.1
- libclang==12.0.0
- livereload==2.6.3
- markdown==3.3.4
- markdown-it-py==1.1.0
- markupsafe==2.0.1
- mccabe==0.6.1
- mdit-py-plugins==0.3.0
- more-itertools==8.10.0
- multidict==5.2.0
- multiprocess==0.70.12.2
- mypy==0.931
- mypy-extensions==0.4.3
- myst-parser==0.16.1
- nltk==3.6.7
- oauthlib==3.1.1
- opt-einsum==3.3.0
- overrides==6.1.0
- packaging==21.2
- pandas==1.3.4
- pathspec==0.9.0
- pathtools==0.1.2
- petname==2.6
- pkginfo==1.7.1
- platformdirs==2.4.0
- pluggy==1.0.0
- promise==2.3
- protobuf==3.19.1
- psutil==5.8.0
- py==1.11.0
- pyarrow==6.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycodestyle==2.8.0
- pydeprecate==0.3.1
- pyflakes==2.4.0
- pyparsing==2.4.7
- pytest==6.2.5
- pytest-cov==3.0.0
- pytest-sphinx==0.3.1
- python-dateutil==2.8.2
- pytorch-lightning==1.5.1
- pytz==2021.3
- pyyaml==6.0
- readme-renderer==30.0
- regex==2021.11.2
- requests==2.26.0
- requests-oauthlib==1.3.0
- requests-toolbelt==0.9.1
- rfc3986==1.5.0
- rouge-score==0.0.4
- rsa==4.7.2
- s3transfer==0.5.0
- sacremoses==0.0.46
- sentencepiece==0.1.96
- sentry-sdk==1.4.3
- shortuuid==1.0.1
- smmap==5.0.0
- snowballstemmer==2.1.0
- soupsieve==2.3
- sphinx==4.3.1
- sphinx-autobuild==2021.3.14
- sphinx-copybutton==0.4.0
- sphinxcontrib-applehelp==1.0.2
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlitedict==1.7.0
- subprocess32==3.5.4
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow==2.7.0
- tensorflow-estimator==2.7.0
- tensorflow-io-gcs-filesystem==0.23.1
- termcolor==1.1.0
- tokenizers==0.10.3
- toml==0.10.2
- tomli==1.2.2
- torchmetrics==0.6.0
- tornado==6.1
- tqdm==4.62.3
- transformers==4.12.3
- twine==3.5.0
- types-pyyaml==6.0.0
- types-setuptools==57.4.2
- typing-utils==0.1.0
- urllib3==1.26.7
- wandb==0.12.6
- webencodings==0.5.1
- werkzeug==2.0.2
- wrapt==1.13.3
- xxhash==2.0.2
- yarl==1.7.2
- yaspin==2.1.0
- zipp==3.6.0
prefix: /opt/miniconda3/envs/tango
================================================
FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/executor-metadata.json
================================================
{
"config": {
"type": "csub",
"a": {
"type": "ref",
"ref": "AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP"
},
"b": {
"type": "ref",
"ref": "ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9"
}
},
"duration": 0.0005,
"finished_at": 1642546366.57007,
"git": {
"commit": "8e09b66caffbff20fd0b1504c961932b97417e8d",
"remote": "https://github.com/allenai/tango.git"
},
"platform": {
"cpu_count": 16,
"executable": "/opt/miniconda3/envs/tango/bin/python",
"host": "ip-192-168-1-194.us-west-2.compute.internal",
"operating_system": "macOS-10.16-x86_64-i386-64bit",
"python": "3.8.12",
"root": "/Users/dirkg/Documents/tango/examples/euler",
"user": "dirkg"
},
"started_at": 1642546366.569589,
"step": "SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz",
"tango": {
"command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic",
"version": "0.4.0rc4"
}
}
================================================
FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/lock
================================================
================================================
FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/requirements.txt
================================================
absl-py==0.15.0
ai2-tango==0.4.0rc1
aiohttp==3.8.0
aiosignal==1.2.0
alabaster==0.7.12
appnope==0.1.2
astunparse==1.6.3
async-timeout==4.0.0
attrs==21.2.0
babel==2.9.1
backcall==0.2.0
base58==2.1.1
beautifulsoup4==4.10.0
black==21.12b0
bleach==4.1.0
boto3==1.19.12
botocore==1.22.12
cached-path==1.0.0
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.7
click-help-colors==0.9.1
click==8.0.3
codecov==2.1.12
colorama==0.4.4
configparser==5.1.0
coverage==6.1.1
datasets==1.15.1
decorator==5.1.0
dill==0.3.4
docker-pycreds==0.4.0
docutils==0.17.1
filelock==3.4.0
flake8==4.0.1
flaky==3.7.0
flatbuffers==2.0
frozenlist==1.2.0
fsspec==2021.11.0
furo==2022.1.2
future==0.18.2
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
glob2==0.7
google-api-core==2.2.2
google-auth-oauthlib==0.4.6
google-auth==2.3.3
google-cloud-core==2.1.0
google-cloud-storage==1.42.3
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==2.1.0
googleapis-common-protos==1.53.0
grpcio==1.41.1
h5py==3.6.0
huggingface-hub==0.1.1
idna==3.3
imagesize==1.2.0
importlib-metadata==4.8.1
iniconfig==1.1.1
ipython==7.29.0
isort==5.10.1
jedi==0.18.0
jinja2==3.0.2
jmespath==0.10.0
joblib==1.1.0
jsonnet==0.17.0
keras-preprocessing==1.1.2
keras==2.7.0
keyring==23.2.1
libclang==12.0.0
livereload==2.6.3
markdown-it-py==1.1.0
markdown==3.3.4
markupsafe==2.0.1
matplotlib-inline==0.1.2
mccabe==0.6.1
mdit-py-plugins==0.3.0
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
more-itertools==8.10.0
multidict==5.2.0
multiprocess==0.70.12.2
mypy-extensions==0.4.3
mypy==0.931
myst-parser==0.16.1
nltk==3.6.7
numpy==1.21.2
oauthlib==3.1.1
olefile==0.46
opt-einsum==3.3.0
overrides==6.1.0
packaging==21.2
pandas==1.3.4
parso==0.8.2
pathspec==0.9.0
pathtools==0.1.2
petname==2.6
pexpect==4.8.0
pickleshare==0.7.5
pillow==8.4.0
pip==21.2.4
pkginfo==1.7.1
platformdirs==2.4.0
pluggy==1.0.0
promise==2.3
prompt-toolkit==3.0.20
protobuf==3.19.1
psutil==5.8.0
ptyprocess==0.7.0
py==1.11.0
pyarrow==6.0.0
pyasn1-modules==0.2.8
pyasn1==0.4.8
pycodestyle==2.8.0
pydeprecate==0.3.1
pyflakes==2.4.0
pygments==2.10.0
pyparsing==2.4.7
pytest-cov==3.0.0
pytest-sphinx==0.3.1
pytest==6.2.5
python-dateutil==2.8.2
pytorch-lightning==1.5.1
pytz==2021.3
pyyaml==6.0
readme-renderer==30.0
regex==2021.11.2
requests-oauthlib==1.3.0
requests-toolbelt==0.9.1
requests==2.26.0
rfc3986==1.5.0
rouge-score==0.0.4
rsa==4.7.2
s3transfer==0.5.0
sacremoses==0.0.46
sentencepiece==0.1.96
sentry-sdk==1.4.3
setuptools==58.0.4
shortuuid==1.0.1
six==1.16.0
smmap==5.0.0
snowballstemmer==2.1.0
soupsieve==2.3
sphinx-autobuild==2021.3.14
sphinx-copybutton==0.4.0
sphinx==4.3.1
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlitedict==1.7.0
subprocess32==3.5.4
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorboard==2.7.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.23.1
tensorflow==2.7.0
termcolor==1.1.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.2
torch==1.10.0
torchaudio==0.10.0
torchmetrics==0.6.0
torchvision==0.11.1
tornado==6.1
tqdm==4.62.3
traitlets==5.1.0
transformers==4.12.3
twine==3.5.0
types-pyyaml==6.0.0
types-setuptools==57.4.2
typing-extensions==3.10.0.2
typing-utils==0.1.0
urllib3==1.26.7
wandb==0.12.6
wcwidth==0.2.5
webencodings==0.5.1
werkzeug==2.0.2
wheel==0.37.0
wrapt==1.13.3
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zipp==3.6.0
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/common/__init__.py
================================================
================================================
FILE: tests/common/dataset_dict_test.py
================================================
from tango.common.dataset_dict import DatasetDict
def test_dataset_dict():
dataset_dict = DatasetDict(splits={"train": list(range(10)), "test": list(range(5))})
assert len(dataset_dict) == 2
assert "train" in dataset_dict
assert "test" in dataset_dict
assert len(dataset_dict["train"]) == 10
assert len(dataset_dict["test"]) == 5
assert set(dataset_dict) == set(dataset_dict.keys()) == {"train", "test"}
================================================
FILE: tests/common/det_hash_test.py
================================================
from tango.common.det_hash import DetHashWithVersion, det_hash
def test_normal_det_hash():
class C:
VERSION = 1
def __init__(self, x: int):
self.x = x
c1_1 = C(10)
c2_1 = C(10)
c3_1 = C(20)
assert det_hash(c1_1) == det_hash(c2_1)
assert det_hash(c3_1) != det_hash(c2_1)
class C:
VERSION = 2
def __init__(self, x: int):
self.x = x
c1_2 = C(10)
c2_2 = C(10)
c3_2 = C(20)
assert det_hash(c1_2) == det_hash(c2_2)
assert det_hash(c3_2) != det_hash(c2_2)
assert det_hash(c1_2) == det_hash(c1_1) # because the version isn't taken into account
assert det_hash(c3_2) == det_hash(c3_1) # because the version isn't taken into account
def test_versioned_det_hash():
class C(DetHashWithVersion):
VERSION = "1"
def __init__(self, x: int):
self.x = x
c1_1 = C(10)
c2_1 = C(10)
c3_1 = C(20)
assert det_hash(c1_1) == det_hash(c2_1)
assert det_hash(c3_1) != det_hash(c2_1)
class C(DetHashWithVersion):
VERSION = "2"
def __init__(self, x: int):
self.x = x
c1_2 = C(10)
c2_2 = C(10)
c3_2 = C(20)
assert det_hash(c1_2) == det_hash(c2_2)
assert det_hash(c3_2) != det_hash(c2_2)
assert det_hash(c1_2) != det_hash(c1_1) # because the version is taken into account
assert det_hash(c3_2) != det_hash(c3_1) # because the version is taken into account
================================================
FILE: tests/common/from_params_pep_563_test.py
================================================
"""
This tests `FromParams` functionality with https://www.python.org/dev/peps/pep-0563/.
"""
from __future__ import annotations
from tango.common.from_params import FromParams, infer_method_params
from tango.common.lazy import Lazy
class Foo(FromParams):
def __init__(self, x: int):
self.x = x
class Bar(FromParams):
def __init__(self, foo: Lazy[Foo]):
self.foo = foo.construct()
class Baz(FromParams):
def __init__(self, bar: Lazy[Bar]):
self.bar = bar.construct()
def test_infer_method_params():
parameters = infer_method_params(Baz, Baz.__init__)
assert not isinstance(parameters["bar"].annotation, str)
def test_from_params():
baz = Baz.from_params({"bar": {"foo": {"x": 1}}})
assert baz.bar.foo.x == 1
================================================
FILE: tests/common/from_params_test.py
================================================
import sys
from copy import deepcopy
from dataclasses import dataclass
from numbers import Number
from typing import (
Dict,
Generic,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
import pytest
from tango.common import det_hash
from tango.common.det_hash import DetHashWithVersion
from tango.common.exceptions import ConfigurationError
from tango.common.from_params import (
FromParams,
create_kwargs,
is_base_registrable,
remove_optional,
takes_arg,
)
from tango.common.lazy import Lazy
from tango.common.params import Params
from tango.common.registrable import Registrable
from tango.common.testing import TangoTestCase
from tango.step import Step
class TestFromParams(TangoTestCase):
def test_takes_arg(self):
def bare_function(some_input: int) -> int:
return some_input + 1
assert takes_arg(bare_function, "some_input")
assert not takes_arg(bare_function, "some_other_input")
class SomeClass:
total = 0
def __init__(self, constructor_param: str) -> None:
self.constructor_param = constructor_param
def check_param(self, check: str) -> bool:
return self.constructor_param == check
@classmethod
def set_total(cls, new_total: int) -> None:
cls.total = new_total
assert takes_arg(SomeClass, "self")
assert takes_arg(SomeClass, "constructor_param")
assert not takes_arg(SomeClass, "check")
assert takes_arg(SomeClass.check_param, "check")
assert not takes_arg(SomeClass.check_param, "other_check")
assert takes_arg(SomeClass.set_total, "new_total")
assert not takes_arg(SomeClass.set_total, "total")
def test_remove_optional(self):
optional_type = Optional[Dict[str, str]]
bare_type = remove_optional(optional_type) # type: ignore
bare_bare_type = remove_optional(bare_type)
assert bare_type == Dict[str, str]
assert bare_bare_type == Dict[str, str]
assert remove_optional(Optional[str]) == str # type: ignore[arg-type]
assert remove_optional(str) == str
@pytest.mark.parametrize("input_type", [dict, Params])
def test_from_params(self, input_type):
params = {"my_int": 10}
my_class = MyClass.from_params(input_type(params), my_bool=True)
assert isinstance(my_class, MyClass)
assert my_class.my_int == 10
assert my_class.my_bool
def test_create_kwargs(self):
kwargs = create_kwargs(
MyClass, MyClass, Params({"my_int": 5}), dict(my_bool=True, my_float=4.4)
)
# my_float should not be included because it's not a param of the MyClass constructor
assert kwargs == {"my_int": 5, "my_bool": True}
def test_extras(self):
class A(Registrable):
pass
@A.register("b")
class B(A):
def __init__(self, size: int, name: str) -> None:
self.size = size
self.name = name
@A.register("c")
class C(A):
def __init__(self, size: int, name: str) -> None:
self.size = size
self.name = name
# custom from params
@classmethod
def from_params(cls, params: Params, size: int, **extras) -> "C": # type: ignore
name = params.pop("name")
return cls(size=size, name=name)
# Check that extras get passed, even though A doesn't need them.
params = Params({"type": "b", "size": 10})
b: B = A.from_params(params, name="extra") # type: ignore[assignment]
assert b.name == "extra"
assert b.size == 10
# Check that extra extras don't get passed.
params = Params({"type": "b", "size": 10})
b = A.from_params(params, name="extra", unwanted=True) # type: ignore[assignment]
assert b.name == "extra" # type: ignore[attr-defined]
assert b.size == 10 # type: ignore[attr-defined]
# Now the same with a custom from_params.
params = Params({"type": "c", "name": "extra_c"})
c: C = A.from_params(params, size=20) # type: ignore[assignment]
assert c.name == "extra_c"
assert c.size == 20
# Check that extra extras don't get passed.
params = Params({"type": "c", "name": "extra_c"})
c = A.from_params(params, size=20, unwanted=True) # type: ignore[assignment]
assert c.name == "extra_c" # type: ignore[attr-defined]
assert c.size == 20 # type: ignore[attr-defined]
def test_variable_length_tuple(self):
class Foo(FromParams):
def __init__(self, x: Tuple[Optional[int], ...]):
self.x = x
assert Foo.from_params({"x": [None, 1, 2, 3]}).x == (None, 1, 2, 3)
assert Foo.from_params({"x": [1, 2]}).x == (1, 2)
assert Foo.from_params({"x": [1]}).x == (1,)
def test_union(self):
class A(FromParams):
def __init__(self, a: Union[int, List[int]]) -> None:
self.a = a
class B(FromParams):
def __init__(self, b: Union[A, List[A]]) -> None:
# Really you would want to be sure that `self.b` has a consistent type, but for
# this test we'll ignore that.
self.b = b
params = Params({"a": 3})
a = A.from_params(params)
assert a.a == 3
params = Params({"a": [3, 4, 5]})
a = A.from_params(params)
assert a.a == [3, 4, 5]
params = Params({"b": {"a": 3}})
b = B.from_params(params)
assert isinstance(b.b, A)
assert b.b.a == 3
params = Params({"b": [{"a": 3}, {"a": [4, 5]}]})
b = B.from_params(params)
assert isinstance(b.b, list)
assert b.b[0].a == 3
assert b.b[1].a == [4, 5]
def test_non_params_object_with_params(self):
bar = Bar.from_params({"foo": Foo(a=1)})
assert bar.foo.a == 1
def test_crazy_nested_union(self):
class A(FromParams):
def __init__(self, a: Union[int, List[int]]) -> None:
self.a = a
class B(FromParams):
def __init__(self, b: Union[A, List[A]]) -> None:
# Really you would want to be sure that `self.b` has a consistent type, but for
# this test we'll ignore that.
self.b = b
class C(FromParams):
def __init__(self, c: Union[A, B, Dict[str, A]]) -> None:
# Really you would want to be sure that `self.c` has a consistent type, but for
# this test we'll ignore that.
self.c = c
# This is a contrived, ugly example (why would you want to duplicate names in a nested
# structure like this??), but it demonstrates a potential bug when dealing with mutatable
# parameters. If you're not careful about keeping the parameters un-mutated in two
# separate places, you'll end up with a B, or with a dict that's missing the 'b' key.
params = Params({"c": {"a": {"a": 3}, "b": {"a": [4, 5]}}})
c = C.from_params(params)
assert isinstance(c.c, dict)
assert c.c["a"].a == 3
assert c.c["b"].a == [4, 5]
def test_union_of_castable_types(self):
class IntFloat(FromParams):
def __init__(self, a: Union[int, float]) -> None:
self.a = a
class FloatInt(FromParams):
def __init__(self, a: Union[float, int]) -> None:
self.a = a
float_param_str = '{"a": 1.0}'
int_param_str = '{"a": 1}'
import json
for expected_type, param_str in [(int, int_param_str), (float, float_param_str)]:
for cls in [IntFloat, FloatInt]:
c = cls.from_params(Params(json.loads(param_str)))
assert type(c.a) == expected_type # type: ignore[attr-defined]
def test_invalid_type_conversions(self):
class A(FromParams):
def __init__(self, a: int) -> None:
self.a = a
with pytest.raises(TypeError):
A.from_params(Params({"a": "1"}))
with pytest.raises(TypeError):
A.from_params(Params({"a": 1.0}))
def test_dict(self):
class A(Registrable):
pass
@A.register("b")
class B(A):
def __init__(self, size: int) -> None:
self.size = size
class C(Registrable):
pass
@C.register("d")
class D(C):
def __init__(self, items: Dict[str, A]) -> None:
self.items = items
params = Params(
{
"type": "d",
"items": {"first": {"type": "b", "size": 1}, "second": {"type": "b", "size": 2}},
}
)
d: D = C.from_params(params) # type: ignore[assignment]
assert isinstance(d.items, dict)
assert len(d.items) == 2
assert all(isinstance(key, str) for key in d.items.keys())
assert all(isinstance(value, B) for value in d.items.values())
assert d.items["first"].size == 1 # type: ignore[attr-defined]
assert d.items["second"].size == 2 # type: ignore[attr-defined]
def test_dict_not_params(self):
class A(FromParams):
def __init__(self, counts: Dict[str, int]) -> None:
self.counts = counts
params = Params({"counts": {"a": 10, "b": 20}})
a = A.from_params(params)
assert isinstance(a.counts, dict)
assert not isinstance(a.counts, Params)
def test_list(self):
class A(Registrable):
pass
@A.register("b")
class B(A):
def __init__(self, size: int) -> None:
self.size = size
class C(Registrable):
pass
@C.register("d")
class D(C):
def __init__(self, items: List[A]) -> None:
self.items = items
params = Params(
{"type": "d", "items": [{"type": "b", "size": 1}, {"type": "b", "size": 2}]}
)
d: D = C.from_params(params) # type: ignore[assignment]
assert isinstance(d.items, list)
assert len(d.items) == 2
assert all(isinstance(item, B) for item in d.items)
assert d.items[0].size == 1 # type: ignore[attr-defined]
assert d.items[1].size == 2 # type: ignore[attr-defined]
def test_tuple(self):
class A(Registrable):
pass
@A.register("b")
class B(A):
def __init__(self, size: int) -> None:
self.size = size
class C(Registrable):
pass
@C.register("d")
class D(C):
def __init__(self, name: str) -> None:
self.name = name
class E(Registrable):
pass
@E.register("f")
class F(E):
def __init__(self, items: Tuple[A, C]) -> None:
self.items = items
params = Params(
{"type": "f", "items": [{"type": "b", "size": 1}, {"type": "d", "name": "item2"}]}
)
f: F = E.from_params(params) # type: ignore[assignment]
assert isinstance(f.items, tuple)
assert len(f.items) == 2
assert isinstance(f.items[0], B)
assert isinstance(f.items[1], D)
assert f.items[0].size == 1
assert f.items[1].name == "item2"
def test_set(self):
class A(Registrable):
def __init__(self, name: str) -> None:
self.name = name
def __eq__(self, other):
return self.name == other.name
def __hash__(self):
return hash(self.name)
@A.register("b")
class B(A):
pass
class C(Registrable):
pass
@C.register("d")
class D(C):
def __init__(self, items: Set[A]) -> None:
self.items = items
params = Params(
{
"type": "d",
"items": [
{"type": "b", "name": "item1"},
{"type": "b", "name": "item2"},
{"type": "b", "name": "item2"},
],
}
)
d: D = C.from_params(params) # type: ignore[assignment]
assert isinstance(d.items, set)
assert len(d.items) == 2
assert all(isinstance(item, B) for item in d.items)
assert any(item.name == "item1" for item in d.items)
assert any(item.name == "item2" for item in d.items)
def test_kwargs_with_multiple_inheritance(self):
# Basic idea: have two identical classes, differing only in the order of their multiple
# inheritance, and make sure that passing kwargs up to the super class works in both cases.
class A(Registrable):
def __init__(self, a: int):
self.a = a
@A.register("b1") # type: ignore
class B1(A, Number):
def __init__(self, b: float, **kwargs):
super().__init__(**kwargs)
self.b = b
@A.register("b2") # type: ignore
class B2(Number, A):
def __init__(self, b: float, **kwargs):
super().__init__(**kwargs)
self.b = b
b1 = B1.from_params(Params({"a": 4, "b": 5}))
assert b1.b == 5
assert b1.a == 4
b2 = B2.from_params(Params({"a": 4, "b": 5}))
assert b2.b == 5
assert b2.a == 4
def test_instantiating_with_multiple_inheritance(self):
class A(Registrable):
def __init__(self, a: int):
self.a = a
@A.register("b") # type: ignore
class B(A, Number):
def __init__(self, b: float, **kwargs):
super().__init__(**kwargs)
self.b = b
assert not is_base_registrable(B)
@B.register("c")
class C(B):
def __init__(self, c: float, **kwargs):
super().__init__(**kwargs)
self.c = c
# make sure we can instantiate B directly.
b = B.from_params({"b": 1.0, "a": 1})
assert isinstance(b, B)
# and also make sure we can instantiate subclasses of B.
c = B.from_params({"type": "c", "c": 2.0, "b": 1.0, "a": 1})
assert isinstance(c, C)
def test_only_infer_superclass_params_if_unknown(self):
class BaseClass(Registrable):
def __init__(self):
self.x = None
self.a = None
self.rest = None
@BaseClass.register("a")
class A(BaseClass):
def __init__(self, a: int, x: int, **kwargs):
super().__init__()
self.x = x
self.a = a
self.rest = kwargs
@BaseClass.register("b")
class B(A):
def __init__(self, a: str, x: int = 42, **kwargs):
super().__init__(x=x, a=-1, raw_a=a, **kwargs)
params = Params({"type": "b", "a": "123"})
# The param `x` should not be required as it has default value in `B`
# The correct type of the param `a` should be inferred from `B` as well.
instance = BaseClass.from_params(params)
assert instance.x == 42
assert instance.a == -1
assert len(instance.rest) == 1 # type: ignore
assert isinstance(instance.rest["raw_a"], str) # type: ignore
assert instance.rest["raw_a"] == "123" # type: ignore
def test_kwargs_are_passed_to_deeper_superclasses(self):
class BaseClass(Registrable):
def __init__(self):
self.a = None
self.b = None
self.c = None
@BaseClass.register("a")
class A(BaseClass):
def __init__(self, a: str):
super().__init__()
self.a = a
@BaseClass.register("b")
class B(A):
def __init__(self, b: str, **kwargs):
super().__init__(**kwargs)
self.b = b
@BaseClass.register("c")
class C(B):
def __init__(self, c, **kwargs):
super().__init__(**kwargs)
self.c = c
params = Params({"type": "c", "a": "a_value", "b": "b_value", "c": "c_value"})
instance = BaseClass.from_params(params)
assert instance.a == "a_value"
assert instance.b == "b_value"
assert instance.c == "c_value"
def test_lazy_construction_can_happen_multiple_times(self):
test_string = "this is a test"
extra_string = "extra string"
class ConstructedObject(FromParams):
def __init__(self, string: str, extra: str):
self.string = string
self.extra = extra
class Testing(FromParams):
def __init__(self, lazy_object: Lazy[ConstructedObject]):
first_time = lazy_object.construct(extra=extra_string)
second_time = lazy_object.construct(extra=extra_string)
assert first_time.string == test_string
assert first_time.extra == extra_string
assert second_time.string == test_string
assert second_time.extra == extra_string
Testing.from_params(Params({"lazy_object": {"string": test_string}}))
def test_lazy_and_from_params_can_be_pickled(self):
import pickle
baz = Baz.from_params(Params({"bar": {"foo": {"a": 2}}}))
pickle.dumps(baz)
def test_optional_vs_required_lazy_objects(self):
class ConstructedObject(FromParams):
def __init__(self, a: int):
self.a = a
class Testing(FromParams):
def __init__(
self,
lazy1: Lazy[ConstructedObject],
lazy2: Lazy[ConstructedObject] = Lazy(ConstructedObject),
lazy3: Lazy[ConstructedObject] = None,
lazy4: Optional[Lazy[ConstructedObject]] = Lazy(ConstructedObject),
) -> None:
self.lazy1 = lazy1.construct()
self.lazy2 = lazy2.construct(a=2)
self.lazy3 = None if lazy3 is None else lazy3.construct()
self.lazy4 = None if lazy4 is None else lazy4.construct(a=1)
test1 = Testing.from_params(Params({"lazy1": {"a": 1}}))
assert test1.lazy1.a == 1
assert test1.lazy2.a == 2
assert test1.lazy3 is None
assert test1.lazy4 is not None
test2 = Testing.from_params(Params({"lazy1": {"a": 1}, "lazy2": {"a": 3}}))
assert test2.lazy1.a == 1
assert test2.lazy2.a == 3
assert test2.lazy3 is None
assert test2.lazy4 is not None
test3 = Testing.from_params(Params({"lazy1": {"a": 1}, "lazy3": {"a": 3}, "lazy4": None}))
assert test3.lazy1.a == 1
assert test3.lazy2.a == 2
assert test3.lazy3 is not None
assert test3.lazy3.a == 3
assert test3.lazy4 is None
with pytest.raises(ConfigurationError, match='Missing key "lazy1" for Testing'):
Testing.from_params(Params({}))
def test_wrapper_kwargs_passed_down(self):
class BaseObject:
def __init__(self, x: int = 1):
self.x = x
class BaseWrapper(BaseObject, FromParams):
def __init__(self, y: int = 2, **kwargs):
super().__init__(**kwargs)
self.y = y
o = BaseWrapper.from_params(Params({"y": 3}), x=2)
assert o.x == 2
def test_iterable(self):
class A(Registrable):
pass
@A.register("b")
class B(A):
def __init__(self, size: int) -> None:
self.size = size
class C(Registrable):
pass
@C.register("d")
class D(C):
def __init__(self, items: Iterable[A]) -> None:
self.items = items
params = Params(
{"type": "d", "items": [{"type": "b", "size": 1}, {"type": "b", "size": 2}]}
)
d: D = C.from_params(params) # type: ignore[assignment]
assert isinstance(d.items, Iterable)
items = list(d.items)
assert len(items) == 2
assert all(isinstance(item, B) for item in items)
assert items[0].size == 1 # type: ignore
assert items[1].size == 2 # type: ignore
def test_mapping(self):
class A(Registrable):
pass
@A.register("b")
class B(A):
def __init__(self, size: int) -> None:
self.size = size
class C(Registrable):
pass
@C.register("d")
class D(C):
def __init__(self, items: Mapping[str, A]) -> None:
self.items = items
params = Params(
{
"type": "d",
"items": {"first": {"type": "b", "size": 1}, "second": {"type": "b", "size": 2}},
}
)
d: D = C.from_params(params) # type: ignore[assignment]
assert isinstance(d.items, Mapping)
assert len(d.items) == 2
assert all(isinstance(key, str) for key in d.items.keys())
assert all(isinstance(value, B) for value in d.items.values())
assert d.items["first"].size == 1 # type: ignore
assert d.items["second"].size == 2 # type: ignore
def test_custom_abc_mapping(self):
from collections import abc
class CustomMapping(abc.Mapping):
def __init__(self, data: Dict[str, int]):
self.data = data
def __getitem__(self, key):
return self.data[key]
def __iter__(self):
return iter(self.data)
def __len__(self):
return len(self.data)
class ClassWithCustomMapping(FromParams):
def __init__(self, mapping: CustomMapping):
self.mapping = mapping
o = ClassWithCustomMapping.from_params({"mapping": {"data": {"a": 1}}})
assert isinstance(o.mapping, CustomMapping)
assert o.mapping["a"] == 1
def test_extra_parameters_are_not_allowed_when_there_is_no_constructor(self):
class A(FromParams):
pass
with pytest.raises(ConfigurationError, match="Extra parameters"):
A.from_params(Params({"some_spurious": "key", "value": "pairs"}))
def test_explicit_kwargs_always_passed_to_constructor(self):
class Base(FromParams):
def __init__(self, lazy: bool = False, x: int = 0) -> None:
self.lazy = lazy
self.x = x
class A(Base):
def __init__(self, **kwargs) -> None:
assert "lazy" in kwargs
super().__init__(**kwargs)
A.from_params(Params({"lazy": False}))
class B(Base):
def __init__(self, **kwargs) -> None:
super().__init__(lazy=True, **kwargs)
b = B.from_params(Params({}))
assert b.lazy is True
def test_raises_when_there_are_no_implementations(self):
class A(Registrable):
pass
with pytest.raises(ConfigurationError, match="not in acceptable choices for type"):
A.from_params("nonexistent_class")
with pytest.raises(ConfigurationError, match='key "type" is required'):
A.from_params(Params({"some_spurious": "key", "value": "pairs"}))
with pytest.raises(ConfigurationError, match='key "type" is required'):
A.from_params(Params({}))
# Some paths through the code are different if there is a constructor here versus not. We
# don't actually go through this logic anymore, but it's here as a regression test.
class B(Registrable):
def __init__(self):
pass
with pytest.raises(ConfigurationError, match="not in acceptable choices for type"):
B.from_params("nonexistent_class")
with pytest.raises(ConfigurationError, match='key "type" is required'):
B.from_params(Params({"some_spurious": "key", "value": "pairs"}))
with pytest.raises(ConfigurationError, match='key "type" is required'):
B.from_params(Params({}))
def test_from_params_raises_error_on_wrong_parameter_name_in_optional_union(self):
class NestedClass(FromParams):
def __init__(self, varname: Optional[str] = None):
self.varname = varname
class WrapperClass(FromParams):
def __init__(self, nested_class: Optional[Union[str, NestedClass]] = None):
if isinstance(nested_class, str):
nested_class = NestedClass(varname=nested_class)
self.nested_class = nested_class
with pytest.raises(ConfigurationError):
WrapperClass.from_params(Params({"nested_class": {"wrong_varname": "varstring"}}))
def test_from_params_handles_base_class_kwargs(self):
class Foo(FromParams):
def __init__(self, a: int, b: str = None, **kwargs) -> None:
self.a = a
self.b = b
for key, value in kwargs.items():
setattr(self, key, value)
foo = Foo.from_params(Params({"a": 2, "b": "hi"}))
assert foo.a == 2
assert foo.b == "hi"
foo = Foo.from_params(Params({"a": 2, "b": "hi", "c": {"2": "3"}}))
assert foo.a == 2
assert foo.b == "hi"
assert foo.c == {"2": "3"} # type: ignore[attr-defined]
class Bar(Foo):
def __init__(self, a: int, b: str, d: int, **kwargs) -> None:
super().__init__(a, b=b, **kwargs)
self.d = d
bar = Bar.from_params(Params({"a": 2, "b": "hi", "c": {"2": "3"}, "d": 0}))
assert bar.a == 2
assert bar.b == "hi"
assert bar.c == {"2": "3"} # type: ignore[attr-defined]
assert bar.d == 0
class Baz(Foo):
def __init__(self, a: int, b: Optional[str] = "a", **kwargs) -> None:
super().__init__(a, b=b, **kwargs)
baz = Baz.from_params(Params({"a": 2, "b": None}))
assert baz.b is None
baz = Baz.from_params(Params({"a": 2}))
assert baz.b == "a"
def test_from_params_base_class_kwargs_crashes_if_params_not_handled(self):
class Bar(FromParams):
def __init__(self, c: str = None) -> None:
self.c = c
class Foo(Bar):
def __init__(self, a: int, b: str = None, **kwargs) -> None:
super().__init__(**kwargs)
self.a = a
self.b = b
foo = Foo.from_params(Params({"a": 2, "b": "hi", "c": "some value"}))
assert foo.a == 2
assert foo.b == "hi"
assert foo.c == "some value"
with pytest.raises(TypeError, match="invalid_key"):
Foo.from_params(Params({"a": 2, "b": "hi", "invalid_key": "some value"}))
def test_from_params_handles_kwargs_in_non_from_params_registered_class(self):
class Bar(Registrable):
pass
class Baz:
def __init__(self, a: int) -> None:
self.a = a
@Bar.register("foo")
class Foo(Baz):
def __init__(self, a: int, b: str = None, **kwargs) -> None:
super().__init__(a)
self.b = b
for key, value in kwargs.items():
setattr(self, key, value)
foo: Foo = Bar.from_params(Params({"type": "foo", "a": 2, "b": "hi"})) # type: ignore[assignment]
assert foo.a == 2
assert foo.b == "hi"
foo = Bar.from_params( # type: ignore[assignment]
Params({"type": "foo", "a": 2, "b": "hi", "c": {"2": "3"}})
)
assert foo.a == 2 # type: ignore[attr-defined]
assert foo.b == "hi" # type: ignore[attr-defined]
assert foo.c == {"2": "3"} # type: ignore[attr-defined]
def test_from_params_passes_extras_to_non_from_params_registered_class(self):
class Bar(Registrable):
pass
class Baz:
def __init__(self, a: int, c: Dict[str, str] = None, extra: str = "idk") -> None:
self.a = a
self.c = c
self.extra = extra
@Bar.register("foo")
class Foo(Baz):
def __init__(self, a: int, b: str = None, **kwargs) -> None:
super().__init__(a, **kwargs)
self.b = b
foo: Foo = Bar.from_params(Params({"type": "foo", "a": 2, "b": "hi"})) # type: ignore[assignment]
assert foo.a == 2
assert foo.b == "hi"
assert foo.c is None
foo = Bar.from_params( # type: ignore[assignment]
Params({"type": "foo", "a": 2, "b": "hi", "c": {"2": "3"}}), extra="4"
)
assert foo.a == 2 # type: ignore[attr-defined]
assert foo.b == "hi" # type: ignore[attr-defined]
assert foo.c == {"2": "3"} # type: ignore[attr-defined]
assert foo.extra == "4" # type: ignore[attr-defined]
def test_from_params_child_has_kwargs_base_implicit_constructor(self):
class Foo(FromParams):
pass
class Bar(Foo):
def __init__(self, a: int, **kwargs) -> None:
self.a = a
bar = Bar.from_params(Params({"a": 2}))
assert bar.a == 2
def test_from_params_has_args(self):
class Foo(FromParams):
def __init__(self, a: int, *args) -> None:
self.a = a
foo = Foo.from_params(Params({"a": 2}))
assert foo.a == 2
def test_from_params_with_dataclass(self):
@dataclass
class Foo(FromParams):
x: int
y: str
assert Foo.from_params({"x": 1, "y": "2"}).x == 1
with pytest.raises(TypeError):
Foo.from_params({"x": 1, "y": 2})
def test_to_params(self):
@dataclass
class Bar(FromParams):
z: bool
@dataclass
class Foo(FromParams):
x: int
bar: Bar
params_dict = {"x": 1, "bar": {"z": True}}
foo = Foo.from_params(deepcopy(params_dict))
assert foo.bar.z
params = foo.to_params()
assert params.as_dict() == params_dict
def test_to_params_needs_custom_to_params(self):
@dataclass
class Bar:
z: bool
@dataclass
class Foo(FromParams):
x: int
bar: Bar
foo = Foo.from_params({"x": 1}, bar=Bar(z=True))
with pytest.raises(NotImplementedError):
foo.to_params()
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 or higher")
def test_type_hinting_generics_from_std_collections(self):
class Item(FromParams):
def __init__(self, a: int) -> None:
self.a = a
class ClassWithStdGenerics(FromParams):
def __init__(self, x: list[Item], y: dict[str, Item]) -> None: # type: ignore[syntax]
self.x = x
self.y = y
o = ClassWithStdGenerics.from_params({"x": [{"a": 1}], "y": {"b": {"a": 1}}})
assert isinstance(o.x, list)
assert isinstance(o.x[0], Item)
assert isinstance(o.y["b"], Item)
def test_with_non_from_params_generics(self):
T = TypeVar("T")
class Item(Generic[T]):
def __init__(self, x: T):
self.x = x
class ClassWithGenerics(FromParams):
def __init__(self, item: Item[T]):
self.item = item
o = ClassWithGenerics.from_params({"item": {"x": 1}})
assert isinstance(o.item, Item)
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python 3.10 or higher")
def test_with_union_pipe(self):
class Item(FromParams):
def __init__(self, a: int) -> None:
self.a = a
class ClassWithUnionType(FromParams):
def __init__(self, x: Item | str): # type: ignore[syntax]
self.x = x
o = ClassWithUnionType.from_params({"x": {"a": 1}})
assert isinstance(o.x, Item)
def test_from_params_with_function(self):
"""
Tests that a function registered as a constructor for a registrable class
will properly construct arguments.
"""
class MyRegistrableClass(Registrable):
def __init__(self, a: int, b: int):
self.a = a
self.b = b
@dataclass
class OptionsClass(FromParams):
a: int
b: int
@MyRegistrableClass.register("func_constructor") # type: ignore
def constructor(options: OptionsClass) -> MyRegistrableClass:
assert isinstance(options, OptionsClass)
return MyRegistrableClass(options.a, options.b)
MyRegistrableClass.from_params({"type": "func_constructor", "options": {"a": 1, "b": 2}})
def test_from_params_passes_no_extra_args_in_factory_construction(self):
class InnerBase(Registrable):
pass
from typing import Callable
def innerbase_with_x_factory(cls) -> Callable[..., InnerBase]:
def factory(x: int, **kwargs) -> InnerBase:
return cls(x=x, **kwargs)
return factory
class Inner(InnerBase):
def __init__(self, x: int):
self.x = x
InnerBase.register("inner")(innerbase_with_x_factory(Inner)) # type: ignore[arg-type]
class OuterBase(Registrable):
default_implementation = "default"
def __init__(self, y: str, i: InnerBase, c: int):
self.i = i
self.y = y
self.c = c
OuterBase.register("default")(OuterBase)
config = {"c": 4, "i": {"type": "inner", "x": 5}}
outer_lazy = Lazy(OuterBase, Params(config))
outer = outer_lazy.construct(y="placeholder")
assert outer.i.x == 5 # type: ignore[attr-defined]
def test_lazy_from_params_with_version(self):
class Gizmo(Registrable):
pass
@Gizmo.register("widget")
class WidgetGizmo(Gizmo, DetHashWithVersion):
VERSION = "001"
def __init__(self, x: int):
self.x = x
@classmethod
def default(cls):
return WidgetGizmo(0)
Gizmo.register("default_widget", "default")(WidgetGizmo)
lazy = Lazy(Gizmo, params=Params({"type": "widget", "x": 1}))
hash_before = det_hash(lazy)
WidgetGizmo.VERSION = "001"
assert hash_before == det_hash(lazy)
WidgetGizmo.VERSION = "002"
assert hash_before != det_hash(lazy)
assert lazy.construct().x == 1 # type: ignore[attr-defined]
default_lazy = Lazy(
Gizmo,
params=Params(
{
"type": "default_widget",
}
),
)
assert hash_before != det_hash(default_lazy)
assert det_hash(lazy) != det_hash(default_lazy)
hash_before = det_hash(default_lazy)
WidgetGizmo.VERSION = "003"
assert hash_before != det_hash(default_lazy)
assert default_lazy.construct().x == 0 # type: ignore[attr-defined]
def test_from_params_that_takes_step_directly(self):
class FakeStepBase(Step):
def run(self, test_input: int) -> int: # type: ignore
return test_input
@FakeStepBase.register("fake_step")
class FakeStep(FakeStepBase):
def run(self, test_input: int) -> int: # type: ignore
return test_input * 2
class FromParamsWithStepInput(FromParams):
def __init__(self, fake_step: FakeStepBase):
self.fake_step = fake_step
o = FromParamsWithStepInput.from_params(
{"fake_step": {"type": "fake_step", "test_input": 1}}
)
assert isinstance(o.fake_step, FakeStep)
class MyClass(FromParams):
def __init__(self, my_int: int, my_bool: bool = False) -> None:
self.my_int = my_int
self.my_bool = my_bool
class Foo(FromParams):
def __init__(self, a: int = 1) -> None:
self.a = a
class Bar(FromParams):
def __init__(self, foo: Foo) -> None:
self.foo = foo
class Baz(FromParams):
def __init__(self, bar: Lazy[Bar]) -> None:
self._bar = bar
@property
def bar(self):
return self._bar.construct()
================================================
FILE: tests/common/params_test.py
================================================
import json
import os
import re
from collections import OrderedDict
import pytest
from tango.common.exceptions import ConfigurationError
from tango.common.params import (
Params,
infer_and_cast,
remove_keys_from_params,
with_overrides,
)
from tango.common.testing import TangoTestCase
class TestParams(TangoTestCase):
@pytest.mark.parametrize("extension", ["jsonnet", "yaml"])
def test_load_from_file(self, extension):
filename = self.FIXTURES_ROOT / "common" / f"params_example.{extension}"
params = Params.from_file(filename)
assert params["model"]["type"] == "classifier"
def test_replace_none(self):
params = Params({"a": "None", "b": [1.0, "None", 2], "c": {"d": "None"}})
assert params["a"] is None
assert params["b"][1] is None
assert params["c"]["d"] is None
def test_init_with_different_types(self):
assert Params({"a": 1, "b": 2}) == Params(Params({"a": 1, "b": 2}))
def test_bad_unicode_environment_variables(self):
filename = self.FIXTURES_ROOT / "common" / "params_example.jsonnet"
os.environ["BAD_ENVIRONMENT_VARIABLE"] = "\udce2"
Params.from_file(filename)
del os.environ["BAD_ENVIRONMENT_VARIABLE"]
def test_with_overrides(self):
original = {
"foo": {"bar": {"baz": 3}, "x": 0},
"bar": ["a", "b", "c"],
"baz": {"bar": 2, "y": 3, "x": [0, 1, 2]},
}
overrides = {
"foo.bar": {"z": 2},
"bar.0": "d",
"baz.bar": 1,
"baz.x": [0, 0],
"z": 2,
}
assert with_overrides(original, overrides) == {
"foo": {"bar": {"z": 2}, "x": 0},
"bar": ["d", "b", "c"],
"baz": {"bar": 1, "y": 3, "x": [0, 0]},
"z": 2,
}
def test_bad_overrides(self):
with pytest.raises(ValueError, match="contains unused keys"):
with_overrides({"foo": [0, 1, 2]}, {"foo.3": 4})
with pytest.raises(ValueError, match="expected list or dict"):
with_overrides({"foo": 3}, {"foo.x": 2})
@pytest.mark.parametrize("input_type", [dict, str])
def test_overrides(self, input_type):
filename = self.FIXTURES_ROOT / "common" / "params_example.jsonnet"
overrides = {
"data_path": "train.txt",
"model.type": "new_classifier",
"model.layers.0.activation": "gelu",
"model.layers.1": {"type": "classifier"},
}
params = Params.from_file(
filename, overrides if input_type == dict else json.dumps(overrides)
)
assert params["data_path"] == "train.txt"
assert params["model"]["type"] == "new_classifier"
assert len(params["model"]["layers"]) == 2
assert params["model"]["layers"][0]["activation"] == "gelu"
assert params["model"]["layers"][1]["type"] == "classifier"
def test_as_flat_dict(self):
params = Params({"a": 10, "b": {"c": 20, "d": "stuff"}}).as_flat_dict()
assert params == {"a": 10, "b.c": 20, "b.d": "stuff"}
def test_jsonnet_features(self):
config_file = self.TEST_DIR / "config.jsonnet"
with open(config_file, "w") as f:
f.write(
"""{
// This example is copied straight from the jsonnet docs
person1: {
name: "Alice",
welcome: "Hello " + self.name + "!",
},
person2: self.person1 { name: "Bob" },
}"""
)
params = Params.from_file(config_file)
alice = params.pop("person1")
bob = params.pop("person2")
assert alice.as_dict() == {"name": "Alice", "welcome": "Hello Alice!"}
assert bob.as_dict() == {"name": "Bob", "welcome": "Hello Bob!"}
params.assert_empty("TestParams")
def test_regexes_with_backslashes(self):
bad_regex = self.TEST_DIR / "bad_regex.jsonnet"
good_regex = self.TEST_DIR / "good_regex.jsonnet"
with open(bad_regex, "w") as f:
f.write(r'{"myRegex": "a\.b"}')
with open(good_regex, "w") as f:
f.write(r'{"myRegex": "a\\.b"}')
with pytest.raises(RuntimeError):
Params.from_file(bad_regex)
params = Params.from_file(good_regex)
regex = params["myRegex"]
assert re.match(regex, "a.b")
assert not re.match(regex, "a-b")
# Check roundtripping
good_regex2 = self.TEST_DIR / "good_regex2.jsonnet"
with open(good_regex2, "w") as f:
f.write(json.dumps(params.as_dict()))
params2 = Params.from_file(good_regex2)
assert params.as_dict() == params2.as_dict()
def test_env_var_substitution(self):
substitutor = self.TEST_DIR / "substitutor.jsonnet"
key = "TEST_ENV_VAR_SUBSTITUTION"
assert os.environ.get(key) is None
with open(substitutor, "w") as f:
f.write(f'{{"path": std.extVar("{key}")}}')
# raises without environment variable set
with pytest.raises(RuntimeError):
Params.from_file(substitutor)
os.environ[key] = "PERFECT"
params = Params.from_file(substitutor)
assert params["path"] == "PERFECT"
del os.environ[key]
def test_as_ordered_dict(self):
# keyD > keyC > keyE; keyDA > keyDB; Next all other keys alphabetically
preference_orders = [["keyD", "keyC", "keyE"], ["keyDA", "keyDB"]]
params = Params(
{
"keyC": "valC",
"keyB": "valB",
"keyA": "valA",
"keyE": "valE",
"keyD": {"keyDB": "valDB", "keyDA": "valDA"},
}
)
ordered_params_dict = params.as_ordered_dict(preference_orders)
expected_ordered_params_dict = OrderedDict(
{
"keyD": {"keyDA": "valDA", "keyDB": "valDB"},
"keyC": "valC",
"keyE": "valE",
"keyA": "valA",
"keyB": "valB",
}
)
assert json.dumps(ordered_params_dict) == json.dumps(expected_ordered_params_dict)
def test_to_file(self):
# Test to_file works with or without preference orders
params_dict = {"keyA": "valA", "keyB": "valB"}
expected_ordered_params_dict = OrderedDict({"keyB": "valB", "keyA": "valA"})
params = Params(params_dict)
file_path = self.TEST_DIR / "config.jsonnet"
# check with preference orders
params.to_file(file_path, [["keyB", "keyA"]])
with open(file_path, "r") as handle:
ordered_params_dict = OrderedDict(json.load(handle))
assert json.dumps(expected_ordered_params_dict) == json.dumps(ordered_params_dict)
# check without preference orders doesn't give error
params.to_file(file_path)
def test_infer_and_cast(self):
lots_of_strings = {
"a": ["10", "1.3", "true"],
"b": {"x": 10, "y": "20.1", "z": "other things"},
"c": "just a string",
}
casted = {
"a": [10, 1.3, True],
"b": {"x": 10, "y": 20.1, "z": "other things"},
"c": "just a string",
}
assert infer_and_cast(lots_of_strings) == casted
contains_bad_data = {"x": 10, "y": int}
with pytest.raises(ValueError, match="cannot infer type"):
infer_and_cast(contains_bad_data)
params = Params(lots_of_strings)
assert params.as_dict() == lots_of_strings
assert params.as_dict(infer_type_and_cast=True) == casted
def test_pop_choice(self):
choices = ["my_model", "other_model"]
params = Params({"model": "my_model"})
assert params.pop_choice("model", choices) == "my_model"
params = Params({"model": "non_existent_model"})
with pytest.raises(ConfigurationError):
params.pop_choice("model", choices)
params = Params({"model": "module.submodule.ModelName"})
assert params.pop_choice("model", choices) == "module.submodule.ModelName"
params = Params({"model": "module.submodule.ModelName"})
with pytest.raises(ConfigurationError):
params.pop_choice("model", choices, allow_class_names=False)
def test_remove_keys_from_params(self):
filename = self.FIXTURES_ROOT / "common" / "params_example.jsonnet"
params = Params.from_file(filename)
assert params["model"]["layers"][0]["activation"] == "relu"
assert params["model"]["layers"][1]["activation"] == "softmax"
remove_keys_from_params(params, keys=["activation"])
assert "activation" not in params["model"]["layers"][0]
assert "activation" not in params["model"]["layers"][1]
================================================
FILE: tests/common/registrable_test.py
================================================
import pytest
from tango.common.exceptions import ConfigurationError
from tango.common.registrable import Registrable
from tango.common.testing import TangoTestCase
from tango.step import Step
class TestRegistrable(TangoTestCase):
def test_basic_functionality(self):
class MockBaseClass(Registrable):
pass
assert "mock-1" not in MockBaseClass.list_available()
@MockBaseClass.register("mock-1")
class MockSubclass1(MockBaseClass):
pass
assert MockBaseClass in Registrable._registry
assert MockBaseClass.by_name("mock-1") == MockSubclass1
# Verify that registering under a name that already exists
# causes a ConfigurationError.
with pytest.raises(ConfigurationError):
@MockBaseClass.register("mock-1")
class MockAlternate(MockBaseClass):
pass
# Registering under a name that already exists should overwrite
# if exist_ok=True.
@MockBaseClass.register("mock-1", exist_ok=True)
class MockAlternate2(MockBaseClass):
pass
assert MockBaseClass.by_name("mock-1") == MockAlternate2
# Test that we get a suggestion when the name is close.
with pytest.raises(ConfigurationError) as exc:
MockBaseClass.by_name("mock_1")
assert "did you mean 'mock-1'?" in str(exc.value)
def test_registering_step_by_reserved_name(self):
with pytest.raises(ConfigurationError, match="cannot use the name 'ref'"):
@Step.register("ref")
class BadStep(Step):
pass
def test_search_modules(self):
Step.search_modules("foo-bar-baz-non-existent")
================================================
FILE: tests/common/sequences_test.py
================================================
import os
from tempfile import TemporaryDirectory
import pytest
from tango.common.sequences import (
ConcatenatedSequence,
MappedSequence,
ShuffledSequence,
SlicedSequence,
SqliteSparseSequence,
)
def assert_equal_including_exceptions(expected_fn, actual_fn):
try:
expected = expected_fn()
except Exception as e:
with pytest.raises(e.__class__):
actual_fn()
else:
assert expected == actual_fn()
def test_shuffled_sequence():
seq = ShuffledSequence(list(range(10)))
assert 5 in seq
assert len(seq) == 10
def test_sliced_sequence():
seq = SlicedSequence(list(range(10)), slice(10))
assert len(seq) == 10
assert seq[0] == 0
assert seq[-1] == 9
seq2 = seq[-2:]
assert len(seq2) == 2
def test_concatenated_sequence():
l1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
l2 = ConcatenatedSequence([0, 1], [], [2, 3, 4], [5, 6, 7, 8, 9], [])
# __len__()
assert len(l1) == len(l2)
# index()
for item in l1 + [999]:
# no indices
assert_equal_including_exceptions(lambda: l1.index(item), lambda: l2.index(item))
# only start index
for index in range(-15, 15):
assert_equal_including_exceptions(
lambda: l1.index(item, index), lambda: l2.index(item, index)
)
# start and stop index
for start_index in range(-15, 15):
for end_index in range(-15, 15):
assert_equal_including_exceptions(
lambda: l1.index(item, start_index, end_index),
lambda: l2.index(item, start_index, end_index),
)
# __getitem__()
for index in range(-15, 15):
assert_equal_including_exceptions(lambda: l1[index], lambda: l2[index])
for start_index in range(-15, 15):
for end_index in range(-15, 15):
assert_equal_including_exceptions(
lambda: l1[start_index:end_index], lambda: list(l2[start_index:end_index])
)
# count()
for item in l1 + [999]:
assert_equal_including_exceptions(lambda: l1.count(item), lambda: l2.count(item))
# __contains__()
for item in l1 + [999]:
assert_equal_including_exceptions(lambda: item in l1, lambda: item in l2)
def test_sqlite_sparse_sequence():
with TemporaryDirectory(prefix="test_sparse_sequence-") as temp_dir:
s = SqliteSparseSequence(os.path.join(temp_dir, "test.sqlite"))
assert len(s) == 0
s.extend([])
assert len(s) == 0
s.append("one")
assert len(s) == 1
s.extend(["two", "three"])
s.insert(1, "two")
assert s[1] == "two"
assert s.count("two") == 2
ss = s[1:3]
assert list(ss) == ["two", "two"]
del s[1:3]
assert len(s) == 2
assert s[-1] == "three"
s.clear()
assert len(s) == 0
def test_mapped_sequence():
my_very_long_sequence = ["John", "Paul", "George", "Ringo"]
m = MappedSequence(lambda x: len(x), my_very_long_sequence)
assert m[0] == 4
assert len(m) == len(my_very_long_sequence)
for i in range(len(m)):
assert m[i] == m[i:][0]
================================================
FILE: tests/common/util_test.py
================================================
import os
import time
from pathlib import Path
import pytest
from flaky import flaky
from tango.common.testing import TangoTestCase
from tango.common.util import (
could_be_class_name,
find_integrations,
find_submodules,
resolve_module_name,
threaded_generator,
)
class TestResolveModuleName(TangoTestCase):
def setup_method(self):
super().setup_method()
self._work_dir_restore = os.getcwd()
os.chdir(self.TEST_DIR)
def teardown_method(self):
super().teardown_method()
os.chdir(self._work_dir_restore)
def test_with_package_init_file(self):
path = Path("fake_package/fake_module/__init__.py")
(self.TEST_DIR / path.parent).mkdir(parents=True)
open(path, "w").close()
open(path.parent.parent / "__init__.py", "w").close()
assert resolve_module_name(str(path)) == ("fake_package.fake_module", Path("."))
def test_with_submodule(self):
path = Path("fake_package/fake_module")
(self.TEST_DIR / path).mkdir(parents=True)
open(path / "__init__.py", "w").close()
open(path.parent / "__init__.py", "w").close()
assert resolve_module_name(str(path)) == ("fake_package.fake_module", Path("."))
def test_with_module_in_child_directory(self):
path = Path("some_dir/fake_module.py")
(self.TEST_DIR / path.parent).mkdir(parents=True)
open(path, "w").close()
assert resolve_module_name(str(path)) == ("fake_module", Path("./some_dir"))
def test_find_submodules():
assert "tango.version" in set(find_submodules())
assert "tango.common.registrable" in set(find_submodules())
assert "tango.common" in set(find_submodules(recursive=False))
assert "tango.common.registrable" not in set(find_submodules(recursive=False))
assert "tango.integrations.torch" in set(find_submodules("integrations"))
assert "tango.integrations.torch" not in set(find_submodules(exclude={"tango.integrations*"}))
def test_find_integrations():
integrations = set(find_integrations())
assert "tango.integrations.torch" in integrations
assert "tango.integrations.torch.format" not in integrations
@pytest.mark.parametrize(
"name, result",
[
("", False),
("foo.Bar", True),
("foo.Bar.", False),
("1foo.Bar", False),
("lib.my_package.MyClass", True),
],
)
def test_could_be_class_name(name: str, result: bool):
assert could_be_class_name(name) is result
@flaky(max_runs=3)
def test_threaded_generator():
def generate_slowly():
for i in range(10):
yield i
time.sleep(0.1)
start = time.time()
for i in threaded_generator(generate_slowly()):
time.sleep(0.1)
end = time.time()
assert end - start < 11
================================================
FILE: tests/end_to_end/test_dataset_dict_from_separate_steps.py
================================================
from typing import Any, Sequence
from tango import Format, JsonFormat, Step
from tango.common import DatasetDict
from tango.common.testing import run_experiment
@Step.register("train_data")
class TrainData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self) -> Sequence[int]: # type: ignore
return list(range(10))
@Step.register("val_data")
class ValData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self) -> Sequence[int]: # type: ignore
return list(range(10, 20))
@Step.register("save_data")
class SaveData(Step):
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = JsonFormat()
def run(self, dataset_dict: DatasetDict) -> Any: # type: ignore
return dataset_dict.splits
def test_experiment():
with run_experiment(
{
"steps": {
"train_data": {
"type": "train_data",
},
"val_data": {
"type": "val_data",
},
"saved_data": {
"type": "save_data",
"dataset_dict": {
"splits": {
"train": {"type": "ref", "ref": "train_data"},
"val": {"type": "ref", "ref": "val_data"},
}
},
},
}
}
) as run_dir:
assert (run_dir / "saved_data").is_dir()
fmt = JsonFormat()
data = fmt.read(run_dir / "saved_data")
assert data["train"] == list(range(10))
assert data["val"] == list(range(10, 20))
================================================
FILE: tests/end_to_end/test_lazy_input_with_another_step.py
================================================
from dataclasses import dataclass
from tango import Format, JsonFormat, Step
from tango.common import FromParams, Lazy
from tango.common.testing import run_experiment
@dataclass
class Foo(FromParams):
number: float
@Step.register("generate_number")
class GenerateNumberStep(Step):
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = JsonFormat()
def run(self) -> float: # type: ignore[override]
return 1.0
@Step.register("lazy_input")
class StepWithLazyInput(Step):
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = JsonFormat()
def run(self, foo: Lazy[Foo]) -> float: # type: ignore[override]
foo = foo.construct()
assert isinstance(foo, Foo)
assert isinstance(foo.number, float)
return foo.number
def test_experiment():
with run_experiment(
{
"steps": {
"gen_number": {
"type": "generate_number",
},
"get_number": {
"type": "lazy_input",
"foo": {
"number": {
"type": "ref",
"ref": "gen_number",
}
},
},
}
}
) as run_dir:
assert (run_dir / "get_number").is_dir()
fmt: Format = JsonFormat()
data = fmt.read(run_dir / "get_number")
assert data == 1.0
================================================
FILE: tests/end_to_end/test_multicore_cli.py
================================================
import pytest
from tango.common.exceptions import CliRunError
from tango.common.logging import initialize_logging, teardown_logging
from tango.common.testing import TangoTestCase
class TestExperiment(TangoTestCase):
def setup_method(self):
super().setup_method()
initialize_logging()
self.config = {
"steps": {
"step1": {
"type": "sleep-print-maybe-fail",
"string": "string_to_pass_down",
"seconds": 1,
"fail": True,
},
"step2": {
"type": "sleep-print-maybe-fail",
"string": {"type": "ref", "ref": "step1"},
"seconds": 1,
"fail": False,
},
"step3": {
"type": "sleep-print-maybe-fail",
"string": "This may or may not fail!",
"seconds": 3,
"fail": False,
},
}
}
def teardown_method(self):
super().teardown_method()
teardown_logging()
def test_experiment(self, caplog):
with pytest.raises(CliRunError):
self.run(
self.config,
multicore=True,
parallelism=2,
)
latest_outputs = self.TEST_DIR / "workspace" / "latest"
num_executed = 0
for out in latest_outputs.iterdir():
if (out / "cache-metadata.json").exists():
num_executed += 1
assert num_executed == 1
def test_experiment_with_overrides(self, caplog):
import json
self.run(
self.config,
multicore=True,
parallelism=2,
overrides=json.dumps({"steps.step1.fail": False}),
)
latest_outputs = self.TEST_DIR / "workspace" / "latest"
num_executed = 0
for out in latest_outputs.iterdir():
if (out / "cache-metadata.json").exists():
num_executed += 1
assert num_executed == 3
================================================
FILE: tests/end_to_end/test_non_cacheable_into_cacheable_multiple_runs.py
================================================
import random
from tango import Step
from tango.common.testing import TangoTestCase
@Step.register("give_me_a_number")
class GiveMeANumber(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self, what_number: int) -> int: # type: ignore
return what_number
@Step.register("random_int")
class RandomInt(Step):
DETERMINISTIC = True
CACHEABLE = True
def run(self, lower_bound: int, upper_bound: int) -> int: # type: ignore
return random.randint(lower_bound, upper_bound)
class TestExperiment(TangoTestCase):
def test_experiment(self, caplog):
config = {
"steps": {
"a_number": {
"type": "give_me_a_number",
"what_number": 3,
},
"final_number": {
"type": "random_int",
"lower_bound": 0,
"upper_bound": {"type": "ref", "ref": "a_number"},
},
}
}
self.run(config)
self.run(config, overrides={"steps.final_number.lower_bound": 1})
================================================
FILE: tests/end_to_end/test_registered_runs.py
================================================
from tango import Format, JsonFormat, Step
from tango.common.testing import TangoTestCase
@Step.register("return_a_number")
class ReturnANumber(Step):
DETERMINISTIC = True
CACHEABLE = True
FORMAT: Format = JsonFormat()
def run(self, what_number: int) -> int: # type: ignore
return what_number
class TestExperiment(TangoTestCase):
def test_experiment_updates_latest_run_output(self, caplog):
config = {
"steps": {
"a_number": {
"type": "return_a_number",
"what_number": 3,
},
}
}
self.run(config)
assert (self.TEST_DIR / "workspace" / "latest" / "a_number").exists()
fmt: Format = JsonFormat()
data = fmt.read(self.TEST_DIR / "workspace" / "latest" / "a_number")
assert data == 3
config = {
"steps": {
"a_number": {
"type": "return_a_number",
"what_number": 5,
},
}
}
self.run(config)
data = fmt.read(self.TEST_DIR / "workspace" / "latest" / "a_number")
assert data == 5
================================================
FILE: tests/end_to_end/test_run_single_step.py
================================================
from tango.common.testing import TangoTestCase
class TestRunSingleStep(TangoTestCase):
def test_run_single_step(self):
config = {
"steps": {
"strA": {"type": "string", "result": "Hello, "},
"strB": {"type": "string", "result": "World"},
"concatenated": {
"type": "concat_strings",
"string1": {"type": "ref", "ref": "strA"},
"string2": {"type": "ref", "ref": "strB"},
},
}
}
num_other_files = 2 # out.log and stepinfo.json
# Regular run contains all step outputs.
self.run(config)
latest_outputs = self.TEST_DIR / "workspace" / "latest"
assert len(list(latest_outputs.iterdir())) == num_other_files + 3
# Running a single step with no dependencies should have a single output.
self.run(config, step_name="strB")
latest_outputs = self.TEST_DIR / "workspace" / "latest"
assert len(list(latest_outputs.iterdir())) == num_other_files + 1
# Running a single step with one or more dependencies will also run the step's dependencies.
self.run(config, step_name="concatenated")
latest_outputs = self.TEST_DIR / "workspace" / "latest"
assert len(list(latest_outputs.iterdir())) == num_other_files + 3
================================================
FILE: tests/end_to_end/test_step_indexing.py
================================================
from tango.common.testing import TangoTestCase
from tango.workspaces import LocalWorkspace
class TestStepIndexing(TangoTestCase):
def test_step_indexing(self):
run_name = "run1"
config = {
"steps": {
"list": {"type": "range_step", "start": 0, "end": 3},
"added": {
"type": "add_numbers",
"a_number": 2,
"b_number": {"type": "ref", "ref": "list", "key": 1},
},
}
}
self.run(config, name=run_name)
workspace = LocalWorkspace(self.TEST_DIR / "workspace")
result = workspace.step_result_for_run(run_name, "added")
assert result == 3
================================================
FILE: tests/end_to_end/test_steps_that_fail.py
================================================
from collections import Counter
from typing import MutableMapping
import pytest
from tango import Step
from tango.common.exceptions import CliRunError
from tango.common.testing import TangoTestCase
step_execution_count: MutableMapping[str, int] = Counter()
@Step.register("step_a")
class StepA(Step):
def run(self, what_number: int) -> int: # type: ignore
global step_execution_count
step_execution_count["a"] += 1
return what_number
@Step.register("step_b")
class StepB(Step):
def run(self, what_number: int) -> int: # type: ignore
global step_execution_count
step_execution_count["b"] += 1
return what_number
step_should_fail: bool = True
@Step.register("step_fail")
class StepFail(Step):
def run(self, what_number: int) -> int: # type: ignore
global step_execution_count
step_execution_count["fail"] += 1
global step_should_fail
if step_should_fail:
raise RuntimeError("Step should fail")
else:
return what_number
class TestExperiment(TangoTestCase):
def test_experiment(self, caplog):
global step_should_fail
config = {
"steps": {
"a_number": {
"type": "step_a",
"what_number": 3,
},
"fail_number": {
"type": "step_fail",
"what_number": {"type": "ref", "ref": "a_number"},
},
"b_number": {
"type": "step_b",
"what_number": {"type": "ref", "ref": "fail_number"},
},
}
}
global step_should_fail
global step_execution_count
step_should_fail = True
with pytest.raises(CliRunError):
self.run(config)
assert step_execution_count["a"] == 1
assert step_execution_count["fail"] == 1
assert step_execution_count["b"] == 0
step_should_fail = False
self.run(config)
assert step_execution_count["a"] == 1
assert step_execution_count["fail"] == 2
assert step_execution_count["b"] == 1
================================================
FILE: tests/end_to_end/test_uncacheable_leaf_steps.py
================================================
from tango import Step
from tango.common.testing import TangoTestCase, run_experiment
from tango.common.testing.steps import MakeNumber # noqa:F401
stored_number = None
@Step.register("store_number_globally")
class StoreNumberGlobally(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self, number: int) -> None: # type: ignore
global stored_number
stored_number = number
class TestExperiment(TangoTestCase):
def test_experiment(self, caplog):
config = {
"steps": {
"a_number": {
"type": "make_number",
"what_number": 3,
},
"store_number": {
"type": "store_number_globally",
"number": {"type": "ref", "ref": "a_number"},
},
}
}
global stored_number
assert stored_number is None
self.run(config)
assert stored_number == 3
class TestExperimentMulticore(TangoTestCase):
def test_experiment(self, caplog):
file_name = self.TEST_DIR / "number_file.txt"
assert not file_name.exists()
with run_experiment(
{
"steps": {
"a_number": {
"type": "make_number",
"what_number": 3,
},
"store_number": {
"type": "store_number_in_file",
"number": {"type": "ref", "ref": "a_number"},
"file_name": str(file_name),
},
}
},
multicore=True,
):
with open(file_name) as file_ref:
number = file_ref.read()
assert int(number) == 3
================================================
FILE: tests/executor_test.py
================================================
from tango.common.testing import TangoTestCase
from tango.common.testing.steps import SleepPrintMaybeFail # noqa:F401
from tango.executor import Executor
from tango.step import Step
from tango.step_graph import StepGraph
from tango.workspaces import LocalWorkspace
@Step.register("sum_numbers")
class AdditionStep(Step):
DETERMINISTIC = True
CACHEABLE = True
def run(self, a: int, b: int) -> int: # type: ignore
return a + b
class TestExecutor(TangoTestCase):
def test_executor(self):
workspace = LocalWorkspace(self.TEST_DIR)
step = AdditionStep(a=1, b=2)
step_graph = StepGraph.from_params({"sum": {"type": "sum_numbers", "a": 1, "b": 2}})
executor = Executor(workspace)
assert len(executor.workspace.step_cache) == 0
output = executor.execute_step_graph(step_graph)
assert "sum" in output.successful
assert len(executor.workspace.step_cache) == 1
assert executor.workspace.step_cache[step] == 3
def test_executor_with_failing_steps(self):
workspace = LocalWorkspace(self.TEST_DIR)
step_graph = StepGraph.from_params(
{
"successful_step": {
"type": "sleep-print-maybe-fail",
"string": "This ran perfectly.",
"seconds": 0,
"fail": False,
},
"failing_step": {
"type": "sleep-print-maybe-fail",
"string": "This should fail.",
"seconds": 0,
"fail": True,
},
"dependent_step": {
"type": "sleep-print-maybe-fail",
"string": {"type": "ref", "ref": "failing_step"},
"seconds": 0,
"fail": False,
},
}
)
executor = Executor(workspace)
assert len(executor.workspace.step_cache) == 0
output = executor.execute_step_graph(step_graph)
assert "successful_step" in output.successful
assert "failing_step" in output.failed
assert "dependent_step" in output.not_run
assert len(executor.workspace.step_cache) == 1
================================================
FILE: tests/executors/__init__.py
================================================
================================================
FILE: tests/executors/multicore_executor_test.py
================================================
import time
import pytest
from tango.common.logging import initialize_logging
from tango.common.testing import TangoTestCase
from tango.common.testing.steps import SleepPrintMaybeFail
from tango.executors.multicore_executor import MulticoreExecutor
from tango.step_graph import StepGraph
from tango.workspaces import LocalWorkspace
class TestMulticoreExecutor(TangoTestCase):
def setup_method(self):
super().setup_method()
initialize_logging()
def test_simple_execution_in_parallel(self):
step_graph = StepGraph(
{
"step1": SleepPrintMaybeFail(string="hello", seconds=5, fail=False),
"step2": SleepPrintMaybeFail(string="hi", seconds=5, fail=False),
}
)
executor = MulticoreExecutor(workspace=LocalWorkspace(self.TEST_DIR), parallelism=2)
start_time = time.time()
executor.execute_step_graph(step_graph)
end_time = time.time()
time_taken = end_time - start_time
assert time_taken < 10 # TODO: will this be flaky?
assert len(executor.workspace.step_cache) == 2
def test_more_processes_ready_than_parallelism(self):
step_graph = StepGraph(
{
"step1": SleepPrintMaybeFail(string="hello", seconds=5, fail=False),
"step2": SleepPrintMaybeFail(string="hi", seconds=5, fail=False),
"step3": SleepPrintMaybeFail(string="howdy", seconds=5, fail=False),
}
)
executor = MulticoreExecutor(workspace=LocalWorkspace(self.TEST_DIR), parallelism=2)
start_time = time.time()
executor.execute_step_graph(step_graph)
end_time = time.time()
time_taken = end_time - start_time
assert 10 < time_taken < 20 # TODO: will this be flaky?
assert len(executor.workspace.step_cache) == 3
@pytest.mark.parametrize("parallelism", [1, 2, 3])
def test_failing_step_no_downstream_task(self, parallelism):
step_graph = StepGraph.from_params(
{
"step1": {
"type": "sleep-print-maybe-fail",
"string": "string_to_pass_down",
"seconds": 0,
"fail": False,
},
"step2": {
"type": "sleep-print-maybe-fail",
"string": {"type": "ref", "ref": "step1"},
"seconds": 0,
"fail": False,
},
"step3": {
"type": "sleep-print-maybe-fail",
"string": "This is going to fail!",
"seconds": 0,
"fail": True,
},
}
)
executor = MulticoreExecutor(
workspace=LocalWorkspace(self.TEST_DIR),
parallelism=parallelism,
)
executor.execute_step_graph(step_graph)
assert len(executor.workspace.step_cache) == 2
@pytest.mark.parametrize("parallelism", [1, 2, 3])
def test_failing_step_with_downstream_task(self, parallelism):
step_graph = StepGraph.from_params(
{
"step1": {
"type": "sleep-print-maybe-fail",
"string": "string_to_pass_down",
"seconds": 0,
"fail": True,
},
"step2": {
"type": "sleep-print-maybe-fail",
"string": {"type": "ref", "ref": "step1"},
"seconds": 0,
"fail": False,
},
"step3": {
"type": "sleep-print-maybe-fail",
"string": "This is going to fail!",
"seconds": 0,
"fail": False,
},
}
)
executor = MulticoreExecutor(
workspace=LocalWorkspace(self.TEST_DIR),
parallelism=parallelism,
)
executor.execute_step_graph(step_graph)
assert len(executor.workspace.step_cache) == 1
@pytest.mark.parametrize("parallelism", [1, 2, 3])
def test_failing_step_with_further_downstream_task(self, parallelism):
step_graph = StepGraph.from_params(
{
"step1": {
"type": "sleep-print-maybe-fail",
"string": "string_to_pass_down",
"seconds": 0,
"fail": True,
},
"step2": {
"type": "sleep-print-maybe-fail",
"string": {"type": "ref", "ref": "step1"},
"seconds": 0,
"fail": False,
},
"step3": {
"type": "sleep-print-maybe-fail",
"string": {"type": "ref", "ref": "step2"},
"seconds": 0,
"fail": False,
},
}
)
executor = MulticoreExecutor(
workspace=LocalWorkspace(self.TEST_DIR),
parallelism=parallelism,
)
executor.execute_step_graph(step_graph)
assert len(executor.workspace.step_cache) == 0
def test_uncacheable_failing_step_no_downstream_task(self):
step_graph = StepGraph.from_params(
{
"step1": {
"type": "sleep-print-maybe-fail",
"string": "string_to_pass_down",
"seconds": 0,
"fail": False,
},
"step2": {
"type": "sleep-print-maybe-fail",
"string": {"type": "ref", "ref": "step1"},
"seconds": 0,
"fail": False,
},
"step3": {
"type": "sleep-print-maybe-fail",
"string": "This is going to fail!",
"seconds": 0,
"fail": True,
"cache_results": False,
},
}
)
executor = MulticoreExecutor(
workspace=LocalWorkspace(self.TEST_DIR),
parallelism=2,
)
executor.execute_step_graph(step_graph)
assert len(executor.workspace.step_cache) == 2
def test_uncacheable_failing_step_with_downstream_task(self):
step_graph = StepGraph.from_params(
{
"step1": {
"type": "sleep-print-maybe-fail",
"string": "string_to_pass_down",
"seconds": 0,
"fail": True,
"cache_results": False,
},
"step2": {
"type": "sleep-print-maybe-fail",
"string": {"type": "ref", "ref": "step1"},
"seconds": 0,
"fail": False,
},
"step3": {
"type": "sleep-print-maybe-fail",
"string": "This is going to fail!",
"seconds": 0,
"fail": False,
},
}
)
executor = MulticoreExecutor(
workspace=LocalWorkspace(self.TEST_DIR),
parallelism=2,
)
executor.execute_step_graph(step_graph)
assert len(executor.workspace.step_cache) == 1
@pytest.mark.parametrize("parallelism", [1, 2, 3])
def test_steps_with_their_own_multiprocessing(self, parallelism):
step_graph = StepGraph.from_params(
{
"step1": {"type": "multiprocessing_step", "num_proc": 2},
"step2": {"type": "multiprocessing_step", "num_proc": 3},
"step3": {"type": "multiprocessing_step", "num_proc": 1},
}
)
executor = MulticoreExecutor(
workspace=LocalWorkspace(self.TEST_DIR),
parallelism=parallelism,
)
executor.execute_step_graph(step_graph)
assert len(executor.workspace.step_cache) == 3
================================================
FILE: tests/format_test.py
================================================
from typing import Dict, Iterable, Optional
import pytest
from tango.common.testing import TangoTestCase
from tango.format import _OPEN_FUNCTIONS, DillFormat, JsonFormat, TextFormat
class TestFormat(TangoTestCase):
@pytest.mark.parametrize("compress", _OPEN_FUNCTIONS.keys())
def test_dill_format(self, compress: Optional[str]):
artifact = "Hello, World!"
format = DillFormat[str](compress)
format.write(artifact, self.TEST_DIR)
assert format.read(self.TEST_DIR) == artifact
assert "compress" in format.to_params()
@pytest.mark.parametrize("compress", _OPEN_FUNCTIONS.keys())
def test_iterable_dill_format(self, compress: Optional[str]):
r = (x + 1 for x in range(10))
format = DillFormat[Iterable[int]](compress)
format.write(r, self.TEST_DIR)
r2 = format.read(self.TEST_DIR)
assert [x + 1 for x in range(10)] == list(r2)
assert "compress" in format.to_params()
@pytest.mark.parametrize("compress", _OPEN_FUNCTIONS.keys())
def test_json_format(self, compress: Optional[str]):
artifact = {"Hello, World!": "Hi!"}
format = JsonFormat[Dict[str, str]](compress)
format.write(artifact, self.TEST_DIR)
assert format.read(self.TEST_DIR) == artifact
assert "compress" in format.to_params()
@pytest.mark.parametrize("compress", _OPEN_FUNCTIONS.keys())
def test_iterable_json_format(self, compress: Optional[str]):
r = (x + 1 for x in range(10))
format = JsonFormat[Iterable[int]](compress)
format.write(r, self.TEST_DIR)
r2 = format.read(self.TEST_DIR)
assert [x + 1 for x in range(10)] == list(r2)
assert "compress" in format.to_params()
def test_iterable_text_format(self):
numbers = ["ichi", "ni", "san"]
l1 = iter(numbers)
format = TextFormat()
format.write(l1, self.TEST_DIR)
l2 = format.read(self.TEST_DIR)
assert list(l2) == numbers
================================================
FILE: tests/integrations/__init__.py
================================================
================================================
FILE: tests/integrations/beaker/__init__.py
================================================
================================================
FILE: tests/integrations/beaker/conftest.py
================================================
import os
import uuid
from pathlib import Path
from typing import Generator
import pytest
from beaker import Beaker
from tango.common import util
from tango.integrations.beaker.common import Constants
from tango.step import Step
@pytest.fixture(autouse=True)
def patched_cache_dir(tmp_path, monkeypatch) -> Path:
monkeypatch.setattr(util, "tango_cache_dir", lambda: tmp_path)
return tmp_path
@pytest.fixture(autouse=True)
def patched_unique_id_suffix(monkeypatch) -> str:
UNIQUE_ID_SUFFIX = os.environ.get("GITHUB_SHA", "")[:6] + "-" + str(uuid.uuid1())[:6]
monkeypatch.setattr(Step, "_UNIQUE_ID_SUFFIX", UNIQUE_ID_SUFFIX)
return UNIQUE_ID_SUFFIX
@pytest.fixture(autouse=True)
def patched_constants_prefix(monkeypatch) -> str:
PREFIX = os.environ.get("GITHUB_SHA", "A")[:6] + "-" + str(uuid.uuid1())[:6] + "-"
monkeypatch.setattr(Constants, "STEP_ARTIFACT_PREFIX", "tango-step-" + PREFIX)
monkeypatch.setattr(Constants, "RUN_ARTIFACT_PREFIX", "tango-run-" + PREFIX)
monkeypatch.setattr(Constants, "ENTRYPOINT_DATASET_PREFIX", "tango-entrypoint-" + PREFIX)
monkeypatch.setattr(Constants, "STEP_GRAPH_ARTIFACT_PREFIX", "tango-step-graph-" + PREFIX)
monkeypatch.setattr(Constants, "STEP_EXPERIMENT_PREFIX", "tango-step-" + PREFIX)
return PREFIX
@pytest.fixture
def beaker_workspace_name() -> str:
return "ai2/tango-beaker-testing"
@pytest.fixture
def beaker_workspace(
beaker_workspace_name: str, patched_unique_id_suffix: str, patched_constants_prefix: str
) -> Generator[str, None, None]:
beaker = Beaker.from_env(default_workspace=beaker_workspace_name)
yield beaker_workspace_name
# Remove experiments.
# for experiment in beaker.workspace.experiments(match=patched_constants_prefix):
# beaker.experiment.delete(experiment)
# Remove datasets.
for dataset in beaker.workspace.datasets(match=patched_unique_id_suffix):
beaker.dataset.delete(dataset)
for dataset in beaker.workspace.datasets(match=patched_constants_prefix):
beaker.dataset.delete(dataset)
================================================
FILE: tests/integrations/beaker/executor_test.py
================================================
import petname
import pytest
from beaker import DataMount
from tango.common.exceptions import ConfigurationError
from tango.common.testing import run_experiment
from tango.executor import Executor
from tango.integrations.beaker.executor import BeakerExecutor
from tango.integrations.beaker.workspace import BeakerWorkspace
from tango.settings import TangoGlobalSettings
from tango.workspaces import default_workspace
def test_from_params(beaker_workspace_name: str):
executor = Executor.from_params(
dict(
type="beaker",
beaker_workspace=beaker_workspace_name,
beaker_image="ai2/conda",
github_token="FAKE_TOKEN",
datasets=[{"source": {"beaker": "some-dataset"}, "mount_path": "/input"}],
budget="ai2/allennlp",
),
workspace=BeakerWorkspace(workspace=beaker_workspace_name),
clusters=["fake-cluster"],
)
assert isinstance(executor, BeakerExecutor)
assert executor.datasets is not None
assert len(executor.datasets) == 1
assert isinstance(executor.datasets[0], DataMount)
assert executor.datasets[0].source.beaker == "some-dataset"
def test_init_with_mem_workspace(beaker_workspace_name: str):
with pytest.raises(ConfigurationError, match="MemoryWorkspace"):
BeakerExecutor(
workspace=default_workspace,
beaker_workspace=beaker_workspace_name,
beaker_image="ai2/conda",
github_token="FAKE_TOKEN",
clusters=["fake-cluster"],
budget="ai2/allennlp",
)
@pytest.fixture
def settings(beaker_workspace_name: str) -> TangoGlobalSettings:
return TangoGlobalSettings(
workspace={"type": "beaker", "beaker_workspace": beaker_workspace_name},
executor={
"type": "beaker",
"beaker_workspace": beaker_workspace_name,
"install_cmd": "pip install .[beaker]",
"clusters": ["ai2/allennlp-cirrascale", "ai2/general-cirrascale"],
"budget": "ai2/allennlp",
},
)
def test_beaker_executor(
settings: TangoGlobalSettings, beaker_workspace_name: str, patched_unique_id_suffix: str
):
run_name = petname.generate()
with run_experiment(
{"steps": {"hello": {"type": "string", "result": "Hello, World!"}}},
settings=settings,
workspace_url=f"beaker://{beaker_workspace_name}",
name=run_name,
multicore=None,
):
workspace = BeakerWorkspace(workspace=beaker_workspace_name)
assert "hello" in workspace.registered_run(run_name).steps
================================================
FILE: tests/integrations/beaker/step_cache_test.py
================================================
from tango.common.testing.steps import FloatStep
from tango.integrations.beaker.step_cache import BeakerStepCache
def test_step_cache(beaker_workspace: str):
cache = BeakerStepCache(beaker_workspace=beaker_workspace)
step = FloatStep(result=1.0)
cache[step] = 1.0
assert step in cache
assert len(cache) == 1
assert FloatStep(result=2.0) not in cache
assert cache[step] == 1.0
================================================
FILE: tests/integrations/beaker/workspace_test.py
================================================
import pytest
from beaker import DatasetNotFound
from tango.common.testing.steps import FloatStep
from tango.integrations.beaker.workspace import BeakerWorkspace
from tango.step_info import StepState
from tango.workspace import Workspace
def test_from_url(beaker_workspace: str):
print(beaker_workspace)
workspace = Workspace.from_url(f"beaker://{beaker_workspace}")
assert isinstance(workspace, BeakerWorkspace)
def test_direct_usage(beaker_workspace: str):
workspace = BeakerWorkspace(beaker_workspace)
step = FloatStep(step_name="float", result=1.0)
run = workspace.register_run([step])
assert run.name in workspace.registered_runs()
assert workspace.step_info(step).state == StepState.INCOMPLETE
workspace.step_starting(step)
assert workspace.step_info(step).state == StepState.RUNNING
workspace.step_finished(step, 1.0)
assert workspace.step_info(step).state == StepState.COMPLETED
assert workspace.step_result_for_run(run.name, "float") == 1.0
def test_remove_step(beaker_workspace: str):
beaker_workspace = "ai2/tango_remove_cache_test"
workspace = BeakerWorkspace(beaker_workspace)
step = FloatStep(step_name="float", result=1.0)
workspace.step_starting(step)
workspace.step_finished(step, 1.0)
step_info = workspace.step_info(step)
dataset_name = workspace.Constants.step_artifact_name(step_info)
cache = workspace.step_cache
assert workspace.beaker.dataset.get(dataset_name) is not None
assert step in cache
workspace.remove_step(step.unique_id)
cache = workspace.step_cache
dataset_name = workspace.Constants.step_artifact_name(step_info)
with pytest.raises(DatasetNotFound):
workspace.beaker.dataset.get(dataset_name)
assert step not in cache
================================================
FILE: tests/integrations/datasets/__init__.py
================================================
================================================
FILE: tests/integrations/datasets/dataset_test.py
================================================
import datasets
from tango.common.sequences import MappedSequence
from tango.common.testing import TangoTestCase
from tango.integrations.datasets import (
DatasetRemixStep,
DatasetsFormat,
LoadDataset,
convert_to_tango_dataset_dict,
)
from tango.step import Step
class TestDatasets(TangoTestCase):
def test_from_params_and_convert_to_tango_dataset_dict(self):
step: LoadDataset = Step.from_params( # type: ignore[assignment]
{
"type": "datasets::load",
"path": "lhoestq/test",
"cache_dir": str(self.TEST_DIR / "cache"),
}
)
hf_dataset_dict = step.result()
assert "train" in hf_dataset_dict
dataset_dict = convert_to_tango_dataset_dict(hf_dataset_dict)
assert "train" in dataset_dict.splits
def test_convert_to_tango_iterable_dataset_dict(self):
def data_gen():
for x in range(100):
yield {"x": x}
hf_dataset_dict = datasets.IterableDatasetDict(
train=datasets.iterable_dataset.IterableDataset.from_generator(data_gen)
)
dataset_dict1 = convert_to_tango_dataset_dict(hf_dataset_dict)
assert "train" in dataset_dict1.splits
def test_load_concatenate_and_interleave(self):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "datasets" / "config.json",
overrides={
"steps.train_data.cache_dir": str(self.TEST_DIR / "cache"),
"steps.dev_data.cache_dir": str(self.TEST_DIR / "cache"),
},
)
assert (result_dir / "train_data" / "data").is_dir()
dataset = DatasetsFormat().read(result_dir / "train_data")
assert len(dataset) == 2
def test_mapped_sequence_of_dataset():
ds = datasets.load_dataset("piqa", split="validation")
mapped_ds = MappedSequence(lambda x: x["goal"], ds) # type: ignore[arg-type]
assert len(ds) == len(mapped_ds) # type: ignore[arg-type]
assert ds[0]["goal"] == mapped_ds[0] # type: ignore[index]
assert ds[0]["goal"] == mapped_ds[:10][0] # type: ignore[index]
def test_datasets_dataset_remix():
dataset_dict = datasets.load_dataset("lhoestq/test")
step = DatasetRemixStep()
result = step.run(
input=dataset_dict, # type: ignore[arg-type]
new_splits={
"all": "train + validation",
"crossval_train": "train[:1] + validation[1:]",
"crossval_test": "train[1:] + validation[:1]",
},
)
assert len(result["all"]) == len(dataset_dict["train"]) + len(dataset_dict["validation"]) # type: ignore
assert len(result["crossval_train"]) == 3
assert len(result["crossval_test"]) == 2
================================================
FILE: tests/integrations/fairscale/__init__.py
================================================
================================================
FILE: tests/integrations/fairscale/train_test.py
================================================
from typing import Any, Dict
import pytest
import torch
from tango.common.logging import initialize_logging, teardown_logging
from tango.common.testing import TangoTestCase
class TestFairScaleTrain(TangoTestCase):
def setup_method(self):
super().setup_method()
initialize_logging(log_level="info")
def teardown_method(self):
teardown_logging()
@pytest.mark.parametrize(
"fsdp",
(
pytest.param(
True,
id="fsdp=True",
marks=[
pytest.mark.gpu,
pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Requires CUDA devices"
),
],
),
pytest.param(False, id="fsdp=False"),
),
)
@pytest.mark.parametrize(
"activation_checkpoint",
(
pytest.param(True, id="checkpointing=True"),
pytest.param(False, id="checkpointing=False"),
),
)
@pytest.mark.parametrize(
"amp",
(
pytest.param(
True,
id="amp=True",
marks=[
pytest.mark.gpu,
pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Requires CUDA devices"
),
],
),
pytest.param(False, id="amp=False"),
),
)
def test_train_tiny_gpt2(self, fsdp: bool, activation_checkpoint: bool, amp: bool):
overrides: Dict[str, Any] = {
"steps.trained_model.model.activation_checkpointing": activation_checkpoint,
}
training_engine: Dict[str, Any] = {
"amp": amp,
"optimizer": {
"type": "torch::AdamW",
"lr": 0.005,
"betas": [0.9, 0.95],
"eps": 1e-6,
},
}
if fsdp:
training_engine["type"] = "fairscale"
fsdp_config = {"reshard_after_forward": True, "mixed_precision": amp}
training_engine["fsdp_config"] = fsdp_config
overrides["steps.trained_model.model.fsdp_config"] = fsdp_config
else:
training_engine["type"] = "torch"
overrides["steps.trained_model.model.fsdp_config"] = None
overrides["steps.trained_model.training_engine"] = training_engine
run_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "fairscale" / "config.jsonnet",
include_package=["test_fixtures.integrations.fairscale.components"],
overrides=overrides,
)
assert (run_dir / "trained_model").is_dir()
================================================
FILE: tests/integrations/flax/__init__.py
================================================
================================================
FILE: tests/integrations/flax/data_test.py
================================================
from typing import Dict
from transformers import AutoTokenizer
from tango.common.testing import TangoTestCase
from tango.integrations.flax import DataLoader, FlaxDataLoader
from tango.integrations.flax.util import get_PRNGkey
from tango.step import Step
class TestDataStep(TangoTestCase):
def test_dataloader(self) -> None:
assert "flax::dataloader" in DataLoader.list_available()
def test_sample_data(self) -> None:
step = Step.from_params( # type: ignore[assignment]
{
"type": "datasets::load",
"path": "lhoestq/demo1",
"split": "train",
"cache_dir": str(self.TEST_DIR / "cache"),
}
)
dataset = step.result()
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
column_names = dataset.column_names
dataset = dataset.map(
lambda e: tokenizer(e["review"], truncation=True, padding="max_length")
)
dataset = dataset.remove_columns(column_names)
data = FlaxDataLoader(dataset, batch_size=16)
rng = get_PRNGkey()
for batch in data(rng, do_distributed=False):
assert isinstance(batch, Dict)
================================================
FILE: tests/integrations/flax/format_test.py
================================================
import os
from tango import Format
from tango.common.testing import TangoTestCase
from tango.integrations.flax.format import FlaxFormat
class TestTorchFormat(TangoTestCase):
def test_read_write(self):
flax_format: FlaxFormat = Format.by_name("flax")() # type: ignore[assignment]
flax_format.write({"a": 1}, self.TEST_DIR)
assert os.path.exists(self.TEST_DIR / "checkpoint_0")
data = flax_format.read(self.TEST_DIR)
assert data == {"a": 1}
================================================
FILE: tests/integrations/flax/optim_test.py
================================================
from tango.integrations.flax.optim import LRScheduler, Optimizer
def test_all_optimizers_registered():
assert "optax::adafactor" in Optimizer.list_available()
def test_all_lr_schedulers_registered():
assert "optax::constant_schedule" in LRScheduler.list_available()
================================================
FILE: tests/integrations/flax/train_test.py
================================================
from tango.common.logging import initialize_logging, teardown_logging
from tango.common.testing import TangoTestCase
class TestTrainStep(TangoTestCase):
def setup_method(self):
super().setup_method()
initialize_logging(enable_cli_logs=True)
def teardown_method(self):
super().teardown_method()
teardown_logging()
def test_trainer(self):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "flax" / "config.jsonnet",
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.flax",
],
)
assert (
result_dir
/ "train"
/ "work"
/ "checkpoint_state_latest"
/ "checkpoint_0"
/ "checkpoint"
).is_file()
================================================
FILE: tests/integrations/gs/__init__.py
================================================
================================================
FILE: tests/integrations/gs/step_cache_test.py
================================================
import os
import pytest
from tango.common.testing import TangoTestCase
from tango.common.testing.steps import FloatStep
from tango.integrations.gs.common import empty_bucket_folder
from tango.integrations.gs.step_cache import GSStepCache
GS_BUCKET_NAME = os.environ.get("GS_BUCKET_NAME", "allennlp-tango-bucket")
GS_SUBFOLDER = f"{GS_BUCKET_NAME}/my-workspaces/workspace1"
class TestGSStepCache(TangoTestCase):
def setup_method(self):
super().setup_method()
empty_bucket_folder(GS_BUCKET_NAME)
empty_bucket_folder(GS_SUBFOLDER)
def teardown_method(self):
super().teardown_method()
@pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER])
def test_step_cache(self, gs_path):
cache = GSStepCache(folder_name=gs_path)
step = FloatStep(result=1.0)
cache[step] = 1.0
assert step in cache
assert len(cache) == 1
assert FloatStep(result=2.0) not in cache
assert cache[step] == 1.0
================================================
FILE: tests/integrations/gs/workspace_test.py
================================================
import os
import pytest
from tango.common.testing import TangoTestCase
from tango.common.testing.steps import FloatStep
from tango.integrations.gs.common import empty_bucket_folder, empty_datastore
from tango.integrations.gs.workspace import GSWorkspace
from tango.step_info import StepState
from tango.workspace import Workspace
GS_BUCKET_NAME = os.environ.get("GS_BUCKET_NAME", "allennlp-tango-bucket")
GS_SUBFOLDER = f"{GS_BUCKET_NAME}/my-workspaces/workspace1"
class TestGSWorkspace(TangoTestCase):
def setup_method(self):
super().setup_method()
empty_bucket_folder(GS_BUCKET_NAME)
empty_bucket_folder(GS_SUBFOLDER)
empty_datastore(GS_BUCKET_NAME)
empty_datastore(GS_SUBFOLDER)
def teardown_method(self):
super().teardown_method()
@pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER])
def test_from_url(self, gs_path: str):
workspace = Workspace.from_url(f"gs://{gs_path}")
assert isinstance(workspace, GSWorkspace)
@pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER])
def test_from_params(self, gs_path: str):
workspace = Workspace.from_params({"type": "gs", "workspace": gs_path})
assert isinstance(workspace, GSWorkspace)
@pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER])
def test_direct_usage(self, gs_path: str):
workspace = GSWorkspace(gs_path)
step = FloatStep(step_name="float", result=1.0)
run = workspace.register_run([step])
assert run.name in workspace.registered_runs()
assert workspace.step_info(step).state == StepState.INCOMPLETE
workspace.step_starting(step)
assert workspace.step_info(step).state == StepState.RUNNING
workspace.step_finished(step, 1.0)
assert workspace.step_info(step).state == StepState.COMPLETED
assert workspace.step_result_for_run(run.name, "float") == 1.0
def test_remove_step(self):
workspace = GSWorkspace(GS_BUCKET_NAME)
step = FloatStep(step_name="float", result=1.0)
step_info = workspace.step_info(step)
workspace.step_starting(step)
workspace.step_finished(step, 1.0)
bucket_artifact = workspace.Constants.step_artifact_name(step_info)
ds_entity = workspace._ds.get(key=workspace._ds.key("stepinfo", step_info.unique_id))
cache = workspace.step_cache
assert workspace.client.artifacts(prefix=bucket_artifact) is not None
assert ds_entity is not None
assert step in cache
workspace.remove_step(step.unique_id)
cache = workspace.step_cache
ds_entity = workspace._ds.get(key=workspace._ds.key("stepinfo", step_info.unique_id))
with pytest.raises(Exception) as excinfo:
workspace.client.artifacts(prefix=bucket_artifact)
assert "KeyError" in str(excinfo)
assert ds_entity is None
assert step not in cache
================================================
FILE: tests/integrations/torch/__init__.py
================================================
================================================
FILE: tests/integrations/torch/data_test.py
================================================
import torch
from tango.integrations.torch.data import DataLoader, Sampler
def test_dataloader_from_params():
DataLoader.from_params(
{
"dataset": list(range(10)),
"batch_size": 2,
"shuffle": True,
}
)
def test_samplers_registered():
assert "torch::SequentialSampler" in Sampler.list_available()
def test_dataloader_from_params_with_sampler():
dataloader = DataLoader.from_params(
{
"dataset": list(range(10)),
"sampler": {
"type": "torch::RandomSampler",
"replacement": True,
},
}
)
assert isinstance(dataloader.sampler, torch.utils.data.RandomSampler)
assert dataloader.sampler.replacement
def test_dataloader_from_params_with_batch_sampler():
dataloader = DataLoader.from_params(
{
"dataset": list(range(10)),
"sampler": {
"type": "torch::BatchSampler",
"sampler": {
"type": "torch::RandomSampler",
},
"batch_size": 2,
"drop_last": True,
},
}
)
assert isinstance(dataloader.sampler, torch.utils.data.BatchSampler)
================================================
FILE: tests/integrations/torch/det_hash_test.py
================================================
import numpy
import torch
from tango.common import det_hash
def test_numpy_det_hash():
a1 = numpy.array([[1, 2], [3, 4]], order="C")
a2 = numpy.array([[1, 2], [3, 4]], order="K")
assert det_hash(a1) == det_hash(a2)
def test_torch_det_hash():
a1 = numpy.array([[1, 2], [3, 4]], order="C")
a2 = numpy.array([[1, 2], [3, 4]], order="K")
a1 = torch.tensor(a1)
a2 = torch.tensor(a2)
assert det_hash(a1) == det_hash(a2)
================================================
FILE: tests/integrations/torch/eval_test.py
================================================
from tango.common.testing import TangoTestCase
class TestEvalStep(TangoTestCase):
def test_basic_eval(self):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations/torch/eval.jsonnet",
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.torch",
],
)
assert (result_dir / "eval" / "data.json").is_file()
================================================
FILE: tests/integrations/torch/format_test.py
================================================
import os
from tango import Format
from tango.common.testing import TangoTestCase
from tango.integrations.torch.format import TorchFormat
class TestTorchFormat(TangoTestCase):
def test_read_write(self):
torch_format: TorchFormat = Format.by_name("torch")() # type: ignore[assignment]
torch_format.write({"a": 1}, self.TEST_DIR)
assert os.path.exists(self.TEST_DIR / "data.pt")
data = torch_format.read(self.TEST_DIR)
assert data == {"a": 1}
================================================
FILE: tests/integrations/torch/optim_test.py
================================================
from tango.integrations.torch.optim import LRScheduler, Optimizer
def test_all_optimizers_registered():
assert "torch::Adagrad" in Optimizer.list_available()
def test_all_lr_schedulers_registered():
assert "torch::ExponentialLR" in LRScheduler.list_available()
================================================
FILE: tests/integrations/torch/train_callback_test.py
================================================
from pathlib import Path
import pytest
from torch.optim import SGD
from tango.common import DatasetDict, Lazy
from tango.integrations.torch import (
DataLoader,
StopEarly,
StopEarlyCallback,
TorchTrainingEngine,
TrainConfig,
)
from tango.workspaces import MemoryWorkspace
from .training_engine_test import DummyModel
def test_stop_early_callback():
workspace = MemoryWorkspace()
train_config = TrainConfig(step_id="FakeStep-abc123", work_dir=Path("/tmp"))
training_engine = TorchTrainingEngine(
train_config=train_config, model=DummyModel(), optimizer=Lazy(SGD, lr=0.001) # type: ignore
)
dataset_dict = DatasetDict(splits={"train": [1, 2, 3]})
train_dataloader = Lazy(DataLoader)
callback = StopEarlyCallback(
patience=10,
workspace=workspace,
train_config=train_config,
training_engine=training_engine,
dataset_dict=dataset_dict,
train_dataloader=train_dataloader,
)
callback.post_val_loop(1, 1, 0.5, 0.5)
callback.post_val_loop(2, 1, 0.5, 0.5)
callback.post_val_loop(20, 1, 0.6, 0.6)
with pytest.raises(StopEarly):
callback.post_val_loop(31, 1, 0.6, 0.6)
================================================
FILE: tests/integrations/torch/train_test.py
================================================
import json
import pytest
import torch.distributed as dist
from tango.common.logging import initialize_logging, teardown_logging
from tango.common.testing import TangoTestCase
class TestTrainStep(TangoTestCase):
def setup_method(self):
super().setup_method()
initialize_logging(enable_cli_logs=True)
def teardown_method(self):
super().teardown_method()
if dist.is_initialized():
dist.destroy_process_group()
teardown_logging()
@pytest.mark.parametrize("with_validation", [True, False])
def test_basic_train(self, with_validation: bool):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "torch" / "train.jsonnet",
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.torch",
],
overrides=""
if with_validation
else json.dumps(
{"steps.train.validation_split": None, "steps.train.validate_every": None}
),
)
assert (result_dir / "train" / "data.pt").is_file()
assert (result_dir / "train" / "work" / "weights.pt").is_file()
assert (
result_dir / "train" / "work" / "checkpoint_state_latest" / "worker0_model.pt"
).is_file()
assert (
result_dir / "train" / "work" / "checkpoint_state_best" / "worker0_optimizer.pt"
).is_file()
assert (
result_dir / "train" / "work" / "checkpoint_state_best" / "worker0_trainer.pt"
).is_file()
@pytest.mark.parametrize("grad_acc", [1, 2])
def test_basic_train_with_epochs(self, grad_acc: int):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "torch" / "train.jsonnet",
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.torch",
],
overrides=json.dumps(
{
"steps.train.train_steps": None,
"steps.train.train_epochs": 2,
"steps.train.validate_every": None,
"steps.train.grad_accum": grad_acc,
}
),
)
assert (result_dir / "train" / "data.pt").is_file()
# Make sure we trained for the right number of steps.
expected_steps = 16 // grad_acc
latest = result_dir / "train" / "work" / "checkpoint_state_latest"
assert latest.is_symlink()
last_step = result_dir / "train" / "work" / f"checkpoint_state_step{expected_steps}"
assert last_step.is_dir()
assert latest.samefile(last_step)
def test_basic_train_with_streaming_data(self):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "torch" / "train.jsonnet",
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.torch",
],
)
assert (result_dir / "train" / "data.pt").is_file()
def test_train_distributed(self):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "torch" / "train_dist.jsonnet",
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.torch",
],
)
assert (result_dir / "train" / "data.pt").is_file()
assert (result_dir / "train" / "work" / "weights.pt").is_file()
assert (
result_dir / "train" / "work" / "checkpoint_state_latest" / "worker0_model.pt"
).is_file()
assert (
result_dir / "train" / "work" / "checkpoint_state_best" / "worker0_model.pt"
).is_file()
assert (
result_dir / "train" / "work" / "checkpoint_state_latest" / "worker1_model.pt"
).is_file()
assert (
result_dir / "train" / "work" / "checkpoint_state_best" / "worker1_model.pt"
).is_file()
@pytest.mark.parametrize("grad_acc", [1, 2])
def test_train_distributed_with_epochs(self, grad_acc: int):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "torch" / "train_dist.jsonnet",
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.torch",
],
overrides=json.dumps(
{
"steps.train.train_steps": None,
"steps.train.train_epochs": 2,
"steps.train.validate_every": None,
"steps.train.grad_accum": grad_acc,
}
),
)
assert (result_dir / "train" / "data.pt").is_file()
# Make sure we trained for the right number of steps.
expected_steps = 8 // grad_acc
latest = result_dir / "train" / "work" / "checkpoint_state_latest"
assert latest.is_symlink()
last_step = result_dir / "train" / "work" / f"checkpoint_state_step{expected_steps}"
assert last_step.is_dir()
assert latest.samefile(last_step)
================================================
FILE: tests/integrations/torch/training_engine_test.py
================================================
import time
import pytest
import torch
import torch.nn as nn
from torch.nn import MSELoss
from tango.common import DatasetDict, Lazy
from tango.common.testing import TangoTestCase
from tango.integrations.torch import (
DataLoader,
StopEarly,
TorchTrainStep,
TrainCallback,
)
from tango.integrations.torch.model import Model
from tango.integrations.torch.training_engine import TorchTrainingEngine
@Model.register("dummy_model")
class DummyModel(Model):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x, y=None):
return self.linear(x)
@pytest.mark.gpu
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices")
class TestTorchTrainingEngine(TangoTestCase):
def test_grad_scaler(self):
training_engine = TorchTrainingEngine.from_params(
{
"train_config": {"step_id": "001", "work_dir": self.TEST_DIR},
"model": {
"type": "dummy_model",
},
"optimizer": {"type": "torch::Adam"},
"amp": True,
}
)
state_dict = {"training_steps": None}
training_engine.save_checkpoint(self.TEST_DIR, state_dict)
saved_grad_scaler = training_engine.grad_scaler
training_engine.load_checkpoint(self.TEST_DIR)
assert (self.TEST_DIR / "worker0_grad_scaler.pt").is_file()
assert training_engine.grad_scaler == saved_grad_scaler
class WorseningModel(Model):
def __init__(self):
super().__init__()
self.linear = nn.Linear(7, 1)
self.loss = MSELoss()
self.start_time = time.time()
def forward(self, x, y):
y_hat = self.linear(x)
time.sleep(0.01)
return {"loss": self.loss(y_hat, y) + (time.time() - self.start_time)}
class StopOnStepCallback(TrainCallback):
def __init__(self, stop_on_step: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stop_on_step = stop_on_step
def post_val_loop(
self, step: int, epoch: int, val_metric: float, best_val_metric: float
) -> None:
if step == self.stop_on_step:
raise StopEarly
def test_with_increasing_loss():
model = WorseningModel()
xs = [torch.randn(7) for _ in range(100)]
train_set = [{"x": x, "y": x + 0.1} for x in xs]
dataset = DatasetDict(splits={"train": train_set, "validation": train_set}, metadata={})
step = TorchTrainStep(
model=model,
training_engine=Lazy(TorchTrainingEngine, optimizer=Lazy(torch.optim.AdamW, lr=1e-5)),
dataset_dict=dataset,
train_dataloader=Lazy(DataLoader),
train_steps=10,
validation_steps=10,
train_split="train",
validation_split="validation",
callbacks=[Lazy(StopOnStepCallback, stop_on_step=9)],
)
step.result()
================================================
FILE: tests/integrations/transformers/data_test.py
================================================
from transformers.data.data_collator import DataCollatorWithPadding, DefaultDataCollator
from tango.integrations.torch import DataCollator
from tango.integrations.transformers.data import * # noqa: F403,F401
def test_init_collator_no_tokenizer():
collator = DataCollator.from_params({"type": "transformers::DefaultDataCollator"})
assert isinstance(collator, DefaultDataCollator)
def test_init_collator_with_tokenizer():
collator = DataCollator.from_params(
{
"type": "transformers::DataCollatorWithPadding",
"tokenizer": {
"pretrained_model_name_or_path": "epwalsh/bert-xsmall-dummy",
},
}
)
assert isinstance(collator, DataCollatorWithPadding)
================================================
FILE: tests/integrations/transformers/finetune_test.py
================================================
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from tango.common.testing import TangoTestCase
from tango.integrations.transformers import TokenizeText2TextData
class TestTokenizeText2TextData(TangoTestCase):
def test_tokenize_seq2seq(self):
dataset = Dataset.from_dict(
{"field1": ["hello", "hi"], "field2": ["world", "me"], "meta_field": [1, 0]}
)
data_dict = DatasetDict({"train": dataset})
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
step = TokenizeText2TextData()
tokenized = step.run(
data=data_dict, tokenizer=tokenizer, source_field="field1", target_field="field2"
)
assert isinstance(tokenized, DatasetDict)
assert len(tokenized["train"]) == 2
assert "input_ids" in tokenized["train"].column_names
assert "labels" in tokenized["train"].column_names
assert tokenized["train"][0]["input_ids"] == [21820, 1]
def test_tokenize_concat(self):
dataset = Dataset.from_dict(
{"field1": ["hello", "hi"], "field2": ["world", "me"], "meta_field": [1, 0]}
)
data_dict = DatasetDict({"train": dataset})
tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
step = TokenizeText2TextData()
tokenized = step.run(
data=data_dict,
tokenizer=tokenizer,
source_field="field1",
target_field="field2",
concat_source_target=True,
)
assert isinstance(tokenized, DatasetDict)
assert len(tokenized["train"]) == 2
assert "input_ids" in tokenized["train"].column_names
assert "labels" in tokenized["train"].column_names
assert tokenized["train"][0]["input_ids"] == [31373, 50257, 6894, 50256]
assert tokenized["train"][0]["labels"] == [-100, -100, 6894, 50256]
================================================
FILE: tests/integrations/transformers/ia3_test.py
================================================
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tango.integrations.transformers.ia3 import GPT_2_IA3_CONFIG, modify_with_ia3
def test_ia3():
config = GPT_2_IA3_CONFIG
model_name = "sshleifer/tiny-gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_seq = tokenizer(["A tiny test on a tiny model."], return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
with torch.inference_mode():
old_outputs = model(
input_ids=input_seq.input_ids,
attention_mask=input_seq.attention_mask,
labels=input_seq.input_ids,
)
model = modify_with_ia3(model, config=config)
with torch.inference_mode():
new_outputs = model(
input_ids=input_seq.input_ids,
attention_mask=input_seq.attention_mask,
labels=input_seq.input_ids,
)
logits_diff = torch.abs(old_outputs.logits - new_outputs.logits).mean()
assert logits_diff < 1e-10
loss_diff = torch.abs(old_outputs.loss - new_outputs.loss)
assert loss_diff < 1e-10
================================================
FILE: tests/integrations/transformers/run_generation_test.py
================================================
from tango import Step
from tango.common import DatasetDict
from tango.common.testing import TangoTestCase
from tango.integrations.transformers import RunGenerationDataset
class TestRunGeneration(TangoTestCase):
def test_run_generation(self):
step = Step.from_params( # type: ignore[assignment]
{
"type": "transformers::run_generation",
"prompts": ["Tango is the future of", "Everybody should be using Tango to"],
"model": "sshleifer/tiny-gpt2",
},
)
result = list(step.result())
assert len(result) == 2
def test_run_generation_with_model(self):
step = Step.from_params( # type: ignore[assignment]
{
"type": "transformers::run_generation",
"prompts": ["Tango is the future of", "Everybody should be using Tango to"],
"model": {
"type": "transformers::AutoModelForCausalLM::from_pretrained",
"pretrained_model_name_or_path": "sshleifer/tiny-gpt2",
},
},
)
result = list(step.result())
assert len(result) == 2
def test_run_generation_dataset(self):
dataset = DatasetDict(
{
"train": [
{"prompt": "Tango is the future of"},
{"prompt": "Everybody should be using Tango to"},
]
},
{},
)
step = RunGenerationDataset(
model="sshleifer/tiny-gpt2", input=dataset, prompt_field="prompt"
)
result = step.result()
assert len(result) == 1
train_split = result["train"]
assert len(train_split) == 2
assert len(train_split[1]) == 2
assert train_split[1]["prompt"] == "Everybody should be using Tango to"
assert all(
g.startswith("Everybody should be using Tango to")
for g in train_split[1]["prompt_generated"]
)
================================================
FILE: tests/integrations/transformers/soft_prompt_test.py
================================================
import transformers
from tango.integrations.transformers import add_soft_prompt
def test_soft_prompt():
model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-small")
tokenizer = transformers.AutoTokenizer.from_pretrained("t5-small")
prompt = "translate English to German: That is good."
model.eval()
generated = model.generate(
tokenizer.encode(prompt, return_tensors="pt"), num_beams=10, num_return_sequences=5
)
original_output = [tokenizer.decode(g) for g in generated]
add_soft_prompt(model, prompt_length=3)
model.eval()
generated = model.generate(
tokenizer.encode(prompt, return_tensors="pt"), num_beams=10, num_return_sequences=5
)
prompted_output = [tokenizer.decode(g) for g in generated]
assert original_output != prompted_output
def test_soft_prompt_twice():
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
add_soft_prompt(model, prompt_length=2)
model.eval()
generated = model.generate(tokenizer.encode("It was the best of times.", return_tensors="pt"))
prompted_output1 = tokenizer.decode(generated[0])
add_soft_prompt(model, prompt_length=5)
model.eval()
generated = model.generate(tokenizer.encode("It was the best of times.", return_tensors="pt"))
prompted_output2 = tokenizer.decode(generated[0])
assert prompted_output1 != prompted_output2
================================================
FILE: tests/integrations/wandb/__init__.py
================================================
================================================
FILE: tests/integrations/wandb/step_cache_test.py
================================================
import os
import pickle
import sys
import pytest
from tango import Step
from tango.integrations.wandb import WandbStepCache
WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "allennlp")
WANDB_PROJECT = "tango-workspace-testing"
class SomeFakeStep(Step):
DETERMINISTIC = True
CACHEABLE = True
def run(self) -> int: # type: ignore
return 1
def test_step_cache_artifact_not_found():
step = SomeFakeStep(step_name="hi there")
step_cache = WandbStepCache(project=WANDB_PROJECT, entity=WANDB_ENTITY)
assert step not in step_cache
@pytest.mark.parametrize(
"protocol",
[pytest.param(protocol, id=f"protocol={protocol}") for protocol in range(4)]
+ [
pytest.param(
5,
id="protocol=5",
marks=pytest.mark.skipif(
sys.version_info < (3, 8), reason="Protocol 5 requires Python 3.8 or newer"
),
),
],
)
def test_pickling(protocol: int):
step_cache = WandbStepCache(project=WANDB_PROJECT, entity=WANDB_ENTITY)
pickle.loads(pickle.dumps(step_cache, protocol=protocol))
================================================
FILE: tests/integrations/wandb/workspace_test.py
================================================
import json
import os
import pickle
import shutil
import sys
import uuid
import pytest
import wandb
from tango import Step, StepGraph, Workspace
from tango.common import Params, util
from tango.common.logging import initialize_logging, teardown_logging
from tango.common.testing import TangoTestCase
from tango.common.testing.steps import * # noqa: F403,F401
from tango.integrations.wandb import WandbWorkspace
from tango.step_info import StepState
WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "allennlp")
WANDB_PROJECT = "tango-workspace-testing"
class TestWandbWorkspace(TangoTestCase):
# Need to define the `setup_method()` as fixture so we can use other fixtures within it.
@pytest.fixture(autouse=True)
def setup_method(self, monkeypatch):
super().setup_method()
# Patch tango_cache_dir()
monkeypatch.setattr(util, "tango_cache_dir", lambda: self.TEST_DIR)
@pytest.mark.parametrize(
"protocol",
[pytest.param(protocol, id=f"protocol={protocol}") for protocol in range(4)]
+ [
pytest.param(
5,
id="protocol=5",
marks=pytest.mark.skipif(
sys.version_info < (3, 8), reason="Protocol 5 requires Python 3.8 or newer"
),
),
],
)
def test_pickle_workspace(self, protocol):
workspace = WandbWorkspace(project=WANDB_PROJECT, entity=WANDB_ENTITY)
unpickled_workspace = pickle.loads(pickle.dumps(workspace, protocol=protocol))
assert unpickled_workspace.wandb_client is not None
assert unpickled_workspace.project == workspace.project
assert unpickled_workspace.entity == workspace.entity
assert unpickled_workspace.steps_dir == workspace.steps_dir
def test_from_url(self):
workspace = Workspace.from_url(f"wandb://{WANDB_ENTITY}/{WANDB_PROJECT}")
assert isinstance(workspace, WandbWorkspace)
assert workspace.entity == WANDB_ENTITY
assert workspace.project == WANDB_PROJECT
class TestWandbWorkspaceUsage(TangoTestCase):
# Need to define the `setup_method()` as fixture so we can use other fixtures within it.
@pytest.fixture(autouse=True)
def setup_method(self, monkeypatch):
super().setup_method()
self.UNIQUE_ID_SUFFIX = os.environ.get("GITHUB_SHA", "")[:6] + "-" + str(uuid.uuid1())[:6]
# Patch tango_cache_dir()
monkeypatch.setattr(util, "tango_cache_dir", lambda: self.TEST_DIR)
# Patch Step unique IDs and W&B run IDs.
monkeypatch.setattr(Step, "_UNIQUE_ID_SUFFIX", self.UNIQUE_ID_SUFFIX)
monkeypatch.setattr(
WandbWorkspace,
"_generate_run_suite_id",
lambda workspace: wandb.util.generate_id() + "-" + self.UNIQUE_ID_SUFFIX,
)
self.workspace = WandbWorkspace(project=WANDB_PROJECT, entity=WANDB_ENTITY)
initialize_logging(enable_cli_logs=True)
def teardown_method(self):
super().teardown_method()
# Delete W&B runs and their artifacts produced by the test.
for wandb_run in self.workspace.wandb_client.runs(
f"{WANDB_ENTITY}/{WANDB_PROJECT}",
):
if (
self.UNIQUE_ID_SUFFIX in wandb_run.id
or self.UNIQUE_ID_SUFFIX in wandb_run.config.get("_run_suite_id", "")
):
wandb_run.delete(delete_artifacts=True)
teardown_logging()
def test_direct_usage(self):
params = Params.from_file(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet")
step_graph = StepGraph.from_params(params.pop("steps", keep_as_dict=True))
tango_run = self.workspace.register_run(step for step in step_graph.values())
# Test 'registered_run()' and 'registered_runs()' methods.
assert self.workspace.registered_run(tango_run.name) == tango_run
assert self.workspace.registered_runs()[tango_run.name] == tango_run
hello_step = step_graph["hello"]
hello_world_step = step_graph["hello_world"]
# Test getting step info.
step_info = self.workspace.step_info(hello_step)
assert step_info.unique_id.endswith(self.UNIQUE_ID_SUFFIX)
assert step_info.step_name == "hello"
assert step_info.state == StepState.INCOMPLETE
# Mark the "hello" step as starting.
self.workspace.step_starting(hello_step)
assert self.workspace.step_info(hello_step).state == StepState.RUNNING
# Mark the "hello" step as finished.
self.workspace.step_finished(hello_step, "hello")
assert self.workspace.step_info(hello_step).state == StepState.COMPLETED
# Make sure the result is in the cache, exists locally, and on W&B.
cache = self.workspace.cache
assert hello_step in cache
assert cache.step_dir(hello_step).is_dir()
assert cache.get_step_result_artifact(hello_step) is not None
# Now make sure we can fetch the item from the cache, even if it's not in memory
# or in the cache directory.
if hello_step.unique_id in cache.weak_cache:
del cache.weak_cache[hello_step.unique_id]
if hello_step.unique_id in cache.strong_cache:
del cache.strong_cache[hello_step.unique_id]
shutil.rmtree(cache.step_dir(hello_step))
assert hello_step in cache
assert cache[hello_step] == "hello"
# Now start the "hello_world" step and then mark it as failed.
self.workspace.step_starting(hello_world_step)
self.workspace.step_failed(hello_world_step, ValueError("oh no!"))
assert self.workspace.step_info(hello_world_step).state == StepState.FAILED
@pytest.mark.parametrize(
"multicore", [pytest.param(True, id="multicore"), pytest.param(False, id="singe-core")]
)
@pytest.mark.parametrize(
"distributed",
[
pytest.param(True, id="distributed"),
pytest.param(False, id="single-device"),
],
)
def test_with_wandb_train_callback(self, multicore: bool, distributed: bool):
self.run(
self.FIXTURES_ROOT
/ "integrations"
/ "torch"
/ ("train.jsonnet" if not distributed else "train_dist.jsonnet"),
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.torch",
],
overrides=json.dumps({"steps.train.callbacks": [{"type": "wandb::log"}]}),
workspace_url=f"wandb://{WANDB_ENTITY}/{WANDB_PROJECT}",
multicore=multicore,
)
================================================
FILE: tests/main_test.py
================================================
import json
import os
import re
import subprocess
from pathlib import Path
from typing import List, Tuple
import click
import pytest
from tango.common.testing import TangoTestCase
from tango.settings import TangoGlobalSettings
from tango.version import VERSION
class TestRun(TangoTestCase):
def clean_log_lines(
self, log_lines: List[str], file_friendly_logging: bool = False
) -> List[str]:
out = []
for line in log_lines:
unstyled_line = click.unstyle(line)
if file_friendly_logging:
assert line == unstyled_line
line = unstyled_line
parts = re.split(r"(DEBUG|INFO|WARNING|ERROR|CRITICAL)\s+", line)
if len(parts) >= 3:
line = "".join(parts[2:])
line = re.sub(r"\s+[^ ]+$", "", line)
elif len(parts) == 1:
line = parts[0]
else:
raise ValueError(str(parts))
if line:
out.append(line.strip())
return out
def check_logs(
self,
run_dir: Path,
process_result: subprocess.CompletedProcess,
file_friendly_logging: bool = False,
) -> Tuple[List[str], List[str]]:
stdout_lines = process_result.stdout.decode().replace("\r", "\n").split("\n")
cleaned_stdout_lines = self.clean_log_lines(stdout_lines, file_friendly_logging)
log_file = run_dir / "out.log"
assert log_file.is_file()
log_lines = open(log_file).read().split("\n")
cleaned_log_lines = self.clean_log_lines(log_lines)
for line in cleaned_stdout_lines[
next(i for i, line in enumerate(stdout_lines) if "Starting new run" in line) :
]:
assert line in cleaned_log_lines
return log_lines, cleaned_log_lines
def test_version(self):
result = subprocess.run(["tango", "--version"], capture_output=True, text=True)
assert result.returncode == 0
assert VERSION in result.stdout
@pytest.mark.parametrize("log_level", ["debug", "info", "warning", "error"])
@pytest.mark.parametrize("raise_error", (True, False))
def test_logging_all_levels(self, log_level: str, raise_error):
cmd = [
"tango",
"--log-level",
log_level,
"run",
str(self.FIXTURES_ROOT / "experiment" / "noisy.jsonnet"),
"-w",
str(self.TEST_DIR),
"-o",
json.dumps({"steps.noisy_step.raise_error": raise_error}),
]
result = subprocess.run(cmd, capture_output=True)
run_dir = next((self.TEST_DIR / "runs").iterdir())
if raise_error:
assert result.returncode == 1
else:
assert result.returncode == 0
_, cleaned_log_lines = self.check_logs(run_dir, result)
# Debug messages.
assert cleaned_log_lines.count("debug message from cli_logger") == 1
assert cleaned_log_lines.count("debug message") == (1 if log_level == "debug" else 0)
# Info messages.
assert cleaned_log_lines.count("info message from cli_logger") == 1
assert cleaned_log_lines.count("info message") == (
1 if log_level in {"debug", "info"} else 0
)
# Warning messages.
assert cleaned_log_lines.count("warning message from cli_logger") == 1
assert cleaned_log_lines.count("warning message") == (
1 if log_level in {"debug", "info", "warning"} else 0
)
# Error messages.
assert cleaned_log_lines.count("error message from cli_logger") == 1
assert cleaned_log_lines.count("error message") == (
1 if log_level in {"debug", "info", "warning", "error"} else 0
)
# Traceback.
if raise_error:
assert "Traceback (most recent call last):" in cleaned_log_lines
assert "ValueError: Oh no!" in cleaned_log_lines
def test_deterministic_experiment(self):
cmd = [
"tango",
"run",
str(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet"),
"-w",
str(self.TEST_DIR),
]
result = subprocess.run(cmd, capture_output=True)
assert result.returncode == 0
assert len(os.listdir(self.TEST_DIR / "cache")) == 2
run_dir = next((self.TEST_DIR / "runs").iterdir())
assert (run_dir / "hello").is_dir()
assert (run_dir / "hello" / "cache-metadata.json").is_file()
assert (run_dir / "hello_world").is_dir()
# Check logs.
self.check_logs(run_dir, result)
# Running again shouldn't create any more directories in the cache.
result = subprocess.run(cmd)
assert result.returncode == 0
assert len(os.listdir(self.TEST_DIR / "cache")) == 2
# We should see two runs now.
assert len(os.listdir(self.TEST_DIR / "runs")) == 2
def test_experiment_with_memory_workspace(self):
cmd = [
"tango",
"run",
str(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet"),
"-w",
"memory://",
]
result = subprocess.run(cmd, capture_output=True)
assert result.returncode == 0
def test_experiment_with_default_workspace(self):
cmd = [
"tango",
"run",
str(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet"),
]
result = subprocess.run(cmd, capture_output=True)
assert result.returncode == 0
def test_random_experiment(self):
cmd = [
"tango",
"run",
str(self.FIXTURES_ROOT / "experiment" / "random.jsonnet"),
"-w",
str(self.TEST_DIR),
]
result = subprocess.run(cmd)
assert result.returncode == 0
def test_run_name(self):
name = "unique-tango-run-name"
cmd = [
"tango",
"run",
str(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet"),
"-w",
str(self.TEST_DIR),
"--name",
name,
]
result = subprocess.run(cmd, capture_output=True)
run_dir = next((self.TEST_DIR / "runs").iterdir())
_, clean_log_lines = self.check_logs(run_dir, result)
assert result.returncode == 0
assert f"Starting new run {name}" == clean_log_lines[0]
@pytest.mark.parametrize("parallelism", [1, 2])
@pytest.mark.parametrize("start_method", ["fork", "spawn"])
@pytest.mark.parametrize("file_friendly_logging", [True, False])
def test_experiment_with_logging_and_multiprocessing(
self, parallelism, start_method, file_friendly_logging
):
cmd = (
[
"tango",
"--log-level",
"info",
"--start-method",
start_method,
]
+ ([] if not file_friendly_logging else ["--file-friendly-logging"])
+ [
"run",
str(self.FIXTURES_ROOT / "experiment" / "logging_check.jsonnet"),
"-w",
str(self.TEST_DIR),
"-j",
str(parallelism),
]
)
result = subprocess.run(cmd, capture_output=True)
run_dir = next((self.TEST_DIR / "runs").iterdir())
_, clean_log_lines = self.check_logs(run_dir, result, file_friendly_logging)
all_logs = "\n".join(clean_log_lines)
assert "[step stringA] 0 - This is a logging test." in clean_log_lines
assert "[step stringC] 0 - This is also a logging test." in clean_log_lines
assert (
"[step final_string] 0 - This is a logging test. This is being logged."
in clean_log_lines
)
# Make sure tqdm output makes it into the log file.
assert "[step stringA] log progress: 100%" in all_logs
assert "[step stringC] log progress: 100%" in all_logs
assert "[step final_string] log progress: 100%" in all_logs
# And logs from steps that contain multiprocessing themselves.
assert "[step multiprocessing_result rank 0] Hello from worker 0!" in all_logs
assert "[step multiprocessing_result rank 1] Hello from worker 1!" in all_logs
assert (
"[step multiprocessing_result rank 0] Hello from the cli logger in worker 0!"
in all_logs
)
assert (
"[step multiprocessing_result rank 1] Hello from the cli logger in worker 1!"
in all_logs
)
assert "[step multiprocessing_result] progress from main process: 100%" in all_logs
class TestSettings(TangoTestCase):
def setup_method(self):
super().setup_method()
self._wd_backup = os.getcwd()
os.chdir(self.TEST_DIR)
cmd = "tango settings init -p ./tango.yml".split(" ")
subprocess.run(cmd, check=True)
def teardown_method(self):
os.chdir(self._wd_backup)
super().teardown_method()
@property
def settings(self) -> TangoGlobalSettings:
return TangoGlobalSettings.from_file(self.TEST_DIR / "tango.yml")
def test_settings_set_workspace(self):
cmd = "tango settings set workspace ./workspace".split(" ")
subprocess.run(cmd, check=True)
assert self.settings.workspace == {
"type": "local",
"dir": str((self.TEST_DIR / "workspace").resolve()),
}
def test_settings_set_include_package(self):
cmd = "tango settings set include-package tango.steps".split(" ")
subprocess.run(cmd, check=True)
assert self.settings.include_package == ["tango.steps"]
def test_settings_set_include_package_invalid(self):
cmd = "tango settings set include-package foo".split(" ")
with pytest.raises(subprocess.CalledProcessError):
subprocess.run(cmd, check=True)
def test_settings_set_environment(self):
cmd = "tango settings set env FOO BAR".split(" ")
subprocess.run(cmd, check=True)
assert self.settings.environment == {"FOO": "BAR"}
def test_settings_set_environment_blocked_var(self):
cmd = "tango settings set env TANGO_LOG_LEVEL info".split(" ")
with pytest.raises(subprocess.CalledProcessError):
subprocess.run(cmd, check=True)
================================================
FILE: tests/step_caches/__init__.py
================================================
================================================
FILE: tests/step_caches/local_step_cache_test.py
================================================
import pickle
import sys
import pytest
from tango.common.testing import TangoTestCase
from tango.step import Step
from tango.step_caches.local_step_cache import LocalStepCache
class DummyStep(Step):
def run(self, x: int) -> int: # type: ignore[override]
return x
class TestLocalStepCache(TangoTestCase):
@pytest.mark.parametrize(
"protocol",
[pytest.param(protocol, id=f"protocol={protocol}") for protocol in range(4)]
+ [
pytest.param(
5,
id="protocol=5",
marks=pytest.mark.skipif(
sys.version_info < (3, 8), reason="Protocol 5 requires Python 3.8 or newer"
),
),
],
)
def test_pickling(self, protocol: int):
step = DummyStep(step_name="dummy", x=1)
step_cache = LocalStepCache(self.TEST_DIR)
step_cache[step] = 1
assert step in step_cache
assert step.unique_id in step_cache.strong_cache
pickled_step_cache = pickle.dumps(step_cache, protocol=protocol)
unpickled_step_cache = pickle.loads(pickled_step_cache)
assert step.unique_id not in unpickled_step_cache.strong_cache
assert step in unpickled_step_cache
================================================
FILE: tests/step_graph_test.py
================================================
import re
from copy import deepcopy
from tempfile import NamedTemporaryFile
import pytest
from tango.common.exceptions import ConfigurationError
from tango.common.testing import TangoTestCase
from tango.common.testing.steps import ( # noqa: F401
AddNumbersStep,
ConcatStringsStep,
StringStep,
)
from tango.step_graph import StepGraph
class TestStepGraph(TangoTestCase):
def test_ordered_steps(self):
step_graph = StepGraph.from_params(
{
"stepB": {
"type": "add_numbers",
"a_number": 2,
"b_number": 3,
},
"stepC": {
"type": "add_numbers",
"a_number": {"type": "ref", "ref": "stepB"},
"b_number": 5,
},
"stepA": {
"type": "add_numbers",
"a_number": 3,
"b_number": 1,
},
}
)
result = StepGraph.ordered_steps(step_graph.parsed_steps)
assert [res.name for res in result] == ["stepB", "stepC", "stepA"]
def test_from_file(self):
step_graph = StepGraph.from_file(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet")
assert "hello" in step_graph
assert "hello_world" in step_graph
def test_missing_type(self):
with pytest.raises(ConfigurationError, match=re.escape('key "type" is required')):
StepGraph.from_params(
{
"step3": {
"a_number": 3,
"b_number": 1,
},
}
)
def test_direct_construction(self):
step_a = AddNumbersStep(a_number=3, b_number=2, step_name="stepA")
step_b = AddNumbersStep(a_number=step_a, b_number=2, step_name="stepB")
step_graph = StepGraph({"stepA": step_a, "stepB": step_b})
assert list(step_graph.parsed_steps.keys()) == ["stepA", "stepB"]
def test_direct_construction_missing_dependency(self):
step_a = AddNumbersStep(a_number=3, b_number=2, step_name="stepA")
step_b = AddNumbersStep(a_number=step_a, b_number=2, step_name="stepB")
with pytest.raises(ConfigurationError, match="Or a missing dependency"):
StepGraph({"stepB": step_b})
def test_to_file(self):
step_graph = StepGraph.from_file(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet")
with NamedTemporaryFile(
prefix="test-step-graph-to-file-", suffix=".jsonnet", dir=self.TEST_DIR
) as file_ref:
step_graph.to_file(file_ref.name)
new_step_graph = StepGraph.from_file(file_ref.name)
assert step_graph == new_step_graph
def test_to_file_without_config(self):
from tango.format import JsonFormat
step_a = AddNumbersStep(a_number=3, b_number=2, step_name="stepA", cache_results=False)
step_b = AddNumbersStep(
a_number=step_a, b_number=2, step_name="stepB", step_format=JsonFormat("gz")
)
step_graph = StepGraph({"stepA": step_a, "stepB": step_b})
with NamedTemporaryFile(
prefix="test-step-graph-to-file-without-config", suffix=".jsonnet", dir=self.TEST_DIR
) as file_ref:
step_graph.to_file(file_ref.name)
new_step_graph = StepGraph.from_file(file_ref.name)
assert step_graph == new_step_graph
def test_with_step_indexer(self):
config = {
"list": {"type": "range_step", "start": 0, "end": 3},
"added": {
"type": "add_numbers",
"a_number": 2,
"b_number": {"type": "ref", "ref": "list", "key": 1},
},
}
step_graph = StepGraph.from_params(deepcopy(config)) # type: ignore[arg-type]
assert [s.name for s in step_graph["added"].dependencies] == ["list"]
assert step_graph.to_config() == config
def test_with_forced_dependencies(self):
config = {
"some_string": {
"type": "string",
"result": "I should run second",
"step_extra_dependencies": [{"type": "ref", "ref": "other_string"}],
},
"other_string": {"type": "string", "result": "I should run first"},
"added": {
"type": "concat_strings",
"string1": "Some string:",
"string2": {"type": "ref", "ref": "some_string"},
},
}
step_graph = StepGraph.from_params(deepcopy(config)) # type: ignore[arg-type]
assert step_graph["some_string"].dependencies == {step_graph["other_string"]}
assert step_graph["added"].recursive_dependencies == {
step_graph["other_string"],
step_graph["some_string"],
}
================================================
FILE: tests/step_info_test.py
================================================
import json
from pathlib import Path
from typing import Any
from tango.common.testing.steps import FloatStep
from tango.step import Step
from tango.step_graph import StepGraph
from tango.step_info import StepInfo
def test_step_info():
step = FloatStep(step_name="float", result=1.0)
step_info = StepInfo.new_from_step(step)
# Check Git metadata.
if (Path.cwd() / ".git").exists():
assert step_info.environment.git is not None
assert step_info.environment.git.commit is not None
assert step_info.environment.git.remote is not None
assert "allenai/tango" in step_info.environment.git.remote
# Check pip requirements.
assert step_info.environment.packages is not None
# Test serialization / deserialization.
serialized = json.dumps(step_info.to_json_dict())
deserialized = StepInfo.from_json_dict(json.loads(serialized))
assert deserialized == step_info
def test_step_info_with_step_dependency():
"""Checks that the StepInfo config is not parsed to a Step if it has dependencies on upstream steps"""
@Step.register("foo", exist_ok=True)
class FooStep(Step):
def run(self, bar: Any) -> str: # type: ignore
return "foo" + bar
@Step.register("bar", exist_ok=True)
class BarStep(Step):
def run(self) -> str: # type: ignore
return "Hey!"
graph = StepGraph.from_params(
{
"foo": {
"type": "foo",
"bar": {"type": "ref", "ref": "bar"},
},
"bar": {
"type": "bar",
},
}
)
step = graph["foo"]
step_info = StepInfo.new_from_step(step)
step_info_json = json.dumps(step_info.to_json_dict())
step_info = StepInfo.from_json_dict(json.loads(step_info_json))
assert isinstance(step_info.config, dict)
================================================
FILE: tests/step_test.py
================================================
import collections
from typing import Any, Dict, Mapping, MutableMapping
import pytest
from tango import StepGraph
from tango.common import Params, Registrable
from tango.common.exceptions import ConfigurationError
from tango.common.from_params import FromParams
from tango.common.testing import TangoTestCase
from tango.step import FunctionalStep, Step, step
from tango.workspaces import MemoryWorkspace
class TestStep(TangoTestCase):
def test_from_params(self):
step = Step.from_params({"type": "float", "result": 3})
result = step.result()
assert result == 3
def test_from_params_wrong_type(self):
with pytest.raises(TypeError):
Step.from_params({"type": "float", "result": "not a float"})
def test_step_with_from_params_input(self):
class Bar(FromParams):
def __init__(self, x: int):
self.x = x
@Step.register("foo", exist_ok=True)
class FooStep(Step):
def run(self, bar: Bar) -> Bar: # type: ignore
return bar
step = Step.from_params({"type": "foo", "bar": {"x": 1}})
assert step.result().x == 1
def test_no_hash_arguments(self):
@Step.register("no_hash_step")
class SkipArgStep(Step):
SKIP_ID_ARGUMENTS = {"arg"}
def run(self, arg: str) -> int: # type: ignore
return 5
step1 = SkipArgStep(arg="foo")
step2 = SkipArgStep(arg="bar")
assert step1.unique_id == step2.unique_id
def test_skip_default_arguments(self):
class SkipArgStep(Step):
def run(self) -> int: # type: ignore
return 5
old_hash = SkipArgStep().unique_id
class SkipArgStep(Step):
SKIP_DEFAULT_ARGUMENTS = {"arg": 5}
def run(self, arg: int = 5) -> int: # type: ignore
return arg
assert SkipArgStep().unique_id == old_hash
assert SkipArgStep(arg=5).unique_id == old_hash
assert SkipArgStep(arg=6).unique_id != old_hash
def test_massage_kwargs(self):
class CountLettersStep(Step):
@classmethod
def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
kwargs = kwargs.copy()
kwargs["text"] = kwargs["text"].lower()
return kwargs
def run(self, text: str) -> Mapping[str, int]: # type: ignore
text = text.lower()
counter: MutableMapping[str, int] = collections.Counter()
for c in text:
counter[c] += 1
return counter
upper = CountLettersStep(text="FOO")
lower = CountLettersStep(text="foo")
assert upper.unique_id == lower.unique_id
assert upper.result() == lower.result()
def test_default_args(self):
class DefaultArgStep(Step[int]):
def run(self, left: int, right: int = 0) -> int: # type: ignore
return left + right
explicit = DefaultArgStep(left=1, right=0)
implicit = DefaultArgStep(left=1)
assert explicit.unique_id == implicit.unique_id
assert explicit.result() == implicit.result()
def test_steps_in_params(self):
class Widget(Registrable):
def __init__(self, x: int):
self.x = x
@Widget.register("gizmo")
class GizmoWidget(Widget):
def __init__(self, x: int):
super().__init__(x * x)
@Step.register("consumer")
class WidgetConsumerStep(Step):
def run(self, widget: Widget): # type: ignore
return widget.x
@Step.register("producer")
class WidgetProducerStep(Step):
def run(self, x: int) -> Widget: # type: ignore
return GizmoWidget(x)
config = {
"widget_producer": Params({"type": "producer", "x": 4}),
"widget_consumer": Params(
{"type": "consumer", "widget": {"type": "ref", "ref": "widget_producer"}}
),
}
sg = StepGraph.from_params(config)
assert len(sg["widget_consumer"].dependencies) > 0
class WidgetHolder(Registrable):
def __init__(self, widget: Widget):
self.widget = widget
@WidgetHolder.register("gizmo")
class GizmoWidgetHolder(WidgetHolder):
def __init__(self, gizmo: GizmoWidget):
super().__init__(gizmo)
@Step.register("holder_consumer")
class WidgetHolderConsumerStep(Step):
def run(self, widget_holder: WidgetHolder) -> int: # type: ignore
return widget_holder.widget.x
config = {
"widget_producer": Params({"type": "producer", "x": 4}),
"holder_consumer": Params(
{
"type": "holder_consumer",
"widget_holder": {
"type": "gizmo",
"gizmo": {"type": "ref", "ref": "widget_producer"},
},
}
),
}
sg = StepGraph.from_params(config)
assert len(sg["holder_consumer"].dependencies) > 0
def test_functional_step(self):
class Bar(FromParams):
def __init__(self, x: int):
self.x = x
@step(exist_ok=True)
def foo(bar: Bar) -> int:
return bar.x
assert issubclass(foo, FunctionalStep)
assert foo().run(Bar(x=1)) == 1
foo_step = Step.from_params({"type": "foo", "bar": {"x": 1}})
assert isinstance(foo_step, FunctionalStep)
assert isinstance(foo_step.kwargs["bar"], Bar)
def test_bound_functional_step(self):
class Bar(FromParams):
def __init__(self, x: int):
self.x = x
@step(exist_ok=True, bind=True)
def foo(self, bar: Bar) -> int:
assert self.work_dir.is_dir()
return bar.x
foo_step = Step.from_params({"type": "foo", "bar": {"x": 1}})
assert isinstance(foo_step, FunctionalStep)
assert foo_step.result(MemoryWorkspace()) == 1
def test_bound_functional_step_missing_self(self):
@step(exist_ok=True, bind=True)
def foo(x: int) -> int:
return x
with pytest.raises(ConfigurationError):
Step.from_params({"type": "foo", "x": 1})
@step(exist_ok=True, bind=True)
def bar(s, x: int) -> int:
return x
with pytest.raises(ConfigurationError):
Step.from_params({"type": "bar", "x": 1})
================================================
FILE: tests/steps/__init__.py
================================================
================================================
FILE: tests/steps/dataset_remix_test.py
================================================
from tango.common.dataset_dict import DatasetDict
from tango.steps.dataset_remix import DatasetRemixStep
def test_dataset_remix_step():
step = DatasetRemixStep("remix")
dataset_dict = DatasetDict(
{
"train": list(range(10)),
"dev": list(range(10, 15)),
"test": list(range(15, 20)),
}
)
result = step.run(
input=dataset_dict,
new_splits={
"all_train": "train + dev",
"cross_val_train": "train[:8]",
"cross_val_dev": "train[-2:]",
},
)
assert len(result["all_train"]) == len(dataset_dict["train"]) + len(dataset_dict["dev"])
================================================
FILE: tests/steps/shell_step_test.py
================================================
import pytest
from tango.common.testing import TangoTestCase
from tango.steps.shell_step import ShellStep, make_registrable
class TestShellStep(TangoTestCase):
def test_shell_step(self):
step = ShellStep()
result = step.run("echo hello")
assert isinstance(result, str)
assert result == "hello\n"
def test_shell_step_failure(self):
step = ShellStep()
with pytest.raises(RuntimeError):
step.run("ls -l non_existent_path")
def test_shell_step_with_output_path(self, caplog):
output_path = self.TEST_DIR / "test-folder"
step = ShellStep()
step.run(f"mkdir {output_path}", output_path=output_path)
assert f"Output found at: {output_path}" in caplog.text
def test_shell_step_different_validation(self, caplog):
@make_registrable(exist_ok=True)
def validate_func(path):
"""
Validates that the file contents of the `path` are a json string.
"""
import json
with open(path) as f:
json.load(f)
output_path = self.TEST_DIR / "hello.json"
command = f"python3 -c \"import json; print(json.dumps({{'a': 23}}))\" > {output_path}"
step = ShellStep()
step.run(command, output_path=output_path, validate_output=validate_func, shell=True)
assert f"Output found at: {output_path}" in caplog.text
def test_shell_step_in_config(self, caplog):
output_path = str(self.TEST_DIR / "test-folder")
config = {
"steps": {
"create_dir": {
"type": "shell_step",
"shell_command": f"mkdir {output_path}",
"output_path": output_path,
"validate_output": {"type": "check_path_existence"},
},
}
}
# Regular run contains all step outputs.
self.run(config)
assert f"Output found at: {output_path}" in caplog.text
================================================
FILE: tests/workspaces/__init__.py
================================================
================================================
FILE: tests/workspaces/local_workspace_test.py
================================================
from shutil import copytree
import pytest
from sqlitedict import SqliteDict
from tango import Step
from tango.common.testing import TangoTestCase
from tango.step_info import StepState
from tango.workspaces import LocalWorkspace
class AdditionStep(Step):
def run(self, a: int, b: int) -> int: # type: ignore
return a + b
class TestLocalWorkspace(TangoTestCase):
def test_local_workspace_one_step(self):
workspace = LocalWorkspace(self.TEST_DIR)
step = AdditionStep(a=1, b=2)
with pytest.raises(KeyError):
# This can't possibly work because the workspace has never seen that step before.
step_info = workspace.step_info(step.unique_id)
assert step_info.state == StepState.INCOMPLETE
step_info = workspace.step_info(step)
assert step_info.state == StepState.INCOMPLETE
result = step.result(workspace)
assert result == 3
step_info = workspace.step_info(step.unique_id)
assert step_info.state == StepState.COMPLETED
step_info = workspace.step_info(step)
assert step_info.state == StepState.COMPLETED
def test_local_workspace_two_steps(self):
workspace = LocalWorkspace(self.TEST_DIR)
step1 = AdditionStep(a=1, b=2)
step2 = AdditionStep(a=step1, b=3)
step_info = workspace.step_info(step2)
assert step_info.state == StepState.INCOMPLETE
step_info = workspace.step_info(step2.unique_id)
assert step_info.state == StepState.INCOMPLETE
assert step1.unique_id in step_info.dependencies
step_info = workspace.step_info(step1.unique_id)
assert step_info.state == StepState.INCOMPLETE
step_info = workspace.step_info(step1)
assert step_info.state == StepState.INCOMPLETE
result = step2.result(workspace)
assert result == 6
for step in [step1, step2]:
step_info = workspace.step_info(step.unique_id)
assert step_info.state == StepState.COMPLETED
step_info = workspace.step_info(step)
assert step_info.state == StepState.COMPLETED
def test_local_workspace_upgrade_v1_to_v2(self):
workspace_dir = self.TEST_DIR / "workspace"
copytree(
self.FIXTURES_ROOT / "v1_local_workspace",
workspace_dir,
symlinks=True,
)
workspace = LocalWorkspace(workspace_dir)
step_info = workspace.step_info("SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz")
assert step_info.state == StepState.COMPLETED
dependencies = list(step_info.dependencies)
# Make sure all the dependencies are there.
while len(dependencies) > 0:
step_info = workspace.step_info(dependencies.pop())
dependencies.extend(step_info.dependencies)
def test_remove_step(self):
workspace = LocalWorkspace(self.TEST_DIR)
step = AdditionStep(a=1, b=2)
workspace.step_starting(step)
workspace.step_finished(step, 1.0)
with SqliteDict(workspace.step_info_file) as d:
assert step.unique_id in d
cache = workspace.step_cache
assert step in cache
workspace.remove_step(step.unique_id)
with SqliteDict(workspace.step_info_file) as d:
assert step.unique_id not in d
cache = workspace.step_cache
assert step not in cache
================================================
FILE: tests/workspaces/memory_workspace_test.py
================================================
from tango.common.testing.steps import FloatStep
from tango.workspaces import MemoryWorkspace
def test_remove_step():
workspace = MemoryWorkspace()
step = FloatStep(step_name="float", result=1.0)
workspace.step_starting(step)
workspace.step_finished(step, 1.0)
cache = workspace.step_cache
assert step.unique_id in workspace.unique_id_to_info
assert step in cache
workspace.remove_step(step.unique_id)
cache = workspace.step_cache
assert step.unique_id not in workspace.unique_id_to_info
assert step not in cache